YOLOv5、YOLOv8改进:C3STR(Swin Transformer)

目录

1.介绍

2. YOLOv5、YOLOv8改进

2.1 common.py配置

2.2 yolo.py配置

2.3 yaml配置文件


1.介绍

视觉领域正在见证从 CNN 到 Transformers 的建模转变,纯 Transformer 架构在主要视频识别基准测试中达到了最高准确度。这些视频模型都建立在 Transformer 层之上,Transformer 层在空间和时间维度上全局连接块。在本文中,我们提倡视频 Transformer 中的局部归纳偏差,与以前的方法相比,即使使用时空分解,也可以在全局范围内计算自注意力,从而实现更好的速度-精度权衡。所提出的视频架构的局部性是通过调整为图像域设计的 Swin Transformer 实现的,同时继续利用预训练图像模型的力量。我们的方法在广泛的视频识别基准测试中实现了最先进的准确度,包括动作识别(Kinetics-400 上的 84.9 top-1 准确度和 Kinetics-600 上的 85.9 top-1 准确度,减少了约 20× 预训练数据和小模型尺寸的 3 倍)和时间建模(SomethingSomething v2 上的 69.6 top-1 准确率)。

论文地址Swin-Transformer论文下载论文地址

 

该论文介绍了一种名为 Swin Transformer 的新视觉 Transformer,它能够作为计算机视觉的通用主干。将 Transformer 从语言适应到视觉的挑战来自两个领域之间的差异,例如视觉实体的规模变化很大,以及与文本中的单词相比,图像中像素的高分辨率。为了解决这些差异,我们提出了一种分层 Transformer,其表示是用移位窗口计算的。移位窗口方案通过将 self-attention 计算限制在不重叠的本地窗口上,同时还允许跨窗口连接,从而带来更高的效率。这种分层架构具有在各种尺度上建模的灵活性,并且具有相对于图像大小的线性计算复杂度。Swin Transformer 的这些特性使其与广泛的视觉任务兼容,包括图像分类(ImageNet-1K 上 86.4 top-1 准确度)和密集预测任务,例如对象检测(COCO 测试上 58.7 box AP 和 51.1 mask AP dev)和语义分割(ADE20K val 为 53.5 mIoU)。它的性能大大超过了之前的 state-of-the-art,在 COCO 上 +2.7 box AP 和 +2.6 mask AP,在 ADE20K 上 +3.2 mIoU,展示了基于 Transformer 的模型作为视觉骨干的潜力。代码和模型将在 它的性能大大超过了之前的 state-of-the-art,在 COCO 上 +2.7 box AP 和 +2.6 mask AP,在 ADE20K 上 +3.2 mIoU,展示了基于 Transformer 的模型作为视觉骨干的潜力。代码和模型将在 它的性能大大超过了之前的 state-of-the-art,在 COCO 上 +2.7 box AP 和 +2.6 mask AP,在 ADE20K 上 +3.2 mIoU,展示了基于 Transformer 的模型作为视觉骨干的潜力。

面临问题:
作者提出了将Swin Transformer缩放到30亿个参数的技术 ,并使其能够使用高达1536×1536分辨率的图像进行训练。在很多方面达到了SOTA。

目前,视觉模型尚未像NLP语言模型那样被广泛探索,部分原因是训练和应用中的以下差异:

(1)视觉模型通常在规模上面临不稳定性问题;

(2)许多下游视觉任务需要高分辨率图像,如何有效地将低分辨率预训练的模型转换为高分辨率模型尚未被有效探索,也就是跨窗口分辨率迁移模型时性能下降。

(3)当图像分辨率较高时,GPU显存消耗也是一个问题。

解决思路:
为了解决这些问题,作者提出了几种技术,并在本文中以Swin Transformer进行了说明:

(1)提高大视觉模型稳定性的后归一化(post normalization) 技术和缩放余弦注意力(scaled cosine attention)方法,以提高大型视觉模型的稳定性;

(2)一种对数间隔连续位置偏差技术(log-spaced continuous position bias technique) ,用于有效地将在低分辨率图像中预训练的模型转换为其高分辨率对应模型。

(3)分享节约GPU内存消耗方法,使得训练大分辨率模型可行;


2. YOLOv5、YOLOv8改进

2.1 common.py配置

在./models/common.py文件中增加以下模块,直接复制即可

class SwinTransformerBlock(nn.Module):def __init__(self, c1, c2, num_heads, num_layers, window_size=8):super().__init__()self.conv = Noneif c1 != c2:self.conv = Conv(c1, c2)# remove input_resolutionself.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])def forward(self, x):if self.conv is not None:x = self.conv(x)x = self.blocks(x)return x
class WindowAttention(nn.Module):def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_size  # Wh, Wwself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5# define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1)  # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Wwself.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)nn.init.normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask=None):B_, N, C = x.shapeqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)# print(attn.dtype, v.dtype)try:x = (attn @ v).transpose(1, 2).reshape(B_, N, C)except:#print(attn.dtype, v.dtype)x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass Mlp(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass SwinTransformerLayer(nn.Module):def __init__(self, dim, num_heads, window_size=8, shift_size=0,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.SiLU, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratio# if min(self.input_resolution) <= self.window_size:#     # if window size is larger than input resolution, we don't partition windows#     self.shift_size = 0#     self.window_size = min(self.input_resolution)assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)def create_mask(self, H, W):# calculate attention mask for SW-MSAimg_mask = torch.zeros((1, H, W, 1))  # 1 H W 1h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))return attn_maskdef forward(self, x):# reshape x[b c h w] to x[b l c]_, _, H_, W_ = x.shapePadding = Falseif min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:Padding = True# print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')pad_r = (self.window_size - W_ % self.window_size) % self.window_sizepad_b = (self.window_size - H_ % self.window_size) % self.window_sizex = F.pad(x, (0, pad_r, 0, pad_b))# print('2', x.shape)B, C, H, W = x.shapeL = H * Wx = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)  # b, L, c# create mask from init to forwardif self.shift_size > 0:attn_mask = self.create_mask(H, W).to(x.device)else:attn_mask = Noneshortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = x# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, Cx_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xx = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W)  # b c h wif Padding:x = x[:, :, :H_, :W_]  # reverse paddingreturn xclass C3STR(C3):# C3 module with SwinTransformerBlock()def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):super().__init__(c1, c2, c2, n, shortcut, g, e)c_ = int(c2 * e)num_heads = c_ // 32self.m = SwinTransformerBlock(c_, c_, num_heads, n)

2.2 yolo.py配置

不需要

2.3 yaml配置文件

增加以下yolov5_swin_transfomrer.yaml文件

代码
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone by yoloair
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3STR, [256]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3STR, [512]],  # 9 <--- ST2CSPB() Transformer module[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

修改完成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/86336.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Grafana离线安装部署以及插件安装

Grafana是一个可视化面板&#xff08;Dashboard&#xff09;&#xff0c;有着非常漂亮的图表和布局展示&#xff0c;功能齐全的度量仪表盘和图形编辑器&#xff0c;支持Graphite、zabbix、InfluxDB、Prometheus和OpenTSDB作为数据源。Grafana主要特性&#xff1a;灵活丰富的图形…

js逆向-某税务网站chinatax分析

目录 一、如图网站二、研究登陆页反爬参数1、datagram参数2、请求接口关系 三、研究详情页反爬参数1、urlyzm与ruuid与x-b3-spanid参数2、los28199参数3、lzkqow23819参数4、jmbw参数 四、最终结果 一、如图网站 二、研究登陆页反爬参数 1、datagram参数 很多接口使用到的dat…

1796_通过vmware打开VirtualBox虚拟机文件

全部学习汇总&#xff1a; GitHub - GreyZhang/toolbox: 常用的工具使用查询&#xff0c;非教程&#xff0c;仅作为自我参考&#xff01; 首先讲vdi格式转换成vmdk格式&#xff0c;以我自己的环境下的信息&#xff0c;处理如下&#xff1a; VBoxManage clonehd "LinuxMin…

【PowerShell】系统安装PowerShell的Core版本,最新版本为7.1

当前以下操作系统支持PowerShell 7.1 版本的安装,非Windows 系统支持的版本和要求有一定的限制。 Windows 8.1/10 (including ARM64)Windows Server 2012 R2, 2016, 2019, and Semi-Annual Channel (SAC)Ubuntu 16.04/18.04/20.04 (including ARM64)Ubuntu 19.10 (via Snap pac…

ESP-IDF学习——1.环境安装与hello-world

ESP-IDF学习——1.环境安装与hello-world 0.前言一、环境搭建1.官方IDE工具2.vscode图形化配置 二、示例工程三、自定义工程四、点灯五、总结 0.前言 最近在学习freertos&#xff0c;但由于买的书还没到&#xff0c;所以先捣鼓捣鼓ESP-IDF&#xff0c;因为这个比Arduino更接近底…

200行C++代码写一个Qt俄罗斯方块小游戏

小小演示一下&#xff1a; 大体思路&#xff1a; 其实很早就想写一个俄罗斯方块了&#xff0c;但是一想到那么多方块还要变形&#xff0c;还要判断落地什么的就脑壳疼。直到现在才写出来。 俄罗斯方块这个小游戏的小难点其实就一个&#xff0c;就是方块的变形&#xff0c;看似…

如何将本地的项目上传到Git

一、GitHub or GitLab or Gitee创建一个新的仓库 二、仓库路径创建成功后&#xff0c;将本地项目上传到git 1. 进入本地项目所在文件夹位置&#xff0c;右击 2.出现git命令框 输入git init 在当前项目的目录中生成本地的git管理&#xff08;会发现在当前目录下多了一个.git文件…

转转闲鱼交易猫链接源码 支持二维码收款

最新仿二手闲置链接源码 后台一键生成链接&#xff0c;后台管理教程&#xff1a;解压源码&#xff0c;修改数据库config/Congig 不会可以看源码里有教程 下载程序&#xff1a;https://pan.baidu.com/s/16lN3gvRIZm7pqhvVMYYecQ?pwd6zw3

30.链表练习题(1)(王道2023数据结构2.3.7节1-15题)

【前面使用的所有链表的定义在第29节】 试题1&#xff1a; 设计一个递归算法&#xff0c;删除不带头结点的单链表L中所有值为x的结点。 首先来看非递归算法&#xff0c;暴力遍历&#xff1a; int Del(LinkList &L,ElemType x){ //此函数实现删除链表中为x的元素LNode *…

科技云报道:分布式存储红海中,看天翼云HBlock如何突围?

科技云报道原创。 过去十年&#xff0c;随着技术的颠覆性创新和新应用场景的大量涌现&#xff0c;企业IT架构出现了稳态和敏态的混合化趋势。 在持续产生海量数据的同时&#xff0c;这些新应用、新场景在基础设施层也普遍基于敏态的分布式架构构建&#xff0c;从而对存储技术…

安卓将图片分割或者拉伸或者旋转到指定尺寸并保存到本地

直接上代码吧:你们要用的话可以按照想法改 package com.demo.util;import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.graphics.Matrix; import android.os.Environment; import android.util.Log;import java.io.File; import java.io.…

代码随想录算法训练营 动态规划part11

一、买卖股票的最佳时机III 123. 买卖股票的最佳时机 III - 力扣&#xff08;LeetCode&#xff09; 请选一个喜欢的吧/(ㄒoㄒ)/~~123. 买卖股票的最佳时机 III - 力扣&#xff08;LeetCode&#xff09; class Solution {public int maxProfit(int[] prices) {if(pricesnul…

SpringBoot项目(百度AI整合)——如何在Springboot中使用语音文件识别 ffmpeg的安装和使用

前言 前言&#xff1a;在实际使用中&#xff0c;经常要参考官方的案例&#xff0c;但有时候因为工具的不一样&#xff0c;比如idea 和 eclipse&#xff0c;普通项目和spring项目等的差别&#xff1b;还有时候因为水平有限&#xff0c;难以在散布于官方的各个文档读懂&#xff…

stable diffusion model训练遇到的问题【No module named ‘triton‘】

一天早晨过来&#xff0c;发现昨天还能跑的diffusion代码&#xff0c;突然出现了【No module named ‘triton’】的问题&#xff0c;导致本就不富裕的显存和优化速度雪上加霜&#xff0c;因此好好探究了解决方案。 首先是原因&#xff0c;由于早晨过来发现【电脑重启】导致了【…

【owt】vs2022 + v141 : 查看WINDOWSSDKDIR

confmfc改为vs2022 + v141 构建 去掉这几个boost库,一样可以链接ok libboost_system-vc141-mt-sgd-x32-1_67.lib libboost_date_time-vc141-mt-sgd-x32-1_67.lib libboost_random-vc141-mt-sgd-x32-1_67.libSDK不在2022或者2017 里面? WINDOWSSDKDIR 在哪里? ##

LuatOS-SOC接口文档(air780E)--camera - codec - 多媒体-编解码

常量 常量 类型 解释 codec.MP3 number MP3格式 codec.WAV number WAV格式 codec.AMR number AMR-NB格式&#xff0c;一般意义上的AMR codec.AMR_WB number AMR-WB格式 codec.create(type, isDecoder) 创建编解码用的codec 参数 传入值类型 解释 int 多媒…

VSCode开发go手记

断点调试&#xff1a; 安装delve&#xff08;windows&#xff09;&#xff1a; go get -u github.com/go-delve/delve/cmd/dlv 设置 launch.json 配置文件&#xff1a; ctrlshiftp 输入 Debug: Open launch.json 打开 launch.json 文件&#xff0c;如果第一次打开,会新建一…

为什么使用命令行

一提到Linux&#xff0c;许多人都会说到“自由”&#xff0c;但他们也许并不知道“自由”的真正涵义。“自由”是指一台没有任何秘密的计算机&#xff0c;并且你可以决定你的计算机能做什么。“自由”是一种权力&#xff0c;但是在过去的二三十年里&#xff0c;这种基本的权力正…

[论文笔记]RE2

引言 今天带来论文Simple and Effective Text Matching with Richer Alignment Features的笔记,论文标题为基于更丰富特征对齐结构的简单高效文本匹配模型。 这篇工作是2019年发表的,在Bert出来之后发表的,在四个著名的文本匹配任务(SNLI,SciTail,QQP,WikiQA)上取得了SOTA…

分块压缩算法及例程

分块压缩算法是一种数据压缩方法&#xff0c;它将输入数据划分为不同的块&#xff0c;并对每个块进行独立的压缩。这种算法通常用于处理大型文件或流式数据&#xff0c;可以提高压缩和解压缩的效率。 以下是一个基本的分块压缩算法的示例&#xff1a; 将输入数据分成固定大小的…