pytorch简单神经网络模型训练

目录

一、导入包

二、数据预处理

三、定义神经网络

四、训练模型和测试模型

五、程序入口


一、导入包

import torch
import torch.nn as nn
import torch.optim as optim # 导入优化器
from torchvision import datasets, transforms # 导入数据集和数据预处理库
from torch.utils.data import DataLoader # 数据加载库

二、数据预处理

def data_loader():'''数据的预处理'''# 定义数据预处理transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5), (0.5))])# 加载FashionMNIST数据集train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.FashionMNIST(root='.data', train=False, download=True, transform=transform)# 数据集加载器train_loader = DataLoader( train_dataset,batch_size=64, shuffle=True)test_loader = DataLoader(test_dataset,batch_size=64, shuffle=False)return train_loader, test_loader

        这是一个常用于机器学习和深度学习研究中的数据集,包含了10类不同时尚商品的图像,每类有6000张训练图像和1000张测试图像。使用了PyTorch框架中的torchvision库来下载和加载Fashion-MNIST数据集。代码中定义了一个transform,它会将图像转换为张量,并对其进行归一化处理。然后,分别创建训练集和测试集的数据加载器train_loadertest_loader,这些加载器会在训练过程中以批量的形式提供数据。 

三、定义神经网络

# 定义神经网络
class QYNN(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(28 * 28,128)self.fc2 = nn.Linear(128,10)def forward(self, x):# 将数据展平x = torch.flatten(x,start_dim=1)# 激活x,方便数据全联接x = torch.relu(self.fc1(x))# 输出10分类x = self.fc2(x)return x

代码定义了一个简单的全连接神经网络,适用于Fashion-MNIST这样的图像分类任务。这个网络包含两个全连接层(fc1fc2),分别用于特征提取和分类。

这里是您定义的QYNN类的一些解释:

  • __init__方法定义了网络的结构。网络接受28x28像素的灰度图像作为输入,首先通过一个线性层fc1将784个像素值映射到128个特征,然后通过第二个线性层fc2将128个特征映射到10个输出,对应于10个类别

  • forward方法定义了数据通过网络的前向传播过程。输入数据首先被展平成一个一维向量,然后通过fc1层,接着是ReLU激活函数,最后通过fc2层输出每个类别的得分。

四、训练模型和测试模型

训练模型

def train(model, train_loader):'''训练模型'''# 训练轮数epochs = 10for epoch in range(epochs):running_loss = 0.0for inputs, labels in train_loader:# 梯度清零optimizer.zero_grad()# 将图片塞进去outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新参数optimizer.step()# 损失值的累加running_loss += loss.item()print(f'Epoch:{epoch+1}/{epochs} | Loss: {running_loss/len(train_loader)}')

测试模型

test函数使用了torch.no_grad()来禁用梯度计算,因为在测试阶段我们不需要计算梯度。函数遍历测试数据加载器中的每个批次,将输入数据传递给模型以获取输出,然后使用torch.max函数来获取每个样本的最高得分类别作为预测结果。最后,函数计算预测正确的样本数量与总样本数量,从而得到准确率。

def test(model, test_loader):'''测试模型'''correct = 0 # 正确的数量total = 0 # 样本的总量with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs, 1)# 获取本次样本的数量total += labels.size(0)# 预测值 和 标签 相同则正确, 对预测值进行累加correct += (predicted == labels).sum().item()print(f'Test Accuracy: {correct / total:.2%}')

五、程序入口


if __name__ == '__main__':# 设置随机种子torch.manual_seed(21)# 实例化神经网络model = QYNN()# 交叉商criterion = nn.CrossEntropyLoss()# 优化器optimizer = optim.SGD(model.parameters(),lr = 0.01)# 数据集train_loader, test_loader = data_loader()# 训练样本train(model, train_loader)# 测设样本test(model, test_loader)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/831648.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python中type,object,class 三者关系

type,object,class 三者关系 在python中&#xff0c;所有类的创建关系遵循&#xff1a; type -> int -> 1 type -> class -> obj例如&#xff1a; a 1 b "abc" print(type(1)) # <class int> 返回对象的类型 print(type(int)) …

基于OpenCv的图像金字塔

⚠申明&#xff1a; 未经许可&#xff0c;禁止以任何形式转载&#xff0c;若要引用&#xff0c;请标注链接地址。 全文共计3077字&#xff0c;阅读大概需要3分钟 &#x1f308;更多学习内容&#xff0c; 欢迎&#x1f44f;关注&#x1f440;【文末】我的个人微信公众号&#xf…

【讲解如何OpenCV入门】

&#x1f308;个人主页: 程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

需求规格说明书编制书(word原件)

1 范围 1.1 系统概述 1.2 文档概述 1.3 术语及缩略语 2 引用文档 3 需求 3.1 要求的状态和方式 3.2 系统能力需求 3.3 系统外部接口需求 3.3.1 管理接口 3.3.2 业务接口 3.4 系统内部接口需求 3.5 系统内部数据需求 3.6 适应性需求 3.7 安全性需求 3.8 保密性需…

GiantPandaCV | FasterTransformer Decoding 源码分析(二)-Decoder框架介绍

本文来源公众号“GiantPandaCV”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;FasterTransformer Decoding 源码分析(二)-Decoder框架介绍 作者丨进击的Killua 来源丨https://zhuanlan.zhihu.com/p/669303360 编辑丨GiantPand…

【Python编程实践1/3】模块

目录 目标 模块 import ​编辑 代码小结 题目 from...import 随机模块 代码小结 randint函数 骰子大战 choice函数 总结 目标 拧一颗螺丝&#xff0c;只会用到螺丝刀&#xff1b;但是修一台汽车&#xff0c;需要一整套汽修的工具。函数就像螺丝刀&#xff0c;可以帮…

python项目==一个web项目,配置模板指定文件清洗规则,调用模板规则清洗文件

代码地址 一个小工具。 一个web项目&#xff0c;配置模板指定文件清洗规则&#xff0c;调用模板规则清洗文件 https://github.com/hebian1994/csv-transfer-all 技术栈&#xff1a; SQLite python flask vue3 elementplus 功能介绍&#xff1a; A WEB tool for cleaning…

JavaScript:Web APIs(三)

本篇文章的内容包括&#xff1a; 一&#xff0c;事件流 二&#xff0c;移除事件监听 三&#xff0c;其他事件 四&#xff0c;元素尺寸与位置 一&#xff0c;事件流 事件流是什么呢&#xff1f; 事件流是指事件执行过程中的流动路径。 我们发现&#xff0c;一个完整的事件执行…

Delta lake with Java--利用spark sql操作数据1

今天要解决的问题是如何使用spark sql 建表&#xff0c;插入数据以及查询数据 1、建立一个类叫 DeltaLakeWithSparkSql1&#xff0c;具体代码如下&#xff0c;例子参考Delta Lake Up & Running第3章内容 import org.apache.spark.sql.SaveMode; import org.apache.spark.…

区域文本提示的实时文本到图像生成;通过一致性自注意力机制的视频生成工具保持视频的一致性;专门为雪佛兰汽车设计的客服聊天机器人

✨ 1: StreamMultiDiffusion StreamMultiDiffusion是首个基于区域文本提示的实时文本到图像生成框架&#xff0c;实现了高速且互动的图像生成。 StreamMultiDiffusion 旨在结合加速推理技术和基于区域的文本提示控制&#xff0c;以克服之前解决方案中存在的速度慢和用户交互性…

约瑟夫问题新解法

前言 又碰到了约瑟夫问题&#xff0c;这样的题目本来用环形链表模拟的话就能做出来。然而&#xff0c;最近新学习了一种做法&#xff0c;实在是有点震惊到我了。无论是思路上&#xff0c;还是代码量上&#xff0c;都是那么的精彩。就想也震惊一下其他人。谁能想到原来模拟出来四…

C/C++程序设计实验报告综合作业 | 小小计算器

本文整理自博主本科大一《C/C程序设计》专业课的课内实验报告&#xff0c;适合C语言初学者们学习、练习。 编译器&#xff1a;gcc 10.3.0 ---- 注&#xff1a; 1.虽然课程名为C程序设计&#xff0c;但实际上当时校内该课的内容大部分其实都是C语言&#xff0c;C的元素最多可能只…

深度解析 Spring 源码:探寻Bean的生命周期

文章目录 一、 Bean生命周期概述二、Bean生命周期流程图三、Bean生命周期验证3.1 代码案例3.2 执行结果 四、Bean生命周期源码4.1 setBeanName()4.2 setBeanFactory()4.3 setApplicationContext()4.4 postProcessBeforeInitialization()4.5 afterPropertiesSet()4.6 postProces…

力扣刷题第1天:消失的数字

大家好啊&#xff0c;从今天开始将会和大家一起刷题&#xff0c;从今天开始小生也会开辟新的专栏。&#x1f61c;&#x1f61c;&#x1f61c; 目录 第一部分&#xff1a;题目描述 第二部分&#xff1a;题目分析 第三部分&#xff1a;解决方法 3.1 思路一&#xff1a;先排序…

十、多模态大语言模型(MLLM)

1 多模态大语言模型&#xff08;Multimodal Large Language Models&#xff09; 模态的定义 模态&#xff08;modal&#xff09;是事情经历和发生的方式&#xff0c;我们生活在一个由多种模态(Multimodal)信息构成的世界&#xff0c;包括视觉信息、听觉信息、文本信息、嗅觉信…

MySQL技能树学习——数据库组成

数据库组成&#xff1a; 数据库是一个组织和存储数据的系统&#xff0c;它由多个组件组成&#xff0c;这些组件共同工作以确保数据的安全、可靠和高效的存储和访问。数据库的主要组成部分包括&#xff1a; 数据库管理系统&#xff08;DBMS&#xff09;&#xff1a; 数据库管理系…

MySQL45讲(一)(40)

回顾binlog_formatstatement STATEMENT 记录SQL语句。日志文件小&#xff0c;节约IO&#xff0c;但是对一些系统函数不能准确复制或不能复制&#xff0c;如now()、uuid()等 在RR隔离级别下&#xff0c;binlog_formatstatement 如果执行insert select from 这条语句是对于一张…

OpenCV如何为等值线创建边界旋转框和椭圆(63)

返回:OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇:OpenCV 为轮廓创建边界框和圆(62) 下一篇:OpenCV的图像矩(64) 目标 在本教程中&#xff0c;您将学习如何&#xff1a; 使用 OpenCV 函数 cv::minAreaRect使用 OpenCV 函数 cv::fitEllipse cv::min…

Gradle 进阶学习 之 build.gradle 文件

build.gradle 是什么&#xff1f; 想象一下&#xff0c;你有一个大型的乐高项目&#xff0c;你需要一个清单来列出所有的乐高积木和它们如何组合在一起。在软件开发中&#xff0c;build.gradle 就是这个清单&#xff0c;它告诉计算机如何构建&#xff08;组合&#xff09;你的软…

这是一个简单的照明材料网站,后续还会更新

1、首页效果图 代码 <!DOCTYPE html> <html><head><meta charset"utf-8" /><title>爱德照明网站首页</title><style>/*外部样式*/charset "utf-8";*{margin: 0;padding: 0;box-sizing: border-box;}a{text-dec…