【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE

在自然语言处理(NLP)领域,Transformer 模型已经成为主流。然而,Transformer 本身并不具备处理序列顺序的能力。为了让模型理解文本中词语的相对位置,我们需要引入位置编码(Positional Encoding)。本文将深入探讨 LLaMA 模型中使用的 Rotary Embedding(旋转式嵌入)位置编码方法,并对比传统的 Transformer 位置编码方案,分析其设计与实现的优势。

1. 传统 Transformer 的位置编码

1.1 正弦余弦编码

在原始的 Transformer 模型中,使用了基于正弦和余弦函数的位置编码。这种编码方式的公式如下:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中:

  • pos 代表词语在序列中的位置。
  • i 代表编码向量的维度索引。
  • d_model 是模型的维度大小。

这种编码方式的主要特点是:

  • 绝对位置编码: 为每个位置生成唯一的向量。
  • 易于泛化到更长的序列: 可以外推到训练期间未见过的序列长度。
  • 维度变化: 编码向量的每个维度上的频率都不同。

1.2 代码示例 (PyTorch)

import torch
import mathdef positional_encoding(pos, d_model):pe = torch.zeros(1, d_model)for i in range(0, d_model, 2):pe[0, i] = math.sin(pos / (10000 ** (i / d_model)))pe[0, i + 1] = math.cos(pos / (10000 ** (i / d_model)))return pe# 示例
d_model = 512
max_len = 10
pos_encodings = torch.stack([positional_encoding(i, d_model) for i in range(max_len)])print("Position Encodings Shape:", pos_encodings.shape) # 输出: torch.Size([10, 1, 512])
print("First 3 position encodings:\n", pos_encodings[:3])

1.3 缺点

传统的正弦余弦位置编码虽然有效,但也有其局限性:

  • 缺乏相对位置信息: 尽管编码能提供绝对位置,但难以直接捕捉词语之间的相对距离关系。
  • 位置编码与输入向量独立: 位置编码是直接加到输入词向量上的,没有与词向量进行交互,信息损失比较明显。

2. LLaMA 的 Rotary Embedding (RoPE)

LLaMA 模型采用了 Rotary Embedding(RoPE),一种相对位置编码方法,它通过旋转的方式将位置信息嵌入到词向量中。RoPE 的核心思想是将位置信息编码为旋转矩阵,然后将词向量进行旋转,从而引入位置信息。

2.1 RoPE 的核心公式

RoPE 的核心公式如下:

RoPE(q, k, pos) = rotate(q, pos, Θ)

其中:

  • qk 分别代表查询向量和键向量。
  • pos 是两个向量之间的相对位置。
  • Θ 是一个旋转矩阵,根据 pos 和预定义的频率生成。
  • rotate(q, pos, Θ) 表示将 q 旋转 Θ 角度后的结果。

更具体来说,对于维度为 d 的向量 q,RoPE 将其分为 d/2 对 (q0, q1), (q2, q3) …, (qd-2, qd-1)。每个维度对应用不同的旋转角度。旋转矩阵 R 的定义是:

R(pos) =  [[cos(pos * θ_0), -sin(pos * θ_0)],[sin(pos * θ_0),  cos(pos * θ_0)]]  [[cos(pos * θ_1), -sin(pos * θ_1)],[sin(pos * θ_1),  cos(pos * θ_1)]]...[[cos(pos * θ_d/2-1), -sin(pos * θ_d/2-1)],[sin(pos * θ_d/2-1),  cos(pos * θ_d/2-1)]]

其中 θ_i = 10000^(-2i/d) ,每个维度对的旋转角度不同。

将旋转矩阵应用于向量 q ,就是:
q_rotated = R(pos) * q

2.2 LLaMA 源码实现

下面是 LLaMA 中 RoPE 的核心代码(简化版,使用 PyTorch):

import torch
import mathdef precompute_freqs(dim, end, theta=10000.0):freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))t = torch.arange(end)freqs = torch.outer(t, freqs)return torch.cat((freqs, freqs), dim=1)def apply_rotary_emb(xq, xk, freqs):xq_complex = torch.complex(xq.float(), torch.roll(xq.float(), shifts=-xq.shape[-1]//2, dims=-1))xk_complex = torch.complex(xk.float(), torch.roll(xk.float(), shifts=-xk.shape[-1]//2, dims=-1))freqs_complex = torch.complex(torch.cos(freqs), torch.sin(freqs))xq_rotated = xq_complex * freqs_complexxk_rotated = xk_complex * freqs_complexreturn xq_rotated.real.type_as(xq), xk_rotated.real.type_as(xk)# 示例
batch_size = 2
seq_len = 5
d_model = 512
head_dim = d_model//8
xq = torch.randn(batch_size, seq_len, 8, head_dim) # 输入查询向量
xk = torch.randn(batch_size, seq_len, 8, head_dim) # 输入键向量freqs = precompute_freqs(head_dim, seq_len)
xq_rotated, xk_rotated  = apply_rotary_emb(xq, xk, freqs)
print("Rotated Query Shape:", xq_rotated.shape)
print("Rotated Key Shape:", xk_rotated.shape)

代码解释

  1. precompute_freqs(dim, end, theta):
    • 此函数用于预计算旋转矩阵中使用的频率。
    • dim: 表示词向量维度。
    • end: 表示最大序列长度。
    • 返回包含所有位置的频率列表。
  2. apply_rotary_emb(xq, xk, freqs):
    • 函数将旋转操作应用于查询向量 xq 和键向量 xk
    • 通过 complex 表示实数向量的旋转,并使用复数乘法完成旋转操作。
    • 使用 torch.roll() 函数将 xq 分成实部和虚部,使用complex类型可以更快的完成旋转计算,避免了循环遍历,提高计算速度。
    • 使用复数乘法完成旋转,通过 .real 属性取出旋转后的实部,并将类型转换回原始类型

2.3 RoPE 的优势

与传统的正弦余弦位置编码相比,RoPE 具有以下优势:

  1. 相对位置编码: RoPE 专注于编码词语之间的相对位置信息,而不仅仅是绝对位置。通过向量旋转,使得向量之间的相对位置信息更直观。
  2. 高效计算: 通过使用复数乘法,RoPE 可以在GPU上进行高效的并行计算。
  3. 良好的外推能力: RoPE 可以比较容易地推广到训练期间未见过的序列长度,并且性能保持稳定。
  4. 可解释性: RoPE 的旋转操作使其相对位置信息具有更强的可解释性,有助于理解模型的行为。

3. 总结

本文详细介绍了 LLaMA 模型中使用的 Rotary Embedding 位置编码方法。通过源码分析和对比传统的位置编码,我们了解了 RoPE 的核心原理和优势。RoPE 通过旋转操作高效地编码相对位置信息,为 LLaMA 模型的强大性能提供了重要的基础。希望本文能帮助你更深入地理解 Transformer 模型中的位置编码机制。

4. 参考资料

  • RoFormer: Enhanced Transformer with Rotary Position Embedding
  • Attention is All You Need

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/69093.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【科研】 -- 医学图像处理方向,常用期刊链接

文章目录 文章目录 医学图像处理方向,期刊链接医学图像处理方向,会议 医学图像处理方向,期刊链接 Sicence https://www.science.org/ Nature https://www.nature.com/ Nature Communications https://www.nature.com/ncomms/ Nature Met…

The specified Gradle distribution ‘gradle-bin.zip‘ does not exist.

The specified Gradle distribution ‘https://services.gradle.org/distributions/gradle-bin.zip’ does not exist. distributionUrl不存在,关联不上,下载不了,那就匹配一个能下载的 distributionUrlhttps://services.gradle.org/distrib…

从零开始实现一个双向循环链表:C语言实战

文章目录 1链表的再次介绍2为什么选择双向循环链表?3代码实现:从初始化到销毁1. 定义链表节点2. 初始化链表3. 插入和删除节点4. 链表的其他操作5. 打印链表和判断链表是否为空6. 销毁链表 4测试代码5链表种类介绍6链表与顺序表的区别7存储金字塔L0: 寄存…

Cesium点集中获取点的id,使用viewer.value.entities.getById报错的解决方法

错误代码: viewer.value.entities.getById(pickedObject.id) 报错: 可以正常获取movement.position但是一直出现如下报错,无法获得航点的id,通过断点定位为 viewer.value.entities.getById(pickedObject.id)导致的报错 解决方…

java 进阶教程_Java进阶教程 第2版

第2版前言 第1版前言 语言基础篇 第1章 Java语言概述 1.1 Java语言简介 1.1.1 Java语言的发展历程 1.1.2 Java的版本历史 1.1.3 Java语言与C/C 1.1.4 Java的特点 1.2 JDK和Java开发环境及工作原理 1.2.1 JDK 1.2.2 Java开发环境 1.2.3 Java工作原理 1.…

ARM Linux Qt使用JSON-RPC实现前后台分离

文章目录 1、前言2、解决方案2.1、JSON-RPC2.2、Qt中应用JSON-RPC的框架图2.3、优点2.4、JSON-RPC 1.0 协议规范 3、程序示例3.1、Linux C(只例举RPC Server相关程序)3.2、Qt程序(只例举RPC Client相关程序) 4、编译程序4.1、交叉…

PyQt6/PySide6 的 QMainWindow 类

QMainWindow 是 PyQt6 或 PySide6 库中一个非常重要的类,它提供了一个主窗口应用程序的框架,该框架可以包含菜单栏、工具栏、状态栏以及中心部件等。QMainWindow 为 GUI 应用程序提供了基本的结构和布局管理功能,非常适合用来创建复杂的用户界…

9. k8s二进制集群之kube-controller-manager部署

同样在部署主机上创建证书请求文件(为之后的证书生成做准备)根据上面的证书文件创建证书(结果会在当前目录下产生kube-controller-manager证书)创建kube-controller-manager服务配置文件创建kube-controller-manager服务启动文件同步kube-controller-manager证书到对应mast…

教程 | i.MX RT1180 ECAT_digital_io DEMO 搭建(一)

本文介绍 i.MX RT1180 EtherCAT digital io DEMO 搭建,Master 使用 TwinCAT ,由于步骤较多,分为上下两篇,本文为第一篇,主要介绍使用 TwinCAT 控制前的一些准备。 原厂 SDK 提供了 evkmimxrt1180_ecat_examples_digit…

ubuntu22.40安装及配置静态ip解决重启后配置失效

遇到这种错误,断网安装即可! 在Ubuntu中配置静态IP地址的步骤如下。根据你使用的Ubuntu版本(如 Netplan 或传统的 ifupdown),配置方法有所不同。以下是基于 Netplan 的配置方法(适用于Ubuntu 17.10及更高版…

服务端渲染技术

一.JSP 1.jsp介绍,全称是java Server Pages ,java服务器页面,就是服务端渲染技术,html只能为用户提供静态数据,而JSP技术允许在页面中嵌套 java代码,jsp技术基于Servlet,Servlet很难对数据进行排版,而jsp就可以,可以理解为jsp就是对Servlet的包装. 2.jsp程序本质是java程序,无…

[23] cuda应用之 nppi 实现图像缩放

[23] cuda应用之 nppi 实现图像缩放 NPP(NVIDIA Performance Primitives)是一个由 NVIDIA 提供的库,专门用于加速图像和信号处理任务。NPP 提供了许多高效的图像处理函数,包括图像缩放。使用 NPP 实现图像缩放可以充分利用 GPU 的…

【产品经理学习案例——AI翻译棒出海业务】

前言: 本文主要讲述了硬件产品在出海过程中,翻译质量、翻译速度和本地化落地策略是硬件产品规划需要考虑的核心因素。针对不同国家,需要优化翻译质量和算法,关注市场需求和文化差异,以便更好地满足当地用户的需求。同…

CH340G上传程序到ESP8266-01(S)模块

文章目录 概要ESP8266模块外形尺寸模块原理图模块引脚功能 CH340G模块外形及其引脚模块引脚功能USB TO TTL引脚 程序上传接线Arduino IDE 安装ESP8266开发板Arduino IDE 开发板上传失败上传成功 正常工作 概要 使用USB TO TTL(CH340G)将Arduino将程序上传…

1.4 Go 数组

一、数组 1、简介 数组是切片的基础 数组是一个固定长度、由相同类型元素组成的集合。在 Go 语言中,数组的长度是类型的一部分,因此 [5]int 和 [10]int 是两种不同的类型。数组的大小在声明时确定,且不可更改。 简单来说,数组…

AI推理性能之王-Groq公司开发的LPU芯片

Groq公司开发的LPU(Language Processing Unit,语言处理单元)芯片是一种专为加速大规模语言模型(LLM)和其他自然语言处理任务而设计的新型AI处理器。以下是对其技术特点、性能优势及市场影响的深度介绍: 技…

C#中的委托(Delegate)

什么是委托? 首先,我们要知道C#是一种强类型的编程语言,强类型的编程语言的特性,是所有的东西都是特定的类型 委托是一种存储函数的引用类型,就像我们定义的一个 string str 一样,这个 str 变量就是 string 类型. 因为C#中没有函数类型,但是可以定义一个委托类型,把这个函数…

rk3506 sd卡启动

1 修改系统配置文件,打开ext4 #SDMMC RK_ROOTFS_TYPE"ext4" RK_ROOTFS_INSTALL_MODULESy RK_WIFIBT_CHIP"AIC8800" # RK_ROOTFS_LOG_GUARDIAN is not set RK_UBOOT_CFG_FRAGMENTS"rk3506_tb" RK_UBOOT_SPLy RK_KERNEL_CFG"rk3506_defconfi…

2025春招,深度思考MyBatis面试题

大家好,我是V哥,2025年的春招马上就是到来,正在准备求职的朋友过完年,也该收收心,好好思考一下自己哪些技术点还需要补一补了,今天 V 哥要跟大家聊的是MyBatis框架的问题,站在一个高级程序员的角…

Docker 安装详细教程(适用于CentOS 7 系统)

目录 步骤如下: 1. 卸载旧版 Docker 2. 配置 Docker 的 YUM 仓库 3. 安装 Docker 4. 启动 Docker 并验证安装 5. 配置 Docker 镜像加速 总结 前言 Docker 分为 CE 和 EE 两大版本。CE即社区版(免费,支持周期7个月)&#xf…