【深度学习】残差网络(ResNet)

如果按照李沐老师书上来,学完 VGG 后还有 NiN 和 GoogLeNet 要学,但是这两个我之前听都没听过,而且我看到我导师有发过 ResNet 相关的论文,就想跳过它们直接看后面的内容。

现在看来这不算是不踏实,因为李沐老师说如果卷积神经网络只学一个架构的话,那就学这个 ResNet(Residual Network)。

随着我们设计越来越深的网络,深刻理解“新添加的层如何提升神经网络的性能”变得至关重要。

加了更多层一定更有用吗?如果不是的话,怎么样加入新层可以有效提高精度呢?

我们通过下图来进行理解。以前我们在网络加入新的卷积层或者全连接层有点像左图中的非嵌套函数类。

尽管随着层数变多(函数 f 1 − f 6 f_1-f_6 f1f6),能覆盖的最优值范围变大,但不一定能很有效的接近全局最优(蓝色五角星)。

例如左图中实际上加到 f 3 f_3 f3 时的最优值距离五角星更近。
嵌套函数类和非嵌套函数类
针对这一问题,何恺明等人(2016)提出了残差网络,其核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。

就像右图中的嵌套函数类,每次新加入函数都能保证不会离五角星更远,进而一步一步逼近全局最优。

下面我们看看如何实现“嵌套”。

一、残差块

块的思想我们在 VGG 中就了解过,可以帮助我们设计深层网络。

以前我们是通过串联起各层来扩大函数类(下左图),而残差块(下右图)通过加入一侧的快速通道,来得到 f ( x ) = x + g ( x ) f(x)=x+g(x) f(x)=x+g(x) 的结构。

正常块(左图)与残差块(右图)

如此的话,就算虚线框中的 g ( x ) g(x) g(x) 没有起到效果,我们也不会退步。

如果虚线框中的各层使得通道数改变,我们就需要加入 1$\times$1 卷积层来进行调整,以保证能加法顺利进行。

不包含以及包含 1$\times$1 卷积层的残差块

对于上图这类特殊的架构,我们需要采用自定义层的方式来实现。

class Residual(nn.Module):  # 定义残差块def __init__(self, input_channels, num_channels,use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)

ResNet 沿用了 VGG 完整的 3$\times$3 卷积层设计。残差块里首先有两个相同输出通道数的卷积层,每个卷积层后面接一个批量规范化层和激活函数。

上述代码通过调整参数use_1x1conv参数的取值,来决定是否添加 1$\times$1 卷积层。

一般我们在增加通道数时,我们会通过调整strides来使得高宽减半。

实际上,我们还可以改变块中组件的位置,可得到各种残差块的变体。

残差块变体

二、ResNet 模型

ResNet 的第一层为输出通道数 64、步幅 2 的 7$\times 7 卷积层,随后接 B N 层和步幅为 2 的 3 7 卷积层,随后接 BN 层和步幅为 2 的 3 7卷积层,随后接BN层和步幅为23\times$3 的最大汇聚层。

    b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

之后使用 4 个由残差块组成的模块,每个模块由若干个同样输出通道数的残差块组成。

第一个模块的通道数同输入通道数一致。由于之前已经使用了步幅为 2 的最大汇聚层,因此无需减小高和宽。

之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。

def resnet_block(input_channels, num_channels, num_residuals,first_block=False):    # 生成由残差块组成的模块blk = []for i in range(num_residuals):# 除了第一个模块,其他模块的第一个残差块需要宽高减半if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blk

我们这里每个模块使用 2 个残差块,其中其第一个模块使用first_block参数来避免宽高减半。

    # b2 不需要通道数翻倍,宽高减半b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))b3 = nn.Sequential(*resnet_block(64, 128, 2))b4 = nn.Sequential(*resnet_block(128, 256, 2))b5 = nn.Sequential(*resnet_block(256, 512, 2))

最后,加入自适应平均汇聚层、展平层和全连接输出层。AdaptiveAvgPool2d的使用可以保证最后的输出为 (1, 1),不用去管池化窗口的大小。

    net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(), nn.Linear(512, 10))

4 个模块,每个模块两个残差块,一个残差块 2 个卷积层,加上最初的 7$\times$7 卷积层和最后的全连接层,共 18 层,故上述模型通常称为 ResNet-18。

在训练模型之前,我们来观察一下各个模块的输入形状是如何变化的。

    # 查看各模块输出形状X = torch.rand(size=(1, 1, 224, 224))for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape:\t', X.shape)
---------------------------------
Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 128, 28, 28])
Sequential output shape:	 torch.Size([1, 256, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 512, 1, 1])
Flatten output shape:	 torch.Size([1, 512])
Linear output shape:	 torch.Size([1, 10])       

第一个模块出来后是 56$\times$56,我开始算不到,因为光算卷积就已经是小数了,没往下算。后面上网查了下,发现是向下取整的,才明白。

这里放上尺寸的计算公式吧,参考这个:https://www.jianshu.com/p/612edc845ad5

卷积后,池化后尺寸计算公式:
(图像尺寸-卷积核尺寸 + 2*填充值)/步长 +1
(图像尺寸-池化窗尺寸 + 2*填充值)/步长 +1

后面 3 个模块都是通道数加倍,宽高减半,减为 7 × \times × 7 后,最后通过汇聚层变为 1 × \times × 1,聚集所有特征。

三、训练模型

同之前一样,我们在 Fashion-MNIST 数据集上训练 ResNet。

因为之前定义好了很多训练相关函数,所以训练代码可以非常轻松的写下来。

我都有点想写一个自己的工具包了,这样就不用每次都复制前面的代码,而是像李沐老师的 d2l 一样。

    lr, num_epochs, batch_size = 0.05, 10, 128    # ResNet使用的参数train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)train(net, train_iter, test_iter, num_epochs, lr, try_gpu())

原本书上的batch_size是 256 的,但是我的 GPU 内存不够,报错,调成了 128。

这次训练是最久的,早知道resize成更小的尺寸了。我看到沐神说他改成 96 只是为了更快运行,就没用 96,想同样用 224,好和前面的模型对比。

训练结果如下:

测试/训练精度变化图

训练损失变化图

10 轮的训练损失为 0.01110 轮的训练精度为 0.99810 轮的测试集精度为 0.926
运行在 cuda:0 上,处理速度为 228.1 样本/

这次的处理速度只是 VGG 的一半,但是效果是很不错的,训练损失仅有 0.011,训练精度都接近 100%了都,而且测试集精度也不低。

可以看出 ResNet 确实是非常有效的网络,它对后面的深层网络也产生了非常深远的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/83687.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue3学习(组合式API——父、子组件间通信详解)

目录 一、组合式API下的父组件传子组件。(自定义属性) (1)基本思想。 (2)核心注意点。(defineProps) (3)传递简单类型数据。 (4)传递对象类型数据。(v-bind"对象类型数据"…

W5500使用ioLibrary库创建TCP客户端

1、WIZnet全硬件TCP/IP协议栈 WIZnet全硬件TCP/IP协议栈,支持TCP,UDP,IPv4,ICMP,ARP,IGMP以及PPPoE协议。 以太网:支持BSD和WIZCHIP(W5500/W5300/W5200/W5100/W5100S)的SOCKET APIs驱动程序。 互联网: DHCP客户端 DNS客户端 FTP客…

管理Oracle Data Guard的最佳实践

Oracle Data Guard的中文名字叫数据卫士,顾名思义,它是生产库的一道保障。所以管理Data Guard是DBA的一项重要工作之一,管理Data Guard时主要有以下几个注意点需要引起重视。 备份库的归档日志积压 一般情况下,生产库的归档日志是…

BootCDN介绍(Bootstrap主导的前端开源项目免费CDN加速服务)

文章目录 BootCDN前端开源项目CDN加速服务全解析什么是BootCDN技术原理与架构CDN技术基础BootCDN架构特点1. 全球分布式节点网络2. 智能DNS解析系统3. 高效缓存管理机制4. 自动同步更新机制5. HTTPS和HTTP/2协议支持 BootCDN的核心优势速度与稳定性开源免费资源丰富度技术规范遵…

2025 Java 微信小程序根据code获取openid,二次code获取手机号【工具类】拿来就用

一、controller调用 /*** 登录** author jiaketao* since 2024-04-10*/ RestController RequestMapping("/login") public class LoginController {/*** 【小程序】登录获取session_key和openid** param code 前端传code* return*/GetMapping("/getWXSessionKe…

软件架构风格系列(3):管道 - 过滤器架构

文章目录 前言一、从生活场景到架构原理,看懂管道 - 过滤器的核心逻辑(一)什么是管道 - 过滤器架构?(二)核心组件拆解 二、架构设计图:一图看懂管道 - 过滤器架构全貌三、Java 示例代码&#xf…

【VIM】vim 常用命令

文章目录 插入模式光标移动拷贝/粘贴/删除/撤销块操作分屏代码缩进命令组合使用其他PowerVim 前言:本文内容大部分摘抄自酷壳和博客园   –   CoolShell – 陈皓   博客园 – 易先讯 插入模式 a → 在光标后插入o → 在当前行后插入一个新行O → 在当前行前插…

polarctf-web-[简单rce]

考点&#xff1a; (1)RCE(eval函数) (2)执行函数(passthru函数) (3)/顶级(根)目录查看 (4)sort排序查看函数 题目来源&#xff1a;Polarctf-web-[简单rce] 解题&#xff1a; 代码审计 <?php/*​PolarD&N CTF​*/highlight_file(__FILE__);function no($txt){ # …

HarmonyOs开发之———使用HTTP访问网络资源

谢谢关注&#xff01;&#xff01; 前言&#xff1a;上一篇文章主要介绍HarmonyOs开发之———Video组件的使用:HarmonyOs开发之———Video组件的使用_华为 video标签查看-CSDN博客 HarmonyOS 网络开发入门&#xff1a;使用 HTTP 访问网络资源 HarmonyOS 作为新一代智能终端…

Vue 图片预览功能(含缩略图)

众所周知&#xff0c;常见的组件库如Element、Ant Design&#xff0c;自带的图片预览功能都没有缩略图&#xff0c;所以 需要单独封装一个图片预览的服务。 第三方库&#xff1a;v-viewer 安装&#xff1a; npm install v-viewer viewerjs 若使用报错&#xff0c;可安装指定…

手写tomcat:基本功能实现(4)

逻辑架构 HTTP 请求与 Socket&#xff1a; 左侧的 “HTTP 请求” 箭头指向 “socket”&#xff0c;表示客户端发送的 HTTP 请求通过 socket 传输到服务器。Socket 负责接收请求&#xff0c;并提取出其中的 请求路径&#xff08;如 /first&#xff09;和 请求方法&#xff08;如…

jvm安全点(一)openjdk17 c++源码垃圾回收安全点信号函数处理线程阻塞

1. 信号处理入口​​ ​​JVM_HANDLE_XXX_SIGNAL​​ 是 JVM 处理信号的统一入口&#xff0c;负责处理 SIGSEGV、SIGBUS 等信号。​​javaSignalHandler​​ 是实际注册到操作系统的信号处理函数&#xff0c;直接调用 JVM_HANDLE_XXX_SIGNAL。 ​​2. 安全点轮询页的识别​​ …

微信小程序:封装表格组件并引用

一、效果 封装表格组件,在父页面中展示表格组件并显示数据 二、表格组件 1、创建页面 创建一个components文件夹,专门用于存储组件的文件夹 创建Table表格组件 2、视图层 (1)表头数据 这里会从父组件中传递表头数据,这里为columns,后续会讲解数据由来 循环表头数组,…

【FMC216】基于 VITA57.1 的 2 路 TLK2711 发送、2 路 TLK2711 接收 FMC 子卡模块

产品概述 FMC216 是一款基于 VITA57.1 标准规范的 2 路 TLK2711 接收、2 路 TLK2711 发送 FMC 子卡模块。该板卡支持 2 路 TLK2711 数据的收发&#xff0c;支持线速率 1.6Gbps&#xff0c;经过 TLK2711 高速串行收发器&#xff0c;可以将 1.6Gbps 的高速串行数据解串为 16 位并…

K8S Gateway API 快速开始、胎教级教程

假设有如下三个节点的 K8S 集群&#xff1a; ​​ k8s31master 是控制节点 k8s31node1、k8s31node2 是工作节点 容器运行时是 containerd 一、Gateway 是什么 背景和目的 入口&#xff08;Ingress&#xff09;目前已停止更新。新的功能正在集成至网关 API 中。在 Kubernetes …

时序数据库IoTDB分布式架构解析与运维指南

一、IoTDB分布式架构概述 分布式系统由一组独立的计算机组成&#xff0c;通过网络通信&#xff0c;对外表现为一个统一的整体。IoTDB的原生分布式架构将服务分为两个核心部分&#xff1a; ‌ConfigNode&#xff08;CN&#xff09;‌&#xff1a;管理节点&#xff0c;负责管理…

Ubuntu 20.04 LTS 中部署 网页 + Node.js 应用 + Nginx 跨域配置 的详细步骤

Ubuntu 20.04 LTS 中部署 网页 Node.js 应用 Nginx 跨域配置 的详细步骤 一、准备工作1、连接服务器2、更新系统 二、安装 Node.js 环境1、安装 Node.js 官方 PPA&#xff08;用于获取最新稳定版&#xff09;&#xff1a;2、安装 Node.js 和 npm&#xff08;LTS 长期支持版本…

3DVR制作的工具或平台

3DVR&#xff08;三维虚拟现实&#xff09;是利用三维图像技术和虚拟现实技术&#xff0c;将真实场景进行三维扫描并转换成计算机可识别的三维模型&#xff0c;使用户能够在虚拟空间中自由漫游&#xff0c;体验身临其境的感觉。3DVR技术结合了全景拍摄和虚拟现实&#xff0c;提…

垂直智能体:企业AI落地的正确打开方式

在当前AI浪潮中&#xff0c;许多企业急于跟进&#xff0c;推出自己的AI智能体解决方案。然而&#xff0c;市场上大量出现的"万能型"智能体却鲜有真正解决实际问题的产品。本文将探讨为何企业应该专注于开发垂直领域智能体&#xff0c;而非追求表面上的全能&#xff0…

软件工程各种图总结

目录 1.数据流图 2.N-S盒图 3.程序流程图 4.UML图 UML用例图 UML状态图 UML时序图 5.E-R图 首先要先了解整个软件生命周期&#xff1a; 通常包含以下五个阶段&#xff1a;需求分析-》设计-》编码 -》测试-》运行和维护。 软件工程中应用到的图全部有&#xff1a;系统…