(即插即用模块-Convolution部分) 一、(ICLR 2022) ODConv 全维动态卷积

在这里插入图片描述

文章目录

  • 1、Omni-dimensional Dynamic Convolution
  • 2、代码实现

paper:OMNI-DIMENSIONAL DYNAMIC CONVOLUTION

Code:https://github.com/OSVAI/ODConv


1、Omni-dimensional Dynamic Convolution

论文首先分析了现有动态卷积的局限性,论文指出现有的动态卷积方法(如 CondConv 和 DyConv)都是通过学习一个线性组合的多个卷积核及其输入依赖的注意力来提升轻量级 CNN 的准确率,但存在以下局限性:

  1. 注意力机制设计粗糙: 它们仅关注卷积核数量的动态性,而忽略了卷积核的空间大小、输入通道数和输出通道数这三个维度,导致无法充分利用卷积核空间。
  2. 模型参数量增加: 动态卷积会显著增加模型参数量,从而影响模型大小和效率。

为了缓解这些存在的问题,论文提出一种 全维动态卷积(Omni-dimensional Dynamic Convolution),其核心原理是 多维注意力机制。ODConv 通过引入多维度注意力机制,来并行策略学习卷积核空间四个维度(空间大小、输入通道数、输出通道数和卷积核数量)上的四种注意力,四种注意力分别是:

  1. 位置注意力: 为每个卷积核在 k×k 空间位置上的每个卷积参数(每个滤波器)分配不同的注意力权重。
  2. 通道注意力: 为每个卷积核的 cin 个输入通道分配不同的注意力权重。
  3. 滤波器注意力: 为每个卷积核的 cout 个滤波器分配不同的注意力权重。
  4. 核注意力: 为整个卷积核分配一个注意力权重。

对于一个输入特征 X ,ODConv的实现过程主要包括以下方面:

  1. SE 类型的注意力模块

    特征压缩: 首先将输入特征通过通道方向的全局平均池化操作压缩成一个特征向量,其长度等于输入通道数。
    特征降维: 接下来,使用一个全连接层将特征向量映射到一个更低维的空间,降维比例 r 根据实验结果进行设置。
    多头注意力分支: 共有四个头部,每个头部负责计算一个维度上的注意力:
    (1)空间注意力分支: 输出大小为 k×k,负责计算位置注意力 αsi。
    (2)通道注意力分支: 输出大小为 cin×1,负责计算通道注意力 αci。
    (3)滤波器注意力分支: 输出大小为 cout×1,负责计算滤波器注意力 αfi。
    (4)核注意力分支: 输出大小为 n×1,负责计算核注意力 αwi。
    注意力归一化: 每个头部使用 Softmax 或 Sigmoid 函数将输出归一化,得到相应的注意力权重。

  2. 并行计算:四个注意力分支并行计算各自的注意力权重。

  3. 逐步应用

    位置注意力: 将 αsi 乘以卷积核的每个 k×k 空间位置上的卷积参数,实现位置维度的注意力机制。
    通道注意力: 将 αci 乘以卷积核的每个滤波器的 cin 个输入通道,实现通道维度的注意力机制。
    滤波器注意力: 将 αfi 乘以卷积核的 cout 个滤波器,实现滤波器维度的注意力机制。
    核注意力: 将 αwi 乘以整个卷积核,实现核维度的注意力机制。

  4. 输出:将经过四种注意力调整的卷积核与输入特征 x 进行卷积操作,得到最终的输出特征 y。


Omni-dimensional Dynamic Convolution 结构图:
在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autogradclass Attention(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):super(Attention, self).__init__()attention_channel = max(int(in_planes * reduction), min_channel)self.kernel_size = kernel_sizeself.kernel_num = kernel_numself.temperature = 1.0self.avgpool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False)self.bn = nn.BatchNorm2d(attention_channel)self.relu = nn.ReLU(inplace=True)self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)self.func_channel = self.get_channel_attentionif in_planes == groups and in_planes == out_planes:  # depth-wise convolutionself.func_filter = self.skipelse:self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)self.func_filter = self.get_filter_attentionif kernel_size == 1:  # point-wise convolutionself.func_spatial = self.skipelse:self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)self.func_spatial = self.get_spatial_attentionif kernel_num == 1:self.func_kernel = self.skipelse:self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)self.func_kernel = self.get_kernel_attentionself._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)if isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def update_temperature(self, temperature):self.temperature = temperature@staticmethoddef skip(_):return 1.0def get_channel_attention(self, x):channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)return channel_attentiondef get_filter_attention(self, x):filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)return filter_attentiondef get_spatial_attention(self, x):spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)spatial_attention = torch.sigmoid(spatial_attention / self.temperature)return spatial_attentiondef get_kernel_attention(self, x):kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)return kernel_attentiondef forward(self, x):x = self.avgpool(x)x = self.fc(x)x = self.bn(x)x = self.relu(x)return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)class ODConv2d(nn.Module):""" kernel_size = 1 or 3 """def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=0, dilation=1, groups=1,reduction=0.0625, kernel_num=4):super(ODConv2d, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.kernel_size = kernel_sizeself.stride = strideself.padding = paddingself.dilation = dilationself.groups = groupsself.kernel_num = kernel_numself.attention = Attention(in_planes, out_planes, kernel_size, groups=groups,reduction=reduction, kernel_num=kernel_num)self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size),requires_grad=True)self._initialize_weights()if self.kernel_size == 1 and self.kernel_num == 1:self._forward_impl = self._forward_impl_pw1xelse:self._forward_impl = self._forward_impl_commondef _initialize_weights(self):for i in range(self.kernel_num):nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')def update_temperature(self, temperature):self.attention.update_temperature(temperature)def _forward_impl_common(self, x):channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)batch_size, in_planes, height, width = x.size()x = x * channel_attentionx = x.reshape(1, -1, height, width)aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)aggregate_weight = torch.sum(aggregate_weight, dim=1).view([-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,dilation=self.dilation, groups=self.groups * batch_size)output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))output = output * filter_attentionreturn outputdef _forward_impl_pw1x(self, x):channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)x = x * channel_attentionoutput = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,dilation=self.dilation, groups=self.groups)output = output * filter_attentionreturn outputdef forward(self, x):return self._forward_impl(x)if __name__ == '__main__':x = torch.randn(4, 512, 7, 7).cuda()model = ODConv2d(512, 512).cuda()out = model(x)print(out.shape)

本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/62627.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习Python基础(2)

二 数据处理 一般来说PyTorch中深度学习训练的流程是这样的: 1. 创建Dateset 2. Dataset传递给DataLoader 3. DataLoader迭代产生训练数据提供给模型 对应的一般都会有这三部分代码 # 创建Dateset(可以自定义) dataset face_dataset # Dataset部分自定义过的…

[2024.11.25-12.1] 一周科技速报

2024 世界传感器大会在郑州开幕 时间:12月1日至2日。 会议内容:大会以 “感知世界 智创未来” 为主题,由 “一会两赛一峰会” 组成。开幕式上发布了 “郑州宣言”,倡导行业携手打造合作共赢的产业新生态,还首发了《2…

(超详细图文详情)Navicat 配置连接 Oracle

1、下载依赖文件 Oracle官网下载直链:https://www.oracle.com/database/technologies/instant-client/winx64-64-downloads.html 夸克网盘下载(oracle19c版本):https://pan.quark.cn/s/5061e690debc 官网下载选择对应 Oracle 版…

jdk各个版本介绍

Java Development Kit(JDK)是Java平台的核心组件,它包含了Java编程语言、Java虚拟机(JVM)、Java类库以及用于编译、调试和运行Java应用程序的工具。 JDK 1.0-1.4(经典时代) • JDK 1.0&#xff…

基于 Python 的自动化框架示例

以下是一个基于Python的自动化测试代码框架示例,包含了 app_lib(库模块,用于存放通用功能相关代码)、app_test(测试用例相关模块)、config(配置文件及配置读取相关部分)等模块&#…

二分法篇——于上下边界的扭转压缩间,窥见正解辉映之光(1)

前言 二分法,这一看似简单却又充满哲理的算法,犹如一道精巧的数学之门,带领我们在问题的迷雾中找到清晰的道路。它的名字虽简单,却深藏着智慧的光辉。在科学的浩瀚星空中,二分法如一颗璀璨的星辰,指引着我们…

基于 FFmpeg/Scrcpy 框架构建的一款高性能的安卓设备投屏管理工具-供大家学习研究参考

支持的投屏方式有:USB,WIFIADB,OTG,投屏之前需要开启开发者选项里面的USB调试。 主要功能有: 1.支持单个或多个设备投屏。 2.支持键鼠操控。 3.支持文字输入。 4.支持共享剪切板(可复制粘贴电脑端文字到手机端,也可导出手机剪切板到电脑端)。 5.支持视频图片上传,可单…

【Go底层】time包Ticker定时器原理

目录 1、背景2、go版本3、源码解释【1】Ticker结构【2】NewTicker函数解释 4、代码示例5、总结 1、背景 说到定时器我们一般想到的库是cron,但是对于一些简单的定时任务场景,标准库time包下提供的定时器就足够我们使用,本篇文章我们就来研究…

Docker 部署Nginx 数据卷挂载 配置文件挂载

启动容器 docker run -d --name nginx \-v /etc/local/nginx/dist:/usr/share/nginx/html \-p 80:80 \--restart always \nginx宿主机站点 /etc/local/nginx/dist 容器内html /usr/share/nginx/html 复制配置文件到主机 docker cp nginx:/etc/nginx/nginx.conf /etc/local/n…

【论文笔记】A Token-level Contrastive Framework for Sign Language Translation

🍎个人主页:小嗷犬的个人主页 🍊个人网站:小嗷犬的技术小站 🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。 基本信息 标题: A Token-level Contrastiv…

ROS2教程 - 3 HelloWorld

更好的阅读体验:https://www.foooor.com 3 HelloWorld 下面从 HelloWorld 开始,讲解 ROS2 的开发。 ROS 开发主要使用 C 或 Python 实现,如果要实现的功能,对性能有要求,可以使用 C 实现,如果对性能没有…

洛谷 B3626 跳跃机器人 C语言 记忆化搜索

题目: https://www.luogu.com.cn/problem/B3626 题目描述 地上有一排格子,共 n 个位置。机器猫站在第一个格子上,需要取第 n 个格子里的东西。 机器猫当然不愿意自己跑过去,所以机器猫从口袋里掏出了一个机器人!这…

【AI】Sklearn

长期更新,建议关注、收藏、点赞。 友情链接: AI中的数学_线代微积分概率论最优化 Python numpy_pandas_matplotlib_spicy 建议路线:机器学习->深度学习->强化学习 目录 预处理模型选择分类实例: 二分类比赛 网格搜索实例&…

⭐️ GitHub Star 数量前十的工作流项目

文章开始前,我们先做个小调查:在日常工作中,你会使用自动化工作流工具吗?🙋 事实上,工作流工具已经变成了提升效率的关键。其实在此之前我们已经写过一篇博客,跟大家分享五个好用的工作流工具。…

Tree搜索二叉树、map和set_数据结构

数据结构专栏 如烟花般绚烂却又稍纵即逝的个人主页 本章讲述数据结构中搜索二叉树与HashMap的学习,感谢大家的支持!欢迎大家踊跃评论,感谢大佬们的支持! 目录 搜索二叉树的概念二叉树搜索模拟实现搜索二叉树查找搜索二叉树插入搜索二叉树删除…

Swift实现高效链表排序:一步步解读

文章目录 前言摘要问题描述题解解题思路Swift 实现代码代码分析示例测试与结果 时间复杂度空间复杂度总结关于我们 前言 本题由于没有合适答案为以往遗留问题,最近有时间将以往遗留问题一一完善。 148. 排序链表 不积跬步,无以至千里;不积小流…

【开篇】.NET开源 ORM 框架 SqlSugar 系列

.NET开源 ORM 框架 SqlSugar 系列 【开篇】.NET开源 ORM 框架 SqlSugar 系列【入门必看】.NET开源 ORM 框架 SqlSugar 系列【实体配置】.NET开源 ORM 框架 SqlSugar 系列【Db First】.NET开源 ORM 框架 SqlSugar 系列【Code First】.NET开源 ORM 框架 SqlSugar 系列【数据事务…

qt QAnimationDriver详解

1、概述 QAnimationDriver是Qt框架中提供的一个类,它主要用于自定义动画帧的时间控制和更新。通过继承和实现QAnimationDriver,开发者可以精确控制动画的时间步长和更新逻辑,从而实现丰富和灵活的动画效果。QAnimationDriver与QAbstractAnim…

何时在 SQL 中使用 CHAR、VARCHAR 和 VARCHAR(MAX)

在管理数据库表时,考虑 CHAR、VARCHAR 和 VARCHAR(MAX) 是必不可少的。此外,使用正确的工具(例如dbForge Studio for SQL Server) ,与数据库相关的任务都会变得更加容易。它是针对 SQL Server 专业人员的强大的一体化解…

20241127 给typecho文章编辑附件 添加视频 图片预览

Typecho在写文章时,如果一次性上传太多张图片可能分不清哪张,因为附件没有略缩图,无法实时阅览图片,给文章插入图片时很不方便。 编辑admin/file-upload.php 大约十八行的位置 一个while 循环里面,这是在进行html元素更新操作,在合…