cifar10

news/2025/10/15 19:49:43/文章来源:https://www.cnblogs.com/xiaoguo1111/p/19144124

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from multiprocessing import freeze_support
import sys

1. 加载和预处理数据

def load_data():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True,transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True,num_workers=2
)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True,transform=transform
)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False,num_workers=2
)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')return trainloader, testloader, classes

2. 构建网络

class Net(nn.Module):
def init(self):
super().init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = torch.flatten(x, 1)  # 展平x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x

3. 编译网络(定义损失函数和优化器)

def compile_model(net):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
return criterion, optimizer

4. 训练网络(已同步设备)

def train(net, trainloader, criterion, optimizer, device, epochs=2):
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# 核心:数据与模型设备同步
inputs, labels = inputs.to(device), labels.to(device)

        # 梯度清零optimizer.zero_grad()# 前向计算 + 反向传播 + 优化参数outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印训练日志running_loss += loss.item()if i % 2000 == 1999:  # 每2000个batch打印一次print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')running_loss = 0.0print('训练完成')

5. 测试网络(已同步设备)

def test(net, testloader, classes, device):
correct = 0
total = 0
# 测试时不计算梯度,加快速度
with torch.no_grad():
for data in testloader:
images, labels = data
# 数据与模型设备同步
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f'测试集整体准确率: {100 * correct // total} %')# 按类别统计准确率
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predictions = torch.max(outputs, 1)# 统计每个类别的预测结果for label, prediction in zip(labels, predictions):if label == prediction:correct_pred[classes[label]] += 1total_pred[classes[label]] += 1# 打印各类别准确率
for classname, correct_count in correct_pred.items():accuracy = 100 * float(correct_count) / total_pred[classname]print(f'类别: {classname:5s} 准确率: {accuracy:.1f} %')

if name == 'main':
freeze_support() # 解决Windows多进程问题
# 自动选择设备(有GPU用GPU,无则用CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"当前使用设备: {device}")

# 加载数据、初始化模型和优化器
trainloader, testloader, classes = load_data()
net = Net().to(device)  # 模型放到指定设备
criterion, optimizer = compile_model(net)# 重定向输出到文件,同时保留控制台打印
original_stdout = sys.stdout
with open('cifar10_result.txt', 'w') as f:sys.stdout = fprint(f"当前使用设备: {device}")train(net, trainloader, criterion, optimizer, device)test(net, testloader, classes, device)sys.stdout = original_stdout  # 恢复控制台输出print("训练完成!结果已保存到 cifar10_result.txt ")

image

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/937750.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[LangChain] 02. 模型接口

LangChain 支持两类主流语言模型:文本补全模型 对话模型文本补全模型 Text Completion Models 这类模型以一段纯文本作为输入,输出结果是一段连续生成的文字(这里的输出文本其实就是对前面输入文本的一个补全),不…

摄像头调试

camera调试经验分享 收藏 一 关于Sensor预览时有条纹: 1。电源不稳定,CMOS sensor对电源的稳定度蛮高的。 2。同步信号受干扰,彩色条纹显然是每行数据中有信号丢失造成。 3。检查mclk和pclk以及他们的ratio,软件…

软件工程作业-报告1 - 实践

软件工程作业-报告1 - 实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", …

C语言学习——字符串数据类型

字符串的数据类型只有char来进行定义,相对之前的来讲较少同样打印的内容需要用引号来进行标注,同时也可以用与整数和小数的方法来进行测量字节 接下来我们对以上三种数据类型进行一个总结和概括: > 所有整数,小…

感知节点@4@ ESP32+arduino+ 第二个程序 LED灯显示

1、查看电路图,那个ESP32的引脚连接LED灯 图中看到是IO2 2、查找和打开例程Blink 3、按照电路图,定义引脚编号 4)编译下载固件 点击“上传”按钮,同时一直按住电路板上的BOOT(IO0)按钮,直到开始下载固件…

WebGL学习及项目实战(第02期:绘制一个点)

@目录目标WebGL原理示意图着色器顶点着色器:片元着色器:着色器代码如下web端(js)js代码代码结构梳理流程图完整代码(可直接在浏览器中查看)运行效果 目标使用WebgL绘制一个点 了解整个绘制的编写流程并进行梳理和…

2025 年 10 月国内加工中心制造商最新推荐排行榜:涵盖立式、卧式、龙门及多规格型号!

当前加工中心市场厂商数量繁杂,产品质量、技术实力与服务水平差异显著,汽车摩托车、工程机械、军工等行业采购方在挑选设备时,常面临不知如何辨别优质厂商、耗费大量时间调研却难觅适配产品的困境。部分厂商存在技术…

display ip routing-table protocol ospf 概念及题目 - 详解

display ip routing-table protocol ospf 概念及题目 - 详解pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Conso…

C语言学习——小数数据类型

> 小数的数据类型分为:float、double > 不同的数据类型所表示的范围和内存大小都不一样可以用,由编译器来决定,可以用sizeof来决定 > 小数的取值范围比整数的大 > c语言的的小数默认double联系的…

高敏感人应对焦虑

2️⃣ 注意力外投法:从内心的泥潭中抽身 烦躁和无法学习,往往是因为注意力被困在了内心的思虑中。我们需要强行把它拉出来,投向外部世界。5-4-3-2-1感官法:具体操作:无论在何处,立刻(在心里或小声地)说出:5 个…

kali构建PHP_MYSQL

kali构建PHP_MYSQL配置Mysql sudo mysql -u root //第一次可以直接进入 alter user root@localhost identified by 123456; create database usr;配置PHP 进入目录:cd /etc/php/8.2/apache2 执行:sudo vim +904 php.in…

Palantir本体论以及对智能体建设的价值与意义

Palantir本体论以及对智能体建设的价值与意义 赋能智能体:Palantir Foundry本体工程如何构建企业级AI的“可编程数字孪生”摘要: 随着大语言模型(LLM)驱动的智能体(AI Agent)成为企业数字化的核心驱动力,传统的…

2025 年执业兽医资格证备考服务机构推荐榜,执业兽医资格证培训机构/执兽考试机构/考试辅导机构获得行业推荐

随着养殖业规模化发展与行业规范化推进,执业兽医资格证已成为从业人员开展专业工作的核心凭证,对应的备考需求逐年增长。但当前市场上,执业兽医资格证备考服务机构在课程适配性、师资专业性、服务响应效率等方面存在…

[LangChain] 基本介绍

在大模型时代,LangChain 是一个帮助开发者快速构建“智能应用” 的工具框架。它像是你搭建 AI 应用时的“万能胶水”——把大模型(如 OpenAI、LLM API)、工具(如搜索引擎、数据库)、记忆能力、链式调用等模块统统…

题解:P6755 [BalticOI 2013] Pipes (Day1)

P6755:构造、图论、拓扑排序、线性代数。题目等价于:给定一个无向图和所有点的点权,给每条边确定一个边权,使得每个点的点权等于与其相连所有边的边权和除以二。特别地,如果无解或有无数解,只需输出 \(0\) 即可。…

深度学习调试记录 - 详解

深度学习调试记录 - 详解2025-10-15 19:18 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; fo…

Palantir 的“本体工程”的核心思路、技术架构与实践示例

Palantir 的“本体工程”的核心思路、技术架构与实践示例引言:为什么“本体工程”在当下越来越被强调? 在 AI+业务落地的浪潮中,很多团队一开始聚焦于模型(LLM、对话模型、检索模型等)、Prompt 设计、上下文检索(…

P14164 [ICPC 2022 Nanjing R] 命题作文

给定一个包含 \(n\) 个点的链,\(m\) 次每次额外添加一条边,操作之间不独立。每次操作完询问有多少种方案选出两条边使得删除这两条边之后图不联通。 \[n,m \le 2.5\times10^5 \] 称额外添加的边为额外边,原来的 \(n…

C语言学习——整数变量

一.整数变量有四种类型数据分别是以下四种形式二.以下是测量数据字节的方式,当需要测量数据的字节时可以通过以下方式进行测量三.有符号整数和无符号整数的定义情况> 注意: > 在用有符号整数定义负数时占位符也…

语音合成技术从1秒样本学习表达风格

某中心研究人员开发的新型语音合成系统仅需1秒语音样本即可学习表达风格,通过变分自编码器和标准化流技术实现表达风格转换,用户评价显示合成语音自然度提升9%。语音合成器从一秒语音样本学习表达风格 用户评价显示,…