浅谈 PyTorch 中的 tensor 及使用

浅谈 PyTorch 中的 tensor 及使用

转自:浅谈 PyTorch 中的 tensor 及使用

这篇文章主要是围绕 PyTorch 中的 tensor 展开的,讨论了张量的求导机制,在不同设备之间的转换,神经网络中权重的更新等内容。面向的读者是使用过 PyTorch 一段时间的用户。本文中的代码例子基于 Python 3 和 PyTorch 1.1,如果文章中有错误或者没有说明白的地方,欢迎在评论区指正和讨论。

文章具体内容分为以下6个部分:

  1. tensor.requires_grad
  2. torch.no_grad()
  3. 反向传播及网络的更新
  4. tensor.detach()
  5. CPU and GPU
  6. tensor.item()

因为本文大部分内容是听着冷鸟的歌完成的,故用此标题封面。

1. requires_grad

当我们创建一个张量 (tensor) 的时候,如果没有特殊指定的话,那么这个张量是默认是不需要求导的。我们可以通过 tensor.requires_grad 来检查一个张量是否需要求导。

在张量间的计算过程中,如果在所有输入中,有一个输入需要求导,那么输出一定会需要求导;相反,只有当所有输入都不需要求导的时候,输出才会不需要 [1]。

举一个比较简单的例子,比如我们在训练一个网络的时候,我们从 DataLoader 中读取出来的一个 mini-batch 的数据,这些输入默认是不需要求导的,其次,网络的输出我们没有特意指明需要求导吧,Ground Truth 我们也没有特意设置需要求导吧。这么一想,哇,那我之前的那些 loss 咋还能自动求导呢?其实原因就是上边那条规则,虽然输入的训练数据是默认不求导的,但是,我们的 model 中的所有参数,它默认是求导的,这么一来,其中只要有一个需要求导,那么输出的网络结果必定也会需要求的。来看个实例:

input = torch.randn(8, 3, 50, 100)
print(input.requires_grad)
# Falsenet = nn.Sequential(nn.Conv2d(3, 16, 3, 1),nn.Conv2d(16, 32, 3, 1))
for param in net.named_parameters():print(param[0], param[1].requires_grad)
# 0.weight True
# 0.bias True
# 1.weight True
# 1.bias Trueoutput = net(input)
print(output.requires_grad)
# True

诚不欺我!但是,大家请注意前边只是举个例子来说明。在写代码的过程中,不要把网络的输入和 Ground Truth 的 requires_grad 设置为 True。虽然这样设置不会影响反向传播,但是需要额外计算网络的输入和 Ground Truth 的导数,增大了计算量和内存占用不说,这些计算出来的导数结果也没啥用。因为我们只需要神经网络中的参数的导数,用来更新网络,其余的导数都不需要。

好了,有个这个例子做铺垫,那么我们来得寸进尺一下。我们试试把网络参数的 requires_grad 设置为 False 会怎么样,同样的网络:

input = torch.randn(8, 3, 50, 100)
print(input.requires_grad)
# Falsenet = nn.Sequential(nn.Conv2d(3, 16, 3, 1),nn.Conv2d(16, 32, 3, 1))
for param in net.named_parameters():param[1].requires_grad = Falseprint(param[0], param[1].requires_grad)
# 0.weight False
# 0.bias False
# 1.weight False
# 1.bias Falseoutput = net(input)
print(output.requires_grad)
# False

这样有什么用处?用处大了。我们可以通过这种方法,在训练的过程中冻结部分网络,让这些层的参数不再更新,这在迁移学习中很有用处。我们来看一个 官方 Tutorial: FINETUNING TORCHVISION MODELS 给的例子:

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():param.requires_grad = False# 用一个新的 fc 层来取代之前的全连接层
# 因为新构建的 fc 层的参数默认 requires_grad=True
model.fc = nn.Linear(512, 100)# 只更新 fc 层的参数
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)# 通过这样,我们就冻结了 resnet 前边的所有层,
# 在训练过程中只更新最后的 fc 层中的参数。

2. torch.no_grad()

当我们在做 evaluating 的时候(不需要计算导数),我们可以将推断(inference)的代码包裹在 with torch.no_grad(): 之中,以达到 暂时 不追踪网络参数中的导数的目的,总之是为了减少可能存在的计算和内存消耗。看 官方 Tutorial 给出的例子:

x = torch.randn(3, requires_grad = True)
print(x.requires_grad)
# True
print((x ** 2).requires_grad)
# Truewith torch.no_grad():print((x ** 2).requires_grad)# Falseprint((x ** 2).requires_grad)
# True

3. 反向传播及网络的更新

这部分我们比较简单地讲一讲,有了网络输出之后,我们怎么根据这个结果来更新我们的网络参数呢。我们以一个非常简单的自定义网络来讲解这个问题,这个网络包含2个卷积层,1个全连接层,输出的结果是20维的,类似分类问题中我们一共有20个类别,网络如下:

class Simple(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False)self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1, bias=False)self.linear = nn.Linear(32*10*10, 20, bias=False)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.linear(x.view(x.size(0), -1))return x

接下来我们用这个网络,来研究一下整个网络更新的流程:

# 创建一个很简单的网络:两个卷积层,一个全连接层
model = Simple()
# 为了方便观察数据变化,把所有网络参数都初始化为 0.1
for m in model.parameters():m.data.fill_(0.1)criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)model.train()
# 模拟输入8个 sample,每个的大小是 10x10,
# 值都初始化为1,让每次输出结果都固定,方便观察
images = torch.ones(8, 3, 10, 10)
targets = torch.ones(8, dtype=torch.long)output = model(images)
print(output.shape)
# torch.Size([8, 20])loss = criterion(output, targets)print(model.conv1.weight.grad)
# None
loss.backward()
print(model.conv1.weight.grad[0][0][0])
# tensor([-0.0782, -0.0842, -0.0782])
# 通过一次反向传播,计算出网络参数的导数,
# 因为篇幅原因,我们只观察一小部分结果print(model.conv1.weight[0][0][0])
# tensor([0.1000, 0.1000, 0.1000], grad_fn=<SelectBackward>)
# 我们知道网络参数的值一开始都初始化为 0.1 的optimizer.step()
print(model.conv1.weight[0][0][0])
# tensor([0.1782, 0.1842, 0.1782], grad_fn=<SelectBackward>)
# 回想刚才我们设置 learning rate 为 1,这样,
# 更新后的结果,正好是 (原始权重 - 求导结果) !optimizer.zero_grad()
print(model.conv1.weight.grad[0][0][0])
# tensor([0., 0., 0.])
# 每次更新完权重之后,我们记得要把导数清零啊,
# 不然下次会得到一个和上次计算一起累加的结果。
# 当然,zero_grad() 的位置,可以放到前边去,
# 只要保证在计算导数前,参数的导数是清零的就好。

这里,我们多提一句,我们把整个网络参数的值都传到 optimizer 里面了,这种情况下我们调用 model.zero_grad(),效果是和 optimizer.zero_grad() 一样的。这个知道就好,建议大家坚持用 optimizer.zero_grad()。我们现在来看一下如果没有调用 zero_grad(),会怎么样吧:

# ...
# 代码和之前一样
model.train()# 第一轮
images = torch.ones(8, 3, 10, 10)
targets = torch.ones(8, dtype=torch.long)output = model(images)
loss = criterion(output, targets)
loss.backward()
print(model.conv1.weight.grad[0][0][0])
# tensor([-0.0782, -0.0842, -0.0782])# 第二轮
output = model(images)
loss = criterion(output, targets)
loss.backward()
print(model.conv1.weight.grad[0][0][0])
# tensor([-0.1564, -0.1684, -0.1564])

我们可以看到,第二次的结果正好是第一次的2倍。第一次结束之后,因为我们没有更新网络权重,所以第二次反向传播的求导结果和第一次结果一样,加上上次我们没有将 loss 清零,所以结果正好是2倍。另外大家可以看一下这个博客 (torch 代码解析 为什么要使用 optimizer.zero_grad() ),我觉得讲得很好。

4. tensor.detach()

接下来我们来探讨两个 0.4.0 版本更新产生的遗留问题。第一个,tensor.datatensor.detach()

在 0.4.0 版本以前,.data 是用来取 Variable 中的 tensor 的,但是之后 Variable 被取消,.data 却留了下来。现在我们调用 tensor.data,可以得到 tensor的数据 + requires_grad=False 的版本,而且二者共享储存空间,也就是如果修改其中一个,另一个也会变。因为 PyTorch 的自动求导系统不会追踪 tensor.data 的变化,所以使用它的话可能会导致求导结果出错。官方建议使用 tensor.detach() 来替代它,二者作用相似,但是 detach 会被自动求导系统追踪,使用起来很安全[2]。多说无益,我们来看个例子吧:

a = torch.tensor([7., 0, 0], requires_grad=True)
b = a + 2
print(b)
# tensor([9., 2., 2.], grad_fn=<AddBackward0>)loss = torch.mean(b * b)b_ = b.detach()
b_.zero_()
print(b)
# tensor([0., 0., 0.], grad_fn=<AddBackward0>)
# 储存空间共享,修改 b_ , b 的值也变了loss.backward()
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

这个例子中,b 是用来计算 loss 的一个变量,我们在计算完 loss 之后,进行反向传播之前,修改 b 的值。这么做会导致相关的导数的计算结果错误,因为我们在计算导数的过程中还会用到 b 的值,但是它已经变了(和正向传播过程中的值不一样了)。在这种情况下,PyTorch 选择报错来提醒我们。但是,如果我们使用 tensor.data 的时候,结果是这样的:

a = torch.tensor([7., 0, 0], requires_grad=True)
b = a + 2
print(b)
# tensor([9., 2., 2.], grad_fn=<AddBackward0>)loss = torch.mean(b * b)b_ = b.data
b_.zero_()
print(b)
# tensor([0., 0., 0.], grad_fn=<AddBackward0>)loss.backward()print(a.grad)
# tensor([0., 0., 0.])# 其实正确的结果应该是:
# tensor([6.0000, 1.3333, 1.3333])

这个导数计算的结果明显是错的,但没有任何提醒,之后再 Debug 会非常痛苦。所以,建议大家都用 tensor.detach() 啊。上边这个代码例子是受 这里 启发。

5. CPU and GPU

接下来我们来说另一个问题,是关于 tensor.cuda()tensor.to(device) 的。后者是 0.4.0 版本之后后添加的,当 device 是 GPU 的时候,这两者并没有区别。那为什么要在新版本增加后者这个表达呢,是因为有了它,我们直接在代码最上边加一句话指定 device ,后面的代码直接用to(device) 就可以了:

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")a = torch.rand([3,3]).to(device)
# 干其他的活
b = torch.rand([3,3]).to(device)
# 干其他的活
c = torch.rand([3,3]).to(device)

而之前版本的话,当我们每次在不同设备之间切换的时候,每次都要用 if cuda.is_available() 判断能否使用 GPU,很麻烦。这个精彩的解释来自于 这里 。

if torch.cuda.is_available():a = torch.rand([3,3]).cuda()
# 干其他的活
if  torch.cuda.is_available():b = torch.rand([3,3]).cuda()
# 干其他的活
if  torch.cuda.is_available():c = torch.rand([3,3]).cuda()

关于使用 GPU 还有一个点,在我们想把 GPU tensor 转换成 Numpy 变量的时候,需要先将 tensor 转换到 CPU 中去,因为 Numpy 是 CPU-only 的。其次,如果 tensor 需要求导的话,还需要加一步 detach,再转成 Numpy 。例子如下:

x  = torch.rand([3,3], device='cuda')
x_ = x.cpu().numpy()y  = torch.rand([3,3], requires_grad=True, device='cuda').
y_ = y.cpu().detach().numpy()
# y_ = y.detach().cpu().numpy() 也可以
# 二者好像差别不大?我们来比比时间:
start_t = time.time()
for i in range(10000):y_ = y.cpu().detach().numpy()
print(time.time() - start_t)
# 1.1049120426177979start_t = time.time()
for i in range(10000):y_ = y.detach().cpu().numpy()
print(time.time() - start_t)
# 1.115112543106079
# 时间差别不是很大,当然,这个速度差别可能和电脑配置
# (比如 GPU 很贵,CPU 却很烂)有关。

6. tensor.item()

我们在提取 loss 的纯数值的时候,常常会用到 loss.item(),其返回值是一个 Python 数值 (python number)。不像从 tensor 转到 numpy (需要考虑 tensor 是在 cpu,还是 gpu,需不需要求导),无论什么情况,都直接使用 item() 就完事了。如果需要从 gpu 转到 cpu 的话,PyTorch 会自动帮你处理。

但注意 item() 只适用于 tensor 只包含一个元素的时候。因为大多数情况下我们的 loss 就只有一个元素,所以就经常会用到 loss.item()。如果想把含多个元素的 tensor 转换成 Python list 的话,要使用 tensor.tolist()

x  = torch.randn(1, requires_grad=True, device='cuda')
print(x)
# tensor([-0.4717], device='cuda:0', requires_grad=True)y = x.item()
print(y, type(y))
# -0.4717346727848053 <class 'float'>x = torch.randn([2, 2])
y = x.tolist()
print(y)
# [[-1.3069953918457031, -0.2710231840610504], [-1.26217520236969, 0.5559719800949097]]

结语

以上内容就是我平时在写代码的时候,觉得需要注意的地方。文章中用了一些简单的代码作为例子,旨在帮助大家理解。文章内容不少,看到这里的大家都辛苦了, 感谢阅读。

最后还是那句话,希望本文能对大家学习和理解 PyTorch 有所帮助。

参考

  1. PyTorch Docs: AUTOGRAD MECHANICS https://pytorch.org/docs/stable/notes/autograd.html
  2. PyTorch 0.4.0 release notes https://github.com/pytorch/pytorch/releases/tag/v0.4.0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532399.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

简述springmvc过程_spring mvc的工作流程是什么?

展开全部SpringMVC工作流程描述向服务器发送HTTP请求&#xff0c;请求被前端控制器 DispatcherServlet 捕获。DispatcherServlet 根据 -servlet.xml 中的配置对请62616964757a686964616fe59b9ee7ad9431333365646233求的URL进行解析&#xff0c;得到请求资源标识符(URI)。 然后根…

PyTorch 的 Autograd

PyTorch 的 Autograd 转自&#xff1a;PyTorch 的 Autograd PyTorch 作为一个深度学习平台&#xff0c;在深度学习任务中比 NumPy 这个科学计算库强在哪里呢&#xff1f;我觉得一是 PyTorch 提供了自动求导机制&#xff0c;二是对 GPU 的支持。由此可见&#xff0c;自动求导 (a…

商场楼层导视牌图片_百宝图商场电子导视软件中预约产品功能简介

百宝图商场电子导视软件中预约产品功能简介 管理端&#xff0c;可配合百宝图商场电子导视软件配套使用 1&#xff1a;数据展示&#xff1a;图形展示总预约数/预约时间峰值/预约途径/各途径数量对比 2&#xff1a;数据统计&#xff1a;有效预约数量/无效预约数量/无效预约原因备…

Pytorch autograd.grad与autograd.backward详解

Pytorch autograd.grad与autograd.backward详解 引言 平时在写 Pytorch 训练脚本时&#xff0c;都是下面这种无脑按步骤走&#xff1a; outputs model(inputs) # 模型前向推理 optimizer.zero_grad() # 清除累积梯度 loss.backward() # 模型反向求导 optimizer.step()…

相对熵与交叉熵_熵、KL散度、交叉熵

公众号关注 “ML_NLP”设为 “星标”&#xff0c;重磅干货&#xff0c;第一时间送达&#xff01;机器学习算法与自然语言处理出品公众号原创专栏作者 思婕的便携席梦思单位 | 哈工大SCIR实验室KL散度 交叉熵 - 熵1. 熵(Entropy)抽象解释&#xff1a;熵用于计算一个随机变量的信…

动手实现一个带自动微分的深度学习框架

动手实现一个带自动微分的深度学习框架 转自&#xff1a;Automatic Differentiation Tutorial 参考代码&#xff1a;https://github.com/borgwang/tinynn-autograd (主要看 core/tensor.py 和 core/ops.py) 目录 简介自动求导设计自动求导实现一个例子总结参考资料 简介 梯度…

git安装后找不见版本_结果发现git版本为1.7.4,(git --version)而官方提示必须是1.7.10及以后版本...

结果发现git版本为1.7.4,(git --version)而官方提示必须是1.7.10及以后版本升级增加ppasudo apt-add-repository ppa:git-core/ppasudo apt-get updatesudo apt-get install git如果本地已经安装过Git&#xff0c;可以使用升级命令&#xff1a;sudo apt-get dist-upgradeapt命令…

随机数生成算法:K进制逐位生成+拒绝采样

随机数生成算法&#xff1a;K进制逐位生成拒绝采样 转自&#xff1a;【宫水三叶】k 进制诸位生成 拒绝采样 基本分析 给定一个随机生成 1 ~ 7 的函数&#xff0c;要求实现等概率返回 1 ~ 10 的函数。 首先需要知道&#xff0c;在输出域上进行定量整体偏移&#xff0c;仍然满…

深入理解NLP Subword算法:BPE、WordPiece、ULM

深入理解NLP Subword算法&#xff1a;BPE、WordPiece、ULM 本文首发于微信公众号【AI充电站】&#xff0c;感谢大家的赞同、收藏和转发(▽) 转自&#xff1a;深入理解NLP Subword算法&#xff1a;BPE、WordPiece、ULM 前言 Subword算法如今已经成为了一个重要的NLP模型性能提升…

http 错误 404.0 - not found_电脑Regsvr32 用法和错误消息的说明

​ 对于那些可以自行注册的对象链接和嵌入 (OLE) 控件&#xff0c;例如动态链接库 (DLL) 文件或 ActiveX 控件 (OCX) 文件&#xff0c;您可以使用 Regsvr32 工具 (Regsvr32.exe) 来将它们注册和取消注册。Regsvr32.exe 的用法RegSvr32.exe 具有以下命令行选项&#xff1a; Regs…

mysql error 1449_MySql错误:ERROR 1449 (HY000)

笔者系统为 mac &#xff0c;不知怎的&#xff0c;Mysql 竟然报如下错误&#xff1a;ERROR 1449 (HY000): The user specified as a definer (mysql.infoschemalocalhost) does not exist一时没有找到是什么操作导致的这个错误。然后经过查询&#xff0c;参考文章解决了问题。登…

MobileNet 系列:从V1到V3

MobileNet 系列&#xff1a;从V1到V3 转自&#xff1a;轻量级神经网络“巡礼”&#xff08;二&#xff09;—— MobileNet&#xff0c;从V1到V3 自从2017年由谷歌公司提出&#xff0c;MobileNet可谓是轻量级网络中的Inception&#xff0c;经历了一代又一代的更新。成为了学习轻…

mysql 查询表的key_mysql查询表和字段的注释

1,新建表以及添加表和字段的注释.create table auth_user(ID INT(19) primary key auto_increment comment 主键,NAME VARCHAR(300) comment 姓名,CREATE_TIME date comment 创建时间)comment 用户信息表;2,修改表/字段的注释.alter table auth_user comment 修改后的表注…

mysql 高级知识点_这是我见过最全的《MySQL笔记》,涵盖MySQL所有高级知识点!...

作为运维和编程人员&#xff0c;对MySQL一定不会陌生&#xff0c;尤其是互联网行业&#xff0c;对MySQL的使用是比较多的。MySQL 作为主流的数据库&#xff0c;是各大厂面试官百问不厌的知识点&#xff0c;但是需要了解到什么程度呢&#xff1f;仅仅停留在 建库、创表、增删查改…

teechart mysql_TeeChart 的应用

TeeChart 是一个很棒的绘图控件&#xff0c;不过由于里面没有注释&#xff0c;网上相关的资料也很少&#xff0c;所以在应用的时候只能是一点点的试。为了防止以后用到的时候忘记&#xff0c;我就把自己用到的东西都记录下来&#xff0c;以便以后使用的时候查询。1、进制缩放图…

NLP新宠——浅谈Prompt的前世今生

NLP新宠——浅谈Prompt的前世今生 转自&#xff1a;NLP新宠——浅谈Prompt的前世今生 作者&#xff1a;闵映乾&#xff0c;中国人民大学信息学院硕士&#xff0c;目前研究方向为自然语言处理。 《Pre-train, Prompt, and Predict: A Systematic Survey of Prompting Methods in…

mysql key_len_浅谈mysql explain中key_len的计算方法

mysql的explain命令可以分析sql的性能&#xff0c;其中有一项是key_len(索引的长度)的统计。本文将分析mysql explain中key_len的计算方法。1、创建测试表及数据CREATE TABLE member (id int(10) unsigned NOT NULL AUTO_INCREMENT,name varchar(20) DEFAULT NULL,age tinyint(…

requestfacade 这个是什么类?_Java 的大 Class 到底是什么?

作者在之前工作中&#xff0c;面试过很多求职者&#xff0c;发现有很多面试者对Java的 Class 搞不明白&#xff0c;理解的不到位&#xff0c;一知半解&#xff0c;一到用的时候&#xff0c;就不太会用。想写一篇关于Java Class 的文章&#xff0c;没有那么多专业名词&#xff0…

初学机器学习:直观解读KL散度的数学概念

初学机器学习&#xff1a;直观解读KL散度的数学概念 转自&#xff1a;初学机器学习&#xff1a;直观解读KL散度的数学概念 译自&#xff1a;https://towardsdatascience.com/light-on-math-machine-learning-intuitive-guide-to-understanding-kl-divergence-2b382ca2b2a8 解读…

php mysql读取数据查询_PHP MySQL 读取数据

PHP MySQL 读取数据从 MySQL 数据库读取数据SELECT 语句用于从数据表中读取数据:SELECT column_name(s) FROM table_name我们可以使用 * 号来读取所有数据表中的字段&#xff1a;SELECT * FROM table_name如需学习更多关于 SQL 的知识&#xff0c;请访问我们的 SQL 教程。使用 …