详解混合精度训练(Mixed Precision Training)

介绍

混合精度训练(Mixed Precision Training)是一种在深度学习中提高训练速度和减少内存占用的技术。在PyTorch中,通过使用半精度浮点数(16位浮点数,FP16)和单精度浮点数(32位浮点数,FP32)的组合。

优点

在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更大的 batch size、更大模型和尺寸更大的输入的训练。

FP16 和 FP32

FP16 和 FP32 是两种不同的浮点数表示格式,它们表示浮点数的精度和范围。

FP16(16位浮点数):

  • FP16 是一种半精度浮点数格式,它使用16位(2字节)来表示一个浮点数。
  • 它的格式通常包括1位符号位、5位指数位和10位尾数位。
  • 由于指数位较少,FP16能够表示的数值范围比FP32小,但它需要的内存和计算资源也更少。
  • FP16在深度学习中被用于加速计算和节省内存,尤其是在支持FP16运算的硬件上。

FP32(32位浮点数):

  • FP32 是一种单精度浮点数格式,它使用32位(4字节)来表示一个浮点数。
  • 它的格式包括1位符号位、8位指数位和23位尾数位。
  • 相比于FP16,FP32能够表示更大范围的数值,具有更高的精度,但也需要更多的内存和计算资源。
  • FP32是最常用的浮点数类型,适用于广泛的科学计算和工程应用。

在深度学习中,使用FP16进行训练可以显著减少模型的内存占用,加快数据传输和计算速度,尤其是在配备有Tensor Core的NVIDIA GPU上。然而,由于FP16的数值范围较小,可能会导致数值下溢(underflow)或精度损失,因此在训练过程中可能需要一些特殊的技术(如梯度缩放和混合精度训练)来确保模型的数值稳定性和最终精度。

基本流程

下面是一个使用PyTorch进行混合精度训练的例子:

  1. 准备环境
    首先,确保你的硬件和PyTorch版本支持FP16运算。然后,导入必要的库:
   import torchimport torch.nn as nnimport torch.optim as optimfrom torch.cuda.amp import autocast, GradScaler
  1. 定义模型
    创建一个简单的神经网络模型,例如一个多层感知机(MLP):
   class SimpleMLP(nn.Module):def __init__(self):super(SimpleMLP, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 2)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x
  1. 启用混合精度

    使用autocast()上下文管理器来指定哪些操作应该使用FP16执行:

   model = SimpleMLP().cuda()model.train()scaler = GradScaler()for epoch in range(num_epochs):for batch in data_loader:x, y = batchx, y = x.cuda(), y.cuda()with autocast():outputs = model(x)loss = criterion(outputs, y)# 反向传播和权重更新scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()

在这个例子中,autocast()将模型的前向传播和损失计算转换为FP16格式。然而,反向传播仍然是在FP32精度下进行的,这是为了保持数值稳定性。

  1. 使用GradScaler

由于FP16的数值范围较小,可能会导致梯度下溢(underflow)。GradScaler在反向传播之前将梯度的值放大,然后在权重更新之后将其缩放回来:

   scaler = GradScaler()

在计算梯度后,使用scaler.step(optimizer)来应用缩放后的梯度。

GradScaler

GradScaler 是 PyTorch 中 torch.cuda.amp 模块提供的一个工具,它用于帮助进行混合精度训练。在混合精度训练中,我们通常使用 FP16 来存储模型的权重和进行前向计算,以减少内存占用和加速计算。

然而,FP16 的数值范围比 FP32 小,这可能导致在梯度计算和权重更新时出现数值下溢(underflow),即梯度的数值变得非常小,以至于在 FP16 格式下无法有效表示。

GradScaler 通过在反向传播之前自动放大(scale up)梯度的值来解决这个问题。然后,在执行权重更新之后,GradScaler 会将放大的梯度缩放(scale down)回原来的大小。这个过程确保了即使在 FP16 格式下,梯度的数值也能保持在可表示的范围内,从而避免了数值下溢的问题。

scaler = torch.cuda.amp.GradScaler()
for inputs, targets in dataloader:with autocast():outputs = model(inputs)loss = loss_fn(outputs, targets)scaler.scale(loss).backward()  # 放大梯度scaler.step(optimizer)  # 应用缩放后的梯度进行权重更新scaler.update()  # 更新缩放因子
  1. 保存和加载模型
   torch.save(model.state_dict(), 'model.pth')model.load_state_dict(torch.load('model.pth'))

在混合精度训练中,虽然模型的权重在训练过程中可能会被转换为 FP16 格式以节省内存和加速计算,但在保存模型时,我们通常会将权重转换回 FP32 格式。这是因为 FP32 提供了更高的数值精度和更广泛的硬件支持,这使得模型在不同环境中的兼容性和可靠性更好。

在 PyTorch 中,当你调用 model.state_dict() 方法时,默认情况下它会返回一个包含 FP32 权重的字典。即使你在训练时使用了 FP16,这个字典也会包含 FP32 权重,因为 PyTorch 会先转换为 FP32 再保存。同样,当你使用 torch.load() 加载模型时,如果模型权重是 FP16 格式,PyTorch 会自动将它们转换为 FP32。

注意,如果你的模型是在 GPU 上训练的,加载模型时应该使用 map_location 参数来指定加载到 CPU,然后再将模型转换为 FP32 并移回 GPU。

可能出现的问题

混合精度训练是一种强大的技术,可以提高深度学习模型的训练速度和效率,同时减少内存使用。然而,尽管它有许多优点,但在实践中也可能遇到一些问题和挑战:

  1. 数值稳定性问题

    • 使用FP16可能会导致数值下溢(underflow),即非常小的数值在FP16格式中无法有效表示,变成零。
    • 由于FP16的精度较低,可能会在训练过程中引入舍入误差,影响模型的收敛和最终性能。
  2. 硬件兼容性

    • 并非所有的硬件都支持FP16运算。在没有专门Tensor Core的GPU上,使用FP16可能不会带来预期的性能提升。
    • 一些旧的或低端的GPU可能完全不支持FP16,这意味着混合精度训练无法在这些硬件上使用。
  3. 软件和库的支持

    • 一些深度学习框架和库可能没有完全支持混合精度训练,或者对FP16的支持不够成熟,这可能需要额外的工作来集成或调试。
  4. 模型和数据类型的转换

    • 在混合精度训练中,需要在FP32和FP16之间转换数据类型,这可能需要仔细管理以避免精度损失。
    • 某些操作可能需要显式地转换为FP32来保证数值稳定性,例如梯度缩放和权重更新。
  5. 调试和分析困难

    • 使用混合精度训练可能会使得模型的调试和性能分析更加复杂,因为需要跟踪哪些操作是在FP16下执行的,哪些是在FP32下执行的。
  6. 模型泛化能力

    • 在某些情况下,混合精度训练可能会影响模型的泛化能力,尤其是在模型对精度非常敏感的情况下。

为了解决这些问题,研究人员和工程师通常会采用一些策略,如使用数值稳定的算法、确保正确的数据类型转换、使用支持混合精度训练的深度学习框架和库,以及在必要时进行模型微调。此外,对于特别需要高精度的任务,可能会选择使用全精度(FP32)训练,以避免潜在的精度问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/825494.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

“全网最全”LLM推理框架集结营 | 看似微不足道,却决定着AIGC项目的成本、效率与性能!

00-前序 随着ChatGPT、GPT-4等大语言模型的出现,彻底点燃了国内外的学者们与企业家们研发LLM的热情。国内外的大语言模型如雨后春笋一般的出现,这些大语言模型有一部分是开源的,有一部分是闭源的。 伴随着大语言模型的出现,国内外…

树莓派驱动RGB灯-rpi-ws281x库安装

1 树莓派的操作系统安装 1.1 操作系统选择 这个选择64位的操作的系统来驱动,一定不要选择32位的操作系统。笔者在这个地方浪费不少时间,具体原因不是很清楚。如果32位的操作系统,后面在rpi-ws281x的库时候会有报错。 1.2 操作系统链接如下…

windows docker desktop==spark环境搭建

编写文件docker-compose.yml version: 3services:spark-master:image: bde2020/spark-master:3.1.1-hadoop3.2container_name: spark-masterports:- "8080:8080"- "7077:7077"- "2220:22"volumes:- F:\spark-data\m1:/dataenvironment:- INIT_D…

Spring学习(三)——AOP

AOP是在不改原有代码的前提下对其进行增强 AOP(Aspect Oriented Programming)面向切面编程,在不惊动原始设计的基础上为其进行功能增强,前面咱们有技术就可以实现这样的功能即代理模式。Java设计模式——代理模式-CSDN博客 基础概念 连接点&#xff08…

2024经常用且免费的10个网盘对比,看看哪个比较好用!

网盘在我们的工作和学习中经常会用到,也是存储资料的必备工具,有了它,我们就不用走到哪都带着移动硬盘了,而目前市场上的主流网盘还有数十款,其中有免费的也有付费的,各家不一,今天小编就来为您…

[Android]模拟器登录Google Play失败

问题: 模拟器登录Google Play失败,提示couldnt sign in there was a problem communicating with google servers. try again later. 原因: 原因是模拟器没有连接到互联网,打开模拟器中Google浏览器进行搜索一样不行。 解决&am…

移动硬盘(PSSD)中文件占用空间远大于文件大小

定义 文件的大小:文件内容实际具有的字节数,它以Byte为衡量单位,只要文件内容和格式不发生变化,文件大小就不会发生变化。 文件占用空间:文件在磁盘上的所占空间,它最小的计量单位是“簇(Cluster)”。 为…

MySQL高负载排查方法最佳实践(15/16)

高负载排查方法 CPU占用率过高问题排查 使用mpstat查看cpu使用情况。 # mpstat 是一款 CPU 性能指标实时展示工具 # 能展示每个 CPU 核的资源视情况,同时还能将资源使用情况进行汇总展示 # 如果CPU0 的 %idle 已经为 0 ,说明此核已经非常繁忙# 打印所…

Istio介绍

1.什么是Istio Istio是一个开源的服务网格(Service Mesh)框架,它提供了一种简单的方式来为部署在Kubernetes等容器编排平台上的微服务应用添加网络功能。Istio的核心功能包括: 服务治理:Istio能够帮助管理服务之间的…

微服务之CircuitBreaker断路器

一、概述 1.1背景 在一个分布式系统中,每个服务都可能会调用其它的服务器,服务之间是相互调用相互依赖。假如微服务A调用微服务B和微服务C,微服务B和微服务C又调用其他的微服务。这就是构成所谓“扇出”。 如果扇出的链路上某个微服务的调…

状态压缩DP题单

P1433 吃奶酪&#xff08;最短路&#xff09; dp(i, s) 表示从 i 出发经过的点的记录为 s 的路线距离最小值 #include<bits/stdc.h> #define int long long using namespace std; const int N 20; signed main() { int n; cin >> n;vector<double>x(n 1),…

C++项目 -- 负载均衡OJ(三)online_judge

C项目 – 负载均衡OJ&#xff08;三&#xff09;online_judge 文章目录 C项目 -- 负载均衡OJ&#xff08;三&#xff09;online_judge一、基于MVC结构的oj服务设计1.结构与功能 二、oj_model.hpp1.建立文件版题库2.文件版题库的服务模块3. MySQL版题库3.1.创建名为oj_client的用…

【uniapp】引入uni-ui组件库

&#xff08;1&#xff09;新建项目的时候选择 uni-ui项目 &#xff08;2&#xff09;已经创建好的项目去官网单独安装 跳转单独安装组件 https://uniapp.dcloud.net.cn/component/uniui/quickstart.html#%E9%80%9A%E8%BF%87-uni-modules-%E5%8D%95%E7%8B%AC%E5%AE%89%E8%A3%8…

202462读书笔记|《一世珍藏的诗歌200首》——你曾经羞赧地向我问起, 是谁最早在此留下足印

202462读书笔记|《一世珍藏的诗歌200首》——你曾经羞赧地向我问起&#xff0c; 是谁最早在此留下足印 《一世珍藏的诗歌200首》作者金宏宇&#xff0c;很多美好的诗&#xff0c;有徐志摩&#xff0c;戴望舒&#xff0c;林徽因&#xff0c;舒婷等的诗精选&#xff0c;很值得一读…

动态库和静态库

文章目录 一、 静态库二、动态库 一、 静态库 静态库&#xff08;.a&#xff09;&#xff1a;程序在编译链接的时候把库的代码链接到可执行文件中。程序运行的时候将不再需要静态库&#xff0c;因为他已经在你字节写的程序中。 编译静态库 将所有的.h文件拷贝到lib/include中…

2024年腾讯云服务器价格一览表

随着云计算技术的快速发展&#xff0c;越来越多的企业和个人开始选择使用云服务器来满足他们的数据存储和计算需求。腾讯云作为国内领先的云服务提供商&#xff0c;其服务器产品因性能稳定、安全可靠而备受用户青睐。那么&#xff0c;2024年腾讯云服务器的价格情况如何呢&#…

网络运输层之(3)GRE协议

网络运输层之(3)GRE协议 Author: Once Day Date: 2024年4月8日 一位热衷于Linux学习和开发的菜鸟&#xff0c;试图谱写一场冒险之旅&#xff0c;也许终点只是一场白日梦… 漫漫长路&#xff0c;有人对你微笑过嘛… 全系列文档可参考专栏&#xff1a;通信网络技术_Once-Day的…

OpenHarmony多媒体-video_trimmer

简介 videotrimmer是在OpenHarmony环境下&#xff0c;提供视频剪辑能力的三方库。 效果展示&#xff1a; 安装教程 ohpm install ohos/videotrimmerOpenHarmony ohpm环境配置等更多内容&#xff0c;请参考 如何安装OpenHarmony ohpm包 。 使用说明 目前支持MP4格式。 视频…

ansible模块实战-部署rsync服务端

目录 1、根据部署流程所用到的命令找出模块 2.实战部署 2.1 服务部署&#xff1a;yum 安装 2.2 准备好rsync服务的配置文件 &#xff0c;并将配置文件通过copy模块分发给192.168.81.136这台受控主机 2.3 创建虚拟机用户 2.4 创建密码文件和改权限 2.5 模块对应目录&…

《QT实用小工具·二十九》托盘图标控件

1、概述 源码放在文章末尾 托盘图标控件 可设置托盘图标对应所属主窗体。 可设置托盘图标。 可设置提示信息。 自带右键菜单。 下面是demo演示&#xff1a; 项目部分代码如下&#xff1a; #ifndef TRAYICON_H #define TRAYICON_H/*** 托盘图标控件* 1. 可设置托盘图标…