"""
CIFAR-10 是 32X32 的彩色图片,共有10个类别,每个类别6000张图片,50000张训练图片(均分为5个batch),10000张测试图片(每个类别选1000张)
将 CIFAR-10 转为 png
""" import  os
import  pickleimport  numpy as  np
from  imageio import  imwrite
base_dir =  r'H:\DataStore' 
data_dir =  os. path. join( base_dir,  'cifar-10-batches-py' ) 
train_dir =  os. path. join( base_dir,  'cifar-10-train-png' ) 
test_dir =  os. path. join( base_dir,  'cifar-10-test-png' ) 
Train =  False 
Test =  True 
def  unpickle ( file_path) : with  open ( file_path,  'rb' )  as  f: _obj =  pickle. load( f,  encoding= 'bytes' ) return  _obj
def  create_dir ( dir_path) : if  not  os. path. isdir( dir_path) : os. makedirs( dir_path) def  get_label_names ( ) : _label_names_obj =  unpickle( os. path. join( data_dir,  'batches.meta' ) ) return  _label_names_obj[ b'label_names' ] def  save_images ( i,  obj,  class_num,  label_names,  dir_path) : img =  np. reshape( obj[ b'data' ] [ i] ,  ( 3 ,  32 ,  32 ) ) img =  img. transpose( 1 ,  2 ,  0 ) label_idx =  obj[ b'labels' ] [ i] _label_name:  str  =  label_names[ label_idx] . decode( ) train_dir_label_name_path =  os. path. join( dir_path,  _label_name) create_dir( train_dir_label_name_path) class_num[ label_idx]  +=  1 _image_name =  str ( class_num[ label_idx] )  +  '.png' image_path =  os. path. join( train_dir_label_name_path,  _image_name) imwrite( image_path,  img) if  __name__ ==  '__main__' : _label_names =  get_label_names( ) if  Train: train_class_num =  [ 0 ]  *  10 for  i in  range ( 1 ,  6 ) : data_batch_path =  os. path. join( data_dir,  'data_batch_'  +  str ( i) ) train_batch_obj =  unpickle( data_batch_path) print ( "{} is loading..." . format ( data_batch_path) ) for  j in  range ( 0 ,  10000 ) : save_images( j,  train_batch_obj,  train_class_num,  _label_names,  train_dir) print ( 'train loaded' ) if  Test: test_class_num =  [ 0 ]  *  10 test_data_path =  os. path. join( data_dir,  'test_batch' ) test_obj =  unpickle( test_data_path) for  i in  range ( 10000 ) : save_images( i,  test_obj,  test_class_num,  _label_names,  test_dir) print ( 'test loaded' )