pytorch 14.3 Batch Normalization综合调参实践

文章目录

    • 一、Batch Normalization与Batch_size综合调参
    • 二、复杂模型上的Batch_normalization表现
      • 1、BN对复杂模型(sigmoid)的影响
      • 2、模型复杂度对模型效果的影响
      • 3、BN对复杂模型(tanh)的影响
    • 三、包含BN层的神经网络的学习率优化
      • 1.学习率敏感度
      • 2.学习率学习曲线
      • 3.不同学习率下不同模型优化效果
    • 四、带BN层的神经网络模型综合调整策略总结

一、Batch Normalization与Batch_size综合调参

  我们知道,BN是一种在长期实践中被证明行之有效的优化方法,但在使用过程中首先需要知道,BN的理论基础(尽管不完全正确)是以BN层能够有效预估输入数据整体均值和方差为前提的,如果不能尽可能的从每次输入的小批数据中更准确的估计整体统计量,则后续的平移和放缩也将是有偏的。而由小批数据估计整体统计量的可信度其实是和小批数据本身数量相关的,如果小批数据数量太少,则进行整体统计量估计时就将有较大偏差,此时会影响模型准确率。
  因此,一般来说,我们在使用BN时,至少需要保证小批数据量(batch_size)在15-30以上,才能进行相对准确的预估。此处我们适当调整小批数据量参数,再进行模型计算。

# 进行数据集切分与加载
# 设置batch_size为50
train_loader, test_loader = split_loader(features, labels, batch_size=50)

在这里插入图片描述
在这里插入图片描述
我们发现,当提升batch_size之后,带BN层的模型效果有明显提升,相比原始模型,带BN层的模型拥有更快的收敛速度。

二、复杂模型上的Batch_normalization表现

1、BN对复杂模型(sigmoid)的影响

  一般来说,BN方法对于复杂模型和复杂数据会更加有效,换而言之,很多简单模型是没必要使用BN层(徒增计算量)。对于上述net_class1来说,由于只存在一个隐藏层,因此也不会存在梯度不平稳的现象,而BN层的优化效果也并不明显。接下来,我们尝试构建更加复杂的模型,来测试BN层的优化效果。

从另一个角度来说,其实我们是建议更频繁的使用更加复杂的模型并带上BN层的,核心原因在于,复杂模型带上BN层之后会有更大的优化空间。

接下来,我们尝试设置更加复杂的数据集,同时增加模型复杂度,测试在更加复杂的环境下BN层表现情况。
此处我们创建满足 y = 2 x 1 2 − x 2 2 + 3 x 3 2 + x 4 2 + 2 x 5 2 y=2x_1^2-x_2^2+3x_3^2+x_4^2+2x_5^2 y=2x12x22+3x32+x42+2x52的回归类数据集。

# 设置随机数种子
torch.manual_seed(420)  # 创建最高项为2的多项式回归数据集
features, labels = tensorGenReg(w=[2, -1, 3, 1, 2], bias=False, deg=2)# 进行数据集切分与加载
train_loader, test_loader = split_loader(features, labels, batch_size=50)

接下来,我们同时创建Sigmoid1-4,并且通过对比带BN层的模型和不带BN层的模型来进行测试。

# class1对比模型
# 设置随机数种子
torch.manual_seed(24)  # 实例化模型  
sigmoid_model1 = net_class1(act_fun= torch.sigmoid, in_features=5)
sigmoid_model1_norm = net_class1(act_fun= torch.sigmoid, in_features=5, BN_model='pre')# 创建模型容器
model_ls1 = [sigmoid_model1, sigmoid_model1_norm]           
name_ls1 = ['sigmoid_model1', 'sigmoid_model1_norm']# 核心参数
lr = 0.03
num_epochs = 40# 模型训练
train_ls1, test_ls1 = model_comparison(model_l = model_ls1, name_l = name_ls1, train_data = train_loader,test_data = test_loader,num_epochs = num_epochs, criterion = nn.MSELoss(), optimizer = optim.SGD, lr = lr, cla = False, eva = mse_cal)# class2对比模型
# 设置随机数种子
torch.manual_seed(24)  # 实例化模型  
sigmoid_model2 = net_class2(act_fun= torch.sigmoid, in_features=5)
sigmoid_model2_norm = net_class2(act_fun= torch.sigmoid, in_features=5, BN_model='pre')# 创建模型容器
model_ls2 = [sigmoid_model2, sigmoid_model2_norm]           
name_ls2 = ['sigmoid_model2', 'sigmoid_model2_norm']# 核心参数
lr = 0.03
num_epochs = 40# 模型训练
train_ls2, test_ls2 = model_comparison(model_l = model_ls2, name_l = name_ls2, train_data = train_loader,test_data = test_loader,num_epochs = num_epochs, criterion = nn.MSELoss(), optimizer = optim.SGD, lr = lr, cla = False, eva = mse_cal)# class3对比模型
# 设置随机数种子
torch.manual_seed(24)  # 实例化模型  
sigmoid_model3 = net_class3(act_fun= torch.sigmoid, in_features=5)
sigmoid_model3_norm = net_class3(act_fun= torch.sigmoid, in_features=5, BN_model='pre')# 创建模型容器
model_ls3 = [sigmoid_model3, sigmoid_model3_norm]           
name_ls3 = ['sigmoid_model3', 'sigmoid_model3_norm']# 核心参数
lr = 0.03
num_epochs = 40# 模型训练
train_ls3, test_ls3 = model_comparison(model_l = model_ls3, name_l = name_ls3, train_data = train_loader,test_data = test_loader,num_epochs = num_epochs, criterion = nn.MSELoss(), optimizer = optim.SGD, lr = lr, cla = False, eva = mse_cal)# class4对比模型
# 设置随机数种子
torch.manual_seed(24)  # 实例化模型  
sigmoid_model4 = net_class4(act_fun= torch.sigmoid, in_features=5)
sigmoid_model4_norm = net_class4(act_fun= torch.sigmoid, in_features=5, BN_model='pre')# 创建模型容器
model_ls4 = [sigmoid_model4, sigmoid_model4_norm]           
name_ls4 = ['sigmoid_model4', 'sigmoid_model4_norm']# 核心参数
lr = 0.03
num_epochs = 40# 模型训练
train_ls4, test_ls4 = model_comparison(model_l = model_ls4, name_l = name_ls4, train_data = train_loader,test_data = test_loader,num_epochs = num_epochs, criterion = nn.MSELoss(), optimizer = optim.SGD, lr = lr, cla = False, eva = mse_cal)
# 训练误差
plt.subplot(221)
for i, name in enumerate(name_ls1):plt.plot(list(range(num_epochs)), train_ls1[i], label=name)
plt.legend(loc = 1)
plt.title('mse_train_ls1')plt.subplot(222)
for i, name in enumerate(name_ls2):plt.plot(list(range(num_epochs)), train_ls2[i], label=name)
plt.legend(loc = 1)
plt.title('mse_train_ls2')plt.subplot(223)
for i, name in enumerate(name_ls3):plt.plot(list(range(num_epochs)), train_ls3[i], label=name)
plt.legend(loc = 1)
plt.title('mse_train_ls3')plt.subplot(224)
for i, name in enumerate(name_ls4):plt.plot(list(range(num_epochs)), train_ls4[i], label=name)
plt.legend(loc = 1)
plt.title('mse_train_ls4')
# 训练误差
plt.subplot(221)
for i, name in enumerate(name_ls1):plt.plot(list(range(num_epochs)), test_ls1[i], label=name)
plt.legend(loc = 1)
plt.title('mse_test_ls1')plt.subplot(222)
for i, name in enumerate(name_ls2):plt.plot(list(range(num_epochs)), test_ls2[i], label=name)
plt.legend(loc = 1)
plt.title('mse_test_ls2')plt.subplot(223)
for i, name in enumerate(name_ls3):plt.plot(list(range(num_epochs)), test_ls3[i], label=name)
plt.legend(loc = 1)
plt.title('mse_test_ls3')plt.subplot(224)
for i, name in enumerate(name_ls4):plt.plot(list(range(num_epochs)), test_ls4[i], label=name)
plt.legend(loc = 1)
plt.title('mse_test_ls4')

在这里插入图片描述
在这里插入图片描述
  由此,我们可以清楚的看到,BN层对更加复杂模型的优化效果更好。换而言之,越复杂的模型对于梯度不平稳的问题就越明显,因此BN层在解决该问题后模型效果提升就越明显。

2、模型复杂度对模型效果的影响

  并且,针对复杂数据集,在一定范围内,伴随模型复杂度提升,模型效果会有显著提升。但是,当模型太过于复杂时,仍然会出现模型效果下降的问题。

for i, name in enumerate(name_ls1):plt.plot(list(range(num_epochs)), test_ls1[i], label=name)
for i, name in enumerate(name_ls2):plt.plot(list(range(num_epochs)), test_ls2[i], label=name)
plt.legend(loc = 1)
plt.title('mse_test')
for i, name in enumerate(name_ls2):plt.plot(list(range(num_epochs)), test_ls2[i], label=name)
for i, name in enumerate(name_ls4):plt.plot(list(range(num_epochs)), test_ls4[i], label=name)
plt.legend(loc = 1)
plt.title('mse_test')

在这里插入图片描述
在这里插入图片描述

3、BN对复杂模型(tanh)的影响

  对于Sigmoid来说,BN层能很大程度上缓解梯度消失问题,从而提升模型收敛速度,并且小幅提升模型效果。而对于激活函数本身就能输出Zero-Centered结果的tanh函数,BN层的优化效果会更好。
训练结果:
在这里插入图片描述
测试结果:
在这里插入图片描述

  相比Sigmoid,使用tanh激活函数本身就是更加复杂的一种选择,因此,BN层在tanh上所表现出的更好的优化效果,也能看成是BN在复杂模型上效果有所提升。

三、包含BN层的神经网络的学习率优化

1.学习率敏感度

学习率lr对复杂模型(tanh)的影响

# 学习率 0.1
# 学习率 0.03
# 学习率 0.01
# 学习率 0.005

在这里插入图片描述
能够看出,随着学习率逐渐变化,拥有BN层的模型表现出更加剧烈的波动,这也说明拥有BN层的模型对学习率变化更加敏感。

2.学习率学习曲线

对于学习率的调整,一般都会出现倒U型曲线。我们能够发现,在当前模型条件下,学习率为0.005左右时模型效果较好。当然,我们这里也只取了四个值进行测试,也有可能最佳学习率在0.006或者0.0051,关于学习率参数的调整策略(LR-scheduler),我们将在下一节进行详细介绍,本节我们将利用此处实验得到的0.005作为学习率进行后续实验。
tanh_model3在不同学习率lr下的loss值
在这里插入图片描述

3.不同学习率下不同模型优化效果

  既然学习率学习曲线是U型曲线,那么U型的幅度其实就代表着学习率对于该模型的优化空间,这里我们可以通过简单实验,来观测不同模型的U型曲线的曲线幅度。首先,对于tanh2来说,带BN层的模型学习率优化效果比不带BN层学习率优化效果更好。

# 设置随机数种子
torch.manual_seed(24)  # 实例化模型  
tanh_model3 = net_class3(act_fun= torch.tanh, in_features=5)
tanh_model3_norm = net_class3(act_fun= torch.tanh, in_features=5, BN_model='pre')
tanh_model4 = net_class4(act_fun= torch.tanh, in_features=5)
tanh_model4_norm = net_class4(act_fun= torch.tanh, in_features=5, BN_model='pre')# 创建模型容器
model_l = [tanh_model3, tanh_model3_norm, tanh_model4, tanh_model4_norm]           
name_l = ['tanh_model3', 'tanh_model3_norm', 'tanh_model4', 'tanh_model4_norm']# 核心参数
lr = 0.001
num_epochs = 40# 模型训练 tanh_model3
train_l001, test_l001 = model_comparison(model_l = model_l, name_l = name_l, train_data = train_loader,test_data = test_loader,num_epochs = num_epochs, criterion = nn.MSELoss(), optimizer = optim.SGD, lr = lr, cla = False, eva = mse_cal)
lr_l = [0.03, 0.01, 0.005, 0.001]
train_ln = [train_l03[1:,-5:].mean(), train_l01[1:,-5:].mean(), train_l005[1:,-5:].mean(), train_l001[1:,-5:].mean()]
test_ln = [test_l03[1:,-5:].mean(), test_l01[1:,-5:].mean(), test_l005[1:,-5:].mean(), test_l001[1:,-5:].mean()]
train_l = [train_l03[0:,-5:].mean(), train_l01[0:,-5:].mean(), train_l005[0:,-5:].mean(), train_l001[0:,-5:].mean()]
test_l = [test_l03[0:,-5:].mean(), test_l01[0:,-5:].mean(), test_l005[0:,-5:].mean(), test_l1[0:,-5:].mean()]plt.subplot(121)
plt.plot(lr_l, train_ln, label='train_mse')
plt.plot(lr_l, test_ln, label='test_mse')
plt.legend(loc = 1)
plt.ylim(4, 25)
plt.title('With BN(tanh3)')plt.subplot(122)
plt.plot(lr_l, train_l, label='train_mse')
plt.plot(lr_l, test_l, label='test_mse')
plt.legend(loc = 1)
plt.ylim(4, 25)
plt.title('Without BN(tanh3)')

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

四、带BN层的神经网络模型综合调整策略总结

  最后,我们总结下截至目前,针对BN层的神经网络模型调参策略。

  • 简单数据、简单模型下不用BN层,加入BN层效果并不显著;
  • BN层的使用需要保持running_mean和running_var的无偏性,因此需要谨慎调整batch_size;
  • 学习率是重要的模型优化的超参数,一般来说学习率学习曲线都是U型曲线;
  • 从学习率调整角度出发,对于加入BN层的模型,学习率调整更加有效;对于带BN层模型角度来说,BN层能够帮助模型拓展优化空间,使得很多优化方法都能在原先无效的模型上生效;
  • 对于复杂问题,在计算能力能够承担的范围内,应当首先构建带BN层的复杂模型,然后再试图进行优化,就像上文所述,很多优化方法只对带BN层的模型有效;

其他拓展方面结论:

  • 关于BN和Xavier/Kaiming方法,一般来说,使用BN层的模型不再会用参数初始化方法,从理论上来看添加BN层能够起到参数初始化的相等效果;(另外,带BN层模型一般也不需要使用Dropout方法)
  • 本节尚未讨论ReLU激活函数的优化,相关优化方法将放在后续进行详细讨论,但需要知道的是,对于ReLU叠加的模型来说,加入BN层之后能够有效缓解Dead
    ReLU Problem,此时无须刻意调小学习率,能够在收敛速度和运算结果间保持较好的平衡。
  • BN层是目前大部分深度学习模型的标配,但前提是你有能力去对其进行优化;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/80226.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Model.eval() 与 torch.no_grad() PyTorch 中的区别与应用

Model.eval() 与 torch.no_grad(): PyTorch 中的区别与应用 在 PyTorch 深度学习框架中,model.eval() 和 torch.no_grad() 是两个在模型推理(inference)阶段经常用到的函数,它们各自有着独特的功能和应用场景。本文将详细解析这两…

Swagger go中文版本手册

Swaggo(github.com/swaggo/swag)的注解语法是基于 OpenAPI 2.0 (以前称为 Swagger 2.0) 规范的,并添加了一些自己的约定。 主要官方文档: swaggo/swag GitHub 仓库: 这是最权威的来源。 链接: https://github.com/swaggo/swag重点关注: README.md: 包含了基本的安装、使用…

物联网设备远程管理:基于代理IP的安全固件更新通道方案

在物联网设备远程管理中,固件更新的安全性直接关系到设备功能稳定性和系统抗攻击能力。结合代理IP技术与安全协议设计,可构建安全、高效的固件更新通道。 一、代理IP在固件更新中的核心作用 网络层隐匿与路由优化 隐藏更新源服务器:通过代理I…

【C++重载操作符与转换】句柄类与继承

目录 一、句柄类的基本概念 1.1 什么是句柄类 1.2 句柄类的设计动机 1.3 句柄类的基本结构 二、句柄类的实现方式 2.1 基于指针的句柄类 2.2 值语义的句柄类 2.3 引用计数的句柄类 三、句柄类与继承的结合应用 3.1 实现多态容器 3.2 实现插件系统 3.3 实现状态模式…

谷歌曾经的开放重定向漏洞(如今已经修复) -- noogle DefCamp 2024

题目描述: 上周,我决定创建自己的搜索引擎。这有点难,所以我背上了另一个。我也在8000端口上尝试了一些东西。 未发现题目任何交互,但是存在一个加密js const _0x43a57f _0x22f9; (function(_0x3d7d57, _0x426e05) {const _0x16c3fa _0x22f9, _0x3187…

【C#】ToArray的使用

在 C# 中&#xff0c;ToArray 方法通常用于将实现了 IEnumerable<T> 接口的集合&#xff08;如 List<T>&#xff09;转换为数组。这个方法是 LINQ 提供的一个扩展方法&#xff0c;位于 System.Linq 命名空间中。因此&#xff0c;在使用 ToArray 方法之前&#xff0…

资产管理平台—chemex

1、简介 Chemex CMDB&#xff08;Configuration Management Database&#xff09;是一个基于现代微服务架构的资产管理与自动化平台&#xff0c;专为 IT 基础设施与业务资产管理而设计。其核心目标是解决大规模系统运维中资产信息混乱、配置分散、数据不一致等问题&#xff0c…

【AI】mcp server是什么玩意儿

文章目录 背景mcp server的必要性mcp server的基本概念mcp server的架构与核心组件总结 背景 劈里啪啦的整了一堆概念&#xff0c;对mcp server还是只停留在知道个词的地步。 虽然目前大模型的对话生成能力很强&#xff0c;但是大模型&#xff08;如deepseek&#xff09;并不能…

c# 数据结构 树篇 入门树与二叉树的一切

事先声明,本文不适合对数据结构完全不懂的小白 请至少学会链表再阅读 c# 数据结构 链表篇 有关单链表的一切_c# 链表-CSDN博客 数据结构理论先导:《数据结构&#xff08;C 语言描述&#xff09;》也许是全站最良心最通俗易懂最好看的数据结构课&#xff08;最迟每周五更新~~&am…

《Cookie Cutter》中2000多张精灵表与10000个2D光源的管理之道

一个小团队如何在多个平台上以优秀的效果展示手绘动画&#xff1f;Subcult Joint 工作室给出了答案。他们用六年时间开发出了游戏《Cookie Cutter》。游戏中使用了数千个使用传统动画技术制作的高分辨率资产&#xff0c;而且这些资产都在 Unity 中进行了优化。由于工作室需要在…

什么是实景VR?实景VR应用场景

实景VR&#xff0c;即基于真实场景的虚拟现实技术&#xff0c;是利用计算机技术生成三维环境&#xff0c;以模拟并再现真实世界场景的技术。 用户通过佩戴VR设备&#xff08;如VR头盔、手柄等&#xff09;或通过电脑设备&#xff0c;可以沉浸在一个高度仿真的虚拟环境中&#…

内核性能测试(60s不丢包性能)

以xGAP-200-SE7K-L&#xff08;双口10G&#xff09;在飞腾D2000上为例&#xff08;单通道最高性能约2.8Gbps) 单口测试 0口&#xff1a; tcp&#xff1a; taskset -c 4 iperf -c 1.1.1.1 -i 1 -t 60 -p 60001 taskset -c 4 iperf -s -i 1 -p 60001 udp&#xff1a; taskse…

58. 区间和

题目链接&#xff1a; 58. 区间和 题目描述&#xff1a; 给定一个整数数组 Array&#xff0c;请计算该数组在每个指定区间内元素的总和。 输入描述 第一行输入为整数数组 Array 的长度 n&#xff0c;接下来 n 行&#xff0c;每行一个整数&#xff0c;表示数组的元素。随后…

C#进阶(2)stack(栈)

前言 我们前面介绍了ArrayList,今天就介绍另一种数据结构——栈。 这是栈的基本形式,博主简单画了一下,你看个意思就行,很明显,这种数据有一种特征:先进后出。因为先进来的数据会在下面,下面是密闭的,所以只能取后面进来的。 C#为我们封好了这种数据结构,我们不用担…

汽车工厂数字孪生实时监控技术从数据采集到三维驱动实现

在工业智能制造推动下&#xff0c;数字孪生技术正成为制造业数字化转型的核心驱动力。今天详细介绍数字孪生实时监控技术在汽车工厂中的应用&#xff0c;重点解析从数据采集到三维驱动实现的全流程技术架构&#xff0c;并展示其在提升生产效率、降低成本和优化决策方面的显著价…

git|gitee仓库同步到github

参考&#xff1a;一次提交更新两个仓库&#xff0c;Get 更优雅的 GitHub/Gitee 仓库镜像同步 文章目录 进入需要使用镜像功能的仓库&#xff0c;进入「管理」找到「仓库镜像管理」选项&#xff0c;点击「添加镜像」按钮绑定github绑定成功后再次点击添加镜像如何申请 GitHub 私…

原生小程序+springboot+vue+协同过滤算法的音乐推荐系统(源码+论文+讲解+安装+部署+调试)

感兴趣的可以先收藏起来&#xff0c;还有大家在毕设选题&#xff0c;项目以及论文编写等相关问题都可以给我留言咨询&#xff0c;我会一一回复&#xff0c;希望帮助更多的人。 系统背景 在数字音乐产业迅猛发展的当下&#xff0c;Spotify、QQ 音乐、网易云音乐等音乐平台的曲…

RustDesk

配置中继服务器 https://rustdesk.com/docs/zh-cn/self-host/windows/ 服务器端 下载Windows版本 rustdesk-server-windows-x86_64.zip&#xff0c;安装路径为&#xff1a;C:\Program Files\RustDeskServer\bin。执行 hbbr.exe 和 hbbs.exe 两个应用程序。这两个应用提供了两…

django中用 InforSuite RDS 替代memcache

在 Django 项目中&#xff0c;InforSuite RDS&#xff08;关系型数据库服务&#xff09;无法直接替代 Memcached&#xff0c;因为两者的设计目标和功能定位完全不同&#xff1a; 特性MemcachedInforSuite RDS核心用途高性能内存缓存&#xff0c;临时存储键值对数据持久化关系型…

leetcode 57. Insert Interval

题目描述 代码&#xff1a;由于intervals已经按照左端点排序&#xff0c;并且intervals中的区间全部不重叠&#xff0c;那么可以断定intervals中所有区间的右端点也已经是有序的。先二分查找intervals中第一个其右端点>newInterval左端点的区间。然后按照类似于56. Merge In…