梯度下降优化算法-Adam

Adam(Adaptive Moment Estimation)是一种结合了动量法(Momentum)和 RMSProp 的自适应学习率优化算法。它通过计算梯度的一阶矩(均值)和二阶矩(未中心化的方差)来调整每个参数的学习率,从而在深度学习中表现出色。


1. Adam 的数学原理

1.1 动量法和 RMSProp 的回顾

  • 动量法:通过引入动量变量,加速梯度下降并减少震荡。
  • RMSProp:通过指数加权移动平均计算历史梯度平方和,自适应调整学习率。

Adam 结合了这两种方法的优点,同时计算梯度的一阶矩和二阶矩。


1.2 Adam 的更新规则

Adam 的更新规则分为以下几个步骤:

1.2.1 梯度计算

首先,计算当前时刻的梯度:

g t = ∇ θ J ( θ t ) g_t = \nabla_\theta J(\theta_t) gt=θJ(θt)

其中:

  • g t g_t gt 是当前时刻的梯度向量,形状与参数 θ t \theta_t θt 相同。

1.2.2 一阶矩估计(动量)

Adam 使用指数加权移动平均来计算梯度的一阶矩(均值):

m t = β 1 ⋅ m t − 1 + ( 1 − β 1 ) ⋅ g t m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t mt=β1mt1+(1β1)gt

其中:

  • m t m_t mt 是梯度的一阶矩估计。
  • β 1 \beta_1 β1 是一阶矩的衰减率,通常取值在 [ 0.9 , 0.99 ) [0.9, 0.99) [0.9,0.99) 之间。
  • 初始时, m 0 m_0 m0 通常设置为 0。

1.2.3 二阶矩估计(RMSProp)

Adam 使用指数加权移动平均来计算梯度的二阶矩(未中心化的方差):

v t = β 2 ⋅ v t − 1 + ( 1 − β 2 ) ⋅ g t 2 v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 vt=β2vt1+(1β2)gt2

其中:

  • v t v_t vt 是梯度的二阶矩估计。
  • β 2 \beta_2 β2 是二阶矩的衰减率,通常取值在 [ 0.99 , 0.999 ) [0.99, 0.999) [0.99,0.999) 之间。
  • g t 2 g_t^2 gt2 表示对梯度向量 g t g_t gt 逐元素平方。
  • 初始时, v 0 v_0 v0 通常设置为 0。

1.2.4 偏差校正

由于 m t m_t mt v t v_t vt 初始值为 0,在训练初期会偏向 0,因此需要进行偏差校正:

m ^ t = m t 1 − β 1 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t} m^t=1β1tmt

v ^ t = v t 1 − β 2 t \hat{v}_t = \frac{v_t}{1 - \beta_2^t} v^t=1β2tvt

其中:

  • m ^ t \hat{m}_t m^t 是校正后的一阶矩估计。
  • v ^ t \hat{v}_t v^t 是校正后的二阶矩估计。
  • t t t 是当前时间步。

1.2.5 参数更新

最后,Adam 的参数更新公式为:

θ t + 1 = θ t − η v ^ t + ϵ ⋅ m ^ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t θt+1=θtv^t +ϵηm^t

其中:

  • η \eta η 是全局学习率。
  • ϵ \epsilon ϵ 是一个很小的常数(通常为 1 0 − 8 10^{-8} 108),用于避免分母为零。
  • v ^ t + ϵ \sqrt{\hat{v}_t} + \epsilon v^t +ϵ 是对校正后的二阶矩估计逐元素开平方。

2. Adam 的详细推导

2.1 一阶矩和二阶矩的意义

  • 一阶矩 m t m_t mt:类似于动量法,表示梯度的指数加权移动平均,用于加速收敛。
  • 二阶矩 v t v_t vt:类似于 RMSProp,表示梯度平方的指数加权移动平均,用于自适应调整学习率。

2.2 偏差校正的作用

偏差校正的目的是解决初始阶段 m t m_t mt v t v_t vt 偏向 0 的问题。通过除以 1 − β 1 t 1 - \beta_1^t 1β1t 1 − β 2 t 1 - \beta_2^t 1β2t,可以校正估计值,使其更接近真实值。


2.3 小常数 ϵ \epsilon ϵ 的作用

小常数 ϵ \epsilon ϵ 的作用是避免分母为零。具体来说:

  • v ^ t \hat{v}_t v^t 很小时, v ^ t + ϵ \sqrt{\hat{v}_t} + \epsilon v^t +ϵ 接近于 ϵ \epsilon ϵ,避免学习率过大。
  • v ^ t \hat{v}_t v^t 很大时, ϵ \epsilon ϵ 的影响可以忽略不计。

3. PyTorch 中的 Adam 实现

在 PyTorch 中,Adam 通过 torch.optim.Adam 实现。以下是 torch.optim.Adam 的主要参数:

参数名含义
params需要优化的参数(通常是模型的参数)。
lr全局学习率(learning rate),即 η \eta η,默认值为 1 0 − 3 10^{-3} 103
betas一阶矩和二阶矩的衰减率,即 ( β 1 , β 2 ) (\beta_1, \beta_2) (β1,β2),默认值为 (0.9, 0.999)。
eps分母中的小常数 ϵ \epsilon ϵ,用于避免除零,默认值为 1 0 − 8 10^{-8} 108
weight_decay权重衰减(L2 正则化)系数,默认值为 0。
amsgrad是否使用 AMSGrad 变体,默认值为 False

3.1 使用 Adam 的代码示例

以下是一个使用 Adam 的完整代码示例:

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的线性模型
model = nn.Linear(10, 1)# 定义损失函数
criterion = nn.MSELoss()# 定义优化器,使用 Adam
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)# 模拟输入数据和目标数据
inputs = torch.randn(32, 10)  # 32 个样本,每个样本 10 维
targets = torch.randn(32, 1)  # 32 个目标值# 训练过程
for epoch in range(100):# 前向传播outputs = model(inputs)loss = criterion(outputs, targets)# 反向传播optimizer.zero_grad()  # 清空梯度loss.backward()        # 计算梯度# 更新参数optimizer.step()       # 更新参数# 打印损失if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch+1}/100], Loss: {loss.item():.4f}")

3.2 参数设置说明

  1. 学习率 (lr)

    • 学习率 η \eta η 控制每次参数更新的步长。
    • 在 Adam 中,学习率会自适应调整,因此初始学习率可以设置得稍小一些。
  2. 衰减率 (betas)

    • 一阶矩衰减率 β 1 \beta_1 β1 和二阶矩衰减率 β 2 \beta_2 β2 分别控制一阶矩和二阶矩的衰减速度。
    • 默认值为 (0.9, 0.999),适用于大多数情况。
  3. 小常数 (eps)

    • 小常数 ϵ \epsilon ϵ 用于避免分母为零,通常设置为 1 0 − 8 10^{-8} 108
  4. 权重衰减 (weight_decay)

    • 权重衰减系数用于 L2 正则化,防止过拟合。
  5. AMSGrad (amsgrad)

    • 如果设置为 True,则使用 AMSGrad 变体,解决 Adam 在某些情况下的收敛问题。

4. 总结

  • Adam 的核心思想:结合动量法和 RMSProp,通过计算梯度的一阶矩和二阶矩,自适应调整学习率。
  • Adam 的更新公式
    m t = β 1 ⋅ m t − 1 + ( 1 − β 1 ) ⋅ g t m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t mt=β1mt1+(1β1)gt
    v t = β 2 ⋅ v t − 1 + ( 1 − β 2 ) ⋅ g t 2 v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 vt=β2vt1+(1β2)gt2
    m ^ t = m t 1 − β 1 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t} m^t=1β1tmt
    v ^ t = v t 1 − β 2 t \hat{v}_t = \frac{v_t}{1 - \beta_2^t} v^t=1β2tvt
    θ t + 1 = θ t − η v ^ t + ϵ ⋅ m ^ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t θt+1=θtv^t +ϵηm^t
  • PyTorch 实现:使用 torch.optim.Adam,设置 lrbetaseps 等参数。
  • 优缺点
    • 优点:自适应学习率,适合非凸优化问题,收敛速度快。
    • 缺点:需要手动调整超参数(如 β 1 \beta_1 β1 β 2 \beta_2 β2)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/67212.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文笔记(六十三)Understanding Diffusion Models: A Unified Perspective(六)(完结)

Understanding Diffusion Models: A Unified Perspective(六)(完结) 文章概括指导(Guidance)分类器指导无分类器引导(Classifier-Free Guidance) 总结 文章概括 引用: …

【PySide6快速入门】信号与槽的使用

文章目录 前言什么是信号与槽信号与槽的功能最简单的信号与槽控件连接信号与信号的连接总结 前言 在 PySide6 中,信号与槽机制是核心概念之一,它是 Qt 库中事件通信的基础。通过信号与槽,开发者能够实现不同组件之间的解耦,从而使…

GOGOGO 枚举

含义:一种类似于类的一种结构 作用:是Java提供的一个数据类型,可以设置值是固定的 【当某一个数据类型受自身限制的时候,使用枚举】 语法格式: public enum 枚举名{…… }有哪些成员? A、对象 public …

AWTK 骨骼动画控件发布

Spine 是一款广泛使用的 2D 骨骼动画工具,专为游戏开发和动态图形设计设计。它通过基于骨骼的动画系统,帮助开发者创建流畅、高效的角色动画。本项目是基于 Spine 实现的 AWTK 骨骼动画控件。 代码:https://gitee.com/zlgopen/awtk-widget-s…

[免费]基于Python的Django博客系统【论文+源码+SQL脚本】

大家好,我是java1234_小锋老师,看到一个不错的基于Python的Django博客系统,分享下哈。 项目视频演示 【免费】基于Python的Django博客系统 Python毕业设计_哔哩哔哩_bilibili 项目介绍 随着互联网技术的飞速发展,信息的传播与…

如何将电脑桌面默认的C盘设置到D盘?详细操作步骤!

将电脑桌面默认的C盘设置到D盘的详细操作步骤! 本博文介绍如何将电脑桌面(默认为C盘)设置在D盘下。 首先,在D盘建立文件夹Desktop,完整的路径为D:\Desktop。winR,输入Regedit命令。(或者单击【…

C++ 写一个简单的加减法计算器

************* C topic:结构 ************* Structure is a very intersting issue. I really dont like concepts as it is boring. I would like to cases instead. If I want to learn something, donot hesitate to make shits. Like building a house. Wh…

excel如何查找一个表的数据在另外一个表是否存在

比如“Sheet1”有“张三”、“李四”“王五”三个人的数据,“Sheet2”只有“张三”、“李四”的数据。我们通过修改“Sheet1”的“民族”或者其他空的列,修改为“Sheet2”的某一列。这样修改后筛选这个修改的列为空的或者为出错的,就能找到两…

MySQL 基础学习(2): INSERT 操作

在这篇文章中,我们将专注于 MySQL 中的 INSERT 操作,深入了解如何高效地向表中插入数据,并探索插入操作中的一些常见错误与解决方案。 一、基础 INSERT 语法 在 MySQL 中,INSERT 操作用于向表中插入新记录,基本语法如…

CVE-2023-38831 漏洞复现:win10 压缩包挂马攻击剖析

目录 前言 漏洞介绍 漏洞原理 产生条件 影响范围 防御措施 复现步骤 环境准备 具体操作 前言 在网络安全这片没有硝烟的战场上,新型漏洞如同隐匿的暗箭,时刻威胁着我们的数字生活。其中,CVE - 2023 - 38831 这个关联 Win10 压缩包挂…

论文阅读(二):理解概率图模型的两个要点:关于推理和学习的知识

1.论文链接:Essentials to Understand Probabilistic Graphical Models: A Tutorial about Inference and Learning 摘要: 本章的目的是为没有概率图形模型背景或没有深入背景的科学家提供一个高级教程。对于更熟悉这些模型的读者,本章将作为…

记录 | 基于Docker Desktop的MaxKB安装

目录 前言一、MaxKBStep 1Step2 二、运行MaxKB更新时间 前言 参考文章:如何利用智谱全模态免费模型,生成大家都喜欢的图、文、视并茂的文章! MaxKB的Github下载地址 参考视频:【2025最新MaxKB教程】10分钟学会一键部署本地私人专属…

Go反射指南

概念: 官方对此有个非常简明的介绍,两句话耐人寻味: 反射提供一种让程序检查自身结构的能力反射是困惑的源泉 第1条,再精确点的描述是“反射是一种检查interface变量的底层类型和值的机制”。 第2条,很有喜感的自嘲…

第26篇 基于ARM A9处理器用C语言实现中断<二>

Q:基于ARM A9处理器怎样编写C语言工程,使用按键中断将数字显示在七段数码管上呢? A:基本原理:主程序需要首先调用子程序set_A9_IRQ_stack()初始化IRQ模式的ARM A9堆栈指针;然后主程序调用子程序config_GIC…

基于GS(Gaussian Splatting)的机器人Sim2Real2Sim仿真平台

项目地址:RoboGSim 背景简介 已有的数据采集方法中,遥操作(下左)是数据质量高,但采集成本高、效率低下;传统仿真流程成本低(下右),但真实度(如纹理、物理&…

「 机器人 」利用冲程对称性调节实现仿生飞行器姿态与方向控制

前言 在仿生扑翼飞行器中,通过改变冲程对称性这一技术手段,可以在上冲与下冲两个阶段引入不对称性,进而产生额外的力或力矩,用于实现俯仰或其他姿态方向的控制。以下从原理、在仿生飞行器中的应用和典型实验示例等方面进行梳理与阐述。 1. 冲程对称性原理 1.1 概念:上冲与…

MongoDB部署模式

目录 单节点模式(Standalone) 副本集模式(Replica Set) 分片集群模式(Sharded Cluster) MongoDB有多种部署模式,可以根据业务需求选择适合的架构和部署方式。 单节点模式(Standa…

微服务搭建----springboot接入Nacos2.x

springboot接入Nacos2.x nacos之前用的版本是1.0的,现在重新搭建一个2.0版本的,学如逆水行舟,不进则退,废话不多说,开搞 1、 nacos2.x搭建 1,首先第一步查询下项目之间的版本对照,不然后期会…

react-native网络调试工具Reactotron保姆级教程

在React Native开发过程中,调试和性能优化是至关重要的环节。今天,就来给大家分享一个非常强大的工具——Reactotron,它就像是一个贴心的助手,能帮助我们更轻松地追踪问题、优化性能。下面就是一份保姆级教程哦! 一、…

npm启动前端项目时报错(vue) error:0308010C:digital envelope routines::unsupported

vue 启动项目时,npm run serve 报下面的错: error:0308010C:digital envelope routines::unsupported at new Hash (node:internal/crypto/hash:67:19) at Object.createHash (node:crypto:133:10) at FSReqCallback.readFileAfterClose [as on…