【现代深度学习技术】注意力机制04:Bahdanau注意力

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ PyTorch深度学习 ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目录

    • 一、模型
    • 二、定义注意力解码器
    • 三、训练
    • 小结


  序列到序列学习(seq2seq)中探讨了机器翻译问题:通过设计一个基于两个循环神经网络的编码器-解码器架构,用于序列到序列学习。具体来说,循环神经网络编码器将长度可变的序列转换为固定形状的上下文变量,然后循环神经网络解码器根据生成的词元和上下文变量按词元生成输出(目标)序列词元。然而,即使并非所有输入(源)词元都对解码某个词元都有用,在每个解码步骤中仍使用编码相同的上下文变量。有什么方法能改变上下文变量呢?

  我们试着找到灵感:在为给定文本序列生成手写的挑战中,Graves设计了一种可微注意力模型,将文本字符与更长的笔迹对齐,其中对齐方式仅向一个方向移动。受学习对齐想法的启发,Bahdanau等人提出了一个没有严格单向对齐限制的可微注意力模型。在预测词元时,如果不是所有输入词元都相关,模型将仅对齐(或参与)输入序列中与当前预测相关的部分。这是通过将上下文变量视为注意力集中的输出来实现的。

一、模型

  下面描述的Bahdanau注意力模型将遵循序列到序列学习(seq2seq)中的相同符号表达。这个新的基于注意力的模型与序列到序列学习(seq2seq)中的模型相同,只不过其中式(3)中的上下文变量 c \mathbf{c} c在任何解码时间步 t ′ t' t都会被 c t ′ \mathbf{c}_{t'} ct替换。假设输入序列中有 T T T个词元,解码时间步 t ′ t' t的上下文变量是注意力集中的输出:
c t ′ = ∑ t = 1 T α ( s t ′ − 1 , h t ) h t (1) \mathbf{c}_{t'} = \sum_{t=1}^T \alpha(\mathbf{s}_{t' - 1}, \mathbf{h}_t) \mathbf{h}_t \tag{1} ct=t=1Tα(st1,ht)ht(1) 其中,时间步 t ′ − 1 t' - 1 t1时的解码器隐状态 s t ′ − 1 \mathbf{s}_{t' - 1} st1是查询,编码器隐状态 h t \mathbf{h}_t ht既是键,也是值,注意力权重 α \alpha α是使用加性注意力打分函数计算的。

  与循环神经网络编码器-解码器架构略有不同,图1描述了Bahdanau注意力的架构。

在这里插入图片描述

图1 一个带有Bahdanau注意力的循环神经网络编码器-解码器模型

import torch
from torch import nn
from d2l import torch as d2l

二、定义注意力解码器

  下面看看如何定义Bahdanau注意力,实现循环神经网络编码器-解码器。其实,我们只需重新定义解码器即可。为了更方便地显示学习的注意力权重,以下AttentionDecoder类定义了带有注意力机制解码器的基本接口。

#@save
class AttentionDecoder(d2l.Decoder):"""带有注意力机制解码器的基本接口"""def __init__(self, **kwargs):super(AttentionDecoder, self).__init__(**kwargs)@propertydef attention_weights(self):raise NotImplementedError

  接下来,让我们在接下来的Seq2SeqAttentionDecoder类中实现带有Bahdanau注意力的循环神经网络解码器。首先,初始化解码器的状态,需要下面的输入:

  1. 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;
  2. 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;
  3. 编码器有效长度(排除在注意力池中填充词元)。

  在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。

class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size,num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size,# num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# 输出X的形状为(num_steps,batch_size,embed_size)X = self.embedding(X).permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:# query的形状为(batch_size,1,num_hiddens)query = torch.unsqueeze(hidden_state[-1], dim=1)# context的形状为(batch_size,1,num_hiddens)context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)# 在特征维度上连结x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)# 将x变形为(1,batch_size,embed_size+num_hiddens)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights)# 全连接层变换后,outputs的形状为# (num_steps,batch_size,vocab_size)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights

  接下来,使用包含7个时间步的4个序列输入的小批量测试Bahdanau注意力解码器。

encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

在这里插入图片描述

三、训练

  与序列到序列学习(seq2seq)类似,我们在这里指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器,并对这个模型进行机器翻译训练。由于新增的注意力机制,训练要序列到序列学习(seq2seq)比没有注意力机制的慢得多。

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

在这里插入图片描述
在这里插入图片描述

  模型训练后,我们用它将几个英语句子翻译成法语并计算它们的BLEU分数。

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, dec_attention_weight_seq = d2l.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device, True)print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

在这里插入图片描述

attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((1, 1, -1, num_steps))

  训练结束后,下面通过可视化注意力权重会发现,每个查询都会在键值对上分配不同的权重,这说明在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。

# 加上一个包含序列结束词元
d2l.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),xlabel='Key positions', ylabel='Query positions')

在这里插入图片描述

小结

  • 在预测词元时,如果不是所有输入词元都是相关的,那么具有Bahdanau注意力的循环神经网络编码器-解码器会有选择地统计输入序列的不同部分。这是通过将上下文变量视为加性注意力池化的输出来实现的。
  • 在循环神经网络编码器-解码器中,Bahdanau注意力将上一时间步的解码器隐状态视为查询,在所有时间步的编码器隐状态同时视为键和值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/80200.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

爬虫学习————开始

🌿自动化的思想 任何领域的发展原因————“不断追求生产方式的改革,即使得付出与耗费精力越来愈少,而收获最大化”。由此,创造出方法和设备来提升效率。 如新闻的5W原则直接让思考过程规范化、流程化。或者前端框架/后端轮子的…

每天五分钟机器学习:KTT条件

本文重点 在前面的课程中,我们学习了拉格朗日乘数法求解等式约束下函数极值,如果约束不是等式而是不等式呢?此时就需要KTT条件出手了,KTT条件是拉格朗日乘数法的推广。KTT条件不仅统一了等式约束与不等式约束的优化问题求解范式,KTT条件给出了这类问题取得极值的一阶必要…

leetcode0829. 连续整数求和-hard

1 题目: 连续整数求和 官方标定难度:难 给定一个正整数 n,返回 连续正整数满足所有数字之和为 n 的组数 。 示例 1: 输入: n 5 输出: 2 解释: 5 2 3,共有两组连续整数([5],[2,3])求和后为 5。 示例 2: 输入: n 9 输出: …

window 显示驱动开发-线性伸缩空间段

线性伸缩空间段类似于线性内存空间段。 但是,伸缩空间段只是地址空间,不能容纳位。 若要保存位,必须分配系统内存页,并且必须重定向地址空间范围以引用这些页面。 内核模式显示微型端口驱动程序(KMD)必须实…

Cadence 高速系统设计流程及工具使用三

5.8 约束规则的应用 5.8.1 层次化约束关系 在应用约束规则之前,我们首先要了解这些约束规则是如何作用在 Cadence 设计对象上的。Cadence 中对设计对象的划分和概念,如表 5-11 所示。 在 Cadence 系统中,把设计对象按层次进行了划分&#…

ScaleTransition 是 Flutter 中的一个动画组件,用于实现缩放动画效果。

ScaleTransition 是 Flutter 中的一个动画组件,用于实现缩放动画效果。它允许你对子组件进行动态的缩放变换,从而实现平滑的动画效果。ScaleTransition 通常与 AnimationController 和 Tween 一起使用,以控制动画的开始、结束和过渡效果。 基…

深入解析:如何基于开源p-net快速开发Profinet从站服务

一、Profinet协议与软协议栈技术解析 1.1 工业通信的"高速公路" Profinet作为工业以太网协议三巨头之一,采用IEEE 802.3标准实现实时通信,具有: 实时分级:支持RT(实时)和IRT(等时实时)通信模式拓扑灵活:支持星型、树型、环型等多种网络结构对象模型:基于…

m个n维向量组中m,n的含义与空间的关系

向量的维度与空间的关系&#xff1a; 一个向量的维度由其分量个数决定&#xff0c;例如 ( n ) 个分量的向量属于 Rn空间 。 向量组张成空间的维度&#xff1a; 当向量组有 ( m ) 个线性无关的 ( n ) 维向量时&#xff1a; 若 ( m < n )&#xff1a; 这些向量张成的是 Rn中的…

excel大表导入数据库

前文介绍了数据量较小的excel表导入数据库的方法&#xff0c;在数据量较大的情况下就不太适合了&#xff0c;一个是因为mysql命令的执行串长度有限制&#xff0c;二是node-xlsx这个模块加载excel文件是整个文件全部加载到内存&#xff0c;在excel文件较大和可用内存受限的场景就…

Python 爬虫基础入门教程(超详细)

一、什么是爬虫&#xff1f; 网络爬虫&#xff08;Web Crawler&#xff09;&#xff0c;又称网页蜘蛛&#xff0c;是一种自动抓取互联网信息的程序。爬虫会模拟人的浏览行为&#xff0c;向网站发送请求&#xff0c;然后获取网页内容并提取有用的数据。 二、Python爬虫的基本原…

Spring Security 深度解析:打造坚不可摧的用户认证与授权系统

Spring Security 深度解析&#xff1a;打造坚不可摧的用户认证与授权系统 一、引言 在当今数字化时代&#xff0c;构建安全可靠的用户认证与授权系统是软件开发中的关键任务。Spring Security 作为一款功能强大的 Java 安全框架&#xff0c;为开发者提供了全面的解决方案。本…

【物联网】基于树莓派的物联网开发【1】——初识树莓派

使用背景 物联网开发从0到1研究&#xff0c;以树莓派为基础 场景介绍 系统学习Linux、Python、WEB全栈、各种传感器和硬件 接下来程序猫将带领大家进军物联网世界&#xff0c;从0开始入门研究树莓派。 认识树莓派 正面图示&#xff1a; 1&#xff1a;树莓派简介 树莓派…

第21节:深度学习基础-激活函数比较(ReLU, Sigmoid, Tanh)

1. 引言 在深度学习领域,激活函数是神经网络中至关重要的组成部分 它决定了神经元是否应该被激活以及如何将输入信号转换为输出信号 激活函数为神经网络引入了非线性因素,使其能够学习并执行复杂的任务 没有激活函数,无论神经网络有多少层,都只能表示线性变换,极大地限…

Fiori学习专题三十:Routing and Navigation

实际上我们的页面是会有多个的&#xff0c;并且可以在多个页面之间跳转&#xff0c;这节课就学习如何在不同页面之间实现跳转。 1.修改配置文件manifest.json&#xff0c;加入routing&#xff0c;包含三个部分&#xff0c;config,routes,targets; config &#xff1a; routerC…

【HarmonyOS NEXT+AI】问答05:ArkTS和仓颉编程语言怎么选?

在“HarmonyOS NEXTAI大模型打造智能助手APP(仓颉版)”课程里面&#xff0c;有学员提到了这样一个问题&#xff1a; 鸿蒙的主推开发语言不是ArkTS吗&#xff0c;本课程为什么使用的是仓颉编程语言&#xff1f; 这里就这位同学的问题&#xff0c;统一做下回复&#xff0c;以方便…

Booth Encoding vs. Non-Booth Multipliers —— 穿透 DC 架构看乘法器的底层博弈

目录 &#x1f9ed; 前言 &#x1f331; 1. Non-Booth 乘法器的实现原理&#xff08;也叫常规乘法器&#xff09; &#x1f527; 构建方式 ✍️ 例子&#xff1a;4x4 Non-Booth 乘法器示意 &#x1f9f1; 硬件结构 ✅ 特点总结 ⚡ 2. Booth Encoding&#xff08;布斯编码…

GET请求如何传复杂数组参数

背景 有个历史项目&#xff0c;是GET请求&#xff0c;但是很多请求还是复杂参数&#xff0c;比如&#xff1a;参数是数组&#xff0c;且数组中每一个元素都是复杂的对象&#xff0c;这个时候怎么传参数呢&#xff1f; 看之前请求直接是拼接在url后面 类似&items%5B0%5D.…

iOS App 安全性探索:源码保护、混淆方案与逆向防护日常

iOS App 安全性探索&#xff1a;源码保护、混淆方案与逆向防护日常 在 iOS 开发者的日常工作中&#xff0c;我们总是关注功能的完整性、性能的优化和UI的细节&#xff0c;但常常忽视了另一个越来越重要的问题&#xff1a;发布后的应用安全。 尤其是对于中小团队或独立开发者&…

A* (AStar) 寻路

//调用工具类获取路线 let route AStarSearch.getRoute(start_point, end_point, this.mapFloor.map_point); map_point 是所有可走点的集合 import { _decorator, Component, Node, Prefab, instantiate, v3, Vec2 } from cc; import { oops } from "../../../../../e…

深度解析动态IP业务核心场景:从技术演进到行业实践

引言&#xff1a;动态IP的技术演进与行业价值 在数字化转型加速的今天&#xff0c;IP地址已从单纯的网络标识演变为支撑数字经济的核心基础设施。动态IP作为灵活高效的地址分配方案&#xff0c;正突破传统认知边界&#xff0c;在网络安全防护、数据价值挖掘、全球业务拓展等领…