YOLOv5推理代码解析

代码如下

import cv2
import numpy as np
import onnxruntime as ort
import time
import random# 画一个检测框
def plot_one_box(x, img, color=None, label=None, line_thickness=None):"""description: 在图像上绘制一个矩形框。param:x: 框的坐标 [x1, y1, x2, y2]img: 输入图像color: 矩形框的颜色,默认为随机颜色label: 框内显示的标签line_thickness: 矩形框的线条宽度return: 无返回值,直接在图像上绘制"""tl = (line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1)  # line/font thickness,计算线条或字体的粗细color = color or [random.randint(0, 255) for _ in range(3)]  # 如果没有提供颜色,随机生成颜色c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))  # 左上角和右下角的坐标cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)  # 绘制矩形框if label:  # 如果提供了标签,则绘制标签tf = max(tl - 1, 1)  # 字体的粗细t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]  # 获取标签的大小c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3  # 计算标签背景框的位置cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # 绘制标签背景框cv2.putText(img,label,(c1[0], c1[1] - 2),0,tl / 3,[225, 255, 255],thickness=tf,lineType=cv2.LINE_AA,)  # 绘制标签文本# 生成网格坐标
def _make_grid(nx, ny):"""description: 生成网格坐标,用于解码预测框位置。param:nx, ny: 网格的行数和列数return: 返回网格坐标"""xv, yv = np.meshgrid(np.arange(ny), np.arange(nx))  # 生成网格坐标return np.stack((xv, yv), 2).reshape((-1, 2)).astype(np.float32)  # 转换为需要的格式# 输出解码
def cal_outputs(outs, nl, na, model_w, model_h, anchor_grid, stride):"""description: 对模型输出的坐标进行解码,转换为图像坐标。param:outs: 模型输出的框的偏移量nl: 输出层数量na: 每层的anchor数目model_w, model_h: 模型输入图像的尺寸anchor_grid: anchor的尺寸stride: 每个输出层的缩放步长return: 解码后的输出"""row_ind = 0grid = [np.zeros(1)] * nl  # 每个层对应一个网格for i in range(nl):h, w = int(model_w / stride[i]), int(model_h / stride[i])  # 计算该层特征图的高和宽length = int(na * h * w)  # 当前层的总框数if grid[i].shape[2:4] != (h, w):  # 如果网格的大小不匹配,则重新生成网格grid[i] = _make_grid(w, h)# 解码每个框的中心坐标和宽高outs[row_ind:row_ind + length, 0:2] = (outs[row_ind:row_ind + length, 0:2] * 2. - 0.5 + np.tile(grid[i], (na, 1))) * int(stride[i])outs[row_ind:row_ind + length, 2:4] = (outs[row_ind:row_ind + length, 2:4] * 2) ** 2 * np.repeat(anchor_grid[i], h * w, axis=0)  # 计算宽高row_ind += lengthreturn outs# 后处理,计算检测框
def post_process_opencv(outputs, model_h, model_w, img_h, img_w, thred_nms, thred_cond):"""description: 对模型输出的框进行后处理,得到最终的检测框。param:outputs: 模型输出的框model_h, model_w: 模型输入的高度和宽度img_h, img_w: 原图的高度和宽度thred_nms: 非极大值抑制的阈值thred_cond: 置信度阈值return: 返回处理后的框、置信度和类别"""conf = outputs[:, 4].tolist()  # 获取每个框的置信度c_x = outputs[:, 0] / model_w * img_w  # 计算中心点x坐标c_y = outputs[:, 1] / model_h * img_h  # 计算中心点y坐标w = outputs[:, 2] / model_w * img_w  # 计算框的宽度h = outputs[:, 3] / model_h * img_h  # 计算框的高度p_cls = outputs[:, 5:]  # 获取分类得分if len(p_cls.shape) == 1:  # 如果分类结果只有一维,增加一维p_cls = np.expand_dims(p_cls, 1)cls_id = np.argmax(p_cls, axis=1)  # 获取类别编号# 计算框的四个角坐标p_x1 = np.expand_dims(c_x - w / 2, -1)p_y1 = np.expand_dims(c_y - h / 2, -1)p_x2 = np.expand_dims(c_x + w / 2, -1)p_y2 = np.expand_dims(c_y + h / 2, -1)areas = np.concatenate((p_x1, p_y1, p_x2, p_y2), axis=-1)  # 合并成框的坐标areas = areas.tolist()  # 转为列表形式ids = cv2.dnn.NMSBoxes(areas, conf, thred_cond, thred_nms)  # 非极大值抑制if len(ids) > 0:  # 如果有框被保留return np.array(areas)[ids], np.array(conf)[ids], cls_id[ids]else:return [], [], []# 图像推理
def infer_img(img0, net, model_h, model_w, nl, na, stride, anchor_grid, thred_nms=0.4, thred_cond=0.5):"""description: 对输入图像进行推理,输出检测框。param:img0: 原始图像net: 加载的ONNX模型model_h, model_w: 模型的输入尺寸nl: 输出层数量na: 每层的anchor数量stride: 每层的缩放步长anchor_grid: 每层的anchor尺寸thred_nms: 非极大值抑制阈值thred_cond: 置信度阈值return: 检测框、置信度和类别"""# 图像预处理img = cv2.resize(img0, [model_w, model_h], interpolation=cv2.INTER_AREA)  # 将图像调整为模型输入大小img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # 转换为RGB格式img = img.astype(np.float32) / 255.0  # 归一化blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)  # 将图像转为模型输入格式# 模型推理outs = net.run(None, {net.get_inputs()[0].name: blob})[0].squeeze(axis=0)  # 推理并去掉batch维度# 输出坐标矫正outs = cal_outputs(outs, nl, na, model_w, model_h, anchor_grid, stride)# 检测框计算img_h, img_w, _ = np.shape(img0)  # 获取原图的尺寸boxes, confs, ids = post_process_opencv(outs, model_h, model_w, img_h, img_w, thred_nms, thred_cond)return boxes, confs, idsif __name__ == "__main__":# 加载ONNX模型model_pb_path = "a.onnx"  # 模型文件路径so = ort.SessionOptions()net = ort.InferenceSession(model_pb_path, so)# 类别字典dic_labels = {0: 'jn', 1: 'pill_bag', 2: 'pill_ban', 3: 'yg', 4: 'ys', 5: 'kfy',6: 'pw', 7: 'yanyao_1', 8: 'yanyao_2', 9: 'paper_cup', 10: 'musai',11: 'carrot', 12: 'potato', 13: 'potato_s', 14: 'potato_black',15: 'cizhuan', 16: 'eluanshi_guang', 17: 'stone', 18: 'zhuankuai_bai',19: 'zhuankuai_red', 20: 'empty'}# 模型参数model_h = 320model_w = 320nl = 3na = 3stride = [8., 16., 32.]anchors = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]]anchor_grid = np.asarray(anchors, dtype=np.float32).reshape(nl, -1, 2)# 打开摄像头video = 1  # 摄像头设备编号,1表示默认摄像头cap = cv2.VideoCapture(video)  # 视频捕获对象flag_det = False  # 检测开关while True:success, img0 = cap.read()  # 读取每一帧if success:if flag_det:  # 如果检测开启t1 = time.time()  # 记录推理前的时间det_boxes, scores, ids = infer_img(img0, net, model_h, model_w, nl, na, stride, anchor_grid,thred_nms=0.4, thred_cond=0.5)  # 推理t2 = time.time()  # 记录推理后的时间# 绘制检测框和标签for box, score, id in zip(det_boxes, scores, ids):label = '%s:%.2f' % (dic_labels[id], score)plot_one_box(box.astype(np.int16), img0, color=(255, 0, 0), label=label, line_thickness=None)# 计算并显示FPSdelta_time = t2 - t1if delta_time > 0:str_FPS = "FPS: %.2f" % (1. / delta_time)else:str_FPS = "FPS: inf"cv2.putText(img0, str_FPS, (50, 50), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 255, 0), 3)cv2.imshow("video", img0)  # 显示图像key = cv2.waitKey(1) & 0xFF  # 等待键盘输入if key == ord('q'):  # 按q键退出breakelif key & 0xFF == ord('s'):  # 按s键切换检测开关flag_det = not flag_detprint(flag_det)cap.release()  # 释放视频捕获对象

代码中推理相关的函数逐行详细中文注释和解释,帮助你从初学者的角度完全理解代码。我们重点讲解这几个核心函数:


1. infer_img() 推理主函数

def infer_img(img0, net, model_h, model_w, nl, na, stride, anchor_grid, thred_nms=0.4, thred_cond=0.5):

这是图像推理的主函数,完成从原始图像到预测结果的所有操作。


第一步:图像预处理

img = cv2.resize(img0, [model_w, model_h], interpolation=cv2.INTER_AREA)
  • 将原始图像 img0 缩放成模型输入要求的大小(例如 320×320)。

  • cv2.INTER_AREA 是一种图像插值方式,适合缩小图像时使用。

img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  • OpenCV 读取图像是 BGR 顺序,而深度学习模型通常使用 RGB,因此这里需要转换颜色通道。

img = img.astype(np.float32) / 255.0
  • 把图像的数据类型转为 float32,并将像素值从 [0, 255] 范围归一化到 [0, 1],符合模型输入要求。

blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
  • OpenCV图像的格式是 (H, W, C),而 PyTorch 模型(如YOLO)的输入是 (B, C, H, W)

  • np.transpose(img, (2, 0, 1)) 把通道 C 移到第一个维度

  • np.expand_dims(..., axis=0) 增加 batch 维度:变成 (1, 3, 320, 320)


第二步:模型推理

outs = net.run(None, {net.get_inputs()[0].name: blob})[0].squeeze(axis=0)
  • 用 ONNX Runtime 推理:输入是 blob

  • net.get_inputs()[0].name 得到模型输入的名字

  • squeeze(axis=0) 把 batch 维度去掉,形状变成 (N, 85),N 是预测框数量,85 是每个框的信息(x, y, w, h, conf, + 80类)


第三步:输出坐标解码

outs = cal_outputs(outs, nl, na, model_w, model_h, anchor_grid, stride)
  • YOLO 的输出是相对 anchor + grid 编码的,需要转换为图像上的真实位置

  • cal_outputs() 就是做这个解码变换的函数(后面详细讲)


第四步:后处理,获取检测框信息

img_h, img_w, _ = np.shape(img0)
boxes, confs, ids = post_process_opencv(outs, model_h, model_w, img_h, img_w, thred_nms, thred_cond)
  • 将模型输出映射回原始图像尺寸

  • 使用置信度阈值和 NMS 非极大值抑制删除重复框

  • 得到最终的:

    • boxes: 框坐标

    • confs: 置信度

    • ids: 类别编号


2. cal_outputs() 坐标解码函数

def cal_outputs(outs, nl, na, model_w, model_h, anchor_grid, stride):

含义解释:

  • outs: 模型输出,形状大致是 (N, 85),前4列是框的位置

  • nl: YOLO使用的输出层数量(3个:大中小目标)

  • na: 每个特征层使用的 anchor 数(通常为 3)

  • anchor_grid: 每层 anchor 的宽高尺寸

  • stride: 每层特征图相对于原图的缩放倍数

grid = [np.zeros(1)] * nl
  • 每一层都要生成网格坐标 grid,初始化为占位

for i in range(nl):h, w = int(model_w / stride[i]), int(model_h / stride[i])
  • 计算第 i 层的特征图尺寸(如:320/8=40)

    length = int(na * h * w)
  • 该层有多少个预测框

    if grid[i].shape[2:4] != (h, w):grid[i] = _make_grid(w, h)
  • 如果还没有生成 grid,就调用 _make_grid() 创建形状为 (h*w, 2) 的网格点

    outs[row_ind:row_ind + length, 0:2] = ...outs[row_ind:row_ind + length, 2:4] = ...
  • 对该层的所有框做位置矫正(中心点解码 + 宽高缩放)

  • 用 grid 和 anchor 反算出真实坐标


3. post_process_opencv() 后处理函数

def post_process_opencv(outputs, model_h, model_w, img_h, img_w, thred_nms, thred_cond):

功能:

  • 将模型输出映射回原始图像尺寸

  • 提取类别信息

  • 使用 OpenCV 的 cv2.dnn.NMSBoxes() 进行非极大值抑制,保留重要框

步骤:

conf = outputs[:, 4].tolist()         # 提取每个框的置信度
c_x = outputs[:, 0] / model_w * img_w
c_y = outputs[:, 1] / model_h * img_h
w = outputs[:, 2] / model_w * img_w
h = outputs[:, 3] / model_h * img_h
  • 将中心点和尺寸从模型尺寸映射回原始图像尺寸

p_cls = outputs[:, 5:]
cls_id = np.argmax(p_cls, axis=1)
  • 取得每个框的类别分数最大值(即分类结果)

p_x1 = c_x - w/2
p_y1 = c_y - h/2
p_x2 = c_x + w/2
p_y2 = c_y + h/2
  • 把中心点转为左上角和右下角坐标 [x1, y1, x2, y2]

areas = np.concatenate((p_x1, p_y1, p_x2, p_y2), axis=-1)
ids = cv2.dnn.NMSBoxes(areas, conf, thred_cond, thred_nms)
  • 用 NMS 去除重叠预测框


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/80499.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CATIA高效工作指南——常规配置篇(二)

一、结构树(Specification Tree)操作技巧精讲 结构树是CATIA设计中记录模型历史与逻辑关系的核心模块,其高效管理直接影响设计效率。本节从基础操作到高级技巧进行系统梳理。 1.1 结构树激活与移动 ​​激活方式​​: ​​白线…

批量重命名bat

作为一名程序员,怎么可以自己一个个改文件名呢! Windows的批量重命名会自动加上括号和空格,看着很不爽,写一个bat处理吧!❥(ゝω・✿ฺ) 功能:将当前目录下的所有文件名里面当括号和空格都去掉。 用法&…

嵌入式软件开发常见warning之 warning: implicit declaration of function

文章目录 🧩 1. C 编译流程回顾(背景)📍 2. 出现 warning 的具体阶段:**编译阶段(Compilation)**🧬 2.1 词法分析(Lexical Analysis)🌲 2.2 语法分…

【人工智能-agent】--Dify中MCP工具存数据到MySQL

本文记录的工作如下: 自定义MCP工具,爬取我的钢铁网数据爬取的数据插值处理自定义MCP工具,把爬取到的数据(str)存入本地excel表格中自定义MCP工具,把爬取到的数据(str)存入本地MySQ…

Golang 应用的 CI/CD 与 K8S 自动化部署全流程指南

一、CI/CD 流程设计与工具选择 1. 技术栈选择 版本控制:Git(推荐 GitHub/GitLab)CI 工具:Jenkins/GitLab CI/GitHub Actions(本文以 GitHub Actions 为例)容器化:Docker Docker Compose制品库…

网络基础1(应用层、传输层)

目录 一、应用层 1.1 序列化和反序列化 1.2 HTTP协议 1.2.1 URL 1.2.2 HTTP协议格式 1.2.3 HTTP服务器示例 二、传输层 2.1 端口号 2.1.1 netstat 2.1.2 pidof 2.2 UDP协议 2.2.1 UDP的特点 2.2.2 基于UDP的应用层…

基于大模型预测的吉兰 - 巴雷综合征综合诊疗方案研究报告大纲

目录 一、引言(一)研究背景(二)研究目的与意义二、大模型预测吉兰 - 巴雷综合征的理论基础与技术架构(一)大模型原理概述(二)技术架构设计三、术前预测与手术方案制定(一)术前预测内容(二)手术方案制定依据与策略四、术中监测与麻醉方案调整(一)术中监测指标与数…

【言语】刷题2

front:刷题1 ⭐ 前对策的说理类 题干 新时代是转型关口,要创新和开放(前对策)创新和开放不能一蹴而就,但是对于现代化很重要 BC片面,排除 A虽然表达出了创新和开放很重要,体现了现代化&#xf…

Blueprints - Gameplay Message Subsystem

一些学习笔记归档; Gameplay Message是C插件,安装方式是把插件文件夹拷贝到Plugins中(没有的话需要新建该文件夹),然后再刷新源码,运行项目; 安装后还需要在插件中激活: 这样&#…

火山云网站搭建

使用火山引擎的 **火山云(Volcano Engine Cloud)** 搭建网站,主要涉及云服务器、存储、网络等核心云服务的配置。以下是搭建网站的基本步骤和关键点: --- ### **一、准备工作** 1. **注册火山引擎账号** - 访问火山引擎官网&…

嵌入式开发学习(第二阶段 C语言基础)

直到型循环的实现 特点:先执行,后判断,不管条件是否满足,至少执行一次。 **代表:**do…while,goto(已经淘汰,不推荐使用) do…while 语法: 循环变量; do {循环体; }…

Nginx +Nginx-http-flv-module 推流拉流

这两天为了利用云服务器实现 Nginx 进行OBS Rtmp推流,Flv拉流时发生了诸多情况,记录实现过程。 环境 OS:阿里云CentOS 7.9 64位Nginx:nginx-1.28.0Nginx-http-flv-module:nginx-http-flv-module-1.2.12 安装Nginx编…

射频ADRV9026驱动

参考: ADRV9026 & ADRV9029 Prototyping Platform User Guide [Analog Devices Wiki] 基于ADRV9026的四通道射频收发FMC子卡-CSDN博客 adrv9026 spi 接口验证代码-CSDN博客

使用本地部署的 LLaMA 3 模型进行中文对话生成

以下程序调用本地部署的 LLaMA3 模型进行多轮对话生成,通过 Hugging Face Transformers API 加载、预处理、生成并输出最终回答。 程序用的是 Chat 模型格式(如 LLaMA3 Instruct 模型),遵循 ChatML 模板,并使用 apply…

Oracle19c中的全局临时表

应用程序通常使用某种形式的临时数据存储来处理过于复杂而无法一次性完成的流程。通常,这些临时存储被定义为数据库表或 PL/SQL 表。从 Oracle 8i 开始,可以使用全局临时表将临时表的维护和管理委托给服务器。 一、临时表分类 Oracle 支持两种类型的临…

Windows 安装 Milvus

说明 操作系统:Window 中间件:docker desktop Milvus:Milvus Standalone(单机版) 安装 docker desktop 参考:Window、CentOs、Ubuntu 安装 docker-CSDN博客 安装 Milvus 参考链接:Run Mil…

24、DeepSeek-V3论文笔记

DeepSeek-V3论文笔记 **一、概述****二、核心架构与创新技术**0.汇总:1. **基础架构**2. **创新策略** 1.DeepSeekMoE无辅助损失负载均衡DeepSeekMoE基础架构无辅助损失负载均衡互补序列级辅助损失 2.多令牌预测(MTP)1.概念2、原理2.1BPD2.2M…

1.8 梯度

(知识体系演进逻辑树) 一元导数(1.5) │ ├─→ 多元偏导数(1.6核心突破) │ │ │ └─解决:多变量耦合时的单变量影响分析 │ │ │ ├─几何:坐标轴切片切线斜率…

274、H指数

题目 给你一个整数数组 citations ,其中 citations[i] 表示研究者的第 i 篇论文被引用的次数。计算并返回该研究者的 h 指数。 根据维基百科上 h 指数的定义:h 代表“高引用次数” ,一名科研人员的 h 指数 是指他(她&#xff09…

【C++11】异常

前言 上文我们学习到了C11中类的新功能【C11】类的新功能-CSDN博客 本文我们来学习C下一个新语法:异常 1.异常的概念 异常的处理机制允许程序在运行时就出现的问题进行相应的处理。异常可以使得我们将问题的发现和问题的解决分开,程序的一部分负…