PyTorch优化器

        PyTorch 提供了多种优化算法用于神经网络的参数优化。以下是对 PyTorch 中主要优化器的全面介绍,包括它们的原理、使用方法和适用场景。

一、基本优化器

1. SGD (随机梯度下降)

torch.optim.SGD(params, lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False)
  • 特点:

    • 最基本的优化器

    • 可以添加动量(momentum)加速收敛

    • 支持Nesterov动量

  • 参数:

    • lr: 学习率(必需)

    • momentum: 动量因子(0-1)

    • weight_decay: L2正则化系数

  • 适用场景: 大多数基础任务

    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

2. Adam (自适应矩估计)

torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
  • 特点:

    • 自适应学习率

    • 结合了动量法和RMSProp的优点

    • 通常需要较少调参

  • 参数:

    • betas: 用于计算梯度及其平方的移动平均系数

    • eps: 数值稳定项

    • amsgrad: 是否使用AMSGrad变体

  • 适用场景: 深度学习默认选择

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

二、自适应优化器

1. Adagrad

torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
  • 特点:

    • 自适应学习率

    • 为每个参数保留学习率

    • 适合稀疏数据

  • 缺点: 学习率会单调递减

2. RMSprop

torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
  • 特点:

    • 解决Adagrad学习率急剧下降问题

    • 适合非平稳目标

    • 常用于RNN

 3. Adadelta

torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
  • 特点:

    • 不需要设置初始学习率

    • 是Adagrad的扩展

三、其他优化器 

1. AdamW

torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
  • 特点:

    • Adam的改进版

    • 更正确的权重衰减实现

    • 通常优于Adam

2. SparseAdam

torch.optim.SparseAdam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08)
  • 特点: 专为稀疏张量优化

3. LBFGS 

torch.optim.LBFGS(params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100)
  • 特点:

    • 准牛顿方法

    • 内存消耗大

    • 适合小批量数据

四、优化器选择指南

优化器适用场景优点缺点
SGD基础任务简单可控需要手动调整学习率
SGD+momentum大多数任务加速收敛需要调参
Adam深度学习默认自适应学习率可能不如SGD泛化好
AdamW带权重衰减的任务更正确的实现-
Adagrad稀疏数据自动调整学习率学习率单调减
RMSpropRNN/非平稳目标解决Adagrad问题-

五、学习率调度器

PyTorch还提供了学习率调度器,可与优化器配合使用:

# 创建优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 创建调度器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)# 训练循环中
for epoch in range(100):train(...)validate(...)scheduler.step()  # 更新学习率

常用调度器:

  • LambdaLR: 自定义函数调整

  • MultiplicativeLR: 乘法更新

  • StepLR: 固定步长衰减

  • MultiStepLR: 多步长衰减

  • ExponentialLR: 指数衰减

  • CosineAnnealingLR: 余弦退火

  • ReduceLROnPlateau: 根据指标动态调整

六、优化器使用技巧

  1. 参数分组: 不同层使用不同学习率

    optimizer = torch.optim.SGD([{'params': model.base.parameters(), 'lr': 0.001},{'params': model.classifier.parameters(), 'lr': 0.01}
    ], momentum=0.9)
  2. 梯度裁剪: 防止梯度爆炸

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  3. 零梯度: 每次迭代前清空梯度

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/75445.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++的UDP连接解析域名地址错误

背景 使用c开发一个udp连接功能的脚本,可以接收发送数据,而且地址是经过内网穿透到外网的 经过 通常发送数据给目标地址,需要把目的地址结构化,要么使用inet_addr解析ip地址,要么使用inet_pton sockaddr_in target…

Spark,上传文件

上传文件 1.上传 先使用命令打开HDFS的NameNode [roothadoop100 hadoop-3.1.3]$ sbin/start-dfs.sh [roothadoop100 hadoop-3.1.3]$ sbin/stop-dfs.sh 和YARN的Job [roothadoop101 hadoop-3.1.3]$ sbin/start-yarn.sh [roothadoop101 hadoop-3.1.3]$ sbin/stop-yarn.sh 在Nam…

如何为Linux/Android Kernel 5.4和5.15添加 fuse passthrough透传功能 ?

背景 参考:Google文档 FUSE 透传 参考此文档,目前kernel.org提供的fuse passthrough补丁在6.9版本之后,但想要在5.4和5.15版本内核做移植应该如何简单点呢?文档中提到 Android的内核为5.4 和 5.15版本内核做了fuse passthrough功…

Ubuntu 防火墙配置

Ubuntu 的防火墙配置可以参考文章:Firewall - Ubuntu Server documentation 22 端口 需要注意的是,在启动防火墙之前,需要先开放 22 端口。 否则 SSH 将会拒绝你连接防火墙。 开放 22 端口的命令为:sudo ufw allow 22 添加端…

Jetson 设备卸载 OpenCV 4.5.4 并编译安装 OpenCV 4.2.0

‌一、卸载 OpenCV 4.5.4‌ 清除已安装的 OpenCV 库‌ sudo apt-get purge libopencv* python3-opencv # 卸载所有APT安装的OpenCV包‌:ml-citation{ref"1,3" data"citationList"}sudo apt autoremove # 清理残留依赖‌:ml-citation{ref"1,4"…

《AI大模型应知应会100篇》第57篇:LlamaIndex使用指南:构建高效知识库

第57篇:LlamaIndex使用指南:构建高效知识库 摘要 在大语言模型(LLM)驱动的智能应用中,如何高效地管理和利用海量知识数据是开发者面临的核心挑战之一。LlamaIndex(原 GPT Index) 是一个专为构建…

Sentinel[超详细讲解]-4

🚓 主要讲解流控模式的 三种方式中的两种: 直接、链路🚀 1️⃣ 直接模式 🚎 直接模式:对资源本身进行限流,例如对某个接口进行限流,当该接口的访问频率超过设定的阈值时,直接拒绝新的…

工作记录 2017-03-24

工作记录 2017-03-24 序号 工作 相关人员 1 修改了邮件上的问题。 更新RD服务器。 郝 更新的问题 1、修改了New User时 init的保存。 2、文件的查询加了ID。 3、加了 patient insurance secondary 4、修改了payment detail的处理。 识别引擎监控 Ps (iCDA LOG :剔除…

裴蜀定理:整数解的奥秘

裴蜀定理:整数解的奥秘 在数学的世界里,裴蜀定理(Bzout’s Theorem)是数论中一个非常重要的定理,它揭示了二次方程和整数解之间的关系。它不仅仅是纯粹的理论知识,还在计算机科学、密码学、算法优化等多个…

python之 “__init__.py” 文件

提示:python之 “init.py” 文件 文章目录 前言一、Python 中 __init__.py 文件的理解1. What(是什么)2. Why(为什么需要)3. Where(在哪里使用)4. How(如何使用) 二、问题…

Gemini 2.5 Pro与Claude 3.7 Sonnet编程性能对比

AI领域的语言模型竞赛日趋白热化,尤其在编程辅助方面表现突出。 Gemini 2.5 Pro和Claude 3.7 Sonnet作为该领域的佼佼者,本文通过一系列编程测试与基准评估对两者的编码功能进行对比分析。 核心结论: • Gemini 2.5 Pro在SWE Bench硬核编程测试中以63.8%的通过率略胜Clau…

On Superresolution Effects in Maximum Likelihood Adaptive Antenna Arrays论文阅读

On Superresolution Effects in Maximum Likelihood Adaptive Antenna Arrays 1. 论文的研究目标与实际问题意义1.1 研究目标1.2 解决的实际问题1.3 实际意义2. 论文提出的新方法、模型与公式2.1 核心创新:标量化近似表达式关键推导步骤:公式优势:2.2 与经典方法的对比传统方…

GIT 撤销上次推送

注意:在执行下述操作之前先备份现有工作进度,如果不慎未保存,在代码编辑器中正在修改的文件下,使用CtrlZ 撤销试试 撤销推送的方法 情况 1:您刚刚推送到远程仓库 如果您的推送操作刚刚完成,并且没有其他…

透视飞鹤2024财报:如何打赢奶粉罐里的科技战?

去年乳制品行业压力还是不小的,尼尔森IQ指出2024年国内乳品市场仍处在收缩区间。但是,总有龙头能抗住压力,飞鹤最近交出的2024财报中就有很多亮点。 比如,2024年飞鹤营收207.5亿元、同比增长6%,净利润36.5亿元&#x…

解决STM32CubeMX中文注释乱码

本人采用【修改系统环境变量】的方法 1. 使用快捷键 win X,打开【系统R】,点击【高级系统设置】 2. 点击【环境变量】 3. 点击【新建】 4.按图中输入【JAVA_TOOL_OPTIONS】和【-Dfile.encodingUTF-8】,新建环境变量后重启CubeMX即可。 解释…

使用typescript实现游戏中的JPS跳点寻路算法

JPS是一种优化A*算法的路径规划算法,主要用于网格地图,通过跳过不必要的节点来提高搜索效率。它利用路径的对称性,只扩展特定的“跳点”,从而减少计算量。 deepseek生成的总是无法完整运行,因此决定手写一下。 需要注…

Jetpack Compose 状态管理指南:从基础到高级实践

在Jetpack Compose中,界面状态管理是构建响应式UI的核心。以下是Compose状态管理的主要概念和实现方式: 基本状态管理 1. 使用 mutableStateOf Composable fun Counter() {var count by remember { mutableStateOf(0) }Button(onClick { count }) {T…

vant4+vue3上传一个pdf文件并实现pdf的预览。使用插件pdf.js

注意下载的插件的版本"pdfjs-dist": "^2.2.228", npm i pdfjs-dist2.2.228 然后封装一个pdf的遮罩。因为pdf文件有多页,所以我用了swiper轮播的形式展示。因为用到移动端,手动滑动页面这样比点下一页下一页的方便多了。 直接贴代码…

Leetcode hot 100(day 4)

翻转二叉树 做法:递归即可,注意判断为空 class Solution { public:TreeNode* invertTree(TreeNode* root) {if(rootnullptr)return nullptr;TreeNode* noderoot->left;root->leftinvertTree(root->right);root->rightinvertTree(node);retu…

C,C++语言缓冲区溢出的产生和预防

缓冲区溢出的定义 缓冲区是内存中用于存储数据的一块连续区域,在 C 和 C 里,常使用数组、指针等方式来操作缓冲区。而缓冲区溢出指的是当程序向缓冲区写入的数据量超出了该缓冲区本身能够容纳的最大数据量时,额外的数据就会覆盖相邻的内存区…