关键点检测--使用YOLOv8对Leeds Sports Pose(LSP)关键点检测

目录

  • 1. Leeds Sports Pose数据集下载
  • 2. 数据集处理
    • 2.1 获取标签
    • 2.2 将图像文件和标签文件处理成YOLO能使用的格式
  • 3. 用YOLOv8进行训练
    • 3.1 训练
    • 3.2 预测


1. Leeds Sports Pose数据集下载

从kaggle官网下载这个数据集,地址为link,下载好的数据集文件如下:
| |——archive
|——images 这个文件夹中放了2000张运动图像
|——visualized 这个文件夹放的是带关键点的图像
|——joints.mat 存放的是关键点标签
|——README.txt 存放对数据集的介绍

2. 数据集处理

2.1 获取标签

原图像的关键点标签以joints.mat格式存放,用python脚本对其进行解析,这里注意每张图像的大小不同,因此不同图像的宽高不同要注意。 解析代码如下,注意更换里面的图像路径。

import os
import cv2
import numpy as np
import scipy.io as sio# 加载 joints.mat 文件
mat_path = "joints.mat"
data = sio.loadmat(mat_path)# 查看所有键
print(data.keys())# 获取 joints 数据 (假设其结构为 (3, num_keypoints, num_samples))
joints = data['joints']# 路径配置
output_dir = r"LSP\labels"
image_dir = r"LSP\images"  # 图像目录路径
os.makedirs(output_dir, exist_ok=True)# 解析 joints 数据
num_samples = joints.shape[2]for idx in range(num_samples):# 构建图像路径image_path = os.path.join(image_dir, f"im{idx + 1:04d}.jpg")# 检查图像是否存在if not os.path.exists(image_path):print(f"图像 {image_path} 不存在,跳过该样本。")continue# 获取当前样本的关键点数据keypoints = joints[:, :, idx]x_coords = keypoints[0, :]y_coords = keypoints[1, :]visibility = keypoints[2, :]# 读取图像获取实际尺寸img = cv2.imread(image_path)if img is None:print(f"无法读取图像 {image_path},跳过该样本。")continueimg_height, img_width = img.shape[:2]# 归一化关键点坐标x_coords /= img_widthy_coords /= img_height# 计算边界框x_min, x_max = np.min(x_coords), np.max(x_coords)y_min, y_max = np.min(y_coords), np.max(y_coords)bbox_width = x_max - x_minbbox_height = y_max - y_minx_center = x_min + bbox_width / 2y_center = y_min + bbox_height / 2# 构建标签行:class_id x_center y_center width height x1 y1 x2 y2 ...label = f"0 {x_center:.6f} {y_center:.6f} {bbox_width:.6f} {bbox_height:.6f} "label += " ".join([f"{x:.6f} {y:.6f}" for x, y in zip(x_coords, y_coords)])label += "\n"# 保存标签文件label_path = os.path.join(output_dir, f"im{idx + 1:04d}.txt")with open(label_path, "w") as f:f.write(label)print(f"所有关键点数据已转换为 YOLOv8 格式,保存至 {output_dir}")

获取的关键点标签如下图:
在这里插入图片描述
对应的图像如下图:
在这里插入图片描述

2.2 将图像文件和标签文件处理成YOLO能使用的格式

我是按照8比2将数据集分成训练集和测试集,划分代码如下:

import os
import random
import shutil# 数据集路径images_path = r"LSP\images"
labels_path = r"LSP\labels"# 输出路径
train_img_dir = os.path.join(images_path, "train")
val_img_dir = os.path.join(images_path, "val")
train_lbl_dir = os.path.join(labels_path, "train")
val_lbl_dir = os.path.join(labels_path, "val")# 创建输出文件夹
os.makedirs(train_img_dir, exist_ok=True)
os.makedirs(val_img_dir, exist_ok=True)
os.makedirs(train_lbl_dir, exist_ok=True)
os.makedirs(val_lbl_dir, exist_ok=True)# 获取所有图像文件名(不带扩展名)
image_files = sorted([f.split('.')[0] for f in os.listdir(images_path) if f.endswith('.jpg')])# 设置划分比例
train_ratio = 0.8
val_ratio = 0.2# 随机打乱文件列表
random.shuffle(image_files)# 划分数据集
train_count = int(len(image_files) * train_ratio)
train_files = image_files[:train_count]
val_files = image_files[train_count:]def move_files(file_list, img_dir, lbl_dir):for filename in file_list:img_src = os.path.join(images_path, f"{filename}.jpg")lbl_src = os.path.join(labels_path, f"{filename}.txt")img_dst = os.path.join(img_dir, f"{filename}.jpg")lbl_dst = os.path.join(lbl_dir, f"{filename}.txt")if os.path.exists(img_src) and os.path.exists(lbl_src):shutil.move(img_src, img_dst)shutil.move(lbl_src, lbl_dst)# 移动文件
move_files(train_files, train_img_dir, train_lbl_dir)
move_files(val_files, val_img_dir, val_lbl_dir)print(f"数据集划分完成:\n  训练集:{len(train_files)} 张\n  验证集:{len(val_files)} 张")

划分后的文件格式图:
| | |——LSP 文件名
| |——images
|——train 训练数据
|——val 测试数据
| |——val
|——train 训练标签
|——val 测试标签
在这里插入图片描述
数据集的yaml格式如下:
在这里插入图片描述

3. 用YOLOv8进行训练

3.1 训练

下面是我进行训练的代码

# Ultralytics YOLO 🚀, AGPL-3.0 licensefrom copy import copyfrom ultralytics.models import yolo
from ultralytics.nn.tasks import PoseModel
from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.utils.plotting import plot_images, plot_resultsclass PoseTrainer(yolo.detect.DetectionTrainer):"""A class extending the DetectionTrainer class for training based on a pose model.Example:```pythonfrom ultralytics.models.yolo.pose import PoseTrainerargs = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)trainer = PoseTrainer(overrides=args)trainer.train()```"""def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):"""Initialize a PoseTrainer object with specified configurations and overrides."""if overrides is None:overrides = {}overrides["task"] = "pose"super().__init__(cfg, overrides, _callbacks)if isinstance(self.args.device, str) and self.args.device.lower() == "mps":LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. ""See https://github.com/ultralytics/ultralytics/issues/4031.")def get_model(self, cfg=None, weights=None, verbose=True):"""Get pose estimation model with specified configuration and weights."""model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)if weights:model.load(weights)return modeldef set_model_attributes(self):"""Sets keypoints shape attribute of PoseModel."""super().set_model_attributes()self.model.kpt_shape = self.data["kpt_shape"]def get_validator(self):"""Returns an instance of the PoseValidator class for validation."""self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks)def plot_training_samples(self, batch, ni):"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""images = batch["img"]kpts = batch["keypoints"]cls = batch["cls"].squeeze(-1)bboxes = batch["bboxes"]paths = batch["im_file"]batch_idx = batch["batch_idx"]plot_images(images,batch_idx,cls,bboxes,kpts=kpts,paths=paths,fname=self.save_dir / f"train_batch{ni}.jpg",on_plot=self.on_plot,)def plot_metrics(self):"""Plots training/val metrics."""plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.pngif __name__ == "__main__":args = dict(model=r'E:\postgraduate\bolt_lossen_project\ultralytics-8.1.0\yolov8n-pose.pt', epochs= 100, mode='train')trains = PoseTrainer(overrides=args)trains.train()#predictor.predict_cli()

3.2 预测

预测代码:

# Ultralytics YOLO 🚀, AGPL-3.0 licensefrom ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
import cv2class PosePredictor(DetectionPredictor):"""A class extending the DetectionPredictor class for prediction based on a pose model.Example:```pythonfrom ultralytics.utils import ASSETSfrom ultralytics.models.yolo.pose import PosePredictorargs = dict(model='yolov8n-pose.pt', source=ASSETS)predictor = PosePredictor(overrides=args)predictor.predict_cli()```"""def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):"""Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""super().__init__(cfg, overrides, _callbacks)self.args.task = "pose"if isinstance(self.args.device, str) and self.args.device.lower() == "mps":LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. ""See https://github.com/ultralytics/ultralytics/issues/4031.")def postprocess(self, preds, img, orig_imgs):"""Return detection results for a given input image or list of images."""preds = ops.non_max_suppression(preds,self.args.conf,self.args.iou,agnostic=self.args.agnostic_nms,max_det=self.args.max_det,classes=self.args.classes,nc=len(self.model.names),)if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a listorig_imgs = ops.convert_torch2numpy_batch(orig_imgs)results = []for i, pred in enumerate(preds):orig_img = orig_imgs[i]pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)img_path = self.batch[0][i]results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts))return resultsif __name__ == "__main__":img = r'E:\postgraduate\bolt_lossen_project\ultralytics-8.1.0\archive\LSP\images\train\im0089.jpg'args = dict(model=r'E:\postgraduate\bolt_lossen_project\ultralytics-8.1.0\runs\pose\train\weights\best.pt', source=img, task = 'pose')predictor = PosePredictor(overrides=args)predictor.predict_cli()

预测结果图:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/79253.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

20250508在WIN10下使用移远的4G模块EC200A-CN直接上网

1、在WIN10/11下安装驱动程序:Quectel_Windows_USB_DriverA_Customer_V1.1.13.zip 2、使用移远的专用串口工具:QCOM_V1.8.2.7z QCOM_V1.8.2_win64.exe 3、配置串口UART42/COM42【移远会自动生成连续三个串口,最小的那一个】 AT命令&#xf…

第J7周:ResNeXt解析

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 具体实现 (一)环境 语言环境:Python 3.10 编 译 器: PyCharm 框 架: Tensorflow (二)具体…

C++之类和对象:初始化列表,static成员,友元,const成员 ……

目录 const成员函数: 前置和后置重载: 取地址及const取地址操作符重载: 初始化列表: explicit关键字: static成员: 友元: 友元函数: 友元类: 内部类&#xff1a…

uni-app 中的条件编译与跨端兼容

uni-app 为了实现一套代码编译到多个平台(包括小程序,App,H5 等),引入了条件编译机制。 通过条件编译,我们可以针对不同的平台编写特定的代码,从而实现跨端兼容。 一、条件编译的作用 平台差异…

Linux平台下SSH 协议克隆Github远程仓库并配置密钥

目录 注意:先提前配置好SSH密钥,然后再git clone 1. 检查现有 SSH 密钥 2. 生成新的 SSH 密钥 3. 将 SSH 密钥添加到 ssh-agent 4. 将公钥添加到 GitHub 5. 测试 SSH 连接 6. 配置 Git 使用 SSH 注意:先提前配置好SSH密钥,然…

[C++] 大数减/除法

目录 高精度博客 - 前两讲高精度减法高精度除法高精度系列函数完整版 高精度博客 - 前两讲 讲次名称链接高精加法[C] 高精度加法(作用 模板 例题)高精乘法[C] 高精度乘法 高精度减法 void subBIG(int x[], int y[], int z[]){z[0] max(x[0], y[0]);for(int i 1; i < …

视频添加字幕脚本分享

脚本简介 这是一个给视频添加字幕的脚本&#xff0c;可以方便的在指定的位置给视频添加不同大小、字体、颜色的文本字幕&#xff0c;添加方式可以直接修改脚本中的文本信息&#xff0c;或者可以提前编辑好.srt字幕文件。脚本执行环境&#xff1a;windowsmingwffmpeg。本方法仅…

ubuntu nobel + qt5.15.2 设置qss语法识别正确

问题展示 解决步骤 首选项里面的高亮怎么编辑选择都没用。如果已经有generic-highlighter和css.xml&#xff0c;直接修改css.xml文件最直接&#xff01; 在generic-highlighter目录下找到css.xml文件&#xff0c;位置是&#xff1a;/opt/Qt/Tools/QtCreator/share/qtcreator/…

洛谷P7528 [USACO21OPEN] Portals G

P7528 [USACO21OPEN] Portals G luogu题目传送门 题目描述 Bessie 位于一个由 N N N 个编号为 1 … N 1\dots N 1…N 的结点以及 2 N 2N 2N 个编号为 1 ⋯ 2 N 1\cdots 2N 1⋯2N 的传送门所组成的网络中。每个传送门连接两个不同的结点 u u u 和 v v v&#xff08; u …

C++STL——priority_queue

优先队列 前言优先队列仿函数头文件 前言 本篇主要讲解优先队列及其底层实现。 优先队列 优先队列的本质就是个堆&#xff0c;其与queue一样&#xff0c;都是容器适配器&#xff0c;不过优先队列是默认为vector实现的。priority_queue的接口优先队列默认为大根堆。 仿函数 …

助力你的Neovim!轻松管理开发工具的魔法包管理器来了!

在现代编程环境中&#xff0c;Neovim 已经成为许多开发者的编辑器选择。而针对 Neovim 的各种插件与功能扩展&#xff0c;则是提升开发体验的重要手段。今天我们要介绍的就是一个强大而便捷的开源项目——mason.nvim&#xff0c;一个旨在简化和优化 Neovim 使用体验的便携式包管…

Java-Lambda 表达式

Lambda 表达式是 Java 8 引入的一项重要特性&#xff0c;它提供了一种简洁的方式来表示匿名函数。Lambda 表达式主要用于简化函数式接口的实现&#xff0c;使代码更加简洁和易读。以下是关于 Lambda 表达式的详细阐述&#xff1a; 1. Lambda 表达式的基本语法 Lambda 表达式的…

05 mysql之DDL

一、SQL的四个分类 我们通常可以将 SQL 分为四类&#xff0c;分别是&#xff1a; DDL&#xff08;数据定义语言&#xff09;、DML&#xff08;数据操作语言&#xff09;、 DCL&#xff08;数据控制语言&#xff09;和 TCL&#xff08;事务控制语言&#xff09;。 DDL 用于创建…

1 2 3 4 5顺序插入,形成一个红黑树

红黑树的特性与优点 红黑树是一种自平衡的二叉搜索树&#xff0c;通过额外的颜色标记和平衡性约束&#xff0c;确保树的高度始终保持在 O(log n)。其核心特性如下&#xff1a; 每个节点要么是红色&#xff0c;要么是黑色。根节点和叶子节点&#xff08;NIL节点&#xff09;是…

微服务6大拆分原则

微服务6大拆分原则 微服务拆分是指将一个大型应用程序拆分成独立服务的过程&#xff0c;在微服务拆分时&#xff0c;需要考虑以下6大微服务拆分原则 一、单一职责原则 微服务单一职责原则&#xff0c;是指每个微服务应该专注于解决一个明确定义的业务领域或功能&#xff0c;…

java: Compilation failed: internal java compiler error 报错解决方案

java: Compilation failed: internal java compiler error 报错解决方案 如下图所示&#xff1a; 在编译的时候提示 java: Compilation failed: internal java compiler error 原因&#xff1a;内部 java 编译错误,一般是编译版本不匹配。 问题解决 项目中有以下设置JDK版本…

介绍一下ReentrantLock 跟 Synchronized 区别

ReentrantLock 跟 Synchronized 区别 面试回答&#xff1a; 相同点&#xff1a; synchronized 和 ReentrantLock 都是用来保护资源线程安全的。 都可以保证可见性。 synchronized 和 ReentrantLock 都拥有可重入的特点。 从基本语义和概念上说 synchronized: Java 内建的…

第7次课 栈A

课堂学习 栈&#xff08;stack&#xff09; 是一种遵循先入后出逻辑的线性数据结构。 我们可以将栈类比为桌面上的一摞盘子&#xff0c;如果想取出底部的盘子&#xff0c;则需要先将上面的盘子依次移走。我们将盘子替换为各种类型的元素&#xff08;如整数、字符、对象等&…

ts装饰器

TypeScript 装饰器是一种特殊类型的声明&#xff0c;能够被附加到类声明、方法、访问符、属性或参数上。它本质上是一个函数&#xff0c;会在运行时被调用&#xff0c;并且被装饰的声明信息会作为参数传递给装饰器函数。 装饰器的分类 类装饰器 类装饰器作用于类构造函数&…

【金仓数据库征文】政府项目数据库迁移:从MySQL 5.7到KingbaseES的蜕变之路

摘要&#xff1a;本文详细阐述了政府项目中将 MySQL 5.7 数据库迁移至 KingbaseES 的全过程&#xff0c;涵盖迁移前的环境评估、数据梳理和工具准备&#xff0c;迁移实战中的数据源与目标库连接配置、迁移任务详细设定、执行迁移与过程监控&#xff0c;以及迁移后的质量验证、系…