import  matplotlib. pyplot as  plt
import  numpy as  np
import  torch 
import  torchvision
from  d2l import  torch as  d2l
from  torch import  nn 
from  PIL import  Image
import  liliPytorch as  lp
from  torch. utils. data import  Dataset,  DataLoaderplt. figure( 'cat' ) 
img =  Image. open ( '../limuPytorch/images/cat.jpg' ) 
plt. imshow( img) def  apply ( img,  aug,  num_rows= 2 ,  num_cols= 4 ,  scale= 1.5 ) : """img: 输入的图像。aug: 增强函数,接受一个图像作为输入并返回一个增强后的图像。num_rows: 显示增强后图像的行数,默认值为2。num_cols: 显示增强后图像的列数,默认值为4。scale: 显示图像的缩放比例,默认值为1.5。""" Y =  [ aug( img)  for  _ in  range ( num_rows *  num_cols) ] d2l. show_images( Y,  num_rows,  num_cols,  scale= scale) 
apply ( img, torchvision. transforms. RandomHorizontalFlip( ) ) 
apply ( img, torchvision. transforms. RandomVerticalFlip( ) ) 
shape_aug =  torchvision. transforms. RandomResizedCrop( ( 200 , 200 ) , scale= ( 0.1 , 1 ) , ratio= ( 0.5 , 2 ) ,  
) apply ( img, shape_aug) 
apply ( img,  torchvision. transforms. ColorJitter( brightness= 0.5 ,  contrast= 0 ,  saturation= 0 ,  hue= 0 ) 
) 
apply ( img,  torchvision. transforms. ColorJitter( brightness= 0 ,  contrast= 0 ,  saturation= 0 ,  hue= 0.5 ) 
) 
color_aug =  torchvision. transforms. ColorJitter( brightness= 0.5 ,  contrast= 0.5 ,  saturation= 0.5 ,  hue= 0.5 ) 
apply ( img,  color_aug) 
augs =  torchvision. transforms. Compose( [ torchvision. transforms. RandomHorizontalFlip( ) ,  color_aug,  shape_aug] ) 
apply ( img,  augs) 
all_images =  torchvision. datasets. CIFAR10( train= True ,  root= "../data" , download= True ) 
d2l. show_images( [ all_images[ i] [ 0 ]  for  i in  range ( 32 ) ] ,  4 ,  8 ,  scale= 0.8 ) 
plt. show( ) 
train_augs =  torchvision. transforms. Compose( [ torchvision. transforms. RandomHorizontalFlip( ) , torchvision. transforms. ToTensor( ) ] ) test_augs =  torchvision. transforms. Compose( [ torchvision. transforms. ToTensor( ) ] ) 
def  load_cifar10 ( is_train,  augs,  batch_size) : dataset =  torchvision. datasets. CIFAR10( root= "../data" ,  train= is_train, transform= augs,  download= True ) dataloader =  torch. utils. data. DataLoader( dataset,  batch_size= batch_size, shuffle= is_train,  num_workers= 4 ) return  dataloadernet =  d2l. resnet18( 10 ,  3 ) 
batch_size =  256 
lr= 0.001 
num_epochs =  10 
train_iter =  load_cifar10( True ,  train_augs,  batch_size) 
test_iter =  load_cifar10( False ,  test_augs,  batch_size) lp. train_ch6( net,  train_iter,  test_iter,  num_epochs,  lr,  lp. try_gpu( ) ) 
plt. show( )