机器学习与深度学习4:数据集处理Dataset,DataLoader,batch_size

        深度学习中,我们能看到别人的代码中都有一个继承Dataset类的数据集处理过程,这也是深度学习处理数据集的的基础,下面介绍这个数据集的定义和使用:

1、数据集加载

1.1 通用的定义

Bach:表示每次喂给模型的数据

Epoch:表示训练一次完整数据集数据的过程

解释:当一个数据集的大小为10时,设定batch大小为5,那么这个数据就会分为2份,每份大小为5,依次投入到模型中进行训练。训练完所有数据后,就叫做一次迭代,称为epoch

1.2 继承Dataset类

我们继承Dataset类需要实现它的三个方法,代码在文末,与Dataloader代码一起。

init:载入数据

getitem:返回指定位置数据

len:返回数据长度

固定用法如下:

import numpy as np
import torch
from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self):#载入数据passdef __getitem__(self, item):#返回相应位置的数据passdef __len__(self):#返回数据长度pass

 例如我们有数据集为手写数字识别数据,文件目录如下:

        在pytorch当然最简单的是用内置的MNIST函数,这里不使用该方法,使用Dataset类写一下。

载入数据:由于数据量太大,因此我们载入每个数据的索引,也就是数据的路径

返回相应位置的数据:实现给出index,能返回相应位置的数据。

返回数据长度:返回所有数据的个数。

1.3 代码实现

灰度图转换(任选其一)

任选其一都可以实现,将原始图片转为灰度图:

transforms.Grayscale(num_output_channels=1)#transform实现转换
Image.open(image_path).convert("L")        #image库转换灰度图

因此可以写出Dataset类加载代码 :

transform = transforms.Compose([#transforms.Grayscale(num_output_channels=1),  # 转换为单通道灰度图transforms.ToTensor()  # 转换为张量
])
class MyDataset(Dataset):def __init__(self):# 载入数据self.images = []self.labels = []for i in range(10):pathX =os.path.join('../mnist_images/train',str(i))imageNameList = os.listdir(pathX)image = []for filename in imageNameList:imagePath = os.path.join('../mnist_images/train',str(i),filename)image.append(imagePath)label = [i] * len(image)#label = [i for _ in range(len(image))]列表推导式self.images.extend(image)self.labels.extend(label)def __getitem__(self, item):#返回相应位置的数据image = Image.open(self.images[item]).convert("L")#image = Image.open(self.images[item])return transform(image),torch.tensor(self.labels[item])#返回一个元组def __len__(self):#返回数据长度return len(self.images)

1.4 Dataloader批量加载 

        使用Dataset函数处理数据集后,就需要使用Dataloader,它的使用很简单,只有一行:

DataLoader(oneDataset, batch_size=32, shuffle=True, drop_last = False,num_works = 8)

        其中oneDateset表示输入的Dataset对象下面是对其中一些参数的解释:

batach_size 表示一个Batch的大小

shuffle 表示是否打乱数据

drop_last 表示是否舍弃最后数据,若为True那么会舍弃Datasize对batch_size不能整除的部分,也就是如果数据量为10,batch_size为3的话,最后一个数据会被舍弃,如果drop_last为False的话,最后一个数会被保留。也就是最后一个batch_size的大小为1。

num_works 表示使用多少进程加载数据,num_works = 0表示使用主进程加载数据,num_works > 0表示使用多少个子进程加载数据。

        DataLoader返回为一个张量形状为[batch_size, channels, height, width] batch_size表示批量大小,可以是任意正整数,训练模型时,模型输入对该参数batch_size无要求限制,但是后面的三个特征维度[channels, height, width]必须跟模型model定义的输入层数据维度一致。

1.5完整代码:

import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transformstransform = transforms.Compose([#transforms.Grayscale(num_output_channels=1),  # 转换为单通道灰度图transforms.ToTensor()  # 转换为张量
])
class MyDataset(Dataset):def __init__(self):# 载入数据self.images = []self.labels = []for i in range(10):pathX =os.path.join('../mnist_images/train',str(i))imageNameList = os.listdir(pathX)image = []for filename in imageNameList:imagePath = os.path.join('../mnist_images/train',str(i),filename)image.append(imagePath)label = [i] * len(image)#label = [i for _ in range(len(image))]列表推导式self.images.extend(image)self.labels.extend(label)def __getitem__(self, item):#返回相应位置的数据image = Image.open(self.images[item]).convert("L")#image = Image.open(self.images[item])return transform(image),torch.tensor(self.labels[item])#返回一个元组def __len__(self):#返回数据长度return len(self.images)
def getDataloder():oneDataset = MyDataset()return DataLoader(oneDataset, batch_size=32, shuffle=True)
if __name__ == '__main__':dataloader = getDataloder()for images, labels in dataloader:print("Batch shape:", images.shape)  # 输出批次形状print("Labels:", labels)  # 输出标签#print(images[0][0][18])break  # 只打印第一个批次

二、 文件下载

文件项目是一个完整的简单神经网络训练手写数字识别,打包下载在这里:点击下载项目

        最后:实现手写数字识别数据集加载方法最简单的是使用pytorch内置MNIST函数实现,仅有一行代码实现上述功能,本文不采用该方法,通过自行实现理解数据集加载原理。

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/74042.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL数据库和表的操作之SQL语句

🎯 本文专栏:MySQL深入浅出 🚀 作者主页:小度爱学习 MySQL数据库和表的操作 关系型数据库,都是遵循SQL语法进行数据查询和管理的。 SQL语句 什么是sql SQL:结构化查询语言(Structured Query Language)&…

ubuntu开发mcu环境

# 编辑 vim或者vscode # 编译 arm-none-eabi # 烧写 openocd 若是默认安装,会在/usr/share/openocd/scripts/{interface,target} 有配置接口和目标版配置 示例: openocd -f interface/stlink-v2.cfg -f target/stm32f1x.cfg 启动后,会…

Windows模仿Mac大小写切换, 中英文切换

CapsLock 功能优化脚本部署指南 部署步骤 第一步:安装 AutoHotkey v2 访问 AutoHotkey v2 官网下载并安装最新版本安装时勾选 "Add Compile Script to context menus" 第二步:部署脚本 直接运行 (调试推荐) 新建文本文件,粘贴…

Selenium Web自动化如何快速又准确的定位元素路径,强调一遍是元素路径

如果文章对你有用,请给个赞! 匹配的ChromeDriver和浏览器版本是更好完成自动化的基础,可以从这里去下载驱动程序: 最全ChromeDriver下载含win linux mac 最新版本134.0.6998.165 持续更新..._chromedriver 134-CSDN博客 如果你问…

CSRF vs SSRF详解

一、CSRF&#xff08;跨站请求伪造&#xff09;攻击全解 攻击原理示意图 受害者浏览器 ├── 已登录银行网站&#xff08;Cookie存活&#xff09; └── 访问恶意网站执行&#xff1a;<img src"http://bank.com/transfer?tohacker&amount1000000">核心…

Python PDF解析利器:pdfplumber | AI应用开发

Python PDF解析利器&#xff1a;pdfplumber全面指南 1. 简介与安装 1.1 pdfplumber概述 pdfplumber是一个Python库&#xff0c;专门用于从PDF文件中提取文本、表格和其他信息。相比其他PDF处理库&#xff0c;pdfplumber提供了更直观的API和更精确的文本定位能力。 主要特点…

niuhe 插件教程 - 配置 MCP让AI更聪明

niuhe 插件官方教程已经上线, 请访问: http://niuhe.zuxing.net niuhe 连接 MCP 介绍 API 文档的未来&#xff1a;MCP&#xff0c;让协作像聊天一样简单. MCP 是 Model Context Protocol(模型上下文协议)的缩写&#xff0c;是 2024 年 11 月 Claude 的公司 Anthropic 推出并开…

26考研——排序_插入排序(8)

408答疑 文章目录 二、插入排序基本概念插入排序方法直接插入排序算法描述示例性能分析 折半插入排序改进点算法步骤性能分析 希尔排序相关概念示例分析希尔排序的效率效率分析空间复杂度时间复杂度 九、参考资料鲍鱼科技课件26王道考研书 二、插入排序 基本概念 定义&#x…

精华贴分享|从不同的交易理论来理解头肩形态,殊途同归

本文来源于量化小论坛策略分享会板块精华帖&#xff0c;作者为孙小迪&#xff0c;发布于2025年2月17日。 以下为精华帖正文&#xff1a; 01 前言 学习了一段时间交易后&#xff0c;我发现在几百年的历史中&#xff0c;不同门派的交易理论对同一种市场特征的称呼不一样&#x…

leetcode437.路径总和|||

对于根结点来说&#xff0c;可以选择当前结点为路径也可以不选择&#xff0c;但是一旦选择当前结点为路径那么后续都必须要选择结点作为路径&#xff0c;不然路径不连续是不合法的&#xff0c;所以这里分开出来两个方法进行递归 由于力扣最后一个用例解答错误&#xff0c;分析…

北斗导航 | 改进奇偶矢量法的接收机自主完好性监测算法原理,公式,应用,RAIM算法研究综述,matlab代码

改进奇偶矢量法的接收机自主完好性监测算法研究 摘要 接收机自主完好性监测(RAIM)是保障全球导航卫星系统(GNSS)安全性的核心技术。针对传统奇偶矢量法在噪声敏感性、多故障隔离能力上的缺陷,本文提出一种基于加权奇偶空间与动态阈值的改进算法。通过引入观测值权重矩阵重…

深度神经网络全解析:原理、结构与方法对比

深度神经网络全解析&#xff1a;原理、结构与方法对比 1. 引言 随着人工智能的发展&#xff0c;深度神经网络&#xff08;Deep Neural Network&#xff0c;DNN&#xff09;已经成为图像识别、自然语言处理、语音识别、自动驾驶等领域的核心技术。相比传统机器学习方法&#x…

经典论文解读系列:MapReduce 论文精读总结:简化大规模集群上的数据处理

&#x1f9e0; MapReduce 论文解读总结&#xff1a;简化大规模集群上的数据处理 原文标题&#xff1a;MapReduce: Simplified Data Processing on Large Clusters 作者&#xff1a;Jeffrey Dean & Sanjay Ghemawat 发表时间&#xff1a;2004 年 发表机构&#xff1a;Google…

通过Appium理解MCP架构

MCP即Model Context Protocol&#xff08;模型上下文协议&#xff09;&#xff0c;是由Anthropic公司于2024年11月26日推出的开放标准框架&#xff0c;旨在为大型语言模型与外部数据源、工具及系统建立标准化交互协议&#xff0c;以打破AI与数据之间的连接壁垒。 MCP架构与Appi…

网页版五子棋项目的问题处理

文章目录 config.WebSocketConfig将键值对加⼊OnlineUserManager中线程安全、锁ObjectMapper来处理json针对多开情况的判定处理连接关闭、异常&#xff08;玩家中途退出&#xff09;后的不合理操作游戏大厅数据更新 config.WebSocketConfig 把MatchAPI注册进去 • 在addHandle…

【初探数据结构】归并排序与计数排序的序曲

&#x1f4ac; 欢迎讨论&#xff1a;在阅读过程中有任何疑问&#xff0c;欢迎在评论区留言&#xff0c;我们一起交流学习&#xff01; &#x1f44d; 点赞、收藏与分享&#xff1a;如果你觉得这篇文章对你有帮助&#xff0c;记得点赞、收藏&#xff0c;并分享给更多对数据结构感…

算法刷题记录——LeetCode篇(8.7) [第761~770题](持续更新)

更新时间&#xff1a;2025-03-30 算法题解目录汇总&#xff1a;算法刷题记录——题解目录汇总技术博客总目录&#xff1a;计算机技术系列博客——目录页 优先整理热门100及面试150&#xff0c;不定期持续更新&#xff0c;欢迎关注&#xff01; 763. 划分字母区间 给你一个字…

Pod 网络与 CNI 的作用

在 Kubernetes 中&#xff0c;Pod 网络 是实现容器间通信的核心机制&#xff0c;每个 Pod 拥有独立的 IP 地址&#xff0c;可直接跨节点通信。CNI&#xff08;Container Network Interface&#xff09; 是 Kubernetes 的网络插件标准&#xff0c;负责为 Pod 分配 IP、配置网络规…

使用keepalived结合tomcat和nginx搭建三主热备架构

角色主机名软件IP地址用户client172.25.250.90keepalivedVIP172.25.250.100keepalivedVIP172.25.250.101keepalivedVIP172.25.250.102masterserverAkeepalived, nginx172.25.250.30backupserverBkeepalived, nginx172.25.250.31backupserverCkeepalived, nginx172.25.250.32web…

STRUCTBERT:将语言结构融入预训练以提升深度语言理解

【摘要】最近&#xff0c;预训练语言模型BERT&#xff08;及其经过稳健优化的版本RoBERTa&#xff09;在自然语言理解&#xff08;NLU&#xff09;领域引起了广泛关注&#xff0c;并在情感分类、自然语言推理、语义文本相似度和问答等各种NLU任务中达到了最先进的准确率。受到E…