如何用更少的显存训练 PyTorch 模型

文章目录

1、引言

2、自动混合精度训练

3、低精度训练

4、梯度检查点

5、通过梯度累积减小批量大小

6、张量分片与分布式训练

7、高效数据加载

8、使用 In-Place 操作

9、Activation and Parameter Offloading

10、使用更精简的优化器

11、高级策略

12、总结


1、引言

在训练大型深度学习模型(包括LLM和视觉Transformer)时,最常见的瓶颈之一就是显存消耗达到峰值。由于大多数人无法使用大规模的GPU集群,因此在本文中将概述一些技术和策略,在不牺牲模型性能和预测准确性的情况下,将显存消耗降低近20倍。请记住,这些技术中的大多数应用并不互相排斥,可以很容易地结合使用,以提高显存效率。

2、自动混合精度训练

混合精度训练结合了16位(FP16)和32位(FP32)浮点格式。其核心思想是在低精度下执行大部分数学运算,从而降低显存带宽和存储需求,同时在计算的关键环节保留必要的精度保障。通过使用FP16存储激活值和梯度,这些张量的显存占用量可减少约一半。但需注意,某些网络层或运算仍需保持FP32精度以避免数值不稳定问题。值得庆幸的是,PyTorch对自动混合精度(AMP)的原生支持极大简化了这一过程。

注意这里是混合精度训练而不是低精度训练

什么是混合精度训练?

混合精度训练结合使用16位(FP16)和32位(FP32)浮点格式以保持模型精度。通过使用16位精度计算梯度,相比全32位精度计算,这一过程可大幅加快运算速度并显著减少显存占用。这种方法在显存或计算资源受限的场景下尤为实用。

之所以采用混合精度而非低精度这一表述,是因为并非所有参数或运算都被转换为16位格式。实际上,训练过程会在32位与16位运算之间动态切换,这种精度层级的有序交替正是该技术被称为混合精度的根本原因。

如上述示意图所示,混合精度训练流程首先将权重转换为低精度格式(FP16)以加速计算,随后梯度计算在低精度环境下完成,但为确保数值稳定性,这些梯度会被重新转换为高精度格式(FP32),最终经过缩放处理的梯度将用于更新原始权重。因此,通过这种机制既能提升训练效率,又不会牺牲网络的整体精度与稳定性。

如前所述,使用 torch.cuda.amp.autocast( ) 可以轻松启用该功能,一个简单的代码示例片段如下:

import torch
from torch.cuda.amp import autocast, GradScaler# Assume your model and optimizer have been defined elsewhere.
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()
for data, target in data_loader:optimizer.zero_grad()# Enable mixed precisionwith autocast():output = model(data)loss = loss_fn(output, target)# Scale the loss and backpropagatescaler.scale(loss).backward()scaler.step(optimizer)scaler.update()

3、低精度训练

如原文所述,理论上可以更进一步尝试完全使用16位低精度(而非混合精度)进行训练。但此时可能因16位浮点数的固有精度限制出现NaN值异常。为解决这一问题,业界开发了多种新型浮点格式,其中由谷歌专门为此研发的BF16应用较为广泛。简而言之,相较于标准的FP16,BF16拥有更大的动态范围——这种扩展的动态范围使其能够更精确地表示极大或极小的数值,从而更适配可能遭遇广泛数值区间的深度学习场景。虽然其较低的尾数精度在某些情况下可能影响计算准确性或引发舍入误差,但在大多数实践中对模型性能的影响微乎其微。

FP16与BF16的动态范围对比

虽然这种格式最初是为TPU开发的,但在大多数现代GPU(Nvidia Ampere架构及更高版本)也支持这种格式。大家可以使用以下方法检查您的GPU是否支持这种格式:

import torch
print(torch.cuda.is_bf16_supported())  # should print True

4、梯度检查点

即使采用混合精度与低精度训练,这些大型模型仍会生成大量中间张量,消耗可观的显存。梯度检查点技术通过在前向传播过程中选择性存储部分中间结果来解决这一问题——未被保存的中间张量将在反向传播阶段重新计算。尽管这会引入额外的计算开销,却能显著节省显存资源。

通过策略性选择需设置检查点的网络层,大家可通过动态重新计算激活值而非存储它们来减少显存使用。这种时间与内存的折中策略对于具有深层架构的模型特别有益,因为中间激活值占用了大部分内存消耗。以下是一个简单的使用示例:

import torch
from torch.utils.checkpoint import checkpoint
def checkpointed_segment(input_tensor):# This function represents a portion of your model# which will be recomputed during the backward pass.# You can create a custom forward pass for this segment.return model_segment(input_tensor)
# Instead of a conventional forward pass, wrap the segment with checkpoint.
output = checkpoint(checkpointed_segment, input_tensor)

采用该方法,在大多数情况下可使激活值的显存占用量降低40%至50%。尽管反向传播阶段因此增加了额外的计算量,但在GPU显存成为瓶颈的场景下,这种以时间换空间的策略通常是可接受的。

5、通过梯度累积减小批量大小

通过最初的方法,你可能会问自己:

为什么不干脆减少batchsize大小?

通过减小批量大小的确是减少显存占用最直接的方法,但需注意的是,这种方式在多数情况下会导致模型预测性弱于使用更大批量训练的模型。因此需要在显存限制与模型效果之间谨慎权衡。

那么如何达到平衡呢?

这正是梯度累积技术发挥作用之处!该方法通过在训练过程中虚拟增大有效批量规模:其核心原理是先在较小的批量上计算梯度,并经过多次迭代的累积(通过采用累加或平均方式),而非在每批次处理后立即更新模型参数。当累积梯度达到目标“虚拟”批量规模时,才使用聚合后的梯度一次性完成模型权重的更新。

这种技术的一个主要缺点是大大增加了训练时间。

6、张量分片与分布式训练

对于单个GPU无法容纳的庞大训练模型(即使经过上述优化),完全分片数据并行(FSDP)是不可或缺的。FSDP将模型参数、梯度和优化器状态分散到多个GPU上。这不仅能将巨大的模型放入显存,还能通过更好地分配通信开销提高训练效率。

FSDP不在每个GPU上维护模型的完整副本,而是在可用设备之间分配模型参数。在执行前向或后向传递时,只有相关的分片被加载到显存中。这种分片机制大大降低了对每台设备显存的需求,结合上述技术,在某些情况下甚至可以将显存需求降低10倍。

Tensor Parallel

样例如下:

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# Initialize your model and ensure it is on the correct device.
model = MyLargeModel().cuda()
# Wrap the model in FSDP for sharded training across GPUs.
fsdp_model = FSDP(model)

7、高效数据加载

在显存优化实践中,数据加载环节常被忽视。虽然优化重点通常集中在模型内部结构与计算过程上,但低效的数据处理可能引发不必要的性能瓶颈,同时影响显存占用与训练速度。若不确定如何优化数据加载器,可遵循以下经验法则:优先启用固定内存(Pinned Memory)与多工作进程(Multiple Workers)配置。

from torch.utils.data import DataLoader# Create your dataset instance and then the DataLoader with pinned memory enabled.
train_loader = DataLoader(dataset,batch_size=64,shuffle=True,num_workers=4,      # Adjust based on your CPU capabilitiespin_memory=True     # Enables faster host-to-device transfers
)

8、使用 In-Place 操作

在张量运算中,若未谨慎管理内存,每次操作都可能生成新对象。原地(In-Place)操作通过直接修改现有张量而非创建副本,可有效减少内存碎片化与总体内存占用。这种特性尤其有利于降低迭代训练循环中的临时内存分配开销。例如:

import torch
x = torch.randn(100, 100, device='cuda')
y = torch.randn(100, 100, device='cuda')
# Using in-place addition
x.add_(y)  # Here x is modified directly instead of creating a new tensor

9、Activation and Parameter Offloading

即便综合运用前述所有优化技术,在训练超大规模模型时,仍可能因海量中间激活值的瞬时占用而触及GPU显存容量极限。此时,中间数据卸载技术可作为额外的安全阀机制——其核心思路是将部分非即时必需的中间数据临时转换至CPU内存,从而为GPU显存腾出关键空间,确保训练流程持续进行。

我们通过策略性将部分激活值和或模型参数卸载至CPU内存,从而将GPU显存专用于核心计算任务。虽然如DeepSpeed、Fabric等专业框架已内置管理此类数据迁移的机制,大家仍可通过以下方式自主实现该功能。

def offload_activation(tensor):# Move tensor to CPU to save GPU memoryreturn tensor.cpu()def process_batch(data):# Offload some activations explicitlyintermediate = model.layer1(data)intermediate = offload_activation(intermediate)intermediate = intermediate.cuda()  # Move back when neededoutput = model.layer2(intermediate)return output

10、使用更精简的优化器

并非所有优化器对内存的需求均等。以广泛使用的Adam优化器为例,其针对每个模型参数需额外维护两个状态变量(均值与方差),导致内存占用倍增。相比之下,采用无状态优化器(如SGD)可将参数总量减少近三分之二——这对于训练大语言模型(LLMs)及其他大规模模型具有显著意义。

尽管普通SGD优化器存在收敛性能较弱的缺陷,但通过引入余弦衰减学习率调整策略(Cosine Decay Learning Rate Scheduler)可有效补偿这一不足。简而言之:

# instead of this
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
# use this
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
num_steps = NUM_EPOCHS * len(train_loader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)

通过这一调整,大家可以在显著改变峰值内存占有量的同时(具体取决于实际任务需求),仍能保持模型精度接近97%的水平。

11、高级策略

虽然上面列出的技术确实为我们奠定了坚实的基础,但我还想列出一些其他高级策略,我们可以考虑将 GPU 提升到极限:

  • 内存剖析和高速缓存管理

如果无法测量,就很难优化。PyTorch 提供了一些检查 GPU 内存使用情况的默认实用程序。使用方法如下:

import torch
# print a detailed report of current GPU memory usage and fragmentation
print(torch.cuda.memory_summary(device=None, abbreviated=False))
# free up cached memory that’s no longer needed by PyTorch
torch.cuda.empty_cache()
  • 使用TorchScript进行JIT编译

PyTorch 的即时(JIT)编译器使大家使用 TorchScript 将 Python 模型转换为优化的、可序列化的程序。通过优化内核启动和减少开销,这种转换可以提高内存和性能。您可以通过以下方式轻松访问它:

import torch
# Suppose `model` is an instance of your PyTorch network.
scripted_model = torch.jit.script(model)
# Now, you can run the scripted model just like before.
output = scripted_model(input_tensor)

尽管框架原生方法已能实现基础功能,但模型编译技术通常能带来更深层次的性能优化。

  • 自定义内核融合

编译的另一个主要好处是将多个操作融合到一个内核中。这有助于减少内存读/写,提高整体吞吐量。融合后的操作如下:

  • 使用torch.compile()进行动态内存分配

再次从编译中获益--使用 JIT 编译器可通过利用跟踪和图形优化技术的编译时优化来优化动态内存分配,从而进一步压缩内存并提高性能,尤其是在大型模型和Transformer架构中。

12、总结

随着 GPU 和云计算变得异常昂贵,只有充分利用现有资源才有意义。这有时可能意味着要在单个 GPU 工作站/笔记本电脑上对 LLM 或视觉Transformer进行训练/微调。上面列出的技术是研究人员/专业人士在算力紧张的情况下进行训练所使用的众多策略中的一部分。

参考资料:AI算法之道

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/80201.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

极速轻量,Rust 网络开发新选择:Hyperlane 框架深度解析

极速轻量,Rust 网络开发新选择:Hyperlane 框架深度解析 在高性能网络服务开发领域,Rust 凭借其内存安全与高效并发的特性备受青睐。今天,我们迎来一款专为现代 Web 服务打造的明星框架——Hyperlane,它以“轻量高效、…

单片机裸机环境下临界区保护

目录 1、直接中断屏蔽法 2、嵌套计数优化法 3、BASEPRI寄存器应用 4、动态优先级调整策略 5、LDREX/STREX指令应用 6、位带别名区原子访问 7、上下文感知保护 8、中断延迟优化技术 在嵌入式系统开发中,临界区保护是确保系统可靠性的关键技术。本文以ARM Cor…

【deepseek教学应用】001:deepseek如何撰写教案并自动实现word排版

本文讲述利用deepseek如何撰写教案并自动实现word高效完美排版。 文章目录 一、访问deepseek官网二、输入教案关键词三、格式转换四、word进一步排版 一、访问deepseek官网 官网:https://www.deepseek.com/ 进入主页后,点击【开始对话】,如…

springboot使用mybatisPlus进行数据库增删改查

springboot使用mybatisPlus进行数据库增删改查 提示:帮帮志会陆续更新非常多的IT技术知识,希望分享的内容对您有用。本章分享的是springboot的使用。前后每一小节的内容是存在的有:学习and理解的关联性。【帮帮志系列文章】:每个…

基于SpringBoot的校园周边美食探索及分享平台的设计与实现

资源详情: 私信我或点击链接获取: 基于SpringBoot的校园周边美食探索及分享平台的设计与实现资源-CSDN文库 摘要 美食一直是与人们日常生活息息相关的产业。传统的电话订餐或者到店消费已经不能适应市场发展的需求。随着网络的迅速崛起,互联…

到达最后一个房间的最少时间II 类似棋盘转移规律查找

文章目录 3342.到达最后一个房间的最少时间II 思路分析:最短路径问题,当然,由于不同的格子之间的移动的代价不统一,所以这个最短路径需要使用Dijkstra算法进行求解,对于直接使用Dijkstra算法模版的题目,大家可以先去做…

基于开源AI大模型AI智能名片S2B2C商城小程序源码的私域流量稳定性构建研究

摘要:在私域流量时代,传统实体零售的"时间积累"逻辑被直播电商等新业态颠覆。完美日记等新锐品牌通过构建私域流量池,实现了从0到1的指数级增长,而传统品牌却陷入"流量焦虑"。本文提出以开源AI大模型AI智能名…

做 iOS 调试时,我尝试了 5 款抓包工具

日常做开发的人,特别是和客户端接口打交道的同学,应该对“抓包”这件事不陌生。 调试登录流程、分析接口格式、排查错误返回、分析网络性能、甚至研究第三方 App 的数据通信……说到底,都绕不开“抓 HTTPS 包”这一步。 而这一步&#xff0…

Algolia - Docsearch的申请配置安装【以踩坑解决版】

👨‍🎓博主简介 🏅CSDN博客专家   🏅云计算领域优质创作者   🏅华为云开发者社区专家博主   🏅阿里云开发者社区专家博主 💊交流社区:运维交流社区 欢迎大家的加入&#xff01…

nginx 配置后端健康检查模块

nginx自带的针对后端节点健康检查的功能比较简单,通过默认自带的ngx_http_proxy_module 模块和ngx_http_upstream_module模块中的参数来完成,当后端节点出现故障时,自动切换到健康节点来提供访问。但是nginx不能事先知道后端节点状态是否健康,后端即使有不健康节点,负载均…

平板收银系统、国产系统,鸿蒙系统,小键盘的封装与应用—仙盟创梦IDE

数字小键盘封装 数组小键盘封装是指将与数组小键盘相关的功能、操作、数据等进行整合,形成一个独立的、可复用的模块。封装数组小键盘具有以下几方面重要意义: 提高代码可维护性 降低复杂度:数组小键盘在实际应用中,可能涉及到…

网工实验——OSPF配置

网络拓扑图 配置 1.为每个路由器配置接口(略)(详细见RIP实验) 2.配置OSPF AR1 [AR1]ospf [AR1-ospf-1]area 1 [AR1-ospf-1-area-0.0.0.1]network 172.16.1.1 0.0.0.0 #精确配置网络,也可以像下面那条命令那样配置 …

Kubernetes client-go 客户端类型与初始化指南

Kubernetes client-go 客户端类型与初始化指南 在 Kubernetes 的 client-go 库中,存在多种客户端用于与 API 服务器交互。以下介绍主要客户端类型,包括用途、初始化方式及 Demo。 1. RESTClient 用途 RESTClient 是底层 REST 客户端,直接…

java加强 -泛型

概念 定义类、接口、方法时&#xff0c;同时声明了一个或多个类型变量&#xff08;如<E>&#xff09;&#xff0c;称为泛型类、泛型接口、泛型方法、它们统称为泛型。 语法 public class ArrayList<E>{} E可以接收不同类型的数据&#xff0c;可以是字符串&…

C++ 项目 -- 高并发内存池

目录 项目介绍 内存池概念 池化技术 内存池 内存池主要解决的问题 malloc 定长内存池 申请内存 释放内存 整体框架设计 thread cache 申请内存 释放内存 central cache 申请内存 释放内存 page cache 申请内存 释放内存 大块内存申请实现 定长内存…

高效C/C++之九:Coverity修复问题:关于数组操作 和 内存操作

【关注我&#xff0c;后续持续新增专题博文&#xff0c;谢谢&#xff01;&#xff01;&#xff01;】 上一篇我们讲了&#xff1a; 这一篇我们开始讲&#xff1a; 高效C/C之九&#xff1a;Coverity修复问题&#xff1a;关于数组操作 和 内存操作 目录 【关注我&#xff0c;后…

vfrom表单设计器使用事件机制控制字段显示隐藏

1. 使用表单设计器进行debug调试 依据 vform3.0开发者文档 https://www.ganweicloud.com/docs/6.1.0/pages/d3e6d9/ 对switch组件设置事件逻辑 调试中

iPhone 和 Android 在日期格式方面的区别

整篇文章由iPhone 和 Android 在日期格式方面有所不同引起,大致介绍了,两种时间标准,以及在 JavaScript 下的格式转换方法。 Unix 时间戳是从1970年1月1日(UTC/GMT的午夜)开始所经过的秒数,不考虑闰秒。 iPhone 和 Android 在日期格式方面有所不同。其中,iPhone(iOS)使…

985高校查重率“隐性阈值”:低于5%可能被重点审查!

你是不是也以为&#xff1a; “查重率越低越好&#xff0c;最好压到1%、0%&#xff0c;导师看了都感动哭&#x1f979;” 但是你不知道的是——在985/211等重点高校&#xff0c;查重率太低反而可能引起导师和学术办公室的“特别关注”&#xff01; 今天就来扒一扒这个查重圈“…

【NLP】33. Pinecone + OpenAI :构建自定义语义搜索系统

Pinecone OpenAI 中文教学教程&#xff1a;构建自定义语义搜索系统 一、背景介绍 当下 AI 问答系统、矩阵检索、短文本分类等场景中&#xff0c;都需要很好地实现 “根据输入进行相似给点搜索”。这种算法基础称为 “向量搜索”&#xff0c;它的核心是将文本转换为向量后&am…