说明:Fashion_MNIST直接离线加载二进制文件到pytorch
'''
将4个gz直接加载到pytoch用来训练t10k-images-idx3-ubyte.gzt10k-labels-idx1-ubyte.gztrain-images-idx3-ubyte.gztrain-labels-idx1-ubyte.gz
'''import os
import numpy as np
import gzip
import matplotlib.pyplot as pltimport torch
import torch.utils.data as Data
from torchvision import datasets, transforms
from torch.autograd import Variableimport timedataPath = 'E:/fashion_binary_gz/'# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")batch_size = 4def load_data(data_folder, data_name, label_name):"""data_folder: 文件目录data_name: 数据文件名label_name:标签数据文件名"""with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是读取二进制数据y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)return (x_train, y_train)class DealDataset(Data.Dataset):"""读取数据、初始化数据"""def __init__(self, folder, data_name, label_name,transform=None):(train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式self.train_set = train_setself.train_labels = train_labelsself.transform = transformdef __getitem__(self, index):img, target = self.train_set[index], int(self.train_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):return len(self.train_set)# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset(dataPath,"train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())testDataset = DealDataset(dataPath,"t10k-images-idx3-ubyte.gz","t10k-labels-idx1-ubyte.gz",transform=transforms.ToTensor())# 训练数据和测试数据的装载
train_loader = Data.DataLoader(dataset=trainDataset,batch_size=100, # 一个批次可以认为是一个包,每个包中含有100张图片shuffle=False,
)test_loader = Data.DataLoader(dataset=testDataset,batch_size=100,shuffle=False,
)if __name__ == '__main__':# 这里trainDataset包含:train_labels, train_set等属性; 数据类型均为ndarrayprint(f'trainDataset.train_labels.shape:{trainDataset.train_labels.shape}\n')print(f'trainDataset.train_set.shape:{trainDataset.train_set.shape}\n')# 这里train_loader包含:batch_size、dataset等属性,数据类型分别为int,DealDataset# dataset中又包含train_labels, train_set等属性; 数据类型均为ndarrayprint(f'train_loader.batch_size: {train_loader.batch_size}\n')print(f'train_loader.dataset.train_labels.shape: {train_loader.dataset.train_labels.shape}\n')print(f'train_loader.dataset.train_set.shape: {train_loader.dataset.train_set.shape}\n')dataiter = iter(train_loader)images, labels = dataiter.next()images = images.numpy()classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']# plot the images in the batch, along with the corresponding labelsfig = plt.figure(figsize=(25, 4))for idx in np.arange(batch_size):ax = fig.add_subplot(2, batch_size/2, idx+1, xticks=[], yticks=[])# ax.imshow(np.squeeze(images[idx]), cmap='gray')ax.imshow(np.squeeze(images[idx]), cmap='gray')ax.set_title(classes[labels[idx]])plt.show()
运行结果
显示图像