基于transformer的解码decode目标检测框架(修改DETR源码)

提示:transformer结构的目标检测解码器,包含loss计算,附有源码

文章目录

  • 前言
  • 一、main函数代码解读
    • 1、整体结构认识
    • 2、main函数代码解读
    • 3、源码链接
  • 二、decode模块代码解读
    • 1、decoded的TransformerDec模块代码解读
    • 2、decoded的TransformerDecoder模块代码解读
    • 3、decoded的DecoderLayer模块代码解读
  • 三、decode模块训练demo代码解读
    • 1、解码数据输入格式
    • 2、解码训练demo代码解读
  • 四、decode模块预测demo代码解读
    • 1、预测数据输入格式
    • 2、解码预测demo代码解读
  • 五、losses模块代码解读
    • 1、matcher初始化
    • 2、二分匹配matcher代码解读
    • 3、num_classes参数解读
    • 4、losses的demo代码解读


前言

最近重温DETR模型,越发感觉detr模型结构精妙之处,不同于anchor base 与anchor free设计,直接利用100框给出预测结果,使用可学习learn query深度查找,使用二分匹配方式训练模型。为此,我基于detr源码提取解码decode、loss计算等系列模块,并重构、修改、整合一套解码与loss实现的框架,该框架可适用任何backbone特征提取接我框架,实现完整训练与预测,我也有相应demo指导使用我的框架。那么,接下来,我将完整介绍该框架源码。同时,我将此源码进行开源,并上传github中,供读者参考。


一、main函数代码解读

1、整体结构认识

在介绍main函数代码前,我先说下整体框架结构,该框架包含2个文件夹,一个losses文件夹,用于处理loss计算,一个是obj_det文件,用于transformer解码模块,该模块源码修改于detr模型,也包含main.py,该文件是整体解码与loss计算demo示意代码,如下图。

在这里插入图片描述

2、main函数代码解读

该代码实际是我随机创造了标签target数据与backbone特征提取数据及位置编码数据,使其能正常运行的demo,其代码如下:

import torch
from obj_det.transformer_obj import TransformerDec
from losses.matcher import HungarianMatcher
from losses.loss import SetCriterionif __name__ == '__main__':Model = TransformerDec(d_model=256, output_intermediate_dec=True, num_classes=4)num_classes = 4   #  类别+1matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)  # 二分匹配不同任务分配的权重losses = ['labels', 'boxes', 'cardinality']  # 计算loss的任务weight_dict = {'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2}  # 为dert最后一个设置权重criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=0.1, losses=losses)# 下面使用iter,我构造了虚拟模型编码数据与数据加载标签数据src = torch.rand((391, 2, 256))pos_embed = torch.ones((391, 1, 256))# 创造真实target数据target1 = {'boxes':torch.rand((5,4)),'labels':torch.tensor([1,3,2,1,2])}target2 = {'boxes': torch.rand((3, 4)), 'labels': torch.tensor([1, 1, 2])}target = [target1, target2]res = Model(src, pos_embed)losses = criterion(res, target)print(losses)

如下图:

在这里插入图片描述

3、源码链接

源码链接:点击这里

二、decode模块代码解读

该模块主要是使用transform方式对backbone提取特征的解码,主要使用learn query等相关trike与transform解码方式内容。
我主要介绍TransformerDec、TransformerDecoder、DecoderLayer模块,为依次被包含关系,或说成后者是前者组成部分。

1、decoded的TransformerDec模块代码解读

该类大意是包含了learn query嵌入、解码transform模块调用、head头预测logit与boxes等内容,是实现解码与预测内容,该模块参数或解释已有注释,读者可自行查看,其代码如下:

class TransformerDec(nn.Module):'''d_model=512, 使用多少维度表示,实际为编码输出表达维度nhead=8, 有多少个头num_queries=100, 目标查询数量,可学习querynum_decoder_layers=6, 解码循环层数dim_feedforward=2048, 类似FFN的2个nn.Linear变化dropout=0.1,activation="relu",normalize_before=False,解码结构使用2种方式,默认False使用post解码结构output_intermediate_dec=False, 若为True保存中间层解码结果(即:每个解码层结果保存),若False只保存最后一次结果,训练为True,推理为Falsenum_classes: num_classes数量与数据格式有关,若类别id=1表示第一类,则num_classes=实际类别数+1,若id=0表示第一个,则num_classes=实际类别数额外说明,coco类别id是1开始的,假如有三个类,名称为[dog,cat,pig],batch=2,那么参数num_classes=4,表示3个类+1个背景,模型输出src_logits=[2,100,5]会多出一个预测,target_classes设置为[2,100],其值为4(该值就是背景,而有类别值为123),那么target_classes中没有值为0,我理解模型不对0类做任何操作,是个无效值,模型只对1234进行loss计算,然4为背景会比较多,作者使用权重0.1避免其背景过度影响。forward return: 返回字典,包含{'pred_logits':[],  # 为列表,格式为[b,100,num_classes+2]'pred_boxes':[],  # 为列表,格式为[b,100,4]'aux_outputs'[{},...] # 为列表,元素为字典,每个字典为{'pred_logits':[],'pred_boxes':[]},格式与上相同}'''def __init__(self, d_model=512, nhead=8, num_queries=100, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False, output_intermediate_dec=False, num_classes=1):super().__init__()self.num_queries = num_queriesself.query_embed = nn.Embedding(num_queries, d_model)  # 与编码输出表达维度一致self.output_intermediate_dec = output_intermediate_decdecoder_layer = DecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)decoder_norm = nn.LayerNorm(d_model)self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/128270.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《现代C++语言核心特性解析》笔记(一)

一、新基础类型(C11~C20) C基础类型回顾一览表 1. 整数类型 long long 我们知道long通常表示一个32位整型,而long long则是用来表示一个64位的整型。不得不说,这种命名方式简单粗暴。不仅写法冗余,而且表…

MuLogin浏览器如何在一台设备上安全登录和管理多个LinkedIn账户?

一、LinkedIn多个账户的用处 LinkedIn作为世界上最大的专业人士社交平台,具有许多有用的功能,对于个人和企业来说都非常重要。以下是多个LinkedIn账户的一些典型用途: 1. 分行业账户:如果您在不同的行业从事职业活动&#xff0c…

【多线程面试题二十二】、 说说你对读写锁的了解

文章底部有个人公众号:热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享? 踩过的坑没必要让别人在再踩,自己复盘也能加深记忆。利己利人、所谓双赢。 面试官:说说你对读写锁的了解 …

华纳云:centos系统中怎么查看cpu信息?

在CentOS系统中,我们可以使用一些命令来查看CPU的详细信息。下面介绍几个常用的命令: 1. lscpu lscpu命令可以显示CPU的架构、型号、核心数、线程数、频率等信息。 # lscpu 执行以上命令后,会输出类似以下内容: 2. cat /proc/…

【C/C++】继承中同名成员处理方式

一、继承中同名成员处理方式 问题:当子类与父类出现同名的成员,如何通过子类对象,访问到子类或父类中同名的数据呢? 访问子类同名成员 直接访问即可访问父类同名成员 需要加作用域 示例: class Base { public:Base…

Pinia和Vuex有什么区别?

Pinia和Vuex有什么区别? 在Vue.js应用程序中,状态管理是确保组件之间数据共享和通信的关键部分。两个主要的状态管理库是Pinia和Vuex。Pinia是一个相对较新的库,而Vuex则是Vue.js社区中使用最广泛的状态管理工具之一。 1. 简介 1.1 Vuex …

金蝶云星空表单插件获取日期控件判空处理(代码示例)

文章目录 金蝶云星空表单插件获取日期控件判空处理C#实现 金蝶云星空表单插件获取日期控件判空处理 C#实现 DateTime? deliveryDate (DateTime?)this.View.Model.GetValue("FApproveDate");//审核日期long leadtime 20;//天数if (!deliveryDate.IsNullOrEmpty()…

在IDEA运行spark程序(搭建Spark开发环境)

建议大家写在Linux上搭建好Hadoop的完全分布式集群环境和Spark集群环境,以下在IDEA中搭建的环境仅仅是在window系统上进行spark程序的开发学习,在window系统上可以不用安装hadoop和spark,spark程序可以通过pom.xml的文件配置,添加…

Java 的高性能缓存库-caffeine!

在项目中用到的除了分布式缓存,还有本地缓存,例如:Guava、Encache,使用本地缓存能够很大程度上提升程序性能,本地缓存是直接从本地内存中读取,没有网络开销。 今天给大家介绍一个高性能的 Java 缓存库 – …

C++设计模式_26_设计模式总结

本篇为C++设计模式的总结课,此篇再回到原帮助大家梳理一下。 文章目录 1. 一个目标2. 两种手段3. 八大原则4. 重构技法5. 从封装变化角度对模式分类6. C++对象模型7. 关注变化点和稳定点8. 什么时候不用模式9. 经验之谈10. 设计模式成长之路1. 一个目标 管理变化,提高复用!…

认证授权--越权访问测试用例

漏洞名称 越权访问 漏洞描述 越权访问,这类漏洞是指应用在检查授权(Authorization)时存在纰漏,使得攻击者在获得低权限用户帐后后,可以利用一些方式绕过权限检查,访问或者操作到原本无权访问的高权限功能…

Angular组件生命周期详解

当 Angular 实例化组件类 并渲染组件视图及其子视图时,组件实例的生命周期就开始了。生命周期一直伴随着变更检测,Angular 会检查数据绑定属性何时发生变化,并按需更新视图和组件实例。当 Angular 销毁组件实例并从 DOM 中移除它渲染的模板时…

Ubuntu 使用 nginx 搭建 https 文件服务器

Ubuntu 使用 nginx 搭建 https 文件服务器 搭建步骤安装 nginx生成证书修改 config重启 nginx 搭建步骤 安装 nginx生成证书修改 config重启 nginx 安装 nginx apt 安装: sudo apt-get install nginx生成证书 使用 openssl 生成证书: 到对应的路径…

智能化的宠物喂食器解决方案

随着经济条件的不断改善,越来越多的家庭开始追求生活的便捷享受,于是喂食器开始走进千家万户,喂食器主要由储存食物的蓄食箱和传送食物的滑道构成,在外部框架的支撑下,一台喂食器才能正常进行工作,而宠物喂…

Day 50 动态规划 part16

Day 50 动态规划 part16 解题理解58372 2道题目 583. 两个字符串的删除操作 72. 编辑距离 解题理解 583 dp[i][j]:以i-1为结尾的字符串word1,和以j-1位结尾的字符串word2,想要达到相等,所需要删除元素的最少次数。 当word1[i -…

数据仓库与数据挖掘

1.数据挖掘的概念 数据挖掘(Data mining),又译为资料探勘、数据采矿。它是数据库知识发现(Knowledge-Discovery in Databases,KDD)中的一个步骤。 数据挖掘一般是指从大量的数据中通过算法搜索隐藏于其中…

案例精选|聚铭综合日志分析系统夯实徐州公交集团网络环境基础

徐州市公共交通集团有限公司成立于1960年,现隶属徐州市交通控股集团有限公司,下辖7家运营公司,1家站务公司,8家直属单位及13个职能部室。运营车辆2364辆,线路177条,线路长度3560公里,日发送班次…

元素的水平居中和垂直几种方案

总结一下各种元素的水平居中和垂直居中方案。 水平居中: 1.行内元素水平居中 text-align: center 定义行内内容(例如文字)如何相对它的块父元素对齐;不仅可以让文字水平居中,还可以让行内元素水平居中 注意:给行内…

云尘-Node1 js代码

继续做题 拿到就是基本扫一下 nmap -sP 172.25.0.0/24 nmap -sV -sS -p- -v 172.25.0.13 然后顺便fscan扫一下咯 nmap: fscan: 还以为直接getshell了 老演员了 其实只是302跳转 所以我们无视 只有一个站 直接看就行了 扫出来了两个目录 但是没办法 都是要跳转 说明还是需要…

@Tag和@Operation标签失效问题。SpringDoc 2.2.0(OpenApi 3)和Spring Boot 3.1.1集成

问题 Tag和Operation标签失效 但是Schema标签有效 pom依赖 <!-- 接口文档--><!--引入openapi支持--><dependency><groupId>org.springdoc</groupId><artifactId>springdoc-openapi-starter-webmvc-ui</artifactId><vers…