PETR和位置编码

PETR和位置编码

petr检测网络中有2种类型的位置编码。
正弦编码和petr论文提出的3D Position Embedding。transformer模块输入除了qkv,还有query_pos和key_pos。这里重点记录下query_pos和key_pos的生成

  • query pos的生成
    先定义reference_points, shape为(n_query, 3),编码部分有两部分构成,经过pos2posemb3d编码(sin编码)后,再用FFN(query_embed)编码一次后用作transformer的query_pos. 至于为什么多了个一次FFN编码,GPT这么解释的:

这种两步编码的设计实际上是将固定的位置编码(pos2posemb3d)和可学习的位置编码(query_embedding)相结合,既保留了位置的几何信息,又允许模型学习任务相关的位置表示。这种设计在3D视觉任务中特别有效,因为它既考虑了空间的周期性特征,又保持了位置编码的可学习性。

  1. pos2posemb3d(sin编码)
    标准的正弦编码
def pos2posemb3d(pos, num_pos_feats=128, temperature=10000):scale = 2 * math.pipos = pos * scale # map pos from [-1, 1] to [-2pi, 2pi]dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)pos_x = pos[..., 0, None] / dim_tpos_y = pos[..., 1, None] / dim_tpos_z = pos[..., 2, None] / dim_tpos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)pos_z = torch.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), dim=-1).flatten(-2)posemb = torch.cat((pos_y, pos_x, pos_z), dim=-1)return posemb
  1. query_embed(FFN)
    self.query_embedding = nn.Sequential(nn.Linear(self.embed_dims*3//2, self.embed_dims),nn.ReLU(),nn.Linear(self.embed_dims, self.embed_dims),)
    
  • key_pos的生成
    对于二维目标检测来说,对像素位置做编码就行了(sin_embed), 如下图的backbone这个分支,对于三维目标检测,petr对每个像素还做了三维位置编码(coords_position_embeding), 下图最下面一个分支。
    最终给transformer的key_pos = 3d位置编码+ 2d像素位置编码
    在这里插入图片描述

    1. 像素的3d位置编码
      根据图像尺寸定义一个视锥空间(coords),每个点用(u, v, d)表示,结合相机内参,可以将其转为世界坐标系下的点(coords3d),在用position_encoder(卷积)处理得到位置编码。
      在这里插入图片描述
    def position_embeding(self, img_feats, img_metas, masks=None):eps = 1e-5pad_h, pad_w, _ = img_metas[0]['pad_shape'][0]B, N, C, H, W = img_feats[self.position_level].shapecoords_h = torch.arange(H, device=img_feats[0].device).float() * pad_h / Hcoords_w = torch.arange(W, device=img_feats[0].device).float() * pad_w / Wif self.LID:index  = torch.arange(start=0, end=self.depth_num, step=1, device=img_feats[0].device).float()index_1 = index + 1bin_size = (self.position_range[3] - self.depth_start) / (self.depth_num * (1 + self.depth_num))coords_d = self.depth_start + bin_size * index * index_1else:index  = torch.arange(start=0, end=self.depth_num, step=1, device=img_feats[0].device).float()bin_size = (self.position_range[3] - self.depth_start) / self.depth_numcoords_d = self.depth_start + bin_size * indexD = coords_d.shape[0]coords = torch.stack(torch.meshgrid([coords_w, coords_h, coords_d])).permute(1, 2, 3, 0) # W, H, D, 3coords = torch.cat((coords, torch.ones_like(coords[..., :1])), -1)coords[..., :2] = coords[..., :2] * torch.maximum(coords[..., 2:3], torch.ones_like(coords[..., 2:3])*eps)img2lidars = []for img_meta in img_metas:img2lidar = []for i in range(len(img_meta['lidar2img'])):img2lidar.append(np.linalg.inv(img_meta['lidar2img'][i]))img2lidars.append(np.asarray(img2lidar))img2lidars = np.asarray(img2lidars)img2lidars = coords.new_tensor(img2lidars) # (B, N, 4, 4)coords = coords.view(1, 1, W, H, D, 4, 1).repeat(B, N, 1, 1, 1, 1, 1)img2lidars = img2lidars.view(B, N, 1, 1, 1, 4, 4).repeat(1, 1, W, H, D, 1, 1)coords3d = torch.matmul(img2lidars, coords).squeeze(-1)[..., :3]coords3d[..., 0:1] = (coords3d[..., 0:1] - self.position_range[0]) / (self.position_range[3] - self.position_range[0])coords3d[..., 1:2] = (coords3d[..., 1:2] - self.position_range[1]) / (self.position_range[4] - self.position_range[1])coords3d[..., 2:3] = (coords3d[..., 2:3] - self.position_range[2]) / (self.position_range[5] - self.position_range[2])coords_mask = (coords3d > 1.0) | (coords3d < 0.0)coords_mask = coords_mask.flatten(-2).sum(-1) > (D * 0.5)coords_mask = masks | coords_mask.permute(0, 1, 3, 2)coords3d = coords3d.permute(0, 1, 4, 5, 3, 2).contiguous().view(B*N, -1, H, W)coords3d = inverse_sigmoid(coords3d)coords_position_embeding = self.position_encoder(coords3d) # position_encoder:conv+relu+convreturn coords_position_embeding.view(B, N, self.embed_dims, H, W), coords_mask
    
    1. 像素的2d正弦编码
      通过图像的宽高,可以对每个像素坐标生成位置编码
    #SinePositionalEncoding3D
    def forward(self, mask):        """Forward function for `SinePositionalEncoding`.Args:mask (Tensor): ByteTensor mask. Non-zero values representingignored positions, while zero values means valid positionsfor this image. Shape [bs, h, w].Returns:pos (Tensor): Returned position embedding with shape[bs, num_feats*2, h, w]."""# For convenience of exporting to ONNX, it's required to convert# `masks` from bool to int.mask = mask.to(torch.int)not_mask = 1 - mask  # logical_notn_embed = not_mask.cumsum(1, dtype=torch.float32)y_embed = not_mask.cumsum(2, dtype=torch.float32)x_embed = not_mask.cumsum(3, dtype=torch.float32)if self.normalize:n_embed = (n_embed + self.offset) / \(n_embed[:, -1:, :, :] + self.eps) * self.scaley_embed = (y_embed + self.offset) / \(y_embed[:, :, -1:, :] + self.eps) * self.scalex_embed = (x_embed + self.offset) / \(x_embed[:, :, :, -1:] + self.eps) * self.scaledim_t = torch.arange(self.num_feats, dtype=torch.float32, device=mask.device)dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)pos_n = n_embed[:, :, :, :, None] / dim_tpos_x = x_embed[:, :, :, :, None] / dim_tpos_y = y_embed[:, :, :, :, None] / dim_t# use `view` instead of `flatten` for dynamically exporting to ONNXB, N, H, W = mask.size()pos_n = torch.stack((pos_n[:, :, :, :, 0::2].sin(), pos_n[:, :, :, :, 1::2].cos()),dim=4).view(B, N, H, W, -1)pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()),dim=4).view(B, N, H, W, -1)pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()),dim=4).view(B, N, H, W, -1)pos = torch.cat((pos_n, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3)return posdef __repr__(self):"""str: a string that describes the module"""repr_str = self.__class__.__name__repr_str += f'(num_feats={self.num_feats}, 'repr_str += f'temperature={self.temperature}, 'repr_str += f'normalize={self.normalize}, 'repr_str += f'scale={self.scale}, 'repr_str += f'eps={self.eps})'return repr_str

顺便记录下未使用的可学习编码

    @POSITIONAL_ENCODING.register_module()class LearnedPositionalEncoding3D(BaseModule):"""Position embedding with learnable embedding weights.Args:num_feats (int): The feature dimension for each positionalong x-axis or y-axis. The final returned dimension foreach position is 2 times of this value.row_num_embed (int, optional): The dictionary size of row embeddings.Default 50.col_num_embed (int, optional): The dictionary size of col embeddings.Default 50.init_cfg (dict or list[dict], optional): Initialization config dict."""def __init__(self,num_feats,row_num_embed=50,col_num_embed=50,init_cfg=dict(type='Uniform', layer='Embedding')):super(LearnedPositionalEncoding3D, self).__init__(init_cfg)self.row_embed = nn.Embedding(row_num_embed, num_feats)self.col_embed = nn.Embedding(col_num_embed, num_feats)self.num_feats = num_featsself.row_num_embed = row_num_embedself.col_num_embed = col_num_embeddef forward(self, mask):"""Forward function for `LearnedPositionalEncoding`.Args:mask (Tensor): ByteTensor mask. Non-zero values representingignored positions, while zero values means valid positionsfor this image. Shape [bs, h, w].Returns:pos (Tensor): Returned position embedding with shape[bs, num_feats*2, h, w]."""h, w = mask.shape[-2:]x = torch.arange(w, device=mask.device)y = torch.arange(h, device=mask.device)x_embed = self.col_embed(x)y_embed = self.row_embed(y)pos = torch.cat((x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(1, w, 1)),dim=-1).permute(2, 0,1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)return pos

参考链接:
https://blog.csdn.net/qq_16137569/article/details/123576866

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/79243.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu搭建 Nginx以及Keepalived 实现 主备

目录 前言1. 基本知识2. Keepalived3. 脚本配置4. Nginx前言 🤟 找工作,来万码优才:👉 #小程序://万码优才/r6rqmzDaXpYkJZF 爬虫神器,无代码爬取,就来:bright.cn Java基本知识: java框架 零基础从入门到精通的学习路线 附开源项目面经等(超全)【Java项目】实战CRU…

文章记单词 | 第56篇(六级)

一&#xff0c;单词释义 interview /ˈɪntəvjuː/&#xff1a; 名词&#xff1a;面试&#xff1b;采访&#xff1b;面谈动词&#xff1a;对… 进行面试&#xff1b;采访&#xff1b;接见 radioactive /ˌreɪdiəʊˈktɪv/&#xff1a;形容词&#xff1a;放射性的&#xff…

MATLAB函数调用全解析:从入门到精通

在MATLAB编程中&#xff0c;函数是代码复用的核心单元。本文将全面解析MATLAB中各类函数的调用方法&#xff0c;包括内置函数、自定义函数、匿名函数等&#xff0c;帮助提升代码效率&#xff01; 一、MATLAB函数概述 MATLAB函数分为以下类型&#xff1a; 内置函数&#xff1a…

哈希表笔记(二)redis

Redis哈希表实现分析 这份代码是Redis核心数据结构之一的字典(dict)实现&#xff0c;本质上是一个哈希表的实现。Redis的字典结构被广泛用于各种内部数据结构&#xff0c;包括Redis数据库本身和哈希键类型。 核心特点 双表设计&#xff1a;每个字典包含两个哈希表&#xff0…

PDF嵌入隐藏的文字

所需依赖 <dependency><groupId>com.itextpdf</groupId><artifactId>itext-core</artifactId><version>9.0.0</version><type>pom</type> </dependency>源码 /*** PDF工具*/ public class PdfUtils {/*** 在 PD…

RAG工程-基于LangChain 实现 Advanced RAG(预检索-查询优化)(下)

Multi-Query 多路召回 多路召回流程图 多路召回策略利用大语言模型&#xff08;LLM&#xff09;对原始查询进行拓展&#xff0c;生成多个与原始查询相关的问题&#xff0c;再将原始查询和生成的所有相关问题一同发送给检索系统进行检索。它适用于用户查询比较宽泛、模糊或者需要…

【业务领域】PCIE协议理解

PCIE协议理解 提示&#xff1a;这里可以添加系列文章的所有文章的目录&#xff0c;目录需要自己手动添加 PCIE学习理解。 文章目录 PCIE协议理解[TOC](文章目录) 前言零、PCIE掌握点&#xff1f;一、PCIE是什么&#xff1f;二、PCIE协议总结物理层切速 链路层事务层6.2 TLP的路…

Jupyter notebook快捷键

文章目录 Jupyter notebook键盘模式快捷键&#xff08;常用的已加粗&#xff09; Jupyter notebook键盘模式 命令模式&#xff1a;键盘输入运行程序命令&#xff1b;这时单元格框线为蓝色 编辑模式&#xff1a;允许你往单元格中键入代码或文本&#xff1b;这时单元格框线是绿色…

Unity图片导入设置

&#x1f3c6; 个人愚见&#xff0c;没事写写笔记 &#x1f3c6;《博客内容》&#xff1a;Unity3D开发内容 &#x1f3c6;&#x1f389;欢迎 &#x1f44d;点赞✍评论⭐收藏 &#x1f50e;Unity支持的图片格式 ☀️BMP:是Windows操作系统的标准图像文件格式&#xff0c;特点是…

Spark-小练试刀

任务1&#xff1a;HDFS上有三份文件&#xff0c;分别为student.txt&#xff08;学生信息表&#xff09;result_bigdata.txt&#xff08;大数据基础成绩表&#xff09;&#xff0c; result_math.txt&#xff08;数学成绩表&#xff09;。 加载student.txt为名称为student的RDD…

内存安全的攻防战:工具链与语言特性的协同突围

一、内存安全&#xff1a;C 开发者永恒的达摩克利斯之剑 在操作系统内核、游戏引擎、金融交易系统等对稳定性要求苛刻的领域&#xff0c;内存安全问题始终是 C 开发者的核心挑战。缓冲区溢出、悬空指针、双重释放等经典漏洞&#xff0c;每年在全球范围内造成数千亿美元的损失。…

OceanBase数据库-学习笔记1-概论

多租户概念 集群和分布式 随着互联网、物联网和大数据技术的发展&#xff0c;数据量呈指数级增长&#xff0c;单机数据库难以存储和处理如此庞大的数据。现代应用通常需要支持大量用户同时访问&#xff0c;单机数据库在高并发场景下容易成为性能瓶颈。单点故障是单机数据库的…

计算机网络——键入网址到网页显示,期间发生了什么?

浏览器做的第一步工作是解析 URL&#xff0c;分清协议是http还是https&#xff0c;主机名&#xff0c;路径名&#xff0c;然后生成http消息&#xff0c;之后委托操作系统将消息发送给 Web 服务器。在发送之前&#xff0c;还需要先去查询dns&#xff0c;首先是查询缓存浏览器缓存…

Qwen3本地化部署,准备工作:SGLang

文章目录 SGLang安装deepseek运行Qwen3-30B-A3B官网:https://github.com/sgl-project/sglang SGLang SGLang 是一个面向大语言模型和视觉语言模型的高效服务框架。它通过协同设计后端运行时和前端编程语言,使模型交互更快速且具备更高可控性。核心特性包括: 1. 快速后端运…

全面接入!Qwen3现已上线千帆

百度智能云千帆正式上线通义千问团队开源的最新一代Qwen3系列模型&#xff0c;包括旗舰级MoE模型Qwen3-235B-A22B、轻量级MoE模型Qwen3-30B-A3B。千帆大模型平台开源模型进一步扩充&#xff0c;以多维开放的模型服务、全栈模型开发、应用开发工具链、多模态数据治理及安全的能力…

蓝桥杯Python(B)省赛回忆

Q&#xff1a;为什么我要写这篇博客&#xff1f; A&#xff1a;在蓝桥杯软件类竞赛&#xff08;Python B组&#xff09;的备赛过程中我在网上搜索关于蓝桥杯的资料&#xff0c;感谢你们提供的参赛经历&#xff0c;对我的备赛起到了整体调整的帮助&#xff0c;让我知道如何以更…

数据转储(go)

​ 随着时间推移&#xff0c;数据库中的数据量不断累积&#xff0c;可能导致查询性能下降、存储压力增加等问题。数据转储作为一种有效的数据管理策略&#xff0c;能够将历史数据从生产数据库中转移到其他存储介质&#xff0c;从而减轻数据库负担&#xff0c;提高系统性能&…

Git Stash 详解

Git Stash 详解 在使用 Git 进行版本控制时&#xff0c;经常会遇到需要临时保存当前工作状态的情况。git stash 命令就是为此设计的&#xff0c;它允许你将未提交的更改暂存起来&#xff0c;在处理其他任务或分支后&#xff0c;再恢复这些更改。 目录 基本概念常用命令示例和…

Windows下Dify安装及使用

Dify安装及使用 Dify 是开源的 LLM 应用开发平台。提供从 Agent 构建到 AI workflow 编排、RAG 检索、模型管理等能力&#xff0c;轻松构建和运营生成式 AI 原生应用。比 LangChain 更易用。 前置条件 windows下安装了docker环境-Windows11安装Docker-CSDN博客 下载 Git下载…

Clang-Tidy协助C++编译期检查

文章目录 在Visual Studio中启用clang-tidyClang-tidy 常用的检查项readability-inconsistent-declaration-parameter-namemisc-static-assert 例子 C/C语言是一门编译型语言&#xff0c;比起python,javascript 这些&#xff0c;有很多BUG可以在编译期被排除掉&#xff0c;当然…