说明:MNIST直接离线加载二进制文件到pytorch
'''
直接以下4个文件读入数据到pytorch中t10k-images-idx3-ubyte.gzt10k-labels-idx1-ubyte.gztrain-images-idx3-ubyte.gztrain-labels-idx1-ubyte.gz'''
import os
import numpy as np
import gzipimport torch.utils.data as Data
from torchvision import transformsimport timedataPath = 'E:/MNIST/binary_file'def load_data(data_folder, data_name, label_name):"""data_folder: 文件目录data_name: 数据文件名label_name:标签数据文件名"""with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)return (x_train, y_train)class DealDataset(Data.Dataset):"""读取数据、初始化数据"""def __init__(self, folder, data_name, label_name,transform=None):(train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式self.train_set = train_setself.train_labels = train_labelsself.transform = transformdef __getitem__(self, index):img, target = self.train_set[index], int(self.train_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):return len(self.train_set)# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset(dataPath, "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())
testDataset = DealDataset(dataPath, "t10k-images-idx3-ubyte.gz","t10k-labels-idx1-ubyte.gz",transform=transforms.ToTensor())# 训练数据和测试数据的装载
train_loader = Data.DataLoader(dataset=trainDataset,batch_size=100, # 一个批次可以认为是一个包,每个包中含有100张图片shuffle=False,
)test_loader = Data.DataLoader(dataset=testDataset,batch_size=100,shuffle=False,
)if __name__ == '__main__':# 这里trainDataset包含:train_labels, train_set等属性; 数据类型均为ndarrayprint(f'trainDataset.train_labels.shape:{trainDataset.train_labels.shape}\n')print(f'trainDataset.train_set.shape:{trainDataset.train_set.shape}\n')# 这里train_loader包含:batch_size、dataset等属性,数据类型分别为int,DealDataset# dataset中又包含train_labels, train_set等属性; 数据类型均为ndarrayprint(f'train_loader.batch_size: {train_loader.batch_size}\n')print(f'train_loader.dataset.train_labels.shape: {train_loader.dataset.train_labels.shape}\n')print(f'train_loader.dataset.train_set.shape: {train_loader.dataset.train_set.shape}\n')
运行结果