# include <torch/torch.h> # include <torch/script.h> # include <iostream> using  std:: cout;  using  std:: endl; class  LinearBnReluImpl  :  public  torch:: nn:: Module { 
private : torch:: nn:: Linear ln{  nullptr  } ; torch:: nn:: BatchNorm1d bn{  nullptr  } ; public : LinearBnReluImpl ( int  input_features,  int  out_features) ; torch:: Tensor forward ( torch:: Tensor x) ; 
} ; 
TORCH_MODULE ( LinearBnRelu) ; inline  torch:: nn:: Conv2dOptions conv_options ( int64_t  in_planes,  int64_t  out_planes,  int64_t  kernel_size, int64_t  stride =  1 ,  int64_t  padding =  0 ,  bool  with_bias =  false 
) 
{ torch:: nn:: Conv2dOptions conv_options =  torch:: nn:: Conv2dOptions ( in_planes,  out_planes,  kernel_size) ; conv_options. stride ( stride) ; conv_options. padding ( padding) ; conv_options. bias ( with_bias) ; return  conv_options; 
} class  ConvReluBnImpl  :  public  torch:: nn:: Module { 
private : torch:: nn:: Conv2d conv{  nullptr  } ; torch:: nn:: BatchNorm2d bn{  nullptr  } ; public : ConvReluBnImpl ( int  input_channel,  int  output_channel,  int  kernel_size,  int  stride,  int  padding= 1 ) ; torch:: Tensor forward ( torch:: Tensor x) ; 
} ; 
TORCH_MODULE ( ConvReluBn) ; class  MLP  :  public  torch:: nn:: Module { 
private : int  mid_features[ 3 ]  =  {  32 ,  64 ,  128  } ; LinearBnRelu ln1{  nullptr  } ; LinearBnRelu ln2{  nullptr  } ; LinearBnRelu ln3{  nullptr  } ; torch:: nn:: Linear out_ln{  nullptr  } ; public : MLP ( int  in_features,  int  out_features) ; torch:: Tensor forward ( torch:: Tensor x) ; 
} ; class  plainCNN  :  public  torch:: nn:: Module { 
private : int  mid_channels[ 3 ] {  32 , 64 , 128  } ; ConvReluBn conv1{  nullptr  } ; ConvReluBn down1{  nullptr  } ; ConvReluBn conv2{  nullptr  } ; ConvReluBn down2{  nullptr  } ; ConvReluBn conv3{  nullptr  } ; ConvReluBn down3{  nullptr  } ; torch:: nn:: Conv2d out_conv{  nullptr  } ; public : plainCNN ( int  in_channels,  int  out_channels) ; torch:: Tensor forward ( torch:: Tensor x) ; 
} ; int  main ( ) 
{ plainCNN c ( 3 ,  2 ) ; auto  x =  torch:: rand ( {  1 , 3 , 224 , 224  } ,  torch:: kFloat) ; auto  a =  c. forward ( x) ; cout << "[in Main]: " <<  a. sizes ( )  <<  endl; return  0 ; 
} LinearBnReluImpl :: LinearBnReluImpl ( int  input_features,  int  out_features) 
{ ln =  register_module ( "ln" ,  torch:: nn:: Linear ( torch:: nn:: LinearOptions ( input_features,  out_features) ) ) ; bn =  register_module ( "bn" ,  torch:: nn:: BatchNorm1d ( out_features) ) ; 
} torch:: Tensor LinearBnReluImpl :: forward ( torch:: Tensor x) 
{ x =  torch:: relu ( ln-> forward ( x) ) ; x =  bn ( x) ; return  x; 
} ConvReluBnImpl :: ConvReluBnImpl ( int  input_channel,  int  output_channel,  int  kernel_size,  int  stride,  int  padding) 
{ conv =  register_module ( "conv" ,  torch:: nn:: Conv2d ( conv_options ( input_channel,  output_channel,  kernel_size,  stride,  padding) ) ) ; bn =  register_module ( "bn" ,  torch:: nn:: BatchNorm2d ( output_channel) ) ; 
} torch:: Tensor ConvReluBnImpl :: forward ( torch:: Tensor x) 
{ x =  torch:: relu ( conv-> forward ( x) ) ; x =  bn ( x) ; return  x; 
} MLP :: MLP ( int  in_features,  int  out_features) 
{ ln1 =  LinearBnRelu ( in_features,  mid_features[ 0 ] ) ; ln2 =  LinearBnRelu ( mid_features[ 0 ] ,  mid_features[ 1 ] ) ; ln3 =  LinearBnRelu ( mid_features[ 1 ] ,  mid_features[ 2 ] ) ; out_ln =  torch:: nn:: Linear ( mid_features[ 2 ] ,  out_features) ; ln1 =  register_module ( "ln1" ,  ln1) ; ln2 =  register_module ( "ln2" ,  ln2) ; ln3 =  register_module ( "ln3" ,  ln3) ; out_ln =  register_module ( "out_ln" ,  out_ln) ; 
} torch:: Tensor MLP :: forward ( torch:: Tensor x) 
{ x =  ln1-> forward ( x) ; x =  ln2-> forward ( x) ; x =  ln3-> forward ( x) ; x =  out_ln-> forward ( x) ; return  x; 
} plainCNN:: plainCNN ( int  in_channels,  int  out_channels) 
{ conv1 =  ConvReluBn ( in_channels,  mid_channels[ 0 ] ,  3 ,  1 ) ; down1 =  ConvReluBn ( mid_channels[ 0 ] ,  mid_channels[ 0 ] ,  3 ,  2 ) ; conv2 =  ConvReluBn ( mid_channels[ 0 ] ,  mid_channels[ 1 ] ,  3 , 1 ) ; down2 =  ConvReluBn ( mid_channels[ 1 ] ,  mid_channels[ 1 ] ,  3 ,  2 ) ; conv3 =  ConvReluBn ( mid_channels[ 1 ] ,  mid_channels[ 2 ] ,  3 , 1 ) ; down3 =  ConvReluBn ( mid_channels[ 2 ] ,  mid_channels[ 2 ] ,  3 ,  2 ) ; out_conv =  torch:: nn:: Conv2d ( conv_options ( mid_channels[ 2 ] ,  out_channels,  3 ) ) ; conv1 =  register_module ( "conv1" ,  conv1) ; down1 =  register_module ( "down1" ,  down1) ; conv2 =  register_module ( "conv2" ,  conv2) ; down2 =  register_module ( "down2" ,  down2) ; conv3 =  register_module ( "conv3" ,  conv3) ; down3 =  register_module ( "down3" ,  down3) ; out_conv =  register_module ( "out_conv" ,  out_conv) ; 
} torch:: Tensor plainCNN:: forward ( torch:: Tensor x) 
{ x =  conv1-> forward ( x) ; cout <<  x. sizes ( )  <<  endl; x =  down1-> forward ( x) ; cout <<  x. sizes ( )  <<  endl; x =  conv2-> forward ( x) ; cout <<  x. sizes ( )  <<  endl; x =  down2-> forward ( x) ; cout <<  x. sizes ( )  <<  endl; x =  conv3-> forward ( x) ; cout <<  x. sizes ( )  <<  endl; x =  down3-> forward ( x) ; cout <<  x. sizes ( )  <<  endl; x =  out_conv-> forward ( x) ; cout <<  x. sizes ( )  <<  endl; return  x; 
}