【深度学习基础 2】 PyTorch 框架

目录

一、 PyTorch 简介

二、安装 PyTorch

三、PyTorch 常用函数和操作

3.1 创建张量(Tensor)

3.2 基本数学运算

3.3 自动求导(Autograd)

3.4 定义神经网络模型

3.5 训练与评估模型

3.6 使用模型进行预测

四、注意事项

五、完整训练示例代码


一、 PyTorch 简介

        PyTorch 是由 Facebook 开发的开源深度学习框架,以动态计算图(Dynamic Computational Graph)著称,允许在运行时即时定义和修改模型结构,便于调试和研究。它支持 GPU 加速,并拥有丰富的生态系统,适用于自然语言处理、计算机视觉等众多领域。

主要特点:

  • 动态计算图:每次运行时构建计算图,便于调试和灵活性高。

  • 自动求导(Autograd):支持自动求导,便于梯度计算与反向传播。

  • 模块化设计:通过 torch.nn 提供丰富的神经网络层及模块,方便构建复杂模型。

  • 丰富的生态:支持 torchvision、torchtext 等扩展库,加速模型开发和实验。

二、安装 PyTorch

        可参考YOLO系列环境配置及训练_yolo环境配置-CSDN博客 中pytorch的安装方法,以下简要概括:(以安装CPU版本为例)

pip install torch torchvision

安装后,可以通过以下代码验证安装及查看版本:

import torch
print("PyTorch 版本:", torch.__version__)

 CPU版本安装成功的输出示例为:

PyTorch 版本: 1.13.0

三、PyTorch 常用函数和操作

3.1 创建张量(Tensor)

        与TensorFlow一样,在 PyTorch 中的张量类似于 NumPy 的数组,同时支持 GPU 加速。

例如:

import torch# 创建标量、向量和矩阵
scalar = torch.tensor(5)
vector = torch.tensor([1, 2, 3])
matrix = torch.tensor([[1, 2], [3, 4]])print("标量:", scalar)
print("向量:", vector)
print("矩阵:\n", matrix)

样例输出:

标量: tensor(5)
向量: tensor([1, 2, 3])
矩阵:tensor([[1, 2],[3, 4]])

3.2 基本数学运算

        PyTorch 同样提供了基本的数学运算,例如:

a = torch.tensor(3.0)
b = torch.tensor(2.0)print("加法:", torch.add(a, b))
print("乘法:", torch.mul(a, b))
# 矩阵乘法
mat1 = torch.tensor([[1, 2]])
mat2 = torch.tensor([[3], [4]])
print("矩阵乘法:\n", torch.matmul(mat1, mat2))

样例输出:

加法: tensor(5.)
乘法: tensor(6.)
矩阵乘法:tensor([[11]])

3.3 自动求导(Autograd)

        PyTorch 的 autograd 功能可以自动计算梯度,非常适合神经网络反向传播的实现。

例如:

# 定义一个需要计算梯度的张量
x = torch.tensor(2.0, requires_grad=True)# 定义函数 y = x³ + 2x + 1
y = x**3 + 2*x + 1# 反向传播,计算 dy/dx
y.backward()print("dy/dx:", x.grad)

样例输出:

dy/dx: tensor(14.)

3.4 定义神经网络模型

        PyTorch 提供了 torch.nn 模块来构建神经网络模型。例如下面使用一个简单的全连接层构建模型:

import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc = nn.Linear(784, 10)  # 输入 784 维,输出 10 维(例如手写数字分类)def forward(self, x):out = self.fc(x)return out# 实例化模型并打印模型结构
model = SimpleNet()
print(model)

样例输出:

SimpleNet((fc): Linear(in_features=784, out_features=10, bias=True)
)

3.5 训练与评估模型

在3.4的基础上,我们继续完善构建。

例如:

# 假设我们有一个批次的输入数据(如手写数字图像,已展平为784维向量)
batch_size = 32
dummy_input = torch.randn(batch_size, 784)  # 随机生成一批数据
dummy_labels = torch.randint(0, 10, (batch_size,))  # 随机生成对应的标签# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 前向传播
outputs = model(dummy_input)
loss = criterion(outputs, dummy_labels)print("初始损失:", loss.item())# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()# 再次输出损失(注意:仅作为示例,损失值可能不会明显下降)
outputs_after = model(dummy_input)
loss_after = criterion(outputs_after, dummy_labels)
print("更新后损失:", loss_after.item())

样例输出:

初始损失: 2.280543327331543
更新后损失: 2.2781271934509277

注意: 损失值会受到随机数据和权重初始化的影响,实际训练中损失下降情况应更为明显。

3.6 使用模型进行预测

在 3.5 训练结束后,我们可以通过调用模型的 forward 方法,可以对新的数据进行预测:

# 对一条测试数据进行预测
test_sample = torch.randn(1, 784)
pred_logits = model(test_sample)
pred_label = torch.argmax(pred_logits, dim=1)
print("预测类别:", pred_label.item())

样例输出:

预测类别: 7

四、注意事项

        在训练过程中,切忌混淆设备(Device),注意将模型和数据迁移到同一设备:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
inputs = inputs.to(device)

五、完整训练示例代码

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc = nn.Linear(784, 10)  # 输入 784 维,输出 10 维(例如手写数字分类)def forward(self, x):out = self.fc(x)return out# 实例化模型
model = SimpleNet()
print(model)# 设置设备(如果有 GPU 就用 GPU,否则用 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)# 假设我们有一个批次的输入数据(如手写数字图像,已展平为784维向量)
batch_size = 32
dummy_input = torch.randn(batch_size, 784).to(device)  # 随机生成一批数据并迁移到 device
dummy_labels = torch.randint(0, 10, (batch_size,)).to(device)  # 随机生成对应的标签并迁移到 device# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 前向传播
outputs = model(dummy_input)
loss = criterion(outputs, dummy_labels)
print("初始损失:", loss.item())# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()# 再次输出损失(仅作为示例,损失值可能不会明显下降)
outputs_after = model(dummy_input)
loss_after = criterion(outputs_after, dummy_labels)
print("更新后损失:", loss_after.item())# 对一条测试数据进行预测
test_sample = torch.randn(1, 784).to(device)
pred_logits = model(test_sample)
pred_label = torch.argmax(pred_logits, dim=1)
print("预测类别:", pred_label.item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/74612.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp中APP上传文件

uniapp提供了uni.chooseImage(选择图片), uni.chooseVideo(选择视频)这两个api,但是对于打包成APP的话就没有上传文件的api了。因此我采用了plus.android中的方式来打开手机的文件管理从而上传文件。 下面…

推陈换新系列————java8新特性(编程语言的文艺复兴)

文章目录 前言一、新特性秘籍二、Lambda表达式2.1 语法2.2 函数式接口2.3 内置函数式接口2.4 方法引用和构造器引用 三、Stream API3.1 基本概念3.2 实战3.3 优势 四、新的日期时间API4.1 核心概念与设计原则4.2 核心类详解4.2.1 LocalDate(本地日期)4.2…

树莓派5从零开发至脱机脚本运行教程——1.系统部署篇

树莓派5应用实例——工创视觉 前言 哈喽,各位小伙伴,大家好。最近接触了树莓派,然后简单的应用了一下,学习程度并不是很深,不过足够刚入手树莓派5的小伙伴们了解了解。后面的几篇更新的文章都是关于开发树莓派5的内容…

GPT Researcher 的win docker安装攻略

github网址是:https://github.com/assafelovic/gpt-researcher 因为docker安装方法不够清晰,因此写一个使用方法 以下是针对 Windows 系统 使用 Docker 运行 AI-Researcher 项目的 详细分步指南: 步骤 1:安装 Docker 下载 Docke…

【后端】【Django DRF】从零实现RBAC 权限管理系统

Django DRF 实现 RBAC 权限管理系统 在 Web 应用中,权限管理 是一个核心功能,尤其是在多用户系统中,需要精细化控制不同用户的访问权限。本文介绍如何使用 Django DRF 设计并实现 RBAC(基于角色的访问控制)系统&…

C#基础学习(五)函数中的ref和out

1. 引言:为什么需要ref和out? ​问题背景:函数参数默认按值传递,值类型在函数内修改不影响外部变量;引用类型重新赋值时外部对象不变。​核心作用:允许函数内部修改外部变量的值,实现“双向传参…

八纲辨证总则

一、八纲辨证的核心定义 八纲即阴、阳、表、里、寒、热、虚、实,是中医分析疾病共性的纲领性辨证方法。 作用:通过八类证候归纳疾病本质,为所有辨证方法(如脏腑辨证、六经辨证)的基础。 二、八纲分类与对应关系 1. 总…

【linux重设gitee账号密码 克隆私有仓库报错】

出现问题时 Cloning into xxx... remote: [session-1f4b16a4] Unauthorized fatal: Authentication failed for https://gitee.com/xxx/xxx.git/解决方案 先打开~/.git-credentials vim ~/.git-credentials或者创建一个 torch ~/.git-credentials 添加授权信息 username/pa…

绿联NAS安装内网穿透实现无公网IP也能用手机平板远程访问经验分享

文章目录 前言1. 开启ssh服务2. ssh连接3. 安装cpolar内网穿透4. 配置绿联NAS公网地址 前言 大家好,今天给大家带来一个超级炫酷的技能——如何在绿联NAS上快速安装cpolar内网穿透工具。想象一下,即使没有公网IP,你也能随时随地远程访问自己…

CSS 美化页面(一)

一、CSS概念 CSS(Cascading Style Sheets,层叠样式表)是一种用于描述 HTML 或 XML(如 SVG、XHTML)文档 样式 的样式表语言。它控制网页的 外观和布局,包括字体、颜色、间距、背景、动画等视觉效果。 二、CS…

空转 | GetAssayData doesn‘t work for multiple layers in v5 assay.

问题分析 当我分析多个样本的时候,而我的seurat又是v5时,通常就会出现这样的报错。 错误的原因有两个: 一个是参数名有slot变成layer 一个是GetAssayData 不是自动合并多个layers,而是选择保留。 那么如果我们想合并多个样本&…

UE4学习笔记 FPS游戏制作17 让机器人持枪 销毁机器人时也销毁机器人的枪 让机器人射击

添加武器插槽 打开机器人的Idle动画,方便查看武器位置 在动画面板里打开骨骼树,找到右手的武器节点,右键添加一个插槽,重命名为RightWeapon,右键插槽,添加一个预览资产,选择Rifle,根…

【JavaScript】七、函数

文章目录 1、函数的声明与调用2、形参默认值3、函数的返回值4、变量的作用域5、变量的访问原则6、匿名函数6.1 函数表达式6.2 立即执行函数 7、练习8、逻辑中断9、转为布尔型 1、函数的声明与调用 function 函数名(形参列表) {函数体 }eg: // 声明 function sayHi…

硬件基础--05_电压

电压(电势差) 有了电压,电子才能持续且定向移动起来,所有电压是形成电流的必要条件。 电压越大,能“定向移动”起来的电子就越多,电流就会越大。 有电压的同时,形成闭合回路才会有电流,不是有电压就有电流…

ES数据过多,索引拆分

公司企微聊天数据存储在 ES 中,虽然按照企业分储在不同的ES 索引中,但某些常用的企微主体使用量还是很大。4年中一个索引存储数据已经达到46多亿条数据,占用存储3.1tb, ES 配置 由于多一个副本,存储得翻倍,成本考虑…

存储服务器是指什么

今天小编主要来为大家介绍存储服务器主要是指什么,存储服务器与传统的物理服务器和云服务器是不同的,其是为了特定的目标所设计的,在硬件配置方式上也有着一定的区别,存储空间会根据需求的不同而改变。 存储服务器中一般会配备大容…

golang不使用锁的情况下,对slice执行并发写操作,是否会有并发问题呢?

背景 并发问题最简单的解决方案加个锁,但是,加锁就会有资源争用,提高并发能力其中的一个优化方向就是减少锁的使用。 我在之前的这篇文章《开启多个协程,并行对struct中的每个元素操作,是否会引起并发问题?》中讨论过多协程场景下struct的并发问题。 Go语言中的slice在…

Java知识整理round1

一、常见集合篇 1. 为什么数组索引从0开始呢?假如从1开始不行咩 数组(Array):一种用连续的内存空间存储相同数据类型数据的线性数据结构 (1)在根据数组索引获取元素的时候,会用索引和寻址公式…

【C++指针】搭建起程序与内存深度交互的桥梁(下)

🔥🔥 个人主页 点击🔥🔥 每文一诗 💪🏼 往者不可谏,来者犹可追——《论语微子篇》 译文:过去的事情已经无法挽回,未来的岁月还可以迎头赶上。 目录 C内存模型 new与…

JavaScript创建对象的多种方式

在JavaScript中,创建对象有多种方式,每种方式都有其优缺点。本文将介绍四种常见的对象创建模式:工厂模式、构造函数模式、原型模式和组合模式,并分析它们的特点以及如何优化。 1. 工厂模式 工厂模式是一种简单的对象创建方式&am…