【现代深度学习技术】注意力机制05:多头注意力

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ PyTorch深度学习 ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目录

    • 一、模型
    • 二、实现
    • 小结


  在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖关系)。因此,允许注意力机制组合使用查询、键和值的不同子空间表示(representation subspaces)可能是有益的。

  为此,与其只使用单独一个注意力汇聚,我们可以用独立学习得到的 h h h组不同的线性投影(linear projections)来变换查询、键和值。然后,这 h h h组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这 h h h个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为多头注意力(multihead attention)。对于 h h h个注意力汇聚输出,每一个注意力汇聚都被称作一个(head)。图1展示了使用全连接层来实现可学习的线性变换的多头注意力。

在这里插入图片描述

图1 多头注意力:多个头连结然后线性变换

一、模型

  在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询 q ∈ R d q \mathbf{q} \in \mathbb{R}^{d_q} qRdq、键 k ∈ R d k \mathbf{k} \in \mathbb{R}^{d_k} kRdk和值 v ∈ R d v \mathbf{v} \in \mathbb{R}^{d_v} vRdv,每个注意力头 h i \mathbf{h}_i hi i = 1 , … , h i = 1, \ldots, h i=1,,h)的计算方法为:
h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p v (1) \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v} \tag{1} hi=f(Wi(q)q,Wi(k)k,Wi(v)v)Rpv(1) 其中,可学习的参数包括 W i ( q ) ∈ R p q × d q \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q} Wi(q)Rpq×dq W i ( k ) ∈ R p k × d k \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k} Wi(k)Rpk×dk W i ( v ) ∈ R p v × d v \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v} Wi(v)Rpv×dv,以及代表注意力汇聚的函数 f f f f f f可以是注意力评分函数中的加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 h h h个头连结后的结果,因此其可学习参数是 W o ∈ R p o × h p v \mathbf W_o\in\mathbb R^{p_o\times h p_v} WoRpo×hpv
W o [ h 1 ⋮ h h ] ∈ R p o (2) \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o} \tag{2} Wo h1hh Rpo(2)

  基于这种设计,每个头都可能会关注输入的不同部分,可以表示比简单加权平均值更复杂的函数。

import math
import torch
from torch import nn
from d2l import torch as d2l

二、实现

  在实现过程中通常选择缩放点积注意力作为每一个注意力头。为了避免计算代价和参数代价的大幅增长,我们设定 p q = p k = p v = p o / h p_q = p_k = p_v = p_o / h pq=pk=pv=po/h。值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 p q h = p k h = p v h = p o p_q h = p_k h = p_v h = p_o pqh=pkh=pvh=po,则可以并行计算 h h h个头。在下面的实现中, p o p_o po是通过参数num_hiddens指定的。

#@save
class MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形状:# (batch_size,查询或者“键-值”对的个数,num_hiddens)# valid_lens 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values 的形状:# (batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,# 然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形状:(batch_size*num_heads,查询的个数,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)

  为了能够使多个头并行计算,上面的MultiHeadAttention类将使用下面定义的两个转置函数。具体来说,transpose_output函数反转了transpose_qkv函数的操作。

#@save
def transpose_qkv(X, num_heads):"""为了多注意力头的并行计算而变换形状"""# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])#@save
def transpose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)

  下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
attention.eval()

在这里插入图片描述

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

在这里插入图片描述

小结

  • 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
  • 基于适当的张量操作,可以实现多头注意力的并行计算。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/80476.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot 集成滑块验证码AJ-Captcha行为验证码 Redis分布式 接口限流 防爬虫

介绍 滑块验证码比传统的字符验证码更加直观和用户友好,能够很好防止爬虫获取数据。 AJ-Captcha行为验证码,包含滑动拼图、文字点选两种方式,UI支持弹出和嵌入两种方式。后端提供Java实现,前端提供了php、angular、html、vue、u…

边缘网关(边缘计算)

边缘网关是边缘计算架构中的关键组件,充当连接终端设备(如传感器、IoT设备)与云端或核心网络的桥梁。它在数据源头附近进行实时处理、分析和过滤,显著提升效率并降低延迟。 核心功能 协议转换 ○ 支持多种通信协议(如…

OpenCV定位地板上的书

任务目标是将下面的图片中的书本找出来: 使用到的技术包括:转灰度图、提取颜色分量、二值化、形态学、轮廓提取等。 我们尝试先把图片转为灰度图,然后二值化,看看效果: 可以看到,二值化后,书的…

机器学习第一讲:机器学习本质:让机器通过数据自动寻找规律

机器学习第一讲:机器学习本质:让机器通过数据自动寻找规律 资料取自《零基础学机器学习》。 查看总目录:学习大纲 关于DeepSeek本地部署指南可以看下我之前写的文章:DeepSeek R1本地与线上满血版部署:超详细手把手指…

修改图像分辨率

在这个教程中,您将学习如何使用Python和深度学习技术来调整图像的分辨率。我们将从基础的图像处理技术开始,逐步深入到使用预训练的深度学习模型进行图像超分辨率处理。 一、常规修改方法 1. 安装Pillow库 首先,你需要确保你的Python环境中…

jsAPI

环境准备 1 安装nvm nvm 即 (node version manager),好处是方便切换 node.js 版本 安装注意事项 要卸载掉现有的 nodejs提示选择 nvm 和 nodejs 目录时,一定要避免目录中出现空格选用【以管理员身份运行】cmd 程序来执行 nvm 命令首次运行前设置好国…

SCDN是什么?

SCDN是安全内容分发网络的简称,它在传统内容分发网络(CDN)的基础上,集成了安全防护能力,旨在同时提升内容传输速度和网络安全性。 SCDN的核心功能有: DDoS防御:识别并抵御大规模分布式拒绝服务…

Qt/C++开发监控GB28181系统/实时视频预览/视频点播/rtp解包解码显示

一、前言 通过gb28181做实时视频预览,也就是视频点播功能,是最重要的功能了,绝对是整个系统排第一重要的,这就是核心功能,什么设备注册、获取通道等都是为了实时预览做准备的,当然这个功能也是最难的&…

找银子 题解(c++)

题目 思路 首先,这道题乍一看,应该可以用搜索来做。 但是,搜索会不会超时间限制呢? 为了防止时间超限,我们可以换一种做法。 先创立两个二维数组,一个是输入的数组a,一个是数组b。 假设 i 行 j 列的数…

子集树算法文档

1.算法概述 子集树是一种 回溯算法,用于生成一个集合的所有子集。给定一个数组 arr,该算法递归地遍历所有可能的子集,并通过一个辅助数组 x 标记当前元素是否被选中。 2.算法特点 时间复杂度:O(2n)(因为一个包含 n 个…

HTTP/1.1 host虚拟主机详解

一、核心需求:为什么需要虚拟主机? 在互联网上,我们常常希望在一台物理服务器(它通常只有一个公网 IP 地址)上运行多个独立的网站,每个网站都有自己独特的域名(例如 www.a-site.com​, www.b-s…

amass:深入攻击面映射和资产发现工具!全参数详细教程!Kali Linux教程!

简介 OWASP Amass 项目使用开源信息收集和主动侦察技术执行攻击面网络映射和外部资产发现。 此软件包包含一个工具,可帮助信息安全专业人员使用开源信息收集和主动侦察技术执行攻击面网络映射并执行外部资产发现。 使用的信息收集技术 技术数据来源APIs&#xf…

Spring Web MVC响应

返回静态页面 第一步 创建html时,要注意创建的路径,要在static下面 第二步 把需要写的内容写到body内 第三步 直接访问路径就可以 返回数据ResponseBody RestController Controller ResponseBody Controller:返回视图 ResponseBody&…

‌鸿蒙PC正式发布:国产操作系统实现全场景生态突破

鸿蒙PC正式发布:国产操作系统实现全场景生态突破‌ 2025年5月8日,华为在深圳举办发布会,正式推出搭载鸿蒙操作系统的个人电脑(PC),标志着国产操作系统在核心技术与生态布局上实现历史性跨越。此次发布的鸿蒙…

【计算机视觉】OpenCV实战项目:Text-Extraction-Table-Image:基于OpenCV与OCR的表格图像文本提取系统深度解析

Text-Extraction-Table-Image:基于OpenCV与OCR的表格图像文本提取系统深度解析 1. 项目概述2. 技术原理与算法设计2.1 图像预处理流水线2.2 表格结构检测算法2.3 OCR优化策略 3. 实战部署指南3.1 环境配置3.2 核心代码解析3.3 执行流程示例 4. 常见问题与解决方案4.…

Redis BigKey 问题是什么

BigKey 问题是什么 BigKey 的具体表现是 redis 中的 key 对应的 value 很大,占用的 redis 空间比较大,本质上是大 value 问题。 BigKey怎么找 redis-cli --bigkeysscanBig Key 产生的原因 1.redis数据结构使用不恰当 2.未及时清理垃圾数据 3.对业务预…

go-gin

前置 gin是go的一个web框架,我们简单介绍一下gin的使用 导入gin :"github.com/gin-gonic/gin" 我们使用import导入gin的包 简单示例: package mainimport ("github.com/gin-gonic/gin" )func main() {r : gin.Default(…

C# NX二次开发:判断两个体是否干涉和获取系统日志的UFUN函数

大家好,今天要讲关于如何判断两个体是否干涉和获取系统日志的UFUN函数。 (1)UF_MODL_check_interference:这个函数的定义为根据单个目标体检查每个指定的工具体是否有干扰。 Defined in: uf_modl.h Overview Checks each sp…

如何解决 Linux 系统文件描述符耗尽的问题

在Linux系统中,文件描述符(File Descriptor, FD)是操作系统管理打开文件、套接字、管道等资源的抽象标识。当进程或系统耗尽文件描述符时,会导致服务崩溃、连接失败等严重问题。以下是详细的排查和解决方案: --- ###…

LVGL简易计算器实战

文章目录 📁 文件结构建议🔹 eval.h 表达式求值头文件🔹 eval.c 表达式求值实现文件(带详细注释)🔹 ui.h 界面头文件🔹 ui.c 界面实现文件🔹 main.c 主函数入口✅ 总结 项目效果&…