Transformer实战-系列教程19:DETR 源码解读6(Transformer类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)

9、Transformer类

位置:models/transformer.py/Transformer类

9.1 _reset_parameters()函数

    def _reset_parameters(self):for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)

这个辅助函数,遍历当前被调用的模型中所有的参数,当这个参数的维度大于1时,会对当前参数使用Xavier均匀初始化方法进行初始化

9.2 构造函数

class Transformer(nn.Module):def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False,return_intermediate_dec=False):super().__init__()encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)encoder_norm = nn.LayerNorm(d_model) if normalize_before else Noneself.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)decoder_norm = nn.LayerNorm(d_model)self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec)self._reset_parameters()self.d_model = d_modelself.nhead = nhead
  1. 定义类,继承PyTorch的nn.Module
  2. 构造函数,传入模型维度、多头注意力的头数、编码器层数、解码器层数、前馈网络的维度、dropout比率、激活函数类型、是否在层之前进行归一化处理、是否返回解码器的中间层输出
  3. 初始化
  4. encoder_layer ,使用TransformerEncoderLayer类创建一个编码器层
  5. encoder_norm ,根据normalize_before的值决定是否创建一个层归一化层
  6. encoder ,使用TransformerEncoder类调用编码器层、编码器层数、归一化等创建编码器
  7. decoder_layer ,使用TransformerDecoderLayer类创建一个解码器层
  8. decoder_norm ,创建一个层归一化层
  9. decoder ,使用TransformerDecoder类调用解码器层、解码器层数、归一化等创建解码器
  10. d_model,模型维度
  11. nhead,多头注意力的头数

9.3 前向传播

    def forward(self, src, mask, query_embed, pos_embed):bs, c, h, w = src.shapesrc = src.flatten(2).permute(2, 0, 1)pos_embed = pos_embed.flatten(2).permute(2, 0, 1)query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)mask = mask.flatten(1)tgt = torch.zeros_like(query_embed)memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,pos=pos_embed, query_pos=query_embed)return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
  1. 前向传播函数,传入数据源、掩码、Q向量、位置编码等
  2. bs, c, h, w,[2, 256, 24, 24],获取数据源的维度,batch、通道数、长、宽,每次传入的数据的长宽可能不同
  3. src,torch.Size([576, 2, 256]),维度从[batch_size, channels, height, width]转换为[height*width, batch_size, channels]
  4. pos_embed ,torch.Size([576, 2, 256]),位置编码进行同样操作
  5. query_embed ,torch.Size([100, 2, 256]),将Q向量扩展并重复,以匹配batch_size
  6. mask ,torch.Size([2, 576]),展平减少一个维度
  7. tgt ,torch.Size([100, 2, 256]),初始化目标序列tgt为与Q向量维度相同但全为0的Tensor,这在解码器的自回归预测中用作初始输入
  8. memory ,torch.Size([576, 2, 256]),调用编码器传入数据源、掩码、位置编码得到输出
  9. hs ,torch.Size([6, 100, 2, 256]),调用解码器传入初始化目标序列、编码器输出、掩码、位置编码、Q向量等得到输出
  10. return ,hs:torch.Size([6, 2, 100, 256]),memory:torch.Size([2, 256, 24, 24]),将编码器输出和解码器的输出的维度进行调整后返回

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/682410.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PS | 15个快捷键演示

01 前言 工具:Adobe Photoshop 2021 安装:无 网上自查 02 快捷键表 Ctrl T自由变换[减小画笔大小]增加画笔大小Shift [降低画笔硬度Shift ]增加笔刷硬度D默认前景/背景颜色X切换前景/背景颜色Ctrl J通过复制新建图层Ctrl Shift J通过剪切新建图层Esc取…

树莓派4B(Raspberry Pi 4B)使用docker搭建单机版nacos [基于docker-compose]

树莓派4B(Raspberry Pi 4B)使用docker搭建单机版nacos [基于docker-compose] 镜像仓库提供的基于arm64架构的nacos镜像很少,我选用的是centralx/nacos-server ,它是基于nacos 2.0.4开发的。 ⚠️ 本文基于docker-compose记述构建单…

使用 WPF + Chrome 内核实现高稳定性的在线客服系统复合应用程序

对于在线客服与营销系统,客服端指的是后台提供服务的客服或营销人员,他们使用客服程序在后台观察网站的被访情况,开展营销活动或提供客户服务。在本篇文章中,我将详细介绍如何通过 WPF Chrome 内核的方式实现复合客服端应用程序。…

#Z2294. 打印树的直径

Description 给你一棵树,树上有N个点,编号从0到N-1 请找出任意一条树的直径,并输出直径上的点,输出顺序为从直径的某个端点走向另一个端点 Format Input 第一行一个整数 n; 之后 n-1 行每行两个整数 u,v&#xf…

位运算总结(Java)

目录 位运算概述 位运算符 位运算的优先级 位运算常见应用 1. 给定一个数n,判断其二进制表示中的第x位是0还是1 2. 将数n的二进制表示中的第x位修改为1 3. 将数n的二进制表示中的第x位修改为0 4. 位图 例题:判断字符是否唯一 5. 提取数n的二进制…

《区块链公链数据分析简易速速上手小册》第5章:高级数据分析技术(2024 最新版)

文章目录 5.1 跨链交易分析5.1.1 基础知识5.1.2 重点案例:分析以太坊到 BSC 的跨链交易理论步骤和工具准备Python 代码示例构思步骤1: 设置环境和获取合约信息步骤2: 分析以太坊上的锁定交易步骤3: 跟踪BSC上的铸币交易 结论 5.1.3 拓展案例 1:使用 Pyth…

OCP的operator——(2)OLM

文章目录 了解OperatorOperator Lifecycle Manager(OLM)OLM概念和资源OLM是什么OLM资源Cluster service version(CSV)Catalog source定制catalog source的image模板目录健康需求 SubscriptionInstall planOperator groupOperator …

SQL世界之命令语句Ⅴ

目录 一、SQL CREATE INDEX 语句 1.SQL CREATE INDEX 语句 2.SQL CREATE INDEX 语法 3.SQL CREATE UNIQUE INDEX 语法 4.SQL CREATE INDEX 实例 二、SQL 撤销索引、表以及数据库 1.SQL DROP INDEX 语句 2.SQL DROP TABLE 语句 3.SQL DROP DATABASE 语句 4.SQL TRUNCA…

文件压缩炸弹,想到有点后怕

今天了解到一个概念,压缩炸弹。 参考: https://juejin.cn/post/7289667869557178404 https://www.zhihu.com/zvideo/1329374649210302464 什么是压缩炸弹 压缩炸弹(也称为压缩文件炸弹、炸弹文件)是一种特殊的文件,它…

ACTable开源框架的使用及异常

介绍 ACTable是对Mybatis做的增强功能,支持SpringBoot以及传统的SpringMvc架构,配置简单,使用方便。主要是自动生成数据库表,直接修改java代码,数据库就会对应的变化,省去在调整数据库表的问题&#xff0c…

笔试刷题(持续更新)| Leetcode 45,1190

45. 跳跃游戏 题目链接: 45. 跳跃游戏 II - 力扣(LeetCode) 这道题思路不难记,遍历数组每个位置,更新下一次的范围,当当前位置已经在当前范围之外时,步数一定得加一,当前范围更新成…

蓝桥杯官网填空题(质数拆分)

问题描述 将 2022 拆分成不同的质数的和,请问最多拆分成几个? 答案提交 本题为一道结果填空的题,只需要算出结果后,在代码中使用输出语句将结果输出即可。 运行限制 import java.util.Scanner;public class Main {static int …

Pandas Series 的学习笔记

Pandas Series 的学习笔记 0. Pandas 简介1. Series 学习1-1. 创建 Series1-2. 索引1-3. 选择数据1-4. 修改 Series1-5. Series 的操作 2. 结论 0. Pandas 简介 想象一下,你有一张超级大的餐桌,上面放满了各种各样的食物。Pandas 就像是这张餐桌&#x…

面试:大数据和深度学习之间的关系是什么?

大数据与深度学习之间存在着紧密的相互关系,它们在当今技术发展中相辅相成。 大数据的定义与特点:大数据指的是规模(数据量)、多样性(数据类型)和速度(数据生成及处理速度)都超出了传统数据处理软件和硬件能力范围的数据集。它具有四个主要特点,通常被称…

【Java】零基础蓝桥杯算法学习——二分查找

算法模板一: // 数组arr的区间[0,left-1]满足arr[i]<k,[left,n-1]满足arr[i]>k;Scanner scan new Scanner(System.in);int[] arr {1,2,3,4,5};int left 0,right arr.length-1;int k scan.nextInt();while(left<right) {//leftright时退出循环int mid (leftrigh…

leetcode(双指针)11.盛最多水的容器(C++详细解释)DAY9

文章目录 1.题目示例提示 2.解答思路3.实现代码结果 4.总结 1.题目 给定一个长度为 n 的整数数组 height 。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线&#xff0c;使得它们与 x 轴共同构成的容器可以容纳最多的水。 返回…

【Django】Django项目部署

项目部署 1 基本概念 项目部署是指在软件开发完毕后&#xff0c;将开发机器上运行的软件实际安装到服务器上进行长期运行。 在安装机器上安装和配置同版本的环境[python&#xff0c;数据库等] django项目迁移 scp /home/euansu/Code/Python/website euansuxx.xx.xx.xx:/home…

Rust的Match语句:强大的控制流运算符

在Rust中&#xff0c;match语句是一种强大的控制流运算符&#xff0c;用于比较一个值与一系列模式&#xff0c;并执行与第一个匹配的模式对应的代码块。它提供了一种清晰而灵活的方式来处理多个条件&#xff0c;使得代码更加可读、易于理解。 Match语句的基本使用 首先&#…

2月14作业

21.C 22.D 23.B 5先出栈表示1&#xff0c;2&#xff0c;3&#xff0c;4已经入栈了&#xff0c;5出后4出&#xff0c;但之后想出1得先让3&#xff0c;2先后出栈&#xff0c;所以 B 不可能 24.10&#xff0c;12&#xff0c;120 25.2&#xff0c;5 26.可能会出现段错误…

js importmap

在html文件中使用npm 下载的包&#xff0c;比如vue&#xff0c;在使用import引入的时候会报错 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-widt…