第8讲、Multi-Head Attention 的核心机制与实现细节

🤔 为什么要有 Multi-Head Attention?

单个 Attention 机制虽然可以捕捉句子中不同词之间的关系,但它只能关注一种角度或模式

Multi-Head 的作用是:

多个头 = 多个视角同时观察序列的不同关系

例如:

  • 一个头可能专注主语和动词的关系;
  • 另一个头可能专注宾语和介词;
  • 还有的可能学习句法结构或时态变化。

这些头的表示最终会被拼接(concatenate)后再线性变换整合成更丰富的上下文表示。

🔍 技术深入:Multi-Head Attention 计算过程

Multi-Head Attention 的计算过程如下:

  1. 对输入 X 进行线性变换得到 Q、K、V 矩阵
  2. 将 Q、K、V 分割成 h 个头
  3. 每个头独立计算 Attention
  4. 拼接所有头的输出
  5. 最后进行一次线性变换
# 伪代码实现
def multi_head_attention(X, h=8):# 线性变换获得 Q, K, VQ = X @ W_q  # [batch_size, seq_len, d_model]K = X @ W_kV = X @ W_v# 分割成多头Q_heads = split_heads(Q, h)  # [batch_size, h, seq_len, d_k]K_heads = split_heads(K, h)V_heads = split_heads(V, h)# 每个头独立计算 attentionattn_outputs = []for i in range(h):attn_output = scaled_dot_product_attention(Q_heads[:, i], K_heads[:, i], V_heads[:, i])attn_outputs.append(attn_output)# 拼接所有头的输出concat_output = concatenate(attn_outputs)  # [batch_size, seq_len, d_model]# 最后的线性变换output = concat_output @ W_oreturn output

🧮 如何判断多少个头(h)?

Transformer 默认将 d_model(模型维度)均分给每个头。

设:

  • d_model = 512:模型的总嵌入维度
  • h = 8:头数

那么每个头的维度为:

d_k = d_model // h = 512 // 8 = 64

一般要求:

⚠️ d_model 必须能被 h 整除。

📊 参数计算

Multi-Head Attention 中的参数量:

  • 输入投影矩阵:3 × (d_model × d_model) = 3d_model²
  • 输出投影矩阵:d_model × d_model = d_model²

总参数量:4 × d_model²

例如,当 d_model = 512 时,参数量约为 100 万。


📌 头的数量怎么选?

头数 h每头维度 d_k适用情境
1全部基线,最弱(没多视角)
4中等小模型,如 tiny Transformer
864标准配置,如原始 Transformer
16更细粒度大模型中常见,如 BERT-large

实际训练中:

  • 小任务(toy 或翻译教学):用 2 或 4 个头就够了。
  • 真实 NLP 任务:建议使用 8 个头(Transformer-base 规范)。
  • 太多头而模型参数不足时,效果可能反而下降(每头维度太小)。

📈 头数与性能关系

研究表明,头数与模型性能并非简单的线性关系:

  • 头数过少:无法捕捉多种语言模式
  • 头数适中:性能最佳
  • 头数过多:每个头的维度变小,表达能力下降

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

🔬 实验发现

Michel et al. (2019) 的研究《Are Sixteen Heads Really Better than One?》发现:

  1. 在训练好的模型中,并非所有头都同等重要
  2. 大多数情况下,可以剪枝掉一部分头而不显著影响性能
  3. 不同层的头有不同的作用,底层头和顶层头往往更为重要

💡 Multi-Head Attention 的优势

  1. 并行计算:所有头可以并行计算,提高训练效率
  2. 多角度表示:捕捉不同类型的依赖关系
  3. 信息冗余:多头提供冗余信息,增强模型鲁棒性
  4. 注意力分散:防止单一头过度关注某些模式

🧠 总结一句话

Multi-Head 的本质是多角度捕捉词与词的关系,提升模型对上下文的理解能力。头数越多,观察角度越多,但每个头的维度会减小,需注意平衡。


📊 Attention 可视化

不同头学习到的注意力模式各不相同。以下是一个英语句子在 8 头注意力机制下的可视化示例:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

可以看到:

  • 头1:关注相邻词的关系
  • 头2:捕捉主语-谓语关系
  • 头3:识别句法结构
  • 头4:连接相关实体
  • 其他头:各自专注于不同的语言特征

这种多角度的观察使得 Transformer 能够全面理解文本的语义和结构。


🖥️ Streamlit 交互式可视化案例

想要直观地理解 Multi-Head Attention?以下是一个使用 Streamlit 构建的交互式可视化案例,让你可以实时探索不同头的注意力模式:

import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertModel# 页面设置
st.set_page_config(page_title="Multi-Head Attention 可视化", layout="wide")
st.title("Multi-Head Attention 可视化工具")# 加载预训练模型
@st.cache_resource
def load_model():tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')model = BertModel.from_pretrained('bert-base-chinese', output_attentions=True)return tokenizer, modeltokenizer, model = load_model()# 用户输入
user_input = st.text_area("请输入一段文本进行分析:", "Transformer是一种强大的神经网络架构,它使用了Multi-Head Attention机制。",height=100)# 处理文本
if user_input:# 分词并获取注意力权重inputs = tokenizer(user_input, return_tensors="pt")outputs = model(**inputs)# 获取所有层的注意力权重attentions = outputs.attentions  # tuple of tensors, one per layer# 选择层layer_idx = st.slider("选择Transformer层:", 0, len(attentions)-1, 0)# 获取选定层的注意力权重layer_attentions = attentions[layer_idx].detach().numpy()# 获取头数num_heads = layer_attentions.shape[1]# 选择头head_idx = st.slider("选择注意力头:", 0, num_heads-1, 0)# 获取选定头的注意力权重head_attention = layer_attentions[0, head_idx]# 获取标记tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])# 可视化fig, ax = plt.subplots(figsize=(10, 8))sns.heatmap(head_attention, xticklabels=tokens, yticklabels=tokens, cmap="YlGnBu", ax=ax)plt.title(f"第 {layer_idx+1} 层,第 {head_idx+1} 个头的注意力权重")st.pyplot(fig)# 显示注意力模式分析st.subheader("注意力模式分析")# 计算每个词的平均注意力avg_attention = head_attention.mean(axis=0)top_indices = np.argsort(avg_attention)[-3:][::-1]st.write("这个注意力头主要关注的词:")for idx in top_indices:st.write(f"- {tokens[idx]}: {avg_attention[idx]:.4f}")# 添加交互式功能if st.checkbox("显示所有头的对比"):st.subheader("所有头的注意力对比")# 为每个头创建一个小型热力图# 计算行列数以适应任意数量的头num_cols = 4num_rows = (num_heads + num_cols - 1) // num_cols  # 向上取整fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 3*num_rows))axes = axes.flatten()for h in range(num_heads):sns.heatmap(layer_attentions[0, h], xticklabels=[] if h < (num_heads-num_cols) else tokens, yticklabels=[] if h % num_cols != 0 else tokens, cmap="YlGnBu", ax=axes[h])axes[h].set_title(f"头 {h+1}")# 隐藏未使用的子图for h in range(num_heads, len(axes)):axes[h].axis('off')plt.tight_layout()st.pyplot(fig)# 添加解释st.markdown("""### 如何解读这个可视化:- 颜色越深表示注意力权重越高- 纵轴代表查询词(当前词)- 横轴代表键词(被关注的词)- 每个头学习不同的关注模式通过调整滑块,你可以探索不同层和不同头的注意力模式,观察模型如何理解文本中的关系。""")# 运行说明
st.sidebar.markdown("""
## 使用说明1. 在文本框中输入你想分析的文本
2. 使用滑块选择要查看的层和注意力头
3. 查看热力图了解词与词之间的注意力关系
4. 勾选"显示所有头的对比"可以同时查看所有头的模式这个工具帮助你直观理解 Multi-Head Attention 的工作原理和不同头的功能分工。
""")# 代码说明
with st.expander("查看完整代码实现"):st.code("""
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import BertTokenizer, BertModel# 页面设置
st.set_page_config(page_title="Multi-Head Attention 可视化", layout="wide")
st.title("Multi-Head Attention 可视化工具")# 加载预训练模型
@st.cache_resource
def load_model():tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')model = BertModel.from_pretrained('bert-base-chinese', output_attentions=True)return tokenizer, modeltokenizer, model = load_model()# 用户输入和可视化逻辑
# ...此处省略,与上面代码相同
""")### 🚀 如何运行这个可视化工具1. 安装必要的依赖:
```bash
pip install streamlit torch transformers matplotlib seaborn
  1. 将上面的代码保存为 attention_viz.py

  2. 运行 Streamlit 应用:

streamlit run attention_viz.py


这个交互式工具让你可以:

  • 输入任意文本并查看注意力分布
  • 选择不同的 Transformer 层和注意力头
  • 直观对比不同头学习到的不同模式
  • 分析哪些词获得了最高的注意力权重

通过这个可视化工具,你可以亲自探索 Multi-Head Attention 的工作原理,加深对这一机制的理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/83442.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

百度智能云千帆携手联想,共创MCP生态宇宙

5月7日&#xff0c;2025联想创新科技大会&#xff08;Tech World&#xff09;在上海世博中心举行&#xff0c;本届大会以“让AI成为创新生产力”为主题。会上&#xff0c;联想集团董事长兼CEO杨元庆展示了包括覆盖全场景的超级智能体矩阵&#xff0c;包括个人超级智能体、企业超…

【OpenCV】帧差法、级联分类器、透视变换

一、帧差法&#xff08;移动目标识别&#xff09;&#xff1a; 好处&#xff1a;开销小&#xff0c;不怎么消耗CPU的算力&#xff0c;对硬件要求不高&#xff0c;但只适合固定摄像头 1、优点 计算效率高&#xff0c;硬件要求 响应速度快&#xff0c;实时性强 直接利用连续帧…

数据库迁移的艺术:团队协作中的冲突预防与解决之道

title: 数据库迁移的艺术:团队协作中的冲突预防与解决之道 date: 2025/05/17 00:13:50 updated: 2025/05/17 00:13:50 author: cmdragon excerpt: 在团队协作中,数据库迁移脚本冲突是常见问题。通过Alembic工具,可以有效地管理和解决这些冲突。冲突预防的四原则包括功能分…

Linux常用命令43——bunzip2解压缩bz2文件

在使用Linux或macOS日常开发中&#xff0c;熟悉一些基本的命令有助于提高工作效率&#xff0c;bunzip2可解压缩.bz2格式的压缩文件。bunzip2实际上是bzip2的符号连接&#xff0c;执行bunzip2与bzip2 -d的效果相同。本篇学习记录bunzip2命令的基本使用。 首先查看帮助文档&#…

盲盒:拆开未知的惊喜,收藏生活的仪式感

一、什么是盲盒&#xff1f;—— 一场关于“未知”的浪漫冒险 盲盒&#xff0c;是一种充满神秘感的消费体验&#xff1a; &#x1f381; 盒中藏惊喜——每个盲盒外观相同&#xff0c;但内含随机商品&#xff0c;可能是普通款、稀有款&#xff0c;甚至是“隐藏款”&#xff1b;…

Android 中使用通知(Kotlin 版)

1. 前置条件 Android Studio&#xff1a;确保使用最新版本&#xff08;2023.3.1&#xff09;目标 API&#xff1a;最低 API 21&#xff0c;兼容 Android 8.0&#xff08;渠道&#xff09;和 13&#xff08;权限&#xff09;依赖库&#xff1a;使用 WorkManager 和 Notificatio…

使用大模型预测急性结石性疾病技术方案

目录 1. 数据预处理与特征工程伪代码 - 数据清洗与特征处理数据预处理流程图2. 大模型构建与训练伪代码 - 模型训练模型训练流程图3. 术前预测系统伪代码 - 术前风险评估术前预测流程图4. 术中实时调整系统伪代码 - 术中风险预警术中调整流程图5. 术后护理系统伪代码 - 并发症预…

每日Prompt:生成自拍照

提示词 帮我生成一张图片&#xff1a;图片风格为「人像摄影」&#xff0c;请你画一张及其平凡无奇的iPhone对镜自拍照&#xff0c;主角是穿着JK风格cos服的可爱女孩&#xff0c;在自己精心布置的可按风格的房间内的落地镜前用后置摄像头随手一拍的快照。照片开启了闪光灯&…

动态规划-64.最小路径和-力扣(LetCode)

一、题目解析 从左上角到右下角使得数字总和最小且只能向下或向右移动 二、算法原理 1.状态表示 我们需要求到达[i,j]位置时数字总和的最小值&#xff0c;所以dp[i][j]表示&#xff1a;到达[i,j]位置时&#xff0c;路径数字总和的最小值。 2.状态转移方程 到达[i,j]之前要先…

LeetCode LCR 010 和为 K 的子数组 (Java)

两种解法详解&#xff1a;暴力枚举与前缀和哈希表寻找和为k的子数组 在解决数组中和为k的连续子数组个数的问题时&#xff0c;我们可以采用不同的方法。本文将详细解析两种常见的解法&#xff1a;暴力枚举法和前缀和结合哈希表的方法&#xff0c;分析它们的思路、优缺点及适用…

OpenVLA (2) 机器人环境和环境数据

文章目录 [TOC](文章目录) 前言1 BridgeData V21.1 概述1.2 硬件环境 2 数据集2.1 场景与结构2.2 数据结构2.2.1 images02.2.2 obs_dict.pkl2.2.3 policy_out.pkl 3 close question3.1 英伟达环境3.2 LIBERO 环境更适合仿真3.3 4090 运行问题 前言 按照笔者之前的行业经验, 数…

深度学习(第3章——亚像素卷积和可形变卷积)

前言&#xff1a; 本章介绍了计算机识别超分领域和目标检测领域中常常使用的两种卷积变体&#xff0c;亚像素卷积&#xff08;Subpixel Convolution&#xff09;和可形变卷积&#xff08;Deformable Convolution&#xff09;&#xff0c;并给出对应pytorch的使用。 亚像素卷积…

大模型在腰椎间盘突出症预测与治疗方案制定中的应用研究

目录 一、引言 1.1 研究背景 1.2 研究目的与意义 二、腰椎间盘突出症概述 2.1 定义与病因 2.2 症状与诊断方法 2.3 治疗方法概述 三、大模型技术原理与应用基础 3.1 大模型的基本原理 3.2 大模型在医疗领域的应用现状 3.3 用于腰椎间盘突出症预测的可行性分析 四、…

Vue3学习(组合式API——ref模版引用与defineExpose编译宏函数)

目录 一、ref模版引用。 &#xff08;1&#xff09;基本介绍。 &#xff08;2&#xff09;核心基本步骤。(以获取DOM、组件为例) &#xff08;3&#xff09;案例&#xff1a;获取dom对象演示。 <1>需求&#xff1a;点击按钮&#xff0c;让输入框聚焦。 &#xff08;4&…

公链开发及其配套设施:钱包与区块链浏览器

公链开发及其配套设施&#xff1a;钱包与区块链浏览器的技术架构与生态实践 ——2025年区块链基础设施建设的核心逻辑与创新突破 一、公链开发&#xff1a;构建去中心化世界的基石 1. 技术架构设计的三重挑战 公链作为开放的区块链网络&#xff0c;需在性能、安全性与去中心…

Kotlin 作用域函数(let、run、with、apply、also)对比

Kotlin 的 作用域函数&#xff08;Scope Functions&#xff09; 是简化代码逻辑的重要工具&#xff0c;它们通过临时作用域为对象提供更简洁的操作方式。以下是 let、run、with、apply、also 的对比分析&#xff1a; 一、核心区别对比表 函数上下文对象引用返回值是否扩展函数…

14、Python时间表示:Unix时间戳、毫秒微秒精度与time模块实战

适合人群&#xff1a;零基础自学者 | 编程小白快速入门 阅读时长&#xff1a;约5分钟 文章目录 一、问题&#xff1a;计算机中的时间的表示、Unix时间点&#xff1f;1、例子1&#xff1a;计算机的“生日”&#xff1a;Unix时间点2、答案&#xff1a;&#xff08;1&#xff09;U…

AI日报 - 2024年5月17日

&#x1f31f; 今日概览 (60秒速览) ▎&#x1f916; 大模型前沿 | OpenAI推出自主编码代理Codex&#xff1b;Google DeepMind发布Gemini驱动的编码代理AlphaEvolve&#xff0c;能设计先进算法&#xff1b;Meta旗舰AI模型Llama 4 Behemoth发布推迟。 Codex能并行处理多任务&…

DriveMM:用于自动驾驶的一体化大型多模态模型——论文阅读

《DriveMM: All-in-One Large Multimodal Model for Autonomous Driving》2024年12月发表&#xff0c;来自中山大学深圳分校和美团的论文。 大型多模态模型&#xff08;LMM&#xff09;通过整合大型语言模型&#xff0c;在自动驾驶&#xff08;AD&#xff09;中表现出卓越的理解…

C++_STL_map与set

1. 关联式容器 在初阶阶段&#xff0c;我们已经接触过STL中的部分容器&#xff0c;比如&#xff1a;vector、list、deque、 forward_list(C11)等&#xff0c;这些容器统称为序列式容器&#xff0c;因为其底层为线性序列的数据结构&#xff0c;里面 存储的是元素本身。那什么是…