用PyTorch搭建卷积神经网络实现MNIST手写数字识别

用PyTorch搭建卷积神经网络实现MNIST手写数字识别

在深度学习领域,卷积神经网络(Convolutional Neural Network,简称CNN)是处理图像数据的强大工具。它通过卷积层、池化层和全连接层等组件,自动提取图像特征,在图像分类、目标检测等任务中表现卓越。本文将使用PyTorch框架,搭建一个CNN模型来实现MNIST手写数字识别,并详细解析每一步代码。

一、MNIST数据集介绍

MNIST数据集是深度学习领域经典的入门数据集,包含70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试。这些图像均为灰度图,尺寸是28x28像素,并且已经做了居中处理,这在一定程度上减少了预处理的工作量,能够加快模型的训练和运行速度。

二、环境准备与数据加载

2.1 导入必要的库

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

上述代码导入了PyTorch的核心库、神经网络模块、数据加载工具以及用于图像数据处理和数据集管理的库。

2.2 下载并加载数据集

training_data = datasets.MNIST(root='data',train=True,download=True,transform=ToTensor()
)test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor()
)

通过datasets.MNIST函数分别下载训练集和测试集。root参数指定数据下载的路径;train=True表示下载训练集数据,train=False则表示下载测试集数据;download=True确保如果数据尚未下载,会自动进行下载;transform=ToTensor()将图像数据转换为PyTorch能够处理的张量格式。

2.3 数据可视化

from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i + 59000]figure.add_subplot(3, 3, i + 1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

这段代码使用matplotlib库展示了训练数据集中的部分手写数字图像,通过plt.imshow函数将张量格式的图像数据可视化,直观感受MNIST数据集的内容。

2.4 创建数据加载器

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

DataLoader用于将数据集打包成批次,batch_size参数指定每个批次包含的数据样本数量。将数据集分成批次进行训练,能够有效减少内存使用,并提高训练速度。

三、设备配置

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

这段代码检测当前设备是否支持GPU(CUDA)或苹果M系列芯片的GPU(MPS),如果都不支持,则使用CPU进行计算。后续模型和数据都会被移动到选定的设备上运行,以充分利用硬件资源加速训练。

四、定义卷积神经网络模型

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU())self.out = nn.Linear(64 * 7 * 7, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output

在这个自定义的CNN类中,继承自nn.Module__init__方法中定义了网络的结构:

  • 卷积层(nn.Conv2d:用于提取图像特征,通过设置in_channels(输入通道数)、out_channels(输出通道数,即卷积核个数)、kernel_size(卷积核大小)、stride(步长)和padding(填充)等参数,控制卷积操作。
  • 激活函数层(nn.ReLU:引入非线性,增强网络的表达能力。
  • 池化层(nn.MaxPool2d:对特征图进行下采样,减少数据量和计算量,同时保留主要特征。
  • 全连接层(nn.Linear:将卷积层和池化层提取的特征映射到输出类别(MNIST数据集中有10个数字类别)。

forward方法定义了数据在网络中的前向传播路径,确保数据按照网络结构依次经过各层处理,最终输出预测结果。

五、训练与测试模型

5.1 定义损失函数和优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

nn.CrossEntropyLoss是适用于多分类任务的交叉熵损失函数,用于计算模型预测结果与真实标签之间的差距。torch.optim.Adam是一种常用的优化器,通过调整模型的参数(model.parameters())来最小化损失函数,lr参数设置学习率,控制参数更新的步长。

5.2 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 == 0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num += 1

在训练函数中:

  • model.train()将模型设置为训练模式,此时模型中的一些层(如Dropout层)会按照训练规则工作。
  • 遍历数据加载器中的每一个批次数据,将数据和标签移动到指定设备上。
  • 通过模型进行预测,计算损失值。
  • 使用optimizer.zero_grad()清零梯度,loss.backward()进行反向传播计算梯度,optimizer.step()根据梯度更新模型参数。
  • 每隔100个批次,打印当前的损失值,以便观察训练过程中的损失变化。

5.3 测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f'Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}')

测试函数中:

  • model.eval()将模型设置为测试模式,关闭一些在训练过程中起作用但在测试时不需要的操作(如Dropout)。
  • 使用with torch.no_grad()上下文管理器,关闭梯度计算,因为在测试阶段不需要更新模型参数,这样可以节省计算资源。
  • 遍历测试数据,计算每个批次的损失值并累加,同时统计预测正确的样本数量。
  • 最后计算并打印测试集上的平均损失和准确率,评估模型的性能。

5.4 执行训练和测试

epoch = 9
for i in range(epoch):print(i + 1)train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)

通过设置训练轮数(epoch),循环调用训练函数进行模型训练,每一轮训练结束后,调用测试函数评估模型在测试集上的性能。

六、总结

本文通过详细的代码解析,展示了如何使用PyTorch搭建一个简单的卷积神经网络来实现MNIST手写数字识别任务。从数据加载、模型定义,到训练和测试,每一个步骤都体现了CNN在图像分类任务中的核心思想和实现方法。通过不断调整模型结构、超参数等,还可以进一步提升模型的性能。卷积神经网络在图像领域的应用远不止于此,它在更复杂的图像任务和其他领域也有着广泛的应用前景,希望本文能为大家深入学习深度学习提供一个良好的开端。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/79857.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Tensorrt 基础入门

什么是tensorrt? 其他厂商: Qualcomm, Hailo, google TPU tensorrt的优劣势 使用tensorrt的pipeline tensorrt使用中存在的问题以及解决方案 tensorrt的应用场景 自动驾驶模型部署需要关注的问题: 边端硬件资源有限 散热(不能水冷) 实时性&…

Qt 显示QRegExp 和 QtXml 不存在问题

QRegExp 和 QtXml 问题 在Qt6 中 已被弃用; 1)QRegExp 已被弃用,改用 QRegularExpression Qt5 → Qt6 重大变更:QRegExp 被移到了 Qt5Compat 模块,默认不在 Qt6 核心模块中。 错误类型解决方法QRegExp 找不到改用 Q…

玩玩OCR

一、Tesseract: 1.下载windows版: tesseract 2. 安装并记下路径,等会要填 3.保存.py文件 import pytesseract from PIL import Image def ocr_local_image(image_path):try:pytesseract.pytesseract.tesseract_cmd rD:\Programs\Tesseract-OCR\tesse…

Dify 完全指南(一):从零搭建开源大模型应用平台(Ollama/VLLM本地模型接入实战)》

文章目录 1. 相关资源2. 核心特性3. 安装与使用(Docker Compose 部署)3.1 部署Dify3.2 更新Dify3.3 重启Dify3.4 访问Dify 4. 接入本地模型4.1 接入 Ollama 本地模型4.1.1 步骤4.1.2 常见问题 4.2 接入 Vllm 本地模型 5. 进阶应用场景6. 总结 1. 相关资源…

C++ Windows 打包exe运行方案(cmake)

文章目录 背景动态库梳理打包方案一、使用 Vcpkg 安装静态库(关键基础配置)1. 初始化 Vcpkg2. 安装静态库(注意 x64-windows-static 后缀) 二、CMakeLists.txt 关键配置三、编译四、验证 不同平台代码兼容\_\_attribute\_\_((pack…

Java学习手册:Hibernate/JPA 使用指南

一、Hibernate 和 JPA 的核心概念 实体(Entity) :实体是 JPA 中用于表示数据库表的 Java 对象。通过在实体类上添加 Entity 注解,JPA 可以将实体类映射到数据库表。例如,定义一个 User 实体类: import ja…

字符串匹配 之 拓展 KMP算法(Z算法)

文章目录 习题2223.构造字符串的总得分和3031.将单词恢复初始状态所需的最短时间 II 灵神代码模版 区别与KMP算法 KMP算法可用于求解在线性时间复杂度0(n)内求解模式串p在主串s中匹配的未知当然,由于在KMP算法中,预处理求解出了next数组,也就…

安全为上,在系统威胁建模中使用量化分析

*注:Open FAIR™ 知识体系是一种开放和独立的信息风险分析方法。它为理解、分析和度量信息风险提供了分类和方法。Open FAIR作为领先的风险分析方法论,已得到越来越多的大型组织认可。 在数字化风险与日俱增的今天,企业安全决策正面临双重挑战…

游戏引擎学习第259天:OpenGL和软件渲染器清理

回顾并为今天的内容做好铺垫 今天,我们将对游戏的分析器进行升级。在之前的修复中,我们解决了分析器的一些敏感问题,例如它无法跨代码重新加载进行分析,以及一些复杂的小问题。现在,我们的分析器看起来已经很稳定了。…

讯睿CMS模版常用标签参数汇总

一、模板调用标签 1、首页 网站名称:{SITE_NAME} 标题:{$meta_title}(列表页通用) Keywords:{$meta_keywords} Description:{$meta_description}2、列表页 迅睿cms调用本栏目基础信息标签代码 当前栏目…

【C#】Buffer.BlockCopy的使用

Buffer.BlockCopy 是 C# 中的一个方法,用于在数组之间高效地复制字节块。它主要用于操作字节数组(byte[]),但也可以用于其他类型的数组,因为它直接基于内存操作。 以下是关于 Buffer.BlockCopy 的详细说明和使用示例&…

记一次pdf转Word的技术经历

一、发现问题 前几天在打开一个pdf文件时,遇到了一些问题,在Win10下使用WPS PDF、万兴PDF、Adobe Acrobat、Chrome浏览器打开都是正常显示的;但是在macOS 10.13中使用系统自带的预览程序和Chrome浏览器(由于macOS版本比较老了&am…

在Laravel 12中实现4A日志审计

以下是在Laravel 12中实现4A(认证、授权、账户管理、审计)日志审计并将日志存储到MongoDB的完整方案(包含性能优化和安全增强措施): 一、环境配置 安装MongoDB扩展包 composer require jenssegers/mongodb配置.env …

链表高级操作与算法

链表是数据结构中的基础,但也是面试和实际开发中的重点考察对象。今天我们将深入探讨链表的高级操作和常见算法,让你能够轻松应对各种链表问题。 1. 链表翻转 - 最经典的链表问题 链表翻转是面试中的常见题目,也是理解链表指针操作的绝佳练…

架构思维:构建高并发读服务_使用懒加载架构实现高性能读服务

文章目录 一、引言二、读服务的功能性需求三、两大基本设计原则1. 架构尽量不要分层2. 代码尽可能简单 四、实战方案:懒加载架构及其四大挑战五、改进思路六、总结与思考题 一、引言 在任何后台系统设计中,「读多写少」的业务场景占据主流:浏…

在运行 Hadoop 作业时,遇到“No such file or directory”,如何在windows里打包在虚拟机里运行

最近在学习Hadoop集群map reduce分布运算过程中,经多方面排查可能是电脑本身配置的原因导致每次运行都会报“No such file or directory”的错误,最后我是通过打包文件到虚拟机里运行得到结果,具体步骤如下: 前提是要保证maven已经…

软考-软件设计师中级备考 11、计算机网络

1、计算机网络的分类 按分布范围分类 局域网(LAN):覆盖范围通常在几百米到几千米以内,一般用于连接一个建筑物内或一个园区内的计算机设备,如学校的校园网、企业的办公楼网络等。其特点是传输速率高、延迟低、误码率低…

【C#】.net core6.0无法访问到控制器方法,直接404。由于自己的不仔细,出现个低级错误,这让DeepSeek看出来了,是什么错误呢,来瞧瞧

🌹欢迎来到《小5讲堂》🌹 🌹这是《C#》系列文章,每篇文章将以博主理解的角度展开讲解。🌹 🌹温馨提示:博主能力有限,理解水平有限,若有不对之处望指正!&#…

当LLM遇上Agent:AI三大流派的“复仇者联盟”

你一定听说过ChatGPT和DeepSeek,也知道它们背后的LLM(大语言模型)有多牛——能写诗、写代码、甚至假装人类。但如果你以为这就是AI的极限,那你就too young too simple了! 最近,**Agent(智能体&a…

Spring Boot多模块划分设计

在Spring Boot多模块项目中,模块划分主要有两种思路:​​技术分层划分​​和​​业务功能划分​​。两种方式各有优缺点,需要根据项目规模、团队结构和业务特点来选择。 ​​1. 技术分层划分(横向拆分)​​ 结构示例&…