[预备知识]6. 优化理论(二)

优化理论

本章节介绍深度学习中的高级优化技术,包括学习率衰减、梯度裁剪和批量归一化。这些技术能够显著提升模型的训练效果和稳定性。

学习率衰减(Learning Rate Decay)

数学原理与可视化

学习率衰减策略的数学表达:

  1. 步进式衰减
    α t = α 0 × γ ⌊ t / s ⌋ \alpha_t = \alpha_0 \times \gamma^{\lfloor t/s \rfloor} αt=α0×γt/s
    其中 s s s为衰减周期, γ \gamma γ为衰减因子

  2. 指数衰减
    α t = α 0 × e − γ t \alpha_t = \alpha_0 \times e^{-\gamma t} αt=α0×eγt

  3. 余弦衰减
    α t = α min + 1 2 ( α 0 − α min ) ( 1 + cos ⁡ ( t π T ) ) \alpha_t = \alpha_{\text{min}} + \frac{1}{2}(\alpha_0 - \alpha_{\text{min}})(1 + \cos(\frac{t\pi}{T})) αt=αmin+21(α0αmin)(1+cos(Ttπ))

import matplotlib.pyplot as plt# 衰减策略可视化
epochs = 100
initial_lr = 0.1# 计算各策略学习率
step_lrs = [initial_lr * (0.1 ** (i//30)) for i in range(epochs)]
expo_lrs = [initial_lr * (0.95 ** i) for i in range(epochs)]
cosine_lrs = [0.01 + 0.5*(0.1-0.01)*(1 + np.cos(np.pi*i/epochs)) for i in range(epochs)]# 绘制对比图
plt.figure(figsize=(12,6))
plt.plot(step_lrs, label='Step Decay (每30步×0.1)')
plt.plot(expo_lrs, label='Exponential Decay (γ=0.95)')
plt.plot(cosine_lrs, label='Cosine Decay (T=100)')
plt.title("不同学习率衰减策略对比", fontsize=14)
plt.xlabel("训练周期", fontsize=12)
plt.ylabel("学习率", fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)

在这里插入图片描述

最佳实践

# 组合使用多种调度器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)# 前50步使用余弦衰减
scheduler1 = CosineAnnealingLR(optimizer, T_max=50)
# 之后使用步进衰减
scheduler2 = StepLR(optimizer, step_size=10, gamma=0.5)for epoch in range(100):train(...)if epoch < 50:scheduler1.step()else:scheduler2.step()

梯度裁剪(Gradient Clipping)

数学原理

梯度裁剪通过限制梯度范数防止参数更新过大:
if  ∥ g ∥ > c : g ← c ∥ g ∥ g \text{if } \|g\| > c: \quad g \gets \frac{c}{\|g\|}g if g>c:ggcg
其中 c c c为裁剪阈值, ∥ g ∥ \|g\| g为梯度范数

梯度动态可视化

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 初始化模型、优化器和损失函数
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()grad_norms = []
clipped_grad_norms = []for _ in range(1000):# 生成随机输入和目标inputs = torch.randn(32, 10)targets = torch.randn(32, 1)# 前向传播outputs = model(inputs)loss = criterion(outputs, targets)# 反向传播optimizer.zero_grad()loss.backward()# 记录裁剪前梯度grad_norms.append(torch.norm(torch.cat([p.grad.view(-1) for p in model.parameters()])))# 执行裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)# 记录裁剪后梯度clipped_grad_norms.append(torch.norm(torch.cat([p.grad.view(-1) for p in model.parameters()])))# 更新参数optimizer.step()# 绘制梯度变化
plt.figure(figsize=(12, 6))
plt.plot(grad_norms, alpha=0.6, label='Original Gradient Norm')
plt.plot(clipped_grad_norms, alpha=0.6, label='Clipped Gradient Norm')
plt.axhline(1.0, color='r', linestyle='--', label='Clipping Threshold')
plt.yscale('log')
plt.title("Gradient Clipping Effect Monitoring", fontsize=14)
plt.xlabel("Training Steps", fontsize=12)
plt.ylabel("Gradient L2 Norm (log scale)", fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

在这里插入图片描述

实践技巧

  1. RNN中推荐值:LSTM/GRU 中 max_norm 取 1.0 或 2.0
  2. 结合学习率:较高学习率需配合较小裁剪阈值
  3. 监控策略:定期输出梯度统计量
print(f"梯度均值: {grad.mean().item():.3e} ± {grad.std().item():.3e}")

批量归一化(Batch Normalization)

数学推导

对于输入批次 B = { x 1 , . . . , x m } B = \{x_1,...,x_m\} B={x1,...,xm}

  1. 计算统计量:
    μ B = 1 m ∑ i = 1 m x i σ B 2 = 1 m ∑ i = 1 m ( x i − μ B ) 2 \mu_B = \frac{1}{m}\sum_{i=1}^m x_i \\ \sigma_B^2 = \frac{1}{m}\sum_{i=1}^m (x_i - \mu_B)^2 μB=m1i=1mxiσB2=m1i=1m(xiμB)2
  2. 标准化:
    x ^ i = x i − μ B σ B 2 + ϵ \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} x^i=σB2+ϵ xiμB
  3. 仿射变换:
    y i = γ x ^ i + β y_i = \gamma \hat{x}_i + \beta yi=γx^i+β

训练/评估模式对比

import torch.nn as nn
import torch
# 创建BN层
bn = nn.BatchNorm1d(64)# 训练模式
bn.train()
for _ in range(100):x = torch.randn(32, 64)  # 批大小32y = bn(x)
print("训练模式统计:", bn.running_mean[:5].detach().numpy())  # 显示部分通道# 评估模式
bn.eval()
with torch.no_grad():x = torch.randn(32, 64)y = bn(x)
print("评估模式统计:", bn.running_mean[:5].detach().numpy())

可视化BN效果

# 生成模拟数据
data = torch.cat([torch.normal(2.0, 1.0, (100, 1)),torch.normal(-1.0, 0.5, (100, 1))
], dim=1)# 应用BN
bn = nn.BatchNorm1d(2)
output = bn(data)# 绘制分布对比
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
sns.histplot(data[:,0], kde=True, label='原始特征1')
sns.histplot(data[:,1], kde=True, label='原始特征2')
plt.title("Distribution of features before BN")plt.subplot(1,2,2)
sns.histplot(output[:,0], kde=True, label='BN后特征1')
sns.histplot(output[:,1], kde=True, label='BN后特征2') 
plt.title("Distribution of features after BN")plt.tight_layout()

在这里插入图片描述


技术组合应用案例

图像分类任务

# 自定义CNN模型
class CustomCNN(nn.Module):def __init__(self):super().__init__()# 卷积层 使用BNself.conv_layers = nn.Sequential(nn.Conv2d(3, 64, 3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(64, 128, 3),nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(2),)# 全连接层self.fc = nn.Linear(128*5*5, 10)def forward(self, x):x = self.conv_layers(x)return self.fc(x.view(x.size(0), -1))# 初始化模型、优化器和调度器
model = CustomCNN()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=200)# 带梯度裁剪的训练循环
max_grad_norm = 5.0  # 裁剪阈值
for epoch in range(200):model.train()  # 模型进入训练模式for inputs, targets in train_loader:  # 训练数据加载器outputs = model(inputs)  # 前向传播loss = F.cross_entropy(outputs, targets)  # 计算损失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)optimizer.step()  # 参数更新scheduler.step()  # 学习率更新

关键技术总结

技术主要作用典型应用场景注意事项
学习率衰减精细收敛深层网络训练配合warmup效果更佳
梯度裁剪稳定训练RNN、Transformer阈值需随batch size调整
批量归一化加速收敛CNN、全连接网络小batch效果差

组合策略建议

  1. CNN架构:BN + 动量SGD + 余弦衰减
  2. RNN架构:梯度裁剪 + Adam + 步进衰减
  3. Transformer:预热 + 梯度裁剪 + AdamW
# Transformer优化示例
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.98))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda step: min(step**-0.5, step*(4000**-1.5))  # 预热
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/78137.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【计算机视觉】语义分割:Mask2Former:统一分割框架的技术突破与实战指南

深度解析Mask2Former&#xff1a;统一分割框架的技术突破与实战指南 技术架构与创新设计核心设计理念关键技术组件 环境配置与安装指南硬件要求安装步骤预训练模型下载 实战全流程解析1. 数据准备2. 配置文件定制3. 训练流程4. 推理与可视化 核心技术深度解析1. 掩膜注意力机制…

数字智慧方案5857丨智慧机场解决方案与应用(53页PPT)(文末有下载方式)

资料解读&#xff1a;智慧机场解决方案与应用 详细资料请看本解读文章的最后内容。 随着科技的飞速发展&#xff0c;智慧机场的建设已成为现代机场发展的重要方向。智慧机场不仅提升了旅客的出行体验&#xff0c;还极大地提高了机场的运营效率。本文将详细解读沃土数字平台在…

【C到Java的深度跃迁:从指针到对象,从过程到生态】第五模块·生态征服篇 —— 第二十章 项目实战:从C系统到Java架构的蜕变

一、跨语言重构&#xff1a;用Java重写Redis核心模块 1.1 Redis的C语言基因解析 Redis 6.0源码核心结构&#xff1a; // redis.h typedef struct redisObject { unsigned type:4; // 数据类型&#xff08;String/List等&#xff09; unsigned encoding:4; // …

ES6异步编程中Promise与Proxy对象

Promise 对象 Promise对象用于解决Javascript中的地狱回调问题&#xff0c;有效的减少了程序回调的嵌套调用。 创建 如果要创建一个Promise对象&#xff0c;最简单的方法就是直接new一个。但是&#xff0c;如果深入学习&#xff0c;会发现使用Promise下的静态方法Promise.re…

UE自动索敌插件Target System Component

https://www.fab.com/zh-cn/listings/9088334d-3bde-4e10-a937-baeb780f880f ​ 一个完全用 C 编写的 UE插件&#xff0c;添加了对简单相机锁定/瞄准系统的支持。它最初​​在蓝图中开发和测试&#xff0c;然后转换并重写为 C 模块和插件。 特征&#xff1a; 可通过一组可在…

中小企业MES系统概要设计

版本&#xff1a;V1.0 日期&#xff1a;2025年5月2日 一、系统架构设计 1.1 整体架构模式 采用分层微服务架构&#xff0c;实现模块解耦与灵活扩展&#xff0c;支持混合云部署&#xff1a; #mermaid-svg-drxS3XaKEg8H8rAJ {font-family:"trebuchet ms",verdana,ari…

STM32移植U8G2

STM32 移植 U8G2 u8g2 &#xff08;Universal 8bit Graphics Library version2 的缩写&#xff09;是用于嵌入式设备的单色图形库&#xff0c;可以在单色屏幕中绘制 GUI。u8g2 内部附带了例如 SSD13xx&#xff0c;ST7xx 等很多 OLED&#xff0c;LCD 驱动。内置多种不同大小和风…

Langchain,为何要名为langchian?

来听听 DeepSeek 怎么说 Human 2025-05-02T01:13:43.627Z langchain 是一个大语言模型开发框架。我的理解中&#xff0c;lang 是词根"语言"&#xff0c;chain是单词"链"&#xff0c;langchain 便是将语言模型和组件串联成链的框架。而 langchain 的图标是…

Windows下Python3脚本传到Linux下./example.py执行失败

1. 背景 大多数情况下通过pycharm编写Python代码&#xff0c;编写调试完&#xff0c;到Linux下发布执行。 以example.py脚本为例 #! /usr/bin/env python3 #! -*- encoding: utf-8 -*- def test(x,y): xint x yint y cxy return c if _name_"__main__": print(test(2…

当MCP撞进云宇宙:多芯片封装如何重构云计算的“芯“未来?

当MCP撞进云宇宙:多芯片封装如何重构云计算的"芯"未来? 2024年3月,AMD发布了震撼业界的MI300A/B芯片——这颗为AI计算而生的"超级芯片",首次在单封装内集成了13个计算芯片(包括3D V-Cache缓存、CDNA3 GPU和Zen4 CPU),用多芯片封装(Multi-Chip Pac…

用定时器做微妙延时注意事项

注意定时器来着APB1还是APB2&#xff0c;二者频率不一样&#xff0c;配置PSC要注意 &#xff08;1&#xff09;高级定时器timer1&#xff0c; timer8以及通用定时器timer9&#xff0c; timer10&#xff0c; timer11的时钟来源是APB2总线 &#xff08;2&#xff09;通用定时器ti…

三类思维坐标空间与时空序位信息处理架构

三类思维坐标空间与时空序位信息处理架构 一、静态信息元子与元组的数据结构设计 三维思维坐标空间定义 形象思维轴&#xff08;x&#xff09;&#xff1a;存储多媒体数据元子&#xff08;图像/音频/视频片段&#xff09; 元子结构&#xff1a;{ID, 数据块, 特征向量, 语义…

spring boot中@Validated

在 Spring Boot 中&#xff0c;Validated 是用于触发参数校验的注解&#xff0c;通常与 ​​JSR-303/JSR-380​​&#xff08;Bean Validation&#xff09;提供的校验注解一起使用。以下是常见的校验注解及其用法&#xff1a; ​1. 基本校验注解​​ 这些注解可以直接用于字段…

Hadoop 单机模式(Standalone Mode)部署与 WordCount 测试

通过本次实验&#xff0c;成功搭建了 Hadoop 单机环境并运行了基础 MapReduce 程序&#xff0c;为后续分布式计算学习奠定了基础。 掌握 Hadoop 单机模式的安装与配置方法。 熟悉 Hadoop 环境变量的配置及 Java 依赖管理。 使用 Hadoop 自带的 WordCount 示例程序进行简单的 …

历史数据分析——运输服务

运输服务板块简介: 运输服务板块主要是为货物与人员流动提供核心服务的企业的集合,涵盖铁路、公路、航空、海运、物流等细分领域。该板块具有强周期属性,与经济复苏、政策调控、供需关系密切关联,尤其是海运领域。有不少国内股市的铁路、公路等相关的上市公司同时属于红利…

openEuler 22.03 安装 Mysql 5.7,TAR离线安装

目录 一、检查系统是否安装其他版本Mariadb数据库二、环境检查2.1 必要环境检查2.2 在线安装&#xff08;有网络&#xff09;2.3 离线安装&#xff08;无网络&#xff09; 三、下载Mysql2.1 在线下载2.2 离线下载 四、安装Mysql五、配置Mysql六、开放防火墙端口七、数据备份八、…

喷泉码技术在现代物联网中的应用 设计

喷泉码技术在现代物联网中的应用 摘 要 喷泉码作为一种无速率编码技术,凭借其动态生成编码包的特性,在物联网通信中展现出独特的优势。其核心思想在于接收端只需接收到足够数量的任意编码包即可恢复原始数据,这种特性使其特别适用于动态信道和多用户场景。喷泉码的实现主要…

GZIPInputStream 类详解

GZIPInputStream 类详解 GZIPInputStream 是 Java 中用于解压缩 GZIP 格式数据的流类,属于 java.util.zip 包。它是 InflaterInputStream 的子类,专门处理 GZIP 压缩格式(.gz 文件)。 1. 核心功能 解压 GZIP 格式数据(RFC 1952 标准)自动处理 GZIP 头尾信息(校验和、时…

网络编程——TCP和UDP详细讲解

文章目录 TCP/UDP全面详解什么是TCP和UDP&#xff1f;TCP如何保证可靠性&#xff1f;1. 序列号&#xff08;Sequence Number&#xff09;2. 确认应答&#xff08;ACK&#xff09;3. 超时重传&#xff08;Timeout Retransmission&#xff09;4. 窗口控制&#xff08;Sliding Win…

性能测试工具篇

文章目录 目录1. JMeter介绍1.1 安装JMeter1.2 打开JMeter1.3 JMeter基础配置1.4 JMeter基本使用流程1.5 JMeter元件作用域和执行顺序 2. 重点组件2.1 线程组2.2 HTTP取样器2.3 查看结果树2.4 HTTP请求默认值2.5 JSON提取器2.6 用户定义的变量2.7 JSON断言2.8 同步定时器&#…