👋 你好!这里有实用干货与深度分享✨✨ 若有帮助,欢迎:
👍 点赞 | ⭐ 收藏 | 💬 评论 | ➕ 关注 ,解锁更多精彩!
📁 收藏专栏即可第一时间获取最新推送🔔。
📖后续我将持续带来更多优质内容,期待与你一同探索知识,携手前行,共同进步🚀。
数据集读取
本文使用PyTorch框架,介绍PyTorch中数据读取的相关知识。
本文目标:
- 了解PyTorch中数据读取的基本概念
- 了解PyTorch中集成的开源数据集的读取方法
- 了解PyTorch中自定义数据集的读取方法
- 了解PyTorch中数据读取的流程
一、数据的准备
使用开源数据集或者自己采集数据后进行数据标注。
PyTorch中数据读取的基本概念
PyTorch中数据读取的基本概念是Dataset
和DataLoader
。
Dataset
是一个抽象类,用于表示数据集。它包含了数据集的长度、索引、数据获取等方法。
DataLoader
是一个类,用于将数据集按批次加载到模型中。它包含了数据读取、数据转换、数据打乱等方法。
实现数据集读取的步骤:
- 继承
Dataset
类,实现__len__
和__getitem__
方法 - 使用
DataLoader
类,将数据集按批次加载到模型中
示例代码:
import torch
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index], self.labels[index]data = torch.randn(100, 3, 224, 224)
labels = torch.randint(0, 10, (100,))dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)for batch_data, batch_labels in dataloader:print(batch_data.shape, batch_labels.shape)
PyTorch中集成的开源数据集的读取方法
使用开源数据MNIST作为示范。
数据集链接:MNIST数据集
PyTorch中以及集成了很多开源数据集,我们可以直接使用。MNIST也包括在其中。
只需要使用PyTorch中的torchvision.datasets
模块即可。
示例代码:
- 引入必要的库:
import torch
from torchvision import datasets
import matplotlib.pyplot as plt
- 加载数据集:
train_dataset = datasets.MNIST(root='./data', train=True, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True)
参数说明:
root
:数据集保存的路径train
:是否为训练集download
:是否下载数据集
- 查看数据集信息:
print(len(train_dataset), len(test_dataset))
print(train_dataset[0][0].size, train_dataset[0][1])
- 可视化数据集:
plt.imshow(train_dataset[0][0], cmap='gray')
plt.show()
- 数据加载:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break
参数说明:
batch_size
:批次大小shuffle
:是否打乱数据,训练集一般需要打乱数据,测试集一般不需要打乱数据
其实,真实的训练过程只需要步骤1、2、5即可,3、4步骤是为了验证数据集是否正确。
二、PyTorch中自定义数据集的读取方法
自定义数据集的读取方法是指,我们自己定义一个数据集,然后使用PyTorch中的Dataset
和DataLoader
类来读取数据集。因为不是所有的数据集都在PyTorch中集成了,当我们有拥有(自己标注或下载)一个新的数据集时,就需要自己定义数据集的读取方法。
这时候需要将数据集以一定的规则保存起来,然后使用PyTorch中的Dataset
和DataLoader
类来读取数据集。
示例代码:
- 引入必要的库:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import matplotlib.pyplot as plt
- 定义数据集类:
class MyDataset(Dataset):def __init__(self, data_dir, transform=None):self.data_dir = data_dirself.transform = transformself.data_list = os.listdir(data_dir)def __len__(self):return len(self.data_list)def __getitem__(self, index):data_path = os.path.join(self.data_dir, self.data_list[index])data = np.load(data_path)label = data['label']if self.transform is not None:data = self.transform(data)return data, label
参数说明:
data_dir
:数据集保存的路径transform
:数据转换函数,可选。1. 用于数据增强,一般的数据增强方法有:随机裁剪、随机旋转、随机翻转、随机缩放等。2. 也可以用于数据预处理,如归一化、标准化等。
- 定义数据转换函数:
def transform(data):data = data['data']data = data.astype(np.float32)data = data / 255.0data = torch.from_numpy(data)return data
- 加载数据集:
train_dataset = MyDataset(data_dir='./data/train', transform=transform)
test_dataset = MyDataset(data_dir='./data/test', transform=transform)
- 查看数据集信息:
print(len(train_dataset), len(test_dataset))
print(train_dataset[0][0].size, train_dataset[0][1])
- 可视化数据集:
plt.imshow(train_dataset[0][0], cmap='gray')
plt.show()
- 数据加载:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break
- 数据增强:
from torchvision import transformstransform = transforms.Compose([transforms.RandomCrop(28), # 随机裁剪,裁剪大小为28x28transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.RandomVerticalFlip(), # 随机垂直翻转transforms.RandomRotation(10), # 随机旋转transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)), # 随机仿射变换transforms.ToTensor() # 转换为张量
])
train_dataset = MyDataset(data_dir='./data/train', transform=transform)
test_dataset = MyDataset(data_dir='./data/test', transform=transform)train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:print(batch_data.shape, batch_labels.shape)break
DataLoader核心参数详解
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None,num_workers=0, collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None,
)
关键参数解析:
num_workers
:数据预加载进程数(建议设为CPU核心数的70-80%)pin_memory
:启用CUDA锁页内存加速GPU传输prefetch_factor
:每个worker预加载的batch数(PyTorch 1.7+)
数据加载性能优化公式
理论最大吞吐量:
T h r o u g h p u t = min ( B a t c h S i z e × n u m _ w o r k e r s D a t a L o a d T i m e , G P U C o m p u t e T i m e − 1 ) Throughput = \min\left(\frac{BatchSize \times num\_workers}{DataLoadTime}, GPUComputeTime^{-1}\right) Throughput=min(DataLoadTimeBatchSize×num_workers,GPUComputeTime−1)
三、拓展:多模态数据加载示例
class MultiModalDataset(Dataset):def __init__(self, img_dir, text_path):self.img_dir = img_dirself.text_data = pd.read_csv(text_path)self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')def __getitem__(self, idx):# 图像处理img_path = os.path.join(self.img_dir, self.text_data.iloc[idx]['image_id'])image = Image.open(img_path).convert('RGB')image = transforms.ToTensor()(image)# 文本处理text = self.text_data.iloc[idx]['description']inputs = self.tokenizer(text, padding='max_length', truncation=True, max_length=128)return {'image': image,'input_ids': torch.tensor(inputs['input_ids']),'attention_mask': torch.tensor(inputs['attention_mask'])}
四、总结
本文介绍了PyTorch中数据读取的基本概念、集成的开源数据集的读取方法、自定义数据集的读取方法和数据读取的流程。
数据读取是深度学习训练的重要环节,数据读取的流程是:
- 定义数据集类
- 定义数据转换函数、数据增强函数
- 加载数据集
📌 感谢阅读!若文章对你有用,别吝啬互动~
👍 点个赞 | ⭐ 收藏备用 | 💬 留下你的想法 ,关注我,更多干货持续更新!