各类神经网络学习:(十)注意力机制(第2/4集),pytorch 中的多维注意力机制、自注意力机制、掩码自注意力机制、多头注意力机制

上一篇下一篇
注意力机制(第1/4集)待编写

一、pytorch 中的多维注意力机制:

N L P NLP NLP 领域内,上述三个参数都是 向量 , 在 p y t o r c h pytorch pytorch 中参数向量会组成 矩阵 ,方便代码编写。

①结构图

注意力机制结构图如下:

在这里插入图片描述

②计算公式详解

计算注意力分数的方式有很多,目前最常用的就是点乘。具体如下:

当向量 q u e r y \large query query k e y \large key key 长度相同时,即 q 、 k i ∈ R ( 1 × d ) q、k_i∈R^{(1×d)} qkiR(1×d) ,则有:注意力分数 s ( q , k i ) = < q , k i > d k \large s(q,k_i)=\frac{<q,k_i>}{\sqrt{d_k}} s(q,ki)=dk <q,ki> ,符号 < q , k i > <q,k_i> <q,ki> 表示点乘/内积运算(向量点乘,结果为标量)。其中 d k d_k dk k i k_i ki 向量的长度(为什么要在原注意力分数底下除以 d k \sqrt{d_k} dk 后面会详解)。

当向量组成矩阵时,假设 Q ∈ R ( n × d ) Q∈R^{(n×d)} QR(n×d) K ∈ R ( m × d ) K∈R^{(m×d)} KR(m×d) V ∈ R ( m × v ) V∈R^{(m×v)} VR(m×v) 。每个矩阵都是由参数行向量堆叠组成。则有:

F ( Q ) = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q ⋅ K T d k ) ⋅ V \Large F(Q)=Attention(Q,K,V)=softmax(\frac{Q·K^T}{\sqrt{d_k}})·V F(Q)=Attention(Q,K,V)=softmax(dk QKT)V
其中 Q K T d ∈ R ( n × m ) \large \frac{QK^T}{\sqrt{d}}∈R^{(n×m)} d QKTR(n×m) 是注意力分数,, s o f t m a x ( Q K T d k ) ∈ R ( n × m ) \large softmax(\frac{QK^T}{\sqrt{d_k}})∈R^{(n×m)} softmax(dk QKT)R(n×m) 是注意力权重, F ( Q ) ∈ R ( n × v ) \large F(Q)∈R^{(n×v)} F(Q)R(n×v) 是输出。

这是一种并行化矩阵计算形式,将所有的 q q q 组合成一个矩阵 Q Q Q k k k v v v 类似,都被组合成了矩阵 K K K V V V 。其详细过程如下:

已知 Q ∈ R ( n × d ) Q∈R^{(n×d)} QR(n×d) K ∈ R ( m × d ) K∈R^{(m×d)} KR(m×d) V ∈ R ( m × v ) V∈R^{(m×v)} VR(m×v) ,该尺寸表示有 n n n q q q m m m k k k m m m v v v 。则:

Q × K T = [ [ ⋯ q 1 ⋯ ] [ ⋯ q 2 ⋯ ] ⋮ [ ⋯ q n ⋯ ] ] ● [ [ ⋮ k 1 ⋮ ⋮ ] [ ⋮ k 2 ⋮ ⋮ ] ⋯ [ ⋮ k m ⋮ ⋮ ] ] = [ q 1 ⋅ k 1 q 1 ⋅ k 2 ⋯ q 1 ⋅ k m q 2 ⋅ k 1 q 2 ⋅ k 2 ⋯ q 2 ⋅ k m ⋮ ⋮ ⋱ ⋮ q n ⋅ k 1 q n ⋅ k 2 ⋯ q n ⋅ k m ] Q \times K^T =\\ \begin{bmatrix} \begin{bmatrix} \cdots & q_1 & \cdots \end{bmatrix} \\ \begin{bmatrix} \cdots & q_2 & \cdots \end{bmatrix} \\ \vdots \\ \begin{bmatrix} \cdots & q_n & \cdots \end{bmatrix} \end{bmatrix} ● \begin{bmatrix} \begin{bmatrix} \vdots \\ k_1 \\ \vdots \\ \vdots \end{bmatrix} & \begin{bmatrix} \vdots \\ k_2 \\ \vdots \\ \vdots \end{bmatrix} & \cdots & \begin{bmatrix} \vdots \\ k_m \\ \vdots \\ \vdots \end{bmatrix} \end{bmatrix}= \begin{bmatrix} q_1 \cdot k_1 & q_1 \cdot k_2 & \cdots & q_1 \cdot k_m \\ q_2 \cdot k_1 & q_2 \cdot k_2 & \cdots & q_2 \cdot k_m \\ \vdots & \vdots & \ddots & \vdots \\ q_n \cdot k_1 & q_n \cdot k_2 & \cdots & q_n \cdot k_m \end{bmatrix} Q×KT= [q1][q2][qn] k1 k2 km = q1k1q2k1qnk1q1k2q2k2qnk2q1kmq2kmqnkm

上述运算可以得到每个小 q q q m m m 个小 k k k 的注意力分数,再经过放缩(除以 d k \sqrt{d_k} dk )和 s o f t m a x softmax softmax 函数后得到每个小 q q q m m m 个小 k k k 的注意力权重矩阵,其尺寸为 n × m n×m n×m ,最终和 V V V 相乘,得到 F ( Q ) F(Q) F(Q) ,其尺寸为 n × v n×v n×v ,对应着 n n n q q q v a l u e value value

③公式细节解释

  1. 第一点:

    使用点乘来计算注意力分数的意义:矩阵点乘 Q ⋅ K T Q·K^T QKT 就意味着做点积/内积,(在注意力机制中,点积通常等同于内积,在数学上点积是内积的特例),内积可直接衡量两个向量的方向对齐程度。若两个向量方向一致(夹角为 0 ° 0° ),则内积最大;方向相反(夹角为 180 ° 180° 180° ),则内积最小。点乘不仅包含方向信息,还隐含向量长度的乘积。例如,若两个长向量方向一致,内积值会显著高于短向量,可能更强调其相关性。

  2. 第二点:

    上述公式中, s o f t m a x softmax softmax 里对注意力分数还除以了 d k \sqrt{d_k} dk ,是因为:由于 s o f t m a x softmax softmax 函数的计算公式用到了 e e e 的次方,当两个数之间的倍数很大时,比如说 99 和 1 ,经过求 e e e 的次方运算之后,差别会指数倍增加,这样求出来的概率会很离谱,不是0.99和0.01,而是0.99999999和0.0000000001(很多9和很多0)。让其中每个元素除以 d k \sqrt{d_k} dk 之后,会降低倍数增加的程度(更数学性的解释可以看 00 预训练语言模型的前世今生(全文 24854 个词) - B站-水论文的程序猿 - 博客园 这篇博客中的有关注意力机制的讲解)。其功能类似于防止梯度消失。

  3. 第三点:

    一般来说,在 t r a n s f o r m e r transformer transformer 里, K = V K=V K=V 。当然 K ≠ V K≠V K=V 也可以,不过两者之间一定是有对应关系,能组成键值对的。

二、自注意力机制(Self-Attention)

当上述的三个参数都由一个另外的共同参数 经过不同的线性变换 生成时(即三者同源),就是自注意力机制。其值体现为 Q ≈ K ≈ V Q≈K≈V QKV

这三个矩阵是在同一个矩阵 X X X 上乘以不同的系数矩阵 W Q 、 W K 、 W V W_Q、W_K、W_V WQWKWV 得到的,因此自注意力机制可以说是在计算 X X X 内部各个 x i x_i xi 之间的相关性。其后续步骤和注意力机制一样。(为什么叫自注意力机制,估计是因为这里是计算自己内部之间的相关性吧)

注意】:最终生成的新的 v a l u e value value 其实依然是小 x x x 的向量表示,只不过这个新向量蕴含了其他的小 x x x 的信息。

具体公式如下:
Q = W Q ⋅ X , K = W K ⋅ X , V = W V ⋅ X F ( Q ) = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q ⋅ K T d k ) ⋅ V \large Q=W_Q·X,~~~~K=W_K·X,~~~~V=W_V·X\\ \Large F(Q)=Attention(Q,K,V)=softmax(\frac{Q·K^T}{\sqrt{d_k}})·V Q=WQX,    K=WKX,    V=WVXF(Q)=Attention(Q,K,V)=softmax(dk QKT)V
N L P NLP NLP 中,可以举一个小例子理解一下(矩阵内数值即为注意力权重):

在这里插入图片描述

上图中,每一个单词就是一个小 q q q ,单词用向量表示。(有个误区:不是说自注意力机制中,小 q q q 和自己的注意力分数就是最大的,这个要看具体语义需求)

其他变种:交叉注意力机制( Q Q Q V V V 不同源, K K K V V V 同源)。

三、掩码自注意力机制(Masked Self-Attention)

N L P NLP NLP 里,在训练过程中,比如说我想训练模型生成:“The cat is cute” 这样一个句子,并且计算其自注意力权重,这个时候 “The cat is cute” 就是已知的 label 。但是句子是一个一个单词生成的( The → The cat → The cat is → The cat is cute),第一个生成 The ,第二个生成 cat … 在没有完全生成之前,都是不能提前告诉模型后面的答案。已知句子总长度为 4 4 4 ,那么注意力权重的个数依次是 1 → 2 → 3 → 4 。如下图所示:

在这里插入图片描述

注意了,这里的生成是指训练时的生成,掩码机制只在训练时使用,因为训练时机器知道有位置信息的句子(句子的长度也已知晓),为了防止窥探到下一个字就要掩码。但在实际使用模型时(测试时),是没有参考答案的,所以不需要掩码!

其实还有其他作用,诸如:避免填充干扰等,后面在 transformer 里会详解。

四、多头注意力机制(Multi-Head Self-Attention)

本质上就是: X X X 做完三次线性变换得到 Q 、 K 、 V Q、K、V QKV之后,将 Q 、 K 、 V Q、K、V QKV分割成 8 8 8 块进行注意力计算,最后将这 8 8 8 个结果拼接,然后线性变换,使其维度和 X X X 一致。(并不是直接对 X 进行切分,也不是对 X 进行重复线性变换)

意义:原论文其实也说不清楚这样做的意义,反正给人一种能学到更细致的语义信息的感觉(深度学习就是这样~~)。

流程图如下:

在这里插入图片描述

第一步:

输入序列 X X X 首先经过三次独立的线性变换,生成查询( Q u e r y Query Query)、键( K e y Key Key)、值( V a l u e Value Value)矩阵:

Q = W Q ⋅ X Q=W_Q·X Q=WQX K = W K ⋅ X K=W_K·X K=WKX V = W V ⋅ X V=W_V·X V=WVX 。其中, W Q 、 W K 、 W V W_Q、W_K、W_V WQWKWV 是可学习的权重矩阵。

第二步:

Q 、 K 、 V Q、K、V QKV 矩阵沿特征维度平均分割为多个头。一般头数均为 8 8 8(即 h = 8 h=8 h=8),假设 Q 、 K 、 V Q、K、V QKV 的特征维度为 M M M ,则分割之后每个头的特征维度为 M / 8 M/8 M/8

第三步:

每个头各自并行计算注意力并得到各自的输出(先点积,再缩放,再做 s o f t m a x softmax softmax ,再乘以 v a l u e value value )【每个头学习不同子空间的语义关系】

第四步:

合并多头输出,将所有头的输出拼接为完整维度,再通过一次线性变换整合信息:

O u t p u t = C o n c a t ( h e a d 1 , … , h e a d h ) ⋅ W O Output=Concat(head_1,…,head_h)⋅W_O Output=Concat(head1,,headh)WO 。其中 W O W_O WO 是最后的线性层的投影矩阵。

值得一提的是:针对 “将 Q 、 K 、 V Q、K、V QKV分割成 8 8 8 块” 这个步骤,《Attention Is All You Need》论文原文说的是: linearly project h times ,意思就是将 Q 、 K 、 V Q、K、V QKV通过线性层将其变换为 8 8 8 个新的特征维度为 M / 8 M/8 M/8 Q ′ 、 K ′ 、 V ′ Q^{'}、K^{'}、V^{'} QKV 。不过这在数学上等效于直接分割成 8 8 8 块,并且后者在算法实现上能提高效率。代码如下:

Q = torch.randn(batch_size, seq_len, h*d_k)
Q = Q.view(batch_size, seq_len, h, d_k)  # 分割为 h 个头

.view() 函数的作用是变换尺寸,将原来的三维张量,变成四维张量( h 个三维张量),元素的值不变,元素的总数也不变,其效果等于切割。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/76397.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uni-app初学

文章目录 1. pages.json 页面路由2. 图标3. 全局 CSS4. 首页4.1 整体框架4.2 完整代码4.3 轮播图 swiper4.3.1 image 4.4 公告4.4.1 uni-icons 4.5 分类 uni-row、uni-col4.6 商品列表 小程序开发网址&#xff1a; 注册小程序账号 微信开发者工具下载 uniapp 官网 HbuilderX 下…

VBA将Word文档内容逐行写入Excel

如果你需要将Word文档的内容导入Excel工作表来进行数据加工&#xff0c;使用下面的代码可以实现&#xff1a; Sub ImportWordToExcel()Dim wordApp As Word.ApplicationDim wordDoc As Word.DocumentDim excelSheet As WorksheetDim filePath As VariantDim i As LongDim para…

MySQL运行一段时间后磁盘出现100%读写

MySQL运行一段时间后磁盘出现100%读写的情况&#xff0c;可能是由多种原因导致的&#xff0c;以下是一些常见原因及解决方法&#xff1a; 可能的原因 1. 磁盘I/O压力过大[^0^]&#xff1a;数据量过大&#xff0c;数据库查询和写入操作消耗大量I/O资源。索引效率低&#xff0c…

【RabbitMQ】延迟队列

1.概述 延迟队列其实就是队列里的消息是希望在指定时间到了以后或之前取出和处理&#xff0c;简单来说&#xff0c;延时队列就是用来存放需要在指定时间被处理的元素的队列。 延时队列的使用场景&#xff1a; 1.订单在十分钟之内未支付则自动取消 2.新创建的店铺&#xff0c;…

Linux笔记之Ubuntu系统设置自动登录tty1界面

Ubuntu22.04系统 编辑getty配置文件 vim /etc/systemd/system/gettytty1.service.d/override.conf如果该目录或者文件不存在&#xff0c;进行创建。 在override.conf文件中进行编辑&#xff1a; [Service] ExecStart ExecStart-/sbin/agetty --autologin yourusername --no…

C++程序诗篇的灵动赋形:多态

文章目录 1.什么是多态&#xff1f;2.多态的语法实现2.1 虚函数2.2 多态的构成2.3 虚函数的重写2.3.1 协变2.3.2 析构函数的重写 2.4 override 和 final 3.抽象类4.多态原理4.1 虚函数表4.2 多态原理实现4.3 动态绑定与静态绑定 5.继承和多态常见的面试问题希望读者们多多三连支…

算法训练之动态规划(三)

♥♥♥~~~~~~欢迎光临知星小度博客空间~~~~~~♥♥♥ ♥♥♥零星地变得优秀~也能拼凑出星河~♥♥♥ ♥♥♥我们一起努力成为更好的自己~♥♥♥ ♥♥♥如果这一篇博客对你有帮助~别忘了点赞分享哦~♥♥♥ ♥♥♥如果有什么问题可以评论区留言或者私信我哦~♥♥♥ ✨✨✨✨✨✨ 个…

$_GET变量

$_GET 是一个超级全局变量&#xff0c;在 PHP 中用于收集通过 URL 查询字符串传递的参数。它是一个关联数组&#xff0c;包含了所有通过 HTTP GET 方法发送到当前脚本的变量。 预定义的 $_GET 变量用于收集来自 method"get" 的表单中的值。 从带有 GET 方法的表单发…

jQuery多库共存

在现代Web开发中&#xff0c;项目往往需要集成多种JavaScript库或框架来满足不同的功能需求。然而&#xff0c;当多个库同时使用时&#xff0c;可能会出现命名冲突、功能覆盖等问题。幸运的是&#xff0c;jQuery提供了一些机制来确保其可以与其他库和谐共存。本文将探讨如何实现…

MySQL 中的聚簇索引和非聚簇索引有什么区别?

MySQL 中的聚簇索引和非聚簇索引有什么区别&#xff1f; 1. 从不同存储引擎去考虑 在MySIAM存储引擎中&#xff0c;索引和数据是分开存储的&#xff0c;包括主键索引在内的所有索引都是“非聚簇”的&#xff0c;每个索引的叶子节点存储的是数据记录的物理地址&#xff08;指针…

Java从入门到“放弃”(精通)之旅——启航①

&#x1f31f;Java从入门到“放弃 ”精通之旅&#x1f680; 今天我将要带大家一起探索神奇的Java世界&#xff01;希望能帮助到同样初学Java的你~ (๑•̀ㅂ•́)و✧ &#x1f525; Java是什么&#xff1f;为什么这么火&#xff1f; Java不仅仅是一门编程语言&#xff0c;更…

三相电为什么没零线也能通电

要理解三相电为什么没零线也能通电&#xff0c;就要从发电的原理说起 1、弧形磁铁中加入电枢&#xff0c;旋转切割磁感线会产生电流 随着电枢旋转的角度变化&#xff0c;电枢垂直切割磁感线 电枢垂直切割磁感线&#xff0c;此时会产生最大电压 当转到与磁感线平行时&#xf…

文件上传做题记录

1&#xff0c;[SWPUCTF 2021 新生赛]easyupload2.0 直接上传php 再试一下phtml 用蚁剑连发现连不上 那就只要命令执行了 2&#xff0c;[SWPUCTF 2021 新生赛]easyupload1.0 当然&#xff0c;直接上传一个php是不行的 phtml也不行&#xff0c;看下是不是前端验证&#xff0c;…

【Pandas】pandas DataFrame head

Pandas2.2 DataFrame Indexing, iteration 方法描述DataFrame.head([n])用于返回 DataFrame 的前几行 pandas.DataFrame.head pandas.DataFrame.head 是一个方法&#xff0c;用于返回 DataFrame 的前几行。这个方法非常有用&#xff0c;特别是在需要快速查看 DataFrame 的前…

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(1):承上启下,继续上路

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(1):承上启下,继续上路 1、前言(1)情况说明(2)工程师的信仰2、知识点(1)普通形(ふつうけい)と思います(2)辞書形ことができます(3)Vたことがあります。(4)Vた とき & Vる とき3、单词(1)日语单词(2…

码率自适应(ABR)相关论文阅读简报

标题&#xff1a;Quality Enhanced Multimedia Content Delivery for Mobile Cloud with Deep Reinforcement Learning 作者&#xff1a;Muhammad Saleem , Yasir Saleem, H. M. Shahzad Asif, and M. Saleem Mian 单位: 巴基斯坦拉合尔54890工程技术大学计算机科学与工程系 …

汇编语言:指令详解

零、前置知识 1、数据类型修饰符 名称解释byte一个字节&#xff0c;8bitword单字&#xff0c;占2个字节&#xff0c;16bitdword双字&#xff0c;占4个字节&#xff0c;32bitqword四字&#xff0c;占8个字节&#xff0c;64bit 2、关键词解释 ptr&#xff1a;它代表 pointer&a…

蓝桥杯c ++笔记(含算法 贪心+动态规划+dp+进制转化+便利等)

蓝桥杯 #include <iostream> #include <vector> #include <algorithm> #include <string> using namespace std; //常使用的头文件动态规划 小蓝在黑板上连续写下从 11 到 20232023 之间所有的整数&#xff0c;得到了一个数字序列&#xff1a; S12345…

【C++算法】54.链表_合并 K 个升序链表

文章目录 题目链接&#xff1a;题目描述&#xff1a;解法C 算法代码&#xff1a; 题目链接&#xff1a; 23. 合并 K 个升序链表 题目描述&#xff1a; 解法 解法一&#xff1a;暴力解法 每个链表的平均长度为n&#xff0c;有k个链表&#xff0c;时间复杂度O(nk^2) 合并两个有序…

Java中的注解技术讲解

Java中的注解&#xff08;Annotation&#xff09;是一种在代码中嵌入元数据的机制&#xff0c;不直接参与业务逻辑&#xff0c;而是为编译器、开发工具以及运行时提供额外的信息和指导。下面我们将由浅入深地讲解Java注解的概念、实现原理、各种应用场景&#xff0c;并通过代码…