【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)

【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)

【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)


文章目录

  • 【深度学习中的注意力机制6】11种主流注意力机制112个创新研究paper+代码——加性注意力(Additive Attention)
  • 1. 加性注意力的起源与提出
  • 2. 加性注意力的原理
  • 3. 发展
  • 4. 代码实现
  • 5. 代码逐句解释


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

1. 加性注意力的起源与提出

加性注意力(Additive Attention)是由Bahdanau et al. 在其2015年关于机器翻译的论文中提出的。这一注意力机制被应用于神经机器翻译(NMT)模型中,旨在提高翻译任务中序列对序列(Seq2Seq)模型的性能,尤其是解决长距离依赖问题。传统的Seq2Seq模型仅依赖于编码器的最终隐藏状态来生成翻译,这在处理长文本时容易丢失输入的细节信息。加性注意力通过在解码过程中对编码器隐藏状态进行加权求和,显著提升了模型性能

加性注意力是一种较早提出的注意力机制,与随后流行的点积注意力不同,加性注意力通过一个可学习的网络计算注意力分数,而不是直接计算向量之间的点积。加性注意力的提出标志着注意力机制在深度学习领域中的广泛应用,尤其是在处理长序列数据时的应用。

2. 加性注意力的原理

加性注意力的核心思想是通过学习一个函数来计算查询(Query)和键(Key)之间的相似性,然后根据相似性对值(Value)进行加权。

具体步骤如下:

1) 输入:

  • Query:解码器中的当前隐藏状态。
  • Key 和 Value:编码器中的隐藏状态(通常是一系列时间步的隐藏状态序列)。

2) 计算注意力分数: 通过将Query和Key进行非线性变换,再经过加性函数求得注意力分数。这个过程使用了一个可学习的权重矩阵,将查询和键分别映射到一个共同的表示空间,计算它们的相似性。

3) softmax归一化: 将上述得到的注意力分数通过softmax函数进行归一化,得到注意力权重。

4) 加权求和: 使用得到的注意力权重对值(Value)进行加权求和,生成最终的加权上下文向量。

公式如下:
在这里插入图片描述
这里, W q W_q Wq W k W_k Wk是可学习的权重矩阵, e i j e_{ij} eij 是注意力分数, v j v_j vj是Value。

3. 发展

加性注意力是最早被提出的注意力机制之一,并在神经机器翻译中取得了显著的成果。后来,随着注意力机制的发展,点积注意力(如Transformer中的缩放点积注意力)因其更高效的计算方式而逐渐取代了加性注意力。然而,加性注意力仍然在某些场景中被使用,尤其是在需要更细致的相似性计算的任务中。

在性能方面,加性注意力与点积注意力的主要区别在于计算复杂度。加性注意力通过一个可学习的神经网络计算注意力分数,计算复杂度为 O ( d ) O(d) O(d),而点积注意力直接计算点积,复杂度为 O ( d 2 ) O(d^2) O(d2),这使得加性注意力在某些场景下具有优势。

4. 代码实现

下面是一个使用加性注意力机制的简化实现,基于PyTorch框架。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass AdditiveAttention(nn.Module):def __init__(self, query_dim, key_dim, hidden_dim):super(AdditiveAttention, self).__init__()# 定义线性层,用于将查询和键映射到同一空间self.query_layer = nn.Linear(query_dim, hidden_dim)self.key_layer = nn.Linear(key_dim, hidden_dim)# 定义一个线性层,用于计算注意力分数self.energy_layer = nn.Linear(hidden_dim, 1)def forward(self, query, keys, values):# query: [batch_size, query_dim]# keys: [batch_size, seq_len, key_dim]# values: [batch_size, seq_len, value_dim]# 计算查询和键的投影query_proj = self.query_layer(query)  # [batch_size, hidden_dim]keys_proj = self.key_layer(keys)  # [batch_size, seq_len, hidden_dim]# 将查询扩展到和键的时间步相同的维度query_proj = query_proj.unsqueeze(1).expand_as(keys_proj)  # [batch_size, seq_len, hidden_dim]# 计算 e_ij = tanh(W_q q + W_k k)energy = torch.tanh(query_proj + keys_proj)  # [batch_size, seq_len, hidden_dim]# 计算注意力分数,并去掉最后一维attention_scores = self.energy_layer(energy).squeeze(-1)  # [batch_size, seq_len]# 通过softmax得到注意力权重attention_weights = F.softmax(attention_scores, dim=1)  # [batch_size, seq_len]# 加权求和值context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)  # [batch_size, value_dim]return context, attention_weights# 测试加性注意力
batch_size = 2
query_dim = 5
key_dim = 5
value_dim = 6
seq_len = 10
hidden_dim = 20# 随机生成查询、键和值
query = torch.randn(batch_size, query_dim)
keys = torch.randn(batch_size, seq_len, key_dim)
values = torch.randn(batch_size, seq_len, value_dim)# 实例化加性注意力
additive_attention = AdditiveAttention(query_dim, key_dim, hidden_dim)# 前向传播
context, attention_weights = additive_attention(query, keys, values)print("上下文向量:", context)
print("注意力权重:", attention_weights)

5. 代码逐句解释

1. 导入库:

import torch
import torch.nn as nn
import torch.nn.functional as F

导入PyTorch库,其中torch用于张量操作,nn包含神经网络模块,F提供常用函数如softmax。

2. 定义加性注意力类:

class AdditiveAttention(nn.Module):def __init__(self, query_dim, key_dim, hidden_dim):super(AdditiveAttention, self).__init__()# 定义线性层,用于将查询和键投影到同一维度self.query_layer = nn.Linear(query_dim, hidden_dim)self.key_layer = nn.Linear(key_dim, hidden_dim)# 定义计算注意力能量的线性层self.energy_layer = nn.Linear(hidden_dim, 1)

这里定义了AdditiveAttention类,继承自nn.Modulequery_layerkey_layer分别是将查询和键投影到同一维度的线性层,energy_layer用于计算注意力能量分数。

3. 前向传播函数:

def forward(self, query, keys, values):query_proj = self.query_layer(query)  # [batch_size, hidden_dim]keys_proj = self.key_layer(keys)  # [batch_size, seq_len, hidden_dim]# 扩展查询的维度,使其与键对齐query_proj = query_proj.unsqueeze(1).expand_as(keys_proj)# 计算注意力能量:e_ij = tanh(W_q q + W_k k)energy = torch.tanh(query_proj + keys_proj)  # [batch_size, seq_len, hidden_dim]# 通过线性层计算注意力分数,并去掉最后一维attention_scores = self.energy_layer(energy).squeeze(-1)  # [batch_size, seq_len]# 使用softmax归一化得到注意力权重attention_weights = F.softmax(attention_scores, dim=1)  # [batch_size, seq_len]# 计算上下文向量,通过加权求和值context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)  # [batch_size, value_dim]return context, attention_weights
  • forward函数负责计算加性注意力的前向传播过程。首先,将查询和键分别通过线性层映射到相同的维度。
  • 然后,计算注意力能量,并使用softmax进行归一化,得到注意力权重。
  • 最后,使用这些注意力权重对值进行加权求和,生成上下文向量。
    4. 测试模型:
# 测试加性注意力
query = torch.randn(batch_size, query_dim)
keys = torch.randn(batch_size, seq_len, key_dim)
values = torch.randn(batch_size, seq_len, value_dim)# 实例化加性注意力
additive_attention = AdditiveAttention(query_dim, key_dim, hidden_dim)# 前向传播
context, attention_weights = additive_attention(query, keys, values)print("上下文向量:", context)
print("注意力权重:", attention_weights)

在这里,使用随机生成的张量querykeysvalues来测试加性注意力的输出。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/57293.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

修复Oracle MySQL Server 安全漏洞(CVE-2023-0464)

@[TOC](修复Oracle MySQL Server 安全漏洞(CVE-2023-0464)) 对于MySQL的漏洞问题,建议通过防火墙来限制远程访问本地3306端口的方式来处理。如果必须要升级,那么涉及到的具体兼容性问题,新版本安装后会导致的业务异常。 所以,建议采用增加防火墙策略的方式,不建议对mysql进…

计算PSNR, SSIM, VAMF工具

计算PSNR, SSIM, VAMF GitHub - fifonik/FFMetrics: Visualizes Video Quality Metrics (PSNR, SSIM & VMAF) calculated by ffmpeg.exe 绘制码率图 GitHub - fifonik/FFBitrateViewer: Visualizes video bitrate received by ffprobe.exe 视频对比 https://github.com/…

什么是全局污染?怎么避免全局污染?

全局污染(Global Pollution)是指在编程过程中,过度使用全局变量或对象导致命名冲突、代码可维护性下降及潜在错误增加的问题。在 JavaScript 等动态语言中,尤其需要关注全局污染的风险。 全局污染的影响 1. 命名冲突 3. 意外修改…

【C#】调用本机AI大模型流式返回

【python】AI Navigator的使用及搭建本机大模型_anaconda ai navigator-CSDN博客 【Python】AI Navigator对话流式输出_python ai流式返回-CSDN博客 前两章节我们讲解了使用AI Navigator软件搭建本机大模型,并使用python对大模型api进行调用,使其流式返…

Python Flask 框架下的 API 接口开发与封装示例

API(Application Programming Interface)接口的开发和封装是构建软件系统的重要环节。以下是关于 API 接口开发和封装的详细步骤: 一、需求分析 在开发 API 接口之前,首先需要明确接口的功能需求。这包括确定接口要提供哪些数据…

“智能科研写作:结合AI与ChatGPT提升SCI论文和基金申请质量“

基于AI辅助下的高效高质量SCI论文撰写及投稿实践 科学研究的核心在于将复杂的思想和实验成果通过严谨的写作有效地传递给学术界和工业界。对于研究生、青年学者及科研人员,如何高效撰写和发表SCI论文,成为提升学术水平和科研成果的重要环节。系统掌握从…

ProteinMPNN中DecLayer类介绍

PositionWiseFeedForward 类的代码 class PositionWiseFeedForward(nn.Module):def __init__(self, num_hidden, num_ff):super(PositionWiseFeedForward, self).__init__()self.W_in = nn.Linear(num_hidden, num_ff, bias=True)self.W_out = nn.Linear(num_ff, num_hidden, …

SAP_FICO模块-资产减值功能对折旧和残值的影响

一、业务背景 由于财务同事没注意,用总账给资产多做了一笔凭证,导致该资产金额虚增,每个月的折旧金额也虚增;现在财务的需求是怎么操作可以进行资产减值,并且减少每个月计提的折旧; 二、实现方式 通过事务码…

linux CentOs7 安装 FastDFS

CentOs7 安装 FastDFS 1. 安装依赖 yum install gcc libevent libevent-devel -y#进入src目录 cd /usr/local/src2. 安装 libfastcommon 库 libfastcommon 库是 FastDFS 文件系统运行需要的公共 C 语言函数库 # 下载 wget https://github.com/happyfish100/libfastcommon/a…

使用梧桐数据库进行销售趋势分析和预测

在当今竞争激烈的商业环境中,企业需要深入了解销售数据,以便做出明智的决策。销售趋势分析和预测是帮助企业把握市场动态、优化库存管理、制定营销策略的重要工具。本文将介绍如何使用SQL来创建销售数据库的表结构,插入示例数据,并…

6.2024.10.22

2024.10.22 2024.10.22 2024.10.22 今天没怎么学习嵌入式的,找时间补上今天学习的空缺。

qt EventFilter用途详解

一、概述 EventFilter是QObject类的一个事件过滤器,当使用installEventFilter方法为某个对象安装事件过滤器时,该对象的eventFilter函数就会被调用。通过重写eventFilter方法,开发者可以在事件处理过程中进行拦截和处理,实现对事…

go 语言 Gin Web 框架的实现原理探究

Gin 是一个用 Go (Golang) 编写的 Web 框架,性能极优,具有快速、支持中间件、crash处理、json验证、路由组、错误管理、内存渲染、可扩展性等特点。 官网地址:https://gin-gonic.com/ 源码地址:https://github.com/gin-gonic/gi…

Shell重定向输入输出

我的后端学习大纲 我的Linux学习大纲 重定向介绍 标准输入介绍 从键盘读取用户输入的数据,然后再把数据拿到Shell程序中使用; 标准输出介绍 Shell程序产生的数据,这些数据一般都是呈现到显示器上供用户浏览查看; 默认输入输出文件 每个…

重新认识Linux下的硬链接和软链接

文章目录 前言1、软链接?1.1 工作原理1.2 特点 2、硬链接2.1 工作原理2.2 特点 3、 总结 前言 让自己永远保持一颗好奇心 今天无意间听别人提到了硬链接和软链接,起初我想这么基础的知识我肯定是知道的,毕竟大学接触Linux到现在工作了那么多…

ubuntu20.04 opencv4.0 /usr/local/lib/libgflags.a(gflags.cc.o): relocation报错解决

在一个只有ubuntu20.04的docker环境中配置opencv4.0.0, 什么库都没有,都要重新安装, 其他的问题在网上都找到了解决方案,唯独这个问题比较棘手: [ 86%] Linking CXX executable …/…/bin/opencv_annotation /usr/bin/ld: /usr/lo…

前OpenAI首席技术官为新AI初创公司筹资;我国发布首个应用临床眼科大模型 “伏羲慧眼”|AI日报

文章推荐 2024人工智能报告.zip |一文迅速了解今年的AI界都发生了什么? 今日热点 据报道,前OpenAI首席技术官Mira Murati正在为一家新的AI初创公司筹集资金 据路透社报道,上个月宣布离职的OpenAI首席技术官Mira Murati正在为一…

栈和队列(一)

栈和队列的定义和特点 栈和队列是一种特殊的线性表,只能在表的端点进行操作 栈的定义和特点 这就是栈的结构,是一个特殊的线性表,只能在栈顶(或者说是表尾)进行操作。其中top为栈顶,base为栈底 栈s的存储…

C语言结构体数组 java静动数组及问题

1. (1)先声明,后定义:如上一天 //(2).声明时直接定义 #define N 5 typedef struct student { int num; int score; }STU; int main(void) { STU class3[N] { {10,90},{14,70},{8,95} }; …

全面解析:集成AWS、云原生和监控的开源运维管理平台

在当今复杂的IT环境中,寻找一个能够同时支持AWS、云原生技术(如Kubernetes)和全面监控功能的开源运维管理平台,已成为许多组织的迫切需求。本文将深入探讨几个有潜力满足这些需求的开源解决方案,分析它们的优势、局限性…