Torch 模型 感受野可视化

前言:感受野是卷积神经网络 (CNN) 中一个重要的概念,它表示 CNN 每一层输出的特征图上的像素点在输入图像上映射的区域。感受野的大小和形状直接影响到网络对输入图像的感知范围和精度,进而调整网络结构、卷积核大小和步长等参数,以改善网络的性能。

效果:本文的实验在 torchvision.models 中的 resnet18 上进行,分别绘制了理论感受野、训练前感受野、训练后感受野

5db41ff89046413db29b9d2546c6e5b9.png

开发环境:PyTorch 1.9.0

适用模型:最大池化层使用 nn.MaxPool 而不是 torch.nn.functional.max_pool 的模型

声明:本文所使用代码不开源,觉得本文的思路可行的话,请加 QQ - 1398173074 购买 (¥40,注明来意)

商品仅包含一份 120+ 行的代码。本文所使用的代码基于 torch、matplotlib 以及其它标准库。其中包含一个名为 ReceptiveField 的类,用于绘制图像识别网络的感受野

代码实现

ReceptiveField 提供了以下函数:

  • _replace:将 MaxPool (这种求最大值的操作会影响感受野的正确性) 替换为 AvgPool
  • __init__:注册前向传播的“挂钩”,用于提取目标层的特征图用于反向传播
  • _backward:前向推导图像,利用“挂钩”获取特征图,从特征图中心点反向传播梯度,进行一系列处理后将梯度图转换为感受野图
  • theoretical:结合 _backward 函数求解理论感受野,其结果经过 sum、sqrt 之后即为理论感受野的尺寸
  • effective:默认情况下结合 _backward 函数求解训练前感受野 (即随机权重的模型);给定 state_dict 时将加载权重,求解训练后的感受野
  • compare:使用 matplotlib 绘制理论感受野、训练前感受野、训练后感受野
class ReceptiveField:""" :param model: 需要进行可视化的模型:param tar_layer: 感兴趣的层, 其所输出特征图需有 4 个维度 [B, C, H, W]:param img_size: 测试时使用的图像尺寸"""def make_input(self, n_sample): ...def __init__(self,model: nn.Module,tar_layer: Union[int, nn.Module],img_size: Union[int, Tuple[int, int]],use_cuda: bool = False,use_copy: bool = False): ...def compare(self, theoretical=True, original=True, state_dict=None, **imshow_kw):""" :param theoretical: 是否绘制理论感受野:param original: 是否绘制训练前的感受野:param state_dict: 模型权值, 如果提供则绘制训练后的感受野"""def effective(self, state_dict=None):""" :param state_dict: 模型权值, 如果提供则绘制训练后的感受野"""def theoretical(self, light=1.):""" :param light: 理论感受野的亮度 [0, 1]"""def _replace(self, model): ...def _backward(self, x): ...

在本文的示例中,对 resnet18 的 layer3 进行了可视化,并计算出理论感受野的尺寸为 211×211

if __name__ == "__main__":from torchvision.models import resnet18# Step 1: 刚完成初始化的模型, 权重<完全随机>, 表 "训练前"m = resnet18()# Step 2: 训练完成后的 state_dict, 等待 ReceptiveField 加载state_dict = resnet18(pretrained=True).state_dict()# Step 3: 绘制感受野 (设置 ReceptiveField 的 use_copy=True, 将创建模型的深拷贝副本)with ReceptiveField(m, tar_layer=m.layer3, img_size=256, use_copy=True) as r:r.compare(state_dict=state_dict)# 理论感受野的尺寸s = round(r.theoretical().sum() ** 0.5)print(f"Theoretical RF: {s}×{s}")plt.show()# Step 4: 加载模型的参数m.load_state_dict(state_dict)

如果将 resnet18 中的某一个卷积改成空洞卷积,感受野将进一步增大到 243×243

if __name__ == "__main__":from torchvision.models import resnet18# Step 1: 刚完成初始化的模型, 权重<完全随机>, 表 "训练前"m = resnet18()print(m)m.layer3[1].conv1.dilation = 2m.layer3[1].conv1.padding = 2# Step 2: 训练完成后的 state_dict, 等待 ReceptiveField 加载state_dict = resnet18(pretrained=True).state_dict()# Step 3: 绘制感受野 (设置 ReceptiveField 的 use_copy=True, 将创建模型的深拷贝副本)with ReceptiveField(m, tar_layer=m.layer3, img_size=256, use_copy=True) as r:r.compare(state_dict=state_dict)# 理论感受野的尺寸s = round(r.theoretical().sum() ** 0.5)print(f"Theoretical RF: {s}×{s}")plt.show()# Step 4: 加载模型的参数m.load_state_dict(state_dict)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2165.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

javaweb-maven

前端HTML,CSS,JS,Vue&#xff0c;Element&#xff0c;Nginx最后去复习&#xff0c; Java开发工程师 主要学习方向是服务端 所以进入javaweb的服务端的第一个知识点 maven 什么是maven 用于管理和构建java项目的工具 maven的官方网站 Maven – Welcome to Apache Maven …

Flink面试(1)

1.Flink 的并行度的怎么设置的&#xff1f; Flink设置并行度的几种方式 1.代码中设置setParallelism() 全局设置&#xff1a; 1 env.setParallelism(3);  算子设置&#xff08;部分设置&#xff09;&#xff1a; 1 sum(1).setParallelism(3) 2.客户端CLI设置&#xff0…

邀请全球创作者参与 The Sandbox 创作者训练营

作为首屈一指的元宇宙平台之一&#xff0c;The Sandbox 的使命是成为全球创作者的中心。随着我们对 Game Maker 的不断改进、旨在激发创作者灵感的定期 Game Jams、革命性的 "创作者挑战 "以及众多其他活动的开展&#xff0c;我们见证了大量个人加入我们充满活力的创…

opencv_5_图像像素的算术操作

方法1&#xff1a;调用库函数 void ColorInvert::mat_operator(Mat& image) { Mat dst; Mat m Mat::zeros(image.size(), image.type()); m Scalar(2, 2, 2); multiply(image, m, dst); m1 Scalar(50,50, 50); //divide(image, m, dst); //add(im…

WordPress social-warfare插件XSS和RCE漏洞【CVE-2019-9978】

WordPress social-warfare插件XSS和RCE漏洞 ~~ 漏洞编号 : CVE-2019-9978 影响版本 : WordPress social-warfare < 3.5.3 漏洞描述 : WordPress是一套使用PHP语言开发的博客平台&#xff0c;该平台支持在PHP和MySQL的服务器上架设个人博客网站。social-warfare plugin是使用…

AIGC元年大模型发展现状手册

零、AIGC大模型概览 AIGC大模型在人工智能领域取得了重大突破&#xff0c;涵盖了LLM大模型、多模态大模型、图像生成大模型以及视频生成大模型等四种类型。这些模型不仅拓宽了人工智能的应用范围&#xff0c;也提升了其处理复杂任务的能力。a.) LLM大模型通过深度学习和自然语…

SpringBoot学习路线推荐

以下是一个基本的 Spring Boot 学习路线&#xff1a; 1. 基础知识&#xff1a;了解 Java 基础、面向对象编程和设计模式。 2. Spring Boot 概述&#xff1a;学习 Spring Boot 的核心概念和优势。 3. 开发环境设置&#xff1a;配置 IDE 和相关工具。 4. 创建项目&#xff1a;使…

yolov8下实现绿萝识别

目录 一:背景 二:过程 一:背景 上一节我们学习了yolov8自带模型的使用,这一节我们讲解下yolov8的数据训练,生成模型来识别绿萝。 二:过程 1:数据准备,我们可以自己收集绿萝的图片,最起码需要准备几百张的图片。我们这里通过网络下载图片保存到一个目录等待处理。 …

Thinkphp封装统一响应

前言 我们平时开发新项目api接口的时候总是要先自定义自己的响应数据格式&#xff0c;但是每个人的风格习惯不同&#xff0c;导致开发人员封装的响应数据格式不统一&#xff0c;而且需要花时间去重复写。本篇文章主要是统一 API 开发过程中「成功」、「失败」以及「异常」情况…

MSR是个什么寄存器

MSR 这种寄存器专门用于调试、程序执行跟踪、计算机性能监控、简化软件编程、电源控制等等各种实验性功能。 什么是 MSR MSR 的概念是不易理解&#xff0c;所以这一节只说一些 MSR 的外在&#xff0c;比如形容和指令等&#xff0c;然后展开说说&#xff0c;看完整篇文章你应该…

计算机视觉 CV 八股分享 [自用](更新中......)

目录 一、深度学习中解决过拟合方法 二、深度学习中解决欠拟合方法 三、梯度消失和梯度爆炸 解决梯度消失的方法 解决梯度爆炸的方法 四、神经网络权重初始化方法 五、梯度下降法 六、BatchNorm 七、归一化方法 八、卷积 九、池化 十、激活函数 十一、预训练 十二…

【uniapp】 合成海报组件

之前公司的同事写过一个微信小程序用的 合成海报的组件 非常十分好用 最近的项目是uni的 把组件改造一下也可以用 记录一下 <template><view><canvas type"2d" class"_mycanvas" id"my-canvas" canvas-id"my-canvas" …

RT-Thread电源管理组件

电源管理组件 嵌入式系统低功耗管理的目的在于满足用户对性能需求的前提下&#xff0c;尽可能降低系统能耗以延长设备待机时间。 高性能与有限的电池能量在嵌入式系统中矛盾最为突出&#xff0c;硬件低功耗设计与软件低功耗管理的联合应用成为解决矛盾的有效手段。 现在的各种…

UniApp 中的路由魔法:玩转页面导航与跳转

正文&#xff1a; 路由在移动应用开发中是一个至关重要的概念&#xff0c;它决定了用户在应用中导航的方式&#xff0c;以及页面之间的跳转和传参方式。在 UniApp 中&#xff0c;路由配置也有其独特的特点和用法。本文将深入探讨 UniApp 中的路由配置&#xff0c;带你领略其中…

排序算法之桶排序

目录 一、简介二、代码实现三、应用场景 一、简介 算法平均时间复杂度最好时间复杂度最坏时间复杂度空间复杂度排序方式稳定性桶排序O(nk )O(nk)O(n^2)O(nk)Out-place稳定 稳定&#xff1a;如果A原本在B前面&#xff0c;而AB&#xff0c;排序之后A仍然在B的前面&#xff1b; 不…

NIO之非阻塞模式

NIO支持非阻塞模式&#xff0c;以网络连接和网络数据传输为例。如果使用阻塞模式&#xff0c;ServerSocketChannel在调用accept等待客户端建立连接是阻塞的&#xff0c;没有连接就一直阻塞。从Channel中读取客户端传送的数据也是阻塞的&#xff0c;没有数据就一直阻塞。当我们开…

Kotlin语法快速入门--条件控制和循环语句(2)

Kotlin语法入门–条件控制和循环语句&#xff08;2&#xff09; 文章目录 Kotlin语法入门--条件控制和循环语句&#xff08;2&#xff09;二、条件控制和循环语句1、if...else2、when2.1、常规用法2.2、特殊用法--并列&#xff1a;2.3、特殊用法--类型判断&#xff1a;2.4、特殊…

C语言进阶课程学习记录-第48课 - 函数设计原则

C语言进阶课程学习记录 - 函数设计原则 本文学习自狄泰软件学院 唐佐林老师的 C语言进阶课程&#xff0c;图片全部来源于课程PPT&#xff0c;仅用于个人学习记录

无人驾驶 自动驾驶汽车 环境感知 精准定位 决策与规划 控制与执行 高精地图与车联网V2X 深度神经网络学习 深度强化学习 Apollo

无人驾驶 百度apollo课程 1-5 百度apollo课程 6-8 七月在线 无人驾驶系列知识入门到提高 当今,自动驾驶技术已经成为整个汽车产业的最新发展方向。应用自动驾驶技术可以全面提升汽车驾驶的安全性、舒适性,满足更高层次的市场需求等。自动驾驶技术得益于人工智能技术的应用…

Linux i2c-tool工具基础使用

一.i2cdetect i2cdetect 是一个用户空间程序&#xff0c;用于扫描 I2C 总线上的设备。它输出一个表格&#xff0c;其中包含指定总线上检测到的设备列表。以下是 i2cdetect 的使用方法&#xff1a; 运行扫描&#xff1a; 要执行 I2C 扫描&#xff0c;请使用以下命令&#xff1…