BatchNorm可以加速模型的收敛并且缓解梯度消失问题,是深度学习领域常用的一个技术
最近仔细学习了BatchNorm的原理,因此想自己动手实现一下它,加深理解
代码如下:
import torch
import torch.nn as nnclass MyBatchNorm(nn.Module):# def __init__(self, batch_norm2, dim):def __init__(self, dim):super().__init__()# 可训练参数 gamma和betaself.gamma = nn.Parameter(data=torch.randn((dim)))self.beta = nn.Parameter(data=torch.randn((dim)))# 全局的均值和方差self.mean_whole = torch.zeros((dim))self.var_whole = torch.zeros((dim))self.lba = 0.99# 防止除零错误self.eps = 1e-7def forward(self, x):# 检查形状if x.dim() == 4:x = x.reshape(x.size(0), x.size(1), -1)assert x.dim() == 3# 处于训练状态if self.training:# 首先计算每个通道的均值和方差# (b, c, d) -> (1, c, 1)mean_batch = torch.mean(x, dim=[0, 2], keepdim=True)var_batch = torch.var(x, dim=[0, 2], keepdim=True, unbiased=False)# 使用滑动平均办法计算全局均值和方差n = x.numel() / x.size(1)self.mean_whole = self.lba * self.mean_whole + (1 - self.lba) * mean_batchself.var_whole = self.lba * self.var_whole + (1 - self.lba) * var_batch * n / (n-1)# 然后归一化数据x = (x - mean_batch) / torch.sqrt((var_batch + self.eps))else:# 归一化数据x = (x - self.mean_whole[None, ..., None]) / torch.sqrt((self.var_whole[None, ..., None] + self.eps))# 放缩平移x = x * self.gamma[None, ..., None] + self.beta[None, ..., None]return xx = torch.randn((2, 3, 4))batch_norm = MyBatchNorm(dim=3)
batch_norm = batch_norm.train()b = batch_norm(x)print(b.shape)
参考资料:
1. 原理
https://zhuanlan.zhihu.com/p/34879333
2. 代码
https://zhuanlan.zhihu.com/p/337732517