大语言模型 05 运行、微调的显存计算详解与优化 全量微调、LoRA 优化策略

写在前面

随着Transformer架构的大语言模型(LLM)不断发展,其参数规模也在迅速增加。无论是进行模型推理还是微调训练,GPU显存消耗都是开发和应用LLM时的重要考量。本文将详细探讨大模型运行(推理)与微调时的显存计算方式。

随着Transformer架构的大语言模型(LLM)不断发展,模型参数规模急剧膨胀,显存消耗成为推理和微调过程中的核心瓶颈。本文系统梳理了大模型在推理与微调阶段的显存计算方法,详细分析了模型参数、优化器状态、中间激活值、批处理大小等因素对显存需求的影响。通过具体案例(如1.5B、7B模型)量化分析显存占用,并总结了常用的显存优化策略,包括混合精度训练、梯度检查点、模型并行、量化剪枝等。此外,还特别介绍了LoRA微调技术在降低显存压力方面的优势。通过本文,读者将能清晰掌握大模型显存消耗的计算逻辑,从而更科学地进行资源规划与优化实践。

在这里插入图片描述

显存计算的重要性

GPU显存限制直接决定了你能运行的模型规模、批处理大小(batch size)和序列长度(sequence length)。因此,掌握显存的计算方法对于优化和合理使用GPU资源尤为重要。

全量微调

我们在微调的过程中,显存主要分布状况梯度计算:
● 反向传播需要存储梯度,相比推理阶段显存占用会大幅增加。
● 优化器状态:例如 Adam 之类的优化器会存储额外的参数状态(如一阶、二阶动量),通常会使显存占用增加 2-3 倍。
● 计算图:PyTorch 计算图在反向传播时需要保留中间激活值,会比前向推理额外消耗显存。
● Batch Size:数据集的 batch size 会影响显存消耗,较大的 batch 会显著增加显存占用。
● 混合精度训练(FP16 vs FP32):如果是 FP32 训练,会比 FP16 训练占用更大显存。
● 梯度累积:如果使用梯度累积(gradient accumulation),每一步的显存占用会相对降低,但不会减少总需求。

请添加图片描述

模型占用

首先根据模型参数大小计算出模型的大小:
比如7B模型:
● 如果是FP32则:7B4 = 28GB
● 如果是FP16则:7B
2 = 14GB

优化器占用

Adam 优化器一般会占用 2~3倍参数大小的显存:
● FP32 大约 28GB*3,大概是 56GB~84GB
● FP16 大约 28GB ~ 42GB

如果是 SGD 之类的优化器,占用会小很多(1 倍参数量)。

中间激活值

这个很难精确计算,但一般会比参数量多 1.5~3 倍。
例如,7B 模型 FP16 训练时,激活值大概 21GB - 42GB。

批处理

每个 batch 会加载部分数据到显存,每个 token 可能会占用 2-4B(FP16 vs FP32)。
如果 batch_size=8,每个序列 2048 个 token,假设 FP16,则是 8 * 2048 * 2B = 32MB
但是 Transformer 计算中间层需要额外的显存,可能会放大到 3~4 倍,则 96MB~128MB

计算实例1.5B

模型本身

假设是一个未量化的 1.5B 模型,假设是 FP16(每个参数2B):
1.5B * 2B = 3GB,模型本身大约是3GB。

优化器

用 AdamW 的话,存储:
● 权重参数(1x)
● 一阶动量(1x)
● 二阶动量(1x)

合计 3 倍参数量,AdamW 需要 9GB。

提取存储

梯度和模型参数大小相同,FP16的话:1.5B * 2B = 3GB

激活中间值

Transformer 涉及到多个层的值,通常是模型参数的 1.5~3倍
3GB * 2 = 6GB

梯度存储

gradient_accumulation_steps,梯度积累的话,我设置8,需要消耗 3GB * 8 = 24GB

总占用

总共显存占用:54GB,PyTorch、CUDA等缓存还需要10%~20%,最终大约在 55GB ~ 65GB。

计算实例7B

模型本身

7B参数 FP16 训练(每个参数2B)
7B * 2 = 14GB

优化器

● 权重参数(1x)
● 一阶动量(1x)
● 二阶动量(1x)

合计是3倍的参数量:14GB * 3 = 42GB

梯度占用

7B * 2 = 14GB

激活值

Transformer 涉及到多个层,激活值通常是模型参数的 1.5 ~ 3 倍
假设是2倍的话,14GB * 2 = 28GB

梯度累积

gradient_accumulation_steps 假设是8,则需要存储8个梯度,14*8 = 112 GB

额外占用

PyTorch、CUDA等可能增加 10% ~ 20% 的显存

总计占用

大概在 220GB ~ 230GB 之间,峰值可能更高(240GB 以内)。

优化显存使用的策略

为降低显存占用,常用以下策略:

  • 使用混合精度训练(如FP16)Mixed Precision:通过使用FP16或BF16等低精度数据类型,显著减少模型参数和梯度存储的显存需求,同时提高训练速度。
  • 梯度检查点(Gradient Checkpointing)以减少激活占用:梯度检查点技术通过重新计算部分前向传播结果,显著减少训练过程中需要存储的激活内存,从而降低整体显存消耗。
  • 模型并行、流水线并行、张量并行:通过将模型的不同部分分配到多个GPU设备上,分担单个GPU的显存压力。
  • 量化和剪枝模型:通过减少参数精度或去除冗余参数,减少模型参数总量,有效降低模型的显存需求和计算成本。
  • 流水线并行(Pipeline Parallelism):模型各层或子模块在不同GPU上流水线执行,有效提高显存和计算资源的利用率。

LoRA微调

LoRA 极大减少显存需求,适合在消费级 GPU(如 24GB 4090)上微调大模型,而全参数微调需要多个高端 GPU(如 4×A100 80GB)。

参数计算

LoRA 只修改部分的 Transformer 层(通常是 Wq 和 Wv),所以显存占比会比较低。
每层参数量通常缩小到 0.1% ~ 1%,设 rank = 16,则大概是 70M ~ 700M 参数,使用 FP16 存储的话,大约140MB ~ 1.4GB 的显存。
在这里插入图片描述

推理要求

● FP32(单精度浮点数):4字节(32位)
● FP16(半精度浮点数):2字节(16位)
● BF16(bfloat16):2字节(16位)
● INT8(8-bit整数):1字节(8位)
● INT4(4-bit整数):0.5字节(4位)

假设是 7B模型,7B * 0.5 / 10的9次 = 3.5GB
缓存、PyTorch、CUDA等缓存大约需要 1~2GB
显存大约6GB左右

假设是 14B模型,14B * 0.5 / 10的9次 = 7GB
缓存、PyTorch、CUDA等缓存大约需要 1~2GB
显存大约10GB左右

假设是32B模型,32B * 0.5 / 10的9次 = 16GB
缓存、PyTorch、CUDA等缓存大约需要 1~2GB
显存大约18GB左右

假设是70B模型,70B * 0.5 / 10的9次 = 35GB
缓存、PyTorch、CUDA等缓存大约需要 1~2GB (双卡可能要 2~4GB)
每张卡大约 35 GB / 2 + 2 = 20GB

暂时小节

正确评估显存需求对合理分配计算资源和优化模型运行性能至关重要。理解以上显存计算的基本公式,有助于高效地利用现有硬件资源,推动大模型的应用和开发。

希望本文能帮助读者更深入地理解大模型在运行与微调阶段显存消耗的具体计算方法,进而优化自己的训练与推理任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/903798.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

对Electron打包的exe文件进行反解析

一、了解 Electron 打包的 exe,本质上就是打包了网页 (HTMLCSSJS),核心文件是 app.asar。超级容易还原,还原率接近 100% 为什么 Electron 特别容易? 因为 Electron 根本没有真正编译成机器码,它只是把网页资源&…

【Vue2】1-创建一个Vue实例

Vue2官方文档 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title> </head&g…

【C语言练习】015. 声明和初始化指针

015. 声明和初始化指针 015. 声明和初始化指针1. 声明指针示例1:声明一个指向整数的指针2. 初始化指针示例2:将指针初始化为`NULL`示例3:将指针初始化为某个变量的地址示例4:将指针初始化为动态分配的内存地址3. 使用指针访问和修改变量的值示例5:使用指针访问和修改变量的…

好未来golang后端开发

OSI网络模型 TCP和UDP对比 HTTP和HTTPS对比 B树 HTTP常见状态码 线程和进程的区别 goroutine的调度模型GMP 常见的排序了解哪些 快速排序 func quickSort(data []int) {if len(data) < 1 {return}base : data[0]l, r : 0, len(data)-1for i : 1; i < r; {if data[i] &g…

(持续更新)Ubuntu搭建LNMP(Linux + Nginx + MySQL + PHP)环境

LNMP&#xff08;Linux Nginx MySQL PHP&#xff09;环境是在Linux操作系统上构建的一个高性能Web服务器环境。M也可以指代其他数据库&#xff0c;P也可以指代Python 1. 准备Linux系统 确保你已经在一台服务器或虚拟机上安装了Linux操作系统。推荐使用Ubuntu、CentOS或Debi…

服务器频繁重启日志分析与诊断

从你提供的日志来看&#xff0c;系统确实经历了多次重启。这个日志行显示的是&#xff1a; reboot system boot 6.8.0-58-generic Tue Apr 29 17:54 - 14:26 (20:31)这表示系统在4月29日17:54启动&#xff0c;运行了约20小时31分钟后&#xff0c;于次日14:26结束&#xff08;可…

如何提升个人的稳定性?

提升自我的稳定性是一个系统性工程&#xff0c;需要从内在认知、情绪管理、行为习惯到外在环境等多个维度进行优化。 以下是一些具体建议&#xff0c;帮助你逐步增强内心的稳定感&#xff1a; 一、内在认知调整 1. 建立清晰的自我认知 通过反思&#xff08;如写日记、冥想…

数值求解Eikonal方程的方法及开源实现

Eikonal方程是一类非线性偏微分方程&#xff0c;形式为 ( |\nabla u(x)| f(x) )&#xff0c;常见于波传播、几何光学、最短路径等问题。以下是数值求解Eikonal方程的方法及开源实现参考&#xff1a; 一、数值求解方法 有限差分法&#xff08;FDM&#xff09; 快速行进法&#…

基于Redis实现-用户签到

基于Redis实现-用户签到 这个功能将使用到Redis中的BitMap来实现。 我们按照月来统计用户签到信息&#xff0c;签到记录为1&#xff0c;未签到则记录为0 把每一个bit位对应当月的每一天&#xff0c;形成了映射关系。用0和1标示业务状态&#xff0c;这种思路称为位图(BitMap)。…

如何用GPU Instancing来优化树木草石重复模型

1&#xff09;如何用GPU Instancing来优化树木草石重复模型 2&#xff09;Unity ASTC压缩后的纹理在部分安卓机型上不显示 3&#xff09;现在大部分项目的竖版UI设计分辨率是多少 4&#xff09;Android上拖拽物体不实时跟随手指的问题 这是第430篇UWA技术知识分享的推送&#x…

Java面试高频问题(31-33)

三十一、服务网格&#xff1a;东西向流量治理与故障注入 服务网格架构分层 mermaid graph BT subgraph Control Plane APilot --> BEnvoy Sidecar CMixer --> B DCitadel --> B end subgraph Data Plane B --> E服务A B --> F服务B B --> G服务C end 核心能…

初学python的我开始Leetcode题8-3

提示&#xff1a;100道LeetCode热题-8-3主要是二叉树相关&#xff0c;包括三题&#xff1a;将有序数组转换为二叉搜索树、验证二叉搜索树、二叉搜索树中第K小的元素。由于初学&#xff0c;所以我的代码部分仅供参考。 目录 前言 题目1&#xff1a;将有序数组转换为二叉搜索树…

1996-2022年全国31省ZF干预度数据/财政干预度数据(含原始数据+计算过程+结果)

1996-2022年全国31省ZF干预度数据/财政干预度数据&#xff08;含原始数据计算过程结果&#xff09; 1、时间&#xff1a;1996-2022年 2、来源&#xff1a;国家统计局和各省年鉴 3、指标&#xff1a;地方财政一般预算支出、地区生产总值&#xff08;GDP&#xff09;、ZF干预度…

g4f升级到0.5.2.0版本了,但是有些机器无法运行,只能降级到0.5.1.2版本

g4f升级到0.5.2.0版本了&#xff0c;跟0.5.1.2更以前的版本相比&#xff0c;主要更新为增加了可以设置Huggingface等供应商的key Providers API key HuggingFace:Get API key HuggingSpace: 因为很多模型都会调用Huggingface&#xff0c;所以最好设置Huggingface的API key。…

C语言教程(二十五):C 语言函数可变参数详解

引言: 在 C 语言编程中,有时我们需要处理参数数量不固定的情况,比如常见的 printf 函数,它可以根据格式化字符串的要求接受任意数量的参数。这种能接受不确定数量参数的函数,就是可变参数函数。下面将深入探讨其定义、实现原理、使用方式、示例以及注意事项。 一、可变参…

OpenStack Yoga版安装笔记(25)Nova Cell理解

1、Nova Cell概述 &#xff08;官方文档&#xff1a;Cells (v2) — nova 25.2.2.dev5 documentation&#xff09; Nova中的cells功能的目的是允许较大的部署将其多个计算节点分割成多个cell。所有的nova部署都默认是cell部署&#xff0c;即使大多数情况下只有单一cell。这意味…

Java Set<String>:如何高效判断是否包含指定字符串?

在 Java 开发中&#xff0c;我们经常使用 Set 集合来存储一组唯一性的元素。特别是 HashSet&#xff0c;由于其基于哈希表的实现&#xff0c;在进行元素查找&#xff08;判断是否包含&#xff09;时通常具有非常高的效率&#xff08;平均时间复杂度 O(1)&#xff09;。 那么&a…

MySQL 查找指定表名的表的主键

原理 SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE WHERE TABLE_NAME 表名 AND CONSTRAINT_NAME PRIMARY方法 public static String getPk(String tableName) {String sql "SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE WHERE TA…

Java大厂面试突击:从Spring Boot自动配置到Kafka分区策略实战解析

第一轮核心知识 面试官:请解释Spring Boot中自动配置的工作原理并演示如何自定义一个@ConfigurationProperties组件? xbhog:自动配置通过EnableAutoConfiguration注解触发,结合当前环境判断(如是否检测到MyBatis依赖)和条件注解(@ConditionalOnClass)来决定是否启用配…

开发板型号 ESP32-DevKitC-32模块型号 ESP32-WROOM-32 和主控芯片 ESP32-D0WDQ6-V3

以下是关于开发板型号 ESP32-DevKitC-32、模块型号 ESP32-WROOM-32 和主控芯片 ESP32-D0WDQ6-V3 的详细介绍&#xff1a; 开发板型号&#xff1a;ESP32-DevKitC-32 概述&#xff1a;ESP32-DevKitC 是乐鑫推出的一款基于 ESP32 模组的小型开发板&#xff0c;板上模组的绝大部…