[machine learning] Transformer - Attention (一)

Attention是Transformer的核心,本系列先通过介绍Attention来学习Transformer。本文先介绍简单版的Attention。

在Attention出现之前,通常使用recurrent neural networds (RNNs)来处理长序列数据。模型架构上,又通常使用encoder-decoder的结构。
以机器翻译为例,当输入文本序列一个一个进入encoder时,encoder也在一步一步地更新它的hidden state(即隐藏层的值)。通过这种方式,encoder在最后一次更新完hidden state后,尽可能多地把整个输入文本序列的含义捕捉存储到最终的hidden state中。decoder以encoder的最终hidden state作为输入,一次一个字地开始翻译。同样,decoder也是一步一步地更新它的hidden state,每一次更新后的hidden state都包含了预测下一个字的必要上下文信息。下面的图就是整个流程:
在这里插入图片描述
整个流程的关键点是encoder将整个文本序列处理成最终的hidden state(memory cell)。decoder以encoder最终的hidden state为输入,生成输出。encoder-decoder RNN架构最大的问题和局限性是在decoding阶段,RNN无法直接访问encoder中比较靠前的hidden state。结果decoder只能基于包含了所有相关信息的最终hidden state。当遇到复杂的前后依赖距离跨度大的长句时,这个问题会导致上下文信息的丢失。

Transformer的提出解决了RNN的缺陷并逐渐取代了RNN在NLP领域的位置。而Transformer的核心正是self-attention。self-attention是这样一套机制,当计算一个序列的表示时,它允许序列中每个位置上的元素去关注序列中其他位置(包括它本身)的元素。self-attention中的"self"是指这套机制能够通过比较当前位置元素与其他所有位置元素的相关性来计算出关注度权重(attention weights)。它评估学习了输入自己本身不同部分之间的联系和依赖关系。

有了上面的基本概念之后,下面先看一个没有训练权重的simple self-attention。

假设现在有一个6个单词的序列,每个单词的embedding维度是3。

import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)

现在我们想要计算第二个单词跟其他所有单词的attention weight。首先我们先计算attention score:

query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)
# 输出:
# tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

这里我们使用的点乘(dot product)。点乘是将两个向量element-wise地相乘然后相加。点乘结果越大表明两个向量越相关。这部分过程如下图所示:
在这里插入图片描述
接下来,我们对attention score做归一化求attention weight。

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())
# 输出:
# Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
# Sum: tensor(1.0000)

实际中,主要是用PyTorch自带的softmax函数来做这个归一化:

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())
# 输出:
# Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
# Sum: tensor(1.)

在这里插入图片描述

有了attention weight之后,就是最后一步,求第二个单词的上下文向量。

query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)
# 输出:
# tensor([0.4419, 0.6515, 0.5683])

在这里插入图片描述

上面是计算第二个单词的上下文向量,下面我们一次性地求所有单词的上下文向量:

# @ 是矩阵乘法,inputs.T 是转置操作
attn_scores = inputs @ inputs.T
print(attn_scores)
# tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
# [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
# [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
# [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
# [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
# [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])# 注意在第一维上做softmax
attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)
# 可以看到每一行加起来都是1
# tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
# [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
# [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
# [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
# [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
# [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])all_context_vecs = attn_weights @ inputs
print(all_context_vecs)
# tensor([[0.4421, 0.5931, 0.5790],
# [0.4419, 0.6515, 0.5683],
# [0.4431, 0.6496, 0.5671],
# [0.4304, 0.6298, 0.5510],
# [0.4671, 0.5910, 0.5266],
# [0.4177, 0.6503, 0.5645]])

本文介绍了一种简单版的self-attention,即通过两个词向量点乘求出两个词向量的相关性。下一篇文章,我们将介绍带可训练参数的self-attention。

参考资料:
《Build a Large Language Model from scratch》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/81862.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android 输入控件事件使用示例

一 前端 <EditTextandroid:id="@+id/editTextText2"android:layout_width="match_parent"android:layout_height="wrap_content"android:ems="10"android:inputType="text"android:text="Name" />二 后台代…

【向量数据库】用披萨点餐解释向量数据库:一个美味的技术类比

文章目录 前言场景设定&#xff1a;披萨特征向量化顾客到来&#xff1a;生成查询向量相似度计算实战1. 欧氏距离计算&#xff08;值越小越相似&#xff09;2. 余弦相似度计算&#xff08;值越大越相似&#xff09; 关键发现&#xff1a;度量选择影响结果现实启示结语 前言 想象…

人工智能和机器学习在包装仿真中的应用与价值

引言 随着包装成为消费品关键的差异化因素&#xff0c;对智能设计、可持续性和高性能的要求比以往任何时候都更高 。为了满足这些复杂的期望&#xff0c;公司越来越多地采用先进的仿真方法&#xff0c;而现在人工智能 (AI) 和机器学习 (ML) 又极大地增强了这些方法 。本文探讨…

【人工智能】深入探索Python中的自然语言理解:实现实体识别系统

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 自然语言理解(NLU)是人工智能(AI)领域中的重要研究方向之一,其目标是让计算机理解和处理人类语言。在NLU的众多应用中,实体识别(Nam…

个人健康中枢的多元化AI硬件革新与精准健康路径探析

在医疗信息化领域,个人健康中枢正经历着一场由硬件技术革新驱动的深刻变革。随着可穿戴设备、传感器技术和人工智能算法的快速发展,新一代健康监测硬件能够采集前所未有的多维度生物数据,并通过智能分析提供精准的健康建议。本文将深入探讨构成个人健康中枢的最新硬件技术,…

深入了解Linux系统—— 进程切换和调度

前言&#xff1a; 了解了进程的状态和进程的优先级&#xff0c;我们现在来看进程是如何被CPU调度执行的。 在单CPU的系统在&#xff0c;程序是并发执行的&#xff1b;也就是说在一段时间呢&#xff0c;进程是轮番执行的&#xff1b; 这也是说一个进程在运行时不会一直占用CPU直…

阿里云服务迁移实战: 06-切换DNS

概述 按前面的步骤&#xff0c;所有服务迁移完毕之后&#xff0c;最后就剩下 DNS 解析修改了。 修改解析 在域名解析处&#xff0c;修改域名的解析地址即可。 如果 IP 已经过户到了新账号&#xff0c;则不需要修改解析。 何确保业务稳定 域名解析更换时&#xff0c;由于 D…

uni-app 中封装全局音频播放器

在开发移动应用时&#xff0c;音频播放功能是一个常见的需求。无论是背景音乐、音效还是语音消息&#xff0c;音频播放都需要一个稳定且易于管理的解决方案。在 uni-app 中&#xff0c;虽然原生提供了 uni.createInnerAudioContext 方法用于音频播放&#xff0c;但直接使用它可…

golang常用库之-标准库text/template

文章目录 golang常用库之-标准库text/template背景什么是text/templatetext/template库的使用 golang常用库之-标准库text/template 背景 在许多编程场景中&#xff0c;我们经常需要把数据按照某种格式进行输出&#xff0c;比如生成HTML页面&#xff0c;或者生成配置文件。这…

Linux btop 使用教程

简介 btop 是一个基于终端的现代系统资源监控器&#xff0c;具有美观的图形界面、响应快、功能丰富等特点。它支持查看 CPU、内存、磁盘、网络、进程&#xff0c;并可以方便地筛选和管理进程。 功能总览 启动命令&#xff1a; btop界面分为以下几部分&#xff1a; CPU 区域…

Vue3调度器错误解析,完美解决Unhandled error during execution of scheduler flush.

目录 Vue3调度器错误解析&#xff0c;完美解决Unhandled error during execution of scheduler flush. 一、问题现象与本质 二、七大高频错误场景与解决方案 1、Setup初始化陷阱 2、模板中的"幽灵属性" 3、异步操作的"定时炸弹" 4、组件嵌套黑洞 5…

使用DeepSeek定制Python小游戏——以“俄罗斯方块”为例

前言 本来想再发几个小游戏后在整理一下流程的&#xff0c;但是今天试了一下这个俄罗斯方块的游戏结果发现本来修改的好好的的&#xff0c;结果后面越改越乱&#xff0c;前面的版本也没保存&#xff0c;根据AI修改他是在几个版本改来改去&#xff0c;想着要求还是不能这么高。…

Kotlin带接收者的Lambda介绍和应用(封装DialogFragment)

先来看一个具体应用&#xff1a;假设我们有一个App&#xff0c;App中有一个退出应用的按钮&#xff0c;点击该按钮后并不是立即退出&#xff0c;而是先弹出一个对话框&#xff0c;询问用户是否确定要退出&#xff0c;用户点了确定再退出&#xff0c;点取消则不退出&#xff0c;…

ES6/ES11知识点 续一

模板字符串 在 ECMAScript&#xff08;ES&#xff09;中&#xff0c;模板字符串&#xff08;Template Literals&#xff09;是一种非常强大的字符串表示方式&#xff0c;它为我们提供了比传统字符串更灵活的功能&#xff0c;尤其是在处理动态内容时。模板字符串通过反引号&…

【C++】智能指针RALL实现shared_ptr

个人主页 &#xff1a; zxctscl 专栏 【C】、 【C语言】、 【Linux】、 【数据结构】、 【算法】 如有转载请先通知 文章目录 1. 为什么需要智能指针&#xff1f;2. 内存泄漏2.1 什么是内存泄漏&#xff0c;内存泄漏的危害2.2 内存泄漏分类&#xff08;了解&#xff09;2.3 如何…

ROS2 开发踩坑记录(持续更新...)

1. 从find_package(xxx REQUIRED)说起&#xff0c;如何引用其他package(包&#xff09; 查看包的安装位置和include路径详细文件列表 例如&#xff0c;xxx包名为pluginlib # 查看 pluginlib 的安装位置 dpkg -L ros-${ROS_DISTRO}-pluginlib | grep include 这条指令的目的是…

系统思考:困惑源于内心假设

不要怀疑&#xff0c;你的困惑来自你的假设。 你是否曾经陷入过无解的困境&#xff0c;觉得外部环境太复杂&#xff0c;自己的处境无法突破&#xff1f;很多时候&#xff0c;答案并不在于外部的局势&#xff0c;而是来自我们内心深处的假设——那些我们理所当然、从未质疑过的…

GitHub修炼法则:第一次提交代码教学(Liunx系统)

前言 github是广大程序员们必须要掌握的一个技能&#xff0c;万事开头难&#xff0c;如果成功提交了第一次代码&#xff0c;那么后来就会简单很多。网上的相关资料往往都不是从第一次开始&#xff0c;导致很多新手们会在过程中遇到很多权限认证相关的问题&#xff0c;进而被卡…

沥青路面裂缝的目标检测与图像分类任务

文章题目是《A grid‐based classification and box‐based detection fusion model for asphalt pavement crack》 于2023年发表在《Computer‐Aided Civil and Infrastructure Engineering》 论文采用了一种基于网格分类和基于框的检测&#xff08;GCBD&#xff09;&#xff…

【Flask】ORM模型以及数据库迁移的两种方法(flask-migrate、Alembic)

ORM模型 在Flask中&#xff0c;ORM&#xff08;Object-Relational Mapping&#xff0c;对象关系映射&#xff09;模型是指使用面向对象的方式来操作数据库的编程技术。它允许开发者使用Python类和对象来操作数据库&#xff0c;而不需要直接编写SQL语句。 核心概念 1. ORM模型…