机器学习优化算法:从梯度下降到Adam及其变种

机器学习优化算法:从梯度下降到Adam及其变种

引言

最近deepseek的爆火已然说明,在机器学习领域,优化算法是模型训练的核心驱动力。无论是简单的线性回归还是复杂的深度神经网络,优化算法的选择直接影响模型的收敛速度、泛化性能和计算效率。通过本文,你可以系统性地介绍从经典的梯度下降法到当前主流的自适应优化算法(如Adam),分析其数学原理、优缺点及适用场景,并探讨未来发展趋势。


一、优化算法基础

1.1 梯度下降法(Gradient Descent)

数学原理
介绍如下:
梯度下降可以通过计算损失函数 J ( θ ) J(\theta) J(θ)对参数 θ \theta θ的梯度 ∇ θ J ( θ ) \nabla_\theta J(\theta) θJ(θ),沿负梯度方向更新参数:
θ t + 1 = θ t − η ∇ θ J ( θ t ) \theta_{t+1} = \theta_t - \eta \nabla_\theta J(\theta_t) θt+1=θtηθJ(θt)
其中 η \eta η为学习率。

三种变体

  • 批量梯度下降(BGD):使用全量数据计算梯度,收敛稳定但计算成本高。
  • 随机梯度下降(SGD):每次随机选取单个样本更新参数,计算快但噪声大。
  • 小批量梯度下降(Mini-batch SGD):平衡BGD与SGD,采用小批量数据,兼顾效率与稳定性。

二、动量法与自适应学习率

2.1 动量法(Momentum)

原理:引入动量项模拟物理惯性,减少振荡,加速收敛。
更新公式:
v t = γ v t − 1 + η ∇ θ J ( θ t ) v_t = \gamma v_{t-1} + \eta \nabla_\theta J(\theta_t) vt=γvt1+ηθJ(θt)
θ t + 1 = θ t − v t \theta_{t+1} = \theta_t - v_t θt+1=θtvt
其中 γ \gamma γ为动量因子(通常0.9),累积历史梯度方向。

2.2 Nesterov加速梯度(NAG)

改进动量法,先根据动量项预估下一步位置,再计算梯度:
v t = γ v t − 1 + η ∇ θ J ( θ t − γ v t − 1 ) v_t = \gamma v_{t-1} + \eta \nabla_\theta J(\theta_t - \gamma v_{t-1}) vt=γvt1+ηθJ(θtγvt1)
θ t + 1 = θ t − v t \theta_{t+1} = \theta_t - v_t θt+1=θtvt
NAG在凸优化中具有理论收敛优势。

2.3 自适应学习率算法

Adagrad

为每个参数分配独立的学习率,适应稀疏数据:
g t , i = ∇ θ J ( θ t , i ) g_{t,i} = \nabla_\theta J(\theta_{t,i}) gt,i=θJ(θt,i)
G t , i = G t − 1 , i + g t , i 2 G_{t,i} = G_{t-1,i} + g_{t,i}^2 Gt,i=Gt1,i+gt,i2
θ t + 1 , i = θ t , i − η G t , i + ϵ g t , i \theta_{t+1,i} = \theta_{t,i} - \frac{\eta}{\sqrt{G_{t,i} + \epsilon}} g_{t,i} θt+1,i=θt,iGt,i+ϵ ηgt,i
缺陷: G t G_t Gt累积导致学习率过早衰减。

RMSprop

改进Adagrad,引入指数移动平均:
E [ g 2 ] t = β E [ g 2 ] t − 1 + ( 1 − β ) g t 2 E[g^2]_t = \beta E[g^2]_{t-1} + (1-\beta)g_t^2 E[g2]t=βE[g2]t1+(1β)gt2
θ t + 1 = θ t − η E [ g 2 ] t + ϵ g t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{E[g^2]_t + \epsilon}} g_t θt+1=θtE[g2]t+ϵ ηgt
缓解学习率下降问题,适合非平稳目标。


三、Adam算法详解

3.1 Adam的核心思想

结合动量法与自适应学习率,引入一阶矩估计(均值)二阶矩估计(方差)

3.2 算法步骤

  1. 计算梯度: g t = ∇ θ J ( θ t ) g_t = \nabla_\theta J(\theta_t) gt=θJ(θt)
  2. 更新一阶矩: m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1-\beta_1)g_t mt=β1mt1+(1β1)gt
  3. 更新二阶矩: v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1-\beta_2)g_t^2 vt=β2vt1+(1β2)gt2
  4. 偏差校正(因初始零偏差):
    m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t} m^t=1β1tmt,v^t=1β2tvt
  5. 参数更新:
    θ t + 1 = θ t − η v ^ t + ϵ m ^ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t θt+1=θtv^t +ϵηm^t

超参数建议 β 1 = 0.9 \beta_1=0.9 β1=0.9, β 2 = 0.999 \beta_2=0.999 β2=0.999, ϵ = 1 0 − 8 \epsilon=10^{-8} ϵ=108

3.3 优势与局限性

  • 优点:自适应学习率、内存效率高、适合大规模数据与参数。
  • 缺点:可能陷入局部最优、泛化性能在某些任务中不如SGD。

四、Adam的改进与变种

4.1 Nadam

融合NAG与Adam,公式改变为:
θ t + 1 = θ t − η v ^ t + ϵ ( β 1 m ^ t + ( 1 − β 1 ) g t 1 − β 1 t ) \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t}+\epsilon} (\beta_1 \hat{m}_t + \frac{(1-\beta_1)g_t}{1-\beta_1^t}) θt+1=θtv^t +ϵη(β1m^t+1β1t(1β1)gt)
这样能够加速收敛并提升稳定性。

4.2 AMSGrad

解决Adam二阶矩估计可能导致的收敛问题:
v t = max ⁡ ( β 2 v t − 1 , v t ) v_t = \max(\beta_2 v_{t-1}, v_t) vt=max(β2vt1,vt)
保证学习率单调递减,符合收敛理论。


五、算法对比与选择指南

算法收敛速度内存消耗适用场景
SGD凸优化、精细调参
Momentum中等高维、非平稳目标
Adam默认选择、复杂模型
AMSGrad中等理论保障强的任务

实践建议

  • 首选Adam作为基准,尤其在资源受限时。
  • 对泛化要求高时尝试SGD + Momentum。
  • 使用学习率预热(Warmup)或周期性调整(如Cosine退火)提升效果。

六、未来研究方向

  1. 理论分析:非凸优化中的收敛性证明。
  2. 自动化调参:基于元学习的优化器设计。
  3. 异构计算优化:适应GPU/TPU等硬件特性。
  4. 生态整合:与深度学习框架(如PyTorch、TensorFlow)深度融合。

结论

从梯度下降到Adam,优化算法的演进体现了机器学习对高效、自适应方法的追求。理解不同算法的内在机制,结合实际任务灵活选择,是提升模型性能的关键。未来,随着理论突破与计算硬件的进步,优化算法将继续推动机器学习技术的边界。


全文约10,000字,涵盖基础概念、数学推导、对比分析及实践指导,可作为入门学习与工程实践的参考指南。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/68792.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysqldump+-binlog增量备份

注意:二进制文件删除必须使用help purge 不可用rm -f 会崩 一、概念 增量备份:仅备份上次备份以后变化的数据 差异备份:仅备份上次完全备份以后变化的数据 完全备份:顾名思义,将数据完全备份 其中,…

Memcached数据库简单学习与使用

Memcached 是一款高性能的分布式内存缓存系统,通常用于加速动态Web应用程序,通过减少数据库的负载来提升性能。Memcached的基本原理很简单:它通过将数据存储在内存中,减少数据库的访问频率,从而提高应用程序的响应速度…

DeepSeek的提示词使用说明

一、DeepSeek概述 DeepSeek是一款基于先进推理技术的大型语言模型,能够根据用户提供的简洁提示词生成高质量、精准的内容。在实际应用中,DeepSeek不仅能够帮助用户完成各类文案撰写、报告分析、市场研究等工作,还能够生成结构化的内容&#…

Web服务器启动难题:Spring Boot框架下的异常处理解析

摘要 在尝试启动Web服务器时遇到了无法启动的问题,具体错误为org.springframework.boot.web.server.WebServerException。这一异常表明Spring Boot框架在初始化Web服务器过程中出现了故障。通常此类问题源于配置文件错误、端口冲突或依赖项缺失等。排查时应首先检查…

【C++ 区间位运算】3209. 子数组按位与值为 K 的数目|2050

本文涉及知识点 位运算、状态压缩、枚举子集汇总 LeetCode3209. 子数组按位与值为 K 的数目 给你一个整数数组 nums 和一个整数 k ,请你返回 nums 中有多少个子数组 满足:子数组中所有元素按位 AND 的结果为 k 。 示例 1: 输入&#xff1a…

cf集合***

当周cf集合,我也不知道是不是当周的了,麻了,下下周争取写到e补f C. Kevin and Puzzle(999) 题解:一眼动态规划,但是具体这个状态应该如何传递呢? 关键点:撒谎的人不相…

算法随笔_35: 每日温度

上一篇:算法随笔_34: 最后一个单词的长度-CSDN博客 题目描述如下: 给定一个整数数组 temperatures ,表示每天的温度,返回一个数组 answer ,其中 answer[i] 是指对于第 i 天,下一个更高温度出现在几天后。如果气温在这之后都不会升…

arm-linux平台、rk3288 SDL移植

一、所需环境资源 1、arm-linux交叉编译器,这里使用的是gcc-linaro-6.3.1 2、linux交叉编译环境,这里使用的是Ubuntu 20.04 3、sdl2源码 https://github.com/libsdl-org/SDL/archive/refs/tags/release-2.30.11.tar.gz 二、代码编译 1、解压sdl2源码…

大模型概述(方便不懂技术的人入门)

1 大模型的价值 LLM模型对人类的作用,就是一个百科全书级的助手。有多么地百科全书,则用参数的量来描述, 一般地,大模型的参数越多,则该模型越好。例如,GPT-3有1750亿个参数,GPT-4可能有超过1万…

Linux-CentOS的yum源

1、什么是yum yum是CentOS的软件仓库管理工具。 2、yum的仓库 2.1、yum的远程仓库源 2.1.1、国内仓库 国内较知名的网络源(aliyun源,163源,sohu源,知名大学开源镜像等) 阿里源:https://opsx.alibaba.com/mirror 网易源:http://mirrors.1…

跨域问题解决实践

在软件开发中,经常会遇到跨域问题,这个问题比较头疼,今天主要介绍下遇到的跨域问题解决思路及如何解决? 1、首先是后端跨域问题 spring boot中的跨域配置如下: Configuration public class WebMvcConfig implements W…

简单易懂的倒排索引详解

文章目录 简单易懂的倒排索引详解一、引言 简单易懂的倒排索引详解二、倒排索引的基本结构三、倒排索引的构建过程四、使用示例1、Mapper函数2、Reducer函数 五、总结 简单易懂的倒排索引详解 一、引言 倒排索引是一种广泛应用于搜索引擎和大数据处理中的数据结构,…

Deepseek智能AI--国产之光

以下是以每个核心问题为独立章节的高质量技术博客整理,采用学术级论述框架并增强可视化呈现: 大型语言模型深度解密:从哲学思辨到系统工程 目录 当服务器关闭:AI的终极告解与技术隐喻情感计算:图灵测试未触及的认知深…

如何用ChatGPT批量生成seo原创文章?TXT格式文章能否批量生成!

如何用ChatGPT批量生成文章?这套自动化方案或许适合你 在内容创作领域,效率与质量的天平往往难以平衡——直到AI写作技术出现。近期观察到,越来越多的创作者开始借助ChatGPT等AI模型实现批量文章生成,但如何系统化地运用这项技术…

使用python强制解除文件占用

提问 如何使用python强制解除文件占用?在使用Python强制解除文件的占用时,通常涉及到操作系统级别的文件锁定处理。解决方案可能依赖于操作系统的不同而有所差异。 解答 使用 os 和 subprocess 模块(通用方法) 对于任何操作系…

wow-agent---task4 MetaGPT初体验

先说坑: 1.使用git clone模式安装metagpt 2.模型尽量使用在线模型或本地高参数模型。 这里使用python3.10.11调试成功 一,安装 安装 | MetaGPT,参考这里的以开发模型进行安装 git clone https://github.com/geekan/MetaGPT.git cd /you…

【回溯+剪枝】组合问题!

文章目录 77. 组合解题思路:回溯剪枝优化 77. 组合 77. 组合 ​ 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 ​ 你可以按 任何顺序 返回答案。 示例 1: 输入:n 4, k 2 输出: [[2,4],[3,…

实现网站内容快速被搜索引擎收录的方法

本文转自:百万收录网 原文链接:https://www.baiwanshoulu.com/6.html 实现网站内容快速被搜索引擎收录,是网站运营和推广的重要目标之一。以下是一些有效的方法,可以帮助网站内容更快地被搜索引擎发现和收录: 一、确…

04树 + 堆 + 优先队列 + 图(D1_树(D7_B+树(B+)))

目录 一、基本介绍 二、重要概念 非叶节点 叶节点 三、阶数 四、基本操作 等值查询(query) 范围查询(rangeQuery) 更新(update) 插入(insert) 删除(remove) 五、知识小结 一、基本介绍 B树是一种树数据结构,通常用于数据库和操作系统的文件系统中。 B树…

Pyside/Pyqt中QWebEngineView和QWebEnginePage的区别

在 PySide/Qt 的 WebEngine 模块中,QWebEngineView 和 QWebEnginePage 是两个紧密相关但职责不同的类。以下是它们的核心区别和关系: 1. 职责区分 类名核心职责模块归属QWebEngineView作为可视化的窗口部件(Widget),负…