PyTorch 中的累积梯度

https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch

有一个小的计算图,两次前向梯度累积的结果,可以看到梯度是严格相等的。

在这里插入图片描述
在这里插入图片描述

代码:

import numpy as np
import torchclass ExampleLinear(torch.nn.Module):def __init__(self):super().__init__()# Initialize the weight at 1self.weight = torch.nn.Parameter(torch.Tensor([1]).float(),requires_grad=True)def forward(self, x):return self.weight * xmodel = ExampleLinear()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)def calculate_loss(x: torch.Tensor) -> torch.Tensor:y = 2 * xy_hat = model(x)temp1 = (y - y_hat)temp2 = temp1**2return temp2# With mulitple batches of size 1
batches = [torch.tensor([4.0]), torch.tensor([2.0])]optimizer.zero_grad()
for i, batch in enumerate(batches):# The loss needs to be scaled, because the mean should be taken across the whole# dataset, which requires the loss to be divided by the number of batches.temp2 = calculate_loss(batch)loss = temp2 / len(batches)loss.backward()print(f"Batch size 1 (batch {i}) - grad: {model.weight.grad}")print(f"Batch size 1 (batch {i}) - weight: {model.weight}")print("="*50)# Updating the model only after all batches
optimizer.step()
print(f"Batch size 1 (final) - grad: {model.weight.grad}")
print(f"Batch size 1 (final) - weight: {model.weight}")

运行结果

Batch size 1 (batch 0) - grad: tensor([-16.])
Batch size 1 (batch 0) - weight: Parameter containing:
tensor([1.], requires_grad=True)
==================================================
Batch size 1 (batch 1) - grad: tensor([-20.])
Batch size 1 (batch 1) - weight: Parameter containing:
tensor([1.], requires_grad=True)
==================================================
Batch size 1 (final) - grad: tensor([-20.])
Batch size 1 (final) - weight: Parameter containing:
tensor([1.2000], requires_grad=True)

然而,如果训练一个真实的模型,结果没有这么理想,比如训练一个bert,𝐵=8,𝑁=1:没有梯度累积(累积每一步),

𝐵=2,𝑁=4:梯度累积(每 4 步累积一次)
使用带有梯度累积的 Batch Normalization 通常效果不佳,原因很简单,因为 BatchNorm 统计数据无法累积。更好的解决方案是使用 Group Normalization 而不是 BatchNorm。
https://ai.stackexchange.com/questions/21972/what-is-the-relationship-between-gradient-accumulation-and-batch-size

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/21164.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言每日一题:13《数据结构》环形链表。

题目链接: 一.环形链表运动基础。 使用快慢指针利用相对移动的思想: 1.第一种情况: 1,令快指针(fast)速度为2. 2.慢指针(slow)速度为1. 3.以慢指针进入环中开始。 4。假设slow刚刚进入环中fast…

【夜深人静学习数据结构与算法 | 第十二篇】动态规划——背包问题

目录 前言: 01背包问题: 二维数组思路: 一维数组思路: 总结: 前言: 在前面我们学习动态规划理论知识的时候,我就讲过要介绍一下背包问题,那么今天我们就来讲解一下背包问题。 在这…

机器学习笔记 - 目标检测中的加权框融合与非极大值抑制的对比

一、对象检测后处理 后处理步骤是目标检测中一个琐碎但重要的组成部分。本文主要是为了了解当拥有多个对象检测模型的集合时,加权框融合(WBF)相对于传统非极大值抑制(NMS)作为对象检测中后处理步骤的差异。 对象检测模型通常使用非极大值抑制 (NMS) 作为默认后处理步骤来过…

权限测试点

1. 只给用户单独设置一个权限,设置之后检查该权限是否生效 2.给用户不设置任何权限,设置之后检查该用户能否使用系统 3.给用户设置全部权限,设置之后检查所有权限是否生效 4.给用户设置部分权限,设置后检查部分权限是否生效 5.用户…

NetApp 入门级全闪存系统 AFF A250:小巧而强大

NetApp 入门级全闪存系统 AFF A250:小巧而强大 作为 AFF A 系列中的入门级全闪存系统,AFF A250 不但可以简化数据管理,还能为您的所有工作负载提供令人惊叹的强劲动力,价格也平易近人。 AFF A250:您的新 IT 专家 AFF…

工厂方法模式(Factory Method)

工厂方法模式就是定义一个用于创建对象的接口,让子类决定实例化哪一个类。工厂方法模式将类的实例化(具体产品的创建)延迟到工厂类的子类(具体工厂)中完成,即由子工厂类来决定该实例化哪一个类。 Define a…

Knife4j系列--解决不显示文件上传的问题

原文网址&#xff1a;Knife4j系列--解决不显示文件上传的问题_IT利刃出鞘的博客-CSDN博客 简介 本文介绍使用Knife4j时无法上传文件的问题。 问题复现 依赖 <dependency><groupId>com.github.xiaoymin</groupId><artifactId>knife4j-spring-boot-…

golang定时任务库cron实践

简介 cron一个用于管理定时任务的库&#xff0c;用 Go 实现 Linux 中crontab这个命令的效果。之前我们也介绍过一个类似的 Go 库——gron。gron代码小巧&#xff0c;用于学习是比较好的。但是它功能相对简单些&#xff0c;并且已经不维护了。如果有定时任务需求&#xff0c;还…

为什么马斯克和奥特曼都想重振加密货币?

1、前言 加密货币已经死了吗&#xff1f;这个问题的答案取决于谁来回答。一个加密爱好者会给你一百个不同的理由来解释为什么加密货币没有死。特斯拉CEO埃隆马斯克和OpenAI CEO 山姆奥特曼都对加密货币及其在塑造未来世界中的潜在作用有着浓厚的兴趣。 在过去很长一段时间里&…

国内GitHub加速访问工具-Fetch GitHub Hosts

一、工具介绍 Fetch GitHub Hosts是一款开源跨平台的国内GitHub加速访问工具&#xff0c;主要为解决研究及学习人员访问 Github 过慢或其他问题而提供的 Github Hosts 同步工具。 项目原理&#xff1a;是通过部署此项目本身的服务器来获取 github.com 的 hosts&#xff0c;而…

Stability AI旗舰图像模型 SDXL1.0发布,AI绘画进入新的时代

Stability AI于7月26号开源了SDXL1.0文生图模型&#xff0c;要知道距离SDXL0.9开源发布也不过一个月,只能说AI发展日新月异。 根据官网介绍&#xff0c;SDXL1.0经过迭代更新&#xff0c;已经是目前世界上最好的图像生成模型 官网根据Discord上的几代实验模型和外部测试&#…

【力扣刷题 | 第二十四天】

目录 前言&#xff1a; 416. 分割等和子集 - 力扣&#xff08;LeetCode&#xff09; 总结 前言&#xff1a; 今晚我们爆刷动态规划类型的题目。 416. 分割等和子集 - 力扣&#xff08;LeetCode&#xff09; 给你一个 只包含正整数 的 非空 数组 nums 。请你判断是否可以将这…

磁共振图像处理中 fft1c 和 ifft1c 函数的 Python 实现

fft1c 和 ifft1c 是 MRI 图像处理的常用函数。通常使用如下的 Matlab 实现 &#xff08;Michael Lustig&#xff0c;2005&#xff09; function res ifft1c(x,dim)% res fft1c(x) % % orthonormal forward 1D FFT %nsize(x,dim); shftzeros(1,5); shft(dim)ceil(n/2);xcirc…

【UE5】UE5与Python Socket通信中文数据接收不全

最近在使用UE的Socket模块与Python服务器进行通信时遇到了一些坑&#xff0c;特此记录一下。 先来复现一下问题&#xff0c;这里只截取关键代码。 UE端&#xff1a; bool ASoc::SendMsg(const FString& Msg) {TSharedRef<FInternetAddr> TargetAddr ISocketSubsy…

Anaconda+Win10安装保姆级教程(合集)

1.安装Anaconda Anaconda3安装及环境配置 2.在使用Anaconda的过程中经常会出现conda的环境默认会安装在C盘 conda安装环境指定位置解决方案 3.在使用pycharm2023版时&#xff0c;尽管cmd能activate自己安装的新环境&#xff0c;但是pychram一直检测不到 pycharm2023找不到…

springboot后端用WebSocket每秒向前端传递数据,python接收数据

1 springboot 1.1 加依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId></dependency> 1.2 WebSocketConfig 后端设置前端请求的网址,注册请求的信息 import org.…

广州银行信用卡中心:强化数字引擎安全,实现业务稳步增长

广州银行信用卡中心是全国城商行中仅有的两家信用卡专营机构之一&#xff0c;拥有从金融产品研发至销售及后期风险控制、客户服务完整业务链条&#xff0c;曾获“2016年度最佳创新信用卡银行”。 数字引擎驱动业务增长 安全左移降低开发风险 近年来&#xff0c;广州银行信用卡…

网络:VRP介绍

1. VRP 华为使用的通用路由平台&#xff0c;华为的交换机、防火墙、安全设备、无线和路由器的命令行几乎一样。 2. VRP分为用户视图、系统视图。 3. 用户视图 user view <Huawei>&#xff1a;其中<>代表的是用户视图&#xff0c;Huawei是设备的名称。命令比较少。…

H5实现签字版签名功能

前言&#xff1a;H5时常需要给C端用户签名的功能&#xff0c;以下是基于Taro框架开发的H5页面实现 一、用到的技术库 签字库&#xff1a;react-signature-canvas主流React Hooks 库&#xff1a;ahooks 二、组件具体实现 解决H5样式问题&#xff0c;主要还是通过两套样式实现…

a-date-picker报错TypeError: date4.locale is not a function

问题描述 使用日期选择器&#xff0c;数据从后端获得&#xff0c;再赋值给a-date-picker做数据回显&#xff0c;遇到这个报错&#xff0c;排错后定位到a-date-picker组件本身接收数据的问题。 如果使用了dayjs或moment库来处理时间字符串&#xff0c;并且使用.format对时间数据…