深度学习——卷积神经网络实现手写数字识别

一、准备工作

导入所需的依赖库:

# import torch # print(torch.__version__) ''' MNIST包含70000张手写数字图像:60000用于训练,10000用于测试 图像是灰度的,28×28像素的,并且居中的,以减少预处理和加快运行 ''' import torch from torch import nn #导入神经网络模块 from torch.utils.data import DataLoader #数据包管理工具,打包数据 from torchvision import datasets #封装了很多与图像相关的模型,数据集 from torchvision.transforms import ToTensor #数据转换,张量,将其他类型的数据转换为tensor张量,numpy array

torch:PyTorch 的核心包,提供张量运算和深度学习构建的基础。

nn:神经网络模块,用于搭建层结构(卷积层、全连接层等)。

DataLoader:数据加载器,可以自动打包数据,支持批量读取。

datasets:提供常用的数据集(如 MNIST、CIFAR10)。

ToTensor:将图片转换为张量格式,方便神经网络使用。

二、加载数据集

'''下载训练数据集(包含训练图片+标签)''' training_data = datasets.MNIST( root="data", train=True, download=True, transform=ToTensor(), ) '''下载测试数据集(包含测试图片+标签)''' test_data = datasets.MNIST( root="data", train=False, download=True, transform=ToTensor(), ) print(len(training_data))

三、数据可视化(非必须)

from matplotlib import pyplot as plt figure = plt.figure() for i in range(9): img,label = training_data[i+59000] figure.add_subplot(3,3,i+1) plt.title(label) plt.axis("off") plt.imshow(img.squeeze(),cmap='gray') a = img.squeeze() plt.show()

四、创建数据加载器

train_dataloader = DataLoader(training_data, batch_size=64) # 是一个类,现在初始化了,但没开始打包,训练开始才打包 test_dataloader = DataLoader(test_data, batch_size=64) for X, y in test_dataloader: print(f"Shape of X [N, C, H, W]: {X.shape}") # N:批次大小, C:通道数(灰度图为1), H:高度, W:宽度 print(f"Shape of y: {y.shape} {y.dtype}") # 标签的形状和数据类型 break device = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") print(f"using{device} device")

train_dataloader = DataLoader(training_data, batch_size=64)表示现在是一个类,初始化了,但没开始打包,训练开始才打包

batch_size=64:每次从数据集中读取 64 张图片作为一个批次


for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}") # N:批次大小, C:通道数(灰度图为1), H:高度, W:宽度
print(f"Shape of y: {y.shape} {y.dtype}") # 标签的形状和数据类型
break # 只查看一个批次

上述代码是用来查看加载器中一个批次的数据形状

五、选择运行设备

选择选用CPU或者GPU

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using {device} device")
  • cuda:表示选择使用 NVIDIA GPU。

  • mps:Apple M 系列芯片的 GPU。

  • cpu:若电脑种没有 GPU,则使用 CPU

六、定义卷积神经网络(CNN)

''' 定义神经网络 类的继承这种方式''' class CNN(nn.Module): def __init__(self): super(CNN,self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(1,16,3,1,1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) self.conv2 = nn.Sequential( nn.Conv2d(16, 16, 3, 1, 1), nn.ReLU(), nn.Conv2d(16, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 32, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) self.conv3 = nn.Sequential( nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(), nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(), ) self.out = nn.Linear(64*7*7,10) def forward(self,x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = x.view(x.size(0), -1) output = self.out(x) return output model = CNN().to(device) print(model)

模型结构说明:

输入1*28*28(64 张图片作为一个批次。故 64*1*28*28)

conv1(一维):卷积 + ReLU + 池化 → 输出 16*14*14

conv2(二维):多层卷积 + ReLU + 池化 → 输出 32*7*7

conv3(三维):卷积层 → 输出 64*7*7

Linear 全连接层:输入 64*7*7,输出 10(对应 0~9 的数字分类)。

七、训练函数

def train(dataloader,model,loss_fn,optimizer): model.train() batch_size_num = 1 for X,y in dataloader: X,y = X.to(device),y.to(device) pred = model.forward(X) loss = loss_fn(pred,y) optimizer.zero_grad() loss.backward() optimizer.step() loss_value = loss.item() if batch_size_num %100 ==0: print(f"loss: {loss_value:>7f} [number:{batch_size_num}]") batch_size_num += 1
  1. 前向传播:计算预测结果pred

  2. 计算损失:loss_fn(pred,y)

  3. 反向传播:loss.backward()计算梯度。

  4. 参数更新:optimizer.step()

八、测试函数

def Test(dataloader,model,loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss,correct =0,0 with torch.no_grad(): for X,y in dataloader: X,y = X.to(device),y.to(device) pred = model.forward(X) test_loss += loss_fn(pred,y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print(f"Test result: \n Accuracy:{(100*correct)}%, Avg loss:{test_loss}")

九、定义损失函数和优化器

loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(),lr=0.005)
  • CrossEntropyLoss:常用于分类任务。

  • Adam 优化器:比 SGD 收敛更快。

十、训练模型和测试

epochs = 10 for t in range(epochs): print(f"epoch {t+1}\n---------------") train(train_dataloader,model,loss_fn,optimizer) print("Done!") Test(test_dataloader,model,loss_fn)



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1183144.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GCC

"GCC stands for GNU Compiler Collection. GCC is an integrated distribution of compilers for several major programming languages." 注意不要误认为 GCC 是 GNU C Compiler 虽然很多人习惯叫它“GNU …

东方博宜OJ 2379:最少交通费 ← 堆优化 Dijkstra + 链式前向星

​【题目来源】https://oj.czos.cn/p/2379https://www.acwing.com/problem/content/852/【题目描述】Mar 星球上共有 n 个城市(编号为 1~n),城市之间为了方便交通修建了 m 条单向高速公路。有些公路是为了交通方便连…

一条龙服务的封头厂家哪家好,河南有推荐吗? - 工业品牌热点

随着工业设备对核心部件质量要求的不断提升,封头作为压力容器、锅炉等设备的安全心脏,其选择与采购成为众多企业的关键决策点。本文围绕企业关注的封头采购核心问题展开问答,结合新乡市光大机械有限公司的实践经验,…

计算机毕业设计 | SpringBoot+vue图书电子商务网站 图书商城(附源码+论文)

1,绪论 1.1 研究背景 互联网时代不仅仅是通过各种各样的电脑进行网络连接的时代,也包含了移动终端连接互联网进行复杂处理的一些事情。传统的互联网时代一般泛指就是PC端,也就是电脑互联网时代,但是最近几十年,是移动…

使用AI开源免费系统,精准识别店内可疑人员行为

如果你手里有零售门店AI无人巡店需求,并且想免费体验这套系统,赶紧去开源社区下载这个项目,开源项目的地址我放在的文章最后。 技术实现:基于轻量级CNN与时序分析的组合识别 思通数科AI视频卫士致力于提供一个轻量…

计算机毕业设计 | SpringBoot+vue健身房管理系统(附源码+论文)

1,研究背景 互联网概念的产生到如今的蓬勃发展,用了短短的几十年时间就风靡全球,使得全球各个行业都进行了互联网的改造升级,标志着互联网浪潮的来临。在这个新的时代,各行各业都充分考虑互联网是否能与本行业进行结合…

全网最全9个AI论文平台,助继续教育学生轻松搞定论文写作!

全网最全9个AI论文平台,助继续教育学生轻松搞定论文写作! AI 工具如何让论文写作更高效? 在当前的学术环境中,继续教育学生面临着越来越多的挑战,尤其是论文写作这一环节。传统的写作方式不仅耗时耗力,还容…

代码随想录算法训练营第五十九天|dijkstra(堆优化版)精讲,Bellman_ford 算法精讲

Bellman_ford 算法精讲bellman ford算法的三部曲:1. initialization(可以设置n1点)1到1的距离为0, 1到2, 1到3,。。1到n的距离为∞2. 进行(v-1)轮松弛(relax the edge) (对每一条边…

基于51/STM32单片机激光测距超声波液位倒车防撞雷达图像显示无线设计(设计源文件+万字报告+讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码

基于51/STM32单片机激光测距超声波液位倒车防撞雷达图像显示无线设计超声波测距蜂鸣器报警波动开关C51-10 超声波测距数码管显示蜂鸣器报警按键阈值设置C51-60 水位-超声波水位OLED屏舵机水泵蜂鸣器按键高低阈值加水排水C51-61XN 水位-蓝牙无线超声波水位OLED屏舵机水泵蜂鸣器按…

基于STM32单片机智能低压断路器交流电压电流温度检测设计24-259(设计源文件+万字报告+讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码

基于STM32单片机智能低压断路器交流电压电流温度检测设计24-259(设计源文件万字报告讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码 24-259、STM32智能低压断路器设计-TFT1.44-交流电压电流互感器-DS18B20-KEY-BELL 产品功能描述: 本设…

基于STM32单片机智能二维码条形码门禁控制语音播报设计24-304(设计源文件+万字报告+讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码

基于STM32单片机智能二维码条形码门禁控制语音播报设计24-304(设计源文件万字报告讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码24-304、STM32的二维码门禁控制系统设计-GM65二维码-电磁-ISD1820 产品功能描述: 本设计由STM32F103C8T…

芦曲泊帕Lusutrombopag在特定人群中的剂量调整与血栓风险监测

芦曲泊帕(Lusutrombopag)在特定人群中的应用需根据个体特征调整剂量,并加强血栓风险监测,以确保治疗的安全性和有效性。特定人群剂量调整:个体化方案,精准治疗Child-Pugh C级肝硬化患者:由于肝功…

基于STM32单片机智能垃圾桶图像识别分类满溢报警无线APP设计S96(设计源文件+万字报告+讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码

基于STM32单片机智能垃圾桶图像识别分类满溢报警无线APP设计S96(设计源文件万字报告讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码STM32-S96-图像识别垃圾分类4组舵机满溢报警按键分类TFT彩屏(无线方式选择) STM32-S96N无无线-无APP板: STM32-S9…

我们的系统经常遇到DAO360.DLL丢失找不到问题 免费下载方法分享

在使用电脑系统时经常会出现丢失找不到某些文件的情况,由于很多常用软件都是采用 Microsoft Visual Studio 编写的,所以这类软件的运行需要依赖微软Visual C运行库,比如像 QQ、迅雷、Adobe 软件等等,如果没有安装VC运行库或者安装…

计算机毕业设计springboot课堂教学管理系统 基于SpringBoot的智慧课堂互动管理平台 SpringBoot+MySQL构建的混合式教学综合系统

计算机毕业设计springboot课堂教学管理系统5l4h8y1j(配套有源码 程序 mysql数据库 论文) 本套源码可以在文本联xi,先看具体系统功能演示视频领取,可分享源码参考。在“互联网教育”快速演进的当下,传统课堂的纸质签到、口头布置作…

基于STM32单片机智能白光LED可见光通信音频传输系统设计25-072(设计源文件+万字报告+讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码

基于STM32单片机智能白光LED可见光通信音频传输系统设计25-072(设计源文件万字报告讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码25-072、基于STM32单片机智能白光LED可见光通信音频传输系统设计 产品功能描述: 基于白光LED可见光通信…

基于STM32单片机智能摄像头识别病虫害诊断预警蓝牙APP设计22-077(设计源文件+万字报告+讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码

基于STM32单片机智能摄像头识别病虫害诊断预警蓝牙APP设计22-077 22-077、 STM32F103ZET6智能化识别植物病虫害诊断及快速预警设计-TFT2.8-摄像头-DHT11-蓝牙产品功能描述: 本系统由STM32F103ZET6单片机核心板(可插TF卡)电路2.8寸TFT彩屏显示…

全场景音视频赋能:三大综合管理平台技术与落地实践

综合管理平台系列产品以现代音视频技术发展趋势及实际应用需求为导向,依托高清混合矩阵搭建基础架构,针对不同行业场景的差异化需求迭代优化,形成集信号处理、功能集成、场景适配于一体的综合处理系统。该系列包含分布式交互管理平台、图像综…

计算机毕业设计springboot零食销售管理信息系统的设计与开发 基于SpringBoot的休闲食品线上进销存平台构建 SpringBoot驱动的零食电商运营支撑系统研发

计算机毕业设计springboot零食销售管理信息系统的设计与开发e1r04j82 (配套有源码 程序 mysql数据库 论文) 本套源码可以在文本联xi,先看具体系统功能演示视频领取,可分享源码参考。当“宅经济”与“即时满足”碰撞,零食赛道爆发式…

FTP 快捷批处理混淆钓鱼方式利用和防范

FTP 快捷批处理混淆钓鱼是一种文件分片 + 合法进程代理的攻击手法,核心是通过.link快捷方式调用系统自带的ftp.exe,配合分片存储的批处理脚本,实现恶意载荷的隐蔽执行。免责声明:本文所涉及的技术仅供学习和参考,…