李沐45_SSD实现——自学笔记

主体思路:
1.生成一堆锚框
2.根据真实标签为每个锚框打标(类别、偏移、mask)
3.模型为每个锚框做一个预测(类别、偏移)
4.计算上述二者的差异损失,以更新模型weights

先读取一张图像。 它的高度和宽度分别为561和728像素。

%matplotlib inline
import torch
from d2l import torch as d2limg = d2l.plt.imread('catdog.jpg')
h, w = img.shape[:2]
h, w
(561, 728)

display_anchors函数定义如下。 我们在特征图(fmap)上生成锚框(anchors),每个单位(像素)作为锚框的中心。 由于锚框中的(x,y)轴坐标值(anchors)已经被除以特征图(fmap)的宽度和高度,因此这些值介于0和1之间,表示特征图中锚框的相对位置。

def display_anchors(fmap_w, fmap_h, s):d2l.set_figsize()# 前两个维度上的值不影响输出fmap = torch.zeros((1, 10, fmap_h, fmap_w))anchors = d2l.multibox_prior(fmap, sizes=s, ratios=[1, 2, 0.5])bbox_scale = torch.tensor((w, h, w, h))d2l.show_bboxes(d2l.plt.imshow(img).axes,anchors[0] * bbox_scale)

锚框的尺度设置为0.15,特征图的高度和宽度设置为4。图像上4行和4列的锚框的中心是均匀分布的。

display_anchors(fmap_w=4, fmap_h=4, s=[0.15])

在这里插入图片描述

将特征图的高度和宽度减小一半,然后使用较大的锚框来检测较大的目标。 当尺度设置为0.4时,一些锚框将彼此重叠。

display_anchors(fmap_w=2, fmap_h=2, s=[0.4])

在这里插入图片描述

进一步将特征图的高度和宽度减小一半,然后将锚框的尺度增加到0.8。 此时,锚框的中心即是图像的中心

display_anchors(fmap_w=1, fmap_h=1, s=[0.8])

在这里插入图片描述

SSD的实现 单发多框检测

定义了这样一个类别预测层,通过参数num_anchors和num_classes分别指定了a
和q。 该图层使用填充为1的3X3的卷积层。此卷积层的输入和输出的宽度和高度保持不变。

%matplotlib inline
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2ldef cls_predictor(num_inputs, num_anchors, num_classes):return nn.Conv2d(num_inputs, num_anchors * (num_classes + 1),kernel_size=3, padding=1)

边界框预测层

每个锚框预测4个偏移量,而不是q+1个类别。

def bbox_predictor(num_inputs, num_anchors):return nn.Conv2d(num_inputs, num_anchors * 4, kernel_size=3, padding=1)

连接多尺度的预测

def forward(x, block):return block(x)Y1 = forward(torch.zeros((2, 8, 20, 20)), cls_predictor(8, 5, 10))
Y2 = forward(torch.zeros((2, 16, 10, 10)), cls_predictor(16, 3, 10))
Y1.shape, Y2.shape
(torch.Size([2, 55, 20, 20]), torch.Size([2, 33, 10, 10]))

将通道维移到最后一维。 因为不同尺度下批量大小仍保持不变,我们可以将预测结果转成二维的(批量大小,高X宽X通道数)的格式,以方便之后在维度1上的连结

def flatten_pred(pred):return torch.flatten(pred.permute(0, 2, 3, 1), start_dim=1)def concat_preds(preds):return torch.cat([flatten_pred(p) for p in preds], dim=1)

尽管Y1和Y2在通道数、高度和宽度方面具有不同的大小,我们仍然可以在同一个小批量的两个不同尺度上连接这两个预测输出。

concat_preds([Y1, Y2]).shape
torch.Size([2, 25300])

高和宽减半块

高和宽减半块down_sample_blk,该模块将输入特征图的高度和宽度减半,可以扩大每个单元在其输出特征图中的感受野。

def down_sample_blk(in_channels, out_channels):blk = []for _ in range(2):blk.append(nn.Conv2d(in_channels, out_channels,kernel_size=3, padding=1))blk.append(nn.BatchNorm2d(out_channels))blk.append(nn.ReLU())in_channels = out_channelsblk.append(nn.MaxPool2d(2))return nn.Sequential(*blk)

在以下示例中,我们构建的高和宽减半块会更改输入通道的数量,并将输入特征图的高度和宽度减半。

forward(torch.zeros((2, 3, 20, 20)), down_sample_blk(3, 10)).shape
torch.Size([2, 10, 10, 10])

基本网络块

基本网络块用于从输入图像中抽取特征,输出特征形状32X32

def base_net():blk = []num_filters = [3, 16, 32, 64]for i in range(len(num_filters) - 1):blk.append(down_sample_blk(num_filters[i], num_filters[i+1]))return nn.Sequential(*blk)forward(torch.zeros((2, 3, 256, 256)), base_net()).shape
torch.Size([2, 64, 32, 32])

完整模型

完整的单发多框检测模型由五个模块组成。每个块生成的特征图既用于生成锚框,又用于预测这些锚框的类别和偏移量。在这五个模块中,第一个是基本网络块,第二个到第四个是高和宽减半块,最后一个模块使用全局最大池将高度和宽度都降到1。

def get_blk(i):if i == 0:blk = base_net()elif i == 1:blk = down_sample_blk(64, 128)elif i == 4:blk = nn.AdaptiveMaxPool2d((1,1))else:blk = down_sample_blk(128, 128)return blk

每个块定义前向传播。与图像分类任务不同,此处的输出包括:CNN特征图Y;在当前尺度下根据Y生成的锚框;预测的这些锚框的类别和偏移量(基于Y)。

def blk_forward(X, blk, size, ratio, cls_predictor, bbox_predictor):Y = blk(X)anchors = d2l.multibox_prior(Y, sizes=size, ratios=ratio)cls_preds = cls_predictor(Y)bbox_preds = bbox_predictor(Y)return (Y, anchors, cls_preds, bbox_preds)
sizes = [[0.2, 0.272], [0.37, 0.447], [0.54, 0.619], [0.71, 0.79],[0.88, 0.961]]
ratios = [[1, 2, 0.5]] * 5
num_anchors = len(sizes[0]) + len(ratios[0]) - 1

完整的模型TinySSD

class TinySSD(nn.Module):def __init__(self, num_classes, **kwargs):super(TinySSD, self).__init__(**kwargs)self.num_classes = num_classesidx_to_in_channels = [64, 128, 128, 128, 128]for i in range(5):# 即赋值语句self.blk_i=get_blk(i)setattr(self, f'blk_{i}', get_blk(i))setattr(self, f'cls_{i}', cls_predictor(idx_to_in_channels[i],num_anchors, num_classes))setattr(self, f'bbox_{i}', bbox_predictor(idx_to_in_channels[i],num_anchors))def forward(self, X):anchors, cls_preds, bbox_preds = [None] * 5, [None] * 5, [None] * 5for i in range(5):# getattr(self,'blk_%d'%i)即访问self.blk_iX, anchors[i], cls_preds[i], bbox_preds[i] = blk_forward(X, getattr(self, f'blk_{i}'), sizes[i], ratios[i],getattr(self, f'cls_{i}'), getattr(self, f'bbox_{i}'))anchors = torch.cat(anchors, dim=1)cls_preds = concat_preds(cls_preds)cls_preds = cls_preds.reshape(cls_preds.shape[0], -1, self.num_classes + 1)bbox_preds = concat_preds(bbox_preds)return anchors, cls_preds, bbox_preds

创建一个模型实例,然后使用它对一个256X256像素的小批量图像X执行前向传播.第一个模块输出特征图的形状为32X32。 回想一下,第二到第四个模块为高和宽减半块,第五个模块为全局汇聚层。 由于以特征图的每个单元为中心有4个锚框生成,因此在所有五个尺度下,每个图像总共生成5444

net = TinySSD(num_classes=1)
X = torch.zeros((32, 3, 256, 256))
anchors, cls_preds, bbox_preds = net(X)print('output anchors:', anchors.shape)
print('output class preds:', cls_preds.shape)
print('output bbox preds:', bbox_preds.shape)
output anchors: torch.Size([1, 5444, 4])
output class preds: torch.Size([32, 5444, 2])
output bbox preds: torch.Size([32, 21776])

训练模型

读取数据集和初始化

batch_size = 32
train_iter, _ = d2l.load_data_bananas(batch_size)
Downloading ../data/banana-detection.zip from http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip...
read 1000 training examples
read 100 validation examples

香蕉检测数据集中,目标的类别数为1。 定义好模型后,我们需要初始化其参数并定义优化算法。

device, net = d2l.try_gpu(), TinySSD(num_classes=1)
trainer = torch.optim.SGD(net.parameters(), lr=0.2, weight_decay=5e-4)

定义损失和平均函数

使用L1范数损失,即预测值和真实值之差的绝对值。 掩码变量bbox_masks令负类锚框和填充锚框不参与损失的计算。 最后,我们将锚框类别和偏移量的损失相加,以获得模型的最终损失函数。

cls_loss = nn.CrossEntropyLoss(reduction='none')
bbox_loss = nn.L1Loss(reduction='none')def calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks):batch_size, num_classes = cls_preds.shape[0], cls_preds.shape[2]cls = cls_loss(cls_preds.reshape(-1, num_classes),cls_labels.reshape(-1)).reshape(batch_size, -1).mean(dim=1)bbox = bbox_loss(bbox_preds * bbox_masks,bbox_labels * bbox_masks).mean(dim=1)return cls + bbox

沿用准确率平均结果,平均绝对误差来评价预测结果

def cls_eval(cls_preds, cls_labels):# 由于类别预测结果放在最后一维,argmax需要指定最后一维。return float((cls_preds.argmax(dim=-1).type(cls_labels.dtype) == cls_labels).sum())def bbox_eval(bbox_preds, bbox_labels, bbox_masks):return float((torch.abs((bbox_labels - bbox_preds) * bbox_masks)).sum())

训练模型

需要在模型的前向传播过程中生成多尺度锚框(anchors),并预测其类别(cls_preds)和偏移量(bbox_preds)。 然后,我们根据标签信息Y为生成的锚框标记类别(cls_labels)和偏移量(bbox_labels)。 最后,我们根据类别和偏移量的预测和标注值计算损失函数。

num_epochs, timer = 20, d2l.Timer()
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['class error', 'bbox mae'])
net = net.to(device)
for epoch in range(num_epochs):# 训练精确度的和,训练精确度的和中的示例数# 绝对误差的和,绝对误差的和中的示例数metric = d2l.Accumulator(4)net.train()for features, target in train_iter:timer.start()trainer.zero_grad()X, Y = features.to(device), target.to(device)# 生成多尺度的锚框,为每个锚框预测类别和偏移量anchors, cls_preds, bbox_preds = net(X)# 为每个锚框标注类别和偏移量bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(anchors, Y)# 根据类别和偏移量的预测和标注值计算损失函数l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,bbox_masks)l.mean().backward()trainer.step()metric.add(cls_eval(cls_preds, cls_labels), cls_labels.numel(),bbox_eval(bbox_preds, bbox_labels, bbox_masks),bbox_labels.numel())cls_err, bbox_mae = 1 - metric[0] / metric[1], metric[2] / metric[3]animator.add(epoch + 1, (cls_err, bbox_mae))
print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')
print(f'{len(train_iter.dataset) / timer.stop():.1f} examples/sec on 'f'{str(device)}')
class err 3.32e-03, bbox mae 3.24e-03
4187.2 examples/sec on cuda:0

在这里插入图片描述

预测目标

将图像中感兴趣的目标检测出来,读取并调整测试图像的大小,然后将其转成卷积层需要的四维格式。

X = torchvision.io.read_image('banana.jpg').unsqueeze(0).float()
img = X.squeeze(0).permute(1, 2, 0).long()

使用下面的multibox_detection函数,我们可以根据锚框及其预测偏移量得到预测边界框。然后,通过非极大值抑制来移除相似的预测边界框

def predict(X):net.eval()anchors, cls_preds, bbox_preds = net(X.to(device))cls_probs = F.softmax(cls_preds, dim=2).permute(0, 2, 1)output = d2l.multibox_detection(cls_probs, bbox_preds, anchors)idx = [i for i, row in enumerate(output[0]) if row[0] != -1]return output[0, idx]output = predict(X)

筛选所有置信度不低于0.9的边界框,做为最终输出。

def display(img, output, threshold):d2l.set_figsize((5, 5))fig = d2l.plt.imshow(img)for row in output:score = float(row[1])if score < threshold:continueh, w = img.shape[0:2]bbox = [row[2:6] * torch.tensor((w, h, w, h), device=row.device)]d2l.show_bboxes(fig.axes, bbox, '%.2f' % score, 'w')display(img, output.cpu(), threshold=0.9)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/824331.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Photoshop 2024 (ps) v25.6中文 强大的图像处理软件 mac/win

Photoshop 2024 for Mac是一款强大的图像处理软件&#xff0c;专为Mac用户设计。它继承了Adobe Photoshop一贯的优秀功能&#xff0c;并进一步提升了性能和稳定性。 Mac版Photoshop 2024 (ps)v25.6中文激活版下载 win版Photoshop 2024 (ps)v25.6直装版下载 无论是专业的设计师还…

EI Scopus双检索 | 2024年清洁能源与智能电网国际会议(CCESG 2024)

会议简介 Brief Introduction 2024年清洁能源与智能电网国际会议(CCESG 2024) 会议时间&#xff1a;2024年 11月27-29日 召开地点&#xff1a;澳大利亚悉尼 大会官网&#xff1a;CCESG 2024-2024 International Joint Conference on Clean Energy and Smart Grid 由CoreShare科…

m4p转换mp3格式怎么转?3个Mac端应用~

M4P文件格式的诞生伴随着苹果公司引入FairPlay版权管理系统&#xff0c;该系统旨在保护音频的内容。M4P因此而生&#xff0c;成为受到FairPlay系统保护的音频格式&#xff0c;常见于苹果设备的iTunes等平台。 MP3文件格式的多个优点 MP3格式的优点显而易见。首先&#xff0c;其…

k8s之etcd

1.特点&#xff1a; etcd 是云原生架构中重要的基础组件。有如下特点&#xff1a; 简单&#xff1a;安装配置简单&#xff0c;而且提供了 HTTP API 进行交互&#xff0c;使用也很简单键值对存储&#xff1a;将数据存储在分层组织的目录中&#xff0c;如同在标准文件系统中监…

vscode msvc qt环境搭建

自己整了好久都没把环境搞好&#xff0c;后来发现已经有大佬搞好了插件&#xff0c;完全不需要自己整理。 下载如下插件&#xff1a; 第二个qt插件就可以自动帮我们生成工程了。 可惜目前似乎支持win&#xff0c;另外就是debug模式运行后会报qwindowsd.dll插件找不到的错误&a…

【简单讲解下如何用爬虫玩转石墨文档】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

链表OJ - 6(链表分割)

题目描述&#xff08;来源&#xff09; 现有一链表的头指针 ListNode* pHead&#xff0c;给一定值x&#xff0c;编写一段代码将所有小于x的结点排在其余结点之前&#xff0c;且不能改变原来的数据顺序&#xff0c;返回重新排列后的链表的头指针。 思路 创建两个链表&#xff0c…

ChatGPT:引领未来的语言模型革命?

一、引言 随着人工智能技术的不断发展&#xff0c;Chat GPT作为一种自然语言处理技术&#xff0c;已经逐渐渗透到各个领域&#xff0c;具有广泛的应用前景。本文将从多个角度探讨Chat GPT的应用领域及其未来发展趋势。 ChatGPT的语言处理能力超越了以往任何一款人工智能产品。…

Docker一键快速私有化部署(Ollama+Openwebui) +AI大模型(gemma,llama2,qwen)20240417更新

几行命令教你私有化部署自己的AI大模型&#xff0c;每个人都可以有自己的GTP 第一步&#xff1a;安装Docker(如果已经有了可以直接跳第二步) ####下载安装Docker wget https://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo -O/etc/yum.repos.d/docker-ce.repo##…

STM32 USB虚拟串口

电路原理图 usb部分 晶振部分 usb与单片机连接 配置信息 sys配置信息 rcc配置信息 usb配置信息 虚拟串口配置信息 时钟配置信息 项目配置信息 代码 包含文件 主函数代码 实验效果 修改接收波特率依然可以正常接收&#xff0c;也就是说单片机可以自动适应上位机的波特率设置。…

4.17作业

#include "double_link_list.h" node_p create_double_link_list() //创建双向链表 {node_p H(node_p)malloc(sizeof(node));if(HNULL){printf("空间申请失败\n");return NULL;}H->data0;H->priNULL;H->nextNULL;return H; } node_p create_node…

BUUCTF——[GXYCTF2019]BabyUpload

BUUCTF——[GXYCTF2019]BabyUpload 1.上传嘛&#xff0c;直接丢正常的jpg文件进服务器 2.发现可以正常上传&#xff0c;并且回显出来啦文件上传的路径 /var/www/html/upload/7df22610744ec51e9cb7a8a8eb674374/1111.jpg 3.尝试上传一句话木马 <?php eval($POST[123456]…

HDFS详解(Hadoop)

Hadoop 分布式文件系统&#xff08;Hadoop Distributed File System&#xff0c;HDFS&#xff09;是 Apache Hadoop 生态系统的核心组件之一&#xff0c;它是设计用于存储大规模数据集并运行在廉价硬件上的分布式文件系统。 1. 分布式存储&#xff1a; HDFS 将文件分割成若干块…

「 网络安全常用术语解读 」漏洞利用交换VEX详解

漏洞利用交换&#xff08;Vulnerability Exploitability eXchange&#xff0c;简称VEX&#xff09;是一个信息安全领域的标准&#xff0c;旨在提供关于软件漏洞及其潜在利用的实时信息。根据美国政府发布的用例(PDF)&#xff0c;由美国政府开发的漏洞利用交换(VEX)使供应商和用…

工业电脑在ESOP工作站行业应用

ESOP工作站行业应用 项目背景 E-SOP是实现作业指导书电子化&#xff0c;并统一管理和集中控制的一套管理信息平台。信迈科技的ESOP终端是一款体积小巧功能齐全的高性价比工业电脑&#xff0c;上层通过网络与MES系统连接&#xff0c;下层连接显示器展示作业指导书。ESOP控制终…

基于开源IM即时通讯框架MobileIMSDK:RainbowChat v11.5版已发布

关于MobileIMSDK MobileIMSDK 是一套专门为移动端开发的开源IM即时通讯框架&#xff0c;超轻量级、高度提炼&#xff0c;一套API优雅支持UDP 、TCP 、WebSocket 三种协议&#xff0c;支持iOS、Android、H5、小程序、Uniapp、标准Java平台&#xff0c;服务端基于Netty编写。 工…

朗思-我的家园正式上线:朗思科技Agent工具软件--人人拥有“Ai-机器人”

4月16日&#xff0c;朗思科技正式发布"朗思-我的家园"。朗思科技是国内领先的Ai Agent智能自动化工具软件产品及方案的提供商&#xff0c;始终坚持自主研发&#xff0c;全面支持国产信创&#xff0c;不断加快产品创新迭代。基于技术领先性和战略前瞻性&#xff0c;其…

【小白学机器学习13】一文理解假设检验的反证法,H0如何设计的,什么时候用左侧检验和右侧检验,等各种关于假设检验的基础知识

目录 前言&#xff1a; 目标 1 什么叫 假设检验 1.1 假设检验的定义 1.1.1 来自百度百科 1.1.2 维基百科 1.2 假设检验的最底层逻辑&#xff1a;是反证法思想 1.3 假设检验的底层构造&#xff1a;小概率反证法思想 2 什么叫反证法 2.1 反证法的概念 2.1.1 来自百度…

MFC下CPictureCtrl控件基于鼠标左键坐标的直线绘图

本文仅供学习交流&#xff0c;严禁用于商业用途&#xff0c;如本文涉及侵权请及时联系本人将于及时删除 目录 1.创建自定义类CMyPictureCtrl 2.布局Dlg 3.实验代码 4.运行结果 在基于对话框的MFC应用程序中&#xff0c;通过鼠标操作获取坐标并在CPictureCtrl控件中使用Lin…

通过Idea部署Tomcat服务器

1.在idea中创建项目 有maven构建工具就创建maven&#xff0c;没有就正常创建一个普通的java程序 创建普通java项目 2.添加框架 3.配置 Tomcat 注意&#xff1a;创建web项目后我们需要配置tomcat才能运行&#xff0c;下面我们来进行配置。 4.添加部署 回到服务器 5.完善配置 6…