PyTorch : torch.cuda.amp: 自动混合精度详解

amp : 全称为 Automatic mixed precision,自动混合精度

amp功能:

可以在神经网络推理过程中,针对不同的层,采用不同的数据精度进行计算,从而实现节省显存和加快速度的目的。

通常,深度学习中使用的精度为32位(单精度)浮点数,而使用16位(半精度)浮点数可以将内存使用减半,同时还可以加快计算速度。然而,16位浮点数的精度较低,可能导致数值下溢或溢出,从而影响训练结果。

混合精度: 有不止一种精度的Tensor(不同精度的数值计算混合使用来加速训练和减少显存占用) :

  • torch.FloatTensor(浮点型 32位)(torch默认的tensor精度类型是torch.FloatTensor)
  • torch.HalfTensor(半精度浮点型 16位)

自动:

  • 预示着Tensor的dtype类型会自动变化,也就是框架按需自动调整tensor的dtype

使用自动混合精度 (amp) 的原因

  • torch.HalfTensor优势是存储小、计算快、更好的利用CUDA设备的Tensor Core。因此训练的时候可以减少显存的占用(可以增加batchsize了),同时训练速度更快
  • torch.HalfTensor劣势是数值范围小(更容易Overflow /Underflow)、舍入误差(Rounding Error,导致一些微小的梯度信息达不到16bit精度的最低分辨率,从而丢失)

混合精度训练机制

amp.autocast

  • amp.autocast():用户不需要手动对模型参数 dtype 转换,amp 会自动为算子选择合适的数值精度;是PyTorch中一种混合精度的技术,可在保持数值精度的情况下提高训练速度和减少显存占用。
  • amp.autocast():能够自动将16位浮点数转换为32位浮点数进行数值计算,并在必要时将结果转换回16位浮点数。这种自动转换可以帮助避免数值下溢或溢出的问题,并在保持数值精度的同时提高计算速度和减少显存占用。
  • 使用torch.cuda.amp.autocast()的过程如下:
  • 将模型和数据移动到GPU上
  • 使用torch.cuda.amp.autocast()上下文管理器包装模型的前向传递和损失计算
  • 使用scaler(即torch.cuda.amp.GradScaler对象)将反向传播的梯度缩放回16位
  • 执行梯度更新

问:使用 torch.cuda.amp.autocast() 将数据 从32位(单精度) 转换为 16位(半精度),会导致精度丢失嘛?
答:使用 torch.cuda.amp.autocast() 将数据从32位(单精度)转换为16位(半精度)会导致精度损失。由于16位浮点数只能表示更少的有效位数,因此它们的精度不如32位浮点数。在混合精度训练中,为了平衡精度和性能,通常会将网络的前向传播和反向传播过程中的参数和梯度计算使用半精度浮点数来加速计算。这种方法可以在一定程度上降低计算精度要求,但会带来一定的精度损失。
·
尽管存在精度损失,使用半精度浮点数的优点在于它们可以显著降低计算时间和显存消耗,从而使模型可以在更大的批量下进行训练,提高训练效率。此外,在实际应用中,对于某些任务,半精度精度的计算误差对于结果的影响可能不是很大,因此,半精度计算可以在保证结果准确性的前提下,大幅度提高模型的训练速度和效率。
原文链接:https://blog.csdn.net/weixin_37804469/article/details/129733868

amp.GradScaler

  • 对于反向传播的时候,FP16 的梯度数值溢出的问题,amp 提供了梯度 scaling 操作,而且在优化器更新参数前,会自动对梯度unscaling,所以,对用于模型优化的超参数不会有任何影响;
  • 具体来说,GradScaler 可以将梯度缩放到较小的范围,以避免数值下溢或溢出的问题,同时保持足够的精度以避免模型的性能下降。

以下是一个示例,展示了如何在 PyTorch 中使用 GradScaler:

import torch
from torch.cuda.amp import GradScaler, autocast# 创建 GradScaler 和模型
scaler = GradScaler()
model = torch.nn.Linear(10, 1).cuda()# 定义损失函数和优化器
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)# 定义一些训练数据和目标
x = torch.randn(32, 10).cuda()
y = torch.randn(32, 1).cuda()# 使用 GradScaler 进行自动混合精度训练
for i in range(1000):optimizer.zero_grad()# 将前向传递包装在autocast中以启用混合精度with autocast():y_pred = model(x)loss = loss_fn(y_pred, y)# 调用 GradScaler 的 backward() 方法计算梯度并缩放scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()if i % 100 == 0:print(f"Step {i}, loss={loss.item():.4f}")

在这个示例中:

  1. 创建一个 GradScaler 对象 scaler
  2. 定义模型和优化器
  3. 在训练循环中,使用 autocast() 上下文管理器将前向传递操作包装起来,这样就可以使用混合精度进行计算
  4. 调用 scaler.scale(loss) 计算损失的缩放版本,并调用 scaler.step(optimizer) 来更新模型参数
  5. 最后使用 scaler.update() 更新 GradScaler 对象的内部状态
  6. 这个过程可以重复进行多次,直到训练结束。

原文链接:https://blog.csdn.net/qq_43369406/article/details/130393078

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/219625.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C# 编写Windows服务程序

1.什么是windows服务? Microsoft Windows 服务(即,以前的 NT 服务)使您能够创建在它们自己的 Windows 会话中可长时间运行的可执行应用程序。这些服务可以在计算机启动时自动启动,可以暂停和重新启动而且不显示任何用…

Web前端 ---- 【Vue】Vue路由传参(query和params)

目录 前言 为什么用路由 路由route和路由器router Vue中路由的工作原理 安装配置vue-router 使用VueRouter 多级路由 路由传参 query传参 params传参 前言 本文介绍路由相关知识路由传参 为什么用路由 为了单页面应用开发,只更换组件,不频繁刷…

什么软件可以压缩视频大小?超级简单

什么软件可以压缩视频大小?当我们想将视频上传到网上时,有时候会遇到视频因为体积太大而无法上传的问题,这种情况就需要将视频进行压缩了。那什么软件可以压缩视频大小呢?下面小编就来为大家介绍压缩视频的方法,支持批…

django-release-debug-apache-mod-wsgi-原理解析

文章目录 1.django-release2.mod_wsgi2.1.winnt模式2.2.worker模式2.3.preforker模式2.4.小节 3.apache配置参数3.1.全局参数3.2.主机参数 4.总结 1.django-release 由于django处理静态资源的效率偏低,顾在release模式不支持静态资源,这种情况需要在apa…

极新AIGC行业峰会 | 圆桌对话:探索中国AGI迭代之路

“AGI正处在一个巨大的研发范式革命的起点。” 整理 | 周梦婕 编辑 | 小白 出品|极新 2023年11月28日,极新AIGC行业峰会在北京东升国际科学院拉开帷幕,峰会上午的圆桌环节由凡卓资本合伙人王梦菲主持,深势科技战略副总裁何雯…

SpringCache使用配置

项目中引入SpringCache pom文件引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-cache</artifactId> </dependency>配置文件指定缓存类型 spring:cache:type: redis启动类开启缓存注解…

TikTok卖家攻略!如何保证账号安全和多账号防关联?

TIKTOK的火爆程度&#xff0c;大家都有目共睹。随之而来的是越来越多的人在上面开展在线业务。作为TikTok的卖家&#xff0c;保障账号安全和防止多账号关联是非常重要的。在这篇博客文章中&#xff0c;我们将分享一些关于如何保护TikTok账号安全和防止多账号关联的实用建议。 …

Spring Boot 3.x.x Spring Security 6.x.x @PreAuthorize 失效

Spring Boot 3.x.x Spring Security 6.x.x PreAuthorize 失效 背景问题解决备注 背景 最近在搞一个后端项目&#xff0c;登录、接口权限、token认证。 版本 Spring Boot 3.2.0 JDK 21 Spring Security 6.2.0 问题 PreAuthorize 失效&#xff0c;没有走认证。 解决 给PreAu…

生成小程序URLlink链接遇到的坑

这里写自定义目录标题 前端生成小程序URL link背景用户打开小程序的常用方法短链接短链接优缺点优点缺点 生成短链接步骤 可能会遇到的问题&#xff1a;其他 注意&#x1f4e2; 前端生成小程序URL link ![h5打开小程序](https://img-blog.csdnimg.cn/direct/a4cfe3ef6d184c6d9…

OfficeWeb365 SaveDraw 文件上传漏洞复现

0x01 产品简介 OfficeWeb365 是专注于 Office 文档在线预览及PDF文档在线预览云服务,包括 Microsoft Word 文档在线预览、Excel 表格在线预览、Powerpoint 演示文档在线预览,WPS 文字处理、WPS 表格、WPS 演示及 Adobe PDF 文档在线预览。 0x02 漏洞概述 OfficeWeb365 Sav…

mapbox修改样式

mapbox有些其实document绘制而成&#xff0c;比如control控件 故而会涉及到样式修改&#xff0c;以适配系统主题 先决条件 必须要安装mapbox-gl&#xff0c;申请access_token yarn add mapbox-gl// or npm install mapbox-gl修改样式 新建一个_mapbox-gl.scss文件&#xff…

[字符串操作]删除单词后缀

删除单词后缀 题目描述 给一组各分别以er、ly和ing结尾的单词&#xff0c; 请删除每个单词的结尾的er、ly或ing&#xff0c; 然后按原顺序输出删除后缀后的单词&#xff08;删除后缀后的单词长度不为0&#xff09;。 关于输入 输入的第一行是一个整数n&#xff08;n≤50&am…

Python 反射

Python 反射是什么&#xff1f; 学习了几天&#xff0c;做个总结留给自己看。 感觉跟 SQL 入门要掌握的原理一样&#xff0c;Python 反射看起来也会做4件事&#xff0c;“增删查获” 增 - 增加属性&#xff0c;方法 setattr 删 - 删除属性&#xff0c;方法 delattr 查 - …

用 MATLAB 实现的计算机CT断层扫描图像重建项目源码

完整源码资源下载链接 计算机断层扫描图像重建 介绍 计算机断层扫描是堆叠在一起的 X 射线图像的集合&#xff0c;以获得作为诊断图像第三维的深度信息。这些“堆叠的” X 射线图像作为正弦图从 CT 机架接收&#xff0c;代表对象单层的 X 射线吸收剖面。该项目的目标是重建该…

Springboot整合篇Druid

一、概述 1.1简介 Druid 是阿里巴巴开源平台上一个数据库连接池实现&#xff0c;结合了 C3P0、DBCP 等 DB 池的优点&#xff0c;同时加入了日志监控。 它本身还自带一个监控平台&#xff0c;可以查看时时产生的sql、uri等监控数据&#xff0c;可以排查慢sql、慢请求&#xff0…

【如何理解select、poll、epoll?】

如何理解select、poll、epoll&#xff1f; select、poll、epollselectpollepoll 知识扩展三者之间的主要区别是什么&#xff1f;epoll的两种模式是什么&#xff1f; select、poll、epoll select、poll、epoll都是Linux中常见的I/O多路复用技术&#xff0c;他们可以用于同时监听…

广西岑溪市火灾通报:1人死亡 AI科技助力预防悲剧

近日&#xff0c;广西岑溪市玉梧大道紫坭工业园一厂房发生一起令人心痛的火灾事件&#xff0c;造成1人不幸丧生。这起悲剧再次提醒我们&#xff0c;火灾的防范工作是多么的重要。在这样的背景下&#xff0c;我想分享一个能够有效预防类似悲剧的技术——北京富维图像公司开发的F…

【Java】网络编程-UDP回响服务器客户端简单代码编写

这一篇文章我们将讲述网络编程中UDP服务器客户端的编程代码 1、前置知识 UDP协议全称是用户数据报协议&#xff0c;在网络中它与TCP协议一样用于处理数据包&#xff0c;是一种无连接的协议。 UDP的特点有&#xff1a;无连接、尽最大努力交付、面向报文、没有拥塞控制 本文讲…

如何处理PHP开发中的单元测试和自动化测试?

如何处理PHP开发中的单元测试和自动化测试&#xff0c;需要具体代码示例 随着软件开发行业的日益发展&#xff0c;单元测试和自动化测试成为了开发者们重视的环节。PHP作为一种广泛应用于Web开发的脚本语言&#xff0c;单元测试和自动化测试同样也在PHP开发中扮演着重要的角色…

基于Spring Boot、Mybatis、Redis和Layui的企业电子招投标系统源码实现与立项流程

招投标管理系统是一款适用于招标代理、政府采购、企业采购和工程交易等领域的企业级应用平台。该平台以项目为主线&#xff0c;从项目立项到项目归档&#xff0c;实现了全流程的高效沟通和协作。通过该平台&#xff0c;用户可以实时共享项目数据信息&#xff0c;实现规范化管理…