import  torch
from  torch import  nn
from  d2l import  torch as  d2ldef  corr2d ( X,  K) : """计算二维互相关运算""" h,  w =  K. shapeY =  torch. zeros( ( X. shape[ 0 ]  -  h +  1 ,  X. shape[ 1 ]  -  w +  1 ) ) for  i in  range ( Y. shape[ 0 ] ) : for  j in  range ( Y. shape[ 1 ] ) : Y[ i,  j]  =  ( X[ i: i +  h,  j: j +  w]  *  K) . sum ( ) return  YX =  torch. tensor( [ [ 0.0 ,  1.0 ,  2.0 ] ,  [ 3.0 ,  4.0 ,  5.0 ] ,  [ 6.0 ,  7.0 ,  8.0 ] ] ) 
K =  torch. tensor( [ [ 0.0 ,  1.0 ] ,  [ 2.0 ,  3.0 ] ] ) 
print ( corr2d( X,  K) ) 
"""
tensor([[19., 25.],[37., 43.]])
""" 
class  Conv2D ( nn. Module) : def  __init__ ( self,  kernel_size) : super ( ) . __init__( ) self. weight =  nn. Parameter( torch. rand( kernel_size) ) self. bias =  nn. Parameter( torch. zeros( 1 ) )  def  forward ( self,  x) : return  corr2d( x,  self. weight)  +  self. bias
X =  torch. ones( ( 6 , 8 ) ) 
X[ : ,  2 : 6 ]  =  0 
print ( X) 
"""
tensor([[1., 1., 0., 0., 0., 0., 1., 1.],[1., 1., 0., 0., 0., 0., 1., 1.],[1., 1., 0., 0., 0., 0., 1., 1.],[1., 1., 0., 0., 0., 0., 1., 1.],[1., 1., 0., 0., 0., 0., 1., 1.],[1., 1., 0., 0., 0., 0., 1., 1.]])
""" 
K =  torch. tensor( [ [ 1.0 ,  - 1.0 ] ] ) 
Y =  corr2d( X,  K) 
print ( Y) 
"""
tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],[ 0.,  1.,  0.,  0.,  0., -1.,  0.],[ 0.,  1.,  0.,  0.,  0., -1.,  0.],[ 0.,  1.,  0.,  0.,  0., -1.,  0.],[ 0.,  1.,  0.,  0.,  0., -1.,  0.],[ 0.,  1.,  0.,  0.,  0., -1.,  0.]])
""" 
Z =  torch. tensor( [ [ 1.0 ] , [ - 1.0 ] ] ) 
print ( corr2d( X. t( ) ,  Z) ) 
conv2d =  nn. Conv2d( 1 ,  1 ,  kernel_size= ( 1 ,  2 ) ,  bias= False ) 
X =  X. reshape( ( 1 ,  1 ,  6 ,  8 ) ) 
Y =  Y. reshape( ( 1 ,  1 ,  6 ,  7 ) ) 
lr =  3e-2   for  i in  range ( 10 ) : Y_hat =  conv2d( X) l =  ( Y_hat -  Y)  **  2 conv2d. zero_grad( )  l. sum ( ) . backward( )  conv2d. weight. data[ : ]  -=  lr *  conv2d. weight. grad if  ( i +  1 )  %  2  ==  0 : print ( f'epoch  { i+ 1 } , loss  { l. sum ( ) : .3f } ' ) 
print ( conv2d. weight. data. reshape( 1 , 2 ) ) 
"""
epoch 2, loss 3.004
epoch 4, loss 0.793
epoch 6, loss 0.251
epoch 8, loss 0.091
epoch 10, loss 0.035
tensor([[ 1.0120, -0.9741]])
"""