Llama改进之——均方根层归一化RMSNorm

引言

在学习完GPT2之后,从本文开始进入Llama模型系列。

本文介绍Llama模型的改进之RMSNorm(均方根层归一化)。它是由Root Mean Square Layer Normalization论文提出来的,可以参阅其论文笔记1

LayerNorm

层归一化(LayerNorm)对Transformer等模型来说非常重要,它可以帮助稳定训练并提升模型收敛性。LayerNorm针对一个样本所有特征计算均值和方差,然后使用这些来对样本进行归一化:
μ = 1 H ∑ i = 1 H x i , σ = 1 H ∑ i = 1 H ( x i − μ ) 2 , N ( x ) = x − μ σ , h = g ⊙ N ( x ) + b (1) \mu = \frac{1}{H}\sum_{i=1}^H x_i,\quad \sigma = \sqrt{\frac{1}{H}\sum_{i=1}^H (x_i - \mu)^2}, \quad N(\pmb x) = \frac{\pmb x-\mu}{\sigma},\quad \pmb h = \pmb g \,\odot N(\pmb x) + \pmb b \tag 1 μ=H1i=1Hxi,σ=H1i=1H(xiμ)2 ,N(x)=σxμ,h=gN(x)+b(1)
这里 x = ( x 1 , x 2 , ⋯ , x H ) \pmb x = (x_1,x_2,\cdots, x_H) x=(x1,x2,,xH)表示某个时间步LN层的输入向量表示,向量维度为 H H H h \pmb h h实LN层的输出; g , b \pmb g,\pmb b g,b实两个可学习的参数。

为什么层归一化有用?一些解释如下2

  1. 减少内部协变量偏移(Internal Covariate Shift): 内部协变量偏移是指在深度神经网络的训练过程中,每一层输入的分布会发生变化,导致网络的训练变得困难。层归一化通过对每一层的输入进行归一化处理,可以减少内部协变量偏移,使得每一层的输入分布更加稳定。
  2. 稳定化梯度: 层归一化有助于保持每一层输出的均值和方差稳定,从而使得梯度的传播更加稳定。这有助于减少梯度消失或梯度爆炸的问题,提高梯度在网络中的流动性,加快训练速度。
  3. 更好的参数初始化和学习率调整: 通过层归一化,每一层的输入分布被归一化到均值为0、方差为1的标准正态分布,这有助于更好地初始化网络参数和调整学习率。参数初始化与学习率调整的稳定性对模型的训练效果至关重要。
  4. 增强模型的泛化能力: 层归一化可以减少网络对训练数据分布的依赖,降低了过拟合的风险,从而提高模型的泛化能力。稳定的输入分布有助于模型更好地适应不同数据集和任务。

RMSNorm

虽然LayerNorm很好,但是它每次需要计算均值和方差。RMSNorm的思想就是移除(1)式中 μ \mu μ的计算部分1
x ˉ i = x i RMS ( x ) g i RMS ( x ) = 1 H ∑ i = 1 H x i 2 (2) \bar x_i = \frac{x_i }{ \text{RMS}(\pmb x)} g_i \quad \text{RMS}(\pmb x) =\sqrt{\frac{1}{H} \sum_{i=1}^H x_i^2} \tag 2 xˉi=RMS(x)xigiRMS(x)=H1i=1Hxi2 (2)

同时在实现也可以移除平移偏置 b \pmb b b

单看(2)式的话,相当于仅使用 x \pmb x x的均方根来对输入进行归一化,它简化了层归一化的计算,变得更加高效,同时还有可能带来性能上的提升。

实现

RMSNorm的实现很简单:

import torch
import torch.nn as nn
from torch import Tensorclass RMSNorm(nn.Module):def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(hidden_size))def _norm(self, hidden_states: Tensor) -> Tensor:variance = hidden_states.pow(2).mean(-1, keepdim=True)return hidden_states * torch.rsqrt(variance + self.eps)def forward(self, hidden_states: Tensor) -> Tensor:return self.weight * self._norm(hidden_states.float()).type_as(hidden_states)

torch.rsqrttorch.sqrt的倒数;eps是一个很小的数,防止除零;hidden_states.float()确保了标准差计算的精确度和稳定性,然后在forward方法中,通过.type_as(hidden_states)将结果转换回原来的数据类型,以保持与输入张量相同的数据类型,使得归一化处理后的结果与输入数据类型一致。

下面通过一个简单的网络来测试一下:

import torch
import torch.nn as nn
from torch import Tensorclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.linear = nn.Linear(in_features=10, out_features=5)self.rmsnorm = RMSNorm(hidden_size=5)def forward(self, x):x = self.linear(x)x = self.rmsnorm(x)return xnet = SimpleNet()input_data = torch.randn(2, 10)  # 2个样本,每个样本包含10个特征output = net(input_data)print("Input Shape:", input_data.shape)
print("Output Shape:", output.shape)
Input Shape: torch.Size([2, 10])
Output Shape: torch.Size([2, 5])

参考


  1. [论文笔记]Root Mean Square Layer Normalization ↩︎ ↩︎

  2. 批归一化和层归一化 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/827668.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android—— log的记忆

一、关键log 1.Java的 backtrace(堆栈log) 上述是一个空指针异常,问题出现在sgtc.settings,所以属于客户UI问题。 2.WindowManager(管理屏幕上的窗口和视图层次结构) 3.ActivityManager(管理应用程序生命周期和任务栈) 4.wifi操作 (1) 连接wifi&#…

开源大模型Llama3,堪比GPT-4。手把手本地安装,纯小白可操作,不需要编程经验,国内可下载,可视化使用。

最近最劲爆科技动态,Meta开源Llama3模型,最强开源模型。 Llama3发布后,扎克伯格亲自给媒体表示“要超越所有人,做最领先AI”。 吴恩达等一众大佬表示祝贺。 在线体验地址:https://www.meta.ai/ 不过国内在线体验基本…

OpenWRT磁盘扩容(PVE虚拟机方案)

官方扩容指导文档 PVE给虚拟机磁盘扩容 给虚拟机磁盘扩容,选中OpenWRT的硬盘,随后选择调整大小 输入增量大小,即增加多少磁盘空间给硬盘。这里我选择增加4G 进入OpenWRT控制台界面安装一些linux常用查看磁盘的工具(也可以通过网…

EasyPoi实现简单的Excel导出、导入

EasyPoi实现Excel导出、导入 下面这种方式不需要模板&#xff0c;更加方便但是不能进行复杂的导出 一、引入依赖 <dependency><groupId>cn.afterturn</groupId><artifactId>easypoi-base</artifactId><version>4.4.0</version><…

2024商业地产五一劳动节健康大会朋克养生市集活动策划方案

2024商业地产五一劳动节健康大会朋克养生市集&#xff08;带薪健康 快乐打工主题&#xff09;活动策划方案 活动策划信息&#xff1a; 方案页码&#xff1a;53页 文件格式&#xff1a;PPT 方案简介&#xff1a; 打工不养生 赚钱养医生 期待已久的五一假期&#xff0c; …

mysql download 2024

好久没在官网下载 mysql server 安装包。今天想下载发现&#xff1a; 我访问mysql官网的速度好慢啊。mysql server 的下载页面在哪里啊&#xff0c;一下两下找不到。 最后&#xff0c;慢慢悠悠终于找到了下载页面&#xff0c;如下&#xff1a; https://dev.mysql.com/downlo…

什么是全局特征,什么又是局部特征

全局特征和局部特征是用来描述数据中信息的两种不同方式&#xff0c;特别是在图像处理、模式识别和机器学习领域中经常被提到。它们有助于理解和分析数据的不同层面&#xff1a; 全局特征&#xff08;Global Features&#xff09; 全局特征描述了整个数据集的整体属性。在图像…

windows和linux服务器等保测评加固方法

服务器加固是通过各种方法增强服务器安全性的过程。保护操作系统免受黑客、破解者和攻击者的侵害。网络安全防护的目标是保密性、完整性、可用性、可控制性、不可否认性。 一、window服务器等保加固 以win2012和win2008 为例&#xff1a; &#xff08;win2008&#xff09; …

【UML建模】用例图

1 参与者 参与者的概念&#xff1a; 指系统以外的、需要使用系统或与系统交互的外部实体 可以分为&#xff1a;人、外部设备、外部系统 参与者的图形符号&#xff1a; 例 3.1 在一个银行业务系统中&#xff0c;可能会有以下参与者 客户 &#xff1a;在银行业务系统中办理…

React【Day2】

React表单控制 受控绑定 概念&#xff1a;使用React组件的状态&#xff08;useState&#xff09;控制表单的状态 双向绑定 MVVM 报错记录&#xff1a; 错误代码&#xff1a; import { useState } from "react";const App () > {const [value, setValue] useS…

TLV61048非同步升压BOOST转换器输入电压2.6-5.5V输出电流4A输出电压最高15V

推荐原因&#xff1a; 输入电压较低&#xff0c;输出电流可达3.5A SOT23-6封装 批量价格约0.70元 TLV61048引脚 TLV61048引脚功能 7 详细说明 7.1 概述 TLV61048是一款非同步升压转换器&#xff0c;支持高达 15 V 的输出电压和输入范围从 2.61 V 到 5.5 V。该TLV61048集成了…

LlamaIndex 加 Ollama 实现 Agent

AI Agent 是 AIGC 落地实现的场景之一&#xff0c;与 RAG 不同&#xff0c;RAG 是对数据的扩充&#xff0c;是模型可以学习到新数据或者本地私有数据。AI Agent 是自己推理&#xff0c;自己做&#xff0c;例如你对 AI Agent 说我要知道今天上海的天气怎么样&#xff0c;由于 AI…

serverLess

第一步 安装依赖 npm install serverless-devs/s g 第二步 配置秘钥&#xff1a; 第三步 执行终端 执行命令 s config add 选择 alibaba cloud &#xff08;alibaba&#xff09; 把对应的ID secret填写&#xff0c;第三个别名可以随便写&#xff1a; serverLess 查看是…

Kettle的安装及简单使用

Kettle的安装及简单使用 文章目录 Kettle的安装及简单使用一、kettle概述二、kettle安装部署和使用Windows下安装&#xff08; 1 &#xff09;概述 案例 1 &#xff1a;MySQL to MySQL主界面&#xff1a;**在kettle中新建转换--->输入--->表输入-->表输入双击**建立连…

js进行数据移除性能比较(splice,map)

当使用 splice() 方法处理大量数据时&#xff0c;确实会遇到性能问题&#xff0c;因为它涉及到移动数组中的元素&#xff0c;导致操作的时间复杂度为 O(n)。对于大量数据&#xff0c;频繁的插入和删除可能会导致性能下降。 1、设置数组数据为10000&#xff0c;使用splice移除数…

软件项目经理需要具备这 11 个能力

当前软件开发技术更新换代越来越快&#xff0c;各种项目实施管理思想也日新月异&#xff0c;作为一个软件项目经理&#xff0c;需要具备这 11 种能力&#xff1a; 1. 项目管理能力 了解项目管理的基本原则和方法&#xff0c;包括制定项目计划、资源分配、风险管理、问题解决和…

Python练习03

题目 解题思路 Demo58 通过字符串切片来进行反转操作 def _reverse():"""这是一个反转整数的函数"""num input("请输入想要反转的整数")print(num[::-1]) 运行结果 Demo61 首先制作一个判断边长的函数&#xff0c;通过三角形两边…

4.23学习总结

一.NIO(一) (一).简介: NIO 是 Java SE 1.4 引入的一组新的 I/O 相关的 API&#xff0c;它提供了非阻塞式 I/O、选择器、通道、缓冲区等新的概念和机制。相比与传统的 I/O 多出的 N 不是单纯的 New&#xff0c;更多的是代表了 Non-blocking 非阻塞&#xff0c;NIO具有更高的并…

【Linux高性能服务器编程】两种高性能并发模式剖析——半同步/半异步模式

hello &#xff01;大家好呀&#xff01; 欢迎大家来到我的Linux高性能服务器编程系列之两种高性能并发模式介绍&#xff0c;在这篇文章中&#xff0c;你将会学习到高效的创建自己的高性能服务器&#xff0c;并且我会给出源码进行剖析&#xff0c;以及手绘UML图来帮助大家来理解…

《HCIP-openEuler实验指导手册》1.4 Apache MPM工作模式调整

MPM介绍 二、配置步骤 查看MPM当前工作模式 方法一&#xff1a; httpd -M | grep mpm方法二&#xff1a; 浏览器访问&#xff1a;http://IP:端口/server-status 方法三&#xff1a; cat /etc/httpd/conf.modules.d/00-mpm.conf查看 LoadModule mpm_event_module modules/mo…