Layer Normalization 算法 和 Batch Normalization 算法的 python实现

Layer Normalization 算法

import torch
from torch import nnclass LN(nn.Module):# 初始化def __init__(self, normalized_shape,  # 在哪个维度上做LNeps:float = 1e-5, # 防止分母为0elementwise_affine:bool = True):  # 是否使用可学习的缩放因子和偏移因子super(LN, self).__init__()# 需要对哪个维度的特征做LN, torch.size查看维度self.normalized_shape = normalized_shape  # [c,w*h]self.eps = epsself.elementwise_affine = elementwise_affine# 构造可训练的缩放因子和偏置if self.elementwise_affine:  self.gain = nn.Parameter(torch.ones(normalized_shape))  # [c,w*h]self.bias = nn.Parameter(torch.zeros(normalized_shape))  # [c,w*h]# 前向传播def forward(self, x: torch.Tensor): # [b,c,w*h]# 需要做LN的维度和输入特征图对应维度的shape相同assert self.normalized_shape == x.shape[-len(self.normalized_shape):]  # [-2:]# 需要做LN的维度索引dims = [-(i+1) for i in range(len(self.normalized_shape))]  # [b,c,w*h]维度上取[-1,-2]维度,即[c,w*h]# 计算特征图对应维度的均值和方差mean = x.mean(dim=dims, keepdims=True)  # [b,1,1]mean_x2 = (x**2).mean(dim=dims, keepdims=True)  # [b,1,1]var = mean_x2 - mean**2  # [b,c,1,1]x_norm = (x-mean) / torch.sqrt(var+self.eps)  # [b,c,w*h]# 线性变换if self.elementwise_affine:x_norm = self.gain * x_norm + self.bias  # [b,c,w*h]return x_norm# ------------------------------- #
# 验证
# ------------------------------- #if __name__ == '__main__':x = torch.linspace(0, 23, 24, dtype=torch.float32)  # 构造输入层x = x.reshape([2,3,2*2])  # [b,c,w*h]# 实例化ln = LN(x.shape[1:])# 前向传播x = ln(x)print(x.shape)

运行结果:

torch.Size([2, 3, 4])

Batch Normalization 算法

import torch
from torch import nnclass MyBN:def __init__(self, momentum=0.01, eps=1e-5, feat_dim=2):"""初始化参数值:param momentum: 动量,用于计算每个batch均值和方差的滑动均值:param eps: 防止分母为0:param feat_dim: 特征维度"""# 均值和方差的滑动均值self._running_mean = np.zeros(shape=(feat_dim, ))self._running_var = np.ones(shape=(feat_dim, ))# 更新self._running_xxx时的动量self._momentum = momentum# 防止分母计算为0self._eps = eps# 对应Batch Norm中需要更新的beta和gamma,采用pytorch文档中的初始化值self._beta = np.zeros(shape=(feat_dim, ))self._gamma = np.ones(shape=(feat_dim, ))def batch_norm(self, x):"""BN向传播:param x: 数据:return: BN输出"""if self.training:x_mean = x.mean(axis=0)x_var = x.var(axis=0)# 对应running_mean的更新公式self._running_mean = (1-self._momentum)*x_mean + self._momentum*self._running_meanself._running_var = (1-self._momentum)*x_var + self._momentum*self._running_var# 对应论文中计算BN的公式x_hat = (x-x_mean)/np.sqrt(x_var+self._eps)else:x_hat = (x-self._running_mean)/np.sqrt(self._running_var+self._eps)return self._gamma*x_hat + self._beta

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/745883.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【大模型系列】图片生成(DDPM/VAE/StableDiffusion/ControlNet/LoRA)

文章目录 1 DDPM(UC Berkeley, 2020)1.1 如何使用DDPM生成图片1.2 如何训练网络1.3 模型原理 2 VAE:Auto-Encoding Variational Bayes(2022,Kingma)2.1 如何利用VAE进行图像增广2.2 如何训练VAE网络2.3 VAE原理2.3.1 Auto-Encoder2.3.2 VAE编码器2.3.3 VAE解码器 3 …

一文解读ISO26262安全标准:术语(二)

一文解读ISO26262安全标准:术语(二) 本文继续补充一些标准中的术语,方便后续文章内容的有效理解。 分支覆盖率 branch coverage 控制流分支覆盖的比率. 100%分支覆盖率意味着100%语句覆盖率,比如,一个if语句…

速盾cdn:服务器ip能直接加cdn吗

速盾CDN(Content Delivery Network)是一种通过在全球各地部署服务器节点,将网站或应用的内容分发到离用户最近的节点上,从而加速网站的访问速度和提升用户体验的技术方案。那么,服务器IP能否直接与CDN配合使用呢&#…

【UE5】持枪状态站立移动的动画混合空间

项目资源文末百度网盘自取 创建角色在持枪状态站立移动的动画混合空间 在BlendSpace文件夹中单击右键选择动画(Animation)中的混合空间(Blend Space) 选择SK_Female_Skeleton 命名为BS_RifleStand 打开 水平轴表示角色的方向,命名为Direction,方…

CASIA-HWDB手写体数据集gnt生成为png格式

👑一、数据集获取 1.1 官方链接获取gnt文件 http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.ziphttp://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip 1.2 百度网盘获取gnt文件 链接:https://pan.baidu.com/s/1pKa…

c++11 标准模板(STL)(std::locale)(五)用此 locale 的 collate 刻面以字典序比较两个字符串

用以封装文化差异的多态刻面的集合 std::locale 类型对象是不可变平面的不可变索引集。 C 输入/输出库的每个流对象与一个 std::locale 对象关联,并用其平面分析及格式化所有数据。另外, locale 对象与每个 std::basic_regex 对象关联。 locale 对象亦可…

Redis 的并发竞争问题是什么?如何解决这个问题?了解 Redis 事务的 CAS 方案吗?

目录 一、面试官心理分析 二、面试题剖析 一、面试官心理分析 这个也是线上非常常见的一个问题,就是多客户端同时并发写一个key,可能本来应该先到的数据后到了,导致数据版本错了;或者是多客户端同时获取一个 key,修改值之后再写回…

KKVIEW: 远程控制软件哪个好用

远程控制软件哪个好用 随着科技的发展和工作方式的改变,远程控制软件越来越受到人们的关注和需求。无论是在家中远程办公,还是技术支持人员为远程用户提供帮助,选择一款高效稳定的远程控制软件至关重要。在众多选择中,有几款远程…

51-30 World Model | 自动驾驶的世界模型:综述

24年3月,澳门大学和夏威夷大学联合发布的工作,World Models for Autonomous Driving: An Initial Survey。花时间反复看了几遍,刚开始觉得世界模型没用,空洞无序,根本不可能部署到实车上,后面逐渐相信&…

idea 导入项目

idea 导入项目并运行 导入设置设置 jdk查看maven 设置 导入 在项目首页 或者 file 选择 open, 然后选择项目根路径 设置 设置 jdk 查看maven 设置

[Python学习]变量存储逻辑和垃圾回收机制(GC)

一、引子 首先,我们从两个例子入手垃圾回收机制: a 1000 b 2000 a b a 100 b 200 a b 这两段代码的功能都是“把b变量的值赋值给a变量”,但是在Python的底层逻辑上,这两段代码的实现过程确是有所不同的。 过程:第一段代码…

基于java实用的音乐软件微信小程序的设计与实现【附项目源码】分享

基于实用的音乐软件微信小程序的设计与实现: 源码地址:https://download.csdn.net/download/weixin_43894652/88842586 一、引言 随着移动互联网的普及和微信小程序的兴起,音乐类小程序成为了用户随时随地享受音乐的重要工具。本需求文档旨在详细阐述一…

基于单片机的大棚温湿度控制系统设计

摘要:现阶段我国的科学技术方面得到了快速的发展,各项社会事业的发展也都进行了智能化技术的应用,农业事业智能化发展在现如今时代发展进程中变得越来越重要了,如果能够实现对大鹏的温度和湿度进行有效且稳定的控制,能够实现现代的农业大棚高水平的发展,这对于我国整体的…

Python 实现一个简单的中文分词处理?

在Python中,实现一个简单的中文分词处理,我们可以采用基于规则的方法,比如最大匹配法、最小匹配法、双向匹配法等。但更常见且效果更好的是使用现有的分词库,如jieba分词。   以下是使用jieba分词库进行中文分词的简单示例: 安装jieba 首先,你需要安装jieba库。如果你…

【图解物联网】第零章 前言

前言 一、本博文的写作背景 这个寒假(准确的说应该是上个学期),作者通过厚脸皮以及社牛的性格,抱住了一位老师的大腿,并且通过寒假期间突击补习,成功得到老师的赏识,得以进组进一步学习各…

c++简单使用

取消同步流是为了解决C有时遇到空格或回车&#xff08;不到\0&#xff09;就会停下的问题 #include<bits/stdc.h> using namespace std; int main() {//取消同步流ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);int a, b;cin >> a>> b;cout << …

拦截器和过滤器(原理区别)

目录 一、拦截器 拦截器是什么 拦截器的使用 拦截器的实现 导入依赖 实现HandlerInterceptor接口 注册拦截器 拦截器的生命周期 拦截器的执行顺序 拦截器的生命周期 多个拦截器的执行流程 拦截器的实际使用 拦截器实现日志记录 实现接口幂等性校验 拦截器的性能…

Python import 跟 Java import 有什么区别?

你好&#xff0c;我是 shengjk1&#xff0c;多年大厂经验&#xff0c;努力构建 通俗易懂的、好玩的编程语言教程。 欢迎关注&#xff01;你会有如下收益&#xff1a; 了解大厂经验拥有和大厂相匹配的技术等 希望看什么&#xff0c;评论或者私信告诉我&#xff01; 文章目录 一…

STL——map set

文章将解决一下几个问题&#xff1a; 1.什么是set 2.什么是map 3.set应用场景 4.map应用场景 序列式容器和关联式容器 数据结构有序列式容器和关联式容器&#xff0c;序列式容器一般有vector,list,deque…&#xff0c;但关联式容器中就有map&#xff0c;关联式容器也是用来存…

23.2 微服务基础实战

23.2 微服务基础实战 课程安排1. **************************************************************************************** 课程安排 1. ****************************************************************************************