- FlashAttention最基础的方案来自使用高速的share memory来加速Softmax操作,实现Softmax的tiling方案。(Q,K,V之间的乘法可由gemm实现。)
 ![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-F2WMd8tb-1691511319949)(https://github.com/Dao-AILab/flash-attention/blob/main/assets/flashattn_banner.jpg#pic_center)]](https://img-blog.csdnimg.cn/6dfd3cfcd08d49beb7c75e9cf9e411c0.png) 
左侧为GPU各部分的访问速度比较
- FlashAttention使用平铺来防止大型实体化𝑁 ×𝑁 注意力矩阵(虚线框)在(相对)慢的GPU HBM上。
中间为实现过程
- softmax的计算公式
  
 注:我也比较好奇,softmax公式怎么好像变得复杂了?我在参考文献60中找到了答案:
 不幸的是,在所表示的数字范围有限的实际硬件上,算法1的第3行(求分母的时候)可能由于指数而上溢或下溢。得到这这种安全形式的改写。
- 作者提出的分解方法
  
右侧为融合核函数和pytorch实现的速度比较
CG
-  https://github.com/Dao-AILab/flash-attention 
-  Jax上继承了Numpy计算加速,XLA加速,JIT编译,自动微分等,以下代码不用自己实现cuda函数Implementation of Flash Attention in Jax 
-  cuda实现 https://github.com/lucidrains/flash-cosine-sim-attention/tree/main 
-  https://github.com/jundaf2/INT8-Flash-Attention-FMHA-Quantization 
-  https://github.com/kyegomez/FlashAttention20Triton 
-  https://github.com/Lightning-AI/lit-llama 
-  Add Flash-Attention to Huggingface Models https://github.com/conceptofmind/flash-gpt 
-  https://www.zhihu.com/question/611236756/answer/3136806315 
