说明直接离线加载cifar10到Pytorch
'''
直接加载6个文件到pytorchdata_batch_1data_batch_2data_batch_3data_batch_4data_batch_5test_batch'''import os
import cv2
import pickle
import numpy as np
import matplotlib.pyplot as pltimport torchvision
from torch.autograd import Variable
import torch.utils.data as Data
from torchvision import transforms#加载cifar10的数据
def load_CIFAR_batch(filename):""" load single batch of cifar """with open(filename, 'rb') as f:datadict = pickle.load(f,encoding='latin1')X = datadict['data']Y = datadict['labels']# X = X.reshape(10000, 3, 32,32).transpose(0,2,3,1).astype("float")X = X.reshape(10000, 3, 32,32).transpose(0,2,3,1)Y = np.array(Y)return X, Ydef load_CIFAR10(ROOT):""" load all of cifar """xs = []ys = []for b in range(1,6):filename = os.path.join(ROOT, 'data_batch_%d' % (b))X, Y = load_CIFAR_batch(filename)xs.append(X)ys.append(Y)Xtrain = np.concatenate(xs)#使变成行向量Ytrain = np.concatenate(ys)del X, YXtest, Ytest = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))return Xtrain, Ytrain, Xtest, Ytestclass DealDataset(Data.Dataset):"""读取数据、初始化数据"""def __init__(self, root, train=True, transform=None):if train:# 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式(train_set, train_labels, _, _) = load_CIFAR10(root)self.train_set = train_setself.train_labels = train_labelselse:(_, _, test_set, test_labels) = load_CIFAR10(root)self.test_set = test_setself.test_labels = test_labelsself.transform = transformself.train = traindef __getitem__(self, index):if self.train:img, target = self.train_set[index], int(self.train_labels[index])else:img, target = self.test_set[index], int(self.test_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):if self.train:return len(self.train_set)else:return len(self.test_set)root = r'E:\cifar-10-python\cifar-10-batches-py'
batch_size = 8# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset(root, train=True, transform=transforms.ToTensor())
testDataset = DealDataset(root, train=False, transform=transforms.ToTensor())# 训练数据和测试数据的装载
train_loader = Data.DataLoader(dataset=trainDataset,batch_size=batch_size, # 一个批次可以认为是一个包,每个包中含有batch_size张图片shuffle=False,
)test_loader = Data.DataLoader(dataset=testDataset,batch_size=batch_size,shuffle=False,
)if __name__ == '__main__':# 这里trainDataset包含:train_labels, train_set等属性; 数据类型均为ndarrayprint(f'trainDataset.train_labels.shape:{trainDataset.train_labels.shape}\n')print(f'trainDataset.train_set.shape:{trainDataset.train_set.shape}\n')# 这里train_loader包含:batch_size、dataset等属性,数据类型分别为int,DealDataset# dataset中又包含train_labels, train_set等属性; 数据类型均为ndarrayprint(f'train_loader.batch_size: {train_loader.batch_size}\n')print(f'train_loader.dataset.train_labels.shape: {train_loader.dataset.train_labels.shape}\n')print(f'train_loader.dataset.train_set.shape: {train_loader.dataset.train_set.shape}\n')# # 可视化1,使用OpenCV# images, lables = next(iter(train_loader))# img = torchvision.utils.make_grid(images, nrow = 10)# img = img.numpy().transpose(1, 2, 0)# # OpenCV默认为BGR,这里img为RGB,因此需要对调img[:,:,::-1]# cv2.imshow('img', img[:,:,::-1])# cv2.waitKey(0)# 可视化2,使用pltdataiter = iter(train_loader)images, labels = dataiter.next()images = images.numpy()classes = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']fig = plt.figure(figsize=(4, 4))for idx in np.arange(batch_size):ax = fig.add_subplot(2, batch_size/2, idx+1, xticks=[], yticks=[])# ax.imshow(np.squeeze(images[idx]), cmap='gray')# a = images[idx]# b = images[idx].transpose(1, 2, 0)# ax.imshow(images[idx].transpose(1, 2, 0), cmap='RGB')ax.imshow(images[idx].transpose(1, 2, 0))ax.set_title(classes[labels[idx]])plt.show()
运行结果
显示图