[machine learning] Transformer - Attention (二)

本文介绍带训练参数的self-attention,即在transformer中使用的self-attention。

首先引入三个可训练的参数矩阵Wq, Wk, Wv,这三个矩阵用来将词向量投射(project)到query, key, value三个向量上。下面我们再定义几个变量:

import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)x_2 = inputs[1] # second input element
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d=2

d_in是输入维度,d_out是输出维度,在transformer模型中d_in和d_out通常是相同的,但是为了更好地理解其中的计算过程,这里我们选择不同的d_in=3和d_out=2。

下面我们初始化三个参数矩阵:

torch.manual_seed(123)W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

这里为了介绍方便,我们设置requires_grad=False, 如果是训练的话,我们要设置requires_grad=True

下面我们计算第二个单词的query, key, value:

# @ 是矩阵乘法
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)
# tensor([0.4306, 1.4551])

可以看到,经过参数矩阵project之后,query, key, value的维度已经转变成2。

下面我们通过计算第二个单词的上下文向量来介绍整个过程。

keys_2 = keys[1] # Python starts index at 0
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)
# tensor(1.8524)attn_scores_2 = query_2 @ keys.T # All attention scores for given query
print(attn_scores_2)
# tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])d_k = keys.shape[1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)
# tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])context_vec_2 = attn_weights_2 @ values
print(context_vec_2)
# tensor([0.3061, 0.8210])

self-attention又叫做scaled-dot product attention,正是因为在求attention weight的时候对attention score做了除以向量维度的平方根来做缩放(scale)。

为什么要除以向量维度的平方根?

除以向量维度的平方根主要是避免太小的梯度以提升训练性能。对于像GPT这样的大型LLM模型,它的向量维度通常会超过上千,那么向量和向量之间的点乘结果就会非常大。而我们知道,对于Softmax函数,如果输入值很大或者很小的话,它是非常平缓的,非常平缓也就意味着梯度很小以至于接近于0。这就导致训练中的反向传播时,会出现非常小的梯度,从而引发梯度消失问题,极大地降低了模型学习的速度,引发训练停滞。下面是Softmax的函数图像:
在这里插入图片描述

怎么理解query, key, value?

query, key, value是借用数据库信息提取领域的概念。query用来搜索信息,key用来存储信息,value用来提取信息。
query:类似于数据库中的查找。它代表着当前模型关心的单词。query用来探测输入序列中的其他单词,去决定当前单词和其他单词的相关性。
key:类似于数据库中的索引。attention机制中,每个单词都有自己的key,这些key用来和query匹配。
value:类似于数据库中key-value对中的值。它代表着输入序列中单词的实际内容或者单词的实际表示。当query探测发现和某些key最相关,它就会提取跟这些key关联的value。

下面实现一个self-attention类:

import torch.nn as nnclass SelfAttention_v1(nn.Module):def __init__(self, d_in, d_out):super().__init__()self.W_query = nn.Parameter(torch.rand(d_in, d_out))self.W_key   = nn.Parameter(torch.rand(d_in, d_out))self.W_value = nn.Parameter(torch.rand(d_in, d_out))def forward(self, x):keys = x @ self.W_keyqueries = x @ self.W_queryvalues = x @ self.W_valueattn_scores = queries @ keys.T # omegaattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)context_vec = attn_weights @ valuesreturn context_vectorch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))# 输出:
# tensor([[0.2996, 0.8053],
#        [0.3061, 0.8210],
#        [0.3058, 0.8203],
#        [0.2948, 0.7939],
#        [0.2927, 0.7891],
#        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

可以看到,输出结果的第二行和上面单独计算第二个单词的上下文向量是一致的。
我们可以使用nn.Linear来改进上面的self-attention类。相比于手动实现nn.Parameter(torch.rand(...))nn.Linear会优化权重的初始化,因此会使模型训练更稳定有效。

class SelfAttention_v2(nn.Module):def __init__(self, d_in, d_out, qkv_bias=False):super().__init__()self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)def forward(self, x):keys = self.W_key(x)queries = self.W_query(x)values = self.W_value(x)attn_scores = queries @ keys.Tattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)context_vec = attn_weights @ valuesreturn context_vectorch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))# 输出:
# tensor([[-0.0739,  0.0713],
#        [-0.0748,  0.0703],
#        [-0.0749,  0.0702],
#        [-0.0760,  0.0685],
#        [-0.0763,  0.0679],
#        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

由于SelfAttention_v1SelfAttention_v2使用了不同的权重初始化方法,因此它们的输出结果是不一样的。



参考资料:
《Build a Large Language Model from Scratch》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/81555.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

施磊老师rpc(三)

文章目录 mprpc框架项目动态库编译框架生成动态库框架初始化函数-文件读取1. 为什么要传入 argc, argv2. 读取参数逻辑3. 配置文件设计 init部分实现 mprpc配置文件加载(一)配置文件加载类成员变量主要方法**src/include/mprpcconfig.h** 配置文件**bin/test.conf** 实现配置文…

文献分享:通过简单的生物偶联策略将肽双特异性抗体(pBsAbs)应用于免疫治疗

背景 双特异性抗体是将单克隆抗体的两个不同抗原结合位点融合成一个单一实体的人工分子。它们已经成为一种很有前景的下一代抗癌治疗方法。尽管双特异性抗体的应用令人着迷&#xff0c;但双特异性抗体的设计和生产仍然繁琐而富有挑战性&#xff0c;导致研发过程漫长&#xff0…

二、shell脚本--变量与数据类型

1. 变量的定义与使用 定义变量&#xff1a;简单直接 在 Shell 里定义变量相当容易&#xff1a; 基本格式: variable_namevalue关键点 ❗&#xff1a;赋值号 的两边绝对不能有空格&#xff01;这绝对是初学者最容易踩的坑之一 &#x1f628;&#xff0c;务必留意&#xff01…

java_Lambda表达式

1、背景 lambda表达式是Java SE 8中一个重要的新特性。lambda表达式允许你通过表达式来代替功能接口。lambda表达式就和方法一样样&#xff0c;它提供了一个正常的参数列表和一个使用这些参数的主体&#xff08;body&#xff0c;可以是一个表达式和一个代码块&#xff09;。La…

给QCustomPlot添加一个QScrollBar滚动条、限制缩放范围、自动设置大小和右边栏垂直缩放

实现效果 实现思路 从QCustomPlot类派生一个类,进行个性化设置,在轴矩形的上边设置Margin,放一个滚动条,设置滚动条的样式 常量定义 #define NQSCRB 1000构造函数初始化 // 设置QScrollBar的样式// 顶部空--5,左侧空--6

实验-组合电路设计1-全加器和加法器(数字逻辑)

目录 一、实验内容 二、实验步骤 2.1 全加器的设计 2.2 加法器的设计 三、调试过程 3.1 全加器调试过程 2.加法器的调试过程 四、实验使用环境 五、实验小结和思考 一、实验内容 a) 介绍 在这次实验中&#xff0c;你将熟悉 Logisim 的操作流程&#xff0c;并且学习…

Linux进程控制与替换详解

进程创建 fork函数初识 在linux中fork函数是非常重要的函数,它从已存在进程中创建⼀个新进程。新进程为子进程,而原进程为父进程。 进程调用fork,当控制转移到内核中的fork代码后,内核做: • 分配新的内存块和内核数据结构给子进程 • 将父进程部分数据结构内容拷贝至…

Vue3学习笔记2——路由守卫

路由守卫 全局 router.beforeEach((to, from, next) > {})router.afterEach((to, from, next) > {}) 组件内守卫 beforeRouteEnter((to, from, next) > {})beforeRouteUpdate((to, from, next) > {})beforeRouteLeave((to, from, next) > {}) 路由独享 be…

AI与无人零售:如何通过智能化技术提升消费者体验和运营效率?

引言&#xff1a;无人零售不只是无人值守 你走进一家无人便利店&#xff0c;没有迎宾、没有收银员&#xff0c;甚至没有一个人在场&#xff0c;但你刚拿起商品&#xff0c;货架旁的摄像头就悄悄“看懂”了你的动作&#xff0c;系统已经在后台为你记账。你以为只是没人管&#x…

如何在3dMax中使用UVW展开修改器?

UVW展开(Unwrap UVW)修改器是3dmax中的一个强大工具,允许对纹理如何应用于3D模型进行精确控制。 与更简单的UVW Map修改器不同,Unwrap UVW修改器提供了高级选项,用于手动编辑纹理映射,对于详细和复杂的模型来说是必不可少的。 在本文中,我们将探讨增强您对Unwrap UVW修…

【Linux】进程优先级与进程切换理解

&#x1f31f;&#x1f31f;作者主页&#xff1a;ephemerals__ &#x1f31f;&#x1f31f;所属专栏&#xff1a;Linux 目录 前言 一、进程优先级 1. 什么是进程优先级 2. 为什么有进程优先级 3. 进程优先级的作用 4. Linux进程优先级的本质 5. 修改进程优先级 二、进…

【Hive入门】Hive高级特性:事务表与ACID特性详解

目录 1 Hive事务概述 2 ACID特性详解 3 Hive事务表的配置与启用 3.1 启用Hive事务支持 3.2 创建事务表 4 Hive事务操作流程 5 并发控制与隔离级别 5.1 Hive的锁机制 5.2 隔离级别 6 Hive事务的限制与优化 6.1 主要限制 6.2 性能优化建议 7 事务表操作示例 7.1 基本…

二叉树算法精解(Java 实现):从遍历到高阶应用

引言 二叉树&#xff08;Binary Tree&#xff09;作为算法领域的核心数据结构&#xff0c;在搜索、排序、数据库索引、编译器语法树构建等众多场景中都有着广泛应用。无论是初学者夯实算法基础&#xff0c;还是求职者备战技术面试&#xff0c;掌握二叉树相关算法都是不可或缺的…

ES6入门---第二单元 模块二:关于数组新增

一、扩展运算符。。。 1、可以把ul li转变为数组 <script>window.onloadfunction (){let aLi document.querySelectorAll(ul li);let arrLi [...aLi];arrLi.pop();arrLi.push(asfasdf);console.log(arrLi);};</script> </head> <body><ul><…

Nature正刊:新型折纸启发手性超材料,实现多模式独立驱动,变形超50%!

机械超材料是一种结构化的宏观结构&#xff0c;其几何排列方式具有独特的几何结构&#xff0c;从而具有独特的力学性能和变形模式。超材料的宏观特性取决于中观尺度晶胞的具体形状、尺寸和几何取向。经典的结构化晶胞&#xff0c;例如以拉伸为主的八面体桁架单元和以弯曲为主的…

Servlet(二)

软件架构 1. C/S 客户端/服务器端 2. B/S 浏览器/服务器端&#xff1a; 客户端零维护&#xff0c;开发快 资源分类 1. 静态资源 所有用户看到相同的部分&#xff0c;如&#xff1a;html,css,js 2. 动态资源 用户访问相同资源后得到的结果可能不一致&#xff0c;如&#xff1a;s…

循环缓冲区

# 循环缓冲区 说明 所谓消费&#xff0c;就是数据读取并删除。 循环缓冲区这个数据结构与生产者-消费者问题高度适配。 生产者-产生数据&#xff0c;消费者-处理数据&#xff0c;二者速度不一致&#xff0c;因此需要循环缓冲区。 显然&#xff0c;产生的数据要追加到循环缓…

嵌入式硬件篇---STM32 系列单片机型号命名规则

文章目录 前言一、STM32 型号命名规则二、具体型号解析1. STM32F103C8T6F103:C:8:T6:典型应用2. STM32F103RCT6F103:R:C:T6:典型应用三、命名规则扩展1. 引脚数与封装代码2. Flash 容量代码3. 温度范围代码四、快速识别技巧性能定位:F1/F4后缀差异硬件设计参考:引脚数…

MySQL 中日期相减的完整指南

MySQL 中日期相减的完整指南 在 MySQL 中&#xff0c;日期相减有几种不同的方法&#xff0c;具体取决于你想要得到的结果类型&#xff08;天数差、时间差等&#xff09;。 1. 使用 DATEDIFF() 函数&#xff08;返回天数差&#xff09; SELECT DATEDIFF(2023-05-15, 2023-05-…

传奇各版本迭代时间及内容变化,屠龙/嗜魂法杖/逍遥扇第一次出现的时间和版本

​【早期经典版本】 1.10 三英雄传说&#xff1a;2001 年 9 月 28 日热血传奇正式开启公测&#xff0c;这是传奇的第一个版本。游戏中白天与黑夜和现实同步&#xff0c;升级慢&#xff0c;怪物爆率低&#xff0c;玩家需要靠捡垃圾卖金币维持游戏开销&#xff0c;遇到高级别法师…