手写体识别

news/2025/10/21 20:39:49/文章来源:https://www.cnblogs.com/inian/p/19156372

手写数字识别:基于PyTorch的卷积神经网络实现

一、项目概述

使用PyTorch实现一个基于卷积神经网络(CNN)的手写手写数字识别模型,通过MNIST数据集训练,实现对手写数字(0-9)的分类识别。

二、环境依赖

  • Python 3.x
  • PyTorch
  • torchvision
  • matplotlib

三、代码实现与解析

1. 导入必要库

# 导入PyTorch核心库
import torch
# 导入PyTorch神经网络模块
import torch.nn as nn
# 导入PyTorch优化器模块
import torch.optim as optim
# 导入数据加载工具
from torch.utils.data import DataLoader
# 导入计算机视觉相关的数据集和数据转换工具
from torchvision import datasets, transforms
# 导入matplotlib用于可视化
import matplotlib.pyplot as plt

2. 数据准备与预处理

# 定义数据转换管道:将图像转为Tensor并进行标准化
transform = transforms.Compose([transforms.ToTensor(),  # 将PIL图像转为Tensor格式,并将像素值从[0,255]归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # 用MNIST数据集的均值(0.1307)和标准差(0.3081)标准化数据
])# 加载MNIST训练数据集(手写数字0-9)
train_dataset = datasets.MNIST(root='./data',  # 数据集存储路径train=True,     # 加载训练集download=True,  # 如果本地没有数据集则自动下载transform=transform  # 应用定义好的数据转换
)
# 加载MNIST测试数据集
test_dataset = datasets.MNIST(root='./data',  # 数据集存储路径train=False,    # 加载测试集download=True,  # 如果本地没有数据集则自动下载transform=transform  # 应用定义好的数据转换
)# 创建训练数据加载器:按批次加载数据,支持打乱顺序
train_loader = DataLoader(train_dataset,  # 要加载的数据集batch_size=64,  # 每个批次包含64个样本shuffle=True    # 训练时打乱数据顺序,增加随机性
)
# 创建测试数据加载器:批次更大,不需要打乱
test_loader = DataLoader(test_dataset,   # 要加载的数据集batch_size=1000,# 每个批次包含1000个样本(测试时可更大)shuffle=False   # 测试时不需要打乱顺序
)

3. 定义神经网络模型

class HandwritingRecognizer(nn.Module):def __init__(self):# 调用父类构造函数super(HandwritingRecognizer, self).__init__()# 定义卷积层序列:用于提取图像特征self.conv_layers = nn.Sequential(# 第一个卷积层:输入1通道(灰度图),输出32通道,卷积核3x3,步长1,填充1nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),  # 激活函数:引入非线性,增强模型表达能力# 最大池化层:2x2窗口,步长2,输出尺寸变为14x14(原28x28)nn.MaxPool2d(kernel_size=2, stride=2),# 第二个卷积层:输入32通道,输出64通道,卷积核3x3,步长1,填充1nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),  # 激活函数# 最大池化层:2x2窗口,步长2,输出尺寸变为7x7(原14x14)nn.MaxPool2d(kernel_size=2, stride=2))# 定义全连接层序列:用于分类决策self.fc_layers = nn.Sequential(# 第一个全连接层:输入为64通道×7×7特征图展平后的向量,输出128维nn.Linear(64 * 7 * 7, 128),nn.ReLU(),  # 激活函数nn.Dropout(0.5),  # 随机丢弃50%神经元,防止过拟合nn.Linear(128, 10)  # 输出层:10个类别(对应数字0-9))# 定义前向传播过程def forward(self, x):x = self.conv_layers(x)  # 输入经过卷积层提取特征# 将特征图展平为一维向量:-1表示自动计算批次维度,64*7*7为特征维度x = x.view(-1, 64 * 7 * 7)x = self.fc_layers(x)    # 展平后的特征经过全连接层得到分类结果return x

4. 初始化模型、损失函数和优化器

model = HandwritingRecognizer()  # 创建手写数字识别模型实例
criterion = nn.CrossEntropyLoss()  # 定义损失函数:多分类交叉熵损失(适用于分类任务)
# 定义优化器:Adam优化器,学习率0.001(控制参数更新速度)
optimizer = optim.Adam(model.parameters(), lr=0.001)

5. 模型训练函数

def train(model, train_loader, criterion, optimizer, epochs=1):model.train()  # 设置模型为训练模式(启用dropout等训练特定层)# 遍历训练轮次for epoch in range(epochs):running_loss = 0.0  # 累计当前轮次的损失# 遍历训练数据的每个批次for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()  # 清零优化器的梯度(防止梯度累积)output = model(data)   # 前向传播:输入数据经过模型得到预测结果loss = criterion(output, target)  # 计算预测结果与真实标签的损失loss.backward()        # 反向传播:计算损失对各参数的梯度optimizer.step()       # 更新模型参数(基于梯度和优化器规则)running_loss += loss.item()  # 累加当前批次的损失值# 每300个批次打印一次平均损失(便于监控训练过程)if batch_idx % 300 == 299:print(f'Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {running_loss/300:.4f}')running_loss = 0.0  # 重置累计损失

6. 模型测试函数

def test(model, test_loader):model.eval()  # 设置模型为评估模式(关闭dropout等训练特定层)correct = 0   # 记录正确预测的样本数total = 0     # 记录总样本数# 关闭梯度计算(测试时不需要反向传播,节省计算资源)with torch.no_grad():# 遍历测试数据的每个批次for data, target in test_loader:output = model(data)  # 前向传播得到预测结果# 取预测概率最大的类别作为最终预测(dim=1表示按行取最大值)_, predicted = torch.max(output.data, 1)total += target.size(0)  # 累加总样本数# 累加预测正确的样本数(预测类别与真实标签相等)correct += (predicted == target).sum().item()# 计算并打印测试集准确率print(f'Test Accuracy: {100 * correct / total:.2f}%')

7. 执行训练和测试

# 执行训练和测试:训练1轮,然后在测试集上评估
train(model, train_loader, criterion, optimizer, epochs=1)
test(model, test_loader)

8. 可视化预测结果

def visualize_prediction(model, test_dataset, idx=0):model.eval()  # 设置模型为评估模式image, label = test_dataset[idx]  # 获取测试集中指定索引的图像和真实标签# 关闭梯度计算with torch.no_grad():# 为图像增加批次维度(模型输入需要[批次, 通道, 高, 宽]格式)output = model(image.unsqueeze(0))predicted = torch.argmax(output).item()  # 取预测概率最大的类别# 显示图像:squeeze()去除多余维度,cmap='gray'设置为灰度图plt.imshow(image.squeeze().numpy(), cmap='gray')# 设置标题:显示真实标签和预测结果plt.title(f'True: {label}, Predicted: {predicted}')plt.show()  # 显示图像# 可视化测试集中索引为42的样本的预测结果
visualize_prediction(model, test_dataset, idx=42)

四、模型结构说明

  1. 卷积层部分

    • 第一层卷积:32个3×3卷积核,提取基础边缘和纹理特征
    • 最大池化:将特征图尺寸从28×28降为14×14
    • 第二层卷积:64个3×3卷积核,提取更复杂的组合特征
    • 最大池化:将特征图尺寸从14×14降为7×7
  2. 全连接层部分

    • 第一个全连接层:将64×7×7的特征展平后映射到128维
    • Dropout层:随机丢弃50%神经元,防止过拟合
    • 输出层:10个神经元,对应0-9十个数字的分类结果

五、训练与评估流程

  1. 训练过程:

    • 前向传播:输入数据通过网络得到预测结果
    • 计算损失:使用交叉熵损失衡量预测与真实标签的差距
    • 反向传播:计算损失对各参数的梯度
    • 参数更新:使用Adam优化器根据梯度更新网络参数
  2. 评估过程:

    • 关闭梯度计算,节省计算资源
    • 计算模型在测试集上的准确率
    • 通过可视化查看具体样本的预测结果

六、扩展方向

  1. 增加训练轮次,观察模型性能变化
  2. 调整网络结构(如增加卷积层、调整通道数)
  3. 尝试不同的优化器和学习率
  4. 增加数据增强操作,提高模型泛化能力
  5. 在GPU上运行以加速训练过程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/942640.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AGC 合集 1.0

AGC001~030。2025.3.10 - 2025.10.21。 收录了前 30 场 AGC 中我写了题解的题目。 如果附带了题目大意的话就是最近才做,或者是我过了几个月看不懂自己写的啥了重写了一遍。 如果您认为某些题目的理解不够到位,非常欢…

20231302邱之钊密码系统设计实验一第二

1.参考相关内容,在Ubuntu或openEuler中(推荐openEuler)中使用OpenSSL库编程实现调用SM2(加密解密,签名验签),SM3(摘要计算,HMAC 计算),SM4(加密解密)算法,使用Markdown记录详细记录实践过程,每完成一项…

你好,我是肆闲:C语言的学习,成长与分享旅程

大家好,我是肆闲。 今天,我写下了我的第一篇博客,就像一个程序员运行了第一段 print("Hello World")一样。代码简单,却标志着一个充满无限可能的新世界,在我眼前打开了大门。作为一个刚入门的小白,我对…

深入BERT内核:用数学解密掩码语言模型的工作原理

传统语言模型有个天生缺陷——只能从左往右读,就像你现在读这段文字一样,一个词一个词往下看,完全不知道后面会出现什么。人类可不是这么学语言的。看到"被告被判**_**",大脑会根据上下文直接联想到&quo…

ZR 2025 NOIP 二十连测 Day 6

100 + 72 + 35 + 0 = 207, Rank 61/131.啊啊啊第一次上 200 /oh25noip二十连测day6 链接:link 题解:题目内 时间:4.5h (2025.10.21 13:40~18:10) 题目数:4 难度:A B C D\(\color{#F39C11} 橙\)*1200估分:100 + 7…

20251021

上午工程实训课接触了电工基础,老师演示了万用表测量、简单电路连接和安全操作规范 下午英语课围绕旅游主题展开听力, 晚上写离散数学作业 间隙用碎片时间整理了课堂笔记,还讨论了实训课上的电路连接问题。 (工程实…

[论文笔记] Precision-Guided Context Sensitivity for Pointer Analysis

Introduction Context-sensitivity 会带来静态分析的精度提升,但是也会带来巨大的开销,这引出一个关键的问题:能否在某些对整体分析的精度有重要影响的函数上选择性的使用 context-sensitivity?这个问题的难点在于…

英语_备忘_疑难

好的,这里有一些关于 **How** 和 **What** 在感叹句中使用的例题,涵盖了常见的规则和易错点。 **规则回顾:*** **What + (a/an) + 形容词 + 名词*** **How + 形容词/副词 + (主谓)** --- **例题:** 请选择正…

朋友圈文案不会写?这个AI指令可能帮得上忙

最近在整理AI提示词的时候,顺手写了个朋友圈文案生成的指令。本来只是自己用,后来发现身边朋友也有这个需求,就想着分享出来。写在前面 刷朋友圈的时候,你有没有发现:有些人随便发个照片配几个字,点赞评论一大堆…

「JOISC2020-掃除」题解

题解记录掃除 (Sweeping) sol 从 Subtask 3 的特殊性质入手,可以发现一个关键性质:无论之后如何操作,这个单调性在任何时刻均满足。其原因可以简单考虑一下操作的效力范围与结果得到。 理解之后容易推广到全局,不难…

职责分离的艺术:剖析主从Reactor模型如何实现极致的并发性能

职责分离的艺术:剖析主从Reactor模型如何实现极致的并发性能Reactor单线程模型 在Reactor单线程模型中,所谓的“单线程”主要针对I/O操作而言,即所有的I/O操作(如accept()、read()、write()和connect())都在同一个…

数学题刷题记录(数学、数论、组合数学)

P5686 [CSP-S2019 江西] 和积和简单题,直接将区间求和转换成前缀和,设 \(A_i = \sum_{i = 1}^n a_i,B_i = \sum_{i = 1}^n b_i\),那么式子为: \[\sum_{l = 1}^n \sum_{r = l}^n (A_r-A_{l-1})(B_r-B_{l-1}) \]\[=\…

记录一次raid恢复之后数据库故障处理(ora-01200,ORA-26101,ORA-600)---惜分飞

记录一次raid恢复之后数据库故障处理(ora-01200,ORA-26101,ORA-600)---惜分飞联系:手机/微信(+86 17813235971) QQ(107644445) 标题:记录一次raid恢复之后数据库故障处理(ora-01200,ORA-26101,ORA-600) 作者:惜分飞…

CF简单构造小计

记录在这的都是感觉比较妙的或者看了题解的( CF2155D Batteries有 \(n\) 个元素,其中有 \(a\) 个是好的( \(a\) 未知)。 每次你可以询问一对元素,返回1当且仅当两个元素都是好的,否则返回0。 在 \(\lfloor\frac{…

软件工程第三次作业:四则运算题目生成器 - Nyanya-

四则运算题目生成器 - 结对项目报告项目信息 详情课程 软件工程作业要求 结对项目项目目标 实现一个四则运算题目生成器,支持有理数运算,规范软件开发流程,熟悉结对编程结对成员 姓名1: [杨浩] 学号1: [3123004462]…

ORA-600 kokasgi1故障处理(sys被重命名)---惜分飞

ORA-600 kokasgi1故障处理(sys被重命名)---惜分飞联系:手机/微信(+86 17813235971) QQ(107644445) 标题:ORA-600 kokasgi1故障处理(sys被重命名) 作者:惜分飞©版权所有[未经本人同意,不得以任何形式转载,否则有…

简单页面聊天

import express from express import http from http import { Server } from socket.io import cors from corsconst app = express() const PORT = process.env.PORT || 3001app.use(cors({ origin: [http://localho…

深入认识ClassLoader - 一次投产失败的复盘

问题背景 投产日,同事负责的项目新版本发布,版本包是SpringBoot v2.7.18的一个FatJar,java -jar启动报错停止了,输出的异常日志如下: Caused by: org.springframework.beans.factory.BeanCreationException: Erro…

python 包来源镜像

python 镜像python安装包,默认地址非常的慢,可改用国内相关镜像‌清华大学开源软件镜像站‌ 地址:https://pypi.tuna.tsinghua.edu.cn/simple‌阿里云开源镜像站‌ 地址:https://mirrors.aliyun.com/pypi/simple/‌…