从代码学习深度学习 - 多头注意力 PyTorch 版

文章目录

  • 前言
  • 一、多头注意力机制介绍
    • 1.1 工作原理
    • 1.2 优势
    • 1.3 代码实现概述
  • 二、代码解析
    • 2.1 导入依赖
      • 序列掩码函数
    • 2.2 掩码 Softmax 函数
    • 2.3 缩放点积注意力
    • 2.4 张量转换函数
    • 2.5 多头注意力模块
    • 2.6 测试代码
  • 总结


前言

在深度学习领域,注意力机制(Attention Mechanism)是自然语言处理(NLP)和计算机视觉(CV)等任务中的核心组件之一。特别是多头注意力(Multi-Head Attention),作为 Transformer 模型的基础,极大地提升了模型对复杂依赖关系的捕捉能力。本文通过分析一个完整的 PyTorch 实现,带你深入理解多头注意力的原理和代码实现。我们将从代码入手,逐步解析每个函数和类的功能,结合文字说明,让你不仅能运行代码,还能理解其背后的设计逻辑。无论你是初学者还是有一定经验的开发者,这篇博客都将帮助你更直观地掌握多头注意力机制。
完整代码:下载链接


一、多头注意力机制介绍

多头注意力(Multi-Head Attention)是 Transformer 模型的核心组件之一,广泛应用于自然语言处理(NLP)、计算机视觉(CV)等领域。它通过并行运行多个注意力头(Attention Heads),允许模型同时关注输入序列中的不同部分,从而捕捉更丰富的语义和上下文依赖关系。相比单一的注意力机制,多头注意力极大地增强了模型的表达能力,能够处理复杂的模式和长距离依赖。

1.1 工作原理

多头注意力的核心思想是将输入的查询(Queries)、键(Keys)和值(Values)通过线性变换映射到多个子空间,每个子空间由一个独立的注意力头处理。具体步骤如下:

  1. 线性变换:对输入的查询、键和值分别应用线性层,将其映射到隐藏维度(num_hiddens),并分割为多个头的表示。
  2. 缩放点积注意力:每个注意力头独立计算缩放点积注意力(Scaled Dot-Product Attention),即通过查询和键的点积计算注意力分数,再与值加权求和。
  3. 并行计算:多个注意力头并行运行,每个头关注输入的不同方面,生成各自的输出。
  4. 合并与变换:将所有头的输出拼接起来,并通过一个线性层融合,得到最终的多头注意力输出。

这种设计允许模型在不同子空间中学习不同的特征,例如在 NLP 任务中,一个头可能关注句法结构,另一个头可能关注语义关系。
在这里插入图片描述

1.2 优势

  • 多样性:多头机制使模型能够从多个角度理解输入,捕捉多样化的模式。
  • 并行性:多头计算可以高效并行化,提升计算效率。
  • 稳定性:通过缩放点积(除以特征维度的平方根),缓解了高维点积导致的数值不稳定问题。

1.3 代码实现概述

在本文的实现中,我们使用 PyTorch 构建了一个完整的多头注意力模块,包含以下关键部分:

  • 序列掩码:处理变长序列,屏蔽无效位置。
  • 缩放点积注意力:实现单个注意力头的计算逻辑。
  • 张量转换:通过 transpose_qkvtranspose_output 函数实现多头分割与合并。
  • 多头注意力类:整合所有组件,完成并行计算和输出融合。

接下来的代码解析将详细展示这些部分的实现,帮助你从代码层面深入理解多头注意力的每一步计算逻辑。

二、代码解析

以下是代码的完整实现和详细解析,代码按照 Jupyter Notebook(在最开始给出了完整代码下载链接) 的结构组织,并附上文字说明,帮助你理解每个部分的逻辑。

2.1 导入依赖

首先,我们导入必要的 Python 包,包括数学运算库 math 和 PyTorch 的核心模块 torchnn

# 导入包
import math
import torch
from torch import nn
  • math:用于计算缩放点积注意力中的归一化因子(即特征维度的平方根)。
  • torch:PyTorch 的核心库,提供张量运算和自动求导功能。
  • nn:PyTorch 的神经网络模块,包含 nn.Modulenn.Linear 等工具,用于构建神经网络层。

序列掩码函数

在处理序列数据(如句子)时,不同序列的长度可能不同,我们需要通过掩码(Mask)来屏蔽无效位置,防止模型关注这些填充区域。以下是 sequence_mask 函数的实现:

def sequence_mask(X, valid_len, value=0):"""在序列中屏蔽不相关的项,使超出有效长度的位置被设置为指定值参数:X: 输入张量,形状 (batch_size, 最大序列长度, 特征维度) 或 (batch_size, 最大序列长度)valid_len: 有效长度张量,形状 (batch_size,),表示每个序列的有效长度value: 屏蔽值,标量,默认值为 0,用于填充无效位置返回:输出张量,形状与输入 X 相同,无效位置被设置为 value"""maxlen = X.size(1)  # 最大序列长度,标量# 创建掩码,形状 (1, 最大序列长度),与 valid_len 比较生成布尔张量,形状 (batch_size, 最大序列长度)mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]# 将掩码取反后,X 的无效位置被设置为 valueX[~mask] = valuereturn X

解析

  • 输入
    • X:输入张量,通常是序列数据,可能包含填充(padding)部分。
    • valid_len:每个样本的有效长度,例如 [3, 2] 表示第一个样本有 3 个有效 token,第二个样本有 2 个。
    • value:用于填充无效位置的值,默认为 0。
  • 逻辑
    • maxlen 获取序列的最大长度(即张量的第二维)。
    • torch.arange(maxlen) 创建一个从 0 到 maxlen-1 的序列,形状为 (1, maxlen)
    • 通过广播机制,与 valid_len(形状 (batch_size, 1))比较,生成布尔掩码 mask,形状为 (batch_size, maxlen)
    • mask 表示哪些位置是有效的(True),哪些是无效的(False)。
    • 使用 ~mask 选择无效位置,将其值设置为 value
  • 输出:修改后的张量 X,无效位置被设置为 value,形状不变。

作用:该函数用于在注意力计算中屏蔽填充区域,确保模型只关注有效 token。

2.2 掩码 Softmax 函数

在注意力机制中,我们需要对注意力分数应用 Softmax 操作,将其转换为概率分布。但由于序列长度不同,需要屏蔽无效位置的贡献。以下是 masked_softmax 函数的实现:

import torch
import torch.nn.functional as Fdef masked_softmax(X, valid_lens):"""通过在最后一个轴上掩蔽元素来执行softmax操作,忽略无效位置参数:X: 输入张量,形状 (batch_size, 查询个数, 键-值对个数),3D张量valid_lens: 有效长度张量,形状 (batch_size,) 或 (batch_size, 查询个数),1D或2D张量,表示每个序列的有效长度,即每个查询可以参考的有效键值对长度返回:输出张量,形状 (batch_size, 查询个数, 键-值对个数),softmax后的注意力权重"""if valid_lens is None:# 如果没有有效长度,直接在最后一个轴上应用softmaxreturn F.softmax(X, dim=-1)shape 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/76794.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学术版 GPT 网页

学术版 GPT 网页 1. 学术版 GPT 网页非盈利版References https://academic.chatwithpaper.org/ 1. 学术版 GPT 网页非盈利版 arXiv 全文翻译&#xff0c;免费且无需登录。 更换模型 System prompt: Serve me as a writing and programming assistant. 界面外观 References …

MarkDown 输出表格的方法

MarkDown用来输出表格很简单&#xff0c;比Word手搓表格简单多了&#xff0c;而且方便修改。 MarkDown代码&#xff1a; |A|B|C|D| |:-|-:|:-:|-| |1|b|c|d| |2|b|c|d| |3|b|c|d| |4|b|c|d| |5|b|c|d|显示效果&#xff1a; ABCD1bcd2bcd3bcd4bcd5bcd A列强制左对齐&#xf…

MetaGPT深度解析:重塑AI协作开发的智能体框架实践指南

一、框架架构与技术突破 1.1 系统架构设计 graph TBA[自然语言需求] --> B(需求解析引擎)B --> C{角色路由系统}C --> D[产品经理Agent]C --> E[架构师Agent]C --> F[工程师Agent]D --> G[PRD文档]E --> H[架构图]F --> I[代码文件]G --> J[知识共…

自用:在使用SpringBoot做学生信息管理系统时遇到的问题

1、在做完查询测试时&#xff0c;一直报出404找不到错误&#xff0c;原因是没有为各个层的实现类添加注解 2、改完之后发现测试没有数据&#xff0c;是因为我写的返回值类型为空&#xff0c;应该返回一个List< Student > 3、我没有想到要写Result实体类&#xff0c;因为不…

SQLite + Redis = Redka

Redka 是一个基于 SQLite 实现的 Redis 替代产品&#xff0c;实现了 Redis 的核心功能&#xff0c;并且完全兼容 Redis API。它可以用于轻量级缓存、嵌入式系统、快速原型开发以及需要事务 ACID 特性的键值操作等场景。 功能特性 Redka 的主要特点包括&#xff1a; 使用 SQLi…

202529 | RocketMQ 简介 + 安装 + 集群搭建 + 消费模式 + 消费者组

RocketMQ简介 RocketMQ 简介 Apache RocketMQ 是一款开源的 分布式消息中间件&#xff08;Message Queue, MQ&#xff09;&#xff0c;由阿里巴巴团队研发并捐赠给 Apache 基金会&#xff0c;现已成为顶级项目。它专为 高吞吐、低延迟、高可靠 的分布式场景设计&#xff0c;广…

Go语言--语法基础4--基本数据类型--整数类型

整型是所有编程语言里最基础的数据类型。 Go 语言支持如下所示的这些整型类型。 需要注意的是&#xff0c; int 和 int32 在 Go 语言里被认为是两种不同的类型&#xff0c;编译器也不会帮你自动做类型转换&#xff0c; 比如以下的例子会有编译错误&#xff1a; var value2 in…

竞拍商城:电商创新的博弈场与未来趋势

竞拍商城&#xff1a;电商创新的博弈场与未来趋势 在传统电商趋于同质化的今天&#xff0c;竞拍商城凭借其独特的交易机制和用户激励模式&#xff0c;成为电商领域的新宠。通过结合拍卖的博弈属性与电商的便捷性&#xff0c;竞拍商城不仅重塑了消费体验&#xff0c;更催生了全…

Linux : 多线程互斥

目录 一 前言 二 线程互斥 三 Mutex互斥量 1. 定义一个锁&#xff08;造锁&#xff09; 2. 初始化锁 3. 上锁 4. 解锁 5. 摧毁锁 四 锁的使用 五 锁的宏初始化 六 锁的原理 1.如何看待锁&#xff1f; 2. 如何理解加锁和解锁的本质 七 c封装互斥锁 八 可重入…

论文阅读笔记——Reactive Diffusion Policy

RDP 论文 通过 AR 提供实时触觉/力反馈&#xff1b;慢速扩散策略&#xff0c;用于预测低频潜在空间中的高层动作分块&#xff1b;快速非对称分词器实现闭环反馈控制。 ACT、 π 0 \pi_0 π0​ 采取了动作分块&#xff0c;在动作分块执行期间处于开环状态&#xff0c;无法及时响…

swagger 注释说明

一、接口注释核心字段 在 Go 的路由处理函数&#xff08;Handler&#xff09;上方添加注释&#xff0c;支持以下常用注解&#xff1a; 注解名称用途说明示例格式Summary接口简要描述Summary 创建用户Description接口详细说明Description 通过用户名和邮箱创建新用户Tags接口分…

STM32 HAL库 OLED驱动实现

一、概述 1.1 OLED 显示屏简介 OLED&#xff08;Organic Light - Emitting Diode&#xff09;即有机发光二极管&#xff0c;与传统的 LCD 显示屏相比&#xff0c;OLED 具有自发光、视角广、响应速度快、对比度高、功耗低等优点。在嵌入式系统中&#xff0c;OLED 显示屏常被用…

Web开发-JavaEE应用动态接口代理原生反序列化危险Invoke重写方法利用链

知识点&#xff1a; 1、安全开发-JavaEE-动态代理&序列化&反序列化 2、安全开发-JavaEE-readObject&toString方法 一、演示案例-WEB开发-JavaEE-动态代理 动态代理 代理模式Java当中最常用的设计模式之一。其特征是代理类与委托类有同样的接口&#xff0c;代理类…

K8s是常用命令和解释

K8s高频命令 获取资源信息&#xff0c;如获取 Pod、Service、Deployment等资源状态信息 kubectl get创建资源如创建Pod、Service、Deployment等资源 kubectl create删除资源&#xff0c;如删除Pod、Service、Deployment等资源 kubectl delete 应用配置文件&#xff0c;如引用D…

【模态分解】EMD-经验模态分解

算法配置页面&#xff0c;也可以一键导出结果数据 报表自定义绘制 获取和下载【PHM学习软件PHM源码】的方式 获取方式&#xff1a;Docshttps://jcn362s9p4t8.feishu.cn/wiki/A0NXwPxY3ie1cGkOy08cru6vnvc

TDengine 语言连接器(Go)

简介 driver-go 是 TDengine 的官方 Go 语言连接器&#xff0c;实现了 Go 语言 database/sql 包的接口。Go 开发人员可以通过它开发存取 TDengine 集群数据的应用软件。 Go 版本兼容性 支持 Go 1.14 及以上版本。 支持的平台 原生连接支持的平台和 TDengine 客户端驱动支持…

链接世界:计算机网络的核心与前沿

计算机网络引言 在数字化时代&#xff0c;计算机网络已经成为我们日常生活和工作中不可或缺的基础设施。从简单的局域网&#xff08;LAN&#xff09;到全球互联网&#xff0c;计算机网络将数以亿计的设备连接在一起&#xff0c;推动了信息交换、资源共享以及全球化的进程。 什…

AI agents系列之全面介绍

随着大型语言模型(LLMs)的出现,人工智能(AI)取得了巨大的飞跃。这些强大的系统彻底改变了自然语言处理,但当它们与代理能力结合时,才真正释放出潜力——能够自主地推理、规划和行动。这就是LLM代理大显身手的地方,它们代表了我们与AI交互以及利用AI的方式的范式转变。 …

如何使用AI辅助开发CSS3 - 通义灵码功能全解析

一、引言 CSS3 作为最新的 CSS 标准&#xff0c;引入了众多新特性&#xff0c;如弹性布局、网格布局等&#xff0c;极大地丰富了网页样式的设计能力。然而&#xff0c;CSS3 的样式规则繁多&#xff0c;记忆所有规则对于开发者来说几乎是不可能的任务。在实际开发中&#xff0c…

复刻系列-星穹铁道 3.2 版本先行展示页

复刻星穹铁道 3.2 版本先行展示页 0. 视频 手搓&#xff5e;星穹铁道&#xff5e;展示页&#xff5e;&#xff5e;&#xff5e; 1. 基本信息 作者: 啊是特嗷桃系列: 复刻系列官方的网站: 《崩坏&#xff1a;星穹铁道》3.2版本「走过安眠地的花丛」专题展示页现已上线复刻的网…