YOLOv8改进 - 注意力机制 | HaloNet 局部自注意力网络通过分块与扩展感受野实现高效空间交互建模

前言

本文介绍了局部自注意力机制及其在YOLOv8中的结合应用。自注意力机制有潜力提升计算机视觉系统性能,为此我们提出扩展方法并结合高效实现方式,开发了HaloNets模型家族。局部自注意力通过关注输入数据局部区域捕捉特征关系,具有计算效率高、增强局部特征捕捉等优势。我们将HaloAttention代码引入指定目录,在ultralytics/nn/tasks.py中注册,配置yolov8 - HaloAttention.yaml文件,最后通过实验脚本和结果验证了改进模型的有效性。

文章目录: YOLOv8改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总

专栏链接: YOLOv8改进专栏

文章目录

  • 前言
  • 介绍
    • 摘要
  • 文章链接
  • 基本原理
      • 1. **基本原理**
      • 2. **计算过程**
      • 3. **优势**
      • 4. **应用**
      • 5. **局限性**
  • 核心代码
  • 引入代码
  • 注册
    • 步骤1:
    • 步骤2
  • 配置yolov8_HaloAttention.yaml
  • 实验
    • 脚本
    • 结果

介绍

摘要

自注意力机制因其与参数无关的感受野扩展能力以及基于内容的交互方式,被认为有潜力提升计算机视觉系统的性能,这与卷积的参数依赖型感受野扩展和与内容无关的交互方式形成了鲜明对比。最近的研究表明,与基线卷积模型(如 ResNet-50)相比,自注意力模型在精度-参数权衡方面取得了令人鼓舞的改进。

在这项工作中,我们旨在开发不仅能超越经典基线模型,还能超越高性能卷积模型的自注意力模型。我们提出了两种自注意力的扩展方法,并结合一种更高效的自注意力实现方式,提升了这些模型的速度、内存使用效率和准确性。基于这些改进,我们开发了一个新的自注意力模型家族,称为HaloNets,其在参数受限的 ImageNet 分类基准测试中达到了最先进的精度。

在初步的迁移学习实验中,我们发现 HaloNet 模型的表现优于体积更大的模型,并且在推理性能方面表现更佳。在更具挑战性的任务中,例如目标检测和实例分割,我们提出的简单局部自注意力与卷积混合模型相较于非常强的基线模型也表现出改进。这些结果进一步证明了自注意力模型在传统上由卷积模型主导的领域中的有效性。

文章链接

论文地址:论文地址

代码地址:代码地址

基本原理

局部自注意力(Local Self-Attention)是一种自注意力机制的变体,主要用于处理图像和其他高维数据。它通过关注输入数据的局部区域来捕捉特征之间的关系,具有以下几个关键特点和优势:

1.基本原理

局部自注意力的核心思想是计算输入数据中每个位置与其邻近位置之间的关系,而不是像全局自注意力那样考虑所有位置。这种方法通过限制注意力的范围来减少计算复杂度,同时仍然能够有效地捕捉局部特征。

2.计算过程

在局部自注意力中,对于输入的每个位置,模型只计算该位置与其周围一定范围内的其他位置的注意力权重。这通常涉及以下步骤:

  • 定义局部窗口:为每个位置定义一个局部窗口,窗口的大小可以根据任务需求进行调整。
  • 计算注意力权重:在局部窗口内,计算每个位置的注意力权重,通常使用点积或其他相似性度量。
  • 加权求和:使用计算得到的注意力权重对局部窗口内的特征进行加权求和,生成新的特征表示。

3.优势

  • 计算效率:局部自注意力显著降低了计算复杂度,因为它只关注局部区域,而不是整个输入。这使得在处理高分辨率图像时,模型能够更高效地运行。
  • 增强局部特征捕捉:局部自注意力能够更好地捕捉图像中的局部特征,如边缘、纹理等,这对于图像分类、目标检测等任务非常重要。
  • 灵活性:局部自注意力可以与卷积神经网络(CNN)等其他网络结构结合使用,形成混合模型,进一步提升性能。

4.应用

局部自注意力在多个计算机视觉任务中得到了广泛应用,包括:

  • 图像分类:通过捕捉局部特征来提高分类准确性。
  • 目标检测:在检测过程中,局部自注意力可以帮助模型更好地理解对象的局部结构。
  • 实例分割:在分割任务中,局部自注意力能够有效地处理不同对象之间的关系。

5.局限性

尽管局部自注意力具有许多优势,但它也存在一些局限性:

  • 信息丢失:由于只关注局部区域,可能会丢失全局上下文信息,影响模型的整体性能。
  • 窗口大小选择:局部窗口的大小需要根据具体任务进行调整,过小可能导致信息不足,过大则可能增加计算负担。

核心代码

classHaloAttention(nn.Module):def__init__(self,*,dim,block_size,halo_size,dim_head=64,heads=8):super().__init__()asserthalo_size>0,'halo size must be greater than 0'self.dim=dim self.heads=heads self.scale=dim_head**-0.5self.block_size=block_size self.halo_size=halo_size inner_dim=dim_head*heads self.rel_pos_emb=RelPosEmb(block_size=block_size,rel_size=block_size+(halo_size*2),dim_head=dim_head)self.to_q=nn.Linear(dim,inner_dim,bias=False)self.to_kv=nn.Linear(dim,inner_dim*2,bias=False)self.to_out=nn.Linear(inner_dim,dim)defforward(self,x):b,c,h,w,block,halo,heads,device=*x.shape,self.block_size,self.halo_size,self.heads,x.deviceasserth%block==0andw%block==0,'fmap dimensions must be divisible by the block size'assertc==self.dim,f'channels for input ({c}) does not equal to the correct dimension ({self.dim})'# get block neighborhoods, and prepare a halo-ed version (blocks with padding) for deriving key valuesq_inp=rearrange(x,'b c (h p1) (w p2) -> (b h w) (p1 p2) c',p1=block,p2=block)kv_inp=F.unfold(x,kernel_size=block+halo*2,stride=block,padding=halo)kv_inp=rearrange(kv_inp,'b (c j) i -> (b i) j c',c=c)# derive queries, keys, valuesq=self.to_q(q_inp)k,v=self.to_kv(kv_inp).chunk(2,dim=-1)# split headsq,k,v=map(lambdat:rearrange(t,'b n (h d) -> (b h) n d',h=heads),(q,k,v))# scaleq*=self.scale# attentionsim=einsum('b i d, b j d -> b i j',q,k)# add relative positional biassim+=self.rel_pos_emb(q)# mask out padding (in the paper, they claim to not need masks, but what about padding?)mask=torch.ones(1,1,h,w,device=device)mask=F.unfold(mask,kernel_size=block+(halo*2),stride=block,padding=halo)mask=repeat(mask,'() j i -> (b i h) () j',b=b,h=heads)mask=mask.bool()max_neg_value=-torch.finfo(sim.dtype).maxsim.masked_fill_(mask,max_neg_value)# attentionattn=sim.softmax(dim=-1)# aggregateout=einsum('b i j, b j d -> b i d',attn,v)# merge and combine headsout=rearrange(out,'(b h) n d -> b n (h d)',h=heads)out=self.to_out(out)# merge blocks back to original feature mapout=rearrange(out,'(b h w) (p1 p2) c -> b c (h p1) (w p2)',b=b,h=(h//block),w=(w//block),p1=block,p2=block)returnout

引入代码

在根目录下的ultralytics/nn/目录,新建一个attention目录,然后新建一个以HaloAttention为文件名的py文件, 把代码拷贝进去。

importtorchfromtorchimportnn,einsumimporttorch.nn.functionalasFfromeinopsimportrearrange,repeat# 用于高效张量操作的库defto(x):"""生成与输入张量相同设备和数据类型的配置字典"""return{"device":x.device,"dtype":x.dtype}defpair(x):"""将输入转换为元组形式,用于统一处理尺寸参数"""return(x,x)ifnotisinstance(x,tuple)elsexdefexpand_dim(t,dim,k):"""扩展张量维度:在指定维度上扩展k倍"""t=t.unsqueeze(dim=dim)# 增加新维度expand_shape=[-1]*len(t.shape)expand_shape[dim]=kreturnt.expand(*expand_shape)# 扩展指定维度defrel_to_abs(x):"""将相对位置编码转换为绝对位置编码"""b,l,m=x.shape# batch_size, 序列长度, 相对位置数r=(m+1)//2# 计算有效半径# 添加列填充以对齐维度col_pad=torch.zeros((b,l,1),**to(x))x=torch.cat((x,col_pad),dim=2)# 重组张量并填充flat_x=rearrange(x,"b l c -> b (l c)")flat_pad=torch.zeros((b,m-l),**to(x))flat_x_padded=torch.cat((flat_x,flat_pad),dim=1)# 最终形状调整final_x=flat_x_padded.reshape(b,l+1,m)returnfinal_x[:,:l,-r:]# 截取有效区域defrelative_logits_1d(q,rel_k):"""计算一维相对位置注意力分数"""b,h,w,_=q.shape# batch, height, width, dimr=(rel_k.shape[0]+1)//2# 相对位置半径# 爱因斯坦求和计算注意力分数logits=einsum("b x y d, r d -> b x y r",q,rel_k)logits=rearrange(logits,"b x y r -> (b x) y r")# 转换为绝对位置并重组logits=rel_to_abs(logits)logits=logits.reshape(b,h,w,r)returnexpand_dim(logits,dim=2,k=r)# 扩展维度用于后续计算classRelPosEmb(nn.Module):"""相对位置编码模块"""def__init__(self,block_size,rel_size,dim_head):super().__init__()self.block_size=block_size scale=dim_head**-0.5# 缩放因子# 初始化相对位置参数self.rel_height=nn.Parameter(torch.randn(rel_size*2-1,dim_head)*scale)self.rel_width=nn.Parameter(torch.randn(rel_size*2-1,dim_head)*scale)defforward(self,q):"""前向传播:计算宽度和高度方向的相对位置偏置"""block=self.block_size q=rearrange(q,"b (x y) c -> b x y c",x=block)# 宽度方向相对位置rel_logits_w=relative_logits_1d(q,self.rel_width)rel_logits_w=rearrange(rel_logits_w,"b x i y j-> b (x y) (i j)")# 高度方向相对位置(转置处理)q=rearrange(q,"b x y d -> b y x d")rel_logits_h=relative_logits_1d(q,self.rel_height)rel_logits_h=rearrange(rel_logits_h,"b x i y j -> b (y x) (j i)")returnrel_logits_w+rel_logits_h# 合并位置偏置classHaloAttention(nn.Module):"""Halo注意力机制模块"""def__init__(self,dim,block_size,halo_size,dim_head=64,heads=8):super().__init__()asserthalo_size>0,"halo size必须大于0"# 参数初始化self.dim=dim# 输入维度self.heads=heads# 注意力头数self.scale=dim_head**-0.5# 缩放因子self.block_size=block_size# 块大小self.halo_size=halo_size# 扩展区域大小# 网络层定义self.rel_pos_emb=RelPosEmb(block_size,block_size+2*halo_size,dim_head)self.to_q=nn.Linear(dim,dim_head*heads,bias=False)# 查询变换self.to_kv=nn.Linear(dim,2*dim_head*heads,bias=False)# 键值变换self.to_out=nn.Linear(dim_head*heads,dim)# 输出变换defforward(self,x):"""前向传播过程"""b,c,h,w,block,halo,heads,device=(*x.shape,self.block_size,self.halo_size,self.heads,x.device)# 输入验证asserth%block==0andw%block==0,"特征图尺寸必须能被块大小整除"assertc==self.dim,f"输入通道数{c}与设定维度{self.dim}不符"# 重组输入为块结构q_inp=rearrange(x,"b c (h p1) (w p2) -> (b h w) (p1 p2) c",p1=block,p2=block)# 使用unfold提取带halo区域的键值对kv_inp=F.unfold(x,kernel_size=block+2*halo,stride=block,padding=halo)kv_inp=rearrange(kv_inp,"b (c j) i -> (b i) j c",c=c)# 生成查询、键、值q=self.to_q(q_inp)k,v=self.to_kv(kv_inp).chunk(2,dim=-1)# 多头注意力拆分q,k,v=map(lambdat:rearrange(t,"b n (h d) -> (b h) n d",h=heads),(q,k,v))q*=self.scale# 缩放查询向量# 计算注意力分数sim=einsum("b i d, b j d -> b i j",q,k)sim+=self.rel_pos_emb(q)# 添加相对位置偏置# 创建掩码处理填充区域mask=torch.ones(1,1,h,w,device=device)mask=F.unfold(mask,kernel_size=block+2*halo,stride=block,padding=halo)mask=repeat(mask,"() j i -> (b i h) () j",b=b,h=heads).bool()# 应用掩码max_neg_value=-torch.finfo(sim.dtype).maxsim.masked_fill_(mask,max_neg_value)# 计算注意力权重并聚合值向量attn=sim.softmax(dim=-1)out=einsum("b i j, b j d -> b i d",attn,v)# 重组输出特征图out=rearrange(out,"(b h) n d -> b n (h d)",h=heads)out=self.to_out(out)out=rearrange(out,"(b h w) (p1 p2) c -> b c (h p1) (w p2)",b=b,h=h//block,w=w//block,p1=block,p2=block,)returnout

注册

ultralytics/nn/tasks.py中进行如下操作:

步骤1:

fromultralytics.nn.attention.HaloAttentionimportHaloAttention

步骤2

修改def parse_model(d, ch, verbose=True):

elifmin{HaloAttention}:args=[ch[f],*args[0:]]

配置yolov8_HaloAttention.yaml

ultralytics/cfg/models/v8/yolov8_HaloAttention.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parametersnc:80# number of classesscales:# model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n:[0.33,0.25,1024]s:[0.33,0.50,1024]m:[0.67,0.75,768]l:[1.00,1.00,512]x:[1.00,1.25,512]# YOLOv8.0n backbonebackbone:# [from, repeats, module, args]-[-1,1,Conv,[64,3,2]]# 0-P1/2-[-1,1,Conv,[128,3,2]]# 1-P2/4-[-1,3,C2f,[128,True]]-[-1,1,Conv,[256,3,2]]# 3-P3/8-[-1,6,C2f,[256,True]]-[-1,1,Conv,[512,3,2]]# 5-P4/16-[-1,6,C2f,[512,True]]-[-1,1,Conv,[1024,3,2]]# 7-P5/32-[-1,1,HaloAttention,[2,1]]-[-1,1,SPPF,[1024,5]]# 9# YOLOv8.0n headhead:-[-1,1,nn.Upsample,[None,2,"nearest"]]-[[-1,6],1,Concat,[1]]# cat backbone P4-[-1,3,C2f,[512]]# 12-[-1,1,nn.Upsample,[None,2,"nearest"]]-[[-1,4],1,Concat,[1]]# cat backbone P3-[-1,3,C2f,[256]]# 15 (P3/8-small)-[-1,1,Conv,[256,3,2]]-[[-1,12],1,Concat,[1]]# cat head P4-[-1,3,C2f,[512]]# 18 (P4/16-medium)-[-1,1,Conv,[512,3,2]]-[[-1,9],1,Concat,[1]]# cat head P5-[-1,3,C2f,[1024]]# 21 (P5/32-large)-[[15,18,21],1,Detect,[nc]]# Detect(P3, P4, P5)

实验

脚本

importwarnings warnings.filterwarnings('ignore')fromultralyticsimportYOLOif__name__=='__main__':model=YOLO('/root/ultralytics/ultralytics/cfg/models/v8/yolov8-HaloAttention.yaml')# 修改为自己的数据集地址model.train(data='coco128.yaml',cache=False,imgsz=640,epochs=10,single_cls=False,# 是否是单类别检测batch=8,close_mosaic=10,workers=0,optimizer='SGD',amp=True,project='runs/train',name='HaloAttention',)

结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1165347.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv8改进 - 注意力机制 | MCA (Multidimensional Collaborative Attention) 多维协作注意力通过三分支结构增强通道与空间特征协同建模

前言 本文介绍了多维协作注意力(MCA)及其在YOLOv8中的结合应用。现有注意力机制方法存在忽略维度建模或计算负担重的问题,为此提出MCA,其通过三分支架构同时推断通道、高度和宽度维度注意力,几乎无额外开销。MCA关键在…

深度学习毕设选题推荐:基于python-pytorch卷神经网络训练识别舌头是否健康

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

深度学习毕设选题推荐:基于python机器学习-pytorch-CNN训练识别服装服饰

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

7个数据安全策略保证YashanDB的安全执行

随着数据库系统在企业核心业务中的广泛应用,数据安全问题日渐成为数据库设计与运维的重要考量点。如何确保数据库在高性能和高可用的前提下,实现数据的保密性、完整性及可用性,是保障企业信息系统稳定运行的基础。YashanDB作为一款面向大规模…

7个为什么选择YashanDB的理由,助力企业决策

在当今快速变化的商业环境中,企业依赖于高效的决策支持系统,而数据库技术作为企业的核心支撑,直接影响着数据管理和决策效率。如何在众多数据库产品中选择合适的解决方案,是决策者面临的重要问题。YashanDB凭借其独特的技术优势和…

代码混淆的AI优化:安全性与性能平衡

代码混淆的AI优化:安全性与性能平衡 关键词:代码混淆、AI优化、安全性、性能平衡、代码保护 摘要:本文深入探讨了代码混淆的AI优化这一前沿话题,旨在实现代码安全性与性能之间的平衡。首先介绍了代码混淆和AI优化的背景知识,包括目的、预期读者和文档结构。接着阐述了核心…

7个影响YashanDB数据库安全性的因素

在现代应用程序开发中,数据库作为数据存储和管理的核心组成部分,其安全性至关重要。YashanDB作为一个高性能的数据库系统,面对各种潜在的安全隐患,采取了多种机制来加以保护。理解影响数据库安全性的各个因素,不仅对于…

深度学习毕设项目推荐-基于python-pytorch训练识别舌头是否健康

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

Java毕设项目推荐-基于Web的校运动会管理系统设计与实现基于SpringBoot的民运会赛务管理系统的设计与实现【附源码+文档,调试定制服务】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

计算机深度学习毕设实战-基于机器学习 python-pytorch训练识别舌头是否健康

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

Java毕设项目推荐-基于java的车辆违章信息管理系统的设计与实现基于JavaEE的车辆违章信息管理系统的设计与实现【附源码+文档,调试定制服务】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

7种常见的YashanDB数据库故障及处理办法

在现代数据库管理系统中,数据库故障的发生是不可避免的。特别是对于复杂的分布式架构和共享存储体系,故障可能会影响到整体系统的可用性和数据的完整性。了解常见的数据库故障及其处理办法,不仅能够提高系统的稳定性,还能够减少业…

手把手教你:提示工程架构师完成提示工程系统持续部署

手把手教你:提示工程架构师完成提示工程系统持续部署 一、引言:为什么提示工程需要“持续部署”? 1. 一个让所有提示工程师头疼的场景 上周深夜,我收到客户支持团队的紧急消息:“线上AI客服的回复突然变得生硬&#xf…

深度学习毕设项目推荐-基于python深度学习的道路车辆内有无佩戴安全带识别

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

深度学习毕设项目:基于python-pytorch机器学习 训练识别舌头是否健康

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

深度学习毕设项目推荐-基于python-pytorch-CNN训练识别服装服饰

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

8个步骤快速部署YashanDB数据库环境

在现代数据库技术领域,应用对数据库性能、一致性及高可用性提出了严格需求。数据库系统的部署涉及多种技术挑战,包括数据存储优化、事务一致性保障、资源高效调度及容灾能力建设。YashanDB以其丰富的存储结构支持、多样的部署形态和一体化的高可用设计&a…

深度学习计算机毕设之基于python-pytorch训练识别舌头是否健康卷神经网络

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

Rust unsafe 一文全功能解析

Rust unsafe 一文全功能解析 在 Rust 生态中,“安全”是贯穿始终的核心标签——编译器通过严格的所有权规则、借用检查器等机制,从根源上规避空指针、悬垂引用、数据竞争等内存安全问题。但现实开发中,部分场景需要突破安全规则的限制&#x…

【毕业设计】基于python-pytorch深度学习训练识别舌头是否健康

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…