我自己的原文哦~ https://blog.51cto.com/whaosoft/13905076
#Executor-Workers架构
图解Vllm V1系列2
本文详细介绍了vllm v1的Executor-Workers架构,包括Executor的四种类型(mp、ray、uni、external_launcher)及其适用场景,以及通过图解展示了Executor与Workers之间的数据传输机制和整体架构,帮助读者理解vllm v1在分布式推理中的核心设计和运作方式。
前文:图解Vllm V1系列1:整体流程
在前文中,我们讨论了 vllm v1 在 offline batching / online serving 这两种场景下的整体运作流程,以offline batching为例:
整体上来看:
vllm v1将请求的pre-process和输出结果的post-process与实际的推理过程拆分在2个不同的进程中(process0, process1)。
Client负责请求的pre-process和输出结果的post-process,EngineCore负责实际的推理过程,不同进程间使用ZMQ来通信数据。
对于offline batching和online serving来说,它们会选取不同类型的Client进行运作,但是它们的EngineCore部分运作基本是一致的,如上图所示。
通过这样的进程拆分,在更好实现cpu和gpu运作的overlap的同时,也将各种模型复杂的前置和后置处理模块化,统一交给processor和output_processor进行管理。
本文我们来关注上图中的Executor部分,也就是管控模型分布式推理的核心部分,我们关注的是它的整体架构和初始化过程,而它实际执行推理的细节,我们留到后续文章细说。
一、Executor的类型
在vllm中,Executor一共有4种类型,由配置参数--distributed-executor-backend决定,相关的代码和文档参见:
代码:
- 决定executor的类型:https://github.com/vllm-project/vllm/blob/refs/tags/v0.8.2/vllm/config.py#L1465
- 根据executor的类型,import具体的executor:https://github.com/vllm-project/vllm/blob/refs/tags/v0.8.2/vllm/v1/executor/abstract.py#L25
文档:
- https://docs.vllm.ai/en/stable/serving/engine_args.html
- https://docs.vllm.ai/en/stable/serving/distributed_serving.html#distributed-inference-and-serving
对于--distributed-executor-backend,默认情况下为None,你当然也可以手动指定。在默认情况下,vllm会根据你的分布式配置(world_size)和你所使用的平台特征(cuda、neuron、是否安装/初始化/配置了ray等)来自动决定--distributed-executor-backend的取值。
我们来简单介绍下这4种类型的Executor。
(1)mp:MultiprocExecutor
- 适用场景:单机多卡。当单机的卡数满足分布式配置,且没有正在运行的ray pg时,默认使用mp
- 在mp的情况下,Executor成为一个主进程,其下的若干个workers构成了它的子进程们
(2)ray:RayDistributedExecutor
- 适用场景:多机多卡
- 在ray的情况下,Executor成为一个ray driver process,其下管控着若干worker process
(3)uni:UniProcExecutor
- 适用场景:单卡或 Neuron环境
(4)external_launcher:ExecutorWithExternalLauncher
- 适用场景:想要用自定义的外部工具(例如Slurm)来做分布式管理
注意:以上的“适用场景”描述的是一般情况,更具体的细节,请参见“决定executor类型”的相关代码。
在本文中,我们将以mp: MultiProcExecutor进行讲解,并假设这里的分布式配置仅用了tp,其余的细节留给用户自行阅读。
二、Executor -> Workers
2.1 整体架构-官网版
我们先来看下官方给出的Executor-Workers架构图。
https://blog.vllm.ai/2025/01/27/v1-alpha-release.html
上图右侧刻画了V1的架构:
- Scheduler和Executor都位于EngineCoreProc所在的进程上。如本文第一章offline batching的流程图所示,Scheduler决定单次调度步骤中要送去推理的请求,并将这些请求发送给Executor。
- 一个Executor下管理着若干workers,每个workers位于独立的进程上,可以理解成一个workers占据着一张卡
- Executor负责把请求broadcast到各个workers上
- 各个workers接收到请求,负责执行实际的推理过程,并将推理结果返回给Executor。
相比于V0,V1这种架构设计的优势在于:在V0中,worker0既要负责调度和数据分发、又要负责实际的推理计算。如此一来,各个workers间存在负载不均的问题,而worker0将成为性能的瓶颈。而V1通过拆分【调度】和【计算】过程解决了这个问题。
2.2 整体架构-细节版
现在我们已经通过vllm官方的示例图,初步了解了V1下Executor-Workers的架构,现在我们扩展这张架构图,来看更多的细节,为了画图简明,这里我们只展示了其中1个worker,没有画出全部workers:
上图展示的是使用MultiprocExecutor下的架构,如前文所说,该类型的Executor常被用于单机多卡的推理场景,我们按照从上到下,从左到右的顺序来解读上图。
1、MultiprocExecutor和Scheduler都位于EngineCoreProc所在的进程中。Scheduler负责决定单次调度步骤中,要送去推理的reqs,并将这些reqs传递给MultiprocExecutor。
2、在MultiprocExecutor上,我们将创建一个rpc_broadcast_mq队列:
- 该队列存储着Executor要broadcast给各个workers的【小数据(<=10MB)】,而【大数据(>10MB)】则不会进入此队列,而是通过zmq socket进行传输
- 每条数据可以被粗糙理解成是(method, data)的形式,data = 数据本身,method=你期望worker上调用什么样的方法来处理这条数据。
- 针对这个队列、以及大小数据的传输细节,我们将在本文第三部分详细介绍。
3、在MultiProcExecutor上,通过make_worker_process创建子进程:
- 每个进程占据一张卡,每个进程上运行着一个worker实例
- 在创建每个子进程时,我们会将rpc_broadcast_mq_handler,也就是输入队列的句柄也传递给子进程,这里你可以粗糙将“handler(句柄)”理解成是一个“地址”,有了这个地址,每个子进程才知道要去哪里找到并【连接】这个队列,以此读到队列中的数据。相关细节我们同样在后文给出。
4、每个Worker实例又由以下几部分组成:
- WorkerWrapper,每个Worker实例都有属于自己的WorkerWrapper。你可以将它形象理解成是一个worker的manager,它负责管控一个worker的生命周期(创建->销毁)、所占资源、扩展功能(例如在rlhf场景下的一些功能)等等。
- Worker,真正的Worker实例,它上面维护着两个重要队列:
- rpc_broadcast_mq:正如3中所说,单个worker通过rpc_broadcast_mq_handler这个句柄,连接上了Executor的rpc_broadcast_mq队列,这样它就能从这个队列中读取(method, data)数据。注意,这里说的是【连接】而不是创建,为了强调这一点,图中单worker上的该队列用虚线表示。
- worker_response_mq:单个worker【创建】的、用于存放这个worker上推理输出结果的队列。同样也会产出worker_response_mq_handler这个句柄。后续这个句柄将通过zmq socket传送给Executor,让Executor可以连接上这个队列,这样Executor就可以获得每个worker的输出结果。
- ModelRunner,一个worker实例下维护着一个ModelRunner实例,这个实例上维护着模型权重分片(model weights sharding)、这块卡上的kv_caches、attn_backend等一系列的具体模型信息,它将最终负责模型权重的加载(load_model),并执行实际的推理过程。
5、连接Executor和Worker的ZMQ sockets:
- Executor和Worker分属不同的进程,这里依然采用ZMQ sockets做进程间的通信。
- 这里其实创建了多个不同socket(为了表达简便,我统一画成ZMQ sockets),每个socket会用于不同内容的通信,例如:
- ready_socket:worker进程向Executor发送ready信号 + worker_broadcast_mq_handler
- local_socket:如前文所说,**除了使用上述的2个队列做Executor->Worker间的输入输出通信外,我们还会直接使用local_socket做输入输出通信。前者用于单机内快速通信较小的数据(<=10MB),后者用于通信大数据(>10MB)**。我们会在后文细说这一点。
- 等等
**6、worker_busy_loop()**:
在worker上启动busy loop,持续监听Executor发送的数据、做推理、并将推理结果持续返回给Executor。这样一来,这个worker就无限运转起来了,除非收到用户信号,显式终止掉这个worker,否则这个busy loop不会停止。
到此为止,我们简单总结一下在Executor->Workers的初始化环节都做了什么事:
- 首先,按照上图所示,创建了Executor->Workers架构,特别注意上述2个输入输出队列的初始化和连接。
- 对于每个Worker,我们通过init_device(),将它绑到指定的卡上,并对它做分布式环境初始化(即制定它的分布式通信group)
- 对于每个worker,我们通过load_model(),当ModelRunner实际去加载这个worker所要的模型分片
- 在每个worker上启动run_busy_loop(),让worker持续不断地运转起来。 更多的细节,请大家自行阅读源码。
接下来,我们着重来讨论这rpc_broadcast_mq和worker_response_mq这两个输入输出队列。
三、Executor与Worker间的数据传输机制
我们先快速回顾一下上文的内容:
(1)在我们的例子中,Executor的具体类型是MultiprocExecutor,它一般适用于单机多卡推理。
(2)Executor和Worker分属不同的进程,Executor需要把输入数据broadcast到Worker上,Worker需要把推理的输出结果返回给Executor。
(3)对于小数据(<=10MB),vllm使用rpc_broadcast_mq和worker_response_mq来做数据传输,这两个队列的本质是ShmRingBuffer(环形共享缓存),其中Shm即我们熟知的shared_memory,而ring是使用环形的方式往shm中读写数据(看不懂也没关系,我们马上来说细节)。
(4)对于大数据(>10MB),vllm使用zmq socket来做数据传输。
为什么要设计2种不同的进程间通信机制,来分别处理【小数据】和【大数据】呢?这里简单说几个我能想到的原因:
(1)首先,通过shm的方式读写数据时,不同的进程都从同一块共享内存(shm)上直接读取,这样数据不需要从一个进程的地址空间复制到另一个进程的的地址空间,也就是可以实现数据的“零拷贝访问”
(2)其次,通过shm的方式读写数据时,可以避免网络协议栈和数据重复写入的开销,可以实现更高效、更快的数据访问。
(3)那么,既然shm这么好,为什么只让【小数据】使用它,而让【大数据】走zmq socket呢?这是因为shm是一块固定的内存大小,一旦预分配好,就不能被改变了。在实际使用场景中,可能需要传输的数据量本身就不大,只是会偶发出现一些【大数据】传输的情况,因此我们没必要预留更大的shm空间,来应对这些只是偶发情况,这样会造成内存的浪费。所以我们额外使用zmq socket来处理这些偶发情况。
3.1 ShmRingBuffer(共享环形缓存)
我们先来看小数据传输的实现机制,相关代码参见:
https://github.com/vllm-project/vllm/blob/refs/tags/v0.8.2/vllm/distributed/device_communicators/shm_broadcast.py#L44
假设我们现在只使用tp,即一个Executor下有若干tp workers,那么:
- Scheduler生成单次调度的结果,将这些结果传递给Executor,我们称单次调度的结果为一个chunk
- Executor拿到单次调度的结果,写入rpc_broadcast_mq(本质是ShmRingBuffer)中
- 这些tp workers都需要从rpc_broadcast_mq读取这个chunk(每个tp worker的输入是相同的)
- 各个tp workers执行推理,并将推理结果写入各自维护的worker_broadcast_mq(本质是ShmRingBuffer)中。
- Scheduler继续生成单次调度结果(chunk),重复以上步骤。
- 不难发现,对于一个chunk,我们总有1个writer,和1个或若干个readers,例如:
- rpc_broadcast_mq的chunk中,writer = Executor,readers = tp workers
- worker_broadcast_mq的chunk中,writer = 某个tp worker,reader = Executor
现在,让我们将ShmRingBuffer想象成是一个存储站:
- 这个存储站中有max_chunks个柜子,每个柜子用于存储一块数据(chunk),max_chunk默认值为10
- 每个柜子的最大数据存储量为max_chunk_bytes,该值当前默认为10MB
- 每个柜子上有一个面板(metadata),这个面板上有1 + n_reader个指示灯。其中1这个指示灯代表written_flag,即用于指示writer是否把chunk塞进了柜子(写入完毕),n_reader个指示灯代表reader_flags,分别表示这些readers是否已经将这个chunk读取完毕。
- 由此可知,对于一个柜子,只有当writer写入完毕后,readers才可以去读。只有当所有readers都读取完毕后,这个柜子里的chunk才可以被“废弃”,也就是这个柜子才可以重新回到“可写入”的状态,让writer写入新数据。
- Scheduler在做一轮又一轮的调度,产出一个又一个的chunk,那么这些chunk就按照顺序,依次装入这些柜子中,当这10个柜子的数据都被轮番用过以后,下一次再来新chunk前,就从0号柜开始复用起(当然要按照上条所说的,检查该柜子是否达到可复用状态),这种环形使用的方式,称之为“ring”。
有了以上这个形象的理解,现在我们再回过头来看vllm代码中的这部分注释,就不难读懂了,而关于代码的更多细节,请大家自行阅读源码:
3.2 zmq socket
正如前文所说,在Executor和Worker间做大数据(>10MB)的传输时,可以使用zmq socket,这块就是传统的zmq socket构建流程了,没有太多可说的。这里我们想结合worker.worker_busy_loop()(也就是一个worker持续读取输入、进行推理、写入输出)的过程,来具体看一下shm和zmq socket是如何配合运作的。
worker_busy_loop()入口:
- https://github.com/vllm-project/vllm/blob/refs/tags/v0.8.2/vllm/v1/executor/multiproc_executor.py#L362
从代码中我们发现一件有趣的事:这里好像只从shm上读取数据,并没有通过zmq socket呀!不要紧,我们现在深入rpc_broadcast_mq.dequeue()中一探究竟。
rpc_broadcast_mq.dequeue():
- https://github.com/vllm-project/vllm/blob/refs/tags/v0.8.2/vllm/distributed/device_communicators/shm_broadcast.py#L462
整体上说:
- 当一个chunk过来时,我们会先检查它的大小:
- 如果<=10MB,则装入shm对应的柜子中
- 如果 > 10MB,则先在shm对应的柜子中记上一笔(buf[0]==1),然后再通过zmq sokcet去send这份数据
- 以上过程不在所提供的代码截图中,需要大家自行找相关代码阅读
- 接着,我们会根据这个柜子的标志buf[0]是否为1,来检测对应的chunk是装在柜子里,还是通过zmq socket发送了。如果是前者,那么直接从shm读;如果是后者,那么就通过zmq socket做receive。
- 最后,如果当前chunk是大数据,虽然它不会装在对应的柜子里,但我们也会认为这个柜子已经被使用过。这样后一个chunk来的时候,它不会使用当前的柜子,而是使用下一个可用的柜子。
你可能会发现,在上述dequeue的代码中,存在2种zmq socket:local_socket和remote_socket,前者用于单机内的通信(writer和readers在一台node内),后者用于多机间的通信(writer和readers不在一台内)。由于我们当前都是以MultiprocExecutor这种单机场景为例的,所以我们提到的zmq socket都是指local_socket。
好,关于Executor->Worker架构的介绍就到这里了,大家可以配合本文,自行阅读源码,获取更多细节。在这个系列后面的文章中,我们还会看到更多Executor-> Worker配合运行的更多例子。
#T2I-R1
文生图也能这样玩?T2I-R1:把R1的推理范式应用到文生图任务!
文生图进入R1时刻:港中文MMLab发布T2I-R1。
论文:https://arxiv.org/pdf/2505.00703
代码:https://github.com/CaraJ7/T2I-R1
最近的大语言模型(LLMs)如OpenAI o1和DeepSeek-R1,已经在数学和编程等领域展示了相当强的推理能力。通过强化学习(RL),这些模型在提供答案之前使用全面的思维链(CoT)逐步分析问题,显著提高了输出准确性。最近也有工作将这种形式拓展到图片理解的多模态大模型中(LMMs)中。然而,这种CoT推理策略如何应用于自回归的图片生成领域仍然处于探索阶段,我们之前的工作Image Generation with CoT(https://github.com/ZiyuGuo99/Image-Generation-CoT)对这一领域有过首次初步的尝试。
与图片理解不同,图片生成任务需要跨模态的文本与图片的对齐以及细粒度的视觉细节的生成。为此,我们提出了适用于图片生成的两个不同层次的CoT推理:
Semantic-CoT
Semantic-CoT 是对于要生成的图像的文本推理,在图像生成之前进行。
负责设计图像的全局结构,例如每个对象的外观和位置。
优化Semantic-CoT可以在图片Token的生成之前显式地对于Prompt进行规划和推理,使生成更容易。
Token-CoT
- Token-CoT是图片Token的逐块的生成过程。这个过程可以被视为一种CoT形式,因为它同样是在离散空间中基于所有先前的Token输出后续的Token,与文本CoT类似。
- Token-CoT更专注于底层的细节,比如像素的生成和维持相邻Patch之间的视觉连贯性。
- 优化Token-CoT可以提高生成图片的质量以及Prompt与生成图片之间的对齐。
然而,尽管认识到这两个层次的CoT,一个关键问题仍然存在:我们怎么能协调与融合它们? 当前主流的自回归图片生成模型如VAR完全基于生成目标进行训练,缺乏Semantic-CoT推理所需的显式文本理解。虽然引入一个专门用于提示解释的独立模型(例如LLM)在技术上是可行的,但这种方法会显著增加计算成本、复杂性和部署的困难。最近,出现了一种将视觉理解和生成合并到单一模型中的趋势。在LMMs的基础上,这些统一LMMs(ULMs)不仅可以理解视觉输入,还可以从文本提示生成图像。然而,它们的两种能力仍然是解耦的,通常在两个独立阶段进行预训练,没有明确证据表明理解能力可以使生成受益。鉴于这些潜力和问题,我们从一个ULM(Janus-Pro)开始,增强它以将Semantic-CoT以及Token-CoT统一到一个框架中用于文本生成图像:
我们提出了BiCoT-GRPO,一种使用强化学习的方法来联合优化ULM的两个层次的CoT:
我们首先指示ULM基于Image Prompt来想象和规划图像来获得Semantic-CoT。然后,我们将Image Prompt和Semantic-CoT重新输入ULM来生成图片以获得Token-CoT。我们对于一个Image Prompt生成多组Semantic-CoT和Token-CoT,对于得到的图像计算组内的相对奖励,从而使用GRPO的方法来在一个训练迭代内,同时优化两个层次的CoT。
与图片的理解任务不同,理解任务有明确定义的奖励规则,图像生成中不存在这样的标准化的规则。为此,我们提出使用多个不同的视觉专家模型的集成来作为奖励模型。这种奖励设计有两个关键的目的:
- 它从多个维度评估生成的图像以确保可靠的质量评估
- 作为一种正则化方法来防止ULM过拟合到某个单一的奖励模型
根据我们提出的方法,我们获得了T2I-R1,这是第一个基于强化学习的推理增强的文生图模型。根据T2I-R1生成的图片,我们发现我们的方法使模型能够通过推理Image Prompt背后的真实意图来生成更符合人类期望的结果,并在处理不寻常场景时展现出增强的鲁棒性。
同时,定量的实验结果也表明了我们方法的有效性。T2I-R1在T2I-CompBench和WISE的Benchmark上分别比baseline模型提高了13%和19%的性能,在多个子任务上甚至超越了之前最先进的模型FLUX.1。
#One Step Diffusion Via ShortCut Models论文解读
AIGC新手,内容理解如有不对请多多指正。
原文:One Step Diffusion via Shortcut Models
github:GitHub - kvfrans/shortcut-models
摘要
为了缓解目前diffusion架构+flow matching生成速度慢且训练阶段复杂的问题提出了一个叫shortcut model的模型,整个训练过程采用单一网络、单一训练阶段。condition包括当前噪声强度,还取决于stepsize(为了在去噪过程中直接跳过),这个方法在作者的实验中比蒸馏的方法好,并且在推理的时候可以改变step budgets。
之前SD3采直流匹配训练,但是仍然需要28步,这篇论文在首页放了一张效果图,效果看起来很惊艳。
初步效果图
整个网络设置是端到端且只需要一次训练就可以完成一个one-step模型,不像之前的关于蒸馏的工作(参考Progressive Distillation for Fast Sampling of Diffusion Models、http://arxiv.org/abs/2211.12039、Relational Diffusion Distillation for Efficient Image Generation,这三个工作都基于教师-学生来蒸馏,通过多阶段的训练来逐渐折半DDIM的采样步数)。
前置知识
Flow-matching
流匹配的内容在网络上的很多博客都有讲解,这部分就简单带过一下。
流匹配实际上就是通过学习将噪声转化为数据的常微分方程(ODE)来解决生成模型问题,在直流匹配中,整个模型就把真实图像的概率分布和噪声的概率分布之间的路径当作一条直线进行传输,在给定 x0 和 x1 的情况下,速度 vt 是完全确定的。但是,如果只给定 xt,就会有多个可信的配对(x0、x1),因此速度会有不同的值,这就使得 vt 成为一个随机变量。Flow-matching模型就是用来估计预期值在xt条件下的vt是多少,然后vt是xt处所有可信速度的平均值。最后可以通过对随机采样的噪声 x0 和数据 x1 对的经验速度进行回归来优化流量模型。
例子就是直流匹配
这个速度vt就是直接用xt对t求导,就得到了x1-x0,然后整个模型优化就靠下面这个损失函数。
直流匹配的损失函数
实际上就是用回归的损失去尽量让预测的速度能够符合直流匹配定义的速度。
然后去噪过程就是从流量模型中采样,首先从正态分布中采样一个噪声点 x0。然后根据学习到流模型从x0到x1迭代更新该点,整个过程可通过在较小的离散时间间隔内进行欧拉采样来近似实现,因为是直线传输。
为什么提出ShortCut models?
作者通过一个实验去研究了完美训练的ODE在步数减少之后的缺陷,具体来说就是步长有限的情况下,还是很难做到能够将噪声分布确定性地映射到我们需要的数据分布。
作者做的实验,这个图示还是很清晰的,仅给定 xt,vt虽然是根据直流的路线去学习的,但是学习得到的vt是存在固有的不确定性的,vt是指向数据点平均值的,直到缩减到一步,一步的话所有的vt几乎是指向一个点,并不能对应原始数据分布,多样性完全崩塌了
流量匹配学习预测从 xt 到数据的平均方向,因此跟随预测的步长越大,将跳转到多个数据点的平均值。在 t=0 时,模型接收纯噪声输入,并且(x0,x1)在训练过程中随机配对,因此 t=0 时的预测速度指向数据集平均值。因此,即使在流量匹配目标的最优状态下,对于任何多模式数据分布,一步生成都会失败。这段是作者原话,感觉说得蛮清晰,就不加个人理解了。
ShortCut Models
insight:可以训练一个支持不同sampling budgets的一个模型,以时间步长t和步长d为作为条件。那么就顺势提出了下面这个公式。
shortcut models的核心公式
这个s就是输入Xt,t,d之后的出来的捷径,得到这个路径之后就可以直接让Xt从这个s出发跳步得到Xt+d,OK,那么整个model的训练目标就很明确了,就是通过shortcut model去学习这个s,条件是Xt,t,d。其实整个公式就是直流匹配的跳步模式,当d≈0的时候,就是flow-matching的训练模式,s就直接退化成了v。
那么要学的东西出来了,用什么去约束呢?第一种方法当然就是用小步长去接近flow-matching的forward过程,但是这样做的话训练成本也还是很高,尤其是对直接端到端训练来说,并且小步长实际上对flow-matching的改进不是很大。第二种就是本文用的方法,直接用shortcut model自己的性质,就是一个s步等于两个s/2步。也就是以下公式。
shortcut等价模型
初步看这个公式可能会疑惑为什么会除以2,请注意,上一个公式在s求出来之后还需要乘d,所以s其实不是最终路程,最终的路程是s*d,而整个式子左边的步长为2d,路程相同的情况下,两边同时除以2d才得出来右边等式的1/2系数。
d>0的时候就直接用这个公式,d=0就直接用流匹配去训练。整体流程如下
shortcut对flow-matching的优化
其实就是将flow-matching当作连续的一条线,shortcut直接输入了步长,然后网络获得步长之后直接去获得应道到路径上的哪个路径点,就是上面图左边曲线的黄色部分。整个训练过程把flow-matching综合起来构成了下面的损失函数:
总体损失函数
上述目标学习的是从噪声到数据的映射,在任何步长序列下查询时都是一致的,包括直接在单步中查询。目标中的流量匹配部分将捷径模型建立在小步长的基础上,以匹配经验速度样本。这就确保了捷径模型在多步长查询时具有基础生成能力,这与等效的流量匹配模型完全相同。第二部分的话,通过串联两个较小shortcut的序列,为较大步长构建适当的目标。这样,生成能力就从多步到少步再到一步。综合目标可通过单一模型和单一端到端训练运行进行联合训练。
训练细节
名词定义:经验目标就是对应损失函数第一项需要的目标,一致性目标就是对应损失函数第二项所需要的目标
当 d → 0 时,s等同于vt。因此可以使用flow-matching的损失来训练d=0时的捷径模型,即随机抽样 (x0, x1) 对并拟合vt的期望值。这个项可以看作是小步s的基础,以匹配数据去噪ODE,然后对t ∼ U (0, 1) 进行均匀采样。为了限制复合误差,并且限制引导路径的总长度。因此,我们选择了一种二元递归模型,即用两条捷径来构建一条两倍大的捷径。
然后确定一个步数 M 来表示逼近 ODE 的最小时间单位;在实验中使用了 128 步。根据 d∈ (1/128, 1/64 ... 1/2, 1),这将产生 log2(128) + 1 = 8 种可能的捷径长度。在每个训练步骤中,我们对 xt、t 和随机 d < 1 进行采样,然后使用shortcut连续进行两步。然后将这两步的并集作为目标,并且在2d处训练模型。
将 1-k 个经验目标与k个一致性目标的比例结合起来,构建一个训练批次。k=1/4是合理的。其实这部分也很好理解,因为这个端到端模型实际上就是需要先训练一个flow-matching较好的模型,然后第二项只是在flow-matching的基础上进行优化,如果flow-matching训练得不好,后一项自然训练不好,因为s_target是需要从flow-matching模型中采样的,后一项只能在d=0训练的基础模型上去拟合这个模型,本质上shortcut还是一个教师-学生的思路,但是不同于之前教师和学生都是模型,shortcut将教师-学生拆分为两个损失函数去训练同一个模型,从而实现了端到端。
CFG设定:评估 d = 0 时的捷径模型时使用 CFG,而在其他情况下则放弃 CFG。CFG 在捷径模型中的一个局限性是,必须在训练前指定 CFG 比例。
EMA:用EMA去从d=0的模型上生成d=1的一致性目标,本质上就是平滑一下误差。
其他就是一些网络设置,这里就不一一阐述了,有兴趣可以自己查看一下原论文。
实验结果
FID-50K分数评估
FID-50K分数评估
可以看到在端到端的训练框架中,shortcut models的FID-50k是SOTA,但是相对于PD的蒸馏方式来说,在一步蒸馏中效果还是有待提高。
对ShortCut提出需解决问题的验证
FID下降趋势
在文章开投我们就提到了这篇论文的insight,他是为了缓解flow-matching在步数极低的情况下的崩塌而提出了,这个实验也证明了这一点,在1步模型中,Shortcut的表现完全暴打直接用flow-matching训练的diffusion(但实际上这个对比没有什么特别大意义,flow-matching确实就不适合一步训练,这个问题SD3当时也提出来了)。
作者在后续甚至验证了shortcut在其他领域的鲁棒性,确实是一项非常完善的工作,有其他领域的读者可以去看下原文。
总结
shortcut models确实提供了一个直接在flow-matching上蒸馏的好办法,但是训练过程中的参数设定个人感觉还是靠多种尝试,例如K的选取或许会较大程度影响shortcut models的发挥。反观多阶段的训练方法,至少多阶段确保了一个训练得较为完善的教师模型能够作为参考,而shortcut models如果参数设置不对,flow-mathcing的基础模型可能会不够完善,进而倒是损失第二项会出现较大程度的累计误差。
其次,作者本人也提到了,虽然shortcut能够抑制flow-matching直接在1步训练上的崩溃,但是在步数太低的时候仍然和多步采样存在较大的性能差距(不过1步能做到这个程度已经很好了。。。)。
总的来说,这篇论文的工作很完善,也是一个比较新颖的减少采样步数的方案,但是本质上也是蒸馏的一种,并且端到端的训练相比于多阶段的训练确实更依靠经验,一不注意就会训练失败。