文章目录
- 1、Omni-dimensional Dynamic Convolution
- 2、代码实现
paper:OMNI-DIMENSIONAL DYNAMIC CONVOLUTION
Code:https://github.com/OSVAI/ODConv
1、Omni-dimensional Dynamic Convolution
论文首先分析了现有动态卷积的局限性,论文指出现有的动态卷积方法(如 CondConv 和 DyConv)都是通过学习一个线性组合的多个卷积核及其输入依赖的注意力来提升轻量级 CNN 的准确率,但存在以下局限性:
- 注意力机制设计粗糙: 它们仅关注卷积核数量的动态性,而忽略了卷积核的空间大小、输入通道数和输出通道数这三个维度,导致无法充分利用卷积核空间。
- 模型参数量增加: 动态卷积会显著增加模型参数量,从而影响模型大小和效率。
为了缓解这些存在的问题,论文提出一种 全维动态卷积(Omni-dimensional Dynamic Convolution),其核心原理是 多维注意力机制。ODConv 通过引入多维度注意力机制,来并行策略学习卷积核空间四个维度(空间大小、输入通道数、输出通道数和卷积核数量)上的四种注意力,四种注意力分别是:
- 位置注意力: 为每个卷积核在 k×k 空间位置上的每个卷积参数(每个滤波器)分配不同的注意力权重。
- 通道注意力: 为每个卷积核的 cin 个输入通道分配不同的注意力权重。
- 滤波器注意力: 为每个卷积核的 cout 个滤波器分配不同的注意力权重。
- 核注意力: 为整个卷积核分配一个注意力权重。
对于一个输入特征 X ,ODConv的实现过程主要包括以下方面:
-
SE 类型的注意力模块:
特征压缩: 首先将输入特征通过通道方向的全局平均池化操作压缩成一个特征向量,其长度等于输入通道数。
特征降维: 接下来,使用一个全连接层将特征向量映射到一个更低维的空间,降维比例 r 根据实验结果进行设置。
多头注意力分支: 共有四个头部,每个头部负责计算一个维度上的注意力:
(1)空间注意力分支: 输出大小为 k×k,负责计算位置注意力 αsi。
(2)通道注意力分支: 输出大小为 cin×1,负责计算通道注意力 αci。
(3)滤波器注意力分支: 输出大小为 cout×1,负责计算滤波器注意力 αfi。
(4)核注意力分支: 输出大小为 n×1,负责计算核注意力 αwi。
注意力归一化: 每个头部使用 Softmax 或 Sigmoid 函数将输出归一化,得到相应的注意力权重。 -
并行计算:四个注意力分支并行计算各自的注意力权重。
-
逐步应用:
位置注意力: 将 αsi 乘以卷积核的每个 k×k 空间位置上的卷积参数,实现位置维度的注意力机制。
通道注意力: 将 αci 乘以卷积核的每个滤波器的 cin 个输入通道,实现通道维度的注意力机制。
滤波器注意力: 将 αfi 乘以卷积核的 cout 个滤波器,实现滤波器维度的注意力机制。
核注意力: 将 αwi 乘以整个卷积核,实现核维度的注意力机制。 -
输出:将经过四种注意力调整的卷积核与输入特征 x 进行卷积操作,得到最终的输出特征 y。
Omni-dimensional Dynamic Convolution 结构图:
2、代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autogradclass Attention(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):super(Attention, self).__init__()attention_channel = max(int(in_planes * reduction), min_channel)self.kernel_size = kernel_sizeself.kernel_num = kernel_numself.temperature = 1.0self.avgpool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False)self.bn = nn.BatchNorm2d(attention_channel)self.relu = nn.ReLU(inplace=True)self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)self.func_channel = self.get_channel_attentionif in_planes == groups and in_planes == out_planes: # depth-wise convolutionself.func_filter = self.skipelse:self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)self.func_filter = self.get_filter_attentionif kernel_size == 1: # point-wise convolutionself.func_spatial = self.skipelse:self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)self.func_spatial = self.get_spatial_attentionif kernel_num == 1:self.func_kernel = self.skipelse:self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)self.func_kernel = self.get_kernel_attentionself._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)if isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def update_temperature(self, temperature):self.temperature = temperature@staticmethoddef skip(_):return 1.0def get_channel_attention(self, x):channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)return channel_attentiondef get_filter_attention(self, x):filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)return filter_attentiondef get_spatial_attention(self, x):spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)spatial_attention = torch.sigmoid(spatial_attention / self.temperature)return spatial_attentiondef get_kernel_attention(self, x):kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)return kernel_attentiondef forward(self, x):x = self.avgpool(x)x = self.fc(x)x = self.bn(x)x = self.relu(x)return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)class ODConv2d(nn.Module):""" kernel_size = 1 or 3 """def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=0, dilation=1, groups=1,reduction=0.0625, kernel_num=4):super(ODConv2d, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.kernel_size = kernel_sizeself.stride = strideself.padding = paddingself.dilation = dilationself.groups = groupsself.kernel_num = kernel_numself.attention = Attention(in_planes, out_planes, kernel_size, groups=groups,reduction=reduction, kernel_num=kernel_num)self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size),requires_grad=True)self._initialize_weights()if self.kernel_size == 1 and self.kernel_num == 1:self._forward_impl = self._forward_impl_pw1xelse:self._forward_impl = self._forward_impl_commondef _initialize_weights(self):for i in range(self.kernel_num):nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')def update_temperature(self, temperature):self.attention.update_temperature(temperature)def _forward_impl_common(self, x):channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)batch_size, in_planes, height, width = x.size()x = x * channel_attentionx = x.reshape(1, -1, height, width)aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)aggregate_weight = torch.sum(aggregate_weight, dim=1).view([-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,dilation=self.dilation, groups=self.groups * batch_size)output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))output = output * filter_attentionreturn outputdef _forward_impl_pw1x(self, x):channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)x = x * channel_attentionoutput = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,dilation=self.dilation, groups=self.groups)output = output * filter_attentionreturn outputdef forward(self, x):return self._forward_impl(x)if __name__ == '__main__':x = torch.randn(4, 512, 7, 7).cuda()model = ODConv2d(512, 512).cuda()out = model(x)print(out.shape)
本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。