用卷积神经网络 (CNN) 实现 MNIST 手写数字识别

在深度学习领域,MNIST 手写数字识别是经典的入门级项目,就像编程世界里的 “Hello, World”。卷积神经网络(Convolutional Neural Network,CNN)作为处理图像数据的强大工具,在该任务中展现出卓越的性能。本文将结合具体的 PyTorch 代码,详细解析如何利用 CNN 实现 MNIST 手写数字识别,带大家从代码实践深入理解背后的技术原理。

一、数据准备:加载与预处理 MNIST 数据集

MNIST 数据集包含 6 万张训练图像和 1 万张测试图像,涵盖 0 - 9 这十个数字的手写体。我们借助torchvision库中的datasets.MNIST函数来加载数据,具体代码如下:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortraining_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)

上述代码中,root="data"指定数据集的存储路径;train=True表示加载训练集,train=False用于加载测试集;download=True确保本地无数据集时自动下载;transform=ToTensor()将图像数据转换为 PyTorch 张量格式,并把像素值从 0 - 255 归一化到 0 - 1 区间,便于后续处理。

为直观感受数据,我们用matplotlib库绘制 9 张训练图像及其标签:

from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i + 59000]figure.add_subplot(3, 3, i + 1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")a = img.squeeze()
plt.show()

完成数据加载后,使用DataLoader将数据封装成批次,方便模型训练和测试:

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

batch_size=64意味着每次训练或测试,模型会同时处理 64 个样本,能提高计算效率和训练稳定性。

二、模型构建:搭建卷积神经网络架构

我们定义一个名为CNN的类,继承自nn.Module,用于构建卷积神经网络:

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,stride=1,padding=1,),nn.ReLU(),nn.MaxPool2d(2))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 3, 1, 1),nn.ReLU(),nn.MaxPool2d(2),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1),nn.ReLU(),)self.out = nn.Linear(64 * 7 * 7, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output

  • 卷积层(nn.Conv2d:在conv1conv2conv3中,通过卷积层提取图像特征。例如conv1中的nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)in_channels=1表示输入图像为单通道灰度图,out_channels=16表示输出 16 个特征图,kernel_size=3指定 3×3 的卷积核,stride=1是步长,padding=1用于保持图像尺寸不变。
  • 激活函数(nn.ReLU:紧跟在卷积层之后,为模型引入非线性,帮助模型学习复杂的模式。
  • 池化层(nn.MaxPool2d:通过下采样操作,如nn.MaxPool2d(2)将图像尺寸减半,减少数据量和模型参数,同时保留重要特征,防止过拟合。
  • 全连接层(nn.Linearself.out = nn.Linear(64 * 7 * 7, 10)将卷积层输出的特征图展平后连接到全连接层,输出 10 个神经元对应 0 - 9 十个数字类别,完成最终分类。

最后,将模型移动到合适的计算设备(GPU、MPS 或 CPU)上:

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model = CNN().to(device)
print(model)

三、模型训练与测试:优化与评估

3.1 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 == 0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1

在训练函数中,model.train()将模型设为训练模式。遍历数据加载器,将每一批数据和标签移至指定设备,前向传播计算预测值,通过交叉熵损失函数nn.CrossEntropyLoss()计算损失,optimizer.zero_grad()清空梯度,loss.backward()反向传播计算梯度,optimizer.step()更新模型参数,每 100 个批次打印一次损失值。

3.2 测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")return test_loss, correct

测试函数中,model.eval()将模型设为评估模式,关闭如 Dropout 等训练时的操作。在with torch.no_grad()下遍历测试数据,计算测试损失和正确预测的样本数,最后计算平均损失和准确率并输出。

3.3 执行训练与测试

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 10
for t in range(epochs):print(f"Epoch {t + 1}\n--------------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

我们选用交叉熵损失函数和 Adam 优化器,学习率设为 0.01,通过 10 个训练周期不断优化模型,训练完成后在测试集上评估模型性能,得到最终的准确率和平均损失。

四、总结与展望

通过上述代码实践,我们成功利用卷积神经网络实现了 MNIST 手写数字识别。从数据加载、模型构建到训练测试,每个环节都紧密相连,展示了 CNN 在图像识别任务中的强大能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/81910.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从 MDM 到 Data Fabric:下一代数据架构如何释放 AI 潜能

从 MDM 到 Data Fabric:下一代数据架构如何释放 AI 潜能 —— 传统治理与新兴架构的范式变革与协同进化 引言:AI 规模化落地的数据困境 在人工智能技术快速发展的今天,企业对 AI 的期望已从 “单点实验” 转向 “规模化落地”。然而&#…

苍穹外卖部署到云服务器使用Docker

部署前端 1.创建nginx镜像 docker pull nginx 2.宿主机(云服务器)创建挂载目录和文件 最好手动创建 而不是通过docker run创建,否则nginx.conf 默认会被创建为文件夹 nginx.conf 和html可以直接从黑马给的资料里导入 3.运行nginx容器&am…

C++ 渗透 数据结构中的二叉搜索树

欢迎来到干货小仓库 "沙漠尽头必是绿洲。" --面对技术难题时,坚持终会看到希望。 1.二叉搜索树的概念 二叉搜索树又称二叉排序树,它或者是一颗空树,或者是具有以下性质的二叉树: a、若它的左子树不为空,则…

实现滑动选择器从离散型的数组中选择

1.使用原生的input 详细代码如下&#xff1a; <template><div class"slider-container"><!-- 滑动条 --><inputtype"range"v-model.number"sliderIndex":min"0":max"customValues.length - 1"step&qu…

ARM寻址方式

寻址方式指的是确定操作数位置的方式。 寻址方式&#xff1a; 立即数寻址 直接寻址&#xff08;绝对寻址&#xff09;&#xff0c;ARM不支持这种寻址方式&#xff0c;但所有CISC处理器都支持 寄存器间接寻址 3种寻址方式总结如下&#xff1a; 助记符 RTL格式 描述 ADD r0,r1…

学苑教育杂志学苑教育杂志社学苑教育编辑部2025年第9期目录

专题研究 核心素养下合作学习在初中数学中的应用 郑铁洪; 4-6 教育管理 小学班级管理应用赏识教育的策略研究 芮望; 7-9 课堂教学 小学数学概念教学的实践策略 刘淑萍; 10-12 “减负提质”下小学五年级语文课堂情境教学 王利;梁岩; 13-15 小练笔的美丽转身…

关于类型转换的细节(隐式类型转换的临时变量和理解const权限)

文章目录 前言类型转换的细节1. 类型转换的临时变量细节二&#xff1a;const与指针 前言 关于类型转换的细节&#xff0c;这里小编和大家探讨两个方面&#xff1a; 关于类型转化的临时变量的问题const关键字的权限问题 — 即修改权限。小编或通过一道例题&#xff08;配图&am…

技术对暴力的削弱

信息时代的大政治分析&#xff1a;效率对暴力的颠覆 一、工业时代勒索逻辑的终结 工厂罢工的消亡 1930年代通用汽车罢工依赖工厂的物理集中、高资本投入和流水线脆弱性&#xff0c;通过暴力瘫痪生产实现勒索。 信息时代企业分散化、资产虚拟化&#xff08;如软件公司可携带代码…

深入理解分布式锁——以Redis为例

一、分布式锁简介 1、什么是分布式锁 分布式锁是一种在分布式系统环境下&#xff0c;通过多个节点对共享资源进行访问控制的一种同步机制。它的主要目的是防止多个节点同时操作同一份数据&#xff0c;从而避免数据的不一致性。 线程锁&#xff1a; 也被称为互斥锁&#xff08…

yolo训练用的数据集的数据结构

Football Players Detection using YOLOV11 可以在roboflow上标注 Sign in to Roboflow 训练数据集只看这个data.yaml 里面是train的image地址和classnames 每个image一一对应一个label 第一个位是分类&#xff0c;0是classnames[0]对应的物体&#xff0c;现在是cuboid &…

Redis 使用及命令操作

文章目录 一、基本命令二、redis 设置键的生存时间或过期时间三、SortSet 排序集合类型操作四、查看中文五、密码设置和查看密码的方法六、关于 Redis 的 database 相关基础七、查看内存占用 一、基本命令 # 查看版本 redis-cli --version 结果&#xff1a;redis-cli 8.0.0red…

Java大师成长计划之第13天:Java中的响应式编程

&#x1f4e2; 友情提示&#xff1a; 本文由银河易创AI&#xff08;https://ai.eaigx.com&#xff09;平台gpt-4o-mini模型辅助创作完成&#xff0c;旨在提供灵感参考与技术分享&#xff0c;文中关键数据、代码与结论建议通过官方渠道验证。 随着现代应用程序的复杂性增加&…

华为私有协议Hybrid

实验top图 理论环节 1. 基本概念 Hybrid接口&#xff1a; 支持同时处理多个VLAN流量&#xff0c;且能针对不同VLAN配置是否携带标签&#xff08;Tagged/Untagged&#xff09;。 核心特性&#xff1a; 灵活控制数据帧的标签处理方式&#xff0c;适用于复杂网络场景。 2. 工作…

K8s 常用命令、对象名称缩写汇总

K8s 常用命令、对象名称缩写汇总 前言 在之前的文章中已经陆续介绍过 Kubernetes 的部分命令&#xff0c;本文将专题介绍 Kubernetes 的常用命令&#xff0c;处理日常工作基本够用了。 集群相关 1、查看集群信息 kubectl cluster-info # 输出信息Kubernetes master is run…

【HDLBits刷题】Verilog Language——1.Basics

目录 一、题目与题解 1.Simple wire&#xff08;简单导线&#xff09; 2.Four wires&#xff08;4线&#xff09; 3.Inverter&#xff08;逆变器&#xff08;非门&#xff09;&#xff09; 4.AND gate &#xff08;与门&#xff09; 5. NOR gate &#xff08;或非门&am…

C语言|递归求n!

C语言| 函数的递归调用 【递归求n!】 0!1; 1!1 n! n*(n-1)*(n-2)*(n-3)*...*3*2*1; 【分析过程】 定义一个求n&#xff01;的函数&#xff0c;主函数直接调用 [ Factorial()函数 ] 1 用if语句去实现&#xff0c;把求n!的情况列举出来 2 if条件有3个&#xff0c;n<0; n0||n…

Android第四次面试总结之Java基础篇(补充)

一、设计原则高频面试题&#xff08;附大厂真题解析&#xff09; 1. 单一职责原则&#xff08;SRP&#xff09;在 Android 开发中的应用&#xff08;字节跳动真题&#xff09; 真题&#xff1a;“你在项目中如何体现单一职责原则&#xff1f;举例说明。”考点&#xff1a;结合…

OpenHarmony GPIO应用开发-LED

学习于&#xff1a; https://docs.openharmony.cn/pages/v5.0/zh-cn/device-dev/driver/driver-platform-gpio-develop.md https://docs.openharmony.cn/pages/v5.0/zh-cn/device-dev/driver/driver-platform-gpio-des.md 通过OpenHarmony官方文档指导可获知&#xff1a;芯片厂…

XILINX原语之——xpm_fifo_async(异步FIFO灵活设置位宽、深度)

目录 一、"fwft"模式&#xff08;First-Word-Fall-Through read mode&#xff09; 1、写FIFO 2、读FIFO 二、"std"模式&#xff08;standard read mode&#xff09; 1、写FIFO 2、读FIFO 调用方式和xpm_fifo_sync基本一致&#xff1a; XILINX原语之…

系统学习算法:动态规划(斐波那契+路径问题)

题目一&#xff1a; 思路&#xff1a; 作为动态规划的第一道题&#xff0c;这个题很有代表性且很简单&#xff0c;适合入门 先理解题意&#xff0c;很简单&#xff0c;就是斐波那契数列的加强版&#xff0c;从前两个数变为前三个数 算法原理&#xff1a; 这五步可以说是所有…