66ye

news/2025/10/23 23:22:59/文章来源:https://www.cnblogs.com/rjsyk/p/19161847

`import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
import matplotlib.pyplot as plt
import numpy as np

1. 数据加载与预处理

transform = Compose([
ToTensor(),
Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # CIFAR-10均值和标准差
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

2. 定义网络模型(CNN)

class CIFAR10Net(nn.Module):
def init(self):
super(CIFAR10Net, self).init()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
self.relu2 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2, 2)
self.drop1 = nn.Dropout(0.25)

    self.conv3 = nn.Conv2d(32, 64, 3, padding=1)self.relu3 = nn.ReLU()self.conv4 = nn.Conv2d(64, 64, 3, padding=1)self.relu4 = nn.ReLU()self.pool2 = nn.MaxPool2d(2, 2)self.drop2 = nn.Dropout(0.25)self.flatten = nn.Flatten()self.fc1 = nn.Linear(64 * 8 * 8, 512)self.relu5 = nn.ReLU()self.drop3 = nn.Dropout(0.5)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.conv1(x)x = self.relu1(x)x = self.conv2(x)x = self.relu2(x)x = self.pool1(x)x = self.drop1(x)x = self.conv3(x)x = self.relu3(x)x = self.conv4(x)x = self.relu4(x)x = self.pool2(x)x = self.drop2(x)x = self.flatten(x)x = self.fc1(x)x = self.relu5(x)x = self.drop3(x)x = self.fc2(x)return x

model = CIFAR10Net()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

3. 定义损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

4. 训练网络

epochs = 20
train_acc_history = []
val_acc_history = []
train_loss_history = []
val_loss_history = []

for epoch in range(epochs):
# 训练阶段
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

    running_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()train_loss = running_loss / len(train_loader)
train_acc = correct / total
train_loss_history.append(train_loss)
train_acc_history.append(train_acc)# 验证阶段
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():for inputs, targets in test_loader:inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss = criterion(outputs, targets)val_loss += loss.item()_, predicted = outputs.max(1)val_total += targets.size(0)val_correct += predicted.eq(targets).sum().item()val_loss = val_loss / len(test_loader)
val_acc = val_correct / val_total
val_loss_history.append(val_loss)
val_acc_history.append(val_acc)print(f'Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.3f} | Val Loss: {val_loss:.3f} | Val Acc: {val_acc:.3f}')

5. 测试模型精度

model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
test_total += targets.size(0)
test_correct += predicted.eq(targets).sum().item()
test_acc = test_correct / test_total
print(f'测试集准确率: {test_acc:.3f}')

绘制训练曲线

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_acc_history, label='训练准确率')
plt.plot(val_acc_history, label='验证准确率')
plt.legend()
plt.title('准确率变化')

plt.subplot(1, 2, 2)
plt.plot(train_loss_history, label='训练损失')
plt.plot(val_loss_history, label='验证损失')
plt.legend()
plt.title('损失变化')
plt.show()

可视化部分测试结果

class_names = ['飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车']
model.eval()
with torch.no_grad():
inputs, targets = next(iter(test_loader))
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)

plt.figure(figsize=(10, 5))
for i in range(10):plt.subplot(2, 5, i+1)img = inputs[i].cpu().numpy().transpose((1, 2, 0))img = img * np.array([0.2023, 0.1994, 0.2010]) + np.array([0.4914, 0.4822, 0.4465])  # 反归一化img = np.clip(img, 0, 1)plt.imshow(img)plt.title(f'预测: {class_names[predicted[i]]}\n真实: {class_names[targets[i]]}')plt.axis('off')

plt.tight_layout()
plt.show()`

cb9001ca7a3e26279bacf9d8f4eb5d89
7392e107e8050a2271f2bd0c7a6f0b4d
79dcc099471f8af9055f42ded600b8c4

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/944767.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Gin笔记一之项目建立与运行

本文首发于公众号:Hunter后端 原文链接:Gin笔记一之项目建立与运行本篇笔记开始介绍 Golang 的 web 框架 Gin 的相关内容。 本系列笔记预计通过四到五篇笔记内容介绍 Gin 框架的核心内容,然后通过一个简单的示例介绍…

【题解】P14254 分割(divide)

想了想,感觉这道题还是总结一下为好。 这个题需要涉及大量证明,也是很恶心人了。 引理一:当第 \(1\) 个点选择了深度为 \(i\),那后续所有节点的深度就只能为 \(i\)。 证明:因为选择的点的深度是不降得,所以不能选…

10.23日学习笔记

一、应用层在 TCP/IP 中的位置 最顶层,直接为用户的应用程序提供网络服务。 不关心底下几层如何传输,只关心“消息语义”与“交换规则”。 典型交互模型: 客户/服务器(C/S) P2P(对等) 混合(边缘 P2P + 索引服务…

埃氏筛及扩展质因数筛——埃拉托斯特尼筛法变种

质数筛这段代码用 “埃拉托斯特尼筛法” 找 2 到 N 之间的所有素数,逻辑很直接:先假设所有数都是素数(用vis数组标记,初始全为true); 排除 0 和 1(它们不是素数,标记为false); 从 2 开始,对每个没被排除的数…

Day2路径,相对与绝对

路径指的是查找文件时,从起点到终点经历的路程 路径也分为绝对路径与相对路径 相对路径是从当前文件出发查找目标额文件 绝对路径是从盘符出发找目标文件 Windows电脑是从盘符出发的,而Mac电脑则是从根目录出发 在…

第九届强网杯线上赛PWN_flag-market

第九届强网杯线上赛PWN_flag-market第九届强网杯线上赛PWN_flag-market 一、题目二、信息搜集 下载题目给的附件,查看文件ctf.xinetd之后,知道我们的可执行程序名为chall:这个文件在附件中的bin目录下。 通过file命…

ISFB银行木马家族演化史:从Gozi到LDR4的技术剖析

本文深入分析ISFB银行木马家族十年演化历程,详细解析其技术架构、功能模块和分支变种,包括加载器、键盘记录、Web注入、VNC远程控制等核心功能,揭示网络犯罪组织的运作模式和技术演进。第1章 — 从Gozi到ISFB:一个…

exgcd板子

void exgcd(int &x,int &y,int a,int b) {if(!b){x=1;y=0;return;}exgcd(x,y,b,a%b);int t=x;x=y;y=t-a/b*y; }

2025.10.23

今天上午算法与数据结构的早八,然后跆拳道前半节课训练,后半节课进行体测,我对这次的成绩非常满意,中午没有点外卖,去食堂买了一个饼,回宿舍休息睡觉,晚上去科技楼制作本周六竞赛的PPT,一直到10点半。

Codeforces Round 976 (Div. 2) A. Find Minimum Operations

这个问题实际上是K进制取位和: 举例:2进制 n=110100 使用几次2的x次幂可以将n置0,ans=3 10进制 n=9924 使用几次10的x次幂可以将n置0,ans=9+9+2+4 k进制也相同 ,代码如下: `#include <bits/stdc++.h> using nam…

102302142罗伟钊第一次作业

1. 作业①: **1)、核心代码与输出 ** o 要求:用requests和BeautifulSoup库方法定向爬取给定网址(http://www.shanghairanking.cn/rankings/bcur/2020 )的数据,屏幕打印爬取的大学排名信息。代码是一个大学排名数…

一个基于 .NET 开源、功能强大的分布式微服务开发框架

前言 今天大姚给大家分享一个基于 .NET 开源、功能强大的分布式微服务开发框架:Anno.Core。Anno.Core 项目介绍 Anno.Core 是一个基于 .NET 开源、功能强大的分布式微服务开发框架,致力于简化分布式、微服务系统的构…

UE4学习笔记

基本操作窗口这里可以打卡很多视口设置可以通过设置书签到自己想要的视角视口世界大纲

20251021 NOIP模拟赛

T2 题目大意; 有一棵大小为 \(n\) 的树和 \(m\) 个关键点,你要从这 \(m\) 个关键点中随机选择 \(k\) 个点,问这 \(k\) 个点两两之间最长距离的期望是多少。 \(n \le 2000, m \le 300\) 解题思路: 最暴力的做法肯定…

RocketMQ+Spring Boot的简单实现及其深入分析

Producer搭建导入RocketMQ依赖和配置RocketMQ地址及producer的group:name<dependency><groupId>org.apache.rocketmq</groupId><artifactId>rocketmq-spring-boot-starter</artifactId>…

xcode程序创建文件存储位置

xcode创建的文件不在cpp文件所在位置,经过查找发现在下面这个地方/Users/用户名/Library/Developer/Xcode/DerivedData/employeesystem-dlmmqxmyqxjljjcoskekmpsbtstd/Build/Products/Debug employeesystem是项目名称…

欧拉操作系统搭建docker

欧拉安装dockerdocker官方没有支持欧拉的,因此使用的是centos7的docker源2者底层是类似的1、配置yum源和安装docker yum-config-manager --add-repo https://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.re…

关于2025年暑假自主巡航小车脚本文件的学习笔记

脚本:gnome-terminal --window -e bash -c "roscore; exec bash" \ gnome-terminal命令用于新建一个GNOME桌面环境的终端程序(顶级窗口) 选项--window,新建一个窗口与默认行为一致(属于是显式写法,提高…

3dmax下载安装教程及激活教程(附安装包)3dmax2025超详细下载安装步骤

很多新手想装 3dmax 2025 却不知道从哪下手,别担心,这份 3dmax 2025 详细安装教程从下载到激活,再到软件用法,一步一步教你,保证看了就会,轻松解决 3dmax 2025 安装难题。目录3dmax 2025 到底好用在哪?3dmax 20…

RFSOC学习记录(五)带通采样定理

RFSOC学习记录(五),在配置adda的混频模式之前通过公式推导介绍了带通采样定理以及奈奎斯特分区​花了三篇文章的时间大致讲了讲我对于rfsoc时钟树的理解,非常的浅薄与浅应用,现在我再从原理层面记录一下我对于rf …