深度学习分类回归(衣帽数据集)

一、步骤

1 加载数据集fashion_minst

2 搭建class NeuralNetwork模型

3 设置损失函数,优化器

4 编写评估函数

5 编写训练函数

6 开始训练

7 绘制损失,准确率曲线

二、代码

导包,打印版本号:

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as Fprint(sys.version_info)
for module in mpl, np, pd, sklearn, torch:print(module.__name__, module.__version__)device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

torch的运算过程都是张量,也叫算子(tensor)

torchvision的包可以提供数据集,图片就是datasets:

这里下载到data目录,如果已有数据则不会下载。这段代码可以实现数据向tensor的转换:

做预处理的时候把图片变成tensor,啥都没写的时候就不会转换成tensor 

from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision import transforms# 定义数据集的变换
transform = transforms.Compose([
])
# fashion_mnist图像分类数据集,衣服分类,60000张训练图片,10000张测试图片
train_ds = datasets.FashionMNIST(root="data",train=True,download=True,transform=transform
)test_ds = datasets.FashionMNIST(root="data",train=False,download=True,transform=transform
)# torchvision 数据集里没有提供训练集和验证集的划分
# 当然也可以用 torch.utils.data.Dataset 实现人为划分
type(train_ds[0]) # 元组,第一个元素是图片,第二个元素是标签

如果使用了数据类型变换:

img_tensor, label = train_ds[0]
img_tensor.shape  #img这时是一个tensor,shape=(1, 28, 28)

在PyTorch中,DataLoader是一个迭代器,它封装了数据的加载和预处理过程,使得在训练机器学习模型时可以方便地批量加载数据。DataLoader主要负责以下几个方面:

  1. 批量加载数据DataLoader可以将数据集(Dataset)切分为更小的批次(batch),每次迭代提供一小批量数据,而不是单个数据点。这有助于模型学习数据中的统计依赖性,并且可以更高效地利用GPU等硬件的并行计算能力。

  2. 数据打乱:默认情况下,DataLoader会在每个epoch(训练周期)开始时打乱数据的顺序。这有助于模型训练时避免陷入局部最优解,并且可以提高模型的泛化能力。

  3. 多线程数据加载DataLoader支持多线程(通过参数num_workers)来并行地加载数据,这可以显著减少训练过程中的等待时间,尤其是在处理大规模数据集时。

  4. 数据预处理DataLoader可以与transforms结合使用,对加载的数据进行预处理,如归一化、标准化、数据增强等操作。

  5. 内存管理DataLoader负责管理数据的内存使用,确保在训练过程中不会耗尽内存资源。

  6. 易用性DataLoader提供了一个简单的接口,可以很容易地集成到训练循环中。

# 从数据集到dataloader
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True) #batch_size分批,shuffle洗牌
val_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False)

这里每32个样本就会算一次平均损失,更新一次w。

定义模型:继承nn.Module

class NeuralNetwork(nn.Module):def __init__(self):super().__init__() # 继承父类的初始化方法,子类有父类的属性self.flatten = nn.Flatten()  # 展平层self.linear_relu_stack = nn.Sequential(nn.Linear(784, 300),  # in_features=784, out_features=300, 784是输入特征数,300是输出特征数nn.ReLU(), # 激活函数nn.Linear(300, 100),#隐藏层神经元数100nn.ReLU(), # 激活函数nn.Linear(100, 10),#输出层神经元数10 )def forward(self, x): # 前向计算,前向传播# x.shape [batch size, 1, 28, 28],1是通道数x = self.flatten(x)  # print(f'x.shape--{x.shape}')# 展平后 x.shape [batch size, 784]logits = self.linear_relu_stack(x)# logits.shape [batch size, 10]return logits #没有经过softmax,称为logitsmodel = NeuralNetwork()

model的结构:第一层是展平层,然后激活,然后隐藏层,激活,输出层


 在训练之前需要测试一下模型能不能用,所以我们随机一个或者从样本拿一个,同尺寸就行:

#为了查看模型运算的tensor尺寸
x = torch.randn(32, 1, 28, 28)
print(x.shape)
logits = model(x) # 把x输入到模型中,得到logits
print(logits.shape)

 然后开始训练,pytorch的训练需要自行实现,包括定义损失函数、优化器、训练步,训练

# 1. 定义损失函数 采用交叉熵损失
loss_fct = nn.CrossEntropyLoss() #内部先做softmax,然后计算交叉熵
# 2. 定义优化器 采用SGD
# Optimizers specified in the torch.optim package,随机梯度下降
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
from sklearn.metrics import accuracy_score # sk里面有一个算子,可以计算准确率@torch.no_grad() # 装饰器,禁止反向传播,节省内存,就是不求导的意思
def evaluating(model, dataloader, loss_fct): # 评估函数,评估也要做一次向前计算,不需要求梯度loss_list = [] # 记录损失pred_list = [] # 记录预测label_list = [] # 记录标签for datas, labels in dataloader:#10000/32=312datas = datas.to(device) # 转到GPUlabels = labels.to(device) # 转到GPU 这两行代码torch必写,把tensor放到GPU上# 前向计算logits = model(datas)  # 进行前向计算loss = loss_fct(logits, labels)         # 验证集损失,loss尺寸是一个数值loss_list.append(loss.item()) # 记录损失,item是把tensor转换为数值preds = logits.argmax(axis=-1)    # 验证集预测,argmax返回最大值索引,-1就是最后一个维度print(f'评估中的preds.shape--{preds.shape}')pred_list.extend(preds.cpu().numpy().tolist())#将PyTorch张量转换为NumPy数组。只有当张量在CPU上时,这个转换才是合法的# print(preds.cpu().numpy().tolist())label_list.extend(labels.cpu().numpy().tolist())acc = accuracy_score(label_list, pred_list) # 计算准确率return np.mean(loss_list), acc
# 训练
def training(model, train_loader, val_loader, epoch, loss_fct, optimizer, eval_step=500):#参数分别是模型,训练集,验证集,训练epoch,损失函数,优化器,评估步数(500评估一次)record_dict = { # 记录字典,用于记录训练过程中的信息"train": [],"val": []}global_step = 0 # 全局步数,记录训练的步数model.train() # 进入训练模式,模型可以切换模式#tqdm是一个进度条库with tqdm(total=epoch * len(train_loader)) as pbar: # 进度条 加入epoch等于10,就是所有样本搞10次,不断地把样本带进去学习,1875*10,60000/32=1875for epoch_id in range(epoch): # 训练epoch次# trainingfor datas, labels in train_loader: #执行次数是60000/32=1875datas = datas.to(device) #datas尺寸是[batch_size,1,28,28]labels = labels.to(device) #labels尺寸是[batch_size]# 梯度清空optimizer.zero_grad() # 每次训练前都要把梯度清空,不然会累加# 模型前向计算logits = model(datas)# 计算损失loss = loss_fct(logits, labels)# 梯度回传,loss.backward()会计算梯度,loss对模型参数求导loss.backward()# 调整优化器,包括学习率的变动等,优化器的学习率会随着训练的进行而减小,更新w,boptimizer.step() #梯度是计算并存储在模型参数的 .grad 属性中,优化器使用这些存储的梯度来更新模型参数preds = logits.argmax(axis=-1) # 训练集预测acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())   # 计算准确率,numpy可以,每个step都算一次loss = loss.cpu().item() # 损失转到CPU,item()取值,一个数值# tensor如果只有一个值(标量),一维是向量,二维是矩阵,可以用item()取出值,如果有多个值,则需要用tolist()转为列表# record# recordrecord_dict["train"].append({"loss": loss, "acc": acc, "step": global_step}) # 记录训练集信息,每一步的损失,准确率,步数# evaluatingif global_step % eval_step == 0:model.eval() # 进入评估模式,不会求梯度val_loss, val_acc = evaluating(model, val_loader, loss_fct)record_dict["val"].append({"loss": val_loss, "acc": val_acc, "step": global_step})model.train() # 进入训练模式# udate stepglobal_step += 1 # 全局步数加1pbar.update(1) # 更新进度条pbar.set_postfix({"epoch": epoch_id}) # 设置进度条显示信息return record_dictepoch = 20 #改为40
model = model.to(device)
record = training(model, train_loader, val_loader, epoch, loss_fct, optimizer, eval_step=1000)
#画线要注意的是损失是不一定在零到1之间的
def plot_learning_curves(record_dict, sample_step=1000):# build DataFrametrain_df = pd.DataFrame(record_dict["train"]).set_index("step").iloc[::sample_step]val_df = pd.DataFrame(record_dict["val"]).set_index("step")last_step = train_df.index[-1] # 最后一步的步数# print(train_df.columns)print(train_df['acc'])print(val_df['acc'])# plotfig_num = len(train_df.columns) # 画几张图,分别是损失和准确率fig, axs = plt.subplots(1, fig_num, figsize=(5 * fig_num, 5))for idx, item in enumerate(train_df.columns):# print(train_df[item].values)axs[idx].plot(train_df.index, train_df[item], label=f"train_{item}")axs[idx].plot(val_df.index, val_df[item], label=f"val_{item}")axs[idx].grid() # 显示网格axs[idx].legend() # 显示图例axs[idx].set_xticks(range(0, train_df.index[-1], 5000)) # 设置x轴刻度axs[idx].set_xticklabels(map(lambda x: f"{int(x/1000)}k", range(0, last_step, 5000))) # 设置x轴标签axs[idx].set_xlabel("step")plt.show()plot_learning_curves(record)  #横坐标是 steps

# dataload for evaluatingmodel.eval() # 进入评估模式
loss, acc = evaluating(model, val_loader, loss_fct)
print(f"loss:     {loss:.4f}\naccuracy: {acc:.4f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/897454.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【leetcode hot 100 19】删除链表的第N个节点

解法一:将ListNode放入ArrayList中,要删除的元素为num list.size()-n。如果num 0则将头节点删除;否则利用num-1个元素的next删除第num个元素。 /*** Definition for singly-linked list.* public class ListNode {* int val;* Lis…

【iOS逆向与安全】sms短信转发插件与上传服务器开发

一、目标 一步步分析并编写一个短信自动转发的deb插件 二、工具 mac系统已越狱iOS设备:脱壳及frida调试IDA Pro:静态分析测试设备:iphone6s-ios14.1.1三、步骤 1、守护进程 ​ 守护进程(daemon)是一类在后台运行的特殊进程,用于执行特定的系统任务。例如:推送服务、人…

Midjourney绘图参数详解:从基础到高级的全面指南

引言 Midjourney作为当前最受欢迎的AI绘图工具之一,其强大的参数系统为用户提供了丰富的创作可能性。本文将深入解析Midjourney的各项参数,帮助开发者更好地掌握这一工具,提升创作效率和质量。 一、基本参数配置 1. 图像比例调整 使用--ar…

音频进阶学习十九——逆系统(简单进行回声消除)

文章目录 前言一、可逆系统1.定义2.解卷积3.逆系统恢复原始信号过程4.逆系统与原系统的零极点关系 二、使用逆系统去除回声获取原信号的频谱原系统和逆系统幅频响应和相频响应使用逆系统恢复原始信号整体代码如下 总结 前言 在上一篇音频进阶学习十八——幅频响应相同系统、全…

vue3 使用sass变量

1. 在<style>中使用scss定义的变量和css变量 1. 在/style/variables.scss文件中定义scss变量 // scss变量 $menuText: #bfcbd9; $menuActiveText: #409eff; $menuBg: #304156; // css变量 :root {--el-menu-active-color: $menuActiveText; // 活动菜单项的文本颜色--el…

gbase8s rss集群通信流程

什么是rss RSS是一种将数据从主服务器复制到备服务器的方法 实例级别的复制 (所有启用日志记录功能的数据库) 基于逻辑日志的复制技术&#xff0c;需要传输大量的逻辑日志,数据库需启用日志模式 通过网络持续将数据复制到备节点 如果主服务器发生故障&#xff0c;那么备用服务…

熵与交叉熵详解

前言 本文隶属于专栏《机器学习数学通关指南》&#xff0c;该专栏为笔者原创&#xff0c;引用请注明来源&#xff0c;不足和错误之处请在评论区帮忙指出&#xff0c;谢谢&#xff01; 本专栏目录结构和参考文献请见《机器学习数学通关指南》 ima 知识库 知识库广场搜索&#…

程序化广告行业(3/89):深度剖析行业知识与数据处理实践

程序化广告行业&#xff08;3/89&#xff09;&#xff1a;深度剖析行业知识与数据处理实践 大家好&#xff01;一直以来&#xff0c;我都希望能和各位技术爱好者一起在学习的道路上共同进步&#xff0c;分享知识、交流经验。今天&#xff0c;咱们聚焦在程序化广告这个充满挑战…

探索在生成扩散模型中基于RAG增强生成的实现与未来

概述 像 Stable Diffusion、Flux 这样的生成扩散模型&#xff0c;以及 Hunyuan 等视频模型&#xff0c;都依赖于在单一、资源密集型的训练过程中通过固定数据集获取的知识。任何在训练之后引入的概念——被称为 知识截止——除非通过 微调 或外部适应技术&#xff08;如 低秩适…

DeepSeek 助力 Vue3 开发:打造丝滑的表格(Table)之添加列宽调整功能,示例Table14基础固定表头示例

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 Deep…

取反符号~

取反符号 ~ 用于对整数进行按位取反操作。它会将二进制表示中的每一位取反&#xff0c;即 0 变 1&#xff0c;1 变 0。 示例 a 5 # 二进制表示为 0000 0101 b ~a # 按位取反&#xff0c;结果为 1111 1010&#xff08;补码表示&#xff09; print(b) # 输出 -6解释 5 的二…

论文阅读分享——UMDF(AAAI-24)

概述 题目&#xff1a;A Unified Self-Distillation Framework for Multimodal Sentiment Analysis with Uncertain Missing Modalities 发表&#xff1a;The Thirty-Eighth AAAI Conference on Artificial Intelligence (AAAI-24) 年份&#xff1a;2024 Github&#xff1a;暂…

WBC已形成“东亚-美洲双中心”格局·棒球1号位

世界棒球经典赛&#xff08;WBC&#xff09;作为全球最高水平的国家队棒球赛事&#xff0c;参赛队伍按实力、地域和历史表现可分为多个“阵营”。以下是基于历届赛事&#xff08;截至2023年&#xff09;的阵营划分及代表性队伍分析&#xff1a; 第一阵营&#xff1a;传统豪强&a…

django中路由配置规则的详细说明

在 Django 中,路由配置是将 URL 映射到视图函数或类视图的关键步骤,它决定了用户请求的 URL 会触发哪个视图进行处理。以下将详细介绍 Django 中路由配置的规则、高级使用方法以及多个应用配置的规则。 基本路由配置规则 1. 项目级路由配置 在 Django 项目中,根路由配置文…

【报错】微信小程序预览报错”60001“

1.问题描述 我在微信开发者工具写小程序时&#xff0c;使用http://localhost:8080是可以请求成功的&#xff0c;数据全都可以无报错&#xff0c;但是点击【预览】&#xff0c;用手机扫描二维码浏览时&#xff0c;发现前端图片无返回且报错60001&#xff08;打开开发者模式查看日…

栅格裁剪(Python)

在地理数据处理中&#xff0c;矢量裁剪栅格是一个非常重要的操作&#xff0c;它可以帮助我们提取感兴趣的区域并获得更精确的分析结果。其重要性包括&#xff1a; 区域限定&#xff1a;地球科学研究通常需要关注特定的地理区域。通过矢量裁剪栅格&#xff0c;我们可以将栅格数…

【无人机路径规划】基于麻雀搜索算法(SSA)的无人机路径规划(Matlab)

效果一览 代码获取私信博主基于麻雀搜索算法&#xff08;SSA&#xff09;的无人机路径规划&#xff08;Matlab&#xff09; 一、算法背景与核心思想 麻雀搜索算法&#xff08;Sparrow Search Algorithm, SSA&#xff09;是一种受麻雀群体觅食行为启发的元启发式算法&#xff0…

MySQL数据库安装及基础用法

安装数据库 第一步&#xff1a;下载并解压mysql-8.4.3-winx64文件夹 链接: https://pan.baidu.com/s/1lD6XNNSMhPF29I2_HBAvXw?pwd8888 提取码: 8888 第二步&#xff1a;打开文件中的my.ini文件 [mysqld]# 设置3306端口port3306# 自定义设置mysql的安装目录&#xff0c;即解…

软件工程:软件开发之需求分析

物有本末&#xff0c;事有终始。知所先后&#xff0c;则近道矣。对软件开发而言&#xff0c;软件需求乃重中之重。必先之事重千钧&#xff0c;不可或缺如日辰。 汽车行业由于有方法论和各种标准约束&#xff0c;对软件开发有严苛的要求。ASPICE指导如何审核软件开发&#xff0…

正则表达式,idea,插件anyrule

​​​​package lx;import java.util.regex.Pattern;public class lxx {public static void main(String[] args) {//正则表达式//写一个电话号码的正则表达式String regex "1[3-9]\\d{9}";//第一个数字是1&#xff0c;第二个数字是3-9&#xff0c;后面跟着9个数字…