[数据处理] 3. 数据集读取

👋 你好!这里有实用干货与深度分享✨✨ 若有帮助,欢迎:​
👍 点赞 | ⭐ 收藏 | 💬 评论 | ➕ 关注 ,解锁更多精彩!​
📁 收藏专栏即可第一时间获取最新推送🔔。​
📖后续我将持续带来更多优质内容,期待与你一同探索知识,携手前行,共同进步🚀。​



人工智能

数据集读取

本文使用PyTorch框架,介绍PyTorch中数据读取的相关知识。

本文目标:

  1. 了解PyTorch中数据读取的基本概念
  2. 了解PyTorch中集成的开源数据集的读取方法
  3. 了解PyTorch中自定义数据集的读取方法
  4. 了解PyTorch中数据读取的流程

一、数据的准备

使用开源数据集或者自己采集数据后进行数据标注。

PyTorch中数据读取的基本概念

PyTorch中数据读取的基本概念是DatasetDataLoader

Dataset是一个抽象类,用于表示数据集。它包含了数据集的长度、索引、数据获取等方法。

DataLoader是一个类,用于将数据集按批次加载到模型中。它包含了数据读取、数据转换、数据打乱等方法。

实现数据集读取的步骤:

  1. 继承Dataset类,实现__len____getitem__方法
  2. 使用DataLoader类,将数据集按批次加载到模型中

示例代码:

import torch
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index], self.labels[index]data = torch.randn(100, 3, 224, 224)
labels = torch.randint(0, 10, (100,))dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)for batch_data, batch_labels in dataloader:print(batch_data.shape, batch_labels.shape)

PyTorch中集成的开源数据集的读取方法

使用开源数据MNIST作为示范。

数据集链接:MNIST数据集

PyTorch中以及集成了很多开源数据集,我们可以直接使用。MNIST也包括在其中。

只需要使用PyTorch中的torchvision.datasets模块即可。

示例代码:

  1. 引入必要的库:
import torch
from torchvision import datasets
import matplotlib.pyplot as plt
  1. 加载数据集:
train_dataset = datasets.MNIST(root='./data', train=True, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True)

参数说明:

  • root:数据集保存的路径
  • train:是否为训练集
  • download:是否下载数据集
  1. 查看数据集信息:
print(len(train_dataset), len(test_dataset))
print(train_dataset[0][0].size, train_dataset[0][1])
  1. 可视化数据集:
plt.imshow(train_dataset[0][0], cmap='gray')
plt.show()
  1. 数据加载:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break

参数说明:

  • batch_size:批次大小
  • shuffle:是否打乱数据,训练集一般需要打乱数据,测试集一般不需要打乱数据

其实,真实的训练过程只需要步骤1、2、5即可,3、4步骤是为了验证数据集是否正确。

二、PyTorch中自定义数据集的读取方法

自定义数据集的读取方法是指,我们自己定义一个数据集,然后使用PyTorch中的DatasetDataLoader类来读取数据集。因为不是所有的数据集都在PyTorch中集成了,当我们有拥有(自己标注或下载)一个新的数据集时,就需要自己定义数据集的读取方法。

这时候需要将数据集以一定的规则保存起来,然后使用PyTorch中的DatasetDataLoader类来读取数据集。

示例代码:

  1. 引入必要的库:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import matplotlib.pyplot as plt
  1. 定义数据集类:
class MyDataset(Dataset):def __init__(self, data_dir, transform=None):self.data_dir = data_dirself.transform = transformself.data_list = os.listdir(data_dir)def __len__(self):return len(self.data_list)def __getitem__(self, index):data_path = os.path.join(self.data_dir, self.data_list[index])data = np.load(data_path)label = data['label']if self.transform is not None:data = self.transform(data)return data, label

参数说明:

  • data_dir:数据集保存的路径
  • transform:数据转换函数,可选。1. 用于数据增强,一般的数据增强方法有:随机裁剪、随机旋转、随机翻转、随机缩放等。2. 也可以用于数据预处理,如归一化、标准化等。
  1. 定义数据转换函数:
def transform(data):data = data['data']data = data.astype(np.float32)data = data / 255.0data = torch.from_numpy(data)return data
  1. 加载数据集:
train_dataset = MyDataset(data_dir='./data/train', transform=transform)
test_dataset = MyDataset(data_dir='./data/test', transform=transform)
  1. 查看数据集信息:
print(len(train_dataset), len(test_dataset))
print(train_dataset[0][0].size, train_dataset[0][1])
  1. 可视化数据集:
plt.imshow(train_dataset[0][0], cmap='gray')
plt.show()
  1. 数据加载:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break
  1. 数据增强:
from torchvision import transformstransform = transforms.Compose([transforms.RandomCrop(28),  # 随机裁剪,裁剪大小为28x28transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomVerticalFlip(),  # 随机垂直翻转transforms.RandomRotation(10),  # 随机旋转transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),  # 随机仿射变换transforms.ToTensor()  # 转换为张量
])
train_dataset = MyDataset(data_dir='./data/train', transform=transform)
test_dataset = MyDataset(data_dir='./data/test', transform=transform)train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break

DataLoader核心参数详解

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None,num_workers=0, collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,
)

关键参数解析

  • num_workers:数据预加载进程数(建议设为CPU核心数的70-80%)
  • pin_memory:启用CUDA锁页内存加速GPU传输
  • prefetch_factor:每个worker预加载的batch数(PyTorch 1.7+)

数据加载性能优化公式

理论最大吞吐量
T h r o u g h p u t = min ⁡ ( B a t c h S i z e × n u m _ w o r k e r s D a t a L o a d T i m e , G P U C o m p u t e T i m e − 1 ) Throughput = \min\left(\frac{BatchSize \times num\_workers}{DataLoadTime}, GPUComputeTime^{-1}\right) Throughput=min(DataLoadTimeBatchSize×num_workers,GPUComputeTime1)

三、拓展:多模态数据加载示例

class MultiModalDataset(Dataset):def __init__(self, img_dir, text_path):self.img_dir = img_dirself.text_data = pd.read_csv(text_path)self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')def __getitem__(self, idx):# 图像处理img_path = os.path.join(self.img_dir, self.text_data.iloc[idx]['image_id'])image = Image.open(img_path).convert('RGB')image = transforms.ToTensor()(image)# 文本处理text = self.text_data.iloc[idx]['description']inputs = self.tokenizer(text, padding='max_length', truncation=True, max_length=128)return {'image': image,'input_ids': torch.tensor(inputs['input_ids']),'attention_mask': torch.tensor(inputs['attention_mask'])}

四、总结

本文介绍了PyTorch中数据读取的基本概念、集成的开源数据集的读取方法、自定义数据集的读取方法和数据读取的流程。

数据读取是深度学习训练的重要环节,数据读取的流程是:

  1. 定义数据集类
  2. 定义数据转换函数、数据增强函数
  3. 加载数据集



📌 感谢阅读!若文章对你有用,别吝啬互动~​
👍 点个赞 | ⭐ 收藏备用 | 💬 留下你的想法 ,关注我,更多干货持续更新!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/904529.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IIS配置SSL

打开iis 如果搜不到iis,要先开 再搜就打得开了 cmd中找到本机ip 用http访问本机ip 把原本的http绑定删了 再用http访问本机ip就不行了 只能用https访问了

RabbitMQ的交换机

一、三种交换机模式 核心区别对比​​ ​​特性​​​​广播模式(Fanout)​​​​路由模式(Direct)​​​​主题模式(Topic)​​​​路由规则​​无条件复制到所有绑定队列精确匹配 Routing Key通配符匹配…

(2025,AR,NAR,GAN,Diffusion,模型对比,数据集,评估指标,性能对比)文本到图像的生成和编辑:综述

【本文为我在去年完成的综述,因某些原因未能及时投稿,但本文仍能为想要全面了解文本到图像的生成和编辑的学习者提供可靠的参考。目前本文已投稿 ACM Computing Surveys。 完整内容可在如下链接获取,或在 Q 群群文件获取。 中文版为论文初稿&…

MCU怎么运行深度学习模型

Gitee仓库 git clone https://gitee.com/banana-peel-x/freedom-learn.git项目场景: 解决面试时遗留的问题,面试官提了两个问题:1.单片机能跑深度学习的模型吗? 2.为什么FreeRTOS要采用SVC去触发第一个任务,只用Pend…

多模态学习(一)——从 Image-Text Pair 到 Instruction-Following 格式

前言 在多模态任务中(例如图像问答、图像描述等),为了使用指令微调(Instruction Tuning)提升多模态大模型的能力,我们需要构建成千上万条**指令跟随(instruction-following)**格式的…

MySQL基础关键_011_视图

目 录 一、说明 二、操作 1.创建视图 2.创建可替换视图 3.修改视图 4.删除视图 5.对视图内容的增、删、改 (1)增 (2)改 (3)删 一、说明 只能将 DQL 语句创建为视图;作用: …

『深夜_MySQL』数据库操作 字符集与检验规则

2.库的操作 2.1 创建数据库 语法: CREATE DATABASE [IF NOT EXISTS] db_name [create_specification [,create_specification]….]create_spcification:[DEFAULT] CHARACTER SET charset_nam[DEFAULT] COLLATE collation_name说明: 大写的表示关键字 …

Spark jdbc写入崖山等国产数据库失败问题

随着互联网、信息产业的大发展、以及地缘政治的变化,网络安全风险日益增长,网络安全关乎国家安全。因此很多的企业,开始了国产替代的脚步,从服务器芯片,操作系统,到数据库,中间件,逐步实现信息技术自主可控,规避外部技术制裁和风险。 就数据库而言,目前很多的国产数据…

数字化转型-4A架构之应用架构

系列文章 数字化转型-4A架构(业务架构、应用架构、数据架构、技术架构)数字化转型-4A架构之业务架构 前言 应用架构AA(Application Architecture)是规划支撑业务的核心系统与功能模块,实现端到端协同。 一、什么是应…

格雷狼优化算法`GWO 通过模拟和优化一个信号处理问题来最大化特定频率下的功率

这段代码是一个Python程序,它使用了多个科学计算库,包括`random`、`numpy`、`matplotlib.pyplot`、`scipy.signal`和`scipy.signal.windows`。程序的主要目的是通过模拟和优化一个信号处理问题来最大化特定频率下的功率。 4. **定义类`class_model`**: - 这个类包含了信号…

中级网络工程师知识点1

1.1000BASE-CX:铜缆,最大传输距离为25米 1000BASE-LX:传输距离可达3000米 1000BASE-ZX:超过10km 2.RSA加密算法的安全性依赖于大整数分解问题的困难性 3.网络信息系统的可靠性测度包括有效性,康毁性,生存性 4.VLAN技术所依据的协议是IEEE802.1q IEEE802.15标准是针…

2025年五一数学建模A题【支路车流量推测】原创论文讲解

大家好呀,从发布赛题一直到现在,总算完成了2025年五一数学建模A题【支路车流量推测】完整的成品论文。 给大家看一下目录吧: 摘 要: 一、问题重述 二.问题分析 2.1问题一 2.2问题二 2.3问题三 2.4问题四 2.5 …

性能优化实践:渲染性能优化

性能优化实践:渲染性能优化 在Flutter应用开发中,渲染性能直接影响用户体验。本文将从渲染流程分析入手,深入探讨Flutter渲染性能优化的关键技术和最佳实践。 一、Flutter渲染流程解析 1.1 渲染流水线 Flutter的渲染流水线主要包含以下几…

linux基础学习--linux磁盘与文件管理系统

linux磁盘与文件管理系统 1.认识linux系统 1.1 磁盘组成与分区的复习 首先了解磁盘的物理组成,主要有: 圆形的碟片(主要记录数据的部分)。机械手臂,与在机械手臂上的磁头(可擦写碟片上的内容)。主轴马达,可以转动碟片,让机械手臂的磁头在碟片上读写数据。 数据存储…

DIFY教程第五弹:科研论文翻译与SEO翻译应用

科研论文翻译 我可以在工作流案例中结合聊天大模型来实现翻译工具的功能,具体的设计如下 在开始节点中接收一个输入信息 content 然后在 LLM 模型中我们需要配置一个 CHAT 模型,这里选择了 DeepSeek-R1 64K 的聊天模型,注意需要在这里设置下…

【Redis】哨兵机制和集群

🔥个人主页: 中草药 🔥专栏:【中间件】企业级中间件剖析 一、哨兵机制 Redis的主从复制模式下,一旦主节点由于故障不能提供服务,需要人工的进行主从切换,同时需要大量的客户端需要被通知切换到…

注意力机制(Attention)

1. 注意力认知和应用 AM: Attention Mechanism,注意力机制。 根据眼球注视的方向,采集显著特征部位数据: 注意力示意图: 注意力机制是一种让模型根据任务需求动态地关注输入数据中重要部分的机制。通过注意力机制&…

解锁 AI 生产力:Google 四大免费工具全面解析20250507

🚀 解锁 AI 生产力:Google 四大免费工具全面解析 在人工智能迅猛发展的今天,Google 推出的多款免费工具正在悄然改变我们的学习、工作和创作方式。本文将深入解析四款代表性产品:NotebookLM、Google AI Studio、Google Colab 和 …

知识图谱:AI大脑中的“超级地图”如何炼成?

人类看到“苹果”一词,会瞬间联想到“iPhone”“乔布斯”“牛顿”,甚至“维生素C”——这种思维跳跃的背后,是大脑将概念连结成网的能力。而AI要模仿这种能力,需要一张动态的“数字地图”来存储和链接知识,这就是​知识…

Win11 24H2首个热补丁下周推送!更新无需重启

快科技5月7 日消息,微软宣布,Windows 11 24H2的首个热补丁更新将于下周通过Patch Tuesday发布,将为管理员带来更高效的安全更新部署方式,同时减少设备停机时间。 为帮助IT管理员顺利过渡到热补丁模式,微软还提供了丰富…