多层感知机笔记

news/2025/10/20 17:39:17/文章来源:https://www.cnblogs.com/inian/p/19153221

Fashion-MNIST 分类任务代码笔记

一、整体概述

本代码基于 PyTorch 实现了一个简单的全连接神经网络,用于解决 Fashion-MNIST 图像分类任务(10个类别)。核心流程包括:网络定义、权重初始化、超参数设置、数据加载、训练循环实现及模型评估。

二、代码分块解析

(一)导入依赖库

import torch
from torch import nn
from d2l import torch as d2l
  • 核心库说明
    • torch:PyTorch 核心库,提供张量操作、自动求导等基础功能。
    • torch.nn:PyTorch 神经网络模块,包含层、损失函数等组件。
    • d2l.torch:《动手学深度学习》工具库,提供数据加载等便捷功能。

(二)定义神经网络结构

net = nn.Sequential(nn.Flatten(),          # 将28x28图像展平为784维向量nn.Linear(784, 256),   # 隐藏层:784→256nn.ReLU(),             # 激活函数nn.Linear(256, 10)     # 输出层:256→10(10个类别)
)
  • 网络组件详解
    1. nn.Flatten():图像预处理层,将输入的 2D 图像张量(28×28)转换为 1D 向量(784 维),适配全连接层输入格式。
    2. nn.Linear(784, 256):全连接隐藏层,接收 784 维输入,输出 256 维特征,通过矩阵乘法实现特征映射。
    3. nn.ReLU():激活函数,引入非线性,公式为 ReLU(x) = max(0, x),解决线性模型表达能力不足的问题。
    4. nn.Linear(256, 10):全连接输出层,将 256 维特征映射到 10 维,对应 10 个类别的原始得分(logits)。

(三)权重初始化函数

# 初始化权重
def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)  # 正态分布初始化权重net.apply(init_weights);
  • 功能说明
    • 定义 init_weights 函数,仅对全连接层(nn.Linear)进行权重初始化。
    • 使用 nn.init.normal_ 按正态分布 N(0, 0.01²) 初始化权重,避免初始权重过大/过小导致训练不稳定(如梯度消失/爆炸)。
    • net.apply(init_weights):递归遍历网络所有层,对符合条件的层执行初始化操作。

(四)超参数与训练组件设置

# 超参数
batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss()  # 交叉熵损失(含Softmax)
trainer = torch.optim.SGD(net.parameters(), lr=lr)  # 随机梯度下降优化器
  • 关键组件说明
    1. 超参数
      • batch_size=256:每次训练迭代的样本数量,平衡训练速度与稳定性。
      • lr=0.1:学习率,控制参数更新幅度。
      • num_epochs=10:训练轮次,即遍历整个训练集的次数。
    2. 损失函数nn.CrossEntropyLoss() 适用于多分类任务,内部集成了 Softmax 函数,直接接收网络输出的 logits 计算损失。
    3. 优化器torch.optim.SGD 为随机梯度下降优化器,接收网络参数和学习率,负责更新权重以最小化损失。

(五)数据加载

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
  • 功能说明:通过 d2l 工具库加载 Fashion-MNIST 数据集,返回训练集迭代器 train_iter 和测试集迭代器 test_iter
  • 数据特性:每张图像为 28×28 灰度图,训练集 60000 样本,测试集 10000 样本,共 10 个服装类别。

(六)训练与评估函数实现

1. 单轮训练函数

def train_epoch(net, train_iter, loss, trainer):"""训练一个epoch"""net.train()  # 切换到训练模式(启用Dropout、BatchNorm等训练特性)total_loss = 0.0total_correct = 0total_samples = 0for X, y in train_iter:# 前向传播:计算预测值和损失y_hat = net(X)l = loss(y_hat, y)# 反向传播 + 参数更新trainer.zero_grad()  # 清空上一轮梯度(避免累积)l.backward()         # 自动计算梯度(基于计算图)trainer.step()       # 根据梯度更新参数# 统计训练指标total_loss += l.item() * X.shape[0]  # 累计总损失(乘以批量大小还原真实损失)# 计算正确预测数:argmax(dim=1)取预测概率最大的类别索引total_correct += (y_hat.argmax(dim=1) == y).sum().item()total_samples += X.shape[0]  # 累计处理样本数# 返回平均损失和训练准确率return total_loss / total_samples, total_correct / total_samples

2. 测试集评估函数

def evaluate_accuracy(net, test_iter):"""评估测试集准确率"""net.eval()  # 切换到评估模式(禁用Dropout、固定BatchNorm统计量)total_correct = 0total_samples = 0with torch.no_grad():  # 禁用梯度计算,减少内存占用并加速for X, y in test_iter:y_hat = net(X)total_correct += (y_hat.argmax(dim=1) == y).sum().item()total_samples += X.shape[0]return total_correct / total_samples  # 返回测试准确率

(七)执行训练流程

for epoch in range(num_epochs):train_loss, train_acc = train_epoch(net, train_iter, loss, trainer)test_acc = evaluate_accuracy(net, test_iter)print(f" epoch {epoch+1}:")print(f"  训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.4f}")print(f"  测试准确率: {test_acc:.4f}")
  • 流程说明
    1. 遍历 num_epochs 个训练轮次。
    2. 每轮调用 train_epoch 完成训练集的一次遍历,获取训练损失和训练准确率。
    3. 调用 evaluate_accuracy 评估模型在测试集上的性能。
    4. 打印当前轮次的训练损失、训练准确率和测试准确率,监控模型训练进度。

三、核心知识点总结

  1. 全连接网络结构:通过 nn.Sequential 堆叠层,Flatten 层适配图像输入,Linear 层实现特征映射,ReLU 引入非线性。
  2. 权重初始化:合理的初始化(如小方差正态分布)是训练稳定的关键。
  3. 训练三要素:交叉熵损失(多分类任务)、SGD 优化器(参数更新)、批量训练(效率与稳定性平衡)。
  4. 训练/评估模式net.train()net.eval() 切换网络状态,torch.no_grad() 优化评估过程。
  5. 指标统计:训练损失(平均损失)、准确率(正确预测数/总样本数)用于监控模型性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/941428.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学会使用树研究和实现递归算法

本文章的核心思想来自labuladong的算法笔记网站,加上了我一些自己的学习心得,只用于学习用途。文章中的图片和代码都是原创,非转载。 背景 本人本科审计学,硕士软件工程,目前研究方向是ai在数据库领域的应用(目前…

Sql查询优化方案

Mybatis 分页查询统计方法重写,在查询方法后面固定追加:"_COUNT" 比如: 分页查询方法:pageQuery 重写分页查询统计:pageQuery_COUNT 单表查询统计:select count(1) from xxxdb.t_order sql 查询优化 利…

计算机思维的数与位

计算机思维的数与位Posted on 2025-10-20 17:34 夜owl 阅读(0) 评论(0) 收藏 举报n进制的数与位 在计算机的代码世界中,是以二进制的位的基础来组成数,至此我还是混淆,二级制的1000(十进制的8)中的1是第4个数…

实用指南:深入解析HarmonyOS ArkTS:从语法特性到实战应用

实用指南:深入解析HarmonyOS ArkTS:从语法特性到实战应用pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consol…

2025 防水背衬板厂家最新推荐榜:剖析质量与口碑,优选品牌助您精准采购

引言 在建筑防水工程愈发受重视的当下,防水背衬板的品质直接决定防水层的耐久性与建筑结构安全。但当前市场呈现 “劣币扰市” 乱象:部分企业用劣质原料生产的产品,短期内即出现渗漏、开裂问题,导致后期维护成本激…

如何安装fluentd 和fluentd-mongo的插件?然后收集nginx的 json格式的数据写到mongodb

手动安装 Fluentd + MongoDB 插件并收集 Nginx JSON 日志 以下是完整步骤,从安装 Fluentd 到配置 Nginx JSON 日志存储到 MongoDB。手动安装 Fluentd(td-agent)如果已通过 RPM 安装 td-agent,跳过此步。否则: 下载…

2025年气柱袋厂家推荐排行榜,防震/防摔/食品级气柱袋,奶瓶/奶粉/电子产品/化妆品气柱袋,缓冲包装与物流运输优选方案

2025年气柱袋厂家推荐排行榜:防震/防摔/食品级气柱袋,奶瓶/奶粉/电子产品/化妆品气柱袋,缓冲包装与物流运输优选方案 行业背景与发展趋势 随着电商物流行业的蓬勃发展,气柱袋作为现代包装领域的重要缓冲材料,正经…

详细介绍:EfficientNet:复合缩放

详细介绍:EfficientNet:复合缩放pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco&…

2025 年防火涂料厂家最新推荐排行榜:精选优质企业,涵盖钢结构各类型涂料,助您精准选品

引言 在建筑与工业安全领域,防火涂料是抵御火灾、守护生命财产安全的关键防线。但当前市场乱象频发,部分企业生产的防火涂料未达国家标准,防火性能堪忧;品牌数量繁杂,产品质量差距悬殊,消费者和企业在选购时常常…

Docker 代理配置的迷思:为什么 127.0.0.1 不总是本地? - 若

在使用 Docker 时配置代理是一个常见的需求,但很多开发者都会遇到一个令人困惑的现象:明明代理服务运行在本机,使用 127.0.0.1 却无法正常工作。本文将深入探讨这个问题背后的原理。 问题现象 让我们先看两个相似的…

惠普打印机驱动下载与安装教程(图文详解 + 常见问题解决方案)

本文详细介绍了惠普打印机驱动的下载安装与配置教程,支持 Windows7/10/11 系统。通过官方安全下载渠道,提供全型号兼容驱动与图文安装步骤,并针对打印乱码、驱动不识别、扫描失败等常见问题提供解决方案。无论家用或…

PHP码农的微信业务开发利器

微擎系统:PHP码农的微信业务开发利器 作为一名深耕PHP开发的码农,我深知在微信生态中开发业务系统面临的挑战:接口对接繁琐、多平台适配复杂、功能迭代周期长。直到公司承接微信端业务需求时,我在网上偶然发现微擎…

深入解析:Matlab通过GUI实现点云的PCA配准(附最简版)

深入解析:Matlab通过GUI实现点云的PCA配准(附最简版)2025-10-20 17:26 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; di…

词向量:从 One-Hot 到 BERT Embedding,NLP 文本表示的核心技术 - 实践

词向量:从 One-Hot 到 BERT Embedding,NLP 文本表示的核心技术 - 实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: …

2025年深圳网站建设/外贸独立站推广/阿里巴巴代运营/1688店铺代运营/短视频运营推广/微信小程序开发服务商权威推荐榜

2025年深圳网站建设/外贸独立站推广/阿里巴巴代运营/1688店铺代运营/短视频运营推广/微信小程序开发服务商权威推荐榜 行业背景与发展趋势 随着数字化转型浪潮的深入推进,深圳作为中国科技创新中心,其数字营销服务行…

计算机毕业设计Hadoop+Spatk+Hive滴滴出行分析 出租车供需平衡优化系统 出租车分析预测 大资料毕业设计(源码+LW+PPT+讲解)

计算机毕业设计Hadoop+Spatk+Hive滴滴出行分析 出租车供需平衡优化系统 出租车分析预测 大资料毕业设计(源码+LW+PPT+讲解)2025-10-20 17:23 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !importan…

微信小脚本入门学习教程,从入门到精通,微信小程序开发进阶(7)

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

Android studio build报错 - show

build报错> Task :app:checkDebugAarMetadata FAILED Execution failed for task :app:checkDebugAarMetadata. > Could not resolve all files for configuration :app:debugRuntimeClasspath.> Could not r…

2025 彩石瓦厂家最新推荐排行榜:权威解析金属瓦 / 屋顶瓦优质厂商,金属/屋顶/凉亭/昆明/云南彩石瓦厂家推荐

引言 随着绿色建筑理念深化与屋面材料升级,彩石瓦凭借耐候性强、装饰性佳的优势,成为别墅、文旅项目及新农村建设的核心选材。但市场现状令人担忧:既有产品基材厚度不达标、彩砂脱落等质量隐患,又存在新锐品牌与传…

2025 年最新干燥剂厂家推荐排行榜:深度剖析各品牌实力,涵盖氯化钙 / 氯化镁 / 硅胶等多类型干燥剂优选指南

在工业生产与日常生活中,干燥剂的防潮、保鲜作用愈发关键,小到食品药品储存,大到集装箱海运防潮,都离不开优质干燥剂的支撑。但当前干燥剂市场品牌繁杂,部分小品牌产品吸湿能力弱、持久度差,难以满足不同行业的专…