卷积神经网络--手写数字识别

本文我们通过搭建卷积神经网络模型,实现手写数字识别。

pytorch中提供了手写数字的数据集 ,我们可以直接从pytorch中下载

MNIST中包含70000张手写数字图像:60000张用于训练,10000张用于测试

图像是灰度的,28x28像素

首先,下载数据集

import torch
from torchvision import datasets #封装与图像相关的模型,数据集
from torchvision.transforms import ToTensor # #数据转换,张量,将其他类型的数据转换为tensor张量training_data=datasets.MNIST(root='data',#表示下载的手写数字到哪个路径train=True,#读取下载后数据中的训练集download=True,#如果之前已经下载过,就不用再下载transform=ToTensor(),#张量,图片不能直接传入神经网络模型
)test_data=datasets.MNIST(root='data',train=False,download=True,transform=ToTensor(),
)

打包数据

from torch.utils.data import DataLoader train_dataloader=DataLoader(training_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)

判断当前设备是否支持GPU

device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'using {device} device')

构建卷积神经网络模型

from torch import nn #导入神经网络模块class CNN(nn.Module):def __init__(self):#初始化类super(CNN,self).__init__()#初始化父类self.conv1=nn.Sequential(# 将多个层(如卷积、激活函数、池化等)按顺序打包,输入数据会​​依次通过这些层​​,无需手动编写每一层的传递逻辑。nn.Conv2d(#2D 卷积层,提取空间特征。in_channels=1,#输入通道数out_channels=16,#输出通道数kernel_size=3,#卷积核大小stride=1,#步长padding=1,#填充),nn.ReLU(),#激活函数,引入非线性变换,使得神经网络能够学习复杂的非线性变换,增强表达能力nn.MaxPool2d(kernel_size=2)# 2x2最大池化(尺寸减半))self.conv2=nn.Sequential(nn.Conv2d(16,32,3,1,1),nn.ReLU(),# nn.Conv2d(32,32,3,1,1),# nn.ReLU(),nn.MaxPool2d(2),)self.conv3=nn.Sequential(nn.Conv2d(32,64,3,1,1))self.out=nn.Linear(64*7*7,10)def forward(self,x):#前向传播x=self.conv1(x)x=self.conv2(x)x=self.conv3(x)x=x.view(x.size(0),-1)# 展平为向量(保留batch_size,合并其他维度)output=self.out(x)  # 全连接层输出return output

返回的output结果大致如图所示

 模型传入GPU

model=CNN().to(device)
print(model)

  损失函数,衡量的是​​模型预测的概率分布​​与​​真实的类别分布​​之间的差异。

loss_fn=nn.CrossEntropyLoss()

  优化器,用于在训练神经网络时更新模型参数,目的是​​在神经网络训练过程中,自动调整模型的参数(权重和偏置),以最小化损失函数​​。

optimizer=torch.optim.Adam(model.parameters(),lr=0.01)

 模型训练

def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)loss=loss_fn(pred,y)# Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()               #梯度值清零loss.backward()                     #反向传播计算得到每个参数的梯度值optimizer.step()                    #根据梯度更新网络参数loss_value=loss.item()if batch_size_num%100==0:print(f'loss:{loss_value:>7f}[number:{batch_size_num}]')batch_size_num+=1epochs=10for i in range(epochs):print(f'第{i}次训练')train(train_dataloader, model, loss_fn, optimizer)

模型测试

def test(dataloader,model,loss_fn):size = len(dataloader.dataset)# 测试集总样本数num_batches = len(dataloader)# 测试集总批次数model.eval()#进入到模型的测试状态,所有的卷积核权重被设为只读模式test_loss, correct = 0, 0# 初始化累计损失和正确预测数#禁用梯度计算with torch.no_grad():#一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算所用内存消耗。for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)test_loss+=loss_fn(pred,y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)b=(pred.argmax(1)==y).type(torch.float)test_loss/=num_batchescorrect/=sizeprint(f'Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}')test(test_dataloader,model,loss_fn)

得到结果如图所示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/77667.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大文件分片上传进阶版(新增md5校验、上传进度展示、并行控制,智能分片、加密上传、断点续传、自动重试),实现四位一体的网络感知型大文件传输系统‌

上篇文章我们总结了大文件分片上传的主要核心,但是我对md5校验和上传进度展示这块也比较感兴趣,所以在deepseek的帮助下,扩展了一下我们的代码,如果有任何问题和想法,非常欢迎大家在评论区与我交流,我需要学…

C# 点击导入,将需要的参数传递到弹窗的页面

点击导入按钮,获取本页面的datagridview标题的结构,并传递到导入界面。 新增一个datatable用于存储datagridview的caption和name,这里用的是devexpress组件中的gridview。 DataTable dt new DataTable(); DataColumn CAPTION …

android的 framework 是什么

Android的Framework(框架)是Android系统的核心组成部分,它为开发者提供了一系列的API(应用程序编程接口),使得开发者能够方便地创建各种Android应用。以下是关于它的详细介绍: 位置与架构 在A…

【MySQL】表的约束(主键、唯一键、外键等约束类型详解)、表的设计

目录 1.数据库约束 1.1 约束类型 1.2 null约束 — not null 1.3 unique — 唯一约束 1.4 default — 设置默认值 1.5 primary key — 主键约束 自增主键 自增主键的局限性:经典面试问题(进阶问题) 1.6 foreign key — 外键约束 1.7…

数据结构-C语言版本(三)栈

数据结构中的栈:概念、操作与实战 第一部分 栈分类及常见形式 栈是一种遵循后进先出(LIFO, Last In First Out)原则的线性数据结构。栈主要有以下几种实现形式: 1. 数组实现的栈(顺序栈) #define MAX_SIZE 100typedef struct …

如何以特殊工艺攻克超薄电路板制造难题?

一、超薄PCB的行业定义与核心挑战 超薄PCB通常指厚度低于1.0毫米的电路板,而高端产品可进一步压缩至0.4毫米甚至0.2毫米以下。这类电路板因体积小、重量轻、热传导性能优异,被广泛应用于折叠屏手机、智能穿戴设备、医疗植入器械及新能源汽车等领域。然而…

AI 赋能 3D 创作!Tripo3D 全功能深度解析与实操教程

大家好,欢迎来到本期科技工具分享! 今天要给大家带来一款革命性的 AI 3D 模型生成平台 ——Tripo3D。 无论你是游戏开发者、设计师,还是 3D 建模爱好者,只要想降低创作门槛、提升效率,这款工具都值得深入了解。 接下…

如何理解抽象且不易理解的华为云 API?

API的概念在华为云的使用中非常抽象,且不容易理解,用通俗的语言 形象的比喻来讲清楚——什么是华为云 API,怎么用,背后原理,以及主要元素有哪些,尽量让新手也能明白。 🧠 一句话先理解&#xf…

第 7 篇:总结与展望 - 时间序列学习的下一步

第 7 篇:总结与展望 - 时间序列学习的下一步 (图片来源: Guillaume Hankenne on Pexels) 恭喜你!如果你一路跟随这个系列走到了这里,那么你已经成功地完成了时间序列分析的入门之旅。我们从零开始,一起探索了时间数据的基本概念、…

PPT无法编辑怎么办?原因及解决方法全解析

在日常办公中,我们经常会遇到需要编辑PPT的情况。然而,有时我们会发现PPT文件无法编辑,这可能由多种原因引起。今天我们来看看PPT无法编辑的几种常见原因,并提供实用的解决方法,帮助你轻松应对。 原因1:文…

前端面试题---GET跟POST的区别(Ajax)

GET 和 POST 是两种 HTTP 请求方式,它们在传输数据的方式和所需空间上有一些重要区别: ✅ 一句话概括: GET 数据放在 URL 中,受限较多;POST 数据放在请求体中,空间更大更安全。 📦 1. 所需空间…

第 5 篇:初试牛刀 - 简单的预测方法

第 5 篇:初试牛刀 - 简单的预测方法 经过前面四篇的学习,我们已经具备了处理时间序列数据的基本功:加载、可视化、分解以及处理平稳性。现在,激动人心的时刻到来了——我们要开始尝试预测 (Forecasting) 未来! 预测是…

从代码学习深度学习 - 学习率调度器 PyTorch 版

文章目录 前言一、理论背景二、代码解析2.1. 基本问题和环境设置2.2. 训练函数2.3. 无学习率调度器实验2.4. SquareRootScheduler 实验2.5. FactorScheduler 实验2.6. MultiFactorScheduler 实验2.7. CosineScheduler 实验2.8. 带预热的 CosineScheduler 实验三、结果对比与分析…

k8s 基础入门篇之开启 firewalld

前面在部署k8s时,都是直接关闭的防火墙。由于生产环境需要开启防火墙,只能放行一些特定的端口, 简单记录一下过程。 1. firewall 与 iptables 的关系 1.1 防火墙(Firewall) 定义: 防火墙是网络安全系统&…

RSS 2025|苏黎世提出「LLM-MPC混合架构」增强自动驾驶,推理速度提升10.5倍!

论文题目:Enhancing Autonomous Driving Systems with On-Board Deployed Large Language Models 论文作者:Nicolas Baumann,Cheng Hu,Paviththiren Sivasothilingam,Haotong Qin,Lei Xie,Miche…

list的学习

list的介绍 list文档的介绍 list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭代。list的底层是双向链表结构,双向链表中每个元素存储在互不相关的独立节点中,在节点中通过指针指向其前一个元素和后一…

生物信息学技能树(Bioinformatics)与学习路径

李升伟 整理 生物信息学是一门跨学科领域,涉及生物学、计算机科学以及统计学等多个方面。以下是关于生物信息学的学习路径及相关技能的详细介绍。 一、基础理论知识 1. 生物学基础知识 需要掌握分子生物学、遗传学、细胞生物学等相关概念。 对基因组结构、蛋白质…

AOSP Android14 Launcher3——远程窗口动画关键类SurfaceControl详解

在 Launcher3 执行涉及其他应用窗口(即“远程窗口”)的动画时,例如“点击桌面图标启动应用”或“从应用上滑回到桌面”的过渡动画,SurfaceControl 扮演着至关重要的角色。它是实现这些跨进程、高性能、精确定制动画的核心技术。 …

超详细实现单链表的基础增删改查——基于C语言实现

文章目录 1、链表的概念与分类1.1 链表的概念1.2 链表的分类 2、单链表的结构和定义2.1 单链表的结构2.2 单链表的定义 3、单链表的实现3.1 创建新节点3.2 头插和尾插的实现3.3 头删和尾删的实现3.4 链表的查找3.5 指定位置之前和之后插入数据3.6 删除指定位置的数据和删除指定…

17.整体代码讲解

从入门AI到手写Transformer-17.整体代码讲解 17.整体代码讲解代码 整理自视频 老袁不说话 。 17.整体代码讲解 代码 import collectionsimport math import torch from torch import nn import os import time import numpy as np from matplotlib import pyplot as plt fro…