通过代码认识 CNN:用 PyTorch 构建卷积神经网络识别手写数字

news/2025/9/21 9:30:23/文章来源:https://www.cnblogs.com/tlnshuju/p/19103198

通过代码认识 CNN:用 PyTorch 构建卷积神经网络识别手写数字

2025-09-21 09:29  tlnshuju  阅读(0)  评论(0)    收藏  举报

目录

一、从代码看 CNN 的核心组件

二、准备工作:库导入与数据加载

三、核心:用代码实现 CNN 并理解各层作用

1.网络层结构

2.重点理解:卷积层参数与输出尺寸计算

四、训练 CNN

五、结果分析


卷积神经网络(CNN)是计算机视觉领域的核心模型,相比全连接网络,它能更高效地提取图像特征。本文不空谈理论,而是通过 PyTorch 代码实现一个完整的 CNN 模型,带你在实战中理解卷积、池化等核心概念,掌握 CNN 的工作原理。


一、从代码看 CNN 的核心组件

在实现模型前,先明确 CNN 的三个核心层 —— 这些是区别于全连接网络的关键,后续代码会逐一对应:

  1. 卷积层(Conv2d):通过滑动窗口提取局部特征(如边缘、纹理);
  2. 激活层(ReLU):引入非线性,让模型学习复杂模式;
  3. 池化层(MaxPool2d):降低特征图尺寸,减少计算量,增强鲁棒性。

我们将用这些组件构建一个识别 MNIST 手写数字的 CNN 模型,边写代码边解释原理。


二、准备工作:库导入与数据加载

首先导入必要的库,加载 MNIST 数据集(28×28 的手写数字图片):

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
# 加载MNIST数据集
training_data = datasets.MNIST(
root="data",
train=True,
download=True,
transform=ToTensor()  # 转为张量,形状为[1,28,28]
)
test_data = datasets.MNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
# 按批次加载数据(每批64张图)
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
# 查看数据形状([批次, 通道, 高度, 宽度])
for X, y in test_dataloader:
print(f"数据形状: {X.shape}")  # 输出:torch.Size([64, 1, 28, 28])
break

关键说明:MNIST 图片是单通道(灰度图),所以输入形状为[N,1,28,28](N 为批次大小),这会影响后续卷积层的参数设置。


三、核心:用代码实现 CNN 并理解各层作用

1.网络层结构

我们构建一个包含 4 个卷积块的 CNN 模型,每个卷积块由 “卷积层 + 激活层” 组成,部分块后添加池化层。通过代码注释详细说明每层的作用和参数含义。

# 自动选择设备(优先GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 第一个卷积块:卷积层+激活层+池化层
self.conv1 = nn.Sequential(
# 卷积层:输入1通道,输出16通道,卷积核5×5,步长1,填充2
nn.Conv2d(
in_channels=1,    # 输入通道数(灰度图为1)
out_channels=16,  # 输出通道数(16个不同的卷积核)
kernel_size=5,    # 卷积核大小5×5
stride=1,         # 步长1(每次滑动1个像素)
padding=2         # 填充2(保持输出尺寸与输入一致:28→28)
),
nn.ReLU(),  # 激活层:引入非线性,过滤负值
# 池化层:2×2窗口,步长2,输出尺寸变为14×14(28/2)
nn.MaxPool2d(kernel_size=2, stride=2)
)
# 第二个卷积块:卷积层+激活层(无池化)
self.conv2 = nn.Sequential(
# 输入16通道(上一层输出),输出32通道,卷积核3×3
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU()  # 输出尺寸保持14×14
)
# 第三个卷积块:卷积层+激活层+池化层
self.conv3 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2)  # 输出尺寸变为7×7(14/2)
)
# 第四个卷积块:卷积层+激活层(无池化)
self.conv4 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU()  # 输出尺寸保持7×7
)
# 全连接层:将特征图转为10个类别(0-9)
self.fc = nn.Linear(128 * 7 * 7, 10)  # 128通道×7×7尺寸
def forward(self, x):
# 前向传播:数据依次经过各层
x = self.conv1(x)  # 输出形状:[N,16,14,14]
x = self.conv2(x)  # 输出形状:[N,32,14,14]
x = self.conv3(x)  # 输出形状:[N,64,7,7]
x = self.conv4(x)  # 输出形状:[N,128,7,7]
x = x.view(x.size(0), -1)  # 展平:[N,128×7×7]
x = self.fc(x)     # 输出形状:[N,10](10个类别分数)
return x
# 创建模型并移动到设备
model = CNN().to(device)
print("CNN模型结构:")
print(model)

2.重点理解:卷积层参数与输出尺寸计算

以第一个卷积层为例,输入是[64,1,28,28](64 张图,1 通道,28×28),经过kernel_size=5, padding=2, stride=1的卷积后,输出尺寸计算公式:

输出尺寸 = (输入尺寸 - 卷积核大小 + 2×填充) / 步长 + 1
即:(28 - 5 + 2×2)/1 + 1 = 28

所以输出仍为 28×28,再经 2×2 池化后变为 14×14—— 这就是卷积层如何在保留特征的同时控制尺寸的核心逻辑。


四、训练 CNN

CNN 的训练流程和全连接网络类似,我们将训练轮次调整为 10 轮,既能保证模型收敛,又能节省训练时间。定义训练和测试函数如下:

# 损失函数(多分类用交叉熵)
loss_fn = nn.CrossEntropyLoss()
# 优化器(Adam,学习率0.0001)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# 训练函数
def train(dataloader, model, loss_fn, optimizer):
model.train()  # 开启训练模式
batch_num = 1
for X, y in dataloader:
X, y = X.to(device), y.to(device)
# 前向传播:计算预测
pred = model(X)
loss = loss_fn(pred, y)
# 反向传播:更新参数
optimizer.zero_grad()  # 梯度清零
loss.backward()        # 计算梯度
optimizer.step()       # 更新参数
# 每100批次打印一次损失
if batch_num % 100 == 1:
print(f"批次 {batch_num} | 损失: {loss.item():.4f}")
batch_num += 1
# 测试函数
def test(dataloader, model, loss_fn):
model.eval()  # 开启测试模式
size = len(dataloader.dataset)
num_batches = len(dataloader)
correct = 0
test_loss = 0
with torch.no_grad():  # 禁用梯度计算
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
# 计算准确率和平均损失
test_loss /= num_batches
correct /= size
print(f"\n测试集:准确率 {100*correct:.2f}% | 平均损失 {test_loss:.4f}\n")
# 开始训练(10轮)
print("="*50)
print("开始训练CNN模型(10轮)")
print("="*50)
for epoch in range(10):
print(f"轮次 {epoch+1}/10")
print("-"*30)
train(train_dataloader, model, loss_fn, optimizer)
# 每2轮测试一次
if (epoch+1) % 2 == 0:
test(test_dataloader, model, loss_fn)
print("="*50)
print("训练结束")

五、结果分析

轮次 10/10
------------------------------
批次 1 | 损失: 0.0002
批次 101 | 损失: 0.0000
批次 201 | 损失: 0.0015
批次 301 | 损失: 0.0190
批次 401 | 损失: 0.0003
批次 501 | 损失: 0.0008
批次 601 | 损失: 0.0001
批次 701 | 损失: 0.0065
批次 801 | 损失: 0.0019
批次 901 | 损失: 0.0310
测试集:准确率 99.17% | 平均损失 0.0355
==================================================
训练结束

即使只训练 10 轮,CNN 在测试集上的准确率通常也能达到99% 以上,明显高于同轮次的全连接网络。这体现了 CNN 的高效性,原因在于:

  1. 局部感受野:卷积层通过滑动窗口只关注局部像素,更符合图像的局部相关性;
  2. 权值共享:同一通道的卷积核参数共享,大幅减少参数数量(全连接层 784→128 需要近 10 万个参数,而 5×5 的卷积层 1→16 仅需 400 个参数);
  3. 池化层:通过下采样保留关键特征,增强模型对图像位移、缩放的鲁棒性。

这些特性让 CNN 在较少的训练轮次下就能达到较好的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/908682.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQLite数据库 - 教程

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

安全技术深度探讨:从鱿鱼皮肤到AI漏洞挖掘

本文探讨了鱿鱼皮肤的光学特性及其潜在安全隐喻,并深入分析了AI在漏洞挖掘、硬件侧信道攻击、智能合约审计等领域的应用与风险,涉及多项实际技术案例与安全架构思考。周五鱿鱼博客:鱿鱼皮肤如何扭曲光线 新研究显示…

【Bluedroid】A2DP Source 音频流暂停流程解析[3]:AVDTP 协议中 Suspend Accept 响应的处理流程与建立分析(Suspend Accept)

【Bluedroid】A2DP Source 音频流暂停流程解析[3]:AVDTP 协议中 Suspend Accept 响应的处理流程与建立分析(Suspend Accept)pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !…

实用指南:【Linux篇章】再续传输层协议UDP :从低可靠到极速传输的协议重生之路,揭秘无连接通信的二次进化密码!

实用指南:【Linux篇章】再续传输层协议UDP :从低可靠到极速传输的协议重生之路,揭秘无连接通信的二次进化密码!2025-09-21 09:14 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word…

数据建模流程分析

📋 完成的工作 我已经为您创建了一个完整的高速列车轴承智能故障诊断系统,包含以下核心组件: 1. 数据预处理模块 (data_preprocessing.py)✅ 支持.mat文件解析✅ 多采样率统一处理(12kHz/48kHz/32kHz)✅ 时域特征…

第四章:大模型(LLM)】08.Agent 教程-(7)使用 LangGraph 的作文评分架构

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

详细介绍:您必须知道的 10 大 Highcharts 性能优化技巧—— 提升加载速度与交互体验的实战建议

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

PHP8.5 Pipeline Operator 你应该了解的 8 个特性

PHP8.5 将在今年 11 月份发布Pipeline Operator (|>) 是 PHP 中一个令人兴奋的新特性,它从函数式编程中汲取灵感。它提供了一种干净、可读且富有表现力的方式来链接多个操作,无需嵌套括号或创建不必要的中间变量。…

Nvidia Orin DK 本地 ollama 主流 20GB 级模型 gpt-oss, gemma3, qwen3 部署与测试 - 实践

Nvidia Orin DK 本地 ollama 主流 20GB 级模型 gpt-oss, gemma3, qwen3 部署与测试 - 实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !importa…

Mysql查询条件里的字符串不加引导索引失效

View PostMysql查询条件里的字符串不加引导索引失效因为类型不一致,mysql做了隐式转换,就会导致索引失效

详细介绍:在Ubuntu平台搭建RTMP直播服务器使用SRS简要指南

详细介绍:在Ubuntu平台搭建RTMP直播服务器使用SRS简要指南pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consol…

实用指南:在 k8s 上部署 Kafka 4.0 3节点集群

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

Django HttpRequest 对象的常用属性 - 指南

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

完整教程:Ajax-day2(图书管理)-弹框显示和隐藏

完整教程:Ajax-day2(图书管理)-弹框显示和隐藏pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas",…

实用指南:C语言基础【20】:指针7

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

civil 3d com api 帮助文档

以前很容易搜到, 不知为什么现在搜不到了。 Getting Started

完整教程:【教程4>第8章>第28节】OFDM完整通信链路项目FPGA开发22——提取导频

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

实用指南:万字详解架构设计:业务架构、应用架构、数据架构、技术架构、单体、分布式、微服务都是什么?

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

WebSockets与Socket.io渗透测试实战指南

本文深入探讨如何通过降级WebSocket通信至HTTP协议实现安全测试,涵盖Socket.io传输机制滥用、协议升级中断技术及Burp Suite高级会话管理配置,提供可实操的渗透测试方法。如何渗透测试WebSockets与Socket.io Ethan R…

深入解析:spring boot3.0整合rabbitmq3.13

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …