P66实训题

news/2025/10/15 20:47:06/文章来源:https://www.cnblogs.com/Neflibata1/p/19144240

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
import matplotlib.pyplot as plt
import numpy as np

1. 数据加载与预处理

transform = Compose([
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

类别标签

classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')

查看数据集样本(可选)

def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()

dataiter = iter(train_loader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images[:4]))
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

2. 构建网络

class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.fc1 = nn.Linear(128 * 4 * 4, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.3)

def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.dropout(x)x = self.pool(torch.relu(self.conv2(x)))x = self.dropout(x)x = self.pool(torch.relu(self.conv3(x)))x = self.dropout(x)x = x.view(-1, 128 * 4 * 4)x = torch.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x

net = Net()

3. 定义损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

4. 训练网络

epochs = 30
train_losses = []
train_accs = []
test_losses = []
test_accs = []

for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0
net.train()
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_loss = running_loss / len(train_loader)
train_acc = 100. * correct / total
train_losses.append(train_loss)
train_accs.append(train_acc)

# 测试
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:outputs = net(inputs)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()
test_loss = test_loss / len(test_loader)
test_acc = 100. * correct / total
test_losses.append(test_loss)
test_accs.append(test_acc)print(f'Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.3f} | Test Acc: {test_acc:.2f}%')

print('Finished Training')

绘制训练曲线

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.legend()
plt.title('Loss')

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(test_accs, label='Test Acc')
plt.legend()
plt.title('Accuracy')
plt.show()

5. 测试模型精度

net.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = net(inputs)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

print(f'测试集精度: {100. * correct / total:.2f}%')

查看各类别预测精度(可选)

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for inputs, labels in test_loader:
outputs = net(inputs)
_, predicted = outputs.max(1)
c = (predicted == labels).squeeze()
for i in range(len(labels)):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1

for i in range(10):
print(f'类别 {classes[i]} 的精度: {100. * class_correct[i] / class_total[i]:.2f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/937783.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

非主流网站程序IndexNow添加方法

第一步:生成API密钥 打开:https://www.bing.com/indexnow/getstarted#implementation 得到一个txt文件,例如:1ad7ba0***4b64b045fbb****0ac5bfd.txt 将这个文件上传到网站根目录,上传之后不要删除。 第二部:新增…

卷积神经网络视频读书报告

《卷积神经网络(CNN)学习感悟》读书报告 24信计2 刘雨坤 摘要 本报告围绕卷积神经网络(CNN)展开深入学习与探讨。通过研读相关资料及观看教学视频,系统梳理了 CNN 的基本概念、核心运算原理、关键组成部分、技术优…

C 语言 - 内存操作函数以及字符串操作函数解析

预先了解 "\0" 标志它是 一个转义字符(escape character),表示的是 数值为 0 的字符,\0 就是 一个字节值为 0 的字符。 char str[] = "ABC"; //在 C语言的字符串 中,\0 用来表示 字符串的结束…

以*this返回局部对象的两种情况

1、以值返回局部对象class Person { public:Person(int age) {this->age = age;}// 以值方式返回局部对象会调用拷贝构造生成一个新的对象返回Person PersonAddPerson(Person p) {this->age += p.age;return *th…

2025.10.15

今天早八上离散数学课,然后上马克思主义原理,老师讲的很好,中午吃了一份沙县小吃的鸡腿饭,然后睡了两个小时觉,起床洗澡,然后上音乐鉴赏课,上课的时候制作了学生会部长成员表。

Kali 自定义ISO镜像

简单自定义 Kali live ISO 简单自定义一下kali 镜像的开机菜单和背景图,没太多技术含量,记录一下留存 # 下载构建脚本,建议在kali系统上构建 git clone https://gitlab.com/kalilinux/build-scripts/live-build-con…

2025秋_12

今天学习了Java

nginx-1.16.1-2.p01.ky10.sw_64.rpm 安装教程(详细步骤,适用于Kylin V10/申威SW64架构)

nginx-1.16.1-2.p01.ky10.sw_64.rpm 安装教程(详细步骤,适用于Kylin V10/申威SW64架构)​ nginx-1.16.1-2.p01.ky10.sw_64.rpm是专门为 ​银河麒麟操作系统 Kylin V10(Ky10)​​ 以及 ​SW64 架构​ 编译打包的 ​…

感知节点@5@ ESP32+arduino+ 第三个程序FreeRTOS 上 LED灯显示 和 串口打印ASCII表

思路: 将 LED灯显示 作为 一个独立的 FreeROTS 任务将串口打印ASCII表 作为 一个独立的 FreeROTS任务 将已经调试好的 LED灯显示代码 和 串口打印ASCII表 可以复制使用。1)观看视频,理解FreeROTS 多任务运…

BIG-Bench:大规模语言模型能力的全面评估与挑战 - 详解

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

pytorch实训题

代码 import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np import time 1. 数据…

近期模拟赛汇总

S2OJ你真是好样的来让我们看看这个人到底在比赛中能干出什么呢 2025.10.8 国庆模拟赛二 T1 因为每个点只会被覆盖一次,所以倍增跳有标记的父亲然后暴力向下扩展就行。 来让我们看看这个人写的什么:点击查看代码 #inc…

实用指南:部署Tomcat11.0.11(Kylinv10sp3、Ubuntu2204、Rocky9.3)

实用指南:部署Tomcat11.0.11(Kylinv10sp3、Ubuntu2204、Rocky9.3)pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "…

Hbase的安装与配置

HBase安装与配置 前提是jdk,zookeeper,ssh都配置完成了 1 安装 官网地址:Index of /hbase国内镜像: # 从华为云镜像下载 HBase wget https://repo.huaweicloud.com/apache/hbase/2.5.7/hbase-2.5.7-bin.tar.gz1.1 …

【Azure App Service】App Service是否支持PHP的版本选择呢?

问题描述 在一个古老的 Azure Web App 项目中,需要修改 PHP 版本,如何操作呢? 问题解答 Linux 版本的PHP修改可以通过门户上修改,但是如果所想要的版本已经不在列表之中,则可以通过PowerShell或Azure CLI命令修改…

OAuth/OpenID Connect 渗透测试完全指南

本文详细介绍了OAuth和OpenID Connect在现代Web应用中的安全测试案例,包括端点侦察、开放重定向、代码重放攻击、CSRF防护、令牌安全等关键测试点,帮助安全人员全面评估认证授权机制的安全性。Web应用渗透测试:OAut…

Problem K. 置换环(The ICPC online 2025)思路解析 - tsunchi

答案 最大权值: \[\begin{cases} \lfloor \frac{n+1}{2} \rfloor \cdot n,\; n\text{为奇数}, \\ \lfloor \frac{n+1}{2} \rfloor \cdot (n+1),\; n\text{为偶数}, \end{cases} \]把列 A:从 n 到 1 倒序输出 思路 题…

Go 语言和 Tesseract OCR 识别英文数字验证码

Go 语言凭借其并发处理能力和简单的语法,成为开发高效程序的首选之一。借助 tesseract 包,我们可以在 Go 中调用 Tesseract OCR 引擎进行验证码识别。 一、安装与配置 安装 Tesseract OCR 首先,确保你已经在系统中安…

Markdown转换为Word:Pandoc模板使用指南 - 实践

Markdown转换为Word:Pandoc模板使用指南 - 实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", …