Masked Attention 在 LLM 训练中的作用与原理

大语言模型(LLM)训练过程中,Masked Attention(掩码注意力) 是一个关键机制,它决定了 模型如何在训练时只利用过去的信息,而不会看到未来的 token。这篇文章将帮助你理解 Masked Attention 的作用、实现方式,以及为什么它能确保当前 token 只依赖于过去的 token,而不会泄露未来的信息。

1. Masked Attention 在 LLM 训练中的作用

在 LLM 训练时,我们通常使用 自回归(Autoregressive) 方式来让模型学习文本的生成。例如,给定输入序列:

"The cat is very"

模型需要预测下一个 token:

"cute"

但是,为了保证模型的生成方式符合自然语言流向,每个 token 只能看到它之前的 token,不能看到未来的 token

Masked Attention 的作用就是:

  • 屏蔽未来的 token,使当前 token 只能关注之前的 token
  • 保证训练阶段的注意力机制符合推理时的因果(causal)生成方式
  • 防止信息泄露,让模型学会自回归生成文本

如果没有 Masked Attention,模型在训练时可以“偷看”未来的 token,导致它学到的规律无法泛化到推理阶段,从而影响文本生成的效果。

举例说明

假设输入是 "The cat is cute",模型按 token 级别计算注意力:

(1) 没有 Mask(BERT 方式)
TokenThecatiscute
The
cat
is
cute

每个 token 都能看到整个句子,适用于 BERT 这种双向模型。

(2) 有 Mask(GPT 方式)
TokenThecatiscute
The
cat
is
cute

每个 token 只能看到它自己及之前的 token,保证训练和推理时的生成顺序一致。

2. Masked Attention 的工作原理

 在标准的 自注意力(Self-Attention) 机制中,注意力分数是这样计算的:

A = \text{softmax} \left( \frac{Q K^T}{\sqrt{d_k}} \right)

其中:

  • Q, K, V  是 Query(查询)、Key(键)和 Value(值)矩阵

  • Q K^T 计算所有 token 之间的相似度

  • 如果不做 Masking,每个 token 都能看到所有的 token

而在 Masked Attention 中,我们会使用一个 上三角掩码(Upper Triangular Mask),使得未来的 token 不能影响当前 token:

S' = \frac{Q K^T}{\sqrt{d_k}} + \text{mask}

Mask 是一个 上三角矩阵,其中:

  • 未来 token 的位置填充 -\infty,确保 softmax 之后它们的注意力权重为 0

  • 只允许关注当前 token 及之前的 token

例如,假设有 4 个 token:

\begin{bmatrix} s_{1,1} & -\infty & -\infty & -\infty \\ s_{2,1} & s_{2,2} & -\infty & -\infty \\ s_{3,1} & s_{3,2} & s_{3,3} & -\infty \\ s_{4,1} & s_{4,2} & s_{4,3} & s_{4,4} \end{bmatrix}

经过 softmax 之后:

A = \begin{bmatrix} 1 & 0 & 0 & 0 \\ \text{non-zero} & \text{non-zero} & 0 & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & \text{non-zero} \end{bmatrix}

最终,每个 token 只会关注它自己和它之前的 token,完全忽略未来的 token!

3. Masked Attention 计算下三角部分的值时,如何保证未来信息不会泄露?

换句话说,我们需要证明 Masked Attention 计算出的下三角部分的值(即历史 token 之间的注意力分数)不会受到未来 token 的影响

1. 问题重述

Masked Attention 的核心计算是:

\text{Attention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}} + \text{mask}) V

其中:

  • Q, K, V 是整个序列的矩阵。

  • QK^T计算的是所有 token 之间的注意力分数。

  • Mask 确保 softmax 后未来 token 的注意力分数变为 0。

这个问题可以分解成两个关键点:

  1. 未来 token 是否影响了下三角部分的 Q 或 K?

  2. 即使未来 token 参与了 Q, K 计算,为什么它们不会影响下三角的注意力分数?

2. 未来 token 是否影响了 Q 或 K?

我们先看 Transformer 计算 Q, K, V 的方式:

Q = X W_Q, \quad K = X W_K, \quad V = X W_V

这里:

  • X 是整个输入序列的表示。

  • W_Q, W_K, W_V是相同的投影矩阵,作用于所有 token。

由于 每个 token 的 Q, K, V 只取决于它自己,并不会在计算时使用未来 token 的信息,所以:

  • 计算第 i 个 token 的 Q_i, K_i, V_i时,并没有用到 X_{i+1}, X_{i+2}, \dots,所以未来 token 并不会影响当前 token 的 Q, K, V

结论 1未来 token 不会影响当前 token 的 Q 和 K。

3. Masked Attention 如何确保下三角部分不包含未来信息?

即使 Q, K 没有未来信息,我们仍然要证明 计算出的注意力分数不会受到未来信息影响

我们来看注意力计算:

\frac{Q K^T}{\sqrt{d_k}}

这是一个 所有 token 之间的相似度矩阵,即:

S = \begin{bmatrix} Q_1 \cdot K_1^T & Q_1 \cdot K_2^T & Q_1 \cdot K_3^T & Q_1 \cdot K_4^T \\ Q_2 \cdot K_1^T & Q_2 \cdot K_2^T & Q_2 \cdot K_3^T & Q_2 \cdot K_4^T \\ Q_3 \cdot K_1^T & Q_3 \cdot K_2^T & Q_3 \cdot K_3^T & Q_3 \cdot K_4^T \\ Q_4 \cdot K_1^T & Q_4 \cdot K_2^T & Q_4 \cdot K_3^T & Q_4 \cdot K_4^T \end{bmatrix}

然后,我们应用 因果 Mask(Causal Mask)

S' = S + \text{mask}

Mask 让右上角(未来 token 相关的部分)变成 -\infty

\begin{bmatrix} S_{1,1} & -\infty & -\infty & -\infty \\ S_{2,1} & S_{2,2} & -\infty & -\infty \\ S_{3,1} & S_{3,2} & S_{3,3} & -\infty \\ S_{4,1} & S_{4,2} & S_{4,3} & S_{4,4} \end{bmatrix}

然后计算 softmax:

A = \text{softmax}(S')

由于 e^{-\infty} = 0,所有未来 token 相关的注意力分数都变成 0

A = \begin{bmatrix} 1 & 0 & 0 & 0 \\ \text{non-zero} & \text{non-zero} & 0 & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & \text{non-zero} \end{bmatrix}

最后,我们计算:

\text{Output} = A V

由于未来 token 的注意力权重是 0,它们的 V 在计算中被忽略。因此,下三角部分(历史 token 之间的注意力)完全不受未来 token 影响。

结论 2未来 token 的信息不会影响下三角部分的 Attention 计算。

4. 为什么 Masked Attention 能防止未来信息泄露?

你可能会问:

即使有 Mask,计算 Attention 之前,我们不是还是用到了整个序列的 Q, K, V 吗?未来 token 的 Q, K, V 不是已经算出来了吗?

的确,每个 token 的 Q, K, V 是独立计算的,但 Masked Attention 确保了:

  1. 计算 Q, K, V 时,每个 token 只依赖于它自己的输入

    • Q_i, K_i, V_i只来自 token i,不会用到未来的信息

    • 未来的 token 并不会影响当前 token 的 Q, K, V

  2. Masked Softmax 阻止了未来 token 的影响

    • 虽然 Q, K, V 都计算了,但 Masking 让未来 token 的注意力分数变为 0,确保计算出的 Attention 结果不包含未来信息。

最终,当前 token 只能看到过去的信息,未来的信息被完全屏蔽!

5. 训练时使用 Masked Attention 的必要性

Masked Attention 的一个关键作用是 让训练阶段和推理阶段保持一致

  • 训练时:模型学习如何根据 历史 token 预测 下一个 token,确保生成文本时符合自然语言流向。

  • 推理时:模型生成每个 token 后,仍然只能访问过去的 token,而不会看到未来的 token。

如果 训练时没有 Masked Attention,模型会学习到“作弊”策略,直接利用未来信息进行预测。但在推理时,模型无法“偷看”未来的信息,导致生成质量急剧下降。

6. 结论

Masked Attention 是 LLM 训练的核心机制之一,其作用在于:

  • 确保当前 token 只能访问过去的 token,不会泄露未来信息
  • 让训练阶段与推理阶段保持一致,避免模型在推理时“失效”
  • 利用因果 Mask 让 Transformer 具备自回归能力,学会按序生成文本

Masked Attention 本质上是 Transformer 训练过程中对信息流动的严格约束,它确保了 LLM 能够正确学习自回归生成任务,是大模型高质量文本生成的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/73957.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【自学笔记】PHP语言基础知识点总览-持续更新

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1. PHP 简介2. PHP 环境搭建3. 基本语法变量与常量数据类型运算符 4. 控制结构条件语句循环语句 5. 函数函数定义与调用作用域 6. 数组7. 字符串8. 表单处理9. 会话…

css选择最后结尾的元素DOM

前言 选中最后一个元素&#xff0c;实际使用非常频繁。 解决方案 使用 CSS 提供的选择器&#xff0c;即可完成。 如下代码示例&#xff0c;两种选择器均可实现。 <p>...</p>p:last-child{ background:#ff0000; }p:nth-last-child(1){background:#ff0000; }p&…

Axios 相关的面试题

在跟着视频教程学习项目的时候使用了axios发送请求&#xff0c;但是只是跟着把代码粘贴上去&#xff0c;一些语法规则根本不太清楚&#xff0c;但是根据之前的博客学习了fetch了之后&#xff0c;一看axios的介绍就明白了。所以就直接展示axios的面试题吧 本文主要内容&#xff…

瑞芯微RKRGA(librga)Buffer API 分析

一、Buffer API 简介 在瑞芯微官方的 librga 库的手册中&#xff0c;有两组配置 buffer 的API&#xff1a; importbuffer 方式&#xff1a; importbuffer_virtualaddr importbuffer_physicaladdr importbuffer_fd wrapbuffer 方式&#xff1a; wrapbuffer_virtualaddr wrapb…

C语言:多线程

多线程概述 定义 多线程是指在一个程序中可以同时运行多个不同的执行路径&#xff08;线程&#xff09;&#xff0c;这些线程可以并发或并行执行。并发是指多个线程在宏观上同时执行&#xff0c;但在微观上可能是交替执行的&#xff1b;并行则是指多个线程真正地同时执行&…

Linux线程池实现

1.线程池实现 全部代码&#xff1a;whb-helloworld/113 1.唤醒线程 一个是唤醒全部线程&#xff0c;一个是唤醒一个线程。 void WakeUpAllThread(){LockGuard lockguard(_mutex);if (_sleepernum)_cond.Broadcast();LOG(LogLevel::INFO) << "唤醒所有的休眠线程&q…

微信小程序逆向开发

一.wxapkg文件 如何查看微信小程序包文件&#xff1a; 回退一级 点击进入这个目录 这个就是我们小程序对应的文件 .wxapkg概述 .wxapkg是微信小程序的包文件格式&#xff0c;且其具有独特的结构和加密方式。它不仅包含了小程序的源代码&#xff0c;还包括了图像和其他资源文…

多输入多输出 | Matlab实现CPO-LSTM冠豪猪算法优化长短期记忆神经网络多输入多输出预测

多输入多输出 | Matlab实现CPO-LSTM冠豪猪算法优化长短期记忆神经网络多输入多输出预测 目录 多输入多输出 | Matlab实现CPO-LSTM冠豪猪算法优化长短期记忆神经网络多输入多输出预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 Matlab实现CPO-LSTM冠豪猪算法优化长短期…

视频编码器的抉择:x264、x265、libaom、vvenc 对比测试实验

264、x265、libaom、vvenc 对比测试实验 测试机器配置&#xff1a;Apple M1 Pro -16G编码器版本&#xff08;选择自己编译&#xff09;&#xff1a;所有源码都是当前最新更新的状态&#xff0c;此外各类编码具体的编译过程可参考我的相关系列博客。 编码器GitHubx264git clon…

【二刷代码随想录】双指针-数组相关题型、推荐习题

一、双指针-数组 相关题型与常用思路 1、单个数组 &#xff08;1&#xff09;原地移除元素类 如推荐习题中的&#xff08;1&#xff09;、&#xff08;2&#xff09;、&#xff08;3&#xff09;&#xff0c;都属于此类。引入双指针 pre、last &#xff0c;用 pre 指针表明数…

Level DB --- TableCache

TableCache 是Level DB 中重要的类&#xff0c;Level DB 中多层&#xff08;multi level&#xff09;&#xff0c;且每一层&#xff08;level&#xff09;有多个 key-value file&#xff0c;TableCache正是用来缓存多层以及多层中的file数据&#xff0c;更快速地检索。 table …

搜索-BFS

马上蓝桥杯了&#xff0c;最近刷了广搜&#xff0c;感觉挺有意思的&#xff0c;广搜题类型都差不多&#xff0c;模板也一样&#xff0c;大家写的时候可以直接套模板 这里给大家讲一个比较经典的广搜题-迷宫 题目问问能否走到 (n,m) 位置&#xff0c;假设最后一个点是我们的&…

智能预测维护:让设备“未卜先知”,减少宕机烦恼

智能预测维护:让设备“未卜先知”,减少宕机烦恼 1. 引言:设备维护的痛点与出路 在工业生产和自动化领域,设备故障一直是令人头疼的问题。设备一旦故障,轻则影响生产效率,重则造成严重损失,甚至带来安全隐患。传统的设备维护方式主要有两种: 被动维护(Reactive Maint…

安卓的布局方式

一、RelativeLayout 相对布局 特点&#xff1a;每个组件相对其他的某一个组件进行定位。 (一)主要属性 1、设置和父组件的对齐&#xff1a; alignParentTop &#xff1a; 设置为true&#xff0c;代表和父布局顶部对齐。 其他对齐只需要改变后面的Top为 Left、Right 或者Bottom&…

SSM中药分类管理系统

&#x1f345;点赞收藏关注 → 添加文档最下方联系方式咨询本源代码、数据库&#x1f345; 本人在Java毕业设计领域有多年的经验&#xff0c;陆续会更新更多优质的Java实战项目希望你能有所收获&#xff0c;少走一些弯路。&#x1f345;关注我不迷路&#x1f345; 项目视频 SS…

epoch、batch、batch size、step、iteration深度学习名词含义详细介绍

卷积神经网络训练中的三个核心概念&#xff1a;Epoch、Batch Size 和迭代次数 在深度学习中&#xff0c;理解一些基本的术语非常重要&#xff0c;这些术语对模型的训练过程、效率以及最终性能都有很大影响。以下是一些常见术语的含义介绍&#xff1a; 1. Epoch&#xff08;周…

React(七):Redux

Redux基本使用 纯函数&#xff1a;1.函数内部不能依赖函数外部变量&#xff1b;2.不能产生副作用&#xff0c;在函数内部改变函数外部的变量 React只帮我们解决了DOM的渲染过程&#xff0c;State还是要由我们自己来管理——redux可帮助我们进行管理 Redux三大特点 1.单一数…

《Android低内存设备性能优化实战:深度解析Dalvik虚拟机参数调优》

1. 痛点分析&#xff1a;低内存设备的性能困局 现象描述&#xff1a;大应用运行时频繁GC导致卡顿 根本原因&#xff1a;Dalvik默认内存参数与硬件资源不匹配 解决方向&#xff1a;动态调整堆内存参数以平衡性能与资源消耗 2. 核心调优参数全景解析 关键参数矩阵&#xff1…

STC89C52单片机学习——第38节: [17-2] 红外遥控红外遥控电机

写这个文章是用来学习的,记录一下我的学习过程。希望我能一直坚持下去,我只是一个小白,只是想好好学习,我知道这会很难&#xff0c;但我还是想去做&#xff01; 本文写于&#xff1a;2025.03.30 51单片机学习——第38节: [17-2] 红外遥控&红外遥控电机 前言开发板说明引用…

计算机组成原理————计算机运算方法精讲<1>原码表示法

第一部分:无符号数和有符号数的概念 1.无符号数 计算机中的数均存放在寄存器当中,通常称寄存器的位数为机器字长,所谓无符号数,就是指没有fu5号的数,在寄存器中的每一位均可用来存放数值,当存放有符号数时,需要留出位置存放符号,机器字长相同时,无符号数与有符号数所…