__len__ 和 __getitem__ 是 PyTorch torch.utils.data.Dataset 抽象类要求必须实现的两个方法,是 PyTorch 数据加载体系的「基石」——
__len__回答:数据集一共有多少个样本?__getitem__回答:给定一个索引 idx,如何获取对应的单个样本?
这两个方法配合 DataLoader,就能实现批量加载、打乱、多线程读取等功能,是 PyTorch 处理数据的标准范式。
__len__ 方法:返回数据集的总样本数
-
告诉
DataLoader「这个数据集有多少个样本」,是DataLoader计算批次、判断迭代终止的依据; -
支持 Python 内置的
len()函数:执行len(dataset)时,本质就是调用dataset.__len__()。 -
self.data是你在__init__中构造的时序样本列表(每个元素是(x, y)样本对);len(self.data)就是数据集的总样本数,比如你生成的仿真数据最终构造了 1899 个样本,__len__就返回 1899; -
实际用途:
DataLoader会用这个数值计算「一个 epoch 要迭代多少个批次」(总样本数 / 批次大小),比如总样本 1899、批次 32,一个 epoch 就迭代 59 批(1899//32=59,最后一批不足 32 个)。
__getitem__ 方法:根据索引获取单个样本
-
是数据集的「样本读取接口」:给定索引
idx,返回对应的单个样本(输入 + 标签); -
DataLoader批量加载数据时,本质是循环调用__getitem__(idx)获取单个样本,再堆叠成批次(batch); -
支持 Python 下标访问:执行
dataset[0]时,本质就是调用dataset.__getitem__(0)。 -
idx:是DataLoader传入的索引(0、1、2... 直到len(dataset)-1); -
self.data[idx]:取出第idx个时序样本(比如idx=0时,取出第一个(x, y)对,x 是 50 步 ×3 特征的历史数据,y 是对应的预测目标); -
转换为张量并移到指定设备:将 numpy 数组转为 PyTorch 张量,适配模型训练;
-
返回值:必须是「输入张量 + 标签张量」的格式,是模型训练时的基本数据单元。
总结
方法 核心作用 实际用途 缺失后果 __len__返回数据集总样本数 DataLoader 计算批次、支持 len (dataset) 无法计算迭代次数、len () 报错 __getitem__根据索引返回单个样本(输入 + 标签) DataLoader 批量加载样本、支持 dataset [idx] 无法获取样本、训练直接中断 简单来说:
__len__定义了数据集的「规模」,__getitem__定义了数据集的「读取规则」;- 这两个方法是 PyTorch 数据加载的「最小实现要求」,所有自定义
Dataset都必须实现它们,才能和DataLoader配合完成批量训练。