37.深度学习中的梯度下降法及其实现

在深度学习的优化过程中,梯度下降法及其变体是必不可少的工具。通过对梯度下降法的理论学习,我们能够更好地理解深度学习模型的训练过程。本篇文章将介绍梯度下降的基本原理,并通过代码实现展示其具体应用。我们会从二维平面的简单梯度下降开始,逐步过渡到三维,再对比多种优化器的效果。

一、梯度下降法简介

梯度下降法(Gradient Descent)是一种常用的优化算法,广泛应用于机器学习和深度学习中。其基本思想是通过迭代更新参数,使得损失函数逐步减小,最终找到最优解。常见的梯度下降法包括随机梯度下降(SGD)、动量法(Momentum)、自适应学习率方法(Adagrad、RMSprop、Adadelta)和Adam等。

二、梯度下降的二维实现

首先,我们来实现一个简单的二维平面内的梯度下降法。目标是找到函数 \(f(x) = x^2 + 4x + 1\) 的最小值。

import torch
import matplotlib.pyplot as plt# 定义目标函数
def f(x):return x**2 + 4*x + 1# 初始化参数
x = torch.tensor([2.0], requires_grad=True)
learning_rate = 0.7# 记录每次梯度下降的值
xs, ys = [], []# 梯度下降迭代
for i in range(100):y = f(x)y.backward()with torch.no_grad():x -= learning_rate * x.gradx.grad.zero_()xs.append(x.item())ys.append(y.item())# 打印最终结果
print(f"最终x值: {x.item()}")# 可视化
x_vals = torch.linspace(-4, 2, 100)
y_vals = f(x_vals)
plt.plot(x_vals.numpy(), y_vals.numpy(), label='f(x)=x^2 + 4x + 1')
plt.scatter(xs, ys, color='red', label='Gradient Descent')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
plt.show()

通过以上代码,我们能够看到从初始点出发,梯度下降法逐步逼近最小值。

三、梯度下降的三维实现

增加一个维度,函数变为 \(f(x, y) = x^2 + y^2\),我们希望通过梯度下降法找到该函数的最小值。

from mpl_toolkits.mplot3d import Axes3D# 定义目标函数
def f(x, y):return x**2 + y**2# 初始化参数
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([2.0], requires_grad=True)
learning_rate = 0.1# 记录每次梯度下降的值
xs, ys, zs = [], [], []# 梯度下降迭代
for i in range(100):z = f(x, y)z.backward()with torch.no_grad():x -= learning_rate * x.grady -= learning_rate * y.gradx.grad.zero_()y.grad.zero_()xs.append(x.item())ys.append(y.item())zs.append(z.item())# 打印最终结果
print(f"最终x, y值: {x.item()}, {y.item()}")# 可视化
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(xs, ys, zs, label='Gradient Descent Path', marker='o')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.legend()
plt.show()# 等高线图
plt.figure()
X, Y = torch.meshgrid(torch.linspace(-3, 3, 100), torch.linspace(-3, 3, 100))
Z = f(X, Y)
plt.contourf(X.numpy(), Y.numpy(), Z.numpy(), 50)
plt.plot(xs, ys, 'r-o', label='Gradient Descent Path')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()

通过以上代码,我们能够在三维空间中看到梯度下降的路径。

四、不同优化器的对比

接下来,我们生成一个数据集,使用不同的优化器进行对比,观察它们的收敛效果。

 

import torch.utils.data as data# 生成数据集
def generate_data(num_samples=1000):x = torch.rand(num_samples, 1)y = torch.rand(num_samples, 1)z = f(x, y) + torch.randn(num_samples, 1)return x, y, zx, y, z = generate_data()
dataset = data.TensorDataset(torch.cat([x, y], dim=1), z)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = data.random_split(dataset, [train_size, test_size])
train_loader = data.DataLoader(train_dataset, batch_size=32)
test_loader = data.DataLoader(test_dataset, batch_size=32)# 定义模型
class SimpleNN(torch.nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc = torch.nn.Linear(2, 1)def forward(self, x):return self.fc(x)# 初始化模型和优化器
models = [SimpleNN() for _ in range(6)]
optimizers = [torch.optim.SGD(models[0].parameters(), lr=0.01),torch.optim.SGD(models[1].parameters(), lr=0.01, momentum=0.9),torch.optim.Adagrad(models[2].parameters(), lr=0.01),torch.optim.RMSprop(models[3].parameters(), lr=0.01),torch.optim.Adadelta(models[4].parameters()),torch.optim.Adam(models[5].parameters(), lr=0.01)
]
loss_fn = torch.nn.MSELoss()# 训练和测试函数
def train_epoch(model, optimizer, loader):model.train()total_loss = 0for x_batch, y_batch in loader:optimizer.zero_grad()y_pred = model(x_batch)loss = loss_fn(y_pred, y_batch)loss.backward()optimizer.step()total_loss += loss.item()return total_loss / len(loader)def test_epoch(model, loader):model.eval()total_loss = 0with torch.no_grad():for x_batch, y_batch in loader:y_pred = model(x_batch)loss = loss_fn(y_pred, y_batch)total_loss += loss.item()return total_loss / len(loader)# 记录误差
train_losses = [[] for _ in range(6)]
test_losses = [[] for _ in range(6)]# 训练和测试过程
num_epochs = 50
for epoch in range(num_epochs):for i in range(6):train_loss = train_epoch(models[i], optimizers[i], train_loader)test_loss = test_epoch(models[i], test_loader)train_losses[i].append(train_loss)test_losses[i].append(test_loss)# 可视化收敛曲线
plt.figure(figsize=(12, 6))
for i, name in enumerate(['SGD', 'Momentum', 'Adagrad', 'RMSprop', 'Adadelta', 'Adam']):plt.plot(train_losses[i], label=f'Train {name}')plt.plot(test_losses[i], '--', label=f'Test {name}')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

通过上述代码,我们能够对比不同优化器的收敛效果,从图中可以看到各个优化器的表现差异。

五、总结

本文通过代码实现详细展示了梯度下降法在二维和三维空间中的应用,并对比了多种优化器的效果。通过这些实践,我们能够更直观地理解梯度下降法的工作原理及其在深度学习中的应用。希望大家通过本篇文章,能够更加熟练地应用梯度下降及其变体进行模型训练。加油!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/870010.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用nodejs进行截图

创建一个空文件夹 初始化项目 npm init -y下载插件 yarn add puppeteer根目录下创建app.js放入以下内容: const puppeteer require(puppeteer);(async () > {// 启动 Puppeteer 并创建一个新浏览器实例// const browser await puppeteer.launch(); // 会在…

大白话讲解AI大模型

大白话讲解大模型 大模型的发展重要大模型发展时间线 大模型的简单原理-训练⼤模型是如何训练并应⽤到场景中的?如果训练私有化模型 模型:model 语料库:用于训练模型的数据 大模型的发展 详细信息来源:DataLearner 2022年11月底…

v-bind指令——03

v-bind 指令详解&#xff1a; 1 、这个指令是干嘛的&#xff1f; 可以让html标签的某个属性的值产生动态的效果 2、v-bind指令的语法格式&#xff1a;<HTML 标签 v-bind : 参数 “表达式”> </HTML> 3、v-bind指令的编译原理&#xff1a; 编译前&#xff1a…

声音的转译者:Transformer模型在语音识别中的革命性应用

声音的转译者&#xff1a;Transformer模型在语音识别中的革命性应用 在人工智能领域&#xff0c;语音到文本转换&#xff08;Speech-to-Text&#xff0c;STT&#xff09;技术正迅速发展&#xff0c;成为连接人类语言与机器理解的桥梁。Transformer模型&#xff0c;以其卓越的处…

关于 RK3588刷镜像升级镜像”没有发现设备“ 的解决方法

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/140287339 长沙红胖子Qt&#xff08;长沙创微智科&#xff09;博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV…

企业资产管理系统带万字文档公司资产管理系统java项目java课程设计java毕业设计

文章目录 企业资产管理系统一、项目演示二、项目介绍三、万字项目文档四、部分功能截图五、部分代码展示六、底部获取项目源码带万字文档&#xff08;9.9&#xffe5;带走&#xff09; 企业资产管理系统 一、项目演示 企业资产管理系统 二、项目介绍 语言&#xff1a;java 数…

javaweb学习day1《HTML篇》--新浪微博(前端页面的创建思路及其HTML、css代码详解)

一、前言 本篇章为javaweb的开端&#xff0c;也是第一篇综合案例&#xff0c;小编也是看着黑马程序员的视频对里面的知识点进行理解&#xff0c;然后自己找一个新浪微博网页看着做的&#xff0c;主要还是因为懒&#xff0c;不想去领黑马程序员的资料了。 小编任务javaweb和ja…

云端日历同步大师:iCloud让工作与生活井井有条

云端日历同步大师&#xff1a;iCloud让工作与生活井井有条 在快节奏的现代生活中&#xff0c;无论是工作还是个人生活&#xff0c;我们都需要一个可靠的日历应用来帮助我们管理日常事务和重要事件。iCloud作为苹果公司提供的云服务&#xff0c;其日历应用&#xff08;Apple Ca…

力扣-dfs

何为深度优先搜索算法&#xff1f; 深度优先搜索算法&#xff0c;即DFS。就是找一个点&#xff0c;往下搜索&#xff0c;搜索到尽头再折回&#xff0c;走下一个路口。 695.岛屿的最大面积 695. 岛屿的最大面积 题目 给你一个大小为 m x n 的二进制矩阵 grid 。 岛屿 是由一些相…

helm安装解决无授权问题

在安装kubesphere的时候需要先安装镜像管理工具helm它配合着tiller服务能方面地创建拉取地三镜像库更像一个本地的maven工具&#xff0c;安装helm可以通过脚本的方式担是容易被强&#xff0c;下载二进制的软件包解压得到helm把它移动/user/local/bin目录下&#xff0c;然后查看…

华为HCIP Datacom H12-821 卷33

1.判断题 缺省情况下&#xff0c;华为AR路由器的VRRP运行在抢占模式下 A、对 B、错 正确答案&#xff1a; A 解析&#xff1a; 无 2.判断题 一个Route-Policy下可以有多个节点&#xff0c;不同的节点号用节点号标识&#xff0c;不同节点之间的关系是"或"的关…

禁用华为小米?微软中国免费送iPhone15

微软中国将禁用华为和小米手机&#xff0c;要求员工必须使用iPhone。如果还没有iPhone&#xff0c;公司直接免费送你全新的iPhone 15&#xff01; 、 这几天在微软热度最高的话题就是这个免费发iPhone&#xff0c;很多员工&#xff0c;收到公司的通知。因为&#xff0c;登录公司…

精通Postman响应解析:正则表达式的实战应用

&#x1f9d0; 精通Postman响应解析&#xff1a;正则表达式的实战应用 在API测试和开发的世界中&#xff0c;Postman是一个强大的工具&#xff0c;它不仅可以发送请求、管理环境&#xff0c;还能使用正则表达式来解析响应。正则表达式是一种强大的文本处理工具&#xff0c;能够…

如何指定多块GPU卡进行训练-数据并行

训练代码&#xff1a; train.py import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset import torch.nn.functional as F# 假设我们有一个简单的文本数据集 class TextDataset(Dataset):def __init__(self, te…

Nginx中文URL请求404

这两天正在搞我的静态网站。方案是&#xff1a;从思源笔记Markdown笔记&#xff0c;用MkOcs build成静态网站&#xff0c;上传到到Nginx服务器。遇到一个问题&#xff1a;URL含有中文会404&#xff0c;全英文URL则正常访问。 ‍ 比如&#xff1a; ​​ ‍ 设置了utf-8 ht…

【Python基础】代码如何打包成exe可执行文件

本文收录于 《一起学Python趣味编程》专栏&#xff0c;从零基础开始&#xff0c;分享一些Python编程知识&#xff0c;欢迎关注&#xff0c;谢谢&#xff01; 文章目录 一、前言二、安装PyInstaller三、使用PyInstaller打包四、验证打包是否成功五、总结 一、前言 本文介绍如何…

Linux C语言基础 day8

目录 思维导图&#xff1a; 学习目标&#xff1a; 学习内容&#xff1a; 1. 字符数组 1.1 二维字符数组 1.1.1 格式 1.1.2 初始化 1.1.3 二维字符数组输入输出、求最值、排序 2. 函数 2.1 概念 关于函数的相关概念 2.2 函数的定义及调用 2.2.1 定义函数的格式 2.3…

数据采集:如何使用八爪鱼采集BOSS直聘职位数据

大家好&#xff0c;我是水哥&#xff01; 今天给大家分享的是数据采集实战&#xff1a;使用「八爪鱼」第三方工具来采集 BOSS 直聘上的数据分析职位数据。 接下来&#xff0c;我们详细看一看。 不重复造轮子 在工作中&#xff0c;我们一定要形成一个认知&#xff0c;能用第…

最新浪子授权系统网站源码 全开源免授权版本

最新浪子授权系统网站源码 全开源免授权版本 此版本没有任何授权我已经去除授权&#xff0c;随意二开无任何加密。 更新日志 1.修复不能下载 2.修复不能更新 3.修复不能删除用户 4.修复不能删除授权 5.增加代理后台管理 6.重写授权读取文件 7.修复已经知道漏洞 源码下…

土壤分析仪:解密土壤之奥秘的科技先锋

在农业生产和生态保护的道路上&#xff0c;土壤的质量与状况一直是我们关注的焦点。土壤分析仪&#xff0c;作为现代科技在农业和环保领域的杰出代表&#xff0c;以其高效、精准的分析能力&#xff0c;为我们揭示了土壤的奥秘&#xff0c;为农业生产提供了科学指导&#xff0c;…