深入解析:使用 PyTorch 实现 CIFAR-10 图像分类:从数据加载到模型训练全流程

news/2025/10/25 11:55:24/文章来源:https://www.cnblogs.com/lxjshuju/p/19165117

在深度学习入门实践中,CIFAR-10 数据集分类是一个经典案例。本文将详细介绍如何使用 PyTorch 构建一个卷积神经网络 (CNN) 来完成 CIFAR-10 图像分类任务,涵盖数据加载、模型构建、训练过程和结果评估的完整流程。

项目概述

CIFAR-10 是一个包含 10 个类别的彩色图像数据集,每个类别有 6000 张 32×32 像素的图像,共 60000 张图像,分为 50000 张训练集和 10000 张测试集。10 个类别分别是:飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。

本项目将实现一个简单的卷积神经网络,使用 PyTorch 框架完成从数据加载到模型评估的全流程,并达到不错的分类效果。

一、数据加载与预处理

首先需要加载 CIFAR-10 数据集并进行必要的预处理:

import torch
import torchvision
import torchvision.transforms as transforms
# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为Tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到[-1, 1]范围
])
# 加载训练集
trainset = torchvision.datasets.CIFAR10(root='C:\\Users\\Administrator\\Desktop\\Untitled Folder\\cifar-10-batches-py',train=True,download=False,  # 已手动放置数据集,无需下载transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,  # 批次大小为4shuffle=True,  # 打乱数据顺序num_workers=0  # Windows环境建议设为0,避免多进程加载报错
)
# 加载测试集
testset = torchvision.datasets.CIFAR10(root='C:\\Users\\Administrator\\Desktop\\Untitled Folder\\cifar-10-batches-py',train=False,download=False,transform=transform
)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,  # 测试集不需要打乱num_workers=0
)
# CIFAR10类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

核心说明

  • transforms.Compose用于组合多个预处理操作
  • ToTensor()将 PIL 图像转换为 PyTorch 张量,并将像素值从 [0,255] 缩放到 [0,1]
  • Normalize()进行归一化,使数据均值为 0,标准差为 1,有助于模型收敛
  • DataLoader用于批量加载数据,并支持多进程加速(Windows 下建议关闭)
  • shuffle=True确保训练时每个 epoch 的数据顺序都不同,有助于模型泛化

二、数据可视化

加载数据后,我们可以编写一个函数来可视化数据,了解我们要处理的图像:

def show_images(tensor_images, labels, class_names):"""显示批量图像并打印对应标签"""# 反归一化(还原图像亮度)tensor_images = tensor_images / 2 + 0.5# 转换为PIL图像网格img_grid = torchvision.utils.make_grid(tensor_images)img = torchvision.transforms.ToPILImage()(img_grid)# 显示图像(调用系统默认图像查看器)img.show()# 打印对应标签print("图像标签:", ' '.join(f"{class_names[labels[j]]:5s}" for j in range(len(labels))))
# 测试:加载并显示一批训练数据
dataiter = iter(trainloader)
images, labels = next(dataiter)  # 获取一批数据(4张图像)
# 显示图像和标签
show_images(images, labels, classes)

核心说明

  • 由于之前对图像进行了归一化,需要通过tensor_images / 2 + 0.5反归一化才能正确显示
  • torchvision.utils.make_grid()可以将多张图像组合成一个网格图像
  • ToPILImage()将张量转换回 PIL 图像格式以便显示

三、构建卷积神经网络模型

接下来我们构建一个简单的卷积神经网络:

import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class CNNNet(nn.Module):def __init__(self):super(CNNNet, self).__init__()# 第一个卷积层:3输入通道,16输出通道,5x5卷积核self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1)# 第一个池化层:2x2池化核,步长为2self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)# 第二个卷积层:16输入通道,36输出通道,3x3卷积核self.conv2 = nn.Conv2d(in_channels=16, out_channels=36, kernel_size=3, stride=1)# 第二个池化层:2x2池化核,步长为2self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 第一个全连接层self.fc1 = nn.Linear(1296, 128)# 第二个全连接层(输出层,10个类别)self.fc2 = nn.Linear(128, 10)def forward(self, x):# 卷积->激活->池化x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))# 展平特征图x = x.view(-1, 36 * 6 * 6)# 全连接层->激活x = F.relu(self.fc1(x))# 输出层x = self.fc2(x)return x
# 初始化模型并移动到可用设备(GPU或CPU)
net = CNNNet()
net = net.to(device)
# 打印模型总参数数量
print("net have {} parameters in total".format(sum(x.numel() for x in net.parameters())))

核心说明

  • 模型采用经典的卷积 - 池化 - 全连接结构
  • Conv2d层负责提取图像特征,通过卷积核捕获局部特征
  • MaxPool2d层进行下采样,减少特征图尺寸,同时保留重要特征
  • forward方法定义了数据在网络中的流动路径
  • x.view(-1, 36 * 6 * 6)将二维特征图展平为一维向量,以便输入全连接层
  • 代码会自动检测并使用 GPU(如果可用),否则使用 CPU

我们还可以通过以下代码获取网络的特征提取部分(前 4 层):

# 获取网络的特征提取部分(卷积层和池化层)
feature_extractor = nn.Sequential(*list(net.children())[:4])

四、定义损失函数和优化器

训练神经网络需要定义损失函数和优化器:

import torch.optim as optim
# 学习率
LR = 0.001
# 损失函数:交叉熵损失,适用于分类任务
criterion = nn.CrossEntropyLoss()
# 优化器:SGD(随机梯度下降)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 也可以使用Adam优化器,通常收敛更快
# optimizer = optim.Adam(net.parameters(), lr=LR)
# 打印网络结构
print(net)

核心说明

  • 交叉熵损失 (CrossEntropyLoss) 是分类任务的常用损失函数
  • SGD 优化器带有动量 (momentum=0.9) 可以加速收敛并减少震荡
  • Adam 优化器通常收敛更快,但在某些任务上 SGD 可能泛化更好
  • 学习率 (lr) 是重要的超参数,过大会导致不收敛,过小会导致收敛太慢

五、训练模型

模型和数据准备就绪后,就可以开始训练了:

for epoch in range(10):  # 迭代10个epochrunning_loss = 0.0for i, data in enumerate(trainloader, 0):# 获取训练数据并移动到设备inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 权重参数梯度清零optimizer.zero_grad()# 正向传播、计算损失、反向传播、参数更新outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 统计并显示损失值running_loss += loss.item()if i % 2000 == 1999:    # 每2000个mini-batch打印一次print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0
print('Finished Training')

核心说明

  • 训练过程由多个 epoch 组成,每个 epoch 遍历整个训练集一次
  • 每个 epoch 又分为多个 mini-batch,按批次处理数据
  • optimizer.zero_grad()清除上一次迭代的梯度
  • loss.backward()计算梯度(反向传播)
  • optimizer.step()根据梯度更新参数
  • 定期打印损失值可以监控训练进度,损失总体应该呈下降趋势

六、模型测试与评估

训练完成后,我们需要测试模型的性能:

# 使用训练好的模型进行预测
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))

我们还可以批量查看测试集的预测结果:

# 加载测试集并显示预测结果
dataiter = iter(testloader)
for i in range(100):  # 控制显示的批次数try:images, labels = next(dataiter)images, labels = images.to(device), labels.to(device)print(f"第{i+1}批图像:")show_pil_image(images.cpu())  # 转回CPU才能显示# 真实标签print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))# 预测结果outputs = net(images)_, predicted = torch.max(outputs, 1)print('Predicted:   ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(4)))print('-'*50)  # 分隔线except StopIteration:break  # 数据迭代完则停止

核心说明

  • torch.max(outputs, 1)返回每个样本的最大预测值和对应的类别索引
  • 预测时需要将数据移动到与模型相同的设备(GPU 或 CPU)
  • 显示图像时需要将张量转回 CPU
  • 通过对比GroundTruth(真实标签)和Predicted(预测结果)可以直观了解模型性能

七、总结与改进方向

本项目实现了一个简单的 CNN 模型用于 CIFAR-10 分类,通过 10 个 epoch 的训练,通常可以达到 60%-70% 的准确率。这个结果对于基础模型来说已经不错,但还有很大的提升空间:

  1. 增加网络深度和宽度:可以尝试使用更深的网络结构,如 VGG、ResNet 等
  2. 数据增强:增加更多的数据增强手段,如随机裁剪、旋转、翻转等,提高模型泛化能力
  3. 调整超参数:尝试不同的学习率、批次大小、优化器等
  4. 正则化:添加 Dropout 层或 L2 正则化,减少过拟合
  5. 学习率调度:使用学习率衰减策略,使训练更稳定

通过这个项目,我们掌握了使用 PyTorch 进行图像分类的完整流程,包括数据加载与预处理、模型构建、训练过程和结果评估。这些技能可以迁移到其他图像分类任务中,是深度学习入门的重要实践。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/946074.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025年冷库保温建材工厂权威推荐榜单:泡沫模块建大棚/检修用围栏/绝缘围栏源头厂家精选

在冷链物流行业快速发展的今天,优质保温建材成为保障冷库效能的关键。 随着2025年全球冷藏室保温板市场规模预计接近367.6亿元,中国已成为全球最大市场,占比升至42%。行业在新版冷库节能设计规范下迎来结构性变革,…

完整教程:营销驱动式增长(MLG)是什么?解析模式、策略与实践案例

完整教程:营销驱动式增长(MLG)是什么?解析模式、策略与实践案例pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: &quo…

2025 年防冻液源头厂家最新推荐口碑排行榜:严检合格技术为先,实力企业权威甄选食品级/空气能专用/长效防冻液公司推荐

引言 在工业生产与交通运输领域,防冻液对设备低温运行至关重要。为精准筛选优质厂家,中国化工产业协会联合设备防护技术联盟开展 2025 年防冻液源头厂家测评,采用 “三维九项” 测评方法。从产品核心性能(冰点控制…

2025 年冷藏车厂家最新推荐排行榜:结合协会测评权威数据,详解优质品牌特点与选购指南 9.6 米 / 解放 / 4.2 米 / 福田 / 小型冷藏车公司推荐

引言 随着冷链物流行业规模持续扩大,冷藏车作为核心运输装备,市场需求年均增速超 15%。但市场产品质量参差不齐,为帮助消费者精准选择,中国物流与采购联合会冷链物流专业委员会联合行业权威机构开展 2025 年度冷藏…

2025 年铣边机源头厂家最新推荐排行榜:含钢板 / 平板 / 板材 / 自走式 / 全自动铣边机机型,结合协会测评数据甄选实力企业

引言 在航空航天、压力容器、电厂、石油化工等关键工业领域,铣边机的性能与品质直接决定生产精度与效率。为助力企业精准选择设备,中国重型机械工业协会联合行业检测机构开展 2025 年度铣边机厂家测评,本次测评覆盖…

2025 年载冷剂厂家推荐排行榜:无醇/安全型/SH-4/SH-5A/多元醇/高低温/超低温/乙二醇/冷库专用/食品级载冷剂公司推荐

引言 在工业制冷与商业冷链领域,载冷剂的品质直接决定系统运行效率与安全,当前市场产品质量差异显著,企业采购面临重重难题。为破解这一困境,本次榜单由中国制冷空调工业协会全程参与测评,采用 “四维权威评估体系…

[网络] [TCP] 使用py脚本简单实现tcp通信发送/储存文件

[网络] [TCP] 使用py脚本简单实现tcp通信发送/储存文件$(".postTitle2").removeClass("postTitle2").addClass("singleposttitle");ChatGPT生成(2025年10月25日11:48:12)目录服务端(…

《手搓》线程池

《手搓》线程池一、什么是《手搓》线程池手搓线程池并不是用来完全代替系统线程池的 你可以把手搓线程池看做系统线程池的一部分 就好比在东海用大的集装箱搞养殖 一个集装箱里养鱼 另一个集装箱里养虾 搞好隔离,鱼虾都…

kali wsl桌面使用

Kali Linux WSL系统及Win-Kex图形界面完整指南 Kali Linux WSL系统介绍 Kali Linux是基于Debian的Linux发行版,专为渗透测试和安全审计设计。Windows Subsystem for Linux (WSL)是微软开发的Windows子系统,允许用户在…

2025 年房屋改造公司最新推荐榜,聚焦企业服务能力与市场口碑深度解析老房 / 旧房 / 局部 / 小户型 / 出租房房屋改造推荐

引言 伴随存量房市场规模持续扩大,房屋改造需求年增长率超 15%,但 “服务断层、质量无保障、价格不透明” 仍是行业突出痛点。为破解选择难题,本次榜单依托中国建筑装饰协会 2025 年三季度《住宅改造服务测评白皮书…

单点登录的完成原理

单点登录的完成原理pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "C…

2025 年桥梁防撞护栏优质厂家最新推荐榜:涵盖锌钢 / ZF01/Q235/Q355B / 景观 / 灯光 / 河道 / 公路 / 喷塑等类型,全方位解析实力企业

引言 当前交通基础设施建设持续推进,桥梁作为核心交通枢纽,其安全防护体系中桥梁防撞护栏的重要性愈发凸显。然而,市场上防撞护栏产品质量参差不齐,部分厂家用劣质原材料降低成本,小作坊式企业因技术落后导致产品…

2025年欧那德语:权威解析课程体系与师资实力

引言:本文将从“课程体系与师资实力”这一核心维度出发,为读者提供一份可量化、可验证、可对照的客观参考,帮助正在比对在线德语培训产品的学习者快速锁定关键信息,减少试错成本。 背景与概况:欧那德语成立于2013…

模拟赛 R18

R18 - A 子集计数 题目描述 给定一长度为 \(n\) 的序列 \(a_1,a_2,\cdots,a_n\),再额外给定一常数 \(m\),对每个 \(k=0,1,2,\cdots,n\),请你求出有多少个 \(S\subset \{1,2,3,\cdots,n\}\),满足存在 \(T\subset S\…

2025年盐趣科研教育深度解析:从录取数据到成果落地的全链路拆解

引言 本文聚焦“录取成果与科研落地”这一核心维度,为计划通过科研背景提升冲刺海外名校的申请者提供一份可量化、可验证的客观参考,避免被营销话术裹挟。 背景与概况 盐趣科研教育(ViaX盐趣,官网www.viax.org)成…

2025年盐趣科研教育深度解析:从“科研背景”维度拆解留学突围路径

引言 本文聚焦“科研背景”这一核心维度,对盐趣科研教育进行针对性拆解,为计划通过科研提升申请竞争力的学生提供一份可落地的客观参考。 背景与概况 盐趣科研教育(ViaX盐趣,www.viax.org)成立于2015年,隶属于北…

大素材数据质量校验实战指南:从0.3%差异率到滴水不漏的核对体系

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

备战CSP:考试环境搭建与使用指南

备战CSP:考试环境搭建与使用指南 大家好!“磨刀不误砍柴工”,今天我们的目标就是磨好手上的这把“刀”——完全熟悉并掌握CSP的官方考试环境。熟练操作环境,可以在考场上为你节省宝贵的时间,避免不必要的慌乱。让…

2025年1月暖风机口碑榜:五款主流机型对比与选购避坑

寒冬深夜,孩子写作业手脚冰凉,老人起床怕冷,上班族回家想立刻暖起来——这些场景让“暖风机”成为搜索热词。可打开电商页面,PTC、石墨烯、远红外、四核等名词扑面而来,价格从百元到千元不等,噪音、耗电、干燥、…

2025 年最新推荐装修公司优质品牌排行榜:聚焦环保与工艺,口碑装修公司权威甄选

引言 随着家装市场需求持续增长,消费者对装修品质、环保标准及服务体验的要求愈发严苛。为助力业主精准筛选优质装修资源,本次榜单由中国建筑装饰协会联合行业权威测评机构共同打造,历经 3 个月调研,覆盖全国 28 个…