PyTorch中的CPU和GPU代码实现详解

PyTorch中的CPU和GPU

  • PyTorch中的CPU和GPU代码实现详解
    • 1. 安装PyTorch
    • 2. 编写支持CPU和GPU的PyTorch代码
      • 2.1 模型定义
      • 2.2 数据加载
      • 2.3 将模型和数据移动到GPU
      • 2.4 训练循环
    • 3. 关键步骤详解
      • **3.1 定义设备**
      • **3.2 模型和数据移动到GPU**
      • **3.3 优化器和损失函数**
    • 4. 完整代码示例
    • 5. 结论

PyTorch中的CPU和GPU代码实现详解

在深度学习的开发过程中,计算资源的高效利用是至关重要的。PyTorch作为一种流行的深度学习框架,支持使用CPUGPU进行模型训练和推理。相较于CPU,GPU由于其强大的并行计算能力,能够显著加速深度学习任务。然而,将PyTorch代码从CPU版本迁移到GPU版本需要进行一些额外的代码修改。本文将详细介绍如何在PyTorch中编写支持CPU和GPU的代码,以及需要特别注意的事项。

1. 安装PyTorch

首先,确保你已经安装了支持GPU的PyTorch版本。如果还没有安装,可以参考以下命令进行安装:

# For CUDA 11.1
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu111

2. 编写支持CPU和GPU的PyTorch代码

2.1 模型定义

定义模型的代码在CPU和GPU版本中基本一致。但是,我们需要确保模型可以在GPU上运行。

import torch
import torch.nn as nn
import torch.optim as optimclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc = nn.Linear(784, 10)def forward(self, x):return self.fc(x)model = SimpleNN()

2.2 数据加载

数据加载部分对于CPU和GPU是相同的。使用DataLoader类加载数据:

from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

2.3 将模型和数据移动到GPU

在PyTorch中,模型和数据需要显式地移动到GPU上。使用.to(device)方法将模型和数据移动到指定设备(CPU或GPU)上。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)

2.4 训练循环

在训练循环中,我们需要确保输入数据和标签也被移动到GPU上。

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)for epoch in range(5):running_loss = 0.0for inputs, labels in trainloader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}")

3. 关键步骤详解

3.1 定义设备

使用torch.device定义设备,根据当前环境选择使用CPU或GPU。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3.2 模型和数据移动到GPU

将模型和数据显式地移动到GPU上。这一步是关键,没有这一步,模型和数据仍然会在CPU上进行计算。

model.to(device)
inputs, labels = inputs.to(device), labels.to(device)

3.3 优化器和损失函数

优化器和损失函数在CPU和GPU版本中不需要特殊处理,它们会自动适应模型所在的设备。

4. 完整代码示例

以下是完整的代码示例,包括从数据加载到训练循环的所有步骤。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc = nn.Linear(784, 10)def forward(self, x):return self.fc(x)model = SimpleNN().to(device)# 数据加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 训练循环
for epoch in range(5):running_loss = 0.0for inputs, labels in trainloader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}")

5. 结论

通过本文的详细讲解,我们了解了如何在PyTorch中编写支持CPU和GPU的代码。重点在于将模型和数据显式地移动到GPU上,并确保训练循环中的每一步都在正确的设备上进行计算。掌握这些技巧后,你可以充分利用GPU的强大计算能力,加速深度学习模型的训练和推理过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/870535.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

构建实时银行应用程序:英国金融机构 Nationwide 为何选择 MongoDB Atlas

Nationwide Building Society 超过135年的互助合作 Nationwide Building Society(以下简称“Nationwide”) 是一家英国金融服务提供商,拥有超过 1500 万名会员,是全球最大的建房互助会。 Nationwide 的故事可以追溯到 1884 年&am…

web后端开发--请求响应

目录 前言 请求 简单参数 原始方法 Spring方式 Post请求乱码处理 实体参数 简单实体参数 复杂实体参数 ​编辑 数组集合参数 数组参数 ​编辑 集合参数 日期参数 ​编辑 Json参数 ​编辑 传递json数据 json数组 json对象(POJO) jso…

Dify中的知识库API列表

1.知识库API列表 通过文本/文件创建/更新/删除文档/查询文档嵌入状态,知识库创建/知识库查询/文档列表查询,分段增/删/改/查。 接口名字功能描述请求示例POST/datasets/{dataset_id}/document/create_by_text通过文本创建文档此接口基于已存在知识库&a…

tableau人口金字塔,漏斗图,箱线图绘制 - 13

人口金字塔,漏斗图,箱线图 1. 金字塔1.1 定义1.2 金字塔创建1.2.1 数据导入1.2.2 数据异常排查1.2.3 创建度量字段1.2.4 转换属性1.2.5 创建数据桶1.2.6 选择相关属性1.2.7 年龄排序1.2.8 创建计算字段1.2.9 选择相关字段1.2.10 设置轴排序1.2.11 设置颜…

Windows系统服务器远程教程

在远程连接Windows系统服务器之前,需要确保以下几点: 被远程的Windows服务器必须开启远程桌面功能。这一功能在Windows系统中默认是关闭的,需要手动启用。 必须为两台计算机中的一台计算机(即客户端)创建远程桌面连接。…

11、中台-DDD-几种微服务架构模型对比分析

引言 在上一章中,我们深入探讨了DDD分层架构的基本概念和实现方法。这一章将重点介绍几种常用的微服务架构模型,包括洋葱架构、六边形架构,并对这两种架构模型与DDD分层架构进行对比分析。通过了解不同架构模型的优缺点,帮助我们…

C++复合数据类型:指针类型、引用类型、指针和引用之间的关系

复合数据类型 (1)指针 A.What(什么是指针) 用于存放对象地址的复合数据类型 B.Which(有哪些指针) 空指针: int *p nullptr; int *p 0;//(不指向任何对象)void *: void *&…

fastermaker-boot代码生成器

fastermaker-boot 是基于Spring Boot3 、Vue3 的一个代码简洁、结构清晰、开发高效、模块可扩展的单体项目的基础开发框架,包含代码生成器模块,适合初级开发者特别是大学生学习研究使用,也是中小型系统快速开发的利器。 开发技术: JDK 17、Sp…

liunx清理服务器内存和日志

1、查看服务器磁盘占用情况 # 查看磁盘占用大小 df -h 2、删除data文件夹下面的日志 3、查看每个服务下面的日志输出文件,过大就先停掉服务再删除out文件再重启服务 4、先进入想删除输入日志的服务文件夹下,查看服务进程,杀掉进程&#xff…

DW03D是一款用于锂离子/聚合物电池保护的高集成度解决方案。DW03D包含内部功率MOSFET、高精度电压检测电路和延迟电路

一般概述 DW03D产品是单节锂离子/锂聚合物可充电电池组保护的高集成度解决方案。DW03D包括了先进的功率MOSFET,高精度的电压检测电路和延时电路。 DW03D具有非常小的TSS08-8的封装,这使得该器件非常适合应用于空间限制得非常小的可充电电池组应用。…

【备战秋招】——算法题目训练和总结day3

【备战秋招】——算法题目训练和总结day3😎 前言🙌BC149简写单词题解思路分析代码分享: dd爱框框题解思路分析代码分享: 除2!题解思路分析代码分享: 总结撒花💞 😎博客昵称&#xff…

Gradle 介绍

Gradle 定义 Gradle 是一个现代化的构建自动化工具,用于管理软件项目的构建过程和依赖关系。它通过一种灵活且强大的 DSL(领域特定语言)语法来描述项目的构建逻辑和任务,可以用于构建几乎任何类型的软件项目,从简单的应…

【Java数据结构】初识线性表之一:顺序表

使用Java简单实现一个顺序表 顺序表是用一段物理地址连续的存储单元依次存储数据元素的线性结构,一般情况下采用数组存储。在数组上完成数据的增删查改。 线性表大致包含如下的一些方法: public class MyArrayList { private int[] array; pri…

怎么将mkv视频格式转为mp4?这四种转换方法你肯定要试试!

怎么将mkv视频格式转为mp4?你是否曾被MKV格式的魅力深深吸引,仿佛踏入了一个充满奇幻色彩的多媒体秘境,那里,音频如溪流潺潺,视频似画卷铺展,字幕则如同夜空中最亮的星,三者交织成一场视听盛宴&…

【彻底禁用Windows系统的自动更新,让电脑使用更顺心!】

文章底部关注公众号:电脑维修小马 回复关键词即可获取软件及注册表:禁用更新 功能简介 自动更新是Windows系统的一项重要功能,旨在保持操作系统的安全性和最新状态。然而,对于许多用户来说,自动更新并不总是那么受欢迎…

ospf-lsa

区域间路由计算 OSPF 单区域带来问题 1. OSPF 网络规模扩大时,每个设备 LSDB 中的 LSA 数据变多,以及进行 SPF 计算时更加复杂, 增加设备的负担和性能损耗 2. 网络拓扑或者路由信息发生变化,网络中所有的设备需要更新…

linux:vi命令

vi * -p打开多个文件进行切换 .文件间切换 Ctrl6 //两文件间的切换 :bn //下一个文件 :bp //上一个文件 :ls //列出打开的文件,带编号 :b1~n //切换至第n个文件 对于用(v)split在多个窗格中打开的文件,这种方法只会在当前窗格中切换不同的文件Ctrl P&a…

c#验证输入语句是否带有sql入侵的方法

为了在C# WinForms中验证用户输入的数据是否包含SQL注入攻击语句,可以使用多种方法来检测和防止SQL注入。以下是几种常见的方法: 1. 使用参数化查询 参数化查询是防止SQL注入的最佳实践,它通过将用户输入作为参数传递给SQL查询,…

渔人杯——RE

贪吃蛇的秘密 修改代码后,报了一个错 # uncompyle6 version 3.9.1 # Python bytecode version base 3.7.0 (3394) # Decompiled from: Python 3.11.8 (tags/v3.11.8:db85d51, Feb 6 2024, 22:03:32) [MSC v.1937 64 bit (AMD64)] # Embedded file name: snake1.py…

2023 N1CTF-n1canary

文章目录 参考n1canary模板类和模板函数make_unique和unique_ptrstd::unique_ptr示例: std::make_unique示例: 结合使用示例 operator->getrandom逆向源码思路exp 参考 https://nese.team/posts/n1ctf2023/ n1canary 模板类和模板函数 template &…