大模型都在用的:旋转位置编码

写在前面

        这篇文章提到了绝对位置编码和相对位置编码,但是他们都有局限性,比如绝对位置编码不能直接表征token的相对位置关系;相对位置编码过于复杂,影响效率。于是诞生了一种用绝对位置编码的方式实现相对位置编码的编码方式——旋转位置编码(Rotary Position Embedding, RoPE),兼顾效率和相对位置关系。

        RoPE的核心思想是通过旋转的方式将位置信息编码到每个维度,从而使得模型能够捕捉到序列中元素的相对位置信息。现在已经在很多大模型证明了其有效性,比如ChatGLM、LLaMA等。

一、RoPE的优点

1.真正的旋转位置编码

        Transformer的原版位置编码也使用了三角函数,但它生成的是每个位置的绝对编码,三角函数的主要用途是生成具有可区分性的周期性模式,也没有应用旋转变换的概念,因此属于绝对位置编码。同时原版的编码使用加法,在多层传递后导致位置信息的稀释,如下图 (没想到这张图也有被当做反面典型的时候吧):

        RoPE不是简单的加法,而是通过复数乘法实现旋转变换,这种旋转是将位置信息融入到token表示中的关键机制。RoPE在实现过程中通过乘法操作融入位置信息,与模型中的Q和K深度融合,将旋转操作真正植入Attention机制内部,强化了位置编码信息的作用

2.更好的相对位置信息编码

        注意力机制通过计算Embedding的内积来确定它们之间的关系强度。

        使用RoPE时,两个位置的编码通过旋转变换后的内积,自然地包含了它们之间的相对位置信息。这是因为旋转操作保持了内积的性质,使得内积计算不仅反映了token的内容相似性,还反映了它们的位置关系。

3.更适用于多维输入

        这点很有意思,传统的Transformer位置编码主要针对一维序列,如文本序列。然而,在某些任务中,输入可能是二维或更高维的数据,如图像或视频数据。旋转位置编码可以更灵活地应用于多维输入数据,通过对不同维度的位置信息进行编码,使得模型能够更好地理解多维数据中的位置关系。

4. 更善于处理长序列

        RoPE可以减少位置信息的损失。在深层网络中,RoPE通过乘法操作融入位置信息,乘法操作有助于在深层网络中保持位置信息的完整性。在处理一个长文本时,RoPE通过在每一层的自注意力计算中使用旋转变换,确保了位置信息能够被有效保留和利用,即使是在模型的较深层次。

二、公式

        既然旋转的位置编码有这么多优点,那怎么实现位置编码的旋转呢,其实网上有很多介绍的文章。大概意思就是复数可以通过乘以e的幂来旋转角度,其中幂就是角度,再结合欧拉公式推出三角函数的表达,大致流程如下。

        欧拉公式:

e^{i\theta }=cos\theta +i\cdot sin\theta        (1)

        复数旋转角度θ:

(x+y\cdot i)e^{i\theta }                (2)

        将(1)带入(2):

(x+y\cdot i)e^{i\theta }=(xcos\theta -ysin\theta )+i(xsin\theta +ycos\theta )        (3)

        这块东西苏剑林老师已经从数学的角度进行过很深入的推导,这里的融合式部分,我就不班门弄斧了。我今天提供一种朴素的思考过程,从代码实现的角度思考如何进行旋转

        众所周知,一维向量是不能旋转的,那我们就旋转一个[2,d]的二维向量q,并且设x=q[0],y=q[1]即:

x=[q_0,q_1...,q_{d/2-1}],y=[q_{d/2},q_{d/2+1}...,q_{d-1}]        (4)

        要旋转q很容易,乘以旋转矩阵就可以了,如果我们要旋转角度θ:

R(\theta )=[x , y]\cdot \begin{bmatrix} cos(\theta ) & -sin(\theta )\\ sin(\theta ) & cos(\theta ) \end{bmatrix}                (5)

        展开之后,结果如下:

\begin{bmatrix} q_0 \\ ... \\ q_{d/2-1}\\ q_{d/2}\\ ...\\ q_{d-1} \end{bmatrix} \bigotimes \begin{bmatrix} cos\theta _0 \\ ... \\ cos\theta _{d-2}\\ cos\theta _{0}\\ ...\\ cos\theta _{d-2} \end{bmatrix} + \begin{bmatrix} -q_{d/2} \\ ... \\ -q_{d-1}\\ q_{0}\\ ...\\ q_{d/2-1} \end{bmatrix} \bigotimes \begin{bmatrix} cos\theta _0 \\ ... \\ cos\theta _{d-2}\\ cos\theta _{0}\\ ...\\ cos\theta _{d-2} \end{bmatrix}        (6)

        上面的\theta = \frac{pos}{10000^{\frac{2i}{d_{model}}}},很眼熟吧,就是沿用了transformer的机制,这里有详细的介绍。

        而且大家看到字母q也大概能猜到,这就是Attention中的Q,同样的操作也可以对K使用。经过上述操作,其实已经以旋转的方式将位置编码融合到Attention机制内部。

        下面就是根据式子(6)的代码实现了。这里提前说一句,ChatGLM的Q和K的形状都是[b,1,32,64],其中b是token_ids的长度;32是multi-head的个数;64将被拆成两部分,每部分32,也就是上面的x,y,下面开始代码实现部分。

三、代码实现

        我们以ChatGLM的代码为例,展示一下RoPE的使用,以下代码都在modeling_chatglm.py文件中,一条训练数据:

{"context": "你好", "target": "你好,我是大白话"}

1.字符串转换成token_ids

[ 5,  74874, 130001, 130004,  5,  74874, 6,  65806,  63850, 95351, 130005]

2.计算position_ids

        根据上面的token_ids计算出position_ids:

[[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2],[0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]]

        解释一下position_ids:第一行表示序列中每个元素的全局位置,第一个“2”表明context结束了,target要开始了,后面所有的2都是target部分;第二行则细化到更具体的局部位置,从1开始表征整个target的内容,这样用两个维度的编码很优雅的体现了context和target,这种层次化处理对于理解上下文非常重要。

        代码如下:

    def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):"""根据token_ids生成position_ids:param input_ids: 这里是[[ 5, 74874, 130001, 130004, 5, 74874, 6, 65806, 63850, 95351, 130005]]:param mask_positions: 2 输出的第1维mask掉几位,即这一位及其前面都是0,后面是1,2...:param device::param use_gmasks::return: [[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2],[0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]]"""batch_size, seq_length = input_ids.shapeif use_gmasks is None:use_gmasks = [False] * batch_sizecontext_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]if self.position_encoding_2d:# 会走这一分支position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)for i, context_length in enumerate(context_lengths):position_ids[i, context_length:] = mask_positions[i]block_position_ids = [torch.cat((torch.zeros(context_length, dtype=torch.long, device=device),torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1)) for context_length in context_lengths]block_position_ids = torch.stack(block_position_ids, dim=0)position_ids = torch.stack((position_ids, block_position_ids), dim=1)else:position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)for i, context_length in enumerate(context_lengths):if not use_gmasks[i]:position_ids[i, context_length:] = mask_positions[i]return position_ids

3.角度序列Embedding

        接下来,将position_ids转换成角度序列Embedding,下表中每个格的公式为

\theta_i = m\cdot \frac{1}{10000^\frac{2\cdot i}{d}}

        其中m是position_ids中元素的数值;i是编码的索引,ChatGLM使用两个0-31拼接;d是维度,hidden_size // (num_attention_heads * 2)=46:

        第一部分:position_ids=[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2],每个值编码成长度64的角度序列:

m | i01310131
0m=0, i=0m=0, i=1...m=0, i=31m=0, i=0m=0, i=1...m=0, i=31
1m=1, i=0m=1, i=1m=1, i=31m=1, i=0m=1, i=1m=1, i=31
2m=2, i=0m=2, i=1m=2, i=31m=2, i=0m=2, i=1m=2, i=31
...
2m=2, i=0m=2, i=1m=2, i=31m=2, i=0m=2, i=1m=2, i=31

        第二部分:block_position_ids=[0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]

m | i01310131
0m=0, i=0m=0, i=1...m=0, i=31m=0, i=0m=0, i=1...m=0, i=31
0m=0, i=0m=0, i=1...m=0, i=31m=0, i=0m=0, i=1...m=0, i=31
0m=0, i=0m=0, i=1...m=0, i=31m=0, i=0m=0, i=1...m=0, i=31
1m=1, i=0m=1, i=1m=1, i=31m=1, i=0m=1, i=1m=1, i=31
...
8m=8, i=0m=8, i=1m=8, i=31m=8, i=0m=8, i=1m=8, i=31

代码如下:

class RotaryEmbedding(torch.nn.Module):def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,error_msgs):passdef __init__(self, dim, base=10000, precision=torch.half, learnable=False):"""根据position_ids计算旋转角度的Embedding:param dim: 这里hidden_size // (num_attention_heads * 2)=46,其中hidden_size=4096 num_attention_heads=32:param base::param precision::param learnable:"""super().__init__()# 初始化“频率”,可以理解为position_id每增加1,增加的角度,是Embedding形式的。inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))inv_freq = inv_freq.half()self.learnable = learnableif learnable:self.inv_freq = torch.nn.Parameter(inv_freq)self.max_seq_len_cached = Noneelse:self.register_buffer('inv_freq', inv_freq)self.max_seq_len_cached = Noneself.cos_cached = Noneself.sin_cached = Noneself.precision = precisiondef forward(self, x, seq_dim=1, seq_len=None):if seq_len is None:seq_len = x.shape[seq_dim]if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):self.max_seq_len_cached = None if self.learnable else seq_len# 1.对position_ids去重并正序排列得到t,如:[[0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]] --> t=[[0, 1, 2]]t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)# 2.t与初始化好的“频率”做外积,得到每个position_id的角度,是Embeddingfreqs = torch.einsum('i,j->ij', t, self.inv_freq)# 3.每个Embedding重复叠加一次emb = torch.cat((freqs, freqs), dim=-1).to(x.device)if self.precision == torch.bfloat16:emb = emb.float()# 4.算cos和sin,并增加维度cos_cached = emb.cos()[:, None, :]sin_cached = emb.sin()[:, None, :]if self.precision == torch.bfloat16:cos_cached = cos_cached.bfloat16()sin_cached = sin_cached.bfloat16()if self.learnable:return cos_cached, sin_cachedself.cos_cached, self.sin_cached = cos_cached, sin_cachedreturn self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]# 类似于查表,根据每个position_id获取相应的Embeddingcos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)......

4.截取拼接Q和K

        这一步对Q或者K做截断,并将第二段取反拼在第一段的前面,拼接成公式第二项的q部分。

上述3、4流程示意图:

代码如下:

def rotate_half(x):x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]return torch.cat((-x2, x1), dim=x1.ndim - 1)  

5.旋转位置编码融合

        将旋转位置编码融合到Q和K中,计算第一部分的cos(\theta1)和sin(\theta1),并与输入的Q1、K1做乘法融合;计算第二部分的cos(\theta1)和sin(\theta1),并与输入的Q1、K1做乘法融合,最后将Q和K分别拼接,组成融合了旋转位置编码的新Q和K。整体流程图如下,其中rotary_pos_emb是上图,也就是步骤3、4:

代码如下:

def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]# 类似于查表,根据每个position_id获取相应的Embeddingcos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)# 执行旋转位置编码与QK的融合q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)return q, k# 整体流程如下
# 1.拆分出Q1、Q2、K1、K2
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
# 2.计算旋转Embedding
cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \position_ids[:, 1, :].transpose(0, 1).contiguous()
# 3.旋转位置编码融合
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
# 4.将拆分出的Q1、Q2、K1、K2合并成新的Q、K
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))

        位置编码对于Transformer的重要性毋庸置疑,旋转位置编码也确实解决了一些问题。最有意思的就是它是一个二维编码,将旋转信息通过乘法操作融入Attention机制内部,强化了位置编码信息,现在已经有很多开源大模型都使用了旋转位置编码,可见其效果不俗。

        旋转位置编码就介绍到这里,关注不迷路(#^.^#)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/828865.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习day1

一、人工智能三大概念 人工智能三大概念 人工智能(AI)、机器学习(ML)和深度学习(DL) 人工智能:人工智能是研究计算代理的合成和分析的领域。人工智能是使用计算机来模拟,而不是人类…

关于Android中的限定符

很多对于Android不了解或是刚接触Android的初学者来说,对于Android开发中出现的例如layout-large或者drawable-xxhdpi这样的文件夹赶到困惑,这这文件夹到底有什么用?什么时候用?这里简单的说一下。 其实,在上面例子中&…

基于OpenCV的人脸签到系统

效果图 目录文件 camerathread.h 功能实现全写在.h里了 class CameraThread : public QThread {Q_OBJECT public:CameraThread(){//打开序号为0的摄像头m_cap.open(0);if (!m_cap.isOpened()) {qDebug() << "Error: Cannot open camera";}//判断是否有文件,人脸…

iframe实现pdf预览,并使用pdf.js修改内嵌标题,解决乱码问题

项目中遇到文件预览功能,并且需要可以打印文件.下插件对于内网来说有点麻烦,正好iframe预览比较简单,且自带下载打印等功能按钮. 问题在于左上方的文件名乱码,网上找了一圈没有看到解决的,要么就是要收费要会员(ztmgs),要么直接说这东西改不了. 使用: 1.引入 PDF.js 库&…

Spring Boot集成Redisson实现延迟队列

项目场景&#xff1a; 在电商、支付等领域&#xff0c;往往会有这样的场景&#xff0c;用户下单后放弃支付了&#xff0c;那这笔订单会在指定的时间段后进行关闭操作&#xff0c;细心的你一定发现了像某宝、某东都有这样的逻辑&#xff0c;而且时间很准确&#xff0c;误差在1s内…

与AI对话:探索最佳国内可用的ChatGPT网站

与AI对话&#xff1a;探索最佳国内可用的ChatGPT网站 &#x1f310; 链接&#xff1a; GPTGod 点击可注册 &#x1f3f7;️ 标签&#xff1a; GPT-4 支持API 支持绘图 Claude &#x1f4dd; 简介&#xff1a;GPTGod 是一个功能全面的平台&#xff0c;提供GPT-4的强大功能&…

JavaEE——Spring Boot + jwt

目录 什么是Spring Boot jwt&#xff1f; 如何实现Spring Boot jwt&#xff1a; 1. 添加依赖 2、创建JWT工具类 3. 定义认证逻辑 4. 添加过滤器 5、 http请求测试 什么是Spring Boot jwt&#xff1f; Spring Boot和JWT&#xff08;JSON Web Token&#xff09;是一对常…

苍穹外卖学习

并不包含全部视频内容&#xff0c;大部分都按照操作文档来手搓代码&#xff0c;资料&#xff0c;代码都上传git。 〇、实际代码 0.1 Result封装 package com.sky.result;import lombok.Data;import java.io.Serializable;/*** 后端统一返回结果* param <T>*/ Data pub…

软考 系统架构设计师系列知识点之软件可靠性基础知识(5)

接前一篇文章&#xff1a;软考 系统架构设计师系列知识点之软件可靠性基础知识&#xff08;4&#xff09; 所属章节&#xff1a; 第9章. 软件可靠性基础知识 第1节 软件可靠性基本概念 9.1.3 可靠性目标 前文定量分析软件的可靠性时&#xff0c;使用失效强度来表示软件缺陷对…

20232937文兆宇 2023-2024-2 《网络攻防实践》实践七报告

20232937文兆宇 2023-2024-2 《网络攻防实践》实践七报告 1.实践内容 &#xff08;1&#xff09;使用Metasploit进行Linux远程渗透攻击 任务&#xff1a;使用Metasploit渗透测试软件&#xff0c;攻击Linux靶机上的Samba服务Usermap_script安全漏洞&#xff0c;获取目标Linux…

机器学习day3

一、距离度量 1.欧氏距离 2.曼哈顿距离 3.切比雪夫距离 4.闵可夫斯基距离 二、特征与处理 1.数据归一化 数据归一化是一种将数据按比例缩放&#xff0c;使之落入一个小的特定区间的过程。 代码实战 运行结果 2.数据标准化 数据标准化是将数据按照其均值和标准差进行缩放的过…

2024新版计算机网络视频教程65集完整版(视频+配套资料)

今日学计算机网络&#xff0c;众生皆叹难理解。 却见老师神乎其技&#xff0c;网络通畅如云烟。 协议层次纷繁复杂&#xff0c;ARP、IP、TCP、UDP。 路由器交换机相连&#xff0c;数据包穿梭无限。 网络安全重于泰山&#xff0c;防火墙、加密都来添。 恶意攻击时刻存在&#xf…

Visual Studio Code使用

目录 1.python的调试 2.c的运行 方法1&#xff1a; 方法2&#xff1a; 3.c的调试 3.1调试方法一&#xff1a;先生成执行文件&#xff0c;再调试 3.2调试方法二&#xff1a;同时生成执行文件&#xff0c;调试 4.tasks.json 与launch.json文件的参考 4.1C生成执行文件tas…

AI视频教程下载:用ChatGPT和 MERN 堆栈构建 SAAS 项目

这是一个关于 掌握ChatGPT 开发应用的全面课程&#xff0c;它将带领你进入 AI 驱动的 SAAS 项目的沉浸式世界。该课程旨在使你具备使用动态的 MERN 堆栈和无缝的 Stripe 集成来构建强大的 SAAS 平台所需的技能。 你将探索打造智能解决方案的艺术&#xff0c;深入研究 ChatGPT 的…

使用R语言进行简单的主成分分析(PCA)

主成分分析&#xff08;PCA&#xff09;是一种广泛使用的数据降维技术&#xff0c;它可以帮助我们识别数据中最重要的特征并简化复杂度&#xff0c;同时尽量保留原始数据的关键信息。在这篇文章中&#xff0c;我们将通过一个具体的例子&#xff0c;使用R语言实现PCA&#xff0c…

主成分分析(PCA):揭秘数据的隐藏结构

在数据分析的世界里&#xff0c;我们经常面临着处理高维数据的挑战。随着维度的增加&#xff0c;数据处理、可视化以及解释的难度也随之增加&#xff0c;这就是所谓的“维度的诅咒”。主成分分析&#xff08;PCA&#xff09;是一种强大的统计工具&#xff0c;用于减少数据的维度…

Maven的仓库、周期和插件

一、简介 随着各公司的Java项目入库方式由老的Ant改为Maven后&#xff0c;相信大家对Maven已经有了个基本的熟悉。但是在实际的使用、入库过程中&#xff0c;笔者发现挺多人对Maven的一些基本知识还缺乏了解&#xff0c;因此在此处跟大家简单地聊下Maven的相关内容&#xff0c…

基于STM32单片机的天然气与温湿度检测报警系统设计

基于STM32单片机的天然气与温湿度检测报警系统设计 一、引言 随着科技的发展和安全生产意识的提高&#xff0c;对于地下矿井等封闭环境中的天然气泄漏和温湿度变化的监控变得尤为重要。本文设计了一种基于STM32单片机的天然气与温湿度检测报警系统&#xff0c;旨在实时监控环…

OpenCV实现霍夫变换

返回:OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;OpenCV 如何实现边缘检测器 下一篇 :OpenCV 实现霍夫圆变换 目标 在本教程中&#xff0c;您将学习如何&#xff1a; 使用 OpenCV 函数 HoughLines()和 HoughLinesP()检测图像中的线条。…

Error opening file a bytes-like object is required,not ‘NoneType‘

错误显示&#xff0c;打开的是一个无效路径的文件 查看json文件内容&#xff0c;索引的路径与json文件保存的路径不同 方法&#xff1a;使用python脚本统一修改json文件路径 import json import os import argparse import cv2 from tqdm import tqdm import numpy as np impo…