CVPR | CNN融合注意力机制,芜湖起飞!

**标题:**On the Integration of Self-Attention and Convolution
**论文链接:**https://arxiv.org/pdf/2111.14556
**代码链接:**https://github.com/LeapLabTHU/ACmix

创新点

1. 揭示卷积和自注意力的内在联系

文章通过重新分解卷积和自注意力模块的操作,发现它们在第一阶段(特征投影)都依赖于 1×1 卷积操作,并且这一阶段占据了大部分的计算复杂度(与通道数的平方成正比)。这一发现为整合两种模块提供了理论基础。

2. 提出 ACmix 模型

基于上述发现,作者提出了 ACmix 模型,它通过共享 1×1 卷积操作来同时实现卷积和自注意力的功能。具体来说:
**第一阶段:**输入特征通过 1×1 卷积投影,生成中间特征。
**第二阶段:**这些中间特征分别用于卷积路径(通过移位和聚合操作)和自注意力路径(计算注意力权重并聚合值)。最终,两条路径的输出通过可学习的权重加权求和,得到最终输出。

3. 改进的移位和聚合操作

文章还提出了一种改进的移位操作,通过使用 固定卷积核的分组卷积 来替代传统的张量移位操作。这种方法不仅提高了计算效率,还允许卷积核的可学习性,进一步增强了模型的灵活性。

4. 适应性路径权重

ACmix 引入了两个可学习的标量参数(α 和 β),用于动态调整卷积路径和自注意力路径的权重。这种设计不仅提高了模型的灵活性,还允许模型在不同深度上自适应地选择更适合的特征提取方式。实验表明,这种设计在模型的不同阶段表现出不同的偏好,例如在早期阶段更倾向于卷积,在后期阶段更倾向于自注意力。

整体结构

第一阶段:特征投影

在第一阶段,输入特征通过三个1×1卷积进行投影,分别生成查询(query)、键(key)和值(value)特征映射。这些特征映射随后被重塑为N块,形成一个包含3×N特征映射的中间特征集。

第二阶段:特征聚合

在第二阶段,中间特征集被分为两个路径进行处理:

  • **自注意力路径:**将中间特征集分为N组,每组包含三个特征映射(分别对应查询、键和值)。这些特征映射按照传统的多头自注意力机制进行处理,计算注意力权重并聚合值。
  • **卷积路径:**通过轻量级的全连接层生成k²个特征映射(k为卷积核大小)。这些特征映射通过移位和聚合操作,以类似传统卷积的方式处理输入特征,从局部感受野收集信息。

输出整合

最后,自注意力路径和卷积路径的输出通过两个可学习的标量参数(α和β)加权求和,得到最终的输出。

改进的移位和聚合操作

为了提高计算效率,ACmix模型采用了改进的移位操作,通过固定卷积核的分组卷积来替代传统的张量移位操作。这种方法不仅提高了计算效率,还允许卷积核的可学习性,进一步增强了模型的灵活性。

模型的灵活性和泛化能力

ACmix模型不仅适用于标准的自注意力机制,还可以与各种变体(如Patchwise Attention、Window Attention和Global Attention)结合使用。这种设计使得ACmix能够适应不同的任务需求,具有广泛的适用性。

消融实验

1. 结合两个路径的输出

消融实验探索了卷积和自注意力输出的不同组合方式对模型性能的影响。实验结果表明:

  • **卷积和自注意力的组合优于单一路径:**使用卷积和自注意力模块的组合始终优于仅使用单一路径(如仅卷积或仅自注意力)的模型。
  • **可学习参数的灵活性:**通过引入可学习的参数(如α和β)来动态调整卷积和自注意力路径的权重,ACmix能够根据网络中不同位置的需求自适应地调整路径强度,从而获得更高的灵活性和性能。

2. 组卷积核的选择

实验还对组卷积核的设计进行了验证,结果表明:

  • **用组卷积替代张量位移:**通过使用组卷积替代传统的张量位移操作,显著提高了模型的推理速度。
  • **可学习卷积核和初始化:**使用可学习的卷积核并结合精心设计的初始化方法,进一步增强了模型的灵活性,并有助于提升最终性能。

3. 不同路径的偏好

ACmix模型引入了两个可学习标量α和β,用于动态调整卷积和自注意力路径的权重。通过平行实验,观察到以下趋势:

  • **早期阶段偏好卷积:**在Transformer模型的早期阶段,卷积作为特征提取器表现更好。
  • **中间阶段混合使用:**在网络的中间阶段,模型倾向于混合使用两种路径,并逐渐增加对卷积的偏好。
  • **后期阶段偏好自注意力:**在网络的最后阶段,自注意力表现优于卷积。

4. 对模型性能的影响

这些消融实验结果表明,ACmix模型通过合理结合卷积和自注意力的优势,并优化计算路径,不仅在多个视觉任务上取得了显著的性能提升,还保持了较高的计算效率

ACmix模块的作用

1. 融合卷积和自注意力的优势

ACmix模块通过结合卷积的局部特征提取能力和自注意力的全局感知能力,实现了一种高效的特征融合策略。这种设计使得模型能够同时利用卷积的局部感受野特性和自注意力的灵活性。

2. 优化计算路径

ACmix通过优化计算路径和减少重复计算,提高了整体模块的计算效率。具体来说,它通过1×1卷积对输入特征图进行投影,生成中间特征,然后根据不同的范式(卷积和自注意力)分别重用和聚合这些中间特征。这种设计不仅减少了计算开销,还提升了模型性能。

3. 改进的位移与求和操作

在卷积路径中,ACmix采用深度可分离卷积(depthwise convolution)来替代低效的张量位移操作,从而提高了实际推理效率。

4. 动态调整路径权重

ACmix引入了两个可学习的标量参数(α和β),用于动态调整卷积和自注意力路径的权重。这种设计使得模型能够根据网络中不同位置的需求自适应地调整路径强度,从而获得更高的灵活性。

5. 广泛的应用潜力

ACmix在多个视觉任务(如图像分类、语义分割和目标检测)上均显示出优于单一机制(仅卷积或仅自注意力)的性能,展示了其广泛的应用潜力。

6. 实验验证

实验结果表明,ACmix在保持较低计算开销的同时,能够显著提升模型的性能。例如,在ImageNet分类任务中,ACmix模型在相同的FLOPs或参数数量下表现出色,并且在与竞争对手的基准比较中取得了持续的改进。此外,ACmix在ADE20K语义分割任务和COCO目标检测任务中也显示出明显的改进

代码实现

import torch
import torch.nn as nndef position(H, W, is_cuda=True):if is_cuda:loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)else:loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)return locdef stride(x, stride):b, c, h, w = x.shapereturn x[:, :, ::stride, ::stride]def init_rate_half(tensor):if tensor is not None:tensor.data.fill_(0.5)def init_rate_0(tensor):if tensor is not None:tensor.data.fill_(0.)class ACmix(nn.Module):def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):super(ACmix, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.head = headself.kernel_att = kernel_attself.kernel_conv = kernel_convself.stride = strideself.dilation = dilationself.rate1 = torch.nn.Parameter(torch.Tensor(1))self.rate2 = torch.nn.Parameter(torch.Tensor(1))self.head_dim = self.out_planes // self.headself.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)self.softmax = torch.nn.Softmax(dim=1)self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,stride=stride)self.reset_parameters()def reset_parameters(self):init_rate_half(self.rate1)init_rate_half(self.rate2)kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)for i in range(self.kernel_conv * self.kernel_conv):kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)self.dep_conv.bias = init_rate_0(self.dep_conv.bias)def forward(self, x):q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)scaling = float(self.head_dim) ** -0.5b, c, h, w = q.shapeh_out, w_out = h // self.stride, w // self.stride# ### att# ## positional encodingpe = self.conv_p(position(h, w, x.is_cuda))q_att = q.view(b * self.head, self.head_dim, h, w) * scalingk_att = k.view(b * self.head, self.head_dim, h, w)v_att = v.view(b * self.head, self.head_dim, h, w)if self.stride > 1:q_att = stride(q_att, self.stride)q_pe = stride(pe, self.stride)else:q_pe = peunfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,self.kernel_att * self.kernel_att, h_out,w_out) # b*head, head_dim, k_att^2, h_out, w_outunfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,w_out) # 1, head_dim, k_att^2, h_out, w_outatt = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)att = self.softmax(att)out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,h_out, w_out)out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)## convf_all = self.fc(torch.cat([q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),v.view(b, self.head, self.head_dim, h * w)], 1))f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])out_conv = self.dep_conv(f_conv)return self.rate1 * out_att + self.rate2 * out_conv#输入 B C H W, 输出 B C H W
if __name__ == '__main__':block = ACmix(in_planes=64, out_planes=64)input = torch.rand(3, 64, 32, 32)output = block(input)print(input.size(), output.size())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/67997.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

module ‘matplotlib.cm‘ has no attribute ‘get_cmap‘

目录 解决方法1: 解决方法2,新版api改了: module matplotlib.cm has no attribute get_cmap 报错代码: cmap matplotlib.cm.get_cmap(Oranges) 解决方法1: pip install matplotlib3.7.3 解决方法2,新版…

使用Nuxt.js实现服务端渲染(SSR):提升SEO与性能的完整指南

使用Nuxt.js实现服务端渲染(SSR):提升SEO与性能的完整指南 使用Nuxt.js实现服务端渲染(SSR):提升SEO与性能的完整指南1. 服务端渲染(SSR)核心概念1.1 CSR vs SSR vs SSG1.2 SSR工作原…

解释 Java 中的反射机制和动态代理的原理?

反射机制是Java语言的一个特性,它允许程序在运行时检查和操作类、方法、字段等。 通过反射,我们可以在运行时获取类的信息,创建对象,调用方法和访问字段,即使这些信息在编译时是未知的。 反射的基本用法 import jav…

http状态码:504 Gateway Timeout(网关超时)的原有以及排查问题的思路

504 Gateway Timeout(网关超时) 是一种常见的HTTP错误状态码,表示服务器作为网关或代理时,未能及时从上游服务器收到响应。以下是它的原因和排查问题的思路: 1. 504错误的含义 定义:服务器作为网关或代理时…

Linux 安装 RabbitMQ

Linux下安装RabbitMQ 1 、获取安装包 # 地址 https://github.com/rabbitmq/erlang-rpm/releases/download/v21.3.8.9/erlang-21.3.8.9-1.el7.x86_64.rpm erlang-21.3.8.9-1.el7.x86_64.rpmsocat-1.7.3.2-1.el6.lux.x86_64.rpm# 地址 https://github.com/rabbitmq/rabbitmq-se…

LOCAL_PREBUILT_JNI_LIBS使用说明

LOCAL_PREBUILT_JNI_LIBS使用说明 使用LOCAL_PREBUILT_JNI_LIBS,可用于控制APK集成时,其相关so的集成方式。 比如,用于将APK中的so,抽取出来。 LOCAL_PREBUILT_JNI_LIBS : \lib/arm64-v8a/libNativeCore.so \lib/arm64-v8a/liba…

Java中的object类

1.Object类是什么? 🟪Object 是 Java 类库中的一个特殊类,也是所有类的父类(超类),位于类继承层次结构的顶端。也就是说,Java 允许把任何类型的对象赋给 Object 类型的变量。 🟦Java里面除了Object类,所有的…

uniapp小程序自定义中间凸起样式底部tabbar

我自己写的自定义的tabbar效果图 废话少说咱们直接上代码,一步一步来 第一步: 找到根目录下的 pages.json 文件,在 tabBar 中把 custom 设置为 true,默认值是 false。list 中设置自定义的相关信息, pagePath&#x…

四、GPIO中断实现按键功能

4.1 GPIO简介 输入输出(I/O)是一个非常重要的概念。I/O泛指所有类型的输入输出端口,包括单向的端口如逻辑门电路的输入输出管脚和双向的GPIO端口。而GPIO(General-Purpose Input/Output)则是一个常见的术语&#xff0c…

vscode+CMake+Debug实现 及权限不足等诸多问题汇总

环境说明 有空再补充 直接贴两个json tasks.json {"version": "2.0.0","tasks": [{"label": "cmake","type": "shell","command": "cmake","args": ["../"…

【Elasticsearch】post_filter

post_filter是 Elasticsearch 中的一种后置过滤机制,用于在查询执行完成后对结果进行过滤。以下是关于post_filter的详细介绍: 工作原理 • 查询后过滤:post_filter在查询执行完毕后对返回的文档集进行过滤。这意味着所有与查询匹配的文档都…

《数据可视化新高度:Graphy的AI协作变革》

在数据洪流奔涌的时代,企业面临的挑战不再仅仅是数据的收集,更在于如何高效地将数据转化为洞察,助力决策。Graphy作为一款前沿的数据可视化工具,凭借AI赋能的团队协作功能,为企业打开了数据协作新局面,重新…

Vue 2 与 Vue 3 的主要区别

Vue.js 是一个流行的前端框架,用于构建用户界面和单页应用。自从 Vue 2 发布以来,社区对其进行了广泛的应用和扩展,而 Vue 3 的发布则带来了许多重要的改进和新特性。 性能提升 Vue 3 在响应式系统上进行了重大的改进,采用了基于…

从零开始:用Qt开发一个功能强大的文本编辑器——WPS项目全解析

文章目录 引言项目功能介绍1. **文件操作**2. **文本编辑功能**3. **撤销与重做**4. **剪切、复制与粘贴**5. **文本查找与替换**6. **打印功能**7. **打印预览**8. **设置字体颜色**9. **设置字号**10. **设置字体**11. **左对齐**12. **右对齐**13. **居中对齐**14. **两侧对…

【IoCDI】_Spring的基本扫描机制

目录 1. 创建测试项目 2. 改变启动类所属包 3. 使用ComponentScan 4. Spring基本扫描机制 程序通过注解告诉Spring希望哪些bean被管理,但在仅使用Bean时已经发现,Spring需要根据五大类注解才能进一步扫描方法注解。 由此可见,Spring对注…

vue 引入百度地图和高德天气 都得获取权限

vue接入百度地图---获取ak https://blog.csdn.net/qq_57144407/article/details/143430661 vue接入高德天气, 需要授权----获取key https://www.jianshu.com/p/09ddd698eebe

通向AGI之路:人工通用智能的技术演进与人类未来

文章目录 引言:当机器开始思考一、AGI的本质定义与技术演进1.1 从专用到通用:智能形态的范式转移1.2 AGI发展路线图二、突破AGI的五大技术路径2.1 神经符号整合(Neuro-Symbolic AI)2.2 世界模型架构(World Models)2.3 具身认知理论(Embodied Cognition)三、AGI安全:价…

python中的命名规范

在python中,命名规范是编写清晰,可读性强代码的重要部分,遵循这些规范可以使代码更易于理解和维护。 Type命名约定命名例子函数(Function)小写单词,下划线分割单词function,delta_function方法&#xff08…

【工具变量】中国省级八批自由贸易试验区设立及自贸区设立数据(2024-2009年)

一、测算方式:参考C刊《中国软科学》任晓怡老师(2022)的做法,使用自由贸易试验区(Treat Post) 表征,Treat为个体不随时间变化的虚拟变量,如果该城市设立自由贸易试验区则赋值为1,反之赋值为0&am…

Java进阶总结——集合

Java进阶总结——集合 说明:对于以上的框架图有如下几点说明 1.所有集合类都位于java.util包下。Java的集合类主要由两个接口派生而出:Collection和Map,Collection和Map是Java集合框架的根接口,这两个接口又包含了一些子接口或实…