PyTorch 提供了多种优化算法用于神经网络的参数优化。以下是对 PyTorch 中主要优化器的全面介绍,包括它们的原理、使用方法和适用场景。
一、基本优化器
1. SGD (随机梯度下降)
torch.optim.SGD(params, lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False)
-
特点:
-
最基本的优化器
-
可以添加动量(momentum)加速收敛
-
支持Nesterov动量
-
-
参数:
-
lr
: 学习率(必需) -
momentum
: 动量因子(0-1) -
weight_decay
: L2正则化系数
-
-
适用场景: 大多数基础任务
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
2. Adam (自适应矩估计)
torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
-
特点:
-
自适应学习率
-
结合了动量法和RMSProp的优点
-
通常需要较少调参
-
-
参数:
-
betas
: 用于计算梯度及其平方的移动平均系数 -
eps
: 数值稳定项 -
amsgrad
: 是否使用AMSGrad变体
-
-
适用场景: 深度学习默认选择
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
二、自适应优化器
1. Adagrad
torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
-
特点:
-
自适应学习率
-
为每个参数保留学习率
-
适合稀疏数据
-
-
缺点: 学习率会单调递减
2. RMSprop
torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
-
特点:
-
解决Adagrad学习率急剧下降问题
-
适合非平稳目标
-
常用于RNN
-
3. Adadelta
torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
-
特点:
-
不需要设置初始学习率
-
是Adagrad的扩展
-
三、其他优化器
1. AdamW
torch.optim.AdamW(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
-
特点:
-
Adam的改进版
-
更正确的权重衰减实现
-
通常优于Adam
-
2. SparseAdam
torch.optim.SparseAdam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08)
-
特点: 专为稀疏张量优化
3. LBFGS
torch.optim.LBFGS(params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100)
-
特点:
-
准牛顿方法
-
内存消耗大
-
适合小批量数据
-
四、优化器选择指南
优化器 | 适用场景 | 优点 | 缺点 |
---|---|---|---|
SGD | 基础任务 | 简单可控 | 需要手动调整学习率 |
SGD+momentum | 大多数任务 | 加速收敛 | 需要调参 |
Adam | 深度学习默认 | 自适应学习率 | 可能不如SGD泛化好 |
AdamW | 带权重衰减的任务 | 更正确的实现 | - |
Adagrad | 稀疏数据 | 自动调整学习率 | 学习率单调减 |
RMSprop | RNN/非平稳目标 | 解决Adagrad问题 | - |
五、学习率调度器
PyTorch还提供了学习率调度器,可与优化器配合使用:
# 创建优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 创建调度器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)# 训练循环中
for epoch in range(100):train(...)validate(...)scheduler.step() # 更新学习率
常用调度器:
-
LambdaLR
: 自定义函数调整 -
MultiplicativeLR
: 乘法更新 -
StepLR
: 固定步长衰减 -
MultiStepLR
: 多步长衰减 -
ExponentialLR
: 指数衰减 -
CosineAnnealingLR
: 余弦退火 -
ReduceLROnPlateau
: 根据指标动态调整
六、优化器使用技巧
-
参数分组: 不同层使用不同学习率
optimizer = torch.optim.SGD([{'params': model.base.parameters(), 'lr': 0.001},{'params': model.classifier.parameters(), 'lr': 0.01} ], momentum=0.9)
-
梯度裁剪: 防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
-
零梯度: 每次迭代前清空梯度
optimizer.zero_grad() loss.backward() optimizer.step()