import  torch
from  torch import  nn
def  pool2d ( X,  pool_size,  mode= 'max' ) : p_h,  p_w =  pool_sizeY =  torch. zeros( ( X. shape[ 0 ]  -  p_h +  1 ,  X. shape[ 1 ]  -  p_w +  1 ) ) for  i in  range ( Y. shape[ 0 ] ) : for  j in  range ( Y. shape[ 1 ] ) : if  mode ==  'max' : Y[ i,  j]  =  X[ i:  i +  p_h,  j:  j +  p_w] . max ( ) elif  mode ==  'avg' : Y[ i,  j]  =  X[ i:  i +  p_h,  j:  j +  p_w] . mean( ) return  YX =  torch. tensor( [ [ 0.0 ,  1.0 ,  2.0 ] ,  [ 3.0 ,  4.0 ,  5.0 ] ,  [ 6.0 ,  7.0 ,  8.0 ] ] ) 
print ( pool2d( X,  ( 2 ,  2 ) ) ) 
"""
tensor([[4., 5.],[7., 8.]])
""" 
print ( pool2d( X,  ( 2 ,  2 ) ,  'avg' ) ) 
"""
tensor([[2., 3.],[5., 6.]])
""" 
X =  torch. arange( 16 ,  dtype= torch. float32) . reshape( ( 1 ,  1 ,  4 ,  4 ) ) 
pool2d =  nn. MaxPool2d( 3 ) 
print ( pool2d( X) ) 
pool2d =  nn. MaxPool2d( 3 ,  padding= 1 ,  stride= 2 ) 
print ( pool2d( X) ) 
"""
tensor([[[[ 5.,  7.],[13., 15.]]]])
""" pool2d =  nn. MaxPool2d( ( 2 ,  3 ) ,  padding= ( 0 ,  1 ) ,  stride= ( 2 ,  3 ) ) 
print ( pool2d( X) ) 
"""
tensor([[[[ 5.,  7.],[13., 15.]]]])
""" 
X =  torch. cat( ( X,  X +  1 ) ,  1 ) 
print ( X) 
print ( X. shape) 
"""
tensor([[[[ 0.,  1.,  2.,  3.],[ 4.,  5.,  6.,  7.],[ 8.,  9., 10., 11.],[12., 13., 14., 15.]],[[ 1.,  2.,  3.,  4.],[ 5.,  6.,  7.,  8.],[ 9., 10., 11., 12.],[13., 14., 15., 16.]]]])
torch.Size([1, 2, 4, 4])
""" pool2d =  nn. MaxPool2d( 3 ,  padding= 1 ,  stride= 2 ) 
"""
tensor([[[[ 5.,  7.],[13., 15.]],[[ 6.,  8.],[14., 16.]]]])
"""