深度学习手写字符识别:训练模型

说明

本篇博客主要是跟着B站中国计量大学杨老师的视频实战深度学习手写字符识别。
第一个深度学习实例手写字符识别

深度学习环境配置

可以参考下篇博客,网上也有很多教程,很容易搭建好深度学习的环境。
Windows11搭建GPU版本PyTorch环境详细过程

数据集

手写字符识别用到的数据集是MNIST数据集(Mixed National Institute of Standards and Technology database);MNIST是一个用来训练各种图像处理系统二进制图像数据集,广泛应用到机器学习中的训练和测试。
作为一个入门级的计算机视觉数据集,发布20多年来,它已经被无数机器学习入门者应用无数遍,是最受欢迎的深度学习数据集之一。

序号说明
发布方National Institute of Standards and Technology(美国国家标准技术研究所,简称NIST)
发布时间1998
背景该数据集的论文想要证明在模式识别问题上,基于CNN的方法可以取代之前的基于手工特征的方法,所以作者创建了一个手写数字的数据集,以手写数字识别作为例子证明CNN在模式识别问题上的优越性。
简介MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的。MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。

跟着视频跑源码

  1. 下载源码:mivlab/AI_course (github.com)
  2. 下载数据集:https://opendatalab.com/MNIST;网上下载的地址比较多,也可以直接下载B站中国计量大学杨老师的百度网盘位置里的MNIST。

运行源码

  1. 在Pycharm中打开AI_course项目,运行classify_pytorch文件目录里train_mnist.py的Python文件。
    在这里插入图片描述
    train_mnist.py具体的源码如下:
import torch
import math
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms, models
import argparse
import os
from torch.utils.data import DataLoaderfrom dataloader import mnist_loader as ml
from models.cnn import Net
from toonnx import to_onnxparser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--datapath', required=True, help='data path')
parser.add_argument('--batch_size', type=int, default=256, help='training batch size')
parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train')
parser.add_argument('--use_cuda', default=False, help='using CUDA for training')args = parser.parse_args()
args.cuda = args.use_cuda and torch.cuda.is_available()
if args.cuda:torch.backends.cudnn.benchmark = Truedef train():os.makedirs('./output', exist_ok=True)if True: #not os.path.exists('output/total.txt'):ml.image_list(args.datapath, 'output/total.txt')ml.shuffle_split('output/total.txt', 'output/train.txt', 'output/val.txt')train_data = ml.MyDataset(txt='output/train.txt', transform=transforms.ToTensor())val_data = ml.MyDataset(txt='output/val.txt', transform=transforms.ToTensor())train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)val_loader = DataLoader(dataset=val_data, batch_size=args.batch_size)model = Net(10)#model = models.vgg16(num_classes=10)#model = models.resnet18(num_classes=10)  # 调用内置模型#model.load_state_dict(torch.load('./output/params_10.pth'))#from torchsummary import summary#summary(model, (3, 28, 28))if args.cuda:print('training with cuda')model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3)scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 30], 0.1)loss_func = nn.CrossEntropyLoss()for epoch in range(args.epochs):# training-----------------------------------model.train()train_loss = 0train_acc = 0for batch, (batch_x, batch_y) in enumerate(train_loader):if args.cuda:batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())else:batch_x, batch_y = Variable(batch_x), Variable(batch_y)out = model(batch_x)  # 256x3x28x28  out 256x10loss = loss_func(out, batch_y)train_loss += loss.item()pred = torch.max(out, 1)[1]train_correct = (pred == batch_y).sum()train_acc += train_correct.item()print('epoch: %2d/%d batch %3d/%d  Train Loss: %.3f, Acc: %.3f'% (epoch + 1, args.epochs, batch, math.ceil(len(train_data) / args.batch_size),loss.item(), train_correct.item() / len(batch_x)))optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()  # 更新learning rateprint('Train Loss: %.6f, Acc: %.3f' % (train_loss / (math.ceil(len(train_data)/args.batch_size)),train_acc / (len(train_data))))# evaluation--------------------------------model.eval()eval_loss = 0eval_acc = 0for batch_x, batch_y in val_loader:if args.cuda:batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())else:batch_x, batch_y = Variable(batch_x), Variable(batch_y)out = model(batch_x)loss = loss_func(out, batch_y)eval_loss += loss.item()pred = torch.max(out, 1)[1]num_correct = (pred == batch_y).sum()eval_acc += num_correct.item()print('Val Loss: %.6f, Acc: %.3f' % (eval_loss / (math.ceil(len(val_data)/args.batch_size)),eval_acc / (len(val_data))))# 保存模型。每隔多少帧存模型,此处可修改------------if (epoch + 1) % 1 == 0:# torch.save(model, 'output/model_' + str(epoch+1) + '.pth')torch.save(model.state_dict(), 'output/params_' + str(epoch + 1) + '.pth')#to_onnx(model, 3, 28, 28, 'params.onnx')if __name__ == '__main__':train()
  1. 报错:没有cv2,即没有安装OpenCV库。
    在这里插入图片描述
  2. 安装OpenCV库,可以命令行安装,也可以Pycharm中安装。
  • 命令行激活虚拟环境:conda activate deeplearning
  • 命令行安装: pip install opencv-python(也可以Pycharm中下载,可能上梯子安装更快)
    在这里插入图片描述
  1. 再次运行,出现如下图提示,表明需要将下载好的数据集配置到configure中。
    在这里插入图片描述
  2. 加载下载好的数据集,即--datapath=数据集的路径
    在这里插入图片描述
  3. 点击“Run”,开始训练,损失和准确率在一直更新,持续训练,直到模型完成,未改动源码的情况下,训练时间可能需要较长。
    在这里插入图片描述
  4. 在小编的拯救者笔记本电脑上持续训练了10小时才完成最终的模型训练,可以看到训练损失已经很低了,准确度很高水平。
    在这里插入图片描述
  5. 在项目中output文件夹中可以看到已经训练好了很多模型;后面可以利用模型进行推理了。
    在这里插入图片描述

参考

https://zhuanlan.zhihu.com/p/681236488

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/667383.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vcruntime140.dll最新的修复方法,一键修复vcruntime140.dll的手段

在这篇文章中,我们将深入探讨并详细介绍各种修复vcruntime140.dll文件缺失或损坏问题的方法。鉴于此类问题广泛存在并影响了众多用户,本文目的是向大家展示不同的修复策略,希望能够帮助每个人解决这些棘手的技术难题。下面一起来看看vcruntim…

【RT-DETR有效改进】UNetv2提出的一种SDI多层次特征融合模块(细节高效涨点)

👑欢迎大家订阅本专栏,一起学习RT-DETR👑 一、本文介绍 本问给大家带来的改进机制是UNetv2提出的一种多层次特征融合模块(SDI)其是一种用于替换Concat操作的模块,SDI模块的主要思想是通过整合编码器生成的层级特征图来增强图像中的语义信息和细节信息。包括皮肤…

黑豹程序员-ElementPlus选择图标器

ElementPlus组件提供了很多图标svg 如何在你的系统中&#xff0c;用户可以使用呢&#xff1f; 这就是图标器&#xff0c;去调用ElementPlus的icon组件库&#xff0c;展示到页面&#xff0c;用户选择&#xff0c;返回选择的组件名称。 效果 代码 <template><el-inpu…

HarmonyOS ArkTS Button基本使用(十八)

HarmonyOS ArkTS是一种应用于鸿蒙系统的应用开发语言&#xff0c;它在TypeScript的基础上&#xff0c;扩展了声明式UI、状态管理等能力。在HarmonyOS中&#xff0c;Button是一种常用的组件&#xff0c;用于实现页面间的跳转和交互。下面详细介绍HarmonyOS ArkTS中Button的基本使…

深度学习环境指南【1】:Nvidia 驱动

系列文章目录 文章目录 系列文章目录前言选择合适的驱动可能遇到的问题安全模式下删除显卡现有的驱动删除在电脑上安装的 DDU 总结 前言 本文作为深度学习环境指南系列的第一篇文章&#xff0c;主要讲解当你第一次拿到显卡完成装机后需要做的步骤&#xff0c;或者是显卡驱动不…

如何使用GPT提问三元操作符?

英语10分钟&#xff1a; 现在chatgpt非常智能&#xff0c;使用的也越来越广泛&#xff0c;今天学习了使用chatgpt4提问时&#xff0c;应该遵循的提示原则&#xff0c;第一个原则&#xff0c;是要写清晰明确的、具体的说明&#xff0c;第二个原则是要给予模型思考的时间。可以安…

机器学习 - 梯度下降

场景 上一章学习了代价函数&#xff0c;在机器学习中&#xff0c;代价模型是用于衡量模型预测值与真实值之间的差异的函数。它是优化算法的核心&#xff0c;目标是通过调整模型的参数来最小化代价模型的值&#xff0c;从而使模型的预测结果更接近真实值。常见的代价模型是均方…

红黑树,以及其在C++的set、map等数据结构中应用

红黑树介绍&#xff1a; 红黑树&#xff08;Red-Black Tree&#xff09;是一种自平衡的二叉搜索树&#xff0c;它在插入和删除操作后通过一系列的旋转和着色操作来维持平衡。红黑树的命名来自于节点上的额外颜色属性&#xff0c;每个节点要么是红色&#xff0c;要么是黑色。 红…

【Boost】:searcher的建立(四)

searcher的建立 一.初始化二.搜索功能三.完整源代码 sercher主要分为两部分&#xff1a;初始化和查找。 一.初始化 初始化分为两步&#xff1a;1.创建Index对象&#xff1b;2.建立索引 二.搜索功能 搜索分为四个步骤 分词&#xff1b;触发&#xff1a;根据分词找到对应的文档…

架构设计特训

一、考点分布 软件架构风格&#xff08;※※※※&#xff09;层次型软件架构风格&#xff08;※※※※&#xff09;面向服务的软件架构风格&#xff08;※※※※&#xff09;云原生架构风格&#xff08;※※※※&#xff09;质量属性与架构评估&#xff08;※※※※※&#xff…

Transformer实战-系列教程1:Transformer算法解读1

&#x1f6a9;&#x1f6a9;&#x1f6a9;Transformer实战-系列教程总目录 有任何问题欢迎在下面留言 Transformer实战-系列教程1&#xff1a;Transformer算法解读1 Transformer实战-系列教程2&#xff1a;Transformer算法解读2 现在最火的AI内容&#xff0c;chatGPT、视觉大模…

网络安全-端口扫描和服务识别的几种方式

禁止未授权测试&#xff01;&#xff01;&#xff01; 前言 在日常的渗透测试中&#xff0c;我们拿到一个ip或者域名之后&#xff0c;需要做的事情就是搞清楚这台主机上运行的服务有哪些&#xff0c;开放的端口有哪些。如果我们连开放的端口和服务都不知道&#xff0c;下一步针…

反洗钱_2_反洗钱国际组织和国际标准

文章目录 二、反洗钱国际组织和国际标准2.1 反洗钱国际组织2.2 反洗钱国际标准2.3 中国与FAFT 二、反洗钱国际组织和国际标准 2.1 反洗钱国际组织 金融行动特别工作组&#xff1a;Financial Action Task Force on Money Laundering (FATF)FATF成立于1989年在巴黎召开的西方七…

206. 反转链表-递归反转链表

206. 反转链表-递归反转链表 解题思路 基本情况处理&#xff1a; 开始时&#xff0c;首先检查链表是否为空或只包含一个节点。若是&#xff0c;直接返回原链表头部。 递归调用&#xff1a; 对于包含两个或更多节点的链表&#xff0c;将递归调用 reverseList 方法&#xff0c;…

Golang切片与数组

在Go语言中&#xff0c;切片&#xff08;Slice&#xff09;和数组&#xff08;Array&#xff09;是两个核心的数据结构&#xff0c;它们在内存管理、灵活性以及性能方面有着显著的区别。接下来将解析Golang中的切片与数组&#xff0c;通过清晰的概念解释、案例代码和实际应用场…

Vue2基础

前端技术了解&#xff08;了解&#xff09;ES6常见语法&#xff08;掌握&#xff09;Vue入门&#xff08;掌握&#xff09;Vue表达式&#xff08;掌握&#xff09;Vue指令&#xff08;掌握&#xff09;计算属性与侦听器&#xff08;了解&#xff09; 一、前端技术了解&#xf…

小林Coding_操作系统_读书笔记

一、硬件结构 1. CPU是如何执行的 冯诺依曼模型&#xff1a;中央处理器&#xff08;CPU&#xff09;、内存、输入设备、输出设备、总线 CPU中&#xff1a;寄存器&#xff08;程序计数器、通用暂存器、指令暂存器&#xff09;&#xff0c;控制单元&#xff08;控制CPU工作&am…

[word] word页面视图放大后,影响打印吗? #笔记#学习方法

word页面视图放大后&#xff0c;影响打印吗&#xff1f; word文档的页面视图又叫普通视图&#xff0c;又叫打印视图&#xff0c;是系统默认的视图&#xff0c;是用户用的最多最常见的视图。 问&#xff1a;怎样打开页面视图&#xff1f; 答&#xff1a;两种方法 方法一、点…

JS 基本语句

函数调用&#xff0c;分支&#xff0c;循环&#xff0c;语句示例。 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"&g…

【Funny guys】龙年专属测试鼠标寿命小游戏...... 用Python给大家半年了......

目录 【Funny guys】龙年专属测试鼠标寿命小游戏...... 用Python给大家半年了...... 龙年专属测试鼠标寿命小游戏用Python给大家半年了贪吃龙游戏 文章所属专区 码农新闻 欢迎各位编程大佬&#xff0c;技术达人&#xff0c;以及对编程充满热情的朋友们&#xff0c;来到我们的程…