Flash Attention原理

news/2025/9/17 16:30:53/文章来源:https://www.cnblogs.com/gongzb/p/19097036

提出问题

Transformer 结构已成为自然语言处理和图像分类等应用中最常用的架构。尽管 Transformer 在规模上不断增大和加深,但处理更长上下文仍然是一个挑战,因为核心的自注意力模块在序列长度上具有二次方的时间和内存复杂度。这导致在处理长序列时速度变慢且内存需求巨大。因此,我们需要一些优化算法来提高注意力模块的计算速度和内存利用率。

解决方案

 

Forward

Standard Attention Implementation

在注意力的一般实现中,对$$\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d}$$三个输入执行以下算法得到输出$$\mathbf{O}$$,其中softmax行级别执行。
$$\begin{equation} \mathbf{S}=\mathbf{Q K}^{\top} \in \mathbb{R}^{N \times N}, \quad \mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O}=\mathbf{P} \mathbf{V} \in \mathbb{R}^{N \times d}, \end{equation}$$
在这个算法中,$$\mathbf{S}, \mathbf{P$$矩阵都是很大,需要在HBM中实例化来进行存储,这样就会带来很多HBM的访问次数,最终体现到算法时间端到端较长的延迟。

FlashAttention Implementation(Tiling)

理论基础

在传统算法中,一种方式是将Mask和SoftMax部分融合,以减少访存次数。然而,FlashAttention则更加激进,它将从输入$$\mathbf{Q}, \mathbf{K}, \mathbf{V$$到输出$$\mathbf{O$$的整个过程进行融合,以避免$$\mathbf{S}, \mathbf{P}$$矩阵的存储开销,实现端到端的延迟缩减。然而,由于输入的长度$$N$$通常很长,无法完全将完整的$$\mathbf{Q}, \mathbf{K}, \mathbf{V},\mathbf{O}$$及中间计算结果存储在SRAM中。因此,需要依赖HBM进行访存操作,与原始计算延迟相比没有太大差异,甚至会变慢(没具体测)。
为了让计算过程的结果完全在SRAM中,摆脱对HBM的依赖,可以采用分片操作,每次进行部分计算,确保这些计算结果能在SRAM内进行交互,待得到对应的结果后再进行输出。
这个过程中,有一点需要注意的是,之前对于softmax的计算是以行为单位的,如下所示:
$$\begin{equation} m(x):=\max _i x_i, \quad f(x):=\left[\begin{array}{lll} e^{x_1-m(x)} & \ldots & e^{x_B-m(x)} \end{array}\right], \quad \ell(x):=\sum_i f(x)_i, \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} \end{equation}$$
当我们将输入进行分片后,无法对完整的行数据执行Softmax操作。这是因为Softmax函数在计算时需要考虑整个行的数据。然而,我们可以通过如下所示方法来获得与完整行Softmax相同的结果,而无需使用近似操作。
$$\begin{equation} \begin{aligned} & m(x)=m\left(\left[x^{(1)} x^{(2)}\right]\right)=\max \left(m\left(x^{(1)}\right), m\left(x^{(2)}\right)\right), \quad f(x)=\left[\begin{array}{ll} e^{m\left(x^{(1)}\right)-m(x)} f\left(x^{(1)}\right) & e^{m\left(x^{(2)}\right)-m(x)} f\left(x^{(2)}\right) \end{array}\right], \\ & \ell(x)=\ell\left(\left[x^{(1)} x^{(2)}\right]\right)=e^{m\left(x^{(1)}\right)-m(x)} \ell\left(x^{(1)}\right)+e^{m\left(x^{(2)}\right)-m(x)} \ell\left(x^{(2)}\right), \quad \operatorname{softmax}(x)=\frac{f(x)}{\ell(x)} . \end{aligned} \end{equation} $$

代码实现

@triton.jit
def _fwd_kernel(Q, K, V, sm_scale,L, M,Out,stride_qz, stride_qh, stride_qm, stride_qk,stride_kz, stride_kh, stride_kn, stride_kk,stride_vz, stride_vh, stride_vk, stride_vn,stride_oz, stride_oh, stride_om, stride_on,Z, H, N_CTX,BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,BLOCK_N: tl.constexpr,
):start_m = tl.program_id(0)off_hz = tl.program_id(1)# initialize offsetsoffs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)offs_n = tl.arange(0, BLOCK_N)offs_d = tl.arange(0, BLOCK_DMODEL)off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qkoff_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kkoff_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk# Initialize pointers to Q, K, Vq_ptrs = Q + off_qk_ptrs = K + off_kv_ptrs = V + off_v# initialize pointer to m and lm_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)# load q: it will stay in SRAM throughoutq = tl.load(q_ptrs)# loop over k, v and update accumulatorfor start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):# -- compute qk ----k = tl.load(k_ptrs)qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)qk += tl.dot(q, k)qk *= sm_scaleqk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))# compute new mm_curr = tl.maximum(tl.max(qk, 1), m_prev)# correct old ll_prev *= tl.exp(m_prev - m_curr)# attention weightsp = tl.exp(qk - m_curr[:, None])l_curr = tl.sum(p, 1) + l_prev# rescale operands of matmulsl_rcp = 1. / l_currp *= l_rcp[:, None]acc *= (l_prev * l_rcp)[:, None]# update accp = p.to(Q.dtype.element_ty)v = tl.load(v_ptrs)acc += tl.dot(p, v)# update m_i and l_il_prev = l_currm_prev = m_curr# update pointersk_ptrs += BLOCK_N * stride_knv_ptrs += BLOCK_N * stride_vk# rematerialize offsets to save registersstart_m = tl.program_id(0)offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)# write back l and ml_ptrs = L + off_hz * N_CTX + offs_mm_ptrs = M + off_hz * N_CTX + offs_mtl.store(l_ptrs, l_prev)tl.store(m_ptrs, m_prev)# initialize pointers to outputoffs_n = tl.arange(0, BLOCK_DMODEL)off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_onout_ptrs = Out + off_otl.store(out_ptrs, acc)

IO Complexity Analysis

Standard Attention

对于标准注意力实现,初期我们需要把输入$$\mathbf{Q}, \mathbf{K}, \mathbf{V$$从HBM中读取,并计算完毕后把输出$$\mathbf{O}$$写入到HBM中。
第一步把$$\mathbf{Q}, \mathbf{K}$$读取出来计算出$$\mathbf{S}=\mathbf{Q K}^{\top}$$,然后把$$\mathbf{S}$$存回去,内存访问复杂度$$\Theta\left(N d+N^2\right$$。
第二步把$$\mathbf{S}$$读取出来计算出$$\mathbf{P}=\operatorname{softmax}(\mathbf{S}$$,然后把$$\mathbf{P$$存回去,内存访问复杂度$$\Theta\left(N^2\right)$$。
第三步把$$\mathbf{V}, \mathbf{P}$$读取出来计算出$$\mathbf{O}=\mathbf{P} \mathbf{V$$,然后计算出结果$$\mathbf{O}$$,内存访问复杂度$$\Theta\left(N d+N^2\right)$$。
综上所述,整体的内存访问复杂度为$$\Theta\left(N d+N^2\right)$$。

FlashAttention

对于FlashAttention,我们设置一个分块大小$$B_c$$来把$$\mathbf{K}, \mathbf{V}$$分成$$T_$$块,对于$$\mathbf{Q}, \mathbf{O}$$的每一块都要把$$\mathbf{K}, \mathbf{V }$$部分的全部元素Load一遍,这样则有FlashAttention的内存访问复杂度为$$\Theta\left(N d+N d T_c\right)$$=$$\Theta\left(N d T_c\right)$$.
在这里,我们需要两个分块大小,$$\mathbf{Q}, \mathbf{O}$$的分块大小$$B_r$$,$$\mathbf{K}, \mathbf{V}$$的分块大小$$B_c$$,我们设定SRAM的大小为$$M$$,为了能把分块后的$$\mathbf{K}, \mathbf{V} \in \mathbb{R}^{B_c \times d}$$放进SRAM,那么则有一下限制:$$\begin{equation} B_c d=O(M) \Leftrightarrow B_c=O\left(\frac{M}{d}\right) \end{equation}$$
相应的,$$\mathbf{Q}, \mathbf{O} \in \mathbb{R}^{B_r \times d}$$有如下限制:$$\begin{equation} B_r d=O(M) \Leftrightarrow B_r=O\left(\frac{M}{d}\right) \end{equation}$$
最终,还有一个中间态$$\mathbf{S}=\mathbf{Q K}^{\top} \in \mathbb{R}^{B_r \times B_c}$$需要存储,则有如下限制:$$\begin{equation} B_r B_c=O(M) \end{equation}$$
综上,限制如下
$$\begin{equation} B_c=\Theta\left(\frac{M}{d}\right), \quad B_r=\Theta\left(\min \left(\frac{M}{d}, \frac{M}{B_c}\right)\right)=\Theta\left(\min \left(\frac{M}{d}, d\right)\right) \end{equation}$$
进而推出
$$\begin{equation} T_c=\frac{N}{B_c}=\Theta\left(\frac{N d}{M}\right) \end{equation}$$
那么在$$M = \Theta (Nd$$的前提下,则有FlashAttention的HBM内存访问复杂度为:
$$\begin{equation} \Theta\left(N d T_c\right)=\Theta\left(\frac{N^2 d^2}{M}\right) = \Theta\left({N d}\right) \end{equation}$$
在语言建模中,通常有$$d \lll N$$,则有$$\Theta_{stand} \left(N d+N^2\right) >  \Theta_{flash} \left(N d\right)$$。这样,在前向的过程中,我们采用分块计算的方式,避免了$$\mathbf{S}, \mathbf{P}$$矩阵的存储开销,整体的运算都在SRAM内进行,降低了HBM访问次数,大大提升了计算的速度,减少了对存储的消耗。

Backward

理论基础

在上面前向的时候我们为了减少HBM访存次数,降低内存消耗量,我们并没有对$$\mathbf{S}, \mathbf{P}$$矩阵进行存储,而这个在反向传播计算梯度的时候确实需要的一个信息。之前有通过Gradient checkpointing的方式来实现梯度实现在前向的时候更加节省内存。
我们这里则采用重新计算的方式来计算对应的梯度。在上面前向计算的时候我们不会存储$$\mathbf{S}, \mathbf{P}$$矩阵,但是我们会存储对应的指数项之和$$L$$来进行梯度的计算。
我们在反向的过程中最重要的事情就是就是Loss函数$$\ph$$对$$\mathbf{Q}, \mathbf{K}, \mathbf{V}, \mathbf{O}$$对应的梯度。
$$\mathbf{O$$对应的梯度最好计算$$\mathbf{dO} = \frac{\partial \phi}{\partial \mathbf{O}}$$,其中$$\mathbf{O$$是现成的。
$$\mathbf{V}$$对应的梯度也很好计算,由于$$\mathbf{O} = \mathbf{P}\mathbf{V} $$,根据链式求导法则和矩阵求导法则则有$$\mathbf{d V}=\mathbf{P}^T \mathbf{d} \mathbf{O}$$,更详细如下所示:
$$\begin{equation} d v_j=\sum_i P_{i j} d o_i=\sum_i \frac{e^{q_i^T k_j}}{L_i} d o_i \end{equation}$$
$$\mathbf{Q}, \mathbf{K}$$对应的梯度算起来就比较复杂一点。这两个经过的计算逻辑步骤更多,我们可以一步一步的来进行计算。我们可以先计算$$\mathbf{dP}, \mathbf{dS}$$。由于$$\mathbf{O} = \mathbf{P}\mathbf{V} $$,则有$$\mathbf{dP}$$如下表示
$$\begin{equation} d P_{i j}=d o_i^T v_j \end{equation}$$
由于$$P_{i:}=\operatorname{softmax}\left(S_{i:}\right)$$,根据上述定理则有:
$$\begin{equation} d S_{i:}=\left(\operatorname{diag}\left(P_{i:}\right)-P_{i:} P_{i:}^T\right) d P_{i:}=P_{i:} \circ d P_{i:}-\left(P_{i:}^T d P_{i:}\right) P_{i:} \end{equation}$$
接下来我们定义如下表示:
$$\begin{equation} D_i=P_{i:}^T d P_{i:}=\sum \frac{e^{q_i \kappa_j}}{L_i} d o_i^T v_j=d o_i^T \sum \frac{e^{q_i \kappa_j}}{L_i} v_j=d o_i^T o_i \end{equation}$$
根据上述定义简化12式则有如下表示:
$$\begin{equation} d S_{i:}=P_{i:} \circ d P_{i:}-D_i P_{i:} \end{equation}$$
相应的$$\mathbf{dS}$$可表示为如下形式:
$$\begin{equation} d S_{i j}=P_{i j} d P_{i j}-D_i P_{i j}=P_{i j}\left(d P_{i j}-D_i\right) \end{equation}$$
又因为$$S_{i j}=q_i^T k_j$$,结合上述推导利用链式求导法则$$\mathbf{Q}, \mathbf{K}$$对应的梯度有如下表示:
$$\begin{equation} d q_i=\sum_j d S_{i j} k_j=\sum_j P_{i j}\left(d P_{i j}-D_i\right) k_j=\sum_j \frac{e^{q_i^T k_j}}{L_i}\left(d o_i^T v_j-D_i\right) k_j \end{equation}$$
$$\begin{equation} d k_j=\sum_i d S_{i j} q_i=\sum_i P_{i j}\left(d P_{i j}-D_i\right) q_i=\sum_i \frac{e^{q_i^T k_j}}{L_i}\left(d o_i^T v_j-D_i\right) q_i \end{equation}$$
至此,我们得到了一个完整的包含前向和反向的,降低了HBM访问次数的,新的Attention算子。

Block-Sparse

相比于上面的全量计算,块稀疏的FlashAttention需要额外提供一个Mask矩阵$$\tilde{\mathbf{M}} \in\{0,1\}^{N \times N$$用于将一些元素置零来保证块稀疏加速计算。本文对于块稀疏的一个计算只是一个简单的尝试,没有进行太深入的探索,所以这里我们先一笔带过,后面我们可以讲一篇对FlashAttention进行块稀疏优化的工作SCFA.
$$\begin{equation} \mathbf{S}=\mathbf{Q} \mathbf{K}^{\top} \in \mathbb{R}^{N \times N}, \quad \mathbf{P}=\operatorname{softmax}\left(\mathbf{S} \odot \mathbb{1}_{\tilde{\mathbf{M}}}\right) \in \mathbb{R}^{N \times N}, \quad \mathbf{O}=\mathbf{P V} \in \mathbb{R}^{N \times d} \end{equation}$$

实验验证

 
通过实验验证发现,FlashAttention在速度和内存占用方面都表现出明显的优势,并取得了良好的效果。目前,FlashAttention已经经过广泛验证, torch2.0中已提供flashattention的实现。正如标题《Fast and Memory-Efficient Exact Attention with IO-Awareness》所示,FlashAttention的优点在于充分考虑了在计算任务中IO的重要性,并通过分块计算的方式开发了一种快速、节省显存、精确无近似的注意力实现方法。这使得我们更便于训练具有更长上下文的Transformer模型,并且为后续注意力算法的优化提供了一个基准。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/906683.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MSMQ 跨服务器读写队列的“消息队列系统的访问被拒绝”的解决方案

转:http://m.blog.csdn.net/blog/2000killer/8904852 在服务器上创建的Queue开发者的 机器只能写数据而不能读数据。.net给出的错误是“对消息队列系统的访问被拒绝”,也就是说拒绝访问队列没有相关权限,我给Everyone和 ANONYMOUS LOGON赋予全部权限都无法解决(正常情况下可…

Linux时间同步---NTP时间同步方案

1.方案背景: 在分布式系统或多服务器集群中,必须建立统一的时间同步机制。服务器间的时间不一致会破坏各类依赖时间交互逻辑,例如导致日志时序混乱、事务顺序错乱、证书验证失败等,从而引发一系列难以排查的不可预知故障。 2.NTP同步网络拓扑图: 3.同步方案 可提前咨询医…

java预习

课前问题列表 1.什么样的方法应该用static修饰?不用static修饰的方法往往具有什么特性?Student的getName应该用static修饰吗?适合用 static 修饰的方法: 工具类方法(如Math.abs())、工厂方法、不需要访问实例变量 / 方法的方法、单例模式的获取实例方法等,这类方法通常与…

B/S体系结构风格

三层B/S风格-概述 》浏览器/服务器(B/S)风格就是上述三层应用结构的一种实现方式,其具体结构为:浏览器/Web服务器/数据库服务器。 》B/S体系结构主要是利用不断成熟的WWW浏览技术,结合浏览的多种脚本语言,用通用浏览器就实现了原来需要复杂的专用软件才能实现的强大功能,…

The 2024 CCPC Online Contest 7/12 L/B/K/D/J/E/C

Problem L. 网络预选赛 签到,直接模拟即可点击查看代码 #include<bits/stdc++.h> using namespace std; int main(){int n,m;cin>>n>>m;vector<string>a(n);for(int i=0;i<n;i++){cin>>a[i];}int sum=0;for(int i=0;i<n-1;i++){for(int j…

在joule里面使用agent 功能

test: Dev: 1: structure 2: 本博客为非营利性个人原创,除部分有明确署名的作品外,所刊登的所有作品的著作权均为本人所拥有,本人保留所有法定权利。违者必究

Feign动态URL配置

方式一、亲测可用,缺点是每个类都需要单独配置@FeignClient(value = "my-biz", url = "${external.my.biz_url}", configuration = FeignHeaderInterceptor.class) public interface MyBizFeign {}@Data @Component @RefreshScope @ConfigurationProperti…

自动化部署工具 Jenkins 的安装与配置

Jenkins 是一个开源的自动化部署工具,广泛用于持续集成(CI)和持续交付(CD)流程。它支持自动化构建、测试和部署应用程序。以下是 Jenkins 的安装与配置的详细教程。1. 安装 Jenkins 以下步骤适用于 Linux 系统(以 Ubuntu 和 CentOS 为例),并包含安装必要的依赖环境。1.…

pip 搭建源

搭建本地pip源主要可通过pypiserver、pip2pi或bandersnatch等工具实现,具体步骤如下: 工具选择与安装‌pypiserver‌:轻量级方案,适合快速搭建小型私有源,通过pip install pypiserver安装 ‌pip2pi‌:适合按需构建源,从requirements.txt生成索引,需配合pip install pip…

qoj10093 Jump the Frog

题意 给出 \(n\) 个由 O 和 ~ 组成的字符串 \(s_i\),还有 \(m\) 个额外字符串,第 \(n+i\) 个字符串 \(s_{n+i}\) 由第 \(s_x\) 和 \(s_y\) \((x,y<n+i)\) 个字符串拼接得到,即 \(s_{n+i}=s_x+s_y\)。你需要对这 \(n+m\) 个字符串解决以下问题: 有一只青蛙从字符串的起点…

new 和make

func NewCase() {// 通过new , 可以创建任意类型,并返回指针mpPtr := new(map[string]*user)if *mpPtr == nil { // 通过 * 获取指针内容fmt.Println("map 为空")}// sliceslicePtr := new([]user)if *slicePtr == nil {fmt.Println("slice 为空")}*sliceP…

微信商户绑定微信公众号、小程序

https://pay.weixin.qq.com/index.php/extend/merchant_appid/mapay_platform/account_manage版权木有,侵权不究,欢迎转载

Ceres 常用 LossFunction 对比

Ceres 常用 LossFunction 对比

测试开发全日制学徒班火热报名中|跟着名企大咖做真实项目,结业即上岗

测试开发全日制学徒班,采用系统化教学的全日制线下课程,通过「人工智能测试+自动化测试+Python开发+测试开发」四维能力培养体系,配备行业资深专家导师陪跑服务(私教1v1答疑+周末私教陪跑服务),全程采用企业级学徒制培养模式。 学员将参与真实企业级项目开发,完成测试全…

唯创知音AI语音交互芯片与模组介绍

AI语音交互已经成为智能产品的基础配置,比如常见的AI玩具、智能家居、带AI功能的蓝牙音响,还有汽车的智能车机和智能穿戴设备等。唯创知音顺应市场趋势推出了WT2606A系列的AI语音交互芯片,和WT3000A M06、WT3000A M07、WT3000A M08三款AI语音交互模组。WT2606A AI语音交互芯…

k3s 高可用集群部署(内置 etcd + VIP + keepalived)

k3s 高可用集群部署(内置 etcd + VIP + keepalived) 一、节点规划master 节点:10.0.0.40、10.0.0.51、10.0.0.53 worker 节点:10.0.0.52、10.0.0.54 VIP(高可用入口):10.0.0.41二、离线包准备下载 k3s 安装脚本、二进制、镜像包 导入镜像到本地或 Harbor 打包所有安装文…

问HashMap底层原理?

HashMap是基于数组+链表+红黑树的哈希表。用于存储键值对。 1.哈希计算和扰动处理,也就是Hash方法 每一个Object都有一个 .hashCode 方法。(哈希计算)在对hashmap进行插入和查询时,先调用key键的key.hashCode()方法获取一个未处理的int哈希值,在底层代码中该值被复制给变量…

用 Go 重写 adbkit:原理、架构与搭建实践

用 Go 重写 adbkit:原理、架构与搭建实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", monospace !impor…

C语言环境搭建之Linux子系统使用vscode连接子系统

安装准备工作查看当前系统版本确保高于16215.0开启WSL Windows Subsystem for Linux(简称WSL)是一个为在Windows 10上能够原生运行Linux二进制可执行文件(ELF格式)的兼容层。安装步骤微软商城Microsoft Store安装Ubuntu(本人安装的版本是22.04)点击等待安装完成输入用户名跟…

移远AT指令笔记

# 测试 AT - 测试AT指令功能是否正常# 模块相关 ATI - 查询模块信息 AT+CGMI - 查询模块制造商标识 AT+CGMM - 查询模块型号 AT+CGMR - 查询模块固件版本号# 网络相关 AT+QCCID - 查询集成电路卡识别码(ICCID) AT+GSN …