YOLOv8改进 | 2023 | FocusedLinearAttention实现有效涨点

论文地址:官方论文地址

代码地址:官方代码地址

一、本文介绍

本文给大家带来的改进机制是Focused Linear Attention(聚焦线性注意力)是一种用于视觉Transformer模型的注意力机制(但是其也可以用在我们的YOLO系列当中从而提高检测精度),旨在提高效率和表现力。其解决了两个在传统线性注意力方法中存在的问题:聚焦能力和特征多样性。这种方法通过一个高效的映射函数和秩恢复模块来提高计算效率和性能,使其在处理视觉任务时更加高效和有效。简言之,Focused Linear Attention是对传统线性注意力方法的一种重要改进,提高了模型的聚焦能力和特征表达的多样性。通过本文你能够了解到:Focused Linear Attention的基本原理和框架,能够在你自己的网络结构中进行添加(需要注意的是一个FLAGFLOPs从8.9涨到了9.1)。

 专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备

实验效果对比:放在了第三章,有对比试验供大家参考

目录

一、本文介绍

二、Focused Linear Attention的机制原理

2.1 Softmax和线性注意力机制的对比

2.2 Focused Linear Attention的提出

2.3 效果对比

三、实验效果对比

四、FocusedLinearAttention代码

五、添加Focused Linear Attention到模型中

5.1 Focused Linear Attention的添加教程

5.2 Focused Linear Attention的yaml文件和训练截图

5.2.1 Focused Linear Attention的yaml文件

5.2.2 Focused Linear Attention的训练过程截图 

六、全文总结 


二、Focused Linear Attention的机制原理

2.1 Softmax和线性注意力机制的对比

上面的图片是关于比较Softmax注意力和线性注意力的差异。在这张图中,Q、K、V 分别代表查询、键和值矩阵,它们的维度为 R N×d。这里提到的几个关键点包括:

1. Softmax注意力:它需要计算查询和键之间的成对相似度,导致计算复杂度为 O(N^2 d)。这种方法在计算上是昂贵的,特别是当处理大规模数据时。

2. 线性注意力:通过适当的近似手段,线性注意力可以解耦Softmax操作,并通过先计算K^{T}V来改变计算顺序,从而将复杂度降低到 O(Nd^{^{2}})。由于在现代视觉Transformer设计中通道维度 d 通常小于标记数 N(例如,在DeiT中d=64, N=196,在Swin Transformer中d=32, N=49),线性注意力模块实际上降低了总体计算成本。

此处提出了线性注意力机制的优势(为了后面提出论文提到的注意力机制在线性注意力机制上的优化):线性注意力模块因此能够在节省计算成本的同时,享受更大的接收域和更高的吞吐量的好处。

总结:这张图片可能是在说明线性注意力如何在保持注意力机制核心功能的同时,提高计算效率,尤其是在处理大规模数据集时的优势。这种方法对于改善视觉Transformer的性能和效率具有重要意义(我下面会出将其用在RT-DETR的模型上看看效果)

2.2 Focused Linear Attention的提出

线性注意力的限制和改进: 尽管线性注意力降低了复杂度,但现有的线性注意力方法仍存在性能下降的问题,并可能因映射函数带来额外的计算开销。为了解决这些问题,作者提出了一个新颖的聚焦线性注意力(Focused Linear Attention)模块。该模块通过简单的映射函数调整查询和键的特征方向,使注意力权重更加明显。此外,还通过深度卷积(DWC)应用于原始注意力矩阵的秩恢复模块来增加特征多样性。

Focused Linear Attention(聚焦线性注意力)是一种用于视觉Transformer模型的注意力机制(但是其也可以用在我们的YOLO系列当中从而提高检测精度),旨在提高效率和表现力。它解决了传统线性注意力方法的两个主要问题:

1. 聚焦能力: 以往的线性注意力缺乏足够的聚焦能力,导致模型难以有效地关注重要特征。Focused Linear Attention通过改进的机制增强了这种聚焦能力。

2. 特征多样性: 传统方法在特征表达上缺乏多样性,影响了模型的表现力。Focused Linear Attention通过特殊的设计来增加特征的多样性和丰富性。

这种方法通过一个高效的映射函数和秩恢复模块来提高计算效率和性能,使其在处理视觉任务时更加高效和有效。

总结:Focused Linear Attention是对传统线性注意力方法的一种重要改进,提高了模型的聚焦能力和特征表达的多样性。

2.3 效果对比

上面的图片显示了多个视觉Transformer模型的性能和计算复杂度的比较。图中分为四个部分:

1. PVT: 对比了不同版本的PVT(Pyramid Vision Transformer),DeiT(Data-efficient Image Transformer),以及T2T(Tokens-to-Token ViT)的Top-1准确率和计算量(FLOPs)。

2. PVT v2: 类似地,展示了PVT v2、ConvNext、DAT(Deformable Attention Transformer)的性能对比。

3. Swin: 对比了Swin Transformer、CvT(Convolutional vision Transformer),以及CoTNet(Contextual Transformer Network)的模型。

4. CSwin: 展示了CSwin Transformer、MViTv2、CoAtNet的性能对比。

在每个图中,还包括了作者提出的FLatten版本的Transformer模型(标记为“Ours”),其在每个分类中都显示了相对较高的准确率或者在相似的FLOPs计算量下具有竞争力的准确率。

右侧的表格详细列出了不同模型的分辨率(Reso)、参数数量(#Params)、计算量(Flops)和Top-1准确率。表中突出了FLatten版本的Transformer模型在Top-1准确率上相对于原始模型的提升(括号中的百分点)。

个人总结:这张图片展示了通过改进的线性注意力模块,即FLatten模型,在保持或稍微增加计算量的前提下,提高了Transformer架构的图像识别准确率。

三、实验效果对比

实验效果图如下所示-> 

因为资源有限我发的文章都要做对比实验所以本次实验我只用了一百张图片检测的是安全帽训练了一百个epoch,该结果只能展示出该机制有效,但是并不能产生决定性结果,因为具体的效果还要看你的数据集和实验环境所影响。

 

四、FocusedLinearAttention代码

在场的FocusedLinearAttention代码是用于Transformer的想要将其用于YOLO上是需要进行很大改动的,所以我这里进行了挺多的改动的,创作不易而且免费给大家看,所以如果能够帮助到大家希望大家能给点个赞和关注支持一下。

import torch.nn as nn
import torch
from einops import rearrangeclass FocusedLinearAttention(nn.Module):def __init__(self, dim, num_patches=64, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, sr_ratio=1,focusing_factor=3.0, kernel_size=5):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.q = nn.Linear(dim, dim, bias=qkv_bias)self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.sr_ratio = sr_ratioif sr_ratio > 1:self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)self.norm = nn.LayerNorm(dim)self.focusing_factor = focusing_factorself.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,groups=head_dim, padding=kernel_size // 2)self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))# self.positional_encoding = nn.Parameter(torch.zeros(size=(1, num_patches // (sr_ratio * sr_ratio), dim)))def forward(self, x):B, C, H, W = x.shape  # 输入为四维:[批次大小, 通道数, 高度, 宽度]dtype, device = x.dtype, x.device# 调整输入以匹配原始模块的预期格式x = rearrange(x, 'b c h w -> b (h w) c')q = self.q(x)if self.sr_ratio > 1:x_ = x.permute(0, 2, 1).reshape(B, C, H, W)x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)x_ = self.norm(x_)kv = self.kv(x_).reshape(B, -1, 2, C).permute(2, 0, 1, 3)else:kv = self.kv(x).reshape(B, -1, 2, C).permute(2, 0, 1, 3)k, v = kv[0], kv[1]N = H * W  # 序列长度# 重新生成位置编码positional_encoding = nn.Parameter(torch.zeros(size=(1, N, self.dim), device=device))k = k + positional_encodingfocusing_factor = self.focusing_factorkernel_function = nn.ReLU()scale = nn.Softplus()(self.scale)q = kernel_function(q) + 1e-6k = kernel_function(k) + 1e-6q = q / scalek = k / scaleq_norm = q.norm(dim=-1, keepdim=True)k_norm = k.norm(dim=-1, keepdim=True)q = q ** focusing_factork = k ** focusing_factorq = (q / q.norm(dim=-1, keepdim=True)) * q_normk = (k / k.norm(dim=-1, keepdim=True)) * k_normbool = Falseif dtype == torch.float16:q = q.float()k = k.float()v = v.float()bool = Trueq, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)if i * j * (c + d) > c * d * (i + j):kv = torch.einsum("b j c, b j d -> b c d", k, v)x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)else:qk = torch.einsum("b i c, b j c -> b i j", q, k)x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)if self.sr_ratio > 1:v = nn.functional.interpolate(v.permute(0, 2, 1), size=x.shape[1], mode='linear').permute(0, 2, 1)if bool:v = v.to(torch.float16)x = x.to(torch.float16)num = int(v.shape[1] ** 0.5)feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")x = x + feature_mapx = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)x = self.proj(x)x = self.proj_drop(x)x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)return x

五、添加Focused Linear Attention到模型中

5.1 Focused Linear Attention的添加教程

添加教程这里不再重复介绍、因为专栏内容有许多,添加过程又需要截特别图片会导致文章大家读者也不通顺如果你已经会添加注意力机制了,可以跳过本章节,如果你还不会,大家可以看我下面的文章,里面详细的介绍了拿到一个任意机制(C2f、Conv、Bottleneck、Loss、DetectHead)如何添加到你的网络结构中去。

注意:本文的注意力机制是有参数的!!!

这个注意力机制也可以放在C2f和Bottleneck中进行使用可以即插即用,个人觉得放在Bottleneck中效果比较好。

添加教程->YOLOv8改进 | 如何在网络结构中添加注意力机制、C2f、卷积、Neck、检测头

需要注意的是本文的task.py配置的代码如下(你现在不知道其是干什么用的可以看添加教程)-> 

from .modules.FocusLinearAttention import FocusedLinearAttention as FLAttention
        elif m is FLAttention:args = [ch[f], *args]

5.2 Focused Linear Attention的yaml文件和训练截图

5.2.1 Focused Linear Attention的yaml文件

下面的是放在Neck部分的截图,参数我以及设定好了,无需进行传入会根据模型输入自动计算,帮助大家省了一些事。

下面的是放在C2f中的yaml配置。 

 

5.2.2 Focused Linear Attention的训练过程截图 

下面是我添加了Focused Linear Attention的训练截图。

下面的是将FLAttention机制我添加到了C2f和Bottleneck。

下面的是我将FLAttention放在Neck中的截图。 

六、全文总结 

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/169082.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++ Boost 异步网络编程基础

Boost库为C提供了强大的支持,尤其在多线程和网络编程方面。其中,Boost.Asio库是一个基于前摄器设计模式的库,用于实现高并发和网络相关的开发。Boost.Asio核心类是io_service,它相当于前摄模式下的Proactor角色。所有的IO操作都需…

leetCode 100. 相同的树 和 leetCode 101. 对称二叉树 和 110. 平衡二叉树 和 199. 二叉树的右视图

1.leetCode 100. 相同的树 C代码: class Solution { public:bool isSameTree(TreeNode* p, TreeNode* q) {if(p nullptr || q nullptr) return pq;return p->val q->val && isSameTree(p->left,q->left) && isSameTree(p->righ…

详解Java中的异常体系机构(throw,throws,try-catch,finally,自定义异常)

目录 一.异常的概念 二.异常的体系结构 三.异常的处理 异常处理思路 LBYL:Look Before You Leap EAFP: Its Easier to Ask Forgiveness than Permission 异常抛出throw 异常的捕获 提醒声明throws try-catch捕获处理 finally的作用 四.自定义异常类 一.异…

openEuler20.03学习01-创建虚拟机

赶个时髦,开始学习openEuler 20.03 (LTS-SP3) 操作系统iso下载地址:https://repo.openeuler.openatom.cn/openEuler-20.03-LTS-SP3/ISO/x86_64/openEuler-20.03-LTS-SP3-x86_64-dvd.iso 公司有现成的vmware环境,创建虚拟机i测试&#xff0c…

Java视频直播技术架构详解

引言 随着互联网的不断发展,视频直播技术成为在线娱乐和沟通的重要组成部分。在众多的视频直播平台中,Java作为一种强大而灵活的编程语言,被广泛应用于构建稳定、高效的视频直播系统。本文将深入探讨Java视频直播技术的架构,包括…

EM@常见平面曲线的方程的不同表示方式

文章目录 abstract常见曲线的不同形式小结:一览表分析圆锥曲线的极坐标方程非标准位置的圆锥曲线参数方程应用比较 refs abstract 常见平面曲线的方程的不同表示方式 常见曲线的不同形式 下面以平面曲线为对象讨论参数方程通常是对普通方程的补充和增强,曲线的普通方程(直角…

【pandas】数据透视表【pivot_table】

pivot_table pandas的pivot_table函数是一个非常有用的工具,用于创建一个数据透视表,这是一种用于数据总结和分析的表格形式。 以下是pivot_table的基本语法: pandas.pivot_table(data, valuesNone, indexNone, columnsNone, aggfuncmean,…

[JVM] 字节二面~简述垃圾回收以及类加载过程,别说八股文,我想看到你自己的理解

GC 的三种收集方法:标记清除、标记整理、复制算法的原理与特点,分别用在什么地方,如果让你优化收集方法,有什么思路? ● 标记清除: 先标记,标记完毕之后再清除,效率不高&#xff0c…

基于opencv+ImageAI+tensorflow的智能动漫人物识别系统——深度学习算法应用(含python、JS、模型源码)+数据集(三)

目录 前言总体设计系统整体结构图系统流程图 运行环境爬虫模型训练实际应用 模块实现1. 数据准备1)爬虫下载原始图片2)手动筛选图片 2. 数据处理1)切割得到人物脸部2)重新命名处理后的图片3)添加到数据集 3. 模型训练及…

系列五、Spring整合MyBatis不忽略mapper接口同目录的xxxMapper.xml

一、概述 默认情况下maven要求我们将xml配置、properties配置等都放在resources目录下,如果我们强行将其放在java目录,即将xxxMapper.xml和xxxMapper接口放在同一个目录下,那么默认情况下maven打包时会将这个xxxMapper.xml文件忽略掉&#xf…

C++中const有什么作用

const用于定义常量:const定义的常量编译器可以对其进行数据静态类型安全检查。const修饰函数形式参数:当输入参数为用户自定义的类型和抽象数据类型时,应该将值传递改为const &传递,可以提高效率。 void fun(A a); void fun(…

十大排序之归并排序(详解)

文章目录 🐒个人主页🏅算法思维框架📖前言: 🎀归并排序 时间复杂度O(n*logn)🎇1. 算法步骤思想🎇2、动画演示🎇3.代码实现 🐒个人主页 🏅算法思维框架 &#…

GraphQL—构建多服务架构的数据层

简介 作为 Facebook 在 2015 年推出的查询语言,GraphQL 能够对 API 中的数据提供一套易于理解的完整描述,使得客户端能够更加准确的获得它需要的数据 现在的web系统大多是基于restful的,我们知道,REST强调以资源来划分系统&#x…

老HIS面临的问题总结

在从业的10余年时间,从事pb开发和教学多年,应朋友的要求,写一篇关于老his的问题,今天终于得空书写。老his自1995年立项至今已走过20余年,目前仍有上千家医院在使用,可以说它在医疗信息化水平的提升和行业人…

Python基础入门例程64-NP64 输出前三同学的成绩(元组)

最近的博文: Python基础入门例程63-NP63 修改报名名单(元组)-CSDN博客 Python基础入门例程62-NP62 运动会双人项目(元组)-CSDN博客 Python基础入门例程61-NP61 牛牛的矩阵相加(循环语句)-CSDN博客 目录 最近的博文: 描述

lvm 扩容根分区失败记录

lvm 扩容根分区失败记录 1、问题描述2、错误描述3、解决方法重启系统进入grub界面,选择kernel 2.x 启动系统。然后同样的resize2fs命令扩容成功。 1、问题描述 根分区不足。 系统有2个内核版本,一个是kernel 2.x,另一个是kernel 4.x。 这次l…

C语言剔除相关数(ZZULIOJ1204:剔除相关数)

题目描述 一个数与另一个数如果含有相同数字和个数的字符&#xff0c;则称两数相关。现有一堆乱七八糟的整数&#xff0c;里面可能充满了彼此相关的数&#xff0c;请你用一下手段&#xff0c;自动地将其剔除。 输入&#xff1a;多实例测试。每组数据包含一个n(n<1000)&#…

知行之桥EDI系统HTTP签名验证

本文简要概述如何在知行之桥EDI系统中使用 HTTP 签名身份验证&#xff0c;并将使用 CyberSource 作为该集成的示例。 API 概述 首字母缩略词 API 代表“应用程序编程接口”。这听起来可能很复杂&#xff0c;但真正归结为 API 是一种允许两个不同实体相互通信的软件。自开发以…

CSS 属性列表

CSS属性列表 序号 属性类别 属性 描述 1 动画属性 keyframes 定义一个动画,keyframes定义的动画名称用来被animation-name所使用。 2 animation 复合属性。检索或设置对象所应用的动画特效。 3 animation-name 检索或设置对象所应用的动画名称 ,必须与规则keyfra…

2023.11.25-电商项目建设业务学习1-指标,业务流程,核销

目录 1.指标分类(原子指标,派生指标,衍生指标) 2.一些业务名词 3.四大业务流程-销售需求 3.1-线上线下销售 3.2线上线下退款 4.四大业务流程-会员业务 5.四大业务流程-供应链业务 6.四大业务流程-商城业务 7.核销主题需求分析 1.指标分类(原子指标,派生指标,衍生指标) 原…