pytorch forward_pytorch使用hook打印中间特征图、计算网络算力等

fc2acf1e55d9cc946bfefc2a72a63ee5.png

0、参考

https://oldpan.me/archives/pytorch-autograd-hook

https://pytorch.org/docs/stable/search.html?q=hook&check_keywords=yes&area=default

https://github.com/pytorch/pytorch/issues/598

https://github.com/sksq96/pytorch-summary

https://github.com/allensll/test/blob/591c7ce3671dbd9687b3e84e1628492f24116dd9/net_analysis/viz_lenet.py

1、背景

在神经网络的反向传播当中个,流程只保存叶子节点的梯度,对于中间变量的梯度没有进行保存。

import torch
x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
z.backward()
x.data -= lr*x.grad.data
print(y.grad)

此时输出就是:None,这个时候hook的作用就派上,hook可以通过自定义一些函数,从而完成中间变量的输出,比如中间特征图、中间层梯度修正等。

​ 在pytorch docs搜索hook,可以发现有四个hook相关的函数,分别为register_hook,register_backward_hook,register_forward_hook,register_forward_pre_hook。其中register_hook属于tensor类,而后面三个属于moudule类。

  • register_hook函数属于torch.tensor类,函数在tensor梯度计算的时候就会执行,这个函数主要处理梯度相关的数据,表现形式$hook(grad) rightarrow Tensor or None$.
import torch
x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
y = x * 2
y.register_hook(print)
<torch.utils.hooks.RemovableHandle at 0x7f765e876f60>
z = torch.mean(y)
z.backward()
tensor([ 0.5000,  0.5000])
  • Register_backward_hook等三个属于torch.nn,属于moudule中的方法。
hook(module, grad_input, grad_output) -> Tensor or None

写个demo,参考:

下面的计算为

import torch
import torch.nn as nn
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def print_hook(grad):print ("register hook:", grad)return gradclass TestNet(nn.Module):def __init__(self):super(TestNet, self).__init__()self.f1 = nn.Linear(4, 1, bias=True)self.weights_init()def weights_init(self):self.f1.weight.data.fill_(4)self.f1.bias.data.fill_(0.1)def forward(self, input):self.input = inputout = input * 0.75out = self.f1(out)out = out / 4return outdef back_hook(self, moudle, grad_input, grad_output):print ("back hook in:", grad_input)print ("back hook out:", grad_output)# 修改梯度# grad_input = list(grad_input)# grad_input[0] = grad_input[0] * 100# print (grad_input)return tuple(grad_input)if __name__ == '__main__':input = torch.tensor([1, 2, 3, 4], dtype=torch.float32, requires_grad=True).to(device)net = TestNet()net.to(device)net.register_backward_hook(net.back_hook)ret = net(input)print ("result", ret)ret.backward()print('input.grad:', input.grad)for param in net.parameters():print('{}:grad->{}'.format(param, param.grad))

输出:

result tensor([7.5250], grad_fn=<DivBackward0>)
back hook in: (tensor([0.2500]), None)
back hook out: (tensor([1.]),)
input.grad: tensor([0.7500, 0.7500, 0.7500, 0.7500])
Parameter containing:
tensor([[4., 4., 4., 4.]], requires_grad=True):grad->tensor([[0.1875, 0.3750, 0.5625, 0.7500]])
Parameter containing:
tensor([0.1000], requires_grad=True):grad->tensor([0.2500])

输出结果以及梯度都很明显,简单分析一下w权重的梯度,

另外,hook中有个bug,假设我们bug,假设我们注释掉out = out / 4这行,可以发现输出变成back hook in: (tensor([1.]), tensor([1.]))。这种情况就不符合上面我们的梯度计算公式,是因为这个时候:

则此时的偏导只是对

进行计算,所以都是1,1。这是pytorch的设计缺陷

c0f3c2eaea4270329b9560d7b13622ff.png
  • register_forward_hook跟Register_backward_hook差不多,就不过多复述。
  • register_forward_pre_hook,可以发现其输入只有hook(module, input) -> None
    其主要是针对推理时的hook.

2、应用

2.1 特征图打印

​ 直接利用pytorch已有的resnet18进行特征图打印,只打印卷积层的特征图,

import torch
from torchvision.models import resnet18
import torch.nn as nn
from torchvision import transformsimport matplotlib.pyplot as pltdef viz(module, input):x = input[0][0]#最多显示4张图min_num = np.minimum(4, x.size()[0])for i in range(min_num):plt.subplot(1, 4, i+1)plt.imshow(x[i].cpu())plt.show()import cv2
import numpy as np
def main():t = transforms.Compose([transforms.ToPILImage(),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = resnet18(pretrained=True).to(device)for name, m in model.named_modules():# if not isinstance(m, torch.nn.ModuleList) and #         not isinstance(m, torch.nn.Sequential) and #         type(m) in torch.nn.__dict__.values():# 这里只对卷积层的feature map进行显示if isinstance(m, torch.nn.Conv2d):m.register_forward_pre_hook(viz)img = cv2.imread('./cat.jpeg')img = t(img).unsqueeze(0).to(device)with torch.no_grad():model(img)if __name__ == '__main__':main()

直接放几张中间层的图

70f5f173e24d17147e5b8704cce87918.png
图1 第一层卷积层输入

f955bb70ffca1cf64b78dbd6e7ce59ae.png
图2 第四层卷积层的输入

2.2 模型大小,算力计算

同样的用法,可以直接参考pytorch-summary这个项目。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/246122.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Games101现代图形学入门Lecture 4: Transformation Cont知识点总结

视频链接&#xff1a;https://www.bilibili.com/video/BV1X7411F744?p4 课程主页链接&#xff1a;http://games-cn.org/intro-graphics/ 课件PPT链接&#xff1a;http://games-cn.org/graphics-intro-ppt-video/ 1. 3D变换 缩放和平移矩阵 旋转矩阵 欧拉角&#xff1a;rol…

python3 for_Python3: for 表达式

#1.在控制台输入一个成绩score #2.判断成绩&#xff0c; #*如果成绩小于60输出不及格 #60到70 及格 #70到80 中等 #80到90 良好 #90 100 优秀 def level(score_list): # score input("请输入成绩&#xff1a;") # while score!"stop": for sc in score_li…

Hash和红黑树以及其在C#中的应用

参考资料&#xff1a; .Net 中HashTable&#xff0c;HashMap 和 Dictionary<key,value> 和List<T>和DataTable的比较 - 王若伊_恩赐解脱 - 博客园 c#HashSet源码解析_fdyshlk的博客-CSDN博客_c# hashset 红黑树和哈希表的区别 - 安全技术 - 亿速云 一、基本概念…

networkx 标签_networkx绘制BA无标度网络

step1: 导入networkx复杂网络库、matplotlib.pyplot、pandasimport networkx as nximport matplotlib.pyplot as pltimport pandas as pdstep2: 绘制BA无标度网络Gnx.barabasi_albert_graph(1000,1) #generate BA networkposnx.spring_layout(G) #set layoutnodecolorG.degree(…

Unity URP中的多Pass Shader和Planer shadow

一 .Unity移动端软阴影技术总结&#xff1a; https://blog.csdn.net/jxw167/article/details/82422891 二. 平面阴影的原理 https://zhuanlan.zhihu.com/p/42781261 https://zhuanlan.zhihu.com/p/31504088 王者荣耀游戏使用的就是该方法&#xff0c;已经有上线产品验证过…

java连接mongodb_第78天: Python 操作 MongoDB 数据库介绍

MongoDB 是一款面向文档型的 NoSQL 数据库&#xff0c;是一个基于分布式文件存储的开源的非关系型数据库系统&#xff0c;其内容是以 K/V 形式存储&#xff0c;结构不固定&#xff0c;它的字段值可以包含其他文档、数组和文档数组等。其采用的 BSON(二进制 JSON )的数据结构&am…

URP中的2D Light光照在移动端不生效的问题

最近在尝试用URP推出的还在preview阶段的2D Render系统&#xff0c;发现2D光照在打成APK后失效&#xff0c;尝试了些方法后发现把2d光照用到的shader放进设置中的built in shader后可以解决问题&#xff1a;

大连开发区取暖费能微信支付吗_下半年教资报考人数增加,那到底能不能异地报考呢?...

想要每周获取两篇群文件快扫码进群吧~因为教师资格证认定的问题&#xff0c;最近教师资格证备考又被广大考生提上了日程&#xff0c;由于“先上岗&#xff0c;后考证”政策&#xff0c;小编预测下一年教师资格证考试的通过率肯定没有以前那么高了&#xff0c;不少人就想选择异地…

python3项目源代码下载_2019年最值得关注的34个Python开源项目——Let's go!

踏着人工智能、区块链的东风&#xff0c;近年来一路“横冲直撞”的 Python 在实现了从小众语言到主流的完美转身后&#xff0c;一头扎进了 2019&#xff0c;依旧没有透出丝毫停下来的架势&#xff0c;反倒有些越烧越热的味道。本文将为你介绍 2019 年最值得关注的 34 个 Python…

Unity 音频优化方案

参考资料&#xff1a; https://www.cnblogs.com/bearhb/p/11210136.html https://blog.csdn.net/chenfujun818/article/details/81710895 文件格式 mp3:失真小&#xff0c;适合音质要求高的文件&#xff0c;例如BGM wav:资源大&#xff0c;不推荐 ogg:压缩比高&#xff0c;适…

android home键后计时拉起app_使用React Native完成App软件

搭建开发环境安装react-native-cli&#xff1a;npm i -g react-native-cliAndroid SDK安装Android SDK并启动进行配置&#xff1a;配置环境变量export ANDROID_HOME~/Library/Android/sdk export PATH${PATH}:${ANDROID_HOME}/tools export PATH${PATH}:${ANDROID_HOME}/platfo…

Unity AssetBundle内存管理相关问题

AssetBundle机制相关资料收集 最近网友通过网站搜索Unity3D在手机及其他平台下占用内存太大. 这里写下关于Unity3D对于内存的管理与优化. Unity3D 里有两种动态加载机制&#xff1a;一个是Resources.Load&#xff0c;另外一个通过AssetBundle,其实两者区别不大。 Resources.L…

移动超级sim卡 无法下载卡_中国移动发布超级SIM卡:全变了

近日&#xff0c;中国移动正式公布了《中国移动超级SIM卡技术白皮书》&#xff0c;明确乐中国移动对于个人领域SIM卡的发展方向、架构设计、能力要求&#xff0c;旨在为行业规划设计SIM卡相关技术、产品和解决方案时提供参考和指引。据悉&#xff0c;中国移动的超级SIM卡增强了…

echart中拆线点的偏移_Qt中圆弧和扇形的绘制

在超声软件的开发中&#xff0c;超声成像模块需要绘制圆弧&#xff0c;例如绘制一个扇形的取样框&#xff0c;左右是一条直线&#xff0c;上下是一个圆弧&#xff0c;像这样。Qt中使用QPainter::drawArc绘制圆弧&#xff0c;使用QPainter::drawPie绘制扇形。圆弧和扇形的绘制接…

python xpath定位元素方法_Python爬虫杂记 - Xpath高级用法

xpath 高级用法 1. 匹配当前节点下的所有&#xff1a; .// . 表示当前 // 表示当前标签下的所有标签 注&#xff1a; 要配合使用 2. 匹配某标签的属性值&#xff1a; /属性名称 这里以input里的value值为例&#xff1a; 例&#xff1a;xpath(//input/value) 3. 匹配多个路径&am…

反向Z(Reversed-Z)的深度缓冲原理

参考文章&#xff1a;https://zhuanlan.zhihu.com/p/75517534 https://zjinc36.github.io/2020/03/10/2020-20200309-%E6%B7%B1%E5%85%A5%E7%90%86%E8%A7%A3%E6%B5%AE%E7%82%B9%E6%95%B0%E4%B8%8E%E6%B5%AE%E7%82%B9%E6%95%B0%E7%9A%84%E7%B2%BE%E5%BA%A6%E9%97%AE%E9%A2%98/ …

screenocr怎么卸载_screenocr是什么软件 screenocr软件及其功能介绍

在日常的生活和工作当中不免会遇到一些无法进行复制但是又想要去将它摘录下来的文字。用手去进行输入的话及麻烦又费力&#xff0c;这个时候我们可以使用OCR技术来讲它们识别出来。而screenocr就是这样子的一款软件&#xff0c;还不是很了解screenocr都有哪些功能如何使用的用户…

saspython知乎_SAS入门书籍有哪些值得推荐?

2020年 8月更新&#xff1a;我觉得&#xff0c;我应该推荐下我本人出版的《SAS编程演义》《SAS编程演义》(谷鸿秋)【摘要 书评 试读】- 京东图书​item.jd.com ------------------------------------------------------------ 我觉得这个问题我还是可以唠叨几句的&#xff0c;我…

access exex控制pc_ownCloud/Nextcloud文件访问控制(Files Access Control)

事实上这是一个插件(APP)&#xff0c;也是ownCloud/Nextcloud的一项重要功能&#xff1a;文件访问控制。文件访问控制APP可以提供丰富强大的访问管理功能&#xff0c;从单文件权限到组文件&#xff0c;再到IP地址屏蔽&#xff0c;可以引用访问的时间、文件类型、用户、组等因素…

渲染杂谈:early-z、z-culling、hi-z、z-perpass到底是什么?

渲染杂谈&#xff1a;early-z、z-culling、hi-z、z-perpass到底是什么&#xff1f; 之前一直被这几个和深度缓存&#xff08;z-buffer&#xff09;相关的概念搞得神魂颠倒。今天在翻阅《Real-Time Rendering》时碰巧碰巧看到了这部分的讲解。硬着头皮看了看&#xff0c;姑且算…