NLP学习记录十:多头注意力

一、单头注意力

单头注意力的大致流程如下:

① 查询编码向量、键编码向量和值编码向量分别经过自己的全连接层(Wq、Wk、Wv)后得到查询Q、键K和值V;

② 查询Q和键K经过注意力评分函数(如:缩放点积运算)得到值权重矩阵;

③ 权重矩阵与值向量相乘,得到输出结果。

 图1 单头注意力模型

 

二、多头注意力 

2.1 使用多头注意力的意义      

        看了一些对多头注意力机制解释的视频,我自己的浅显理解是:在实践中,我们会希望查询Q能够从给定内容中尽可能多地匹配到与自己相关的语义信息,从而得到更准确的预测输出。而多头注意力将查询、键和值分成不同的子空间表示(representation subspaces)(有点类似于子特征?),使得匹配过程更加细化。

2.2 代码实现

        也许直接看代码能更快地理解这个过程:

import torch
from torch import nn
from attentionScore import DotProductAttention
# 多头注意力模型
class MultiHeadAttention(nn.Module):def __init__(self, key_size, query_size, value_size, num_hiddens,num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)# queries:(batch_size,查询的个数,query_size)# keys:(batch_size,“键-值”对的个数,key_size)# values:(batch_size,“键-值”对的个数,value_size)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)queries = self.W_q(queries)keys = self.W_k(keys)values = self.W_v(values)# 经过变换后,输出的queries,keys,values的形状:(batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)queries = transpose_qkv(queries, self.num_heads)keys = transpose_qkv(keys, self.num_heads)values = transpose_qkv(values, self.num_heads)# valid_lens的形状:(batch_size,)或(batch_size,查询的个数)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)
# 为了多注意力头的并行计算而变换形状
def transpose_qkv(X, num_heads):# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])
# 逆转transpose_qkv函数的操作
def transpose_output(X, num_heads):X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)

        可以发现,前面的处理流程和单头注意力的第①步是一样的,都是使用全连接层计算查询Q、键K、值V。但在进行点积运算之前,模型使用transpose_qkv函数对QKV进行了切割变换,下图可以帮助理解这个过程:

图2 transpose_qkv函数处理Q

图3 transpose_qkv函数处理K 

        这个过程就像是把一个整体划分为了很多小的子空间。一个不知道恰不恰当的比喻,就像是把“父母”这个词拆分成了“长辈”、“养育者”、“监护人”、“爸妈”多重含义。

        对切割变换后的QK进行缩放点积运算,过程如下图所示:

 图4 对切割变换后的Q和K进行缩放点积运算

        transpose_output后的输出结果:

图5 对值加权结果进行transpose_output变换后 

        对比单头注意力的值加权输出,原来的每个查询Q匹配到了更多的value:

图6 多头注意力与单头注意力的值加权结果对比

        整个过程就像是把一个父需求分割成不同的子需求,子需求单独与不同的子特征进行匹配,最后使得每个父需求获得了更多的语义信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/70849.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

算法-二叉树篇08-完全二叉树的节点个数

完全二叉树的节点个数 力扣题目链接 题目描述 给你一棵 完全二叉树 的根节点 root ,求出该树的节点个数。 完全二叉树 的定义如下:在完全二叉树中,除了最底层节点可能没填满外,其余每层节点数都达到最大值,并且最下…

【原创工具】同文件夹PDF文件合并 By怜渠客

【原创工具】同文件夹PDF文件合并 By怜渠客 原贴:可批量合并多个文件夹内的pdf工具 - 吾爱破解 - 52pojie.cn 他这个存在一些问题,并非是软件内自主实现的PDF合并,而是调用的pdftk这一工具,但楼主并没有提供pdftk,而…

微软云和金山云和k8有什么区别

Kubernetes(K8s)和微软云(Microsoft Cloud)是两种不同的技术,分别用于不同的目的。Kubernetes是一个开源的容器编排系统,用于自动化部署、扩展和管理容器化应用程序,而微软云是一个提供多种云服…

libGL.so.1: cannot open shared object file: No such file or directory-linux022

in <module> from PyQt5.QtGui import QPixmap, QFont, QIcon ImportError: libGL.so.1: cannot open shared object file: No such file or directory 这个错误信息表示XXXX 在启动时遇到问题&#xff0c;缺少 libGL.so.1 文件。libGL.so.1 是与 OpenGL 图形库相关的共…

渗透测试【seacms V9】

搭建seacms环境 我选择在虚拟机中用宝塔搭建环境 将在官网选择的下载下来的文件解压后拖入宝塔面板的文件中 创建网站 添加站点 搭建完成seacmsV9 找到一个报错口 代码分析 <?php set_time_limit(0); error_reporting(0); $verMsg V6.x UTF8; $s_lang utf-8; $dfDbn…

论文阅读笔记:Continual Forgetting for Pre-trained Vision Models

论文阅读笔记&#xff1a;Continual Forgetting for Pre-trained Vision Models 1 背景2 创新点3 方法4 模块4.1 问题设置4.2 LoRA4.3 概述4.4 GS-LoRA4.5 损失函数 5 效果6 结论 1 背景 出于隐私和安全考虑&#xff0c;如今从预先训练的视觉模型中删除不需要的信息的需求越来…

车载DoIP诊断框架 --- 连接 DoIP ECU/车辆的故障排除

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 简单,单纯,喜欢独处,独来独往,不易合同频过着接地气的生活,除了生存温饱问题之外,没有什么过多的欲望,表面看起来很高冷,内心热情,如果你身…

【爬虫基础】第二部分 爬虫基础理论 P1/3

上节内容回顾&#xff1a;【爬虫基础】第一部分 网络通讯 P1/3-CSDN博客 【爬虫基础】第一部分 网络通讯-Socket套接字 P2/3-CSDN博客 【爬虫基础】第一部分 网络通讯-编程 P3/3-CSDN博客 爬虫相关文档&#xff0c;希望互相学习&#xff0c;共同进步 风123456789&#xff…

Compose 手势处理,增进交互体验

Compose 手势处理&#xff0c;增进交互体验 概述常用手势处理Modifierclickable()combinedClickable()draggable()swipeable()transformable()scrollable()nestedScrollNestedScrollConnectionNestedScrollDispatcher 定制手势处理使用 PointerInput ModifierPointerInputScope…

ue5 3dcesium中从本地配置文件读取路3dtilles的路径

关卡蓝图中获得3dtiles的引用 拉出设置url 设置路径 至于设置的路径从哪里来 可以使用varest读取文件里的接送字符串 path中配置地址 path变量的值为: Data/VillageStartMapConfig.json此地址代表content的地下的data文件夹里的config.json文件 {"FilePath": &quo…

音视频入门基础:RTP专题(12)——RTP中的NAL Unit Type简介

一、引言 RTP封装H.264时&#xff0c;RTP对NALU Header的nal_unit_type附加了扩展含义。 由《音视频入门基础&#xff1a;H.264专题&#xff08;4&#xff09;——NALU Header&#xff1a;forbidden_zero_bit、nal_ref_idc、nal_unit_type简介》可以知道&#xff0c;nal_unit…

搜索赋能:大型语言模型的知识增强与智能提升

引言 近年来&#xff0c;大型语言模型&#xff08;LLM&#xff09;取得了显著的进展&#xff0c;并在各个领域展现出强大的能力。然而&#xff0c;LLM也存在一些局限性&#xff0c;尤其是在知识库方面。由于训练数据的局限性&#xff0c;LLM无法获取最新的知识&#xff0c;也无…

EX_25/2/24

写一个三角形类&#xff0c;拥有私有成员 a,b,c 三条边 写好构造函数初始化 abc 以及 abc 的set get 接口 再写一个等腰三角形类&#xff0c;继承自三角形类 1&#xff1a;写好构造函数&#xff0c;初始化三条边 2&#xff1a;要求无论如何&#xff0c;等腰三角形类对象&#x…

nv docker image 下载与使用命令备忘

1&#xff0c;系统需求 Requirements for GPU Simulation GPU Architectures Volta, Turing, Ampere, Ada, Hopper NVIDIA GPU with Compute Capability 7.0 CUDA 11.x (Driver 470.57.02), 12.x (Driver 525.60.13) Supported Systems CPU architectures x86_64, ARM…

学习记录:初次学习使用transformers进行大模型微调

初次使用transformers进行大模型微调 环境&#xff1a; 电脑配置&#xff1a; 笔记本电脑&#xff1a;I5&#xff08;6核12线程&#xff09; 16G RTX3070&#xff08;8G显存&#xff09; 需要自行解决科学上网 Python环境&#xff1a; python版本:3.8.8 大模型&#xff1a…

【Java学习】Object类与接口

面向对象系列五 一、引用 1.自调传自与this类型 2.类变量引用 3.重写时的发生 二、Object类 1.toString 2.equals 3.hashCode 4.clone 三、排序规则接口 1.Comparable 2.Comparator 一、引用 1.自调传自与this类型 似复刻变量调用里面的非静态方法时&#xff0c;都…

OpenEuler学习笔记(三十五):搭建代码托管服务器

以下是主流的代码托管软件分类及推荐&#xff0c;涵盖自托管和云端方案&#xff0c;您可根据团队规模、功能需求及资源情况选择&#xff1a; 一、自托管代码托管平台&#xff08;可私有部署&#xff09; 1. GitLab 简介: 功能全面的 DevOps 平台&#xff0c;支持代码托管、C…

Vscode无法加载文件,因为在此系统上禁止运行脚本

1.在 vscode 终端执行 get-ExecutionPolicy 如果返回是Restricted&#xff0c;说明是禁止状态。 2.在 vscode 终端执行set-ExecutionPolicy RemoteSigned 爆红说明没有设置成功 3.在 vscode 终端执行Set-ExecutionPolicy -Scope CurrentUser RemoteSigned 然后成功后你再在终…

Transformer 架构 理解

大家读完觉得有帮助记得关注和点赞&#xff01;&#xff01;&#xff01; Transformer 架构&#xff1a;encoder/decoder 内部细节。 的介绍&#xff0c;说明 Transformer 架构相比当时主流的 RNN/CNN 架构的创新之处&#xff1a; 在 transformer 之前&#xff0c;最先进的架构…

事务的4个特性和4个隔离级别

事务的4个特性和4个隔离级别 1. 什么是事务2. 事务的ACID特性2.1 原子性2.2 一致性2.3 持久性2.4 隔离性 3. 事务的创建4. 事务并发时出现的问题4.1 DIRTY READ 脏读4.2 NON - REPEATABLR READ 不可重复读4.3 PHANTOM READ 幻读 5. 事务的隔离级别5.1 READ UNCOMMITTED 读未提交…