BN做归一化时,使用的均值和方差是当前这个Batch的 如果这时 track_running_stats=True, 则会更新running_mean 和 running_var 但是,running_mean 和 running_var不用在训练阶段 BN 做归一化时,使用的均值和方差是BN存储的running_mean 和 running_var 不管这时track_running_stats 是 True 还是 False, 都不会更新 running_mean 和 running_var '''
Author: Chae Luv
Date: 2022-08-17 22:40:13
LastEditors: Chae Luv
LastEditTime: 2022-08-17 23:15:22
FilePath: /re-record-audio-watermark/10-base_model/test_bn.py
Description: Copyright (c) 2022 by Chae Luv/USTC, All Rights Reserved. 
''' 
import  torch
import  torch. nn as  nndef  create_inputs ( ) : return  torch. randn( 8 ,  3 ,  20 ,  20 ) def  simulated_bn_forward ( x,  bn_weight,  bn_bias,  eps,  mean_val= None ,  var_val= None ) : if  mean_val is  None : mean_val =  x. mean( [ 0 ,  2 ,  3 ] ) if  var_val is  None : var_val =  x. var( [ 0 ,  2 ,  3 ] ,  unbiased= False ) x =  x -  mean_val[ None ,  . . . ,  None ,  None ] x =  x /  torch. sqrt( var_val[ None ,  . . . ,  None ,  None ]  +  eps) x =  x *  bn_weight[ . . . ,  None ,  None ]  +  bn_bias[ . . . ,  None ,  None ] return  mean_val,  var_val,  xpytorch_bn =  nn. BatchNorm2d( num_features= 3 ,  momentum= None ) 
running_mean =  torch. zeros( 3 ) 
running_var =  torch. ones_like( running_mean) 
pytorch_bn. train( mode= False ) 
test_input =  create_inputs( ) 
print ( f'pytorch_bn running_mean is  { pytorch_bn. running_mean} ' ) 
print ( f'pytorch_bn running_var is  { pytorch_bn. running_var} ' ) 
bn_outputs =  pytorch_bn( test_input) 
print ( f'Now pytorch_bn running_mean is  { pytorch_bn. running_mean} ' ) 
print ( f'Now pytorch_bn running_var is  { pytorch_bn. running_var} ' ) 
_,  _,  simulated_outputs =  simulated_bn_forward( test_input,  pytorch_bn. weight, pytorch_bn. bias,  pytorch_bn. eps, running_mean,  running_var) 
assert  torch. allclose( simulated_outputs,  bn_outputs) 
pytorch_bn. train( mode= True ) 
pytorch_bn. track_running_stats =  False 
bn_outputs_notrack =  pytorch_bn( test_input) 
_,  _,  simulated_outputs_notrack =  simulated_bn_forward( test_input,  pytorch_bn. weight, pytorch_bn. bias,  pytorch_bn. eps) print ( torch. sum ( simulated_outputs_notrack -  bn_outputs_notrack) ) 
assert  torch. allclose( simulated_outputs_notrack,  bn_outputs_notrack) 
assert  not  torch. allclose( bn_outputs,  bn_outputs_notrack)