Mosaic:面向超长序列的多GPU注意力分片方案

Transformer的"二次方注意力瓶颈"的问题是老生常谈了。这个瓶颈到底卡在哪实际工程里怎么绕过去?本文从一个具体问题出发,介绍Mosaic这套多轴注意力分片方案的设计思路。

注意力的内存困境

注意力机制的计算公式:

Attention(Q, K, V) = softmax(QKᵀ / √d) × V

问题出在QKᵀ这个矩阵上,它的形状是

(序列长度 × 序列长度)

拿150,000个token的序列算一下:

Memory = 150,000² × 4 bytes = 90 billion bytes ≈ 84 GB

这只是注意力权重本身的开销,而且还是单层、单头。A100的显存上限是80GB,放不下就是放不下。

现有方案的局限

FlashAttention它通过分块计算,不需要把完整的注意力矩阵实例化出来,内存复杂度从O(n²)降到O(n)。单卡场景下效果很好,但问题是整个序列还是得塞进同一张GPU。

Ring Attention换了个思路:把序列切片分到多张GPU上,每张卡持有一部分Q,K和V在GPU之间像传令牌一样轮转,一维序列处理起来是很不错的。

但是多维怎么办?

比如处理表格数据的Transformer,输入张量形状是

(batch, rows, features, embed)

。模型需要在不同维度上做注意力:features维度只有5个token,rows维度却有150,000个。前者单卡轻松搞定,后者则必须分片。

现有的库都没法干净地处理这种多轴场景。手写的话,每个轴要单独写分片逻辑,进程组管理、张量reshape全得自己来。代码会变得很脏。

Mosaic的设计

Mosaic本质上是个协调层,负责把不同的注意力轴路由到合适的计算后端:

import mosaic # Small axis: run locally feature_attn = mosaic.MultiAxisAttention( embed_dim=96, num_heads=4, attention_axis=2, # features dimension backend="local" # no communication needed ) # Large axis: shard across GPUs row_attn = mosaic.MultiAxisAttention( embed_dim=96, num_heads=4, attention_axis=1, # rows dimension backend="ring" # ring attention across GPUs )

底层Mosaic会自动处理轴的置换、QKV投影前的reshape、后端分发、以及计算完成后张量形状的还原。模型代码保持清晰,分布式的复杂性被封装掉了。

Ring Attention的工作机制

核心思想其实很直接:不需要同时持有全部的K和V。可以分批计算注意力分数,逐步累积,最后再做归一化。

比如说4张GPU的情况下流程是这样的:

Initial state: GPU 0: Q₀, K₀, V₀ GPU 1: Q₁, K₁, V₁ GPU 2: Q₂, K₂, V₂ GPU 3: Q₃, K₃, V₃ Step 1: Each GPU computes attention with its local K, V GPU 0: score₀₀ = Q₀ @ K₀ᵀ ... Step 2: Pass K, V to the next GPU in the ring GPU 0 receives K₃, V₃ from GPU 3 GPU 0 sends K₀, V₀ to GPU 1 Step 3: Compute attention with received K, V GPU 0: score₀₃ = Q₀ @ K₃ᵀ Accumulate with score₀₀ Repeat for all chunks... Final: Each GPU has complete attention output for its Q chunk

单卡内存占用变成O(n²/p),p是GPU数量。8张卡的话内存需求直接砍到1/8。150k序列从84GB降到约10GB每卡。

Mesh2D:更激进的分片

序列特别长的时候Ring Attention的线性分片可能还不够,这时候可以用Mesh2D把Q和K都切分了:

4 GPUs arranged in 2×2 mesh: K₀ K₁ ┌──────┬──────┐ Q₀ │GPU 0 │GPU 1 │ ├──────┼──────┤ Q₁ │GPU 2 │GPU 3 │ └──────┴──────┘ Each GPU computes one tile of QKᵀ

内存复杂度降到O(n²/p²)。64张卡组成8×8网格时,每卡内存需求下降64倍。

attn=mosaic.MultiAxisAttention( embed_dim=128, num_heads=8, attention_axis=1, backend="mesh2d", mesh_shape=(8, 8) )

感知集群拓扑的组合策略

在实际部署环境里,不同GPU之间的通信带宽差异很大。节点内GPU走NVLink能到900 GB/s,跨节点通过InfiniBand通常只有200 GB/s左右。

ComposedAttention

就是针对这种拓扑特征设计的:

# 4 nodes × 8 GPUs = 32 total composed = mosaic.ComposedAttention( mesh_shape=(4, 8), # (nodes, gpus_per_node) head_parallel=True, # Split heads across nodes (slow link) seq_parallel="ring" # Ring within nodes (fast link) )

需要更精细控制的话,可以用

HierarchicalAttention

hier = mosaic.HierarchicalAttention( intra_node_size=8, intra_node_strategy="local", # Compute locally within node inter_node_strategy="ring" # Ring between node leaders )

重通信走快链路轻通信才跨节点。

实现细节

整个库大约800行Python,核心代码如下:

class MultiAxisAttention(nn.Module): def forward(self, x): # 1. Move attention axis to seq position x, inv_perm = self._permute_to_seq(x) # 2. Flatten batch dims, project QKV x = x.view(-1, seq_len, embed_dim) qkv = self.qkv_proj(x).view(batch, seq, 3, heads, head_dim) q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) # 3. Dispatch to backend out = self._attn_fn(q, k, v) # local, ring, or mesh2d # 4. Project output, restore shape out = self.out_proj(out.transpose(1, 2).reshape(...)) return out.permute(inv_perm)

后端封装了现有的成熟实现:

local

后端调用

F.scaled_dot_product_attention

(也就是FlashAttention),

ring

后端用ring-flash-attn库的

ring_flash_attn_func

mesh2d

是自定义的all-gather加SDPA,所有的底层都跑的是FlashAttention内核。

所有后端统一用FlashAttention的融合GEMM+softmax实现。后端函数在初始化时就绑定好,前向传播不做分支判断。张量操作尽量用

x.view()

而不是

x.reshape()

,保持内存连续性。集合通信的目标张量预分配好,避免

torch.cat

的开销。模块级别做导入不在每次前向传播时产生import开销。

快速上手

安装:

pip install git+https://github.com/stprnvsh/mosaic.git # With ring attention support pip install flash-attn ring-flash-attn

单节点启动:

torchrun --nproc_per_node=4 train.py

多节点的话:

# Node 0 torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \ --master_addr=192.168.1.100 --master_port=29500 train.py # Node 1 torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 \ --master_addr=192.168.1.100 --master_port=29500 train.py

训练脚本示例:

import mosaic import torch.distributed as dist dist.init_process_group("nccl") ctx = mosaic.init(sp_size=dist.get_world_size()) model = MyModel().to(ctx.device) # Data is pre-sharded: each GPU has seq_total / world_size tokens x_local = load_my_shard() out = model(x_local) # Communication handled by Mosaic

总结

最后,Mosaic不会自动并行化模型(这个用nnScaler),不管数据并行(PyTorch DDP/FSDP的事),也不处理模型分片(交给FSDP或Megatron)。

Mosaic专注于一件事:多轴注意力的分片路由,这套方案最初是给nanoTabPFN做的,一个表格数据Transformer。

这个模型要同时在rows(150k个)和features(5个)两个维度做注意力。标准Ring Attention对维度语义没有感知,它只认序列这个概念,分不清rows和features的区别。

所以Mosaic需求很明确:小轴本地算,大轴分布式算,轴的路由逻辑不能侵入模型代码,有兴趣的可以试试。

https://avoid.overfit.cn/post/791e0f30540e4d289a43d01d383e8ab2

作者:Pranav Sateesh

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1125141.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2026马斯克《Moonshots》播客独家访谈全记录与深度解析

一、访谈基础信息全景 本次访谈是马斯克2026年首份重磅公开对话,录制于2025年12月22日美国得克萨斯州奥斯汀特斯拉超级工厂(Cybertruck与Optimus机器人核心生产基地),于2026年1月6日通过《Moonshots》播客正式上线,全时长近3小时。访谈由“零重力”公司创始人、奇点大学执…

操作系统期末复习——第4章:文件系统

目录第4章:文件系统概述4.1 文件4.1.1 文件命名4.1.2 文件结构4.1.3 文件类型4.1.4 文件访问4.2 目录4.2.1 一级目录系统4.2.2 二级目录系统4.2.3 层次目录系统4.2.4 路径名4.3文件系统的实现4.3.1 文件系统布局4.3.2 文件与磁盘4.3.3 ⭐文件的实现4.3.4 ⭐目录的实…

GESP Python 编程一级教材之 10 掌握变量的创建及使用(教程含历年试题解析)

系列文章 《GESP系列教程之 什么是GESP?》 《GESP 认证标准之 Python 编程一级标准(考试大纲与要求含考试真题)》 《GESP 认证标准之 Python 编程二级标准(考试大纲与要求含考试真题)》 《GESP 认证标准之 Python 编程三级标准(考试大纲与要求含考试真题)》 《GESP …

微信小程序 PHP_uniapp的社区团购系统_1g4y216z

微信小程序社区团购系统概述 该系统基于PHP和UniApp技术栈开发,整合微信小程序前端与PHP后端,实现社区团购的完整业务流程。前端采用UniApp跨平台框架,兼容多端运行;后端使用PHP构建高效的数据接口,支持商品管理、订单…

GESP Python 编程一级教材之 11 掌握输入输出语句 input 和 print(教程含历年试题解析)

系列文章 《GESP系列教程之 什么是GESP?》 《GESP 认证标准之 Python 编程一级标准(考试大纲与要求含考试真题)》 《GESP 认证标准之 Python 编程二级标准(考试大纲与要求含考试真题)》 《GESP 认证标准之 Python 编程三级标准(考试大纲与要求含考试真题)》 《GESP …

6.1 Elasticsearch-Lucene 索引文件结构:tim、tip、doc、pos、pay

6.1 Elasticsearch-Lucene 索引文件结构:tim、tip、doc、pos、pay Elasticsearch 的搜索性能之所以能在 PB 级别数据量下仍保持毫秒级响应,核心依赖是 Lucene 的倒排索引文件格式。一个分片(shard)本质上就是 Lucene 的一个索引目…

GESP Python 编程一级教材之 12 神奇的画笔turtle绘图,掌握图形库 turtle 的主要功能,使用 turtle 进行绘图(教程含历年试题解析)

系列文章 《GESP系列教程之 什么是GESP?》 《GESP 认证标准之 Python 编程一级标准(考试大纲与要求含考试真题)》 《GESP 认证标准之 Python 编程二级标准(考试大纲与要求含考试真题)》 《GESP 认证标准之 Python 编程三级标准(考试大纲与要求含考试真题)》 《GESP …

微信小程序 PHP_uniapp的社区老人服务管理系统_lz9wo71q

微信小程序 PHP_uniapp 社区老人服务管理系统摘要 该系统基于微信小程序和 PHP_uniapp 技术栈开发,旨在为社区老年人提供便捷的线上服务管理平台。通过整合社区资源,实现服务需求对接、健康监测、活动组织等功能,提升老年人生活质量。 技术架…

GESP Python 编程一级教材之 13 掌握模块的导入方法(教程含历年试题解析)

系列文章 《GESP系列教程之 什么是GESP?》 《GESP 认证标准之 Python 编程一级标准(考试大纲与要求含考试真题)》 《GESP 认证标准之 Python 编程二级标准(考试大纲与要求含考试真题)》 《GESP 认证标准之 Python 编程三级标准(考试大纲与要求含考试真题)》 《GESP …

玫瑰克隆AI工具:深耕小红书生态的爆款创作赋能利器

玫瑰克隆AI工具的核心定位,是专为小红书内容生态打造的“爆款逻辑拆解原创内容赋能”AI辅助创作系统。它区别于泛用型AI文案工具,深耕小红书平台规则、用户偏好与流量机制,以技术驱动破解创作者的核心痛点,助力不同层级创作者从“…

论文复现:PMSM速度伺服系统的强化学习与最优控制

论文复现:PMSM速度伺服系统的强化学习与最优控制 以下是基于论文提出的控制策略的复现代码,包括模型建立、控制器设计、强化学习算法实现以及仿真验证。代码将分为以下几个部分: 系统建模与参数定义 快速电流环PI控制器 模型降阶与慢速子系统 最优速度环设计与LQR问题 强化…

爆火!9款AI论文工具实测,PaperNex维普一把过!

深夜,你的论文进度条还卡在10%?导师的夺命连环催即将到来,知网维普的查重高墙横亘在前。别慌,这篇2024年最新的“急救指南”,将为你揭秘9款实测有效的AI论文神器,特别是能让你在最后关头“一把过”的王牌工…

多智能体实战指南:9种模式打造高效AI应用

想要构建一个智能体应用,最重要的是什么?可能很多人首先会想到要选择一个性能强大的大模型。 这个回答没错,毕竟当前的LLM Based Agent哪能缺少LLM的支撑。但事实却是,很多基于先进大模型构建的智能体没能体现出应用效果&#xff…

微信小程序 PHP_uniapp的音乐播放器排行榜系统的设计与实现_5h11g380

微信小程序音乐播放器排行榜系统设计与实现 该系统基于微信小程序平台,采用PHP后端与Uniapp前端框架开发,实现了一个功能完善的音乐播放器排行榜系统。系统设计分为前端展示、后端数据处理和数据库管理三大模块。 前端采用Uniapp跨平台框架开发&#xff…

收藏必备!国产最强大模型GLM-4-Plus评测:打破国外垄断,三大场景解决程序员痛点!

本文介绍了智谱AI推出的GLM-4-Plus大模型,该模型在SuperBench评测中排名第三,打破了国外模型垄断前三的局面。文章详细展示了GLM-4-Plus如何帮助程序员解决代码编写、理解和错误排查三大痛点,介绍了其强大的文件分析功能,并讲解了…

收藏!80%的人正在浪费大模型革命!这份产品经理转型指南请务必收藏

文章揭示大模型产品领域现状:真正的"神级产品经理"尚未出现,而80%的人正在用错误方式转型。作者剖析了四种典型错误:传统C端思维套用ChatBot、迷信专家call而非实际建模、高管只做PMO不学模型、O2O老兵只关注KPI和投流。强调大模型…

微信小程序 PHP_uniapp校园外卖跑腿骑手在线接单系统 _f8zv38dg

系统概述 微信小程序 PHP_uniapp校园外卖跑腿骑单系统是一款基于Uniapp框架和PHP后端开发的校园生活服务应用,旨在为学生和骑手提供高效的外卖配送与跑腿服务。系统支持多端兼容(微信小程序、H5、App),涵盖用户下单、骑手接单、订…

从零到 AI 产品经理:3 个必备技能缩短你的转型路径

不废话,直接上排期表:三天看“大盘”,把大模型这个行业的生态位、AI 产品经理的价值机会和类型搞清楚一星期“吃透”大模型底层原理:不学算法,但是必须懂模型怎么作业、应用方式和能力边界30 天每周跑通一个项目&#…

【Agent实战】Anthropic Skills、MCP与LangGraph的工程实践

摘要 随着大语言模型(LLM)应用从简单的Chatbot向自主智能体(Autonomous Agents)演进,如何管理复杂的任务上下文、标准化的工具调用以及确定性的业务流程,成为了系统设计的核心挑战。Anthropic 推出的 Skills 规范,结合 Model Context Protocol (MCP) 与 Function Calli…

传统PM转型大模型产品:避开90%人踩过的认知误区“ 解析

近期聊了不少希望转型大模型的PM,发现90%的人踩坑的真相是:用错方法论,把时间砸在错误方向。 一个很普遍的转型误区:技术思维陷阱**,觉得大模型高大上,先要了解算法底层逻辑才能入局**。然而2个转型实际案…