Pytorch 网络冻结的三种方法区别:detach、requires_grad、with_no_grad

1、requires_grad

requires_grad=True # 要求计算梯度;
requires_grad=False # 不要求计算梯度;

在pytorch中,tensor有一个 requires_grad参数,如果设置为True,那么它会追踪对于该张量的所有操作。在完成计算时可以通过调用backward()自动计算所有的梯度,并且,该张量的所有梯度会自动累加到张量的.grad属性;反之,如果设置为False,则不会记录这些操作过程,自然而然就不会进行计算梯度的工作 。 tensorrequires_grad的属性默认为False.

x = torch.tensor([1.0, 2.0])
x.requires_grad"""
结果:
False
"""

我们可以先看一下requires_grad参数设置分别为True和False时的情况。

# 设置好requires_grad的值为True
import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y = torch.tensor([3.0, 4.0], requires_grad=False)
y1 = 2.0 * x + 2.0 * yprint(x, x.requires_grad)
print(y, y.requires_grad)
print(y1, y1.requires_grad)y1.backward(torch.tensor([1.0, 1.0]))  
print(x.grad)
print(y.grad)"""
结果:
tensor([1., 2.], requires_grad=True) True
tensor([3., 4.]) False
tensor([ 8., 12.], grad_fn=<AddBackward0>) True
tensor([2., 2.])
None
"""

在上面的实验中,发现在计算中如果存在tensor张量xrequires_gradTrue的情况,那么计算之后的结果y1requires_grad也为True,且计算梯度时仅会计算x的梯度,因为前面设置了y张量的requires_gradFalse,所以最后y张量的grad属性值为None

关于张量tensor的梯度计算可以参考另一篇博客:Tensor及其梯度

所以在深度学习训练时,要冻结部分权值参数不进行参数更新的话,可以在优化器初始化之前将参数组进行筛选,将不想进行训练的参数的requires_grad设置为False。代码示例参考如下:

cnn = CNN() #构建网络for n,p in cnn.named_parameters():print(n,p.requires_grad)if n=="conv1.0.weight":p.requires_grad = Falseoptimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,cnn.parameters()), lr=learning_rate)

也可以把requires_grad属性置为 False这个操作放在optimizer之后,参数都不会进行更新。但是区别在于,先进行requires_grad属性置为False的操作,再optimizer初始化,不会将该层的参数放进优化器中更新,而先进行optimizer初始化,再进行requires_grad属性置为False的操作,会将所有的参数放进优化器中,但不更新该指定层参数,只更新剩下的参数。对比看来,optimizer中的参数量会相比前者会更大一点。

所以一般最好是将requires_grad属性置为 False这个操作放在optimizer之前。

注意事项:

1、requires_grad属性置为 False或者默认时,不能在对该tensor计算梯度,否则会进行报错。因为并没有追踪到任何计算历史,所以就不存在梯度的计算了。

import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y = torch.tensor([3.0, 4.0], requires_grad=False)
y1 = 2.0 * x + 2.0 * y
# y1.backward(torch.tensor([1.0, 1.0])) 
# print(x.grad)
y.backward(torch.tensor([1.0, 1.0]))"""
结果:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
"""

2、整数型的tensor并没有requires_grad这个属性,只有浮点类型的tensor可以计算梯度

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
d:\test.ipynb Cell 9 in <cell line: 2>()1 a = torch.tensor([1,2])
----> 2 b = torch.tensor([3,4], requires_grad=True)3 c = a+b4 print(a.requires_grad)RuntimeError: Only Tensors of floating point and complex dtype can require gradients

2、detach()

detach方法就是返回了一个新的张量,该张量与当前计算图完全分离,且该张量的计算将不会记录到梯度当中。

import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y = torch.tensor([3.0, 4.0], requires_grad=True)
z = torch.tensor([3.0, 2.0], requires_grad=True)
x = x * 2
z1 = z.detach()
x1 = x.detach() 
y1 = 2.0 * x1 + 2.0 * y + 3 * z1
y1.backward(torch.tensor([1.0, 1.0])) 
print(x.requires_grad)
print(x.grad)
print(y.requires_grad)
print(y.grad)
print(z.requires_grad)
print(z.grad)
print(z1.requires_grad)
print(z1.grad)"""
结果:
True
None
True
tensor([2., 2.])
True
None
False
None
C:\Users\26973\AppData\Local\Temp\ipykernel_37516\3652236761.py:12: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.print(x.grad)
"""

从上面实验可以看到,使用detach()方法后,可以截断反向传播的梯度流,其作用有点类似于将requires_grad属性置为False的情况。

requires_grad_()requires_grad属性置为False不同的是 detach()函数会返回一个新的Tensor对象b , 并且新Tensor是与当前的计算图分离的,其requires_grad属性为False,反向传播时不会计算其梯度。 ba共享数据的存储空间,二者指向同一块内存。

requires_grad_()函数会改变Tensorrequires_grad属性并返回Tensor修改requires_grad的操作是原位操作(in place)。其默认参数为requires_grad=Truerequires_grad=True时,自动求导会记录对Tensor的操作,requires_grad_()的主要用途是告诉自动求导开始记录对Tensor的操作。

关于detach()返回的张量与原张量共享数据的存储空间,二者指向同一块内存可以由以下代码看出:

z1[0] = 5.0
print(z)
print(z1)"""
结果:
tensor([5., 2.], requires_grad=True)
tensor([5., 2.])
"""

当我们修改z1中的数据时,z中的数据也随之修改。

**总结:**当我们在计算到某一步时,不需要在记录某一个张量的梯度时,就可以使用detach()将其从追踪记录当中分离出来,这样一来该张量对应计算产生的梯度就不会被考虑了。比较常见的就是在GAN生成模型中,当训练一次生成器后,再训练判别器时,需要对生成器生成的fake进行损失计算,但是又不希望这部分损失对生成器进行权值的更新,这个时候需要冻结生成器那部分的权值,因此通常将生成器生成的fake张量使用fetch()进行阶段,再输入到判别器进行运算,这样最后使用loss.backward()时仅会对判别器部分的梯度进行计算

import torch
import numpy as npy = torch.tensor([3.0, 4.0], requires_grad=True)
z = torch.tensor([3.0, 2.0], requires_grad=True)
z1 = z.detach()
z2 = z1 + y
y1 = torch.sum(3 * z2)
y1.backward() 
print(z2, z2.requires_grad)
print(y.grad)
print(z2.grad)
print(z1.grad)"""
结果:
tensor([6., 6.], grad_fn=<AddBackward0>) True
tensor([3., 3.])
None
None
"""

这里假设z是生成器生成的图片,z1表示的是使用detch()截断后的张量,y表示的判别器内部的一些运算张量,z2表示经过判别器后的结果,y1假设是计算loss的损失函数,我们可以看到,使用y1.backward() 后,不会对生成器生成的z产生任何的梯度,在优化器优化时自然而然不会对其进行优化。而对于后面的步骤仍然会跟踪记录其所有的计算过程,比如对于z1在判别器中进行运算仍然会记录其过程,并仅会对判别器内部的参数y进行梯度计算,从而进行优化

(注:这里z2.grad之所以也为None,是因为z2节点不是叶子节点,它是由z1y进行累加而来的,所以在z2处不会有grad属性,这部分可以看其给出的警告)

UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.

3、with_no_grad

torch.no_grad()是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。

# 设置好requires_grad的值为True
import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y1 = x ** 2with torch.no_grad():  # 这里使用了no_grad()包裹不需要被追踪的计算过程y2 = y1 * 2y3 = x ** 5y4 = y1 + y2 + y3print(y1, y1.requires_grad)
print(y2, y2.requires_grad)
print(y3, y3.requires_grad)
print(y4, y4.requires_grad)y4.backward(torch.ones(y4.shape))  # y1.backward() y2.backward()
print(x.grad)"""
结果:
tensor([1., 4.], grad_fn=<PowBackward0>) True
tensor([2., 8.]) False
tensor([ 1., 32.]) False
tensor([ 4., 44.], grad_fn=<AddBackward0>) True
tensor([2., 4.])
"""

可以看出,其实使用with torch.no_grad()这个后,被其包裹的所有运算都是不计算梯度的,其效果与detach()类似,所以使用下列代码的运行结果是一样的:

# 设置好requires_grad的值为True
import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y1 = x ** 2# with torch.no_grad():  # 这里使用了no_grad()包裹不需要被追踪的计算过程
y2 = y1 * 2
y3 = x ** 5
y2 = y2.detach()
y3 = y3.detach()y4 = y1 + y2 + y3print(y1, y1.requires_grad)
print(y2, y2.requires_grad)
print(y3, y3.requires_grad)
print(y4, y4.requires_grad)y4.backward(torch.ones(y4.shape))  # y1.backward() y2.backward()
print(x.grad)
"""
结果:
tensor([1., 4.], grad_fn=<PowBackward0>) True
tensor([2., 8.]) False
tensor([ 1., 32.]) False
tensor([ 4., 44.], grad_fn=<AddBackward0>) True
tensor([2., 4.])
"""

detach()是考虑将单个张量从追踪记录当中脱离出来;
torch.no_grad()是一个warper,可以将多个计算步骤的张量计算脱离出去,本质上没啥区别。

4、总结:

  1. requires_grad:在最开始创建Tensor时候可以设置的属性,用于表明是否追踪当前Tensor的计算操作。后面也可以通过requires_grad_()方法设置该参数,但是只有叶子节点才可以设置该参数。
  2. detach()方法:则是用于将某一个Tensor从计算图中分离出来。返回的是一个内存共享的Tensor,一变都变。
  3. torch.no_grad():对所有包裹的计算操作进行分离。但是torch.no_grad()将会使用更少的内存,因为从包裹的开始,就表明不需要计算梯度了,因此就不需要保存中间结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/155219.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阿里云服务器公网带宽升级的三种方法

阿里云服务器公网带宽不够用有哪些解决方法&#xff1f;可以更改带宽或带宽临时升级&#xff0c;更改带宽是永久公网带宽&#xff0c;带宽临时升级可以选择升级时间段&#xff0c;也可以绑定弹性公网EIP来修改公网带宽&#xff0c;阿里云服务器网aliyunfuwuqi.com分享阿里云服务…

海外IP代理科普——API代理是什么?怎么用?

随着互联网的不断发展&#xff0c;越来越多的企业开始使用API&#xff08;应用程序接口&#xff09;来实现数据的共享和交流。而在API使用中&#xff0c;海外代理IP也逐渐普及。那么&#xff0c;什么是API代理IP呢&#xff1f;它有什么作用&#xff1f;API接口有何用处&#xf…

从0开始学习JavaScript--JavaScript 函数

JavaScript中的函数是编写可维护、模块化代码的关键。本文将深入研究JavaScript函数的各个方面&#xff0c;包括基本语法、函数作用域、闭包、高阶函数、箭头函数等&#xff0c;并通过丰富的示例代码来帮助读者更好地理解和应用这些概念。 函数的基本语法 函数是一段可被重复…

git merge指定的文件

如果你只想在合并时包含特定类型的文件&#xff08;例如&#xff0c;只合并.java文件&#xff09;&#xff0c;可以使用git checkout命令和git merge命令的组合来实现。以下是一个可能的步骤&#xff1a; 确保在目标分支上&#xff1a; 在执行合并之前&#xff0c;请确保你已经…

openGauss学习笔记-129 openGauss 数据库管理-参数设置-查看参数值

文章目录 openGauss学习笔记-129 openGauss 数据库管理-参数设置-查看参数值129.1 操作步骤129.2 示例 openGauss学习笔记-129 openGauss 数据库管理-参数设置-查看参数值 openGauss安装后&#xff0c;有一套默认的运行参数&#xff0c;为了使openGauss与业务的配合度更高&…

nvim 配置教程

1. 主角&#xff1a; NeoVim sudo apt install -y ninja-build gettext cmake unzip curl sudo apt install -y universal-ctags cscope #函数跳转用到的依赖 git clone https://github.com/neovim/neovim.git --depth1 cd neovim make CMAKE_BUILD_TYPERelWithDebInfo sudo…

C#学习相关系列之Linq用法---where和select用法(二)

一、select用法 Linq中的select可以便捷使我们的对List中的每一项进行操作&#xff0c;生成新的列表。 var ttlist.select(p>p10); //select括号内为List中的每一项&#xff0c;p10即为对每一项的操作&#xff0c;即对每项都加10生成新的List 用法实例&#xff1a; 1、la…

C++ 20类型转换指南:使用场景与最佳实践

C 20类型转换指南&#xff1a;使用场景与最佳实践 类型转换 (Casts) C 提供了五种特定的类型转换&#xff1a;const_cast<>()、static_cast<>()、reinterpret_cast<>()、dynamic_cast<>() 和 C20 引入的 std::bit_cast<>()。 请注意&#xff…

阿里云服务器带宽可以修改吗?不够用怎么办?

阿里云服务器公网带宽不够用有哪些解决方法&#xff1f;可以更改带宽或带宽临时升级&#xff0c;更改带宽是永久公网带宽&#xff0c;带宽临时升级可以选择升级时间段&#xff0c;也可以绑定弹性公网EIP来修改公网带宽&#xff0c;阿里云服务器网aliyunfuwuqi.com分享阿里云服务…

NoSQL 与传统数据库的集成

数据库集成势在必行 随着数据格局以前所未有的复杂性和规模发展&#xff0c;围绕数据库的叙述已经发生了巨大的变化。NoSQL 数据库已成为传统关系数据库的引人注目的替代品&#xff0c;在可扩展性、灵活性和数据模型多样性方面提供了显着的优势。然而&#xff0c;由于其 ACID …

SpringCloud原理-OpenFeign篇(三、FeignClient的动态代理原理)

文章目录 前言正文一、前戏&#xff0c;FeignClientFactoryBean入口方法的分析1.1 从BeanFactory入手1.2 AbstractBeanFactory#doGetBean(...)中对FactoryBean的处理1.3 结论 FactoryBean#getObject() 二、FeignClientFactoryBean实现的getObject()2.1 FeignClientFactoryBean#…

oepnpnp - 自己出图做开口扳手

文章目录 oepnpnp - 自己出图做开口扳手概述笔记做好的一套扳手实拍美图工程图END oepnpnp - 自己出图做开口扳手 概述 我的openpnp设备顶部相机安装支架, 由于结构限制, 螺柱的安装位置和机械挂壁的距离太近了. 导致拧紧(手工或者工具)很困难. 也不能重新做相机支架, 因为将…

构建和应用卡尔曼滤波器 (KF)--扩展卡尔曼滤波器 (EKF)

作为一名数据科学家&#xff0c;我们偶尔会遇到需要对趋势进行建模以预测未来值的情况。虽然人们倾向于关注基于统计或机器学习的算法&#xff0c;但我在这里提出一个不同的选择&#xff1a;卡尔曼滤波器&#xff08;KF&#xff09;。 1960 年代初期&#xff0c;Rudolf E. Kal…

腾讯云CVM标准型S5性能如何?CPU采用什么型号?

腾讯云服务器CVM标准型S5实例具有稳定的计算性能&#xff0c;CVM 2核2G S5活动优惠价格280.8元一年自带1M带宽&#xff0c;15个月313.2元、2核4G配置748.2元15个月&#xff0c;CPU内存配置还可以选择4核8G、8核16G等配置&#xff0c;公网带宽可选1M、3M、5M或10M&#xff0c;腾…

传输层——UDP协议

文章目录 一.传输层1.再谈端口号2.端口号范围划分3.认识知名端口号4.两个问题5.netstat与iostat6.pidof 二.UDP协议1.UDP协议格式2.UDP协议的特点3.面向数据报4.UDP的缓冲区5.UDP使用注意事项6.基于UDP的应用层协议 一.传输层 在学习HTTP等应用层协议时&#xff0c;为了便于理…

C语言初学3:变量和常量

一、变量的定义与初始化 # include <stdio.h> int main() {int age; //定义整型变量float salary; //定义浮点型变量char grade; //定义字符型变量 int *ptr; //定义指针变量 int i, j, k; //定义多个变量int x 10; …

【Python】可再生能源发电与电动汽车的协同调度策略研究

1 主要内容 之前发布了《可再生能源发电与电动汽车的协同调度策略研究》matlab版本程序&#xff0c;本次发布的为Python版本&#xff0c;采用gurobi作为求解器&#xff0c;有需要的可以下载对照学习研究。 首先详细介绍了优化调度模型的求解方案&#xff0c;分别采用二次规划…

初识linux(1)

文章目录 什么是linux什么是操作系统&#xff1f;开源 怎么装linux的环境基础指令lspwdcdtouchmkdirrmdir与rmmancpmv 什么是linux linux是一款开源操作系统 什么是操作系统&#xff1f; 操作系统&#xff1a;一种对计算机所有计算机软硬件进行控制和管理的系统软件 开源 开源&…

npm ERR! Cannot read properties of null (reading ‘pickAlgorithm‘)

node版本问题&#xff0c;版本太高&#xff0c;降低就行&#xff0c;我将到v16.14.1就行了

C#入门(9):多态介绍与代码演示

多态性是面向对象编程的一个核心概念&#xff0c;它允许你使用一个父类引用来指向一个子类对象。这可以使程序具有可扩展性&#xff0c;并且可以用来实现一些高级编程技术&#xff0c;如接口、事件、抽象类等。 多态相关的概念 以下是一些在C#中使用多态性的关键概念&#xf…