YOLO旋转目标检测之ONNX模型推理

YOLO旋转检测相较于目标检测而言,其只是最后的输出层网络发生了改变,一个最明显的区别便是:目标检测的检测框是xywh,而旋转检测则为xywha,其中,这个a代表angle,即旋转角度,其余的基本相同。
在这里插入图片描述

pt模型推理

这里我们在模型训练完成后,即可进行推理操作,这里我们首先使用默认的模型格式,即pt格式

from ultralytics import YOLO
import cv2
import numpy as np
# 加载模型
model = YOLO("best.pt")  # 加载训练好的旋转框检测模型
# 预测图像
results = model("1.jpg")  # 预测图像
# 可视化参数配置
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 2
colors = [(0,255,0), (255,0,0), (0,0,255)]  # 不同类别的颜色
# 遍历每个检测结果
for result in results:# 获取原始图像并转换为OpenCV格式img = result.orig_img.copy()# 遍历每个旋转框for polygon, cls, conf in zip(result.obb.xyxyxyxy, result.obb.cls, result.obb.conf):# 将坐标转换为整数类型pts = polygon.cpu().numpy().reshape(-1, 2).astype(int)# 绘制多边形边界框cv2.polylines(img, [pts], isClosed=True,color=colors[int(cls)%len(colors)],thickness=thickness)# 构建标签文本label = f"{result.names[int(cls)]} {conf:.2f}"# 计算文本位置(取第一个点上方)text_origin = (pts[0][0], pts[0][1] - 10 if pts[0][1] > 20 else pts[0][1] + 20)# 绘制文本背景(text_w, text_h), _ = cv2.getTextSize(label, font, font_scale, thickness)cv2.rectangle(img,(text_origin[0], text_origin[1] - text_h - 5),(text_origin[0] + text_w, text_origin[1] + 5),colors[int(cls)%len(colors)],-1)  # 填充矩形# 绘制文本cv2.putText(img, label,(text_origin[0], text_origin[1]),font, font_scale,(255,255,255),  # 白色文字thickness)cv2.imwrite("result.jpg", img)

从结果来看,输出的结果的维度为(4,7)其中4代表4个结果,7则是对应的内容,根据拆分结果来看,分别是xywhr以及class_id(类别编号)以及scores(置信度),同时需要注意的是,使用pt的推理结果中,其自动执行了将xywhr转换为xyxyxyxy的操作,这方便我们直接使用opencv中的rectangle方法进行绘图操作。

在这里插入图片描述

ONNX模型推理

ultralytics中提供了将pt文件转换为onnxtflite等多种格式的方法,ONNX(Open Neural Network Exchange)是一种开放的文件格式,用于表示机器学习模型。它使得不同的人工智能框架能够互相交换模型,从而提高了模型的可移植性和互操作性。通过ONNX,开发者可以在一个框架中训练模型,然后将该模型迁移到另一个支持ONNX的框架中进行推理,而无需重新训练或大幅修改模型。

使用ONNX模型进行推理的代码如下:其主要包含数据预处理、模型加载、模型推理三个步骤:

def load_model(weights):"""加载ONNX模型并返回会话对象。:param weights: 模型权重文件路径:return: ONNX运行会话对象"""session = ort.InferenceSession(weights, providers=['CPUExecutionProvider'])logging.info(f"模型加载成功: {weights}")return sessiondef run_inference(session, image_bytes, imgsz=(640, 640)):"""对输入图像进行预处理,然后使用ONNX模型执行推理。:param session: ONNX运行会话对象:param image_bytes: 输入图像的字节数据:param imgsz: 模型输入的尺寸:return: 推理结果、缩放比例、填充尺寸"""im0 = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)  # 解码图像字节数据if im0 is None:raise ValueError("无法从image_bytes解码图像")img, ratio, (dw, dh) = letterbox(im0, new_shape=imgsz)  # 调整图像尺寸img = img.transpose((2, 0, 1))[::-1]  # 调整通道顺序由(640,640,3)变为(3,640,640)img = np.ascontiguousarray(img)img = img[np.newaxis, ...].astype(np.float32) / 255.0  # 归一化处理input_name = session.get_inputs()[0].nameresult = session.run(None, {input_name: img})  # 执行模型推理return result[0], ratio, (dw, dh)def process_images_in_folder(folder_path, model_weights, output_folder, conf_threshold, iou_threshold, imgsz):"""批量处理文件夹中的图像,执行推理、解析和可视化,保存结果。:param folder_path: 输入图像文件夹路径:param model_weights: ONNX模型权重文件路径:param output_folder: 输出结果文件夹路径:param conf_threshold: 置信度阈值:param iou_threshold: IoU 阈值,用于旋转NMS:param imgsz: 模型输入大小"""session = load_model(weights=model_weights)  # 加载ONNX模型if not os.path.exists(output_folder):os.makedirs(output_folder)  # 如果输出文件夹不存在,则创建for filename in os.listdir(folder_path):if filename.endswith(('.jpg', '.png', '.jpeg')):  # 处理图片文件image_path = os.path.join(folder_path, filename)with open(image_path, 'rb') as f:image_bytes = f.read()print("image_path:", image_path)raw_output, ratio, dwdh = run_inference(session=session, image_bytes=image_bytes, imgsz=imgsz)  # 执行推理# 主函数:加载参数
if __name__ == "__main__":folder_path = r"images"  # 输入图像文件夹路径model_weights = r"best.onnx"  # ONNX模型路径output_folder = "results"  # 输出结果文件夹conf_threshold = 0.5  # 置信度阈值iou_threshold = 0.5  # IoU阈值,用于旋转NMSimgsz = (640, 640)  # 模型输入大小process_images_in_folder(folder_path, model_weights, output_folder, conf_threshold, iou_threshold, imgsz)  # 执行批量处理

推理出的结果如下:raw_output(1,7,8400), ratio((0.15873015873015872, 0.15873015873015872)为缩放比例, dwdh(0.0, 80.0)是填充尺度。

随后,便是结果解析了,即后处理过程,如下:
这里需要注意的是,输出结果为(8400,7)其中,0-3xywh4scores5class_id6angle

import os
import cv2
import numpy as np
import onnxruntime as ort
import logging"""
YOLO11 旋转目标检测OBB
1、ONNX模型推理、可视化
2、ONNX输出格式: x_center, y_center, width, height, class1_confidence, ..., classN_confidence, angle
3、支持不同尺寸图片输入、支持旋转NMS过滤重复框、支持ProbIoU旋转IOU计算
"""def letterbox(img, new_shape=(640, 640), color=(0, 0, 0), auto=False, scale_fill=False, scale_up=False, stride=32):"""将图像调整为指定尺寸,同时保持长宽比,添加填充以适应目标输入形状。:param img: 输入图像:param new_shape: 目标尺寸:param color: 填充颜色:param auto: 是否自动调整填充为步幅的整数倍:param scale_fill: 是否强制缩放以完全填充目标尺寸:param scale_up: 是否允许放大图像:param stride: 步幅,用于自动调整填充:return: 调整后的图像、缩放比例、填充尺寸(dw, dh)"""shape = img.shape[:2]if isinstance(new_shape, int):new_shape = (new_shape, new_shape)r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])  # 计算缩放比例if not scale_up:r = min(r, 1.0)ratio = r, rnew_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]if auto:dw, dh = np.mod(dw, stride), np.mod(dh, stride)elif scale_fill:dw, dh = 0.0, 0.0new_unpad = (new_shape[1], new_shape[0])ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]dw /= 2  # 填充均分dh /= 2if shape[::-1] != new_unpad:img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))left, right = int(round(dw - 0.1)), int(round(dw + 0.1))img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)return img, ratio, (dw, dh)def _get_covariance_matrix(obb):"""计算旋转边界框的协方差矩阵。:param obb: 旋转边界框 (Oriented Bounding Box),包含中心坐标、宽、高和旋转角度:return: 协方差矩阵的三个元素 a, b, c"""widths = obb[..., 2] / 2heights = obb[..., 3] / 2angles = obb[..., 4]cos_angle = np.cos(angles)sin_angle = np.sin(angles)a = (widths * cos_angle)**2 + (heights * sin_angle)**2b = (widths * sin_angle)**2 + (heights * cos_angle)**2c = widths * cos_angle * heights * sin_anglereturn a, b, cdef batch_probiou(obb1, obb2, eps=1e-7):"""计算旋转边界框之间的 ProbIoU。:param obb1: 第一个旋转边界框集合:param obb2: 第二个旋转边界框集合:param eps: 防止除零的极小值:return: 两个旋转边界框之间的 ProbIoU"""x1, y1 = obb1[..., 0], obb1[..., 1]x2, y2 = obb2[..., 0], obb2[..., 1]a1, b1, c1 = _get_covariance_matrix(obb1)a2, b2, c2 = _get_covariance_matrix(obb2)t1 = ((a1[:, None] + a2) * (y1[:, None] - y2)**2 + (b1[:, None] + b2) * (x1[:, None] - x2)**2) / ((a1[:, None] + a2) * (b1[:, None] + b2) - (c1[:, None] + c2)**2 + eps) * 0.25t2 = ((c1[:, None] + c2) * (x2 - x1[:, None]) * (y1[:, None] - y2)) / ((a1[:, None] + a2) * (b1[:, None] + b2) - (c1[:, None] + c2)**2 + eps) * 0.5t3 = np.log(((a1[:, None] + a2) * (b1[:, None] + b2) - (c1[:, None] + c2)**2) /(4 * np.sqrt((a1 * b1 - c1**2)[:, None] * (a2 * b2 - c2**2)) + eps) + eps) * 0.5bd = np.clip(t1 + t2 + t3, eps, 100.0)hd = np.sqrt(1.0 - np.exp(-bd) + eps)return 1 - hddef rotated_nms_with_probiou(boxes, scores, iou_threshold=0.5):"""使用 ProbIoU 执行旋转边界框的非极大值抑制(NMS)。:param boxes: 旋转边界框的集合:param scores: 每个边界框的置信度得分:param iou_threshold: IoU 阈值,用于确定是否抑制框:return: 保留的边界框索引列表"""order = scores.argsort()[::-1]  # 根据置信度得分降序排序keep = []while len(order) > 0:i = order[0]keep.append(i)if len(order) == 1:breakremaining_boxes = boxes[order[1:]]iou_values = batch_probiou(boxes[i:i+1], remaining_boxes).squeeze(0)mask = iou_values < iou_threshold  # 保留 IoU 小于阈值的框order = order[1:][mask]return keepdef parse_onnx_output(output, ratio, dwdh, conf_threshold=0.5, iou_threshold=0.5):"""解析ONNX模型的输出,提取旋转边界框坐标、置信度和类别信息,并应用旋转NMS。:param output: ONNX模型的输出,包含预测的边界框信息:param ratio: 缩放比例,用于将坐标还原到原始尺度:param dwdh: 填充的宽高,用于调整边界框的中心点坐标:param conf_threshold: 置信度阈值,过滤低于该阈值的检测框:param iou_threshold: IoU 阈值,用于旋转边界框的非极大值抑制(NMS):return: 符合条件的旋转边界框的检测结果"""boxes, scores, classes, detections = [], [], [], []num_detections = output.shape[2]  # 获取检测的边界框数量num_classes = output.shape[1] - 6  # 计算类别数量# 逐个解析每个检测结果for i in range(num_detections):detection = output[0, :, i]x_center, y_center, width, height = detection[0], detection[1], detection[2], detection[3]  # 提取边界框的中心坐标和宽高angle = detection[-1]  # 提取旋转角度if num_classes > 0:class_confidences = detection[4:4 + num_classes]  # 获取类别置信度if class_confidences.size == 0:continueclass_id = np.argmax(class_confidences)  # 获取置信度最高的类别索引confidence = class_confidences[class_id]  # 获取对应的置信度else:confidence = detection[4]  # 如果没有类别信息,直接使用置信度值class_id = 0  # 默认类别为 0if confidence > conf_threshold:  # 过滤掉低置信度的检测结果x_center = (x_center - dwdh[0]) / ratio[0]  # 还原中心点 x 坐标y_center = (y_center - dwdh[1]) / ratio[1]  # 还原中心点 y 坐标width /= ratio[0]  # 还原宽度height /= ratio[1]  # 还原高度boxes.append([x_center, y_center, width, height, angle])  # 将边界框信息加入列表scores.append(confidence)  # 将置信度加入列表classes.append(class_id)  # 将类别加入列表if not boxes:return []# 转换为 NumPy 数组boxes = np.array(boxes)scores = np.array(scores)classes = np.array(classes)# 应用旋转 NMSkeep_indices = rotated_nms_with_probiou(boxes, scores, iou_threshold=iou_threshold)# 构建最终检测结果for idx in keep_indices:x_center, y_center, width, height, angle = boxes[idx]  # 获取保留的边界框信息confidence = scores[idx]  # 获取对应的置信度class_id = classes[idx]  # 获取类别obb_corners = calculate_obb_corners(x_center, y_center, width, height, angle)  # 计算旋转边界框的四个角点detections.append({"position": obb_corners,  # 旋转边界框的角点坐标"confidence": float(confidence),  # 置信度"class_id": int(class_id),  # 类别 ID"angle": float(angle)  # 旋转角度})return detectionsdef calculate_obb_corners(x_center, y_center, width, height, angle):"""根据旋转角度计算旋转边界框的四个角点。:param x_center: 边界框中心的 x 坐标:param y_center: 边界框中心的 y 坐标:param width: 边界框的宽度:param height: 边界框的高度:param angle: 旋转角度:return: 旋转边界框的四个角点坐标"""cos_angle = np.cos(angle)  # 计算旋转角度的余弦值sin_angle = np.sin(angle)  # 计算旋转角度的正弦值dx = width / 2  # 计算宽度的一半dy = height / 2  # 计算高度的一半# 计算旋转边界框的四个角点坐标corners = [(int(x_center + cos_angle * dx - sin_angle * dy), int(y_center + sin_angle * dx + cos_angle * dy)),(int(x_center - cos_angle * dx - sin_angle * dy), int(y_center - sin_angle * dx + cos_angle * dy)),(int(x_center - cos_angle * dx + sin_angle * dy), int(y_center - sin_angle * dx - cos_angle * dy)),(int(x_center + cos_angle * dx + sin_angle * dy), int(y_center + sin_angle * dx - cos_angle * dy)),]return corners  # 返回角点坐标def save_detections(image, detections, output_path):"""在图像上绘制旋转边界框检测结果并保存。:param image: 原始图像:param detections: 检测结果列表:param output_path: 保存路径"""for det in detections:corners = det['position']  # 获取旋转边界框的四个角点confidence = det['confidence']  # 获取置信度class_id = det['class_id']  # 获取类别ID# 绘制边界框的四条边for j in range(4):pt1 = corners[j]pt2 = corners[(j + 1) % 4]cv2.line(image, pt1, pt2, (0, 0, 255), 2)# 在边界框上方显示类别和置信度cv2.putText(image, f'Class: {class_id}, Conf: {confidence:.2f}',(corners[0][0], corners[0][1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 3)cv2.imwrite(output_path, image)  # 保存绘制后的图像

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/81349.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

架构进阶:深入学习企业总体架构规划(Oracle 战略专家培训课件)【附全文阅读】

本文主要讨论了企业总体技术架构规划的重要性与实施建议。针对Oracle战略专家培训课件中的内容&#xff0c;文章强调了行业面临的挑战及现状分析、总体技术架构探讨、SOA集成解决方案讨论与问题解答等方面。文章指出&#xff0c;为了消除信息孤岛、强化应用系统&#xff0c;需要…

llamafactory-cli webui启动报错TypeError: argument of type ‘bool‘ is not iterable

一、问题 在阿里云NoteBook上启动llamafactory-cli webui报错TypeError: argument of type ‘bool’ is not iterable This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run gradio deploy from the terminal in the working directory t…

Gas 优化不足、升级机制缺陷问题

以下是针对智能合约中 Gas 优化不足 与 升级机制缺陷 的技术风险分析与解决方案: 一、Gas 优化不足 1. 核心问题 Gas 优化不足会导致合约执行成本过高,直接影响用户体验和协议可行性,尤其在交易高峰期可能引发链上拥堵或交易失败。 2. 常见风险点 冗余计算与存储操作 例如…

使用xlwings计算合并单元格的求和

有如下一个excel表 表内有合并单元格&#xff0c;现在需要求和&#xff0c;不能直接下拉填充公式怎么办&#xff1f; 通常的办法是先取消合并单元格&#xff0c;计算后&#xff0c;再次合并单元格&#xff0c;比较繁琐。 在此&#xff0c;尝试使用python和xlwings运行直接给出…

[创业之路-354]:农业文明到智能纪元:四次工业革命下的人类迁徙与价值重构

农业文明到智能纪元&#xff1a;四次工业革命下的人类迁徙与价值重构 从游牧到定居&#xff0c;从蒸汽轰鸣到算法洪流&#xff0c;人类文明的每一次跨越都伴随着生产关系的剧烈震荡。四次工业革命的浪潮不仅重塑了物质世界的生产方式&#xff0c;更将人类推向了身份认同与存在…

LeetCode 2302.统计得分小于 K 的子数组数目:滑动窗口(不需要前缀和)

【LetMeFly】2302.统计得分小于 K 的子数组数目&#xff1a;滑动窗口&#xff08;不需要前缀和&#xff09; 力扣题目链接&#xff1a;https://leetcode.cn/problems/count-subarrays-with-score-less-than-k/ 一个数组的 分数 定义为数组之和 乘以 数组的长度。 比方说&…

kafka学习笔记(四、生产者(客户端)深入研究(二)——消费者协调器与_consumer_offsets剖析)

1.消费者协调器和组协调器 如果消费者客户端中配置了多个分配策略&#xff0c;则多消费者的分区分配交由消费者协调器和组协调器来完成&#xff0c;他们之间使用一套组协调协议进行交互。 1.1.在均衡原理 将全部消费者分成多个子集&#xff0c;每个消费者组的子集在服务中对…

快速将FastAPI接口转为模型上下文协议(MCP)!

fastapi_mcp 是一个用于将 FastAPI 端点暴露为模型上下文协议&#xff08;Model Context Protocol, MCP&#xff09;工具的库&#xff0c;并且支持认证功能。 环境macbook&#xff0c;python3.13 pip install fastapi uvicorn fastapi-mcp 代码 from fastapi import FastAPI, …

实验数据的转换

最近做实验需要把x轴y轴z轴的数据处理一下&#xff0c;总结一下解决的方法&#xff1a; 源文件为两个txt文档&#xff0c;分别为x轴和y轴&#xff0c;如下&#xff1a; 最终需要达到的效果是如下&#xff1a; 就是需要把各个矩阵的数据整理好放在同一个txt文档里。 步骤① …

第Y3周:yolov5s.yaml文件解读

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 本次任务&#xff1a;将yolov5s网络模型中的第4层的C3x2修改为C3x1&#xff0c;第6层的C3x3修改为C3x2。 首先输出原来的网络结构&#xff1a; from n pa…

Ansible安装配置

一、前提 服务器操作系统均为centos7.9 主机ipmaster(Ansible管理端)172.25.192.2node1172.25.192.10node2172.25.192.3 更新/etc/hosts文件 二、安装 master节点&#xff1a; 1. 安装epel源 yum install -y epel-release 2. 安装Ansible yum install -y ansible A…

MySQL中ROW_NUMBER() OVER的用法以及使用场景

使用语法 ROW_NUMBER() OVER ([PARTITION BY partition_column1, partition_column2, ...]ORDER BY sort_column1 [ASC|DESC], sort_column2 [ASC|DESC], ... )PARTITION BY&#xff1a;将数据按指定列分组&#xff0c;每组内单独生成行号。ORDER BY&#xff1a;决定组内行号的…

【人工智能】释放本地AI潜能:LM Studio用户脚本自动化DeepSeek的实战指南

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 随着大型语言模型(LLM)的快速发展,DeepSeek以其高效的性能和开源特性成为开发者关注的焦点。LM Studio作为一款强大的本地AI模型管理工具…

笔试强训:Day3

一、牛牛冲钻五&#xff08;模拟&#xff09; 登录—专业IT笔试面试备考平台_牛客网 #include<iostream> using namespace std; int main(){int t,n,k;string s;cin>>t;while(t--){cin>>n>>k>>s;int ret0;//统计加了多少星for(int i0;i<n;i)…

语音识别质量的跟踪

背景 这个项目是用来生成结构化的电子病历的。数据的来源是医生的录音。中间有一大堆的处理&#xff0c;语音识别&#xff0c;关键字匹配&#xff0c;结构化处理&#xff0c;病历编辑......。最多的时候给上百家医院服务。 语音识别质量的跟踪 一、0225医院的训练后的情况分…

人工智能搜索时代的SEO:关键趋势与优化策略

随着人工智能&#xff08;AI&#xff09;技术的飞速发展&#xff0c;搜索引擎的运作方式正在经历前所未有的变革。2025年&#xff0c;AI驱动的搜索&#xff08;如谷歌的AI概览、ChatGPT搜索和必应的AI增强功能&#xff09;不仅改变了用户获取信息的方式&#xff0c;还为SEO从业…

Node.js心得笔记

npm init 可用npm 来调试node项目 浏览器中的顶级对象时window <ref *1> Object [global] { global: [Circular *1], clearImmediate: [Function: clearImmediate], setImmediate: [Function: setImmediate] { [Symbol(nodejs.util.promisify.custom)]: [Getter] }, cl…

计算机网络01-网站数据传输过程

局域网&#xff1a; 覆盖范围小&#xff0c;自己花钱买设备&#xff0c;宽带固定&#xff0c;自己维护&#xff0c;&#xff0c;一般长度不超过100米&#xff0c;&#xff0c;&#xff0c;带宽也比较固定&#xff0c;&#xff0c;&#xff0c;10M&#xff0c;&#xff0c;&…

Mysql常用函数解析

字符串函数 CONCAT(str1, str2, …) 将多个字符串连接成一个字符串。 SELECT CONCAT(Hello, , World); -- 输出: Hello World​​SUBSTRING(str, start, length) 截取字符串的子串&#xff08;起始位置从1开始&#xff09;。 SELECT SUBSTRING(MySQL, 3, 2); -- 输出: SQ…

SpringMVC 前后端数据交互 中文乱码

ajax 前台传入数据&#xff0c;但是后台接收到的数据中文乱码 首先我们分析一下原因&#xff1a;我们调用接口的时候传入的中文&#xff0c;是没有乱码的 此时我们看一下Java后台接口对应的编码&#xff1a; 默认情况&#xff1a;Servlet容器&#xff08;如Tomcat&#xff09;默…