Pytorch 项目实战-1: MNIST 手写数字识别

刚接触深度学习的小伙伴们,是不是经常听说 MNIST 数据集和 PyTorch 框架?今天就带大家从零开始,用 PyTorch 实现 MNIST 手写数字识别,轻松迈出深度学习实践的第一步!

一、MNIST 数据集:深度学习界的 “Hello World”​

MNIST 数据集就像是深度学习领域的 “新手村”,里面包含了 6 万张手写数字训练图片和 1 万张测试图片,每张图片都是 28×28 像素的灰度图像,对应的数字标签是 0 - 9。就好比是一个装满数字 “小卡片” 的百宝箱,我们的任务就是教会计算机 “看懂” 这些卡片上的数字。

二、PyTorch 基础库导入

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

 这段代码是在引入我们需要的工具包。torch是 PyTorch 的核心库,就像搭建房屋的砖块;DataLoader是数据加载器,帮我们把数据分批处理;torchvision里的transforms和MNIST,一个用来转换数据格式,一个用来获取 MNIST 数据集;matplotlib.pyplot则是绘图工具,能帮我们直观看到预测结果;nn和nn.functional用于搭建神经网络模型和定义激活函数等操作。

三、搭建神经网络模型

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(28*28, 64)self.fc2 = nn.Linear(64, 64)self.fc3 = nn.Linear(64, 64)self.fc4 = nn.Linear(64, 10)def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.relu(self.fc3(x))x = self.fc4(x)  return x

这里定义了一个名为Net的神经网络类,继承自nn.Module。__init__函数是初始化操作,fc1到fc4是全连接层,全连接层就像是一个 “信息加工厂”,将输入数据进行变换。比如fc1把 28×28(784)个像素点组成的数据转换为 64 个特征。forward函数定义了数据的前向传播过程,F.relu是激活函数,它就像一个 “开关”,让神经网络具备了学习非线性关系的能力,能让数据在不同层之间更好地传递和处理。

四、数据加载与预处理

def get_data_loader():train_data = MNIST(root="mnist_data", train=True, transform=transforms.ToTensor(), download=True)test_data = MNIST(root="mnist_data", train=False, transform=transforms.ToTensor(), download=True)train_loader = DataLoader(train_data, batch_size=64, shuffle=True)test_loader = DataLoader(test_data, batch_size=64)return train_loader, test_loader

get_data_loader函数负责获取 MNIST 数据集。MNIST函数从指定路径(root)下载数据,transforms.ToTensor()将图片转换为 PyTorch 能处理的张量格式。DataLoader则把数据打包成一批一批的,batch_size=64表示每批有 64 张图片,shuffle=True让训练数据每次都打乱顺序,这样能让模型学习得更好。就像把小卡片分成一叠叠,每次训练随机抽取一叠,避免模型 “记住” 固定顺序。

五、模型评估函数

def evaluate(test_loader, net):net.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs = inputs.view(-1, 28*28)outputs = net(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()return correct / total

evaluate函数用来评估模型的准确率。net.eval()将模型设置为评估模式,with torch.no_grad()表示在这个过程中不计算梯度,节省计算资源。遍历测试数据,把图片数据整理成合适格式后输入模型,得到输出结果,用torch.max找到概率最大的类别作为预测结果,最后计算预测正确的比例。

六、主函数:模型训练与测试

def main():train_loader, test_loader = get_data_loader()net = Net()loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(net.parameters(), lr=0.001)print("初始准确率:", evaluate(test_loader, net))for epoch in range(5): net.train()running_loss = 0.0for i, (inputs, labels) in enumerate(train_loader):inputs = inputs.view(-1, 28*28)optimizer.zero_grad()outputs = net(inputs)loss = loss_fn(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{5}], Batch [{i+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')running_loss = 0.0accuracy = evaluate(test_loader, net)print(f'Epoch {epoch+1}, 测试准确率: {accuracy:.4f}')net.eval()fig, axes = plt.subplots(1, 5, figsize=(15, 3))for i, (images, labels) in enumerate(test_loader, 1):if i > 5: breakwith torch.no_grad():outputs = net(images[0].view(-1, 28*28))_, predicted = torch.max(outputs, 1)axes[i-1].imshow(images[0].view(28, 28), cmap='gray')axes[i-1].set_title(f'预测: {predicted.item()}')axes[i-1].axis('off')plt.tight_layout()plt.show()

在main函数里,先获取数据加载器,创建模型,定义损失函数nn.CrossEntropyLoss()(它结合了 Softmax 和交叉熵损失计算)和优化器torch.optim.Adam(用来更新模型参数)。然后进入训练循环,epoch表示训练的轮数,每轮中遍历训练数据,通过前向传播、计算损失、反向传播和参数更新,不断调整模型参数。训练过程中打印损失值,每轮结束后评估模型准确率。最后可视化 5 张测试图片的预测结果,直观看到模型的识别效果。 

if __name__ == "__main__":main()

这行代码确保只有直接运行脚本时才执行main函数,避免被其他脚本导入时意外执行。​

通过以上步骤,我们就完成了 MNIST 手写数字识别模型的搭建、训练和测试。希望这篇博客能帮助小白们理解深度学习的基本流程,快动手试试,开启你的 AI 探索之旅吧!如果在实践过程中有任何问题,欢迎在评论区交流~

彩蛋:

点赞+收藏+关注一键三联+评论“冲冲冲”可免费获取下面的项目实战大礼包!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/84059.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大数据量下Redis分片的5种策略

随着业务规模的增长,单一Redis实例面临着内存容量、网络带宽和计算能力的瓶颈。 分片(Sharding)成为扩展Redis的关键策略,它将数据分散到多个Redis节点上,每个节点负责整个数据集的一个子集。 本文将分享5种Redis分片策略。 1. 取模分片(M…

CentOS 7上搭建高可用BIND9集群指南

在 CentOS 7 上搭建一个高可用的 BIND9 集群通常涉及以下几种关键技术和策略的组合:主从复制 (Master-Slave Replication)、负载均衡 (Load Balancing) 以及可能的浮动 IP (Floating IP) 或 Anycast。 我们将主要关注主从复制和负载均衡的实现,这是构成高…

LangChain4j入门AI(六)整合提示词(Prompt)

前言 提示词(Prompt)是用户输入给AI模型的一段文字或指令,用于引导模型生成特定类型的内容。通过提示词,用户可以告诉AI“做什么”、 “如何做”以及“输出格式”,从而在满足需求的同时最大程度减少无关信息的生成。有…

【MySQL】笔记

📚 博主的专栏 🐧 Linux | 🖥️ C | 📊 数据结构 | 💡C 算法 | 🅒 C 语言 | 🌐 计算机网络 在ubuntu中,改配置文件: sudo nano /etc/mysql/mysql.conf.d/mysq…

TDengine 运维—容量规划

概述 若计划使用 TDengine 搭建一个时序数据平台,须提前对计算资源、存储资源和网络资源进行详细规划,以确保满足业务场景的需求。通常 TDengine 会运行多个进程,包括 taosd、taosadapter、taoskeeper、taos-explorer 和 taosx。 在这些进程…

OpenCV CUDA模块图像过滤------创建一个盒式滤波器(Box Filter)函数createBoxFilter()

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 cv::cuda::createBoxFilter 是 OpenCV CUDA 模块中的一个工厂函数,用于创建一个 盒式滤波器(Box Filter)&…

谷歌I/O 2025 完全指南:由Gemini开创的AI新时代及其对我们未来的影响

在这个朝着智能化一切狂奔的世界中,谷歌I/O 2025不仅展示了人工智能创新——它传递了一个明确的信息:未来已来临,而且由Gemini驱动。 从突破性的模型进步到沉浸式通信工具和个性化AI助手,谷歌正在重塑人机交互的本质。让我们一起了解最重大的公告及其对开发者、用户和AI生…

5000 字总结CSS 中的过渡、动画和变换详解

CSS 中的过渡、动画和变换详解 一、CSS 过渡(Transitions) 1. 基本概念 CSS 过渡是一种平滑改变 CSS 属性值的机制,允许属性值在一定时间内从一个值逐渐变化到另一个值,从而创建流畅的动画效果。过渡只能用于具有中间值的属性(如颜色、大小、位置等),不能用于 displa…

【图像生成大模型】CogVideoX-5b:开启文本到视频生成的新纪元

CogVideoX-5b:开启文本到视频生成的新纪元 项目背景与目标模型架构与技术亮点项目运行方式与执行步骤环境准备模型加载与推理量化推理 执行报错与问题解决内存不足模型加载失败生成质量不佳 相关论文信息总结 在人工智能领域,文本到视频生成技术一直是研…

辨析Spark 运行方式、运行模式(master)、部署方式(deploy-mode)

为了理清 Spark 运行方式、部署模式(master)、部署方式(deploy-mode) 之间的关系,我们先明确几个核心概念,再对比它们的联系与区别。 一、核心概念解析 1. Spark 运行方式(代码执行方式&#…

从芯片互连到机器人革命:英伟达双线出击,NVLink开放生态+GR00T模型定义AI计算新时代

5月19日,在台湾举办的Computex 2025上,英伟达推出新技术“NVLink Fusion”,允许非英伟达CPU和GPU,同英伟达产品以及高速GPU互连技术NVLink结合使用,加速AI芯片连接。新技术的推出旨在保持英伟达在人工智能开发和计算领…

04算法学习_209.长度最小的子数组

04算法学习_209.长度最小的子数组题目描述:个人代码:学习思路:第一种写法:题解关键点: 第二种写法:题解关键点: 个人学习时疑惑点解答: 04算法学习_209.长度最小的子数组 力扣题目链…

【已解决】docker search --limit 1 centos Error response from daemon

在docker search的时候你是否遇到过这样的问题? Error response from daemon: Get "https://index.docker.io/v1/search?qcentos&n1": dial tcp 103.56.16.112:443: i/o timeout解决方案 可以尝试一下加一层docker镜像代理: 以mysql:5.…

vue好用插件

自动导入插件 cnpm i -D unplugin-auto-import配置 //在vite.config.js文件加入AutoImport({imports:["vue","vue-router","pinia"],dts:true,}),

算法--js--电话号码的字母组合

题:给定一个仅包含数字 2-9 的字符串,返回所有它能表示的字母组合。答案可以按 任意顺序 返回。给出数字到字母的映射如下(与电话按键相同)。注意 1 不对应任何字母。 function letterCombinations (digits){if (!digits.length)…

OSI 网络七层模型中的物理层、数据链路层、网络层

一、OSI 七层模型 物理层、数据链路层、网络层、传输层、会话层、表示层、应用层 1. 物理层(Physical Layer) 功能:传输原始的比特流(0和1),通过物理介质(如电缆、光纤、无线电波)…

Linux 文件(3)

文章目录 1. Linux下一切皆文件2. 文件缓冲区2.1 缓冲区是什么2.2 缓冲区的刷新策略2.3 为什么要有缓冲区2.4 一个理解缓冲区刷新的例子 3. 标准错误 1. Linux下一切皆文件 在刚开始学习Linux的时候,我们就说Linux下一切皆文件——键盘是文件,显示器是文…

STM32之串口通信蓝牙(BLE)

一、串口通信的原理与应用 通信的方式 处理器与外部设备之间或者处理器与处理器之间通信的方式分两种:串行通信和并行通信。 串行通信 传输原理:数据按位依次顺序传输(每一位占据固定的时间长度 MSB or LSB) 优点&#xff1a…

基于python的机器学习(七)—— 数据特征选择

目录 一、特征选择概念 二、特征选择的方法 2.1 过滤式特征选择 2.1.1 方差分析 2.1.2 相关系数 2.1.3 卡方检验 2.2 包裹式特征选择 2.2.1 递归特征消除 2.3 嵌入式特征选择 2.3.1 决策树特征重要性 一、特征选择概念 特征选择是机器学习非常重要的一个步骤&#x…

《AI工程技术栈》:三层结构解析,AI工程如何区别于ML工程与全栈工程

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…