transformer-实现单层Decoder 层

Decoder Layer

  • 论文地址

    https://arxiv.org/pdf/1706.03762

解码器层结构

  • Transformer解码器层由三种核心组件构成:

    1. Masked多头自注意力:关注解码器序列当前位置之前的上下文(因果掩码)

    2. Encoder-Decoder多头注意力:关注编码器输出的相关上下文

    3. 前馈神经网络:进行非线性特征变换

      image-20250429162448704

    今天这里实现的是上图中蓝色框中的单层DecoderLayer,不包含 embedding和位置编码,以及最后的Linear和Softmax。

    主要处理流程:

    1. Decoder 的Masked自注意力
    2. Encoder-Decoder自注意力
    3. 前馈神经网络:进行非线性特征变换
    4. 残差连接 + 层归一化
    5. Dropout:最终输出前进行随机失活

数学表达

  • 解码器层计算过程分为三个阶段:

    1. Masked自注意力阶段

    MaskedAtt ( Q , K , V ) = LayerNorm ( MultiHead ( Q , K , V ) + R e s i d u a l ) \text{MaskedAtt}(Q,K,V) = \text{LayerNorm}(\text{MultiHead}(Q,K,V) + Residual) MaskedAtt(Q,K,V)=LayerNorm(MultiHead(Q,K,V)+Residual)

    1. Encoder-Decoder注意力阶段

    CrossAtt ( Q d e c , K e n c , V e n c ) = LayerNorm ( MultiHead ( Q d e c , K e n c , V e n c ) + R e s i d u a l ) \text{CrossAtt}(Q_{dec}, K_{enc}, V_{enc}) = \text{LayerNorm}(\text{MultiHead}(Q_{dec},K_{enc},V_{enc}) + Residual) CrossAtt(Qdec,Kenc,Venc)=LayerNorm(MultiHead(Qdec,Kenc,Venc)+Residual)

    1. 前馈网络阶段

    FFN ( x ) = LayerNorm ( ReLU ( x W 1 + b 1 ) W 2 + b 2 + x ) \text{FFN}(x) = \text{LayerNorm}(\text{ReLU}(xW_1 + b_1)W_2 + b_2 + x) FFN(x)=LayerNorm(ReLU(xW1+b1)W2+b2+x)

    其中:

    1. d_model 为模型维度
    2. Residual 为残差连接
    3. 下标dec来源于Decoder自己的输出,下标enc为Encoder的输出

代码实现

  • 实现单层

    其他层的实现

    层名链接
    PositionEncodinghttps://blog.csdn.net/hbkybkzw/article/details/147431820
    calculate_attentionhttps://blog.csdn.net/hbkybkzw/article/details/147462845
    MultiHeadAttentionhttps://blog.csdn.net/hbkybkzw/article/details/147490387
    FeedForwardhttps://blog.csdn.net/hbkybkzw/article/details/147515883
    LayerNormhttps://blog.csdn.net/hbkybkzw/article/details/147516529
    EncoderLayerhttps://blog.csdn.net/hbkybkzw/article/details/147591824

    下面统一在before.py中导入

  • 实现单层的DecoderLayer

    import torch 
    from torch import nnfrom before import PositionEncoding,calculate_attention,MultiHeadAttention,FeedForward,LayerNormclass DecoderLayer(nn.Module):def __init__(self, n_heads, d_model, ffn_hidden, dropout_prob=0.1):super(DecoderLayer, self).__init__()self.masked_att = MultiHeadAttention(n_heads=n_heads, d_model=d_model, dropout_prob=dropout_prob)self.att = MultiHeadAttention(n_heads=n_heads, d_model=d_model, dropout_prob=dropout_prob)self.norms = nn.ModuleList([LayerNorm(d_model=d_model) for _ in range(3)])  # 三个归一化层self.ffn = FeedForward(d_model=d_model, ffn_hidden=ffn_hidden, dropout_prob=dropout_prob)self.dropout = nn.Dropout(dropout_prob)def forward(self, x, encoder_kv, dst_mask=None, src_dst_mask=None):# 第一阶段:Decoder 的Masked自注意力_x = xmask_att_out = self.masked_att(q=x, k=x, v=x, mask=dst_mask)mask_att_out = self.norms[0](mask_att_out + _x)  # 残差连接后归一化# 第二阶段:Encoder-Decoder注意力_x = mask_att_outatt_out = self.att(q=mask_att_out, k=encoder_kv, v=encoder_kv, mask=src_dst_mask)att_out = self.norms[1](att_out + _x)# 第三阶段:前馈网络_x = att_outffn_out = self.ffn(att_out)ffn_out = self.norms[2](ffn_out + _x)return self.dropout(ffn_out)
    
  • 注意力掩码机制

    掩码类型作用域功能描述
    dst_mask目标序列自注意力防止当前位置关注未来信息(因果掩码)
    src_dst_mask编码器-解码器注意力控制解码器查询对编码器键值对的访问权限
  • 参数说明

    参数名类型说明
    n_headsint注意力头数量
    d_modelint模型隐藏层维度
    ffn_hiddenint前馈网络中间层维度(通常4倍)
    dropout_probfloatDropout概率(默认0.1)

使用示例

  • 测试代码

    if __name__ == "__main__":# 实例化解码器层:8头,512维,前馈层2048,20% dropoutdecoder_layer = DecoderLayer(n_heads=8, d_model=512, ffn_hidden=2048, dropout_prob=0.2)# 模拟输入:batch_size=4,目标序列长度50,编码器输出长度80x = torch.randn(4, 50, 512)encoder_out = torch.randn(4, 80, 512)tgt_mask = Nonesrc_mask = Noneoutput = decoder_layer(x, encoder_out, dst_mask=tgt_mask, src_dst_mask=src_mask)print("输入形状:", x.shape)print("encode_kv 形状:", encoder_out.shape)print("输出形状:", output.shape)
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/81014.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设计模式每日硬核训练 Day 16:责任链模式(Chain of Responsibility Pattern)完整讲解与实战应用

🔄 回顾 Day 15:享元模式小结 在 Day 15 中,我们学习了享元模式(Flyweight Pattern): 通过共享对象,分离内部状态与外部状态,大量减少内存开销。适用于字符渲染、游戏场景、图标缓…

大数据开发环境的安装,配置(Hadoop)

1. 三台linux服务器的安装 1. 安装VMware VMware虚拟机软件是一个“虚拟PC”软件,它使你可以在一台机器上同时运行二个或更多Windows、DOS、LINUX系统。与“多启动”系统相比,VMWare采用了完全不同的概念。 我们可以通过VMware来安装我们的linux虚拟机…

多模态大语言模型arxiv论文略读(四十九)

When Do We Not Need Larger Vision Models? ➡️ 论文标题:When Do We Not Need Larger Vision Models? ➡️ 论文作者:Baifeng Shi, Ziyang Wu, Maolin Mao, Xin Wang, Trevor Darrell ➡️ 研究机构: UC Berkeley、Microsoft Research ➡️ 问题背…

【深度学习与大模型基础】第14章-分类任务与经典分类算法

Part 1:什么是分类任务? 1.1 分类就是“贴标签” 想象你有一堆水果,有苹果🍎、橘子🍊、香蕉🍌,你的任务是让机器学会自动判断一个新水果属于哪一类——这就是分类(Classification&…

LeetCode 2906 统计最大元素出现至少K次的子数组(滑动窗口)

给出一个示例: 输入:nums [1,3,2,3,3], k 2 输出:6 解释:包含元素 3 至少 2 次的子数组为:[1,3,2,3]、[1,3,2,3,3]、[3,2,3]、[3,2,3,3]、[2,3,3] 和 [3,3] 。该题也是一个比较简单的滑动窗口的题目,但是…

使用 Spring Boot 进行开发

✨ 使用 Spring Boot 进行开发 ✨ 📌 本节将深入介绍如何高效使用 Spring Boot,涵盖以下核心主题: 1️⃣ 🔧 构建系统 深入了解 Spring Boot 的项目结构和依赖管理 2️⃣ ⚙️ 自动配置 探索 Spring Boot 的自动化配置机制和原…

Qt的WindowFlags窗口怎么选?

Qt.Dialog: 指示窗口是一个对话框,这通常会改变窗口的默认按钮布局,并可能影响窗口框架的样式。Qt.Popup: 指示窗口是一个弹出式窗口(例如菜单或提示),它通常是临时的且没有任务栏按钮。Qt.Tool: 标识窗口作为一个工具…

Redis高可用架构全解析:主从复制、哨兵模式与集群实战指南

Redis高可用架构全解析:主从复制、哨兵模式与集群实战指南 引言 在分布式系统架构中,Redis作为高性能内存数据库的标杆,其高可用与扩展性设计始终是开发者关注的焦点。本文将深入剖析Redis的三大核心机制——主从复制、哨兵模式与集群架构&…

音视频之H.265/HEVC网络适配层

H.265/HEVC系列文章: 1、音视频之H.265/HEVC编码框架及编码视频格式 2、音视频之H.265码流分析及解析 3、音视频之H.265/HEVC预测编码 4、音视频之H.265/HEVC变换编码 5、音视频之H.265/HEVC量化 6、音视频之H.265/HEVC环路后处理 7、音视频之H.265/HEVC熵编…

element-plus(vue3)表单el-select下拉框的远程分页下拉触底关键字搜索实现

一、基础内核-自定义指令 1.背景 2.定义 3.使用 4.注意 当编辑时需要回显,此时由于分页导致可能匹配不到对应label文本显示,此时可以这样解决 二、升级使用-二次封装组件 三、核心代码 1.自定义指令 定义 ----------------selectLoadMoreDirective.…

大内存生产环境tomcat-jvm配置实践

话不多讲,奉上代码,分享经验,交流提高! 64G物理内存,8核CPU生产环境tomcat-jvm配置如下: JAVA_OPTS-server -XX:MaxMetaspaceSize4G -XX:ReservedCodeCacheSize2G -XX:UseG1GC -Xms48G -Xmx48G -XX:MaxGCPauseMilli…

C++函数模板基础

1 函数模板 1.1 基础介绍 函数模板是一种特殊的函数定义,它允许你创建通用的函数,这些函数可以处理多种不同的数据类型,而不需要为每种数据类型都编写一个单独的函数。 在 C++ 里,函数模板的格式包含模板声明与函数定义两部分,其基本格式如下: template <typename…

mangodb的数据库与集合命令,文档命令

MongoDB的下载安装与启动&#xff0c; 一、MongoDB下载安装 1. 官网下载 打开官网&#xff1a;https://www.mongodb.com/try/download/community选择&#xff1a; 版本&#xff08;Version&#xff09;&#xff1a;选最新版或者根据需要选旧版。平台&#xff08;OS&#xff0…

flink端到端数据一致性

这里有一个注意点&#xff0c;就是flink端的精准一次 1.barrier对齐精准和一次非对齐精准一次 对比​​ ​​维度​​​​Barrier 对齐的精准一次​​​​Barrier 非对齐的精准一次​​​​触发条件​​需等待所有输入流的 Barrier 对齐后才能触发检查点 收到第一个 Barrier …

4月29号

级别越大,字体越小. CSS样式控制: 例如把日期设为灰色字体

PHP代码-服务器下载文件页面编写

内部环境的服务资源下载页面有访问需求&#xff0c;给开发和产品人员编写一个简洁的下载页面提供资源下载。直接用nginxphp的形式去编写了&#xff0c;这里提供展示index.php文件代码如下&#xff1a; <?php // 配置常量 define(BASE_DIR, __DIR__); // 当前脚本所在目录作…

MySQL基础关键_001_认识

目 录 一、概述 1.数据库&#xff08;DB&#xff09;分类 &#xff08;1&#xff09;关系型数据库 &#xff08;2&#xff09;非关系型数据库 2.数据库管理系统&#xff08;DBMS&#xff09; 3.SQL &#xff08;1&#xff09;说明 &#xff08;2&#xff09;分类 二、…

Shell、Bash 执行方式及./ 执行对比详解

Shell、Bash 执行方式及./ 执行对比详解 在 Linux 和 UNIX 系统的使用过程中&#xff0c;Shell 脚本是实现自动化任务、系统管理的重要工具。而在执行 Shell 脚本时&#xff0c;我们常常会用到bash命令以及./的执行方式&#xff0c;这两种执行方式看似相似&#xff0c;实则存在…

P1494 [国家集训队] 小 Z 的袜子 Solution

Description 给定序列 a ( a 1 , a 2 , ⋯ , a n ) a(a_1,a_2,\cdots,a_n) a(a1​,a2​,⋯,an​)&#xff0c;有 q q q 次查询&#xff0c;每次查询给定 ( l , r ) (l,r) (l,r). 你需要求出 2 ∑ i ≤ i < j ≤ r [ a i a j ] ( r − l ) ( r − l 1 ) \dfrac{2\sum…

解决vue3 路由query传参刷新后数据丢失的问题

前言&#xff1a;在页面刷新的时候&#xff0c;路由query数据会被清空&#xff0c;网上很多方法说query传参可以实现&#xff0c;反正我是没有实现 思路&#xff1a;将数据保存到本地&#xff0c;通过 “ &#xff1f;” 进行判断是否有数据&#xff0c;页面销毁的时候删除本地…