Softmax回归模型

news/2025/10/29 17:04:13/文章来源:https://www.cnblogs.com/ldra/p/19174921

这段代码是一个完整的 Softmax回归模型 实现,用于解决 Fashion-MNIST数据集的图像分类问题。简单来说,它的作用是:让计算机通过学习大量衣服、鞋子等服饰图片,学会识别新的服饰图片属于哪一类(比如T恤、裤子、运动鞋等,共10个类别)。

具体来说,代码做了这几件事:

  1. 准备数据:加载Fashion-MNIST数据集(包含6万张训练图和1万张测试图),并把图片转换成模型能处理的格式(张量),再分成批量供模型训练。
  2. 定义模型:用Softmax回归模型(一种简单的神经网络),先把28×28的图片“摊平”成784个数字,再通过一个全连接层输出10个结果(对应10个类别的“分数”)。
  3. 设置训练工具:用“交叉熵损失函数”衡量模型预测的错误程度,用“随机梯度下降(SGD)”优化器调整模型参数,让模型慢慢学会正确分类。
  4. 训练模型:重复10轮训练,每轮都用训练集数据让模型“学习”:
    • 先让模型预测图片类别,计算错误(损失);
    • 再根据错误调整模型参数;
    • 最后用测试集检验模型的分类能力(准确率)。
  5. 输出结果:每轮训练后打印损失值和准确率,最终模型在测试集上的准确率约85%,说明它能较好地识别服饰类别。
    简单讲,这是一个“教计算机认衣服”的程序,用的是最基础的深度学习方法(Softmax回归),适合入门理解神经网络的训练过程。

1.导入需要的工具库

点击查看代码
import torch  # PyTorch核心库:处理张量(类似数组但能在GPU上运行)和深度学习基础功能
from torch import nn  # 神经网络模块:提供各种层(如全连接层)、损失函数等
from torch.utils.data import DataLoader  # 数据加载器:帮我们批量处理数据,方便训练
from torchvision import datasets, transforms  # 计算机视觉工具:提供现成数据集和数据转换工具
from d2l import torch as d2l  # D2L工具库:提供一些辅助函数(如计算准确率的工具)
2.数据加载与预处理
点击查看代码
# 数据转换规则:把图像转成Tensor格式(PyTorch能处理的格式),同时自动把像素值从0-255变成0-1
transform = transforms.Compose([transforms.ToTensor()])
batch_size = 256  # 每次训练时一次性喂给模型256张图片(批量大小)# 加载训练数据集(Fashion-MNIST,包含6万张衣服、鞋子等图片)
train_dataset = datasets.FashionMNIST(root='./data',        # 数据存在当前文件夹的data文件夹里train=True,           # 这是训练集(用来训练模型的)download=True,        # 如果本地没有数据,就自动下载(约30MB)transform=transform   # 用上面定义的规则处理图片
)# 加载测试数据集(1万张图片,用来检验模型好坏)
test_dataset = datasets.FashionMNIST(root='./data',train=False,          # 这是测试集(不参与训练,只用来评估)download=True,transform=transform
)# 创建训练数据加载器(把数据分成批次,打乱顺序)
train_iter = DataLoader(train_dataset,batch_size=batch_size,  # 每次256张shuffle=True,           # 训练时打乱顺序,让模型学的更全面num_workers=0           # 单进程加载(Windows系统用多进程会出错,所以设为0)
)# 创建测试数据加载器(不需要打乱顺序)
test_iter = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,          # 测试时不需要打乱,按顺序来就行num_workers=0
)
3.模型定义与参数初始化
点击查看代码
# 定义模型:用Sequential把层按顺序拼起来
net = nn.Sequential(nn.Flatten(),          # 展平层:把28×28的图片(二维)变成784个数字的一维数组nn.Linear(784, 10)     # 全连接层:把784个输入转换成10个输出(对应10个类别)
)# 定义参数初始化函数(给模型的"权重"赋值初始值)
def init_weights(m):# 如果是全连接层(nn.Linear),就初始化它的权重if type(m) == nn.Linear:# 权重用均值0、标准差0.01的正态分布随机数初始化nn.init.normal_(m.weight, std=0.01)# 偏置(类似y=ax+b里的b)默认初始化为0,不用额外设置# 把上面的初始化规则应用到模型的所有层
net.apply(init_weights)
4.损失函数与优化器
点击查看代码
# 定义损失函数:CrossEntropyLoss(自带Softmax功能)
# 作用:比较模型预测结果和真实标签的差距,输出"损失值"
# reduction='none':返回每个样本的损失,不自动求平均
loss = nn.CrossEntropyLoss(reduction='none')# 定义优化器:SGD(随机梯度下降,最常用的优化方法)
# 作用:根据损失调整模型的参数(权重和偏置)
trainer = torch.optim.SGD(net.parameters(),  # 需要优化的参数(模型里的权重和偏置)lr=0.1             # 学习率:控制参数调整的幅度(0.1是比较经典的值)
)
5. 模型评估函数(计算准确率)
点击查看代码
def evaluate_accuracy(net, data_iter):# 如果模型是PyTorch的标准模型,就切换到"评估模式"# (有些层如Dropout在训练和测试时行为不同,这里确保是测试模式)if isinstance(net, torch.nn.Module):net.eval()  # 切换到评估模式# 累加器:记录2个数据——正确预测的数量、总样本数量metric = d2l.Accumulator(2)# 关闭梯度计算(评估时不需要训练,节省内存)with torch.no_grad():# 遍历数据集中的每一批数据for X, y in data_iter:# 计算当前批次中正确预测的数量,累加到metric# d2l.accuracy(net(X), y):模型预测结果和真实标签对比,返回正确数# y.numel():当前批次的总样本数(比如256)metric.add(d2l.accuracy(net(X), y), y.numel())# 准确率 = 正确预测数 / 总样本数return metric[0] / metric[2]
6.模型训练循环
点击查看代码
num_epochs = 10  # 训练10轮(把整个训练集重复用10次来训练)# 遍历每一轮训练
for epoch in range(num_epochs):# 累加器:记录3个数据——总损失、正确预测数、总样本数metric = d2l.Accumulator(3)net.train()  # 切换到训练模式(和评估模式对应,规范写法)# 遍历训练集中的每一批数据for X, y in train_iter:trainer.zero_grad()  # 梯度清零:每次计算前把上一轮的梯度清空y_hat = net(X)       # 前向传播:把输入X喂给模型,得到预测结果y_hatl = loss(y_hat, y)   # 计算损失:比较预测y_hat和真实标签y的差距l.mean().backward()  # 反向传播:计算每个参数的梯度(损失对参数的影响)trainer.step()       # 优化器更新参数:根据梯度调整权重和偏置# 关闭梯度计算,累加当前批次的指标with torch.no_grad():metric.add(l.sum(),          # 累加当前批次的总损失d2l.accuracy(y_hat, y),  # 累加当前批次的正确预测数y.numel()         # 累加当前批次的总样本数)# 计算当前轮次的关键指标train_loss = metric[0] / metric[2]  # 平均训练损失 = 总损失 / 总样本数train_acc = metric[1] / metric[2]   # 训练集准确率 = 正确数 / 总样本数test_acc = evaluate_accuracy(net, test_iter)  # 测试集准确率# 打印训练日志(保留3位小数,方便观察)print(f"epoch {epoch + 1:2d} | loss: {train_loss:.3f} | train_acc: {train_acc:.3f} | test_acc: {test_acc:.3f}")
7.预期输出
点击查看代码
epoch  1 | loss: 0.785 | train_acc: 0.746 | test_acc: 0.796
...
epoch 10 | loss: 0.447 | train_acc: 0.849 | test_acc: 0.851

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/950071.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Oracle的connect by level在MySQL中的华丽变身 - 详解

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

handsontable实现新增删除行(双行)

handsontable实现新增删除行(双行)// 配置方法const tableSettings = computed(() => {return {...hotTableParams,nestedHeaders: false,filters: false,columnSorting: false,height: 358,rowHeaders: false,co…

2025年国产角接触球轴承厂家推荐 一文了解轴承厂家选择标准

角接触球轴承在机械设备中的应用广泛,特别是精密机床、高速电机、电主轴等运行要求高的场合,更需要质量品质好一点的角接触轴承。想要轴承用的好,就得找到合适的生产厂家,下面就来推荐下2025年值得信任的国产角接触…

vxe-table 树形表格显示连接线的方式

vxe-table 树形表格显示连接线的方式 完整连接线 通过 tree-config.showLine 来启用是否显示连接线<template><div><vxe-grid v-bind="gridOptions"></vxe-grid></div> </…

2025年上海衣帽间定制机构权威推荐榜单:衣帽间设计/衣帽间十大品牌/衣帽间装修源头公司精选

在上海,一个规划合理的衣帽间正成为新兴住宅的标配。数据显示,2024年中国家装行业市场规模已突破860亿元,其中定制家具份额持续增长,而衣帽间作为定制家具的重要组成部分,正受到越来越多消费者的青睐。 01 行业趋…

在Web应用开发中状态到底是什么?

在计算机科学中,“状态”(State)这个词经常出现在讨论有状态(Stateful)和无状态(Stateless)系统、服务或组件时。要理解“状态”到底是什么,我们可以从最基本的层面来解释。一、什么是“状态”? 简单来说,“…

前后端不分离的springboot应用,静态文件修改了不更新的问题

当然,还有不依赖idea的解决方案,就是静态文件通过nginx来代理,直接将js和css这些文件代理到我们的代码目录,这样我们修改了代码目录后,配置就立马生效了。这样可以不依赖idea的版本,假如你的idea怎么设置热更新都…

Cookie与缓存的区别

一、本质定义 Cookie:客户端轻量化文本存储,存用户身份、网站偏好等会话相关数据,容量约4KB。 缓存:客户端/服务器临时存储,存网页静态资源(图、JS/CSS),容量几MB到几十GB。 二、3大核心区别 1. 存储内容:Coo…

2025 年铝卷厂家最新推荐榜,聚焦企业技术实力与市场口碑深度解析铝板铝卷/铝卷板/橘皮铝卷/压花铝卷/防锈铝卷/花纹铝卷公司推荐

引言 本次 2025 年铝卷产品推荐榜,由有色金属工业协会联合行业权威检测机构共同测评制定,测评过程严格遵循《铝及铝合金轧制卷材行业质量评价标准》。测评团队从企业生产实力、产品品质、服务能力三大维度入手,涵盖…

无人机航测界的强者——Pix4Dmapper 4.5.6使用教程+图文步骤

软件介绍 Pix4Dmapper 4.5.6是一款专业的无人机航空三维建模软件。它可以将通过无人机拍摄的照片转化为三维地图、模型、点云、高精度的数字高程模型(DEM)、数字表面模型(DSM)和正射影像(DOM)等。该软件具有自动…

qml与html通信

1. 在qml显示html并通讯main.qmlimport QtQuick 2.12 import QtQuick.Window 2.12 import QtWebEngine 1.2 import QtQuick.Controls Window {id:mainWindowwidth: 640height: 480visible: truetitle: qsTr("WebE…

2025 年排烟风机厂家最新推荐榜,技术实力与市场口碑深度解析,筛选高性能低噪音优质企业屋顶/双速/离心式/防排烟风机公司推荐

引言 为助力建筑行业精准选购排烟风机,本次榜单由暖通空调工业协会联合消防设备质量监督检验中心共同测评发布。测评采用 “三维九项” 评估体系,从技术实力(耐高温性能、噪音控制、节能效率)、市场口碑(客户满意…

2025 年建筑模型公司最新推荐榜,技术实力与市场口碑深度解析含沙盘、微缩、高端模型品牌

引言 随着建筑行业数字化升级,建筑模型需求向高精度、智能化、定制化加速迭代,据建筑装饰协会 2025 年行业测评报告显示,超 68% 的高端项目对模型精度要求达 0.1mm 级,且 82% 的客户将 “技术创新能力” 列为选择合…

Session、Cookie、Token 区别

一、本质理解 1. Cookie:就是浏览器里的“小记事本”,只能存少量东西(约4KB),比如记个“身份编号”,本身不是身份,只是个装信息的小本子。 2. Session:相当于服务器里的“用户档案柜”。你登录后,服务器给你编…

2025 年聚脲厂家最新推荐榜,技术实力与市场口碑深度解析,精选行业优质企业聚脲防腐/单组分双组分聚脲/MUL 聚脲/聚脲防水公司推荐

引言 在材料保护领域,聚脲产品因卓越性能需求持续攀升,为精准筛选优质企业,建筑材料联合会防腐保温材料分会联合行业权威机构开展 2025 年聚脲企业测评。测评涵盖生产规模(年产能、厂房面积)、技术实力(专利数量…

Flask零基础入门:5步搭建你的第一个Web应用

本文详细介绍了Flask框架的入门使用方法,包括环境搭建、路由配置、模板渲染和数据库集成等核心内容,通过完整代码示例帮助读者快速掌握Web开发基础。你想快速搭建一个Web应用,却总被复杂框架吓退?😫 惊人事实:F…

2025 年红外测温仪厂家最新推荐榜,技术实力与市场口碑深度解析比色/感应加热/高性价比/单晶炉红外测温仪公司推荐

引言 在工业制造、光伏能源、半导体等核心领域的智能化升级进程中,红外测温仪作为关键检测设备,其精度与稳定性直接决定生产安全与产品合格率。据中国仪器仪表行业协会 2025 年《工业测温设备专项测评报告》显示,国…

2025 年真空计厂家最新推荐榜,技术实力与市场口碑深度解析,涵盖压阻硅、薄膜硅等多类型产品皮拉尼真空计/单晶炉真空计公司推荐

引言 在半导体、光伏等高端制造领域,真空计作为核心测量设备,直接影响生产精度与效率。据仪器仪表行业协会 2025 年报告显示,国内真空计市场规模已达 86.3 亿元,国产化率提升至 45%,但市场产品良莠不齐,37% 的企…

2025年10月企业网站建设开发公司排行榜:前十名精选

摘要 随着数字化转型加速,2025年企业网站建设行业持续增长,中小企业对高效、可靠的建站服务需求激增。本文基于市场调研和用户口碑,整理出排名前十的企业网站建设服务商列表,并提供详细比较和选择指南,帮助企业主…

2025年企业网站建设开发公司口碑排行榜Top 10

摘要 企业网站建设行业在2025年持续蓬勃发展,随着数字化转型加速,中小企业对高效、可靠的建站服务需求激增。行业趋势显示,集成AI技术、移动端优化和全渠道营销成为主流。本文基于市场调研和用户反馈,整理了2025年…