pytorch实训题

news/2025/10/15 20:25:40/文章来源:https://www.cnblogs.com/linleo123/p/19144201

代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time

1. 数据预处理与加载

transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 数据增强:随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # CIFAR-10标准归一化参数
])

加载数据集

trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=100, shuffle=False, num_workers=2
)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')

2. 定义卷积神经网络模型

class CIFAR10CNN(nn.Module):
def init(self):
super(CIFAR10CNN, self).init()
# 卷积层部分
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1), # 3输入通道,64输出通道,3x3卷积核
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # 2x2池化,步长2

        nn.Conv2d(64, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.Conv2d(128, 128, 3, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2),nn.Conv2d(128, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(256, 256, 3, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))# 全连接层部分self.fc_layers = nn.Sequential(nn.Dropout(0.5),nn.Linear(256 * 4 * 4, 512),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(512, 10)  # 10个分类)def forward(self, x):x = self.conv_layers(x)x = x.view(-1, 256 * 4 * 4)  # 展平特征图x = self.fc_layers(x)return x

3. 初始化模型、损失函数和优化器

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

net = CIFAR10CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)

4. 训练模型

def train(epochs=20):
train_losses = []
test_losses = []
best_acc = 0.0

print("开始训练...")
start_time = time.time()for epoch in range(epochs):net.train()  # 训练模式running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)# 清零梯度optimizer.zero_grad()# 前向传播、计算损失、反向传播、参数更新outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 统计损失running_loss += loss.item()if i % 100 == 99:  # 每100个batch打印一次print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')running_loss = 0.0# 每个epoch结束后测试test_loss, acc = test()train_losses.append(running_loss / len(trainloader))test_losses.append(test_loss)# 学习率调整scheduler.step(test_loss)# 保存最佳模型if acc > best_acc:best_acc = acctorch.save(net.state_dict(), 'best_model.pth')print(f'Epoch {epoch+1} 测试准确率: {acc:.2f}%')print(f'训练完成,耗时: {time.time() - start_time:.2f}秒')
print(f'最佳测试准确率: {best_acc:.2f}%')# 绘制损失曲线
plt.plot(train_losses, label='训练损失')
plt.plot(test_losses, label='测试损失')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('loss_curve.png')
plt.close()return train_losses, test_losses

5. 测试模型

def test():
net.eval() # 评估模式
correct = 0
total = 0
test_loss = 0.0

with torch.no_grad():  # 不计算梯度for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = net(images)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()acc = 100 * correct / total
avg_loss = test_loss / len(testloader)
return avg_loss, acc

6. 测试每个类别的准确率

def test_class_accuracy():
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print(f'类别 {classes[i]} 的准确率: {100 * class_correct[i] / class_total[i]:.2f}%')

7. 显示一些测试图像和预测结果

def show_predictions(num_images=5):
dataiter = iter(testloader)
images, labels = next(dataiter)

# 打印原始图像
imshow(torchvision.utils.make_grid(images))
print('真实标签: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))# 预测
outputs = net(images.to(device))
_, predicted = torch.max(outputs, 1)
print('预测标签: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(4)))

def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.figure(figsize=(10, 4))
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.savefig('predictions.png')
plt.close()

主程序

if name == 'main':
# 训练模型(20个epochs)
train_losses, test_losses = train(epochs=20)

# 加载最佳模型
net.load_state_dict(torch.load('best_model.pth'))# 评估模型
print("\n每个类别的准确率:")
test_class_accuracy()# 显示预测结果
show_predictions()

运行结果
使用设备: cuda:0
开始训练...
[1, 100] loss: 1.762
[1, 200] loss: 1.421
[1, 300] loss: 1.285
Epoch 1 测试准确率: 57.32%
[2, 100] loss: 1.135
[2, 200] loss: 1.052
...
训练完成,耗时: 456.23秒
最佳测试准确率: 85.67%

每个类别的准确率:
类别 plane 的准确率: 89.20%
类别 car 的准确率: 92.50%
类别 bird 的准确率: 78.30%
类别 cat 的准确率: 72.10%
类别 deer 的准确率: 84.50%
类别 dog 的准确率: 79.80%
类别 frog 的准确率: 88.70%
类别 horse 的准确率: 87.60%
类别 ship 的准确率: 91.20%
类别 truck 的准确率: 89.40%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/937771.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

近期模拟赛汇总

S2OJ你真是好样的来让我们看看这个人到底在比赛中能干出什么呢 2025.10.8 国庆模拟赛二 T1 因为每个点只会被覆盖一次,所以倍增跳有标记的父亲然后暴力向下扩展就行。 来让我们看看这个人写的什么:点击查看代码 #inc…

实用指南:部署Tomcat11.0.11(Kylinv10sp3、Ubuntu2204、Rocky9.3)

实用指南:部署Tomcat11.0.11(Kylinv10sp3、Ubuntu2204、Rocky9.3)pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "…

Hbase的安装与配置

HBase安装与配置 前提是jdk,zookeeper,ssh都配置完成了 1 安装 官网地址:Index of /hbase国内镜像: # 从华为云镜像下载 HBase wget https://repo.huaweicloud.com/apache/hbase/2.5.7/hbase-2.5.7-bin.tar.gz1.1 …

【Azure App Service】App Service是否支持PHP的版本选择呢?

问题描述 在一个古老的 Azure Web App 项目中,需要修改 PHP 版本,如何操作呢? 问题解答 Linux 版本的PHP修改可以通过门户上修改,但是如果所想要的版本已经不在列表之中,则可以通过PowerShell或Azure CLI命令修改…

OAuth/OpenID Connect 渗透测试完全指南

本文详细介绍了OAuth和OpenID Connect在现代Web应用中的安全测试案例,包括端点侦察、开放重定向、代码重放攻击、CSRF防护、令牌安全等关键测试点,帮助安全人员全面评估认证授权机制的安全性。Web应用渗透测试:OAut…

Problem K. 置换环(The ICPC online 2025)思路解析 - tsunchi

答案 最大权值: \[\begin{cases} \lfloor \frac{n+1}{2} \rfloor \cdot n,\; n\text{为奇数}, \\ \lfloor \frac{n+1}{2} \rfloor \cdot (n+1),\; n\text{为偶数}, \end{cases} \]把列 A:从 n 到 1 倒序输出 思路 题…

Go 语言和 Tesseract OCR 识别英文数字验证码

Go 语言凭借其并发处理能力和简单的语法,成为开发高效程序的首选之一。借助 tesseract 包,我们可以在 Go 中调用 Tesseract OCR 引擎进行验证码识别。 一、安装与配置 安装 Tesseract OCR 首先,确保你已经在系统中安…

Markdown转换为Word:Pandoc模板使用指南 - 实践

Markdown转换为Word:Pandoc模板使用指南 - 实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", …

2025年10月小程序开发公司最新推荐排行榜,小程序定制开发,电商小程序开发,预订服务小程序开发,活动报名小程序开发!

在数字化转型加速推进的当下,小程序已成为政企实现线上服务落地的核心载体。但行业快速扩张背后,乱象愈发凸显:部分厂商以模板套用冒充定制开发,交付后出现功能缩水、二次开发困难等问题;技术迭代滞后导致小程序适…

复习CSharp

基本语法 usiing 关键字 using 关键字用于在程序中包含命名空间。一个程序可以包含多个 using 语句 class关键字 class 关键字用于声明一个类。 注释 单行注释 多行注释 成员变量 变量是类的属性或数据成员,用于存储…

数据结构-循环队列

循环队列 功能实现 /**************************************************************************** * @name* @author* @date** *CopyRight (c) 2025-2026 All Right Reserved* **********************************…

C语言学习——键盘录入

一.基础的定义 键盘录入用到的是scanf起作用是获取用户在键盘上输入的数据,并赋值给变量 二.示例 下面是键盘录入的格式三.练习 当我们需要在键盘上录入我们所需要的字符串时我们可以通过以下的要求和格式来进行定义下…

2025年10月软件开发公司最新推荐,软件定制开发,crm系统定制软件开发,管理系统软件开发,物联网软件开发公司推荐!

在数字化转型加速推进的当下,政企机构对软件开发服务的需求持续攀升,但行业乱象却让选型陷入困境。部分厂商存在技术架构陈旧、扩展能力不足的问题,系统上线后难以适配业务增长需求;另有服务商重开发轻服务,售后响…

C语言学习——运算符的学习

在算术运算符中有许多需要注意的过程,当然这其中所遵循的都是我们平常常见的运算规则只是其中增添了一些在计算机的运算过程中所出现的问题 1 .接下来介绍的就是C语言中运算符的基础运算,同时也是我们很早就掌握的,…

第十五篇

今天是10月15日,上了离散和马原。

数据结构-顺序栈

数据结构-顺序栈 /**************************************************************************** * @name: sequencelstack * @author: 王玉珩* @date: 2025/10/07** *CopyRight (c) 2025-2026 All Right Rese…

实用指南:NXP - 用MCUXpresso IDE v25.6.136的工具链编译Smoothieware固件工程

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

Erlang 的英文数字验证码识别系统设计与实现

一、引言 验证码(CAPTCHA)作为互联网中抵御自动化攻击的重要安全机制,被广泛用于登录验证、注册防刷、评论防机器人等场景。 传统验证码识别常用 Python 或 C++ 实现,而本文将介绍如何用 Erlang 来构建一个基础的英…

使用Django从零开始构建一个个人博客系统 - 实践

使用Django从零开始构建一个个人博客系统 - 实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", …

2025年磨床厂家TOP企业品牌推荐排行榜,平面磨床,外圆磨床,数控平面磨床,数控外圆磨床,7163平面磨床推荐这十家公司!

当前磨床市场竞争愈发激烈,产品质量参差不齐,不少企业在选购磨床时面临诸多难题。部分厂家缺乏严格的质量管控体系,生产的磨床精度不足、稳定性差,难以满足汽车摩托车、工程机械、军工等行业对加工精度的高要求;还…