《从卷积核到数字解码:CNN 手写数字识别实战解析》

文章目录

  • 一、手写数字识别的本质与挑战
  • 二、使用步骤
    • 1.导入torch库以及与视觉相关的torchvision库
    • 2.下载datasets自带的手写数字的数据集到本地
  • 三、完整代码展示


一、手写数字识别的本质与挑战

手写数字识别的核心是:从二维像素矩阵中提取具有判别性的特征,区分 0-9 这 10 个类别。其难点包括:
手写风格多样性:不同人书写的数字(如 “3” 可能有开口或闭口)、笔画粗细、倾斜角度差异大。
位置与尺度变化:数字在图像中的位置(偏上 / 偏下)、大小可能不一致(如 MNIST 数据集中数字存在轻微平移)。
噪声与形变:实际场景中可能存在笔画断裂、污渍等噪声,或扫描时的图像模糊。
传统方法(如 SVM、KNN)依赖人工设计特征(如 HOG、SIFT、几何矩),需专家经验且泛化能力有限;而 CNN 通过自动化特征学习 + 结构化归纳偏置,天然适配这些挑战。

二、使用步骤

1.导入torch库以及与视觉相关的torchvision库

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

2.下载datasets自带的手写数字的数据集到本地

"""下载测试数据集(包含图片和标签)"""training_data=datasets.MNIST(root='../data',train=True,download=True,transform=ToTensor()
)"""下载测试数据集(包含训练图片+标签)"""test_data=datasets.MNIST(root='../data',train=False,download=True,transform=ToTensor()
)

3、将下载的数据集打包

train_dataloder=DataLoader(training_data,batch_size=64)
test_dataloder=DataLoader(test_data,batch_size=64)

4、指定数据训练的设备

device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"print(f"{device}device")

5、定义神经网络框架和前向传播

class NeurakNetwork(nn.Module):     #通过调用类的形式来使用神经网络,神经网络的模型nn.moudledef __init__(self):super().__init__()  #继承父类的初始化self.flatten=nn.Flatten()   #将二位数据展成一维数据self.hidden1=nn.Linear(28*28,128)   #第一个参数时有多少个神经元传进来,第二个参数是有多少个数据传出去self.hidden2=nn.Linear(128,256)self.out=nn.Linear(256,10)      #输出必须与标签类型相同,输入必须是上一层神经元的个数def forward(self,x):    #前向传播,指明数据的流向,使神经网络连接起来,函数名称不能修改x=self.flatten(x)x=self.hidden1(x)x=torch.relu(x)     #激活函数,torch使用relu或者tanh函数作为激活函数x=self.hidden2(x)x=torch.relu(x)x=self.out(x)return x

6、初始化神经网络并将模型加载到设备中

model = NeurakNetwork().to(device)      #将刚刚定义的模型传入到GPU中

7、定义模型训练的函数

def train(dataloader,model,loss_fn,optimizer):model.train()       #告诉模型,即将开始训练,其中的w进行随机化操作,已经更新w,在训练过程中,w会被修改"""pytorch提供两种方式来切换训练和测试的模式,分别是model.train()和model.eval()一般用法是,在训练开始之前写上model.train(),在测试时写上model.eval()"""batch_size_num=1for X,y in dataloader:      #其中batch为每一个数据的编号X,y=X.to(device),y.to(device)   #将训练数据集和标签传入cpu和gpupred=model.forward(X)loss=loss_fn(pred,y)    #通过交叉熵损失函数计算loss#Backpropagation  进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()   #梯度值清零loss.backward()     #反向传播计算得到的每个参数的梯度值woptimizer.step()    #根据梯度更新网络w参数loss_value=loss.item()  #从tensor数据中提取数据出来,tensor获取损失值if batch_size_num%100==0:print(f"loss:{loss_value:>7f}[number:{batch_size_num}]")batch_size_num+=1

8、定义测试的函数

def test(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader)model.eval()test_loss,correct=0,0with torch.no_grad():for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)test_loss+=loss_fn(pred,y).item()correct +=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)b=(pred.argmax(1)==y).type(torch.float)test_loss/=num_batchescorrect/=sizeprint(f"Test result:\n Accurracy:{(100*correct)}%,AVG loss:{test_loss}")

9、初始化损失函数创建优化器

loss_fn=nn.CrossEntropyLoss()   #创建交叉熵损失函数对象,适合做多分类optimizer=torch.optim.SGD(model.parameters(),lr=0.01)   #创建优化器,使用SGD随机梯度下降

10、调用训练和测试的函数,完成训练一次测试一次

train(train_dataloder,model,loss_fn,optimizer)  #训练一次完整的数据,多轮训练
test(test_dataloder,model,loss_fn)

11、训练20轮,测试一次

epochs=20
for epoch in range(epochs):train(train_dataloder,model,loss_fn,optimizer)print(f"epoch{epoch}")
test(test_dataloder,model,loss_fn)

三、完整代码展示


"""手写数字识别"""
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor"""下载测试数据集(包含图片和标签)"""training_data=datasets.MNIST(root='../data',train=True,download=True,transform=ToTensor()
)"""下载测试数据集(包含训练图片+标签)"""test_data=datasets.MNIST(root='../data',train=False,download=True,transform=ToTensor()
)
print(len(training_data))"""展示手写图片,把训练集中的前59000张图片展示一下"""
from matplotlib import pyplot as plt
figure=plt.figure()
for i in range(9):img,label=training_data[i+59000]figure.add_subplot(3,3,i+1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(),cmap='gray')a=img.squeeze()
plt.show()train_dataloder=DataLoader(training_data,batch_size=64)
test_dataloder=DataLoader(test_data,batch_size=64)device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"print(f"{device}device")"""self参数理解:在类内部开辟出了一个共享空间,所有被定义在这片空间的参数都能够使用self.参数名来调用"""class NeurakNetwork(nn.Module):     #通过调用类的形式来使用神经网络,神经网络的模型nn.moudledef __init__(self):super().__init__()  #继承父类的初始化self.flatten=nn.Flatten()   #将二位数据展成一维数据self.hidden1=nn.Linear(28*28,128)   #第一个参数时有多少个神经元传进来,第二个参数是有多少个数据传出去self.hidden2=nn.Linear(128,256)self.out=nn.Linear(256,10)      #输出必须与标签类型相同,输入必须是上一层神经元的个数def forward(self,x):    #前向传播,指明数据的流向,使神经网络连接起来,函数名称不能修改x=self.flatten(x)x=self.hidden1(x)x=torch.relu(x)     #激活函数,torch使用relu或者tanh函数作为激活函数x=self.hidden2(x)x=torch.relu(x)x=self.out(x)return xmodel = NeurakNetwork().to(device)      #将刚刚定义的模型传入到GPU中
print(model)def train(dataloader,model,loss_fn,optimizer):model.train()       #告诉模型,即将开始训练,其中的w进行随机化操作,已经更新w,在训练过程中,w会被修改"""pytorch提供两种方式来切换训练和测试的模式,分别是model.train()和model.eval()一般用法是,在训练开始之前写上model.train(),在测试时写上model.eval()"""batch_size_num=1for X,y in dataloader:      #其中batch为每一个数据的编号X,y=X.to(device),y.to(device)   #将训练数据集和标签传入cpu和gpupred=model.forward(X)loss=loss_fn(pred,y)    #通过交叉熵损失函数计算loss#Backpropagation  进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()   #梯度值清零loss.backward()     #反向传播计算得到的每个参数的梯度值woptimizer.step()    #根据梯度更新网络w参数loss_value=loss.item()  #从tensor数据中提取数据出来,tensor获取损失值if batch_size_num%100==0:print(f"loss:{loss_value:>7f}[number:{batch_size_num}]")batch_size_num+=1def test(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader)model.eval()test_loss,correct=0,0with torch.no_grad():for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)test_loss+=loss_fn(pred,y).item()correct +=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)b=(pred.argmax(1)==y).type(torch.float)test_loss/=num_batchescorrect/=sizeprint(f"Test result:\n Accurracy:{(100*correct)}%,AVG loss:{test_loss}")loss_fn=nn.CrossEntropyLoss()   #创建交叉熵损失函数对象,适合做多分类optimizer=torch.optim.Adam(model.parameters(),lr=0.01)   #创建优化器,使用Adam优化器#params:要训练的参数,一般传入的都是model.parameters()
#lr是指学习率,也就是步长#loss表示模型训练后的输出结果与样本标签的差距,如果差距越小,就表示模型训练越好,越逼近于真实的模型
train(train_dataloder,model,loss_fn,optimizer)  #训练一次完整的数据,多轮训练
test(test_dataloder,model,loss_fn)epochs=20
for epoch in range(epochs):train(train_dataloder,model,loss_fn,optimizer)print(f"epoch{epoch}")
test(test_dataloder,model,loss_fn)

在这里插入图片描述
可以看到经过20轮的训练模型的正确率为96.91%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/77877.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UniOcc:自动驾驶占用预测和预报的统一基准

25年3月来自 UC Riverside、U Wisconsin 和 TAMU 的论文"UniOcc: A Unified Benchmark for Occupancy Forecasting and Prediction in Autonomous Driving"。 UniOcc 是一个全面统一的占用预测基准(即基于历史信息预测未来占用)和基于摄像头图…

模型量化核心技术解析:从算法原理到工业级实践

一、模型量化为何成为大模型落地刚需&#xff1f; 算力困境&#xff1a;175B参数模型FP32推理需0.5TB内存&#xff0c;超出主流显卡容量 速度瓶颈&#xff1a;FP16推理延迟难以满足实时对话需求&#xff08;如客服场景<200ms&#xff09; 能效挑战&#xff1a;边缘设备运行…

AD9253链路训练

传统方式 参考Xilinx官方文档xapp524。对于AD9253器件 - 125M采样率 - DDR模式&#xff0c;ADC器件的DCO采样时钟(500M Hz)和FCO帧时钟是中心对齐的&#xff0c;适合直接采样。但是DCO时钟不能直接被FPGA内部逻辑使用&#xff0c;需要经过BUFIO和BUFR缓冲后&#xff0c;得到s_b…

解决方案:远程shell连不上Ubuntu服务器

服务器是可以通过VNC登录&#xff0c;排除了是服务器本身故障 检查服务是否在全网卡监听 sudo ss -tlnp | grep sshd确保有一行类似 LISTEN 0 128 0.0.0.0:22 0.0.0.0:* users:(("sshd",pid...,fd3))返回无结果&#xff0c;表明系统里并没有任…

关于大数据的基础知识(四)——大数据的意义与趋势

成长路上不孤单&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【14后&#x1f60a;///计算机爱好者&#x1f60a;///持续分享所学&#x1f60a;///如有需要欢迎收藏转发///&#x1f60a;】 今日分享关于大数据的基础知识&#xff08;四&a…

智能指针(weak_ptr )之三

1. std::weak_ptr 1.1 定义与用法 std::weak_ptr 是一种不拥有对象所有权的智能指针&#xff0c;用于观察但不影响对象的生命周期。主要用于解决 shared_ptr 之间的循环引用问题。 主要特性&#xff1a; 非拥有所有权&#xff1a;不增加引用计数。可从 shared_ptr 生成&…

学习海康VisionMaster之卡尺工具

一&#xff1a;进一步学习了 今天学习下VisionMaster中的卡尺工具&#xff1a;主要用于测量物体的宽度、边缘的特征的位置以及图像中边缘对的位置和间距 二&#xff1a;开始学习 1&#xff1a;什么是卡尺工具&#xff1f; 如果我需要检测芯片的每一个PIN的宽度和坐标&#xff…

Java面试实战:从Spring Boot到微服务的深入探讨

Java面试实战&#xff1a;从Spring Boot到微服务的深入探讨 场景&#xff1a;电商场景的面试之旅 在某互联网大厂的面试间&#xff0c;面试官李老师正襟危坐&#xff0c;而对面坐着的是传说中的“水货程序员”赵大宝。 第一轮&#xff1a;核心Java与构建工具 面试官&#x…

深入理解 Spring @Configuration 注解

在 Spring 框架中,@Configuration 注解是一个非常重要的工具,它用于定义配置类,这些类可以包含 Bean 定义方法。通过使用 @Configuration 和 @Bean 注解,开发者能够以编程方式创建和管理应用程序上下文中的 Bean。本文将详细介绍 @Configuration 注解的作用、如何使用它以及…

密码学中的盐值是什么?

目录 1. 盐值的基本概念 2. 盐值的作用 (1) 防止彩虹表攻击 (2) 防止相同的密码生成相同的哈希值 (3) 增加暴力破解的难度 3. 如何使用盐值&#xff1f; (1) 生成盐值 (2) 将盐值附加到密码 (3) 存储盐值和哈希值 (4) 验证密码 4. 盐值如何增加暴力破解的难度 在线暴…

基于瑞芯微RK3576国产ARM八核2.2GHz A72 工业评估板——Docker容器部署方法说明

前 言 本文适用开发环境: Windows开发环境:Windows 7 64bit、Windows 10 64bit Linux开发环境:VMware16.2.5、Ubuntu22.04.5 64bit U-Boot:U-Boot-2017.09 Kernel:Linux-6.1.115 LinuxSDK:LinuxSDK-[版本号](基于rk3576_linux6.1_release_v1.1.0) Docker是一个开…

大数据技术全解析

目录 前言1. Kafka&#xff1a;流数据的传输平台1.1 Kafka概述1.2 Kafka的应用场景1.3 Kafka的特点 2. HBase&#xff1a;分布式列式数据库2.1 HBase概述2.2 HBase的应用场景2.3 HBase的特点 3. Hadoop&#xff1a;大数据处理的基石3.1 Hadoop概述3.2 Hadoop的应用场景3.3 Hado…

mcpo的简单使用

1.安装依赖 conda create -n mcpo python3.11 conda activate mcpo pip install mcpo pip install uv2.随便从https://github.com/modelcontextprotocol/servers?tabreadme-ov-file 找一个mcp服务使用就行&#xff0c;我这里选的是爬虫 然后安装 pip install mcp-server-f…

uniapp-商城-32-shop 我的订单-订单详情和组件goods-list

上面完成了我的订单&#xff0c;通过点击我的订单中每一条数据&#xff0c;可以跳转到订单详情中。 这里就需要展示订单的状态&#xff0c;支付状态&#xff0c;物流状态&#xff0c;取货状态&#xff0c;用户信息&#xff0c;订单中的货物详情等。 1、创建一个订单详情文件 …

XCVU13P-2FHGA2104I Xilinx Virtex UltraScale+ FPGA

XCVU13P-2FHGA2104I 是 Xilinx&#xff08;现为 AMD&#xff09;Virtex UltraScale™ FPGA 系列中的高端 Premium 器件&#xff0c;基于 16nm FinFET 工艺并采用 3D IC 堆叠硅互连&#xff08;SSI&#xff09;技术&#xff0c;提供业内顶级的计算密度和带宽​。该芯片集成约 3,…

【Python3】Django 学习之路

第一章&#xff1a;Django 简介 1.1 什么是 Django&#xff1f; Django 是一个高级的 Python Web 框架&#xff0c;旨在让 Web 开发变得更加快速和简便。它鼓励遵循“不要重复自己”&#xff08;DRY&#xff0c;Don’t Repeat Yourself&#xff09;的原则&#xff0c;并提供了…

Python 设计模式:模板模式

1. 什么是模板模式&#xff1f; 模板模式是一种行为设计模式&#xff0c;它定义了一个操作的算法的骨架&#xff0c;而将一些步骤延迟到子类中。模板模式允许子类在不改变算法结构的情况下&#xff0c;重新定义算法的某些特定步骤。 模板模式的核心思想是将算法的固定部分提取…

【后端】构建简洁的音频转写系统:基于火山引擎ASR实现

在当今数字化时代&#xff0c;语音识别技术已经成为许多应用不可或缺的一部分。无论是会议记录、语音助手还是内容字幕&#xff0c;将语音转化为文本的能力对提升用户体验和工作效率至关重要。本文将介绍如何构建一个简洁的音频转写系统&#xff0c;专注于文件上传、云存储以及…

音频base64

音频 Base64 是一种将二进制音频数据&#xff08;如 MP3、WAV 等格式&#xff09;编码为 ASCII 字符串的方法。通过 Base64 编码&#xff0c;音频文件可以转换为纯文本形式&#xff0c;便于在文本协议&#xff08;如 JSON、XML、HTML 或电子邮件&#xff09;中传输或存储&#…

240422 leetcode exercises

240422 leetcode exercises jarringslee 文章目录 240422 leetcode exercises[237. 删除链表中的节点](https://leetcode.cn/problems/delete-node-in-a-linked-list/)&#x1f501;节点覆盖法 [392. 判断子序列](https://leetcode.cn/problems/is-subsequence/)&#x1f501;…