缩放点积注意力

Scaled Dot-Product Attention

  • 论文地址

    https://arxiv.org/pdf/1706.03762

注意力机制介绍

  • 缩放点积注意力是Transformer模型的核心组件,用于计算序列中不同位置之间的关联程度。其核心思想是通过查询向量(query)和键向量(key)的点积来获取注意力分数,再通过缩放和归一化处理,最后与值向量(value)加权求和得到最终表示。

    image-20250423201641471

数学公式

  • 缩放点积注意力的计算过程可分为三个关键步骤:

    1. 点积计算与缩放:通过矩阵乘法计算查询向量与键向量的相似度,并使用 d k \sqrt{d_k} dk 缩放
    2. 掩码处理(可选):对需要忽略的位置施加极大负值掩码
    3. Softmax归一化:将注意力分数转换为概率分布
    4. 加权求和:用注意力权重对值向量进行加权

    公式表达为:
    Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V Attention(Q,K,V)=softmax(dk QKT)V
    其中:

    • Q ∈ R s e q _ l e n × d _ k Q \in \mathbb{R}^{seq\_len \times d\_k} QRseq_len×d_k:查询矩阵
    • K ∈ R s e q _ l e n × d _ k K \in \mathbb{R}^{seq\_len \times d\_k} KRseq_len×d_k:键矩阵
    • V ∈ R s e q _ l e n × d _ k V \in \mathbb{R}^{seq\_len \times d\_k} VRseq_len×d_k:值矩阵

    s e q _ l e n seq\_len seq_len 为序列长度, d _ k d\_k d_k 为embedding的维度。

代码实现

  • 计算注意力分数

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    import torchdef calculate_attention(query, key, value, mask=None):"""计算缩放点积注意力分数参数说明:query: [batch_size, n_heads, seq_len, d_k]key:   [batch_size, n_heads, seq_len, d_k] value: [batch_size, n_heads, seq_len, d_k]mask:  [batch_size, seq_len, seq_len](可选)"""d_k = key.shape[-1]key_transpose = key.transpose(-2, -1)  # 转置最后两个维度# 计算缩放点积 [batch, h, seq_len, seq_len]att_scaled = torch.matmul(query, key_transpose) / d_k ** 0.5# 掩码处理(解码器自注意力使用)if mask is not None:att_scaled = att_scaled.masked_fill_(mask=mask, value=-1e9)# Softmax归一化att_softmax = torch.softmax(att_scaled, dim=-1)# 加权求和 [batch, h, seq_len, d_k]return torch.matmul(att_softmax, value)
    
  • 相关解释

    1. 输入张量 query, key, value的形状

      如果是直接计算的话,那么shape是 [batch_size, seq_len, d_model]

      当然为了学习更多的表征,一般都是多头注意力,这时候shape则是[batch_size, n_heads, seq_len, d_k]

      其中

      • batch_size:批量

      • n_heads:注意力头的数量

      • seq_len: 序列的长度

      • d_model: embedding维度

      • d_k: d_k = d_model / n_heads

    2. 代码中的shape转变

      • key_transpose :key的转置矩阵

        由 key 转置了最后两个维度,维度从 [batch_size, n_heads, seq_len, d_k] 转变为 [batch_size, n_heads, d_k, seq_len]

      • **att_scaled **:缩放点积

        由 query 和 key 通过矩阵相乘得到

        [batch_size, n_heads, seq_len, d_k] @ [batch_size, n_heads, d_k, seq_len] --> [batch_size, n_heads, seq_len, seq_len]

      • att_score: 注意力分数

        由两个矩阵相乘得到

        [batch_size, n_heads, seq_len, seq_len] @ [batch_size, n_heads, seq_len, d_k] --> [batch_size, n_heads, seq_len, d_k]


使用示例

  • 测试代码

    if __name__ == "__main__":# 模拟输入:batch_size=2, 8个注意力头,序列长度512,d_k=64x = torch.ones((2, 8, 512, 64))# 计算注意力(未使用掩码)att_score = calculate_attention(x, x, x)print("输出形状:", att_score.shape)  # torch.Size([2, 8, 512, 64])print("注意力分数示例:\n", att_score[0,0,:3,:3])
    

    在实际使用中通常会将此实现封装为nn.Module并与位置编码、残差连接等组件配合使用,构建完整的Transformer层。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/77912.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

可吸收聚合物:医疗科技与绿色未来的交汇点

可吸收聚合物(Biodegradable Polymers)作为生物医学工程的核心材料,正引领一场从“金属/塑料植入物”到“智能降解材料”的范式转移。根据QYResearch(恒州博智)预测,2031年全球可吸收聚合物市场销售额将突破…

房地产项目绩效考核管理制度与绩效提升

房地产项目绩效考核管理制度的核心目的是通过合理的绩效考核机制,提升项目的整体运作效率,并鼓励项目团队成员的积极性。该制度适用于所有房地产项目部工作人员,涵盖了项目经理和项目成员的考核。考核的主要内容包括项目经理和项目部成员的工…

【算法笔记】动态规划基础(一):dp思想、基础线性dp

目录 前言动态规划的精髓什么叫“状态”动态规划的概念动态规划的三要素动态规划的框架无后效性dfs -> 记忆化搜索 -> dp暴力写法记忆化搜索写法记忆化搜索优化了什么?怎么转化成dp?dp写法 dp其实也是图论首先先说结论:状态DAG是怎样的…

pytorch 51 GroundingDINO模型导出tensorrt并使用c++进行部署,53ms一张图

本专栏博客第49篇文章分享了将 GroundingDINO模型导出onnx并使用c++进行部署,并尝试将onnx模型转换为trt模型,fp16进行推理,可以发现推理速度提升了一倍。为此对GroundingDINO的trt推理进行调研,发现 在GroundingDINO-TensorRT-and-ONNX-Inference项目中分享了模型导出onnx…

一个关于相对速度的假想的故事-6

既然已经知道了速度是不能叠加的,同时也知道这个叠加是怎么做到的,那么,我们实际上就知道了光速的来源,也就是这里的虚数单位的来源: 而它的来源则是, 但这是两个速度的比率,而光速则是一个速度…

深度学习激活函数与损失函数全解析:从Sigmoid到交叉熵的数学原理与实践应用

目录 前言一、sigmoid 及导数求导二、tanh 三、ReLU 四、Leaky Relu五、 Prelu六、Softmax七、ELU八、极大似然估计与交叉熵损失函数8.1 极大似然估计与交叉熵损失函数算法理论8.1.1 伯努利分布8.1.2 二项分布8.1.3 极大似然估计总结 前言 书接上文 PaddlePaddle线性回归详解…

Python内置函数---breakpoint()

用于在代码执行过程中动态设置断点,暂停程序并进入调试模式。 1. 基本语法与功能 breakpoint(*args, kwargs) - 参数:接受任意数量的位置参数和关键字参数,但通常无需传递(默认调用pdb.set_trace())。 - 功能&#x…

从零手写 RPC-version1

一、 前置知识 1. 反射 获取字节码的三种方式 Class.forName("全类名") (全类名,即包名类名)类名.class对象.getClass() (任意对象都可调用,因为该方法来自Object类) 获取成员方法 Method getMethod(St…

ARINC818协议(六)

上图中,红色虚线上面为我们常用的simple mode简单模式,下面和上面的结合在一起,就形成了extended mode扩展模式。 ARINC818协议 container header容器头 ancillary data辅助数据 视频流 ADVB帧映射 FHCP传输协议 R_CTRL:路由控制routing ctr…

PyCharm 链接 Podman Desktop 的 podman-machine-default Linux 虚拟环境

#工作记录 PyCharm Community 连接到Podman Desktop 的 podman-machine-default Linux 虚拟环境详细步骤 1. 准备工作 确保我们已在 Windows 系统中正确安装并启动了 Podman Desktop。 我们将通过 Podman Desktop 提供的名为 podman-machine-default 的 Fedora Linux 41 WSL…

小白自学python第一天

学习python的第一天 一、常用的值类型(先来粗略认识一下~) 类型说明数字(number)包含整型(int)、浮点型(float)、复数(complex)、布尔(boolean&…

初阶数据结构--排序算法(全解析!!!)

排序 1. 排序的概念 排序:所谓排序,就是使一串记录,按照其中的某个或某些些关键字的大小,递增或递减的排列起来的操作。 2. 常见的排序算法 3. 实现常见的排序算法 以下排序算法均是以排升序为示例。 3.1 插入排序 基本思想:…

Android studio开发——room功能实现用户之间消息的发送

文章目录 1. Flask-SocketIO 后端代码后端代码 2. Android Studio Java 客户端代码客户端代码 3. 代码说明 SocketIO基础 1. Flask-SocketIO 后端代码 后端代码 from flask import Flask, request from flask_socketio import SocketIO, emit import uuidapp Flask(__name_…

4.LinkedList的模拟实现:

LinkedList的底层是一个不带头的双向链表。 不带头双向链表中的每一个节点有三个域:值域,上一个节点的域,下一个节点的域。 不带头双向链表的实现: public class Mylinkdelist{//定义一个内部类(节点)stat…

Sentinel数据S2_SR_HARMONIZED连续云掩膜+中位数合成

在GEE中实现时,发现简单的QA60是无法去云的,最近S2地表反射率数据集又进行了更新,原有的属性集也进行了变化,现在的SR数据集名称是“S2_SR_HARMONIZED”。那么: 要想得到研究区无云的图像,可以参考执行以下…

理解计算机系统_网络编程(1)

前言 以<深入理解计算机系统>(以下称“本书”)内容为基础&#xff0c;对程序的整个过程进行梳理。本书内容对整个计算机系统做了系统性导引,每部分内容都是单独的一门课.学习深度根据自己需要来定 引入 网络是计算机科学中非常重要的部分,笔者过去看过相关的内…

【2025】Datawhale AI春训营-RNA结构预测(AI+创新药)-Task2笔记

【2025】Datawhale AI春训营-RNA结构预测&#xff08;AI创新药&#xff09;-Task2笔记 本文对Task2提供的进阶代码进行理解。 任务描述 Task2的任务仍然是基于给定的RNA三维骨架结构&#xff0c;生成一个或多个RNA序列&#xff0c;使得这些序列能够折叠并尽可能接近给定的目…

vim 命令复习

命令模式下的命令及快捷键 # dd删除光所在行的内容 # ndd从光标所在行开始向下删除n行 # yy复制光标所在行的内容 # nyy复制光标所在行向下n行的内容 # p将复制的内容粘贴到光标所在行以下&#xff08;小写&#xff09; # P将复制的内容粘贴到光标所在行以上&#xff08;大写&…

哪些心电图表现无缘事业编体检呢?

根据《公务员录用体检通用标准》心血管系统条款及事业单位体检实施细则&#xff0c;心电图不合格主要涉及以下类型及处置方案&#xff1a; 一、心律失常类 早搏&#xff1a;包括房性早搏、室性早搏和交界性早搏。如果每分钟早搏次数较多&#xff08;如超过5次&#xff09;&…

Linux学习——UDP

编程的整体框架 bind&#xff1a;绑定服务器&#xff1a;TCP地址和端口号 receivefrom()&#xff1a;阻塞等待客户端数据 sendto():指定服务器的IP地址和端口号&#xff0c;要发送的数据 无连接尽力传输&#xff0c;UDP:是不可靠传输 实时的音视频传输&#x…