28-29【动手学深度学习】批量归一化 + ResNet

1. 批量归一化

1.1 原理

当神经网络比较深的时候会发现:数据在下面,损失函数在上面,这样会出现什么问题?

  • 正向传递的时候,数据是从下往上一步一步往上传递
  • 反向传递的时候,数据是从上面往下传递,这时候就会出现问题:梯度在上面的时候比较大,越到下面就越容易变小(因为是n个很小的数进行相乘,越到后面结果就越小,也就是说越靠近数据的,层的梯度就越小
  • 上面的梯度比较大,那么每次更新的时候上面的层就会不断地更新;但是下面层因为梯度比较小,所以对权重地更新就比较少,这样的话就会导致上面的收敛比较快,而下面的收敛比较慢,这样就会导致底层靠近数据的内容(网络所尝试抽取的网络底层的特征:简单的局部边缘、纹理等信息)变化比较慢,上层靠近损失的内容(高层语义信息)收敛比较快,所以每一次底层发生变化,所有的层都得跟着变(底层的信息发生变化就导致上层的权重全部白学了),这样就会导致模型的收敛比较慢

所以提出了假设:能不能在改变底部信息的时候,避免顶部不断的重新训练?(这也是批量归一化所考虑的问题)

\varepsilon 是为了避免除以0

全连接层

通常,我们将批量规范化层置于全连接层中的仿射变换和激活函数之间。 设全连接层的输入为x,权重参数和偏置参数分别为W和b,激活函数为ϕ,批量规范化的运算符为BN。 那么,使用批量规范化的全连接层的输出的计算详情如下:

h=ϕ(BN(Wx+b)).

回想一下,均值和方差是在应用变换的"相同"小批量上计算的。

卷积层

同样,对于卷积层,我们可以在卷积层之后和非线性激活函数之前应用批量规范化。 当卷积有多个输出通道时,我们需要对这些通道的“每个”输出执行批量规范化,每个通道都有自己的拉伸(scale)和偏移(shift)参数,这两个参数都是标量。 假设我们的小批量包含m个样本,并且对于每个通道,卷积的输出具有高度p和宽度q。 那么对于卷积层,我们在每个输出通道的m⋅p⋅q个元素上同时执行每个批量规范化。 因此,在计算平均值和方差时,我们会收集所有空间位置的值,然后在给定通道内应用相同的均值和方差,以便在每个空间位置对值进行规范化。

批量归一化需要在激活函数之前,因为BN是线性的吗,而激活函数是非线性的

 使用BN,可以增大学习率,因此可以加速收敛速度

预测过程中的批量规范化

正如我们前面提到的,批量规范化在训练模式和预测模式下的行为通常不同。 首先,将训练好的模型用于预测时,我们不再需要样本均值中的噪声以及在微批次上估计每个小批次产生的样本方差了。 其次,例如,我们可能需要使用我们的模型对逐个样本进行预测。 一种常用的方法是通过移动平均估算整个训练数据集的样本均值和方差,并在预测时使用它们得到确定的输出。 可见,和暂退法一样,批量规范化层在训练模式和预测模式下的计算结果也是不一样的。

1.2 代码

从零实现

import torch
from torch import nn
from d2l import torch as d2ldef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和移位return Y, moving_mean.data, moving_var.data

创建一个正确的 BatchNorm 图层

class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y

应用BatchNorm 于LeNet模型

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4),nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16,kernel_size=5), BatchNorm(16, num_dims=4),nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(), nn.Linear(16 * 4 * 4, 120),BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2),nn.Sigmoid(), nn.Linear(84, 10))

在Fashion-MNIST数据集上训练网络

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

 拉伸参数 gamma 和偏移参数 beta

net[1].gamma.reshape((-1, )), net[1].beta.reshape((-1, ))

 简明实现

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6),nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16),nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(), nn.Linear(256, 120), nn.BatchNorm1d(120),nn.Sigmoid(), nn.Linear(120, 84), nn.BatchNorm1d(84),nn.Sigmoid(), nn.Linear(84, 10))
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

小结

  • 在模型训练过程中,批量规范化利用小批量的均值和标准差,不断调整神经网络的中间输出,使整个神经网络各层的中间输出值更加稳定。
  • 批量规范化在全连接层和卷积层的使用略有不同。
  • 批量规范化层和暂退层一样,在训练模式和预测模式下计算不同。
  • 批量规范化有许多有益的副作用,主要是正则化。另一方面,”减少内部协变量偏移“的原始动机似乎不是一个有效的解释。

2. ResNet

2.1 原理

        只有当较复杂的函数类包含较小的函数类时,我们才能确保提高它们的性能。 对于深度神经网络,如果我们能将新添加的层训练成恒等映射(identity function)f(x)=x,新模型和原模型将同样有效。 同时,由于新模型可能得出更优的解来拟合训练数据集,因此添加层似乎更容易降低训练误差。 

当经过很多层卷积之后,可能通道数会产生变化,所以要加上1×1的卷积转换通道数。 (通常情况下是,当高宽减半时,通道数变为原来的一倍)

2.2 代码

残差块

import torch 
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3,padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3,padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 =  nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)# 当 inplace=True 时,ReLU 会直接在输入张量上修改数据(覆盖原值),不分配额外内存存储输出self.relu = nn.ReLU(inplace=True)  def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y = Y + Xreturn F.relu(Y)

输入和输出形状一致

blk = Residual(3, 3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
Y.shape

 增加输出通道数的同时,减半输出的高和宽

blk = Residual(3, 6, use_1x1conv=True, strides=2)
blk(X).shape

 ResNet模型

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))def resnet_block(input_channels, num_channels, num_residuals,first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels, use_1x1conv=True,strides=2))else:blk.append(Residual(num_channels, num_channels))return blk# *的含义是将list展开,变成一个个的输入
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))net = nn.Sequential(b1, b2, b3, b4, b5, nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(), nn.Linear(512, 10))

观察一下ResNet中不同模块的输入形状是如何变化的

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape:\t', X.shape)

训练模型

lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

小结

  • 学习嵌套函数(nested function)是训练神经网络的理想情况。在深层神经网络中,学习另一层作为恒等映射(identity function)较容易(尽管这是一个极端情况)。
  • 残差映射可以更容易地学习同一函数,例如将权重层中的参数近似为零。
  • 利用残差块(residual blocks)可以训练出一个有效的深层神经网络:输入可以通过层间的残余连接更快地向前传播。
  • 残差网络(ResNet)对随后的深层神经网络设计产生了深远影响。

ResNet的梯度计算

3. 第二次kaggle竞赛

竞赛地址:https://www.kaggle.com/c/classify-leaves

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/80720.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux网络】Http服务优化 - 增加请求后缀、状态码描述、重定向、自动跳转及注册多功能服务

📢博客主页:https://blog.csdn.net/2301_779549673 📢博客仓库:https://gitee.com/JohnKingW/linux_test/tree/master/lesson 📢欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正! &…

AIGC(生成式AI)试用 32 -- AI做软件程序测试 3

总结之前的AI做程序测试过程,试图优化提问方式,整合完成的AI程序测试提问,探索更多可能的AI测试 AIGC(生成式AI)试用 30 -- AI做软件程序测试 1 AIGC(生成式AI)试用 31 -- AI做软件程序…

C语言实现迪杰斯特拉算法进行路径规划

使用C语言实现迪杰斯特拉算法进行路径规划 迪杰斯特拉算法是一种用于寻找加权图中最短路径的经典算法。它特别适合用于计算从一个起点到其他所有节点的最短路径,前提是图中的边权重为非负数。 一、迪杰斯特拉算法的基本原理 迪杰斯特拉算法的核心思想是“贪心法”…

引领印尼 Web3 变革:Mandala Chain 如何助力 1 亿用户迈向数字未来?

当前 Web3 的发展正处于关键转折点,行业亟需吸引新用户以推动 Web3 的真正大规模采用。然而,大规模采用面临着核心挑战:数据泄露风险、集中存储的安全漏洞、跨系统互操作性障碍,以及低效的服务访问等问题。如何才能真正突破这些瓶…

WebSocket是h5定义的,双向通信,节省资源,更好的及时通信

浏览器和服务器之间的通信更便利,比http的轮询等效率提高很多, WebSocket并不是权限的协议,而是利用http协议来建立连接 websocket必须由浏览器发起请求,协议是一个标准的http请求,格式如下 GET ws://example.com:3…

Kaamel白皮书:IoT设备安全隐私评估实践

1. IoT安全与隐私领域的现状与挑战 随着物联网技术的快速发展,IoT设备在全球范围内呈现爆发式增长。然而,IoT设备带来便捷的同时,也引发了严峻的安全与隐私问题。根据NSF(美国国家科学基金会)的研究表明,I…

php安装swoole扩展

PHP安装swoole扩展 Swoole官网 安装准备 安装前必须保证系统已经安装了下列软件 4.8 版本需要 PHP-7.2 或更高版本5.0 版本需要 PHP-8.0 或更高版本6.0 版本需要 PHP-8.1 或更高版本gcc-4.8 或更高版本makeautoconf 安装Swool扩展 安装官方文档安装后需要再php.ini中增加…

服务器传输数据存储数据建议 传输慢的原因

一、JSON存储的局限性 1. 性能瓶颈 全量读写:JSON文件通常需要整体加载到内存中才能操作,当数据量大时(如几百MB),I/O延迟和内存占用会显著增加。 无索引机制:查找数据需要遍历所有条目(时间复…

Android四大核心组件

目录 一、为什么需要四大组件? 二、Activity:看得见的界面 核心功能 生命周期图解 代码示例 三、Service:看不见的劳动者 两大类型 生命周期对比 注意陷阱 四、BroadcastReceiver:消息传递专员 两种注册方式 广播类型 …

「Mac畅玩AIGC与多模态01」架构篇01 - 展示层到硬件层的架构总览

一、概述 AIGC(AI Generated Content)系统由多个结构层级组成,自上而下涵盖交互界面、API 通信、模型推理、计算框架、底层驱动与硬件支持。本篇梳理 AIGC 应用的六层体系结构,明确各组件在系统中的职责与上下游关系,…

[MERN 项目实战] MERN Multi-Vendor 电商平台开发笔记(v2.0 从 bug 到结构优化的工程记录)

[MERN 项目实战] MERN Multi-Vendor 电商平台开发笔记(v2.0 从 bug 到结构优化的工程记录) 其实之前没想着这么快就能把 2.0 的笔记写出来的,之前的预期是,下一个阶段会一直维持到将 MERN 项目写完,毕竟后期很多东西都…

互斥量函数组

头文件 #include <pthread.h> pthread_mutex_init 函数原型&#xff1a; int pthread_mutex_init(pthread_mutex_t *restrict mutex, const pthread_mutexattr_t *restrict attr); 函数参数&#xff1a; mutex&#xff1a;指向要初始化的互斥量的指针。 attr&#xf…

互联网的下一代脉搏:深入理解 QUIC 协议

互联网的下一代脉搏&#xff1a;深入理解 QUIC 协议 互联网是现代社会的基石&#xff0c;而数据在其中高效、安全地传输是其运转的关键。长期以来&#xff0c;传输层的 TCP&#xff08;传输控制协议&#xff09;一直是互联网的主力军。然而&#xff0c;随着互联网应用场景的日…

全球城市范围30米分辨率土地覆盖数据(1985-2020)

Global urban area 30 meter resolution land cover data (1985-2020) 时间分辨率年空间分辨率10m - 100m共享方式保护期 277 天 5 时 42 分 9 秒数据大小&#xff1a;8.98 GB数据时间范围&#xff1a;1985-2020元数据更新时间2024-01-11 数据集摘要 1985~2020全球城市土地覆…

【Vue】单元测试(Jest/Vue Test Utils)

个人主页&#xff1a;Guiat 归属专栏&#xff1a;Vue 文章目录 1. Vue 单元测试简介1.1 为什么需要单元测试1.2 测试工具介绍 2. 环境搭建2.1 安装依赖2.2 配置 Jest 3. 编写第一个测试3.1 组件示例3.2 编写测试用例3.3 运行测试 4. Vue Test Utils 核心 API4.1 挂载组件4.2 常…

数据湖的管理系统管什么?主流产品有哪些?

一、数据湖的管理系统管什么&#xff1f; 数据湖的管理系统主要负责管理和优化存储在数据湖中的大量异构数据&#xff0c;确保这些数据能够被有效地存储、处理、访问和治理。以下是数据湖管理系统的主要职责&#xff1a; 数据摄入管理&#xff1a;管理系统需要支持从多种来源&…

英文中日期读法

英文日期的读法和写法因地区&#xff08;英式英语与美式英语&#xff09;和正式程度有所不同&#xff0c;以下是详细说明&#xff1a; 一、日期格式 英式英语 (日-月-年) 写法&#xff1a;1(st) January 2023 或 1/1/2023读法&#xff1a;"the first of January, twenty t…

衡量矩阵数值稳定性的关键指标:矩阵的条件数

文章目录 1. 定义2. 为什么要定义条件数&#xff1f;2.1 分析线性系统 A ( x Δ x ) b Δ b A(x \Delta x) b \Delta b A(xΔx)bΔb2.2 分析线性系统 ( A Δ A ) ( x Δ x ) b (A \Delta A)(x \Delta x) b (AΔA)(xΔx)b2.3 定义矩阵的条件数 3. 性质及几何意义3…

4月22日复盘-开始卷积神经网络

4月24日复盘 一、CNN 视觉处理三大任务&#xff1a;图像分类、目标检测、图像分割 上游&#xff1a;提取特征&#xff0c;CNN 下游&#xff1a;分类、目标、分割等&#xff0c;具体的业务 1. 概述 ​ 卷积神经网络是深度学习在计算机视觉领域的突破性成果。在计算机视觉领…

【网络原理】从零开始深入理解TCP的各项特性和机制.(三)

上篇介绍了网络原理传输层TCP协议的知识,本篇博客给大家带来的是网络原理剩余的内容, 总体来说,这部分内容没有上两篇文章那么重要,本篇知识有一个印象即可. &#x1f40e;文章专栏: JavaEE初阶 &#x1f680;若有问题 评论区见 ❤ 欢迎大家点赞 评论 收藏 分享 如果你不知道分…