MNIST 手写数字分类

转自我的个人博客: https://shar-pen.github.io/2025/05/04/torch-distributed-series/1.MNIST/

基础的单卡训练

本笔记本演示了训练一个卷积神经网络(CNN)来对 MNIST 数据集中的手写数字进行分类的过程。工作流程包括:

  1. 数据准备:加载和预处理 MNIST 数据集。
  2. 模型定义:使用 PyTorch 构建 CNN 模型。
  3. 模型训练:在 MNIST 训练数据集上训练模型。
  4. 模型评估:在 MNIST 测试数据集上测试模型并评估其性能。
  5. 可视化:展示样本图像及其对应的标签。

参考 pytorch 官方示例 https://github.com/pytorch/examples/blob/main/mnist/main.py 。

至于为什么选择 MNIST 分类任务, 因为它就是深度学习里的 Hello World.

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms
from time import time

深度学习里,真正必要的超参数,大致是下面这些:

  1. 学习率(learning rate)

    • 最最核心的超参数。
    • 决定每次参数更新的步幅大小。
    • 学习率不合适,训练几乎一定失败。
  2. 优化器(optimizer)

    • 比如 SGDAdamAdamW 等。
    • 不同优化器,收敛速度、最终效果差异很大。
    • 有时也需要设置优化器内部超参(比如 Adam 的 β 1 , β 2 \beta_1, \beta_2 β1,β2)。
  3. 批大小(batch size)

    • 多少样本合成一批送进模型训练。
    • 影响训练稳定性、收敛速度、硬件占用。
  4. 训练轮次(epoch)最大步数(max steps)

    • 总共训练多久。
    • 如果训练不够长,模型欠拟合;太久则过拟合或资源浪费。
  5. 损失函数(loss function)

    • 明确训练目标,比如分类用 CrossEntropyLoss,回归用 MSELoss
    • 不同任务必须选对损失。

超参设置

我们设置些最基础的超参: epoch, batch size, device, lr

EPOCHS = 5
BATCH_SIZE = 512
LR = 0.001
LR_DECAY_STEP_NUM = 1
LR_DECAY_FACTOR = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

数据构建

直接用库函数生成 dataset 和 dataloader, 前者其实只是拿来生成 dataloader

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_data = datasets.MNIST(root = './mnist',train=True,       # 设置True为训练数据,False为测试数据transform = transform,# download=True  # 设置True后就自动下载,下载完成后改为False即可
)train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)test_data = datasets.MNIST(root = './mnist',train=False,       # 设置True为训练数据,False为测试数据transform = transform,
)test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)# plot one exampleprint(f'dataset: input shape: {train_data.data.size()}, label shape: {train_data.targets.size()}')
print(f'dataloader iter: input shape: {next(iter(train_loader))[0].size()}, label shape: {next(iter(train_loader))[1].size()}')
plt.imshow(train_data.data[0].numpy(), cmap='gray')
plt.title(f'Label: {train_data.targets[0]}')
plt.show()

​ dataset: input shape: torch.Size([60000, 28, 28]), label shape: torch.Size([60000])
​ dataloader iter: input shape: torch.Size([512, 1, 28, 28]), label shape: torch.Size([512])

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

网络

设计简单的 ConvNet, 几层 CNN + MLP。初始化新模型后,先将其放到 DEVICE 上

class ConvNet(nn.Module):"""A neural network model for MNIST digit classification.This model is designed to classify images from the MNIST dataset, which consists of grayscale images of handwritten digits (0-9). The network architecture includes convolutional layers for feature extraction, followed by fully connected layers for classification.Attributes:features (nn.Sequential): A sequential container of convolutional layers, activation functions, pooling, and dropout for feature extraction.classifier (nn.Sequential): A sequential container of fully connected layers, activation functions, and dropout for classification.Methods:forward(x):Defines the forward pass of the network. Takes an input tensor `x`, processes it through the feature extractor and classifier, and returns the log-softmax probabilities for each class."""def __init__(self):super(ConvNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(1, 32, 3, 1),nn.ReLU(),nn.Conv2d(32, 64, 3, 1),nn.ReLU(),nn.MaxPool2d(2),nn.Dropout(0.25))self.classifier = nn.Sequential(nn.Linear(9216, 128),nn.ReLU(),nn.Dropout(0.5),nn.Linear(128, 10))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)output = F.log_softmax(x, dim=1)return output

训练和评估函数

将训练和评估函数分别封装为函数,使主循环更简洁

def train(model, device, train_loader, optimizer):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if (batch_idx + 1) % 30 == 0: print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

主训练循环

model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP_NUM, gamma=LR_DECAY_FACTOR)start_time = time()  # Record the start time
for epoch in range(EPOCHS):epoch_start_time = time()  # Record the start time of the current epochprint(f'Epoch {epoch}/{EPOCHS}')print(f'Learning Rate: {scheduler.get_last_lr()[0]}')train(model, DEVICE, train_loader, optimizer)test(model, DEVICE, test_loader)scheduler.step()epoch_end_time = time()  # Record the end time of the current epochprint(f"Time for epoch {epoch}: {epoch_end_time - epoch_start_time:.2f} seconds")end_time = time()  # Record the end time
print(f"Total training time: {end_time - start_time:.2f} seconds")
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A   1795609      C   ...st/anaconda3/envs/xprepo/bin/python        448MiB |
|    0   N/A  N/A   1814253      C   ...st/anaconda3/envs/xprepo/bin/python       1036MiB |
|    7   N/A  N/A   4167010      C   ...guest/anaconda3/envs/QDM/bin/python      19416MiB |
+-----------------------------------------------------------------------------------------+

0 卡的占用 1484 MB

完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms
from time import time
import argparseclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(1, 32, 3, 1),nn.ReLU(),nn.Conv2d(32, 64, 3, 1),nn.ReLU(),nn.MaxPool2d(2),nn.Dropout(0.25))self.classifier = nn.Sequential(nn.Linear(9216, 128),nn.ReLU(),nn.Dropout(0.5),nn.Linear(128, 10))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)output = F.log_softmax(x, dim=1)return outputdef arg_parser():parser = argparse.ArgumentParser(description="MNIST Training Script")parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training")parser.add_argument("--lr", type=float, default=0.0005, help="Learning rate")parser.add_argument("--lr_decay_step_num", type=int, default=1, help="Step size for learning rate decay")parser.add_argument("--lr_decay_factor", type=float, default=0.5, help="Factor by which learning rate is decayed")parser.add_argument("--cuda_id", type=int, default=0, help="CUDA device ID to use")return parser.parse_args()def prepare_data(batch_size):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])train_data = datasets.MNIST(root = './mnist',train=True,       # 设置True为训练数据,False为测试数据transform = transform,# download=True  # 设置True后就自动下载,下载完成后改为False即可)train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)test_data = datasets.MNIST(root = './mnist',train=False,       # 设置True为训练数据,False为测试数据transform = transform,)test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)return train_loader, test_loaderdef train(model, device, train_loader, optimizer):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if (batch_idx + 1) % 30 == 0: print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))def train_mnist_classification():args = arg_parser()print(args)EPOCHS = args.epochsBATCH_SIZE = args.batch_sizeLR = args.lrLR_DECAY_STEP_NUM = args.lr_decay_step_numLR_DECAY_FACTOR = args.lr_decay_factorCUDA_ID = args.cuda_idDEVICE = torch.device(f"cuda:{CUDA_ID}")train_loader, test_loader = prepare_data(BATCH_SIZE)model = ConvNet().to(DEVICE)optimizer = optim.Adam(model.parameters(), lr=LR)scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP_NUM, gamma=LR_DECAY_FACTOR)start_time = time()  # Record the start timefor epoch in range(EPOCHS):epoch_start_time = time()  # Record the start time of the current epochprint(f'Epoch {epoch}/{EPOCHS}')print(f'Learning Rate: {scheduler.get_last_lr()[0]}')train(model, DEVICE, train_loader, optimizer)test(model, DEVICE, test_loader)scheduler.step()epoch_end_time = time()  # Record the end time of the current epochprint(f"Time for epoch {epoch}: {epoch_end_time - epoch_start_time:.2f} seconds")end_time = time()  # Record the end timeprint(f"Total training time: {end_time - start_time:.2f} seconds")if __name__ == "__main__":train_mnist_classification()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/82568.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据库中的 Segment、Extent、Page、Row 详解

在关系型数据库的底层存储架构中,数据并不是随意写入磁盘,而是按照一定的结构分层管理的。理解这些存储单位对于优化数据库性能、理解 SQL 执行过程以及排查性能问题都具有重要意义。 我将从宏观到微观,依次介绍数据库存储中的四个核心概念&…

DAMA车轮图

DAMA车轮图是国际数据管理协会(DAMA International)提出的数据管理知识体系(DMBOK)的图形化表示,它以车轮(同心圆)的形式展示了数据管理的核心领域及其相互关系。以下是基于用户提供的关键词对D…

《QDebug 2025年4月》

一、Qt Widgets 问题交流 1. 二、Qt Quick 问题交流 1.QML单例动态创建的对象,访问外部id提示undefined 先定义一个窗口组件,打印外部的id: // MyWindow.qml import QtQuick 2.15 import QtQuick.Window 2.15Window {id: controlwidth: …

JS | 正则 · 常用正则表达式速查表

以下是前端开发中常用的正则表达式速查表,包含验证规则、用途说明与示例: 📌 常用正则表达式速查表 名称正则表达式描述 / 用途示例手机号/^1[3-9]\d{9}$/中国大陆手机号13812345678 ✅座机号/^0\d{2,3}-?\d{7,8}$/固定电话010-12345678 ✅…

系统思考:个人与团队成长

四年前,我交付的系统思考项目,今天学员的反馈依然深深触动了我。 我常常感叹,系统思考不仅仅是一场培训,更像是一场持续的“修炼”。在这条修炼之路上,最珍贵的,便是有志同道合的伙伴们一路同行&#xff0…

写屏障和读屏障的区别是什么?

写屏障(Write Barrier)与读屏障(Read Barrier)的区别 在计算机科学中,写屏障和读屏障是两种关键的内存同步机制,主要用于解决并发编程中的可见性、有序性问题,或在垃圾回收(GC&…

ssh -T git@github.com 测试失败解决方案:修改hosts文件

问题描述 通过SSH方式测试,使用该方法测试连接可能会遇到连接超时、端口占用的情况,原因是因为DNS配置及其解析的问题 ssh -T gitgithub.com我们可以详细看看建立 ssh 连接的过程中发生了什么,可以使用 ssh -v命令,-v表示 verbo…

大疆无人机搭载树莓派进行目标旋转检测

环境部署 首先是环境创建,创建虚拟环境,名字叫 pengxiang python -m venv pengxiang随后激活环境 source pengxiang/bin/activate接下来便是依赖包安装过程了: pip install onnxruntime #推理框架 pip install fastapi uvicorn[standard] #网络请求…

00 Ansible简介和安装

1. Ansible概述与基本概念 1.1. 什么是Ansible? Ansible 是一款用 Python 编写的开源 IT 自动化工具,主要用于配置管理、软件部署及高级工作流编排。它能够简化应用程序部署、系统更新等操作,并且支持自动化管理大规模的计算机系统。Ansibl…

Linxu实验五——NFS服务器

一.NFS服务器介绍 NFS服务器(Network File System)是一种基于网络的分布式文件系统协议,允许不同操作系统的主机通过网络共享文件和目录3。其核心作用在于实现跨平台的资源透明访问,例如在Linux和Unix系统之间共享静态数据&#…

『 测试 』测试基础

文章目录 1. 调试与测试的区别2. 开发过程中的需求3. 开发模型3.1 软件的生命周期3.2 瀑布模型3.2.1 瀑布模型的特点/缺点 3.3 螺旋模型3.3.1 螺旋模型的特点/缺点 3.4 增量模型与迭代模型3.5 敏捷模型3.5.1 Scrum模型3.5.2 敏捷模型中的测试 4 测试模型4.1 V模型4.2 W模型(双V…

红外遥控键

红外 本章节旨在让用户自定义红外遥控功能,需要有板载红外接收的板卡。 12.1. 获取红外遥控键值 由于不同遥控器厂家定义的按键键值不一样,所以配置不通用,需要获取实际按键对应的键值。 1 2 3 4 5 6 #设置输出等级 echo 7 4 1 7> /pr…

同一个虚拟环境中conda和pip安装的文件存储位置解析

文章目录 存储位置的基本区别conda安装的包pip安装的包 看似相同实则不同的机制实际路径示例这种差异带来的问题如何检查包安装来源最佳实践建议 总结 存储位置的基本区别 conda安装的包 存储在Anaconda(或Miniconda)目录下的pkgs和envs子目录中: ~/anaconda3/en…

机器学习极简入门:从基础概念到行业应用

有监督学习(supervised learning) 让模型学习的数据包含正确答案(标签)的方法,最终模型可以对无标签的数据进行正确处理和预测,可以分为分类与回归两大类 分类问题主要是为了“尽可能分开整个数据而画线”…

split和join的区别‌

split和join是Python中用于处理字符串的两种方法,它们的主要区别在于功能和使用场景。‌ split()方法 ‌split()方法用于将字符串按照指定的分隔符分割成多个子串,并返回这些子串组成的列表‌。如果不指定分隔符,则默认分割所有的空白字符&am…

MySQL从入门到精通(二):Windows和Mac版本MySQL安装教程

目录 MySQL安装流程 (一)、进入MySQL官网 (二)、点击下载(Download) (三)、Windows和Mac版本下载 下载Windows版本 下载Mac版本 (四)、验证并启动MySQL …

LeetCode 解题思路 45(分割等和子集、最长有效括号)

解题思路: dp 数组的含义: 在数组中是否存在一个子集,其和为 i。递推公式: dp[i] | dp[i - num]。dp 数组初始化: dp[0] true。遍历顺序: 从大到小去遍历,从 i target 开始,直到 …

电影感户外哑光人像自拍摄影Lr调色预设,手机滤镜PS+Lightroom预设下载!

调色详情 电影感户外哑光人像自拍摄影 Lr 调色,是借助 Lightroom 软件,针对户外环境下拍摄的人像自拍进行后期处理。旨在模拟电影画面的氛围与质感,通过调色赋予照片独特的艺术气息。强调打造哑光效果,使画面色彩不过于浓烈刺眼&a…

使用 NV‑Ingest、Unstructured 和 Elasticsearch 处理非结构化数据

作者:来自 Elastic Ajay Krishnan Gopalan 了解如何使用 NV-Ingest、Unstructured Platform 和 Elasticsearch 为 RAG 应用构建可扩展的非结构化文档数据管道。 Elasticsearch 原生集成了行业领先的生成式 AI 工具和提供商。查看我们的网络研讨会,了解如…

Android 13 使能user版本进recovery

在 debug 版本上,可以在关机状态下,同时按 电源键 和 音量加键 进 recovery 。 user 版本上不行。 参考 使用 build 变体 debug 版本和 user 版本的差别之一就是 ro.debuggable 属性不同。 顺着这个思路追踪,找到 bootable/recovery/reco…