transformer中的build_attention_mask

build_attention_mask 方法的作用是构建一个因果注意力掩码,用于屏蔽 Transformer 模型中的未来位置。

因果注意力掩码的工作原理

因果注意力掩码通过将未来位置的注意力权重设置为负无穷大,从而确保这些位置的注意力得分在 softmax 计算中接近于零。具体来说,这个掩码矩阵是一个上三角矩阵,其中上三角部分(不包括对角线)被设置为负无穷大。这样,当计算第 𝑖个位置的注意力分数时,只会考虑位置
0 到 𝑖的内容,而忽略位置 𝑖+1及之后的位置,这对于生成任务(如语言模型)非常重要。下面是这个方法的实现和详细解释:

详细解释

  1. 掩码矩阵的构建

    • 掩码矩阵 mask 的形状为 ([context_length, context_length])。
    • mask.fill_(float("-inf")) 将矩阵的所有元素初始化为负无穷大。
    • mask.triu_(1) 将上三角部分(不包括对角线)设置为零。
  2. 应用掩码

    • 在计算注意力分数时,这个掩码会被添加到注意力分数矩阵中。
    • 由于被掩盖的部分被设置为负无穷大,它们在 softmax 计算中会得到接近零的权重,从而 effectively 被忽略。

示例

假设 context_length 为 5,那么生成的掩码 mask 将如下所示:

tensor([[  0., -inf, -inf, -inf, -inf],[  0.,   0., -inf, -inf, -inf],[  0.,   0.,   0., -inf, -inf],[  0.,   0.,   0.,   0., -inf],[  0.,   0.,   0.,   0.,   0.]])

计算注意力分数时的效果

在计算注意力分数时,假设我们有以下示例:

  • 输入序列:[x_0, x_1, x_2, x_3, x_4]
  • 注意力权重矩阵形状:[context_length, context_length]

计算第 ( i ) 个位置的注意力分数时,将使用掩码矩阵对注意力权重进行修正:

未加掩码时的注意力权重矩阵:
[[a00, a01, a02, a03, a04],[a10, a11, a12, a13, a14],[a20, a21, a22, a23, a24],[a30, a31, a32, a33, a34],[a40, a41, a42, a43, a44]]加上掩码后的注意力权重矩阵:
[[a00, -inf, -inf, -inf, -inf],[a10,  a11, -inf, -inf, -inf],[a20,  a21,  a22, -inf, -inf],[a30,  a31,  a32,  a33, -inf],[a40,  a41,  a42,  a43,  a44]]

为什么可以确保位置 ( i ) 只能关注位置 ( 0 ) 到 ( i )

  • 对于位置 ( 0 ),只有 ( a00 ) 会被保留,其他的都被设置为负无穷大。
  • 对于位置 ( 1 ),只有 ( a10 ) 和 ( a11 ) 会被保留,其他的都被设置为负无穷大。
  • 对于位置 ( 2 ),只有 ( a20 ), ( a21 ), ( a22 ) 会被保留,其他的都被设置为负无穷大。
  • 对于位置 ( 3 ),只有 ( a30 ), ( a31 ), ( a32 ), ( a33 ) 会被保留,其他的都被设置为负无穷大。
  • 对于位置 ( 4 ),所有的 ( a40, a41, a42, a43, a44 ) 都会被保留,因为它已经是最后一个位置。

这种掩码方式确保了模型在生成第 ( i ) 个位置的输出时,不会看到第 ( i+1 ) 及之后的位置的输入。

代码示例

以下是一个简单的代码示例,展示了如何使用 build_attention_mask 生成掩码并应用到注意力机制中:

import torchclass ExampleModel:def __init__(self, context_length):self.context_length = context_lengthdef build_attention_mask(self):mask = torch.empty(self.context_length, self.context_length)mask.fill_(float("-inf"))mask.triu_(1)return mask# 假设 context_length 为 5
context_length = 5
model = ExampleModel(context_length)
attention_mask = model.build_attention_mask()print(attention_mask)# 示例的注意力权重矩阵
attention_weights = torch.randn(context_length, context_length)# 加上掩码后的注意力权重矩阵
masked_attention_weights = attention_weights + attention_mask
print(masked_attention_weights)

总结

build_attention_mask 方法通过生成一个上三角掩码矩阵,确保了每个位置 ( i ) 只能关注位置 ( 0 ) 到 ( i ),而不能关注位置 ( i+1 ) 及之后的位置。这个机制通过在注意力分数计算中设置负无穷大,使得这些位置在 softmax 计算中得到接近零的权重,从而 effectively 被忽略。这对于生成任务(如语言模型)非常重要,确保模型在生成时只依赖已生成的部分,而不会看到未来的输入。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/32662.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sqlalchemy event监听

在 SQLAlchemy 中,event 系统允许你监听数据库引擎、会话、映射类等对象上的事件,并在这些事件发生时执行自定义的代码。这对于在 SQL 语句执行前后、对象加载、对象刷新等时刻执行特定的逻辑非常有用。 要使用 SQLAlchemy 的 event 系统,你…

爬虫经典案例之爬取豆瓣电影Top250(方法一)

简介:主要使用bs4、request、pandas等模块,实现数据的爬取和存储。 目前存在一点小问题,就是个别电影的导演、演员、上映年份和地区等信息与大部分电影的这些信息的格式有细微差别,导致正则表达式无法正常匹配到个别电影的信息&am…

解析cJSON数组

json串: { "list":[ "hello","world" ] } 代码 : int func(char *sn) { int ret 0; cJSON *root, *list; FILE *fp fopen("a.txt", "r"); if(!fp) { printf("open s…

并发的概念

并发是指在同一时间间隔内同时执行多个任务或处理多个事件的能力或现象。在计算机科学中,特别是在多任务处理系统中,"并发"通常用于描述系统能够在同一时间段内处理多个任务或操作的能力。 并发并不意味着同时执行多个任务,而是通…

艺体培训机构管理系统的设计

管理员账户功能包括:系统首页,个人中心,管理员管理,教师管理,学员管理,活动管理,课程管理,选课信息管理 前台账户功能包括:系统首页,个人中心,论…

【深度C++】之“类与结构体”

0. 抽象数据类型 类(class) 和结构体(struct) 都是C中的自定义数据类型,是使用C实现面向对象编程思想的起点。 类的基本思想是数据抽象(data abstraction) 和封装(encapsulation&a…

【会议征稿,ACM出版】2024年图像处理、智能控制与计算机工程国际学术会议(IPICE 2024,8月9-11)

2024年图像处理、智能控制与计算机工程国际学术会议(IPICE 2024)将于2024年8月9-11日在中国福州举行。本届会议由阳光学院、福建省空间信息感知与智能处理重点实验室、空间数据挖掘与应用福建省高校工程研究中心联合主办。 会议主要围绕图像处理、智能控…

分布式定时任务系列10:XXL-job源码分析之路由策略

传送门 分布式定时任务系列1:XXL-job安装 分布式定时任务系列2:XXL-job使用 分布式定时任务系列3:任务执行引擎设计 分布式定时任务系列4:任务执行引擎设计续 分布式定时任务系列5:XXL-job中blockingQueue的应用 …

Go语言的诞生背景

人不走空 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌赋:斯是陋室,惟吾德馨 目录 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌…

Linux操作系统处理器调度基本准则和实现

1,基本概念 在多道程序系统中,进程的数量往往多于处理机的个数,进程争用处理机的情况就在所难免。处理机调度是对处理机进行分配,就是从就绪队列中,按照一定的算法(公平、低效)选择一个进程并将…

mysql学习——SQL中的DDL和DML

SQL中的DDL和DML DDL数据库操作:表操作 DML添加数据修改数据删除数据 学习黑马MySQL课程,记录笔记,用于复习。 DDL DDL:Data Definition Language,数据定义语言,用来定义数据库对象(数据库,表&…

【CSS】简单实用的calc()函数

calc() 是 CSS 中的一个功能,允许你在属性值中进行基础的数学计算。这是非常有用的,特别是当你需要在不同的上下文或视口大小中动态调整尺寸或位置时。 以下是一些 calc() 函数的简单实用示例: 动态宽度: 假设你希望一个元素的…

C语言入门课程学习笔记8:变量的作用域递归函数宏定义交换变量

C语言入门课程学习笔记8 第36课 - 变量的作用域与生命期(上)第37课 - 变量的作用域与生命期(下)实验—局部变量的作用域实验-变量的生命期 第38课 - 函数专题练习第39课 - 递归函数简介实验小结 第40课 - C 语言中的宏定义实验小结…

基于Java的学生成绩管理系统

你好呀,我是计算机学姐码农小野!如果有相关需求,可以私信联系我。 开发语言:Java 数据库:MySQL 技术:Java技术,B/S结构 工具:MyEclipse,MySQL 系统展示 首页 个人中…

基于YOLOv5+pyqt5的跌倒检测系统(含pyqt页面、训练好的模型)

简介 跌倒是老年人和身体不便者常见的意外事故,及时检测和处理跌倒事件对于保障他们的安全至关重要。为了提高对跌倒事件的监控效率,我们开发了一种基于YOLOv5目标检测模型的跌倒检测系统。本报告将详细介绍该系统的实际应用与实现,包括系统…

虚拟机IP地址频繁变化的解决方法

勾八动态分配IP,让我在学习redis集群的时候,配置很多的IP地址,但是由于以下原因导致我IP频繁变动,报错让我烦恼!!!! 为什么虚拟机的IP地址会频繁变化? 虚拟机IP地址频繁…

终极解决方案,传统极速方案,下载软件的双雄对决!

在数字资源日益丰富的今天,下载管理器成为了我们日常生活中不可或缺的工具。市场上两款备受欢迎的下载管理软件——Internet Download Manager(IDM)和迅雷11,它们以各自的特色和优势,满足了不同用户群体的需求。 软件…

Less与Sass的区别

1. 功能和工具: Sass:提供了更多的功能和内置方法,如条件语句、循环、数学函数等。Sass 也支持更复杂的操作和逻辑构建。 Less:功能也很强大,但相比之下,Sass 在功能上更为丰富和成熟。 2、编译环境&…

uniapp使用伪元素实现气泡

uniapp使用伪元素实现气泡 背景实现思路代码实现尾巴 背景 气泡效果在开发中使用是非常常见的,使用场景有提示框,对话框等等,今天我们使用css来实现气泡效果。老规矩,先看下效果图: 实现思路 其实实现这个气泡框的…

自动驾驶规划中使用 OSQP 进行二次规划 代码原理详细解读

目录 1 问题描述 什么是稀疏矩阵 CSC 形式 QP Path Planning 问题 1. Cost function 1.1 The first term: 1.2 The second term: 1.3 The thrid term: 1.4 The forth term: 对 Qx 矩阵公式的验证 整体 Q 矩阵(就是 P 矩阵,二次项的权重矩阵&…