Transformer实战-系列教程18:DETR 源码解读5(BackboneBase类/Backbone类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)

7、BackboneBase类

位置:models/backbone.py/BackboneBase类

7.1 构造函数

class BackboneBase(nn.Module):def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):super().__init__()for name, parameter in backbone.named_parameters():if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:parameter.requires_grad_(False)if return_interm_layers:return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}else:return_layers = {'layer4': "0"}self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)self.num_channels = num_channels
  1. 定义一个继承nn.Module的类
  2. 构造函数,传入4个参数:
    • backbone:一个nn.Module对象,代表用于特征提取的骨架网络
    • train_backbone:是否训练backbone
    • num_channels:backbone通道数
    • return_interm_layers:是否返回backbone的中间层输出
  3. 初始化
  4. 遍历backbone的所有参数,named_parameters()方法返回网络中所有参数的迭代器,包括参数的名称和值
  5. 如果train_backbone设置为False,且不训练layer2layer3layer4,也就是说如果train_backbone为False,backbone的所有层的所有参数都不需要训练,即所有层都被冻住
  6. 不需要训练的参数的requires_grad属性设置为False
  7. 根据return_interm_layers的值
  8. 选择性地设置return_layers字典
  9. 一个层对应一个值
  10. 这个字典定义了哪些层的输出将被返回
  11. 创建IntermediateLayerGetter实例,它封装了backbone,根据return_layers字典决定返回哪些层的输出,IntermediateLayerGetter来自torchvision
  12. num_channels

7.2 前向传播

    def forward(self, tensor_list: NestedTensor):xs = self.body(tensor_list.tensors)out: Dict[str, NestedTensor] = {}for name, x in xs.items():m = tensor_list.maskassert m is not Nonemask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]out[name] = NestedTensor(x, mask)return out
  1. 前向传播函数,接收NestedTensor对象作为输入
  2. xs ,获取指定层的输出
  3. out,初始化一个字典,存储每个返回层的输出及其对应的新掩码
  4. 遍历xsitems
  5. 获取mask
  6. 确认mask存在
  7. 计算新的掩码
  8. 将输出和新掩码封装为NestedTensor对象
  9. 返回out字典

8、Backbone类

8.1 Backbone类

位置:models/backbone.py/Backbone类

class Backbone(BackboneBase):def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):backbone = getattr(torchvision.models, name)(replace_stride_with_dilation=[False, False, dilation],pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)num_channels = 512 if name in ('resnet18', 'resnet34') else 2048super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
  1. 定义一个继承BackboneBase的类
  2. 初始化方法,接受四个参数:
    • name:字符串,指定要使用的ResNet模型的名称(如resnet50resnet101等)
    • train_backbone:布尔值,指示是否训练backbone
    • return_interm_layers:布尔值,指示是否返回backbone的中间层输出
    • dilation:布尔值,指示在网络的最后几层是否应用空洞卷积(dilation)以增加感受野
  3. 通过torchvision.models动态获取指定名称的ResNet模型
  4. replace_stride_with_dilation,最后一个stage应用空洞卷积
  5. pretrained,根据is_main_process()的返回值决定是否加载预训练权重,norm_layer设置为FrozenBatchNorm2d,在backbone中使用冻结的批归一化
  6. 根据ResNet模型的不同,设置不同的输出通道数
  7. 调用基类BackboneBase的初始化方法,传递创建的backbone实例和其他参数

这个Backbone类通过提供对ResNet模型的封装,允许用户灵活地选择不同的配置,例如是否训练Backbone、是否返回中间层输出以及是否在网络后段应用空洞卷积。同时,通过使用冻结的批量归一化层,可以在不调整BN层参数的情况下,利用预训练的模型进行特征提取

8.2 build_backbone()函数

位置:models/backbone.py/build_backbone()函数

本项目的backbone,主要是调用resnet,用来提取图像特征,进而构建图像序列做Transformer的输入,backbone的构建主要通过这个函数来实现:

def build_backbone(args):position_embedding = build_position_encoding(args)train_backbone = args.lr_backbone > 0return_interm_layers = args.masksbackbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)model = Joiner(backbone, position_embedding)model.num_channels = backbone.num_channelsreturn model

这段代码定义了一个名为build_backbone的函数,用于根据提供的参数构建一个含有位置编码的骨架网络模型。以下是对这段代码的逐行解释:

  1. 函数build_backbone,接收命令行参数
  2. position_embedding ,调用build_position_encoding,函数构建位置编码
  3. 通过lr_backbone(backbone的学习率)是否大于0来决定是否训练backbone
  4. args.masks指示是否需要骨架网络返回中间层的输出
  5. 通过Backbone类构建backbone
  6. 通过Joiner类传入backbone和位置编码,建立backbone模型

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/683174.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

我的NPI项目之嵌入式总线系列(一) -- SPI 接口

如我的NPI项目之Android 安全系列 -- 外挂SE集成(SPI)接口-CSDN博客 提到SPI的接口,基本的电气特性已经给出。这边文章就针对协议部分进行详细解析。从协议网找到了原文:SPI protocol 还有wilipedia SPI 主要涉及一下几个方面&a…

bat 定时收缩sqlserver2017

如果你希望使用批处理(.bat)文件来定时收缩SQL Server的数据库,你可以编写一个脚本来执行这个任务。但首先,需要注意的是,定期收缩数据库通常不是一个好的做法,因为它可能会对性能产生负面影响,…

全闭环直播推流桌面分享远控系统

直播推流涉及多协议,多端技术栈和知识点,,想要做好并不容易,经过几年时间的迭代,终于小有成就,聚集了媒体服务器,实时会议sfu,远控kvm等功能。可以做一个音视频应用的瑞士小军刀。主…

详解Vue文件结构+实现一个简单案例

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…

测试开发-2-概念篇

文章目录 衡量软件测试结果的依据—需求1.需求的概念2.从软件测试人员角度看需求3.为什么需求对软件测试人员如此重要4.如何才可以深入理解被测试软件的需求5.测试用例的概念6.软件错误(BUG)的概念7.开发模型和测试模型8.软件的生命周期9.瀑布模型&#…

MATLAB知识点:randperm函数(★★★★★)将一个数字序列进行随机打乱

​讲解视频:可以在bilibili搜索《MATLAB教程新手入门篇——数学建模清风主讲》。​ MATLAB教程新手入门篇(数学建模清风主讲,适合零基础同学观看)_哔哩哔哩_bilibili 节选自第3章:课后习题讲解中拓展的函数 在讲解第…

Codeforces Round 923 - A.B.C.D

文章目录 A. Make it WhiteB. Following the StringC.Choose the Different Ones!D. Find the Different Ones! A. Make it White #include<bits/stdc.h>using namespace std;void solve() {int n;cin >> n;string s; cin >> s;int flag 0;int x 0, y -1…

django中admin页面汉化

在Django中&#xff0c;将admin界面汉化为中文需要进行一些配置和翻译文件的添加。下面是一个基本的步骤指南&#xff0c;帮助你实现Django admin的汉化&#xff1a; 一&#xff1a;安装并配置Django: 如果你还没有安装Django&#xff0c;首先通过pip安装它&#xff1a; pip…

【开源训练数据集1】神经语言程式(NLP)项目的15 个开源训练数据集

一个聊天机器人需要大量的训练数据,以便在无需人工干预的情况下快速解决用户的询问。然而,聊天机器人开发的主要瓶颈是获取现实的、面向任务的对话数据来训练这些基于机器学习的系统。 我们整理了训练聊天机器人所需的对话数据集,包括问答数据、客户支持数据、对话数据和多…

ESP32学习(1)——环境搭建

使用的ESP32板子如下图所示 它可以用Arduino 软件&#xff0c;基于C语言开发。但是&#xff0c;在这里&#xff0c;我是用Thonny软件&#xff0c;基于micro_python对其进行开发。 1.安装Thonny Thonny的软件安装包&#xff0c;可以去它官网上下载。Thonny, Python IDE for begi…

【MySQL】学习外键约束处理员工数据

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-g4glZPIY0IKhiTfe {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

【原理解密】多角度、多尺度、多目标的边缘模板匹配

学习《OpenCV应用开发&#xff1a;入门、进阶与工程化实践》一书 做真正的OpenCV开发者&#xff0c;从入门到入职&#xff0c;一步到位&#xff01; 边缘模板匹配的基本原理 OpenCV中自带的模板匹配算法&#xff0c;完全是像素基本的模板匹配&#xff0c;特别容易受到光照影…

【剪映】如何使用曲线变速?

如何使用曲线变速 进入视频编辑界面后&#xff0c;选中视频&#xff0c;点击下方工具栏的-【变速】-【曲线变速】&#xff0c;进入后可以看到七个预设变速&#xff0c;其中后六个为系统自带预设变速&#xff0c;每个预设变速效果不同&#xff0c;直接点击这六个预设&#xff0c…

Python一些可能用的到的函数系列124 GlobalFunc

说明 GlobalFunc是算网的下一代核心数据处理基础。 算网是一个分布式网络&#xff0c;为了能够实现真的分布式计算&#xff08;加快大规模任务执行效率&#xff09;&#xff0c;以及能够在很长的时间内维护不同版本的计算方法&#xff0c;需要这样一个对象/服务来支撑。Globa…

如何使用python在三天内制作出一个赛车游戏

制作一个赛车游戏是一个复杂的过程&#xff0c;涉及多个方面&#xff0c;如游戏设计、图形渲染、物理引擎、用户输入处理等。在三天内完成这个任务可能非常具有挑战性&#xff0c;特别是如果你是初学者。但如果你有基本的Python编程知识和一些游戏开发经验&#xff0c;以下是一…

尚硅谷最新Node.js 学习笔记(三)

目录 六、Node.js 模块化 6.1、介绍 什么是模块化与模块&#xff1f; 什么是模块化项目&#xff1f; 模块化好处 6.2、模块暴露数据 模块初体验 暴露数据 6.3、导入&#xff08;引入&#xff09;模块 6.4、导入模块的基本流程 6.5、CommonJS规范 七、包管理工具 7…

Win 10 如何升级 Win 11

方法一&#xff1a; 设置->Windows 更新->检查更新 然后会有许多要下载更新的&#xff0c;期间会要求多次重启&#xff0c;每次重启完之后再检查更新&#xff0c;直到显示是最新&#xff0c;然后一般会有一个Win11的入口&#xff0c;点这里就可以了。 我很久之前升的&…

Java-数组遍历

for循环遍历 具体描述 假设有一个数组nums,设置初始条件i0,即从数组的第一个开始,循环结束条件为i<nums.length,即数组中所有元素的数量&#xff0c;设置更新条件i,即依次遍历完数组中所有元素 实例&#xff1a; public class demo04 {public static void main(String[]…

2019年全年回顾

本文于2020年Q1完成&#xff0c;发布在个人博客网站上。 最近几年处于动荡之中&#xff0c;比较忙碌&#xff0c;好几年没有写年度总结了。 现在2020年Q1马上结束&#xff0c;先把19年的总结补了。 年度大事记 1月 启动项目迁移工作。 深圳团队的人员释放&#xff0c;在南京…

java数据结构前置知识以及认识泛型

目录 什么是集合框架 容器 时间复杂度 空间复杂度 包装类 装箱 拆箱 引出泛型 泛型类的使用 类型推导 泛型如何编译的 泛型的上界 泛型方法静态泛型方法以及泛型上界 什么是集合框架 Java 集合框架 Java Collection Framework &#xff0c;又被称为容器 containe…