【深度学习】计算机视觉(18)——从应用到设计

文章目录

  • 1 不同的注意力机制
    • 1.1 自注意力
    • 1.2 多头注意力
    • 1.3 交叉注意力
      • 1.3.1 基础
      • 1.3.2 进阶

1 不同的注意力机制

在学习的过程中,发现有很多计算注意力的方法,例如行/列注意力、交叉注意力等,如果对注意力机制本身不是特别实现,很难进行自己的网络设计。

1.1 自注意力

在这里插入图片描述

又拿出这张快被我盘包浆的图。假设输入序列的维度为(batch_size, seq_len, d_model),通过线性变换矩阵 W Q , W K , W V ∈ R d m o d e l × d m o d e l W^Q, W^K, W^V ∈ \mathbb{R}^{d_{model}×d_{model}} WQ,WK,WVRdmodel×dmodel生成 Q Q Q/ K K K/ V V V,形状为(batch_size, seq_len, d_model)。注意到, Q ⋅ K T Q·K^T QKT再通过Softmax操作得到了Attention Map,是注意力权重矩阵(后续用 A A A表示)。通过之前的学习可以知道,注意力权重矩阵A的格式为二维矩阵,形状为(batch_size, n,n),其中 n n n是输入序列的长度(即token数量)。假设输入序列长度为3,每个token的长度为4:
在这里插入图片描述
那么 A A A中红色格子表示第二个token与第三个token的关联,即 A [ i ] [ j ] A[i][j] A[i][j]每个元素表示输入序列中第 i i i个序列对第 j j j个序列的注意力权重。这里要注意,是否 A A A是一个以对角线为对称轴的对称矩阵呢?虽然 Q ⋅ K T Q·K^T QKT是对称的,但是经过Softmax后,每一行都会转换为概率分布,这样“位置3对位置2的影响”与“位置2对位置3的影响”就不同了。
接下来要计算 A ⋅ V A·V AV,表示每个位置综合其他位置的加权求和。
在这里插入图片描述

1.2 多头注意力

若使用多头注意力,只是列的长度发生改变,被均分成头的数量。假设输入序列的维度为(batch_size, seq_len, d_model),通过线性变换矩阵 W Q , W K , W V ∈ R d m o d e l × d k W^Q, W^K, W^V ∈ \mathbb{R}^{d_{model}×d_{k}} WQ,WK,WVRdmodel×dk生成 Q Q Q/ K K K/ V V V,形状为(batch_size, seq_len, d_k),其中 d k = d m o d e l h d_k=\frac{d_model}{h} dk=hdmodel(h为多头注意力头数)。

在多头注意力(Multi-Head Attention)中, A A A的格式会扩展为四维张量:(batch_size, num_heads, n, n),batch_size表示样本批次大小,num_heads表示注意力头数,n表示序列长度。

1.3 交叉注意力

1.3.1 基础

标准的自注意力机制中, Q Q Q/ K K K/ V V V通常由同一个输入矩阵 x x x通过不同的线性变换生成。自注意力机制关注于单一输入序列内部元素之间的关系,通过同源输入捕捉序列内部依赖关系。

交叉注意力(Cross-Attention)则关注于两个不同输入序列之间的相互作用。 Q Q Q K K K可以分布来自不同的输入序列,常见于编码器-解码器架构。

在Transformer模型中,CrossAttention通常用于编码器和解码器之间的交互。编码器负责将输入序列编码为一系列特征向量,而解码器则根据这些特征向量逐步生成输出序列。为了使解码器能够更有效地利用编码器的信息,CrossAttention层被引入其中。解码器的每个位置会生成一个查询向量(query),该向量用于在编码器的所有位置进行注意力权重计算。编码器的每个位置则生成一组键向量(keys)和值向量(values)。通过计算查询向量与键向量的相似度,并经过softmax函数归一化后,得到注意力权重。最后,注意力权重与值向量相乘并求和,得到编码器调整后的输出,供解码器使用。

Q Q Q是来自解码器的当前状态(例如翻译任务中的目标语言词, K K K V V V是来自编码器的输出(例如源语言的特征)。
Softmax仅要求 Q Q Q K K K的维度匹配,并不限制来源。

假设 Q Q Q的输入形状为(batch_size, seq_len_q, d_model),seq_len_q为目标序列长度; K K K/ V V V的输入形状为(batch_size, seq_len_kv, d_model),seq_len_kv为源序列长度,输出的形状为(batch_size, h, seq_len_q, seq_len_kv)

最终输出的注意力权重矩阵 A A A作用于 V V V矩阵,生成融合跨序列信息的输出 O u t p u t = A ⋅ V Output=A·V Output=AV

Q Q Q由解码器的自注意力层输出生成,解码器在生成目标序列的每一步时,会将已生成的部分序列通过掩码自注意力层处理,生成当前步的上下文表示,这一表示作为 Q Q Q的输入。

以机器翻译为例,将“我喜欢狼”由中文翻译成英文。每次生成一个词,假设当前已经生成了"I"、“LIKE”,接下来要进行后面的词的翻译,如下图。 x 1 ′ x1' x1是已经生成的上下文表示,由解码器的自注意力层输出。 x 2 ′ x2' x2是源序列“我喜欢狼”经编码器输出的特征向量。
在这里插入图片描述
图示红色问号表示待生成的词语,当生成第三个目标词时,原矩阵新增一行,该行表示问号词对源序列所有三个词的关注权重,而该行的初始值是基于已生成的词的嵌入向量和位置编码生成。对于A,表示接下来要生成的词与源序列的相关度,比如红色阴影部分表示问号词与“我”的语义依赖强度。

​强对齐​​:目标词与源词存在直接翻译关系(如"WOLVES"→"狼"),对应权重接近1。
​​弱对齐​​:目标词依赖源序列的上下文(如生成冠词 “THE” 时可能关注源序列的主语位置)。
​​零权重​​:源词与当前目标词无关(如生成英文标点时,权重集中于源序列的句尾词)。

1.3.2 进阶

编码器的 K K K V V V在推理时是固定不变的,但解码器的 Q Q Q随着目标序列生成动态扩展。例如,生成“WOLVES”时, Q 3 Q_3 Q3需与编码器的 K K K计算相似度,而历史 Q Q Q和编码器的 K K K可能已经被缓存,然后只需要计算 ∑ j = 0 4 A 3 ⋅ V j \sum_{j=0}^{4}A_3·V_j j=04A3Vj即可。

训练阶段,不需要考虑生成词的先后顺序,模型并行处理整个目标序列而非逐词生成,此时所有目标词的 Q Q Q必须同时计算,以利用GPU的并行计算能力加速训练。同时,需要通过反向传播更新所有Q的权重矩阵 W Q W_Q WQ,这要求通过计算完整的 Q Q Q矩阵计算所有注意力权重 A A A,才能正确更新权重矩阵 W Q W_Q WQ,如果仅计算 Q 3 ⋅ K T Q_3·K^T Q3KT,将导致 W Q W_Q WQ的梯度无法涵盖历史位置的语义关联。

每个注意力头的 W Q W_Q WQ矩阵是固定维度的(d_model×d_k),将每行向量从原来指定的特征向量长度转换为分多头之后的特征向量长度,无论目标序列长度如何,所有 Q Q Q向量均通过同一组 W Q W_Q WQ进行投影。这种设计使得模型能够处理任意长度的目标序列,但要求所有 Q Q Q的投影逻辑一致。

解码器隐状态 H i H_i Hi指的是解码器在第 i i i步生成的动态时序表示,包含目标序列的生成进度(如已生成的词数 i i i)、上下文语义、目标序列内部依赖,计算公式为 s i = g ( s i − 1 , y i − 1 , c ) s_i=g(s_{i-1}, y_{i-1}, c) si=g(si1,yi1,c),其中 g g g是解码器的更新函数, c c c是编码器的上下文向量。


参考来源:
AIGC
深入理解CrossAttention:交叉注意力机制的奥秘
【深度学习】Cross-Attention(交叉注意力)机制详解与应用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/905265.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

洛谷 P1955 [NOI2015] 程序自动分析

【题目链接】 洛谷 P1955 [NOI2015] 程序自动分析 【题目考点】 1. 并查集 2. 离散化 【解题思路】 多组数据问题,对于每组数据,有多个 x i x j x_ix_j xi​xj​或 x i ≠ x j x_i \neq x_j xi​xj​的约束条件。 所有相等的变量构成一个集合&…

[Java] 输入输出方法+猜数字游戏

目录 1. 输入输出方法 1.1 输入方法 1.2 输出方法 2. 猜数字游戏 1. 输入输出方法 Java中输入和输出是属于Scanner类里面的方法,如果要使用这两种方法需要引用Scanner类。 import java.util.Scanner; java.util 是Java里面的一个包,里面包含一些工…

zst-2001 上午题-历年真题 UML(13个内容)

UML基础 UML - 第1题 ad UML - 第2题 依赖是暂时使用对象,关联是长期连接 依赖:依夜情 关联:天长地久 组合:组一辈子乐队 聚合:好聚好散 bd UML - 第3题 adc UML - 第4题 bad UML - 第5题 d UML…

WebFlux vs WebMVC vs Servlet 对比

WebFlux vs WebMVC vs Servlet 技术对比 WebFlux、WebMVC 和 Servlet 是 Java Web 开发中三种不同的技术架构,它们在编程模型、并发模型和适用场景上有显著区别。以下是它们的核心对比: 核心区别总览 特性ServletSpring WebMVCSpring WebFlux编程模型…

htmlUnit和Selenium的区别以及使用BrowserMobProxy捕获网络请求

1. Selenium:浏览器自动化之王 核心定位: 跨平台、跨语言的浏览器操控框架,通过驱动真实浏览器实现像素级用户行为模拟。 技术架构: 核心特性: 支持所有主流浏览器(含移动端模拟) 精…

SSRF相关

SSRF(Server Side Request Forgery,服务器端请求伪造),攻击者以服务器的身份发送一条构造好的请求给服务器所在地内网进行探测或攻击。 产生原理: 服务器端提供了能从其他服务器应用获取数据的功能,如从指定url获取网页内容、加载指定地址的图…

SaaS备份的必要性:厂商之外的数据保护策略

在当今数字化时代,企业对SaaS(软件即服务)应用的依赖程度不断攀升。SaaS应用为企业提供了便捷的生产力工具,然而,这也使得数据安全面临诸多挑战,如意外删除、勒索软件攻击以及供应商故障等。因此&#xff0…

【Python 基础语法】

Python 基础语法是编程的基石,以下从核心要素到实用技巧进行系统梳理: 一、代码结构规范 缩进规则 使用4个空格缩进(PEP 8标准)缩进定义代码块(如函数、循环、条件语句) def greet(name):if name: # 正确缩…

利用“Flower”实现联邦机器学习的实战指南

一个很尴尬的现状就是我们用于训练 AI 模型的数据快要用完了。所以我们在大量的使用合成数据! 据估计,目前公开可用的高质量训练标记大约有 40 万亿到 90 万亿个,其中流行的 FineWeb 数据集包含 15 万亿个标记,仅限于英语。 作为…

自动化测试与功能测试详解

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 什么是自动化测试? 自动化测试是指利用软件测试工具自动实现全部或部分测试,它是软件测试的一个重要组成 部分,能完成许多手工测试无…

MySQL全量,增量备份与恢复

目录 一.MySQL数据库备份概述 1.数据备份的重要性 2.数据库备份类型 3.常见的备份方法 二:数据库完全备份操作 1.物理冷备份与恢复 2.mysqldump 备份与恢复 3.MySQL增量备份与恢复 3.1MySQL增量恢复 3.2MySQL备份案例 三:定制企业备份策略思路…

Ubuntu 安装 Nginx

Nginx 是一个高性能的 Web 服务器和反向代理服务器,同时也可以用作负载均衡器和 HTTP 缓存。 Nginx 的主要用途 用途说明Web服务器提供网页服务,处理用户的 HTTP 请求,返回 HTML、CSS、JS、图片等静态资源。反向代理服务器将用户请求转发到…

人工智能 机器学习期末考试题

自测试卷2 一、选择题 1.下面哪个属性不是NumPy中数组的属性( )。 A.ndim B.size C.shape D.add 2.一个简单的Series是由( )的数据组成的。 A.两…

使用阿里云CLI调用OpenAPI

介绍使用阿里云CLI调用OpenAPI的具体操作流程,包括安装、配置凭证、生成并调用命令等步骤。 方案概览 使用阿里云CLI调用OpenAPI,大致分为四个步骤: 安装阿里云CLI:根据您使用设备的操作系统,选择并安装相应的版本。…

K8S Svc Port-forward 访问方式

在 Kubernetes 中,kubectl port-forward 是一种 本地与集群内资源(Pod/Service)建立临时网络隧道 的访问方式,无需暴露服务到公网,适合开发调试、临时访问等场景。以下是详细使用方法及注意事项: 1. 基础用…

23、DeepSeek-V2论文笔记

DeepSeek-V2 1、背景2、KV缓存优化2.0 KV缓存(Cache)的核心原理2.1 KV缓存优化2.2 性能对比2.3 架构2.4多头注意力 (MHA)2.5 多头潜在注意力 (MLA)2.5.1 低秩键值联合压缩 (Low-Rank Key-Value …

MySQL OCP试题解析(2)

试题如下图所示: 一、题目背景还原 假设存在以下MySQL用户权限配置: -- 创建本地会计用户CREATE USER accountinglocalhost IDENTIFIED BY acc_123;-- 创建匿名代理用户(用户名为空,允许任意主机)CREATE USER % IDENTI…

深度学习Y7周:YOLOv8训练自己数据集

🍨 本文为🔗365天深度学习训练营中的学习记录博客🍖 原作者:K同学啊 一、配置环境 1.官网下载源码 2.安装需要环境 二、准备好自己的数据 目录结构: 主目录 data images(存放图片) annotati…

英伟达Blackwell架构重构未来:AI算力革命背后的技术逻辑与产业变革

——从芯片暴力美学到分布式智能体网络,解析英伟达如何定义AI基础设施新范式 开篇:当算力成为“新石油”,英伟达的“炼油厂”如何升级? 2025年3月,英伟达GTC大会上,黄仁勋身披标志性皮衣,宣布了…

CurrentHashMap的整体系统介绍及Java内存模型(JVM)介绍

当我们提到ConurrentHashMap时,先想到的就是HashMap不是线程安全的: 在多个线程共同操作HashMap时,会出现一个数据不一致的问题。 ConcurrentHashMap是HashMap的线程安全版本。 它通过在相应的方法上加锁,来保证多线程情况下的…