传统的大语言模型采用的训练目标是 Next-Token Prediction (NTP),即在位置 t 上预测下一个 token (t+1)。
而 Multi-Token Prediction (MTP) 的核心思想在于:
- 不仅预测下一个 token,而是能够同时预测多个未来的 token。
- 这种方式可以显著提升推理效率。例如,当 n=4(一次预测 4 个 token)时,推理速度可实现约 3 倍的加速。
DeepSeek-V3 借鉴了 Meta FAIR 团队论文 Better & Faster Large Language Models via Multi-token Prediction 中的思路,但在实现上有明显不同:它并不是直接并行预测多个 token,而是保持完整的因果链,以逐层递进的方式预测未来 token。
本文将重点介绍 DeepSeek-V3 中 MTP 的实现。在此之前,我们先回顾一下 Meta FAIR 团队提出的 MTP 思路。
1. MTP 方法
1.1 NTP (Next-token Prediction)
- 传统语言模型的训练目标:给定历史上下文 $x_{1:t}$,预测下一个 token $x_{t+1}$。
- 损失函数是标准的交叉熵:$$ L_1 = -\sum_t \log P_\theta(x_{t+1} | x_{1:t}) $$
- 这种方式虽然简单有效,但只考虑一步预测,容易陷入局部模式学习。
下图是 NTP 示意图,我们以 Qwen2.5-32B 为例,词表大小为 152064,hidden size 为 $d_{model}$=5120 ,num heads 为 40,Transformer block 的层数为 64,假设输入序列长度为 2048。