四. 以Annoy算法建树的方式聚类清洗图像数据集,一次建树,无限次聚类搜索,提升聚类搜索效率。(附完整代码)

文章内容结构:

一. 先介绍什么是Annoy算法。
二. 用Annoy算法建树的完整代码。
三. 用Annoy建树后的树特征匹配聚类归类图像。

一. 先介绍什么是Annoy算法

下面的文章链接将Annoy算法讲解的很详细,这里就不再做过多原理的分析了,想详细了解的可以看看这篇文章内容。

https://zhuanlan.zhihu.com/p/148819536

总的来说:

(1)通过多次递归迭代,建立一个二叉树,以二叉树的方式,提升数据聚类和搜索速度,但会损失一些精度。

(2)建树过程相对比较耗时,但建树只需要一次,部署到线上或者其他设备上,能无数次聚类搜索。(类似于人脸识别的人脸底库)

(注: 这里全部是个人经验,能提升样本标注和清洗效率,不是标准的数据处理方式,希望对您有帮助。)

--------

二. 用Annoy算法建树的完整代码

对底库聚类建树,生成Annoy树特征文件。 

下面参数说明:

最佳聚类类别数量, 是根据《三.以聚类的方式清洗图像数据集,找到最佳聚类类别数 (图像特征提取+Kmeans聚类)》获取得到
BEST_NUM_CLUSTERS = 2501图像特征提取后的向量维度,是pt或者onnx模型输出的类别数
FEATURE_DIM = 190推断图像尺寸,是根据训练pt模型时,输入的图像尺寸大小
CLASSIFY_SIZE = 224  

以下是正式的代码:

import os
import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
import shutil
from sklearn.cluster import KMeans
from sklearn.preprocessing import Normalizer
from  tqdm import tqdm
import math
import matplotlib.pyplot as plt# 图像预处理函数
def preprocess_image(image_path):roi_frame= cv2.imread(image_path)width = roi_frame.shape[1]height = roi_frame.shape[0]if (width != CLASSIFY_SIZE) or (height != CLASSIFY_SIZE) :if width > height:# 将图像逆时针旋转90度roi_frame = cv2.rotate(roi_frame, cv2.ROTATE_90_COUNTERCLOCKWISE)new_height = CLASSIFY_SIZEnew_width = int(roi_frame.shape[1] * (CLASSIFY_SIZE / roi_frame.shape[0]))roi_frame = cv2.resize(roi_frame, (new_width, new_height))# 计算上下左右漂移量y_offset = (CLASSIFY_SIZE - roi_frame.shape[0]) // 2x_offset = (CLASSIFY_SIZE - roi_frame.shape[1]) // 2gray_image = np.full((CLASSIFY_SIZE, CLASSIFY_SIZE, 3), 128, dtype=np.uint8)# 将调整大小后的目标图像放置到灰度图上gray_image[y_offset:y_offset + roi_frame.shape[0], x_offset:x_offset + roi_frame.shape[1]] = roi_frame# # 显示结果# cv2.imshow("gray_image", gray_image)# cv2.waitKey(1)# 将图像转为 rgbgray_image =  cv2.cvtColor(gray_image, cv2.COLOR_BGR2RGB)else:gray_image = cv2.cvtColor(roi_frame, cv2.COLOR_BGR2RGB)img_np = np.array(gray_image).transpose(2, 0, 1).astype(np.float32)# 假设模型需要[0,1]归一化img_np = img_np / 255.0# 均值 方差mean = np.array([0.485, 0.456, 0.406],dtype=np.float32).reshape(3, 1, 1)std = np.array([0.229, 0.224, 0.225],dtype=np.float32).reshape(3, 1, 1)img_np= (img_np - mean)/stdreturn np.expand_dims(img_np, axis=0)# 卸载 onnxruntime
# 安装  pip install onnxruntime-gpu
def get_onnx_providers():# 检查是否安装了GPU版本的ONNX Runtimeall_provider = ort.get_available_providers()if "CUDAExecutionProvider" in all_provider:providers = [("CUDAExecutionProvider", {"device_id": 0,"arena_extend_strategy": "kNextPowerOfTwo","gpu_mem_limit": 6 * 1024 * 1024 * 1024,  # 限制GPU内存使用为2GB"cudnn_conv_algo_search": "EXHAUSTIVE","do_copy_in_default_stream": True,}),"CPUExecutionProvider"]print("检测到NVIDIA GPU,使用CUDA加速")return providerselse:print("未检测到NVIDIA GPU,使用CPU")return ["CPUExecutionProvider"]if __name__ =="__main__":root_path =  "/home/xxx/Download"# ONNX模型路径MODEL_PATH = os.path.join(root_path, "08以图搜图_找相似度/98_weights/classify_modified_model_224.onnx")# 图像文件夹路径IMAGE_DIR = os.path.join(root_path, "08以图搜图_找相似度/99_test_datasets/8_bcd已验收/8")# 分类结果输出路径OUTPUT_DIR = os.path.join(root_path, "08以图搜图_找相似度/99_test_datasets/8_bcd已验收/8_kmeans_besk_k_classify")# 保存ann建树文件路径ANNOY_PATH = "08以图搜图_找相似度/01kmeans和DBscan/kmeans/annoy_cls.ann"# 最佳聚类类别数量(用kmeans和inner找到的)BEST_NUM_CLUSTERS = 2501# 图像特征提取后的向量维度FEATURE_DIM = 190  # 根据自己的模型输出维度修改# 推断图像尺寸CLASSIFY_SIZE = 224# 手动划分分类数量# NUM_CLUSTERS = 3000# 创建输出文件夹os.makedirs(OUTPUT_DIR, exist_ok=True)print("ONNX Runtime版本:", ort.__version__)print("可用执行器:", ort.get_available_providers())#   可用执行器: ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'AzureExecutionProvider', 'CPUExecutionProvider']# 加载ONNX模型(动态获取输入/输出名称)ort_session = ort.InferenceSession(MODEL_PATH,providers=get_onnx_providers())# 确保输出名称正确input_name = ort_session.get_inputs()[0].nameoutput_name = ort_session.get_outputs()[0].namefrom annoy import AnnoyIndext = AnnoyIndex(FEATURE_DIM, metric="angular")  # FEATURE_DIM是图像特征提取后的向量维度# 提取特征向量features = []image_paths = []print("====开始对所有图像推理, 提取特征====")for index, filename in tqdm(enumerate(os.listdir(IMAGE_DIR))):if filename.lower().endswith((".png", ".jpg", ".jpeg")):path = os.path.join(IMAGE_DIR, filename)try:# 前处理input_tensor = preprocess_image(path)# 推断feature = ort_session.run([output_name], {input_name: input_tensor})[0]# 确保特征展平为1D,  190维度features.append(feature.reshape(-1))image_paths.append(path)# 增加到Annoy树t.add_item(index, feature.reshape(-1))except Exception as e:print(f"Error processing {filename}: {str(e)}")t.build(BEST_NUM_CLUSTERS)    # 根据kmeans聚类找到最佳的聚类类别数量t.save(ANNOY_PATH)print("+++++提取特征结束+++++")print("+++++Annoy建树结束+++++++++")

生成建树annoy_cls.ann文件。

三. 用Annoy建树后的树特征匹配聚类归类图像

使用流程:

(1)加载ann建树文件

(2)提取单张A图像特征

(3)单张A图像特征与ann建树文件的特征进行比对,找到ann建树文件里面的与A图像特征相似的TOP_K的底库图像,拷贝走或者移动走。

import os
import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
import shutil
from sklearn.cluster import KMeans
from sklearn.preprocessing import Normalizer
from  tqdm import tqdm
import math
import matplotlib.pyplot as plt# 图像预处理函数
def preprocess_image(image_path):roi_frame= cv2.imread(image_path)width = roi_frame.shape[1]height = roi_frame.shape[0]if (width != CLASSIFY_SIZE) or (height != CLASSIFY_SIZE) :if width > height:# 将图像逆时针旋转90度roi_frame = cv2.rotate(roi_frame, cv2.ROTATE_90_COUNTERCLOCKWISE)new_height = CLASSIFY_SIZEnew_width = int(roi_frame.shape[1] * (CLASSIFY_SIZE / roi_frame.shape[0]))roi_frame = cv2.resize(roi_frame, (new_width, new_height))# 计算上下左右漂移量y_offset = (CLASSIFY_SIZE - roi_frame.shape[0]) // 2x_offset = (CLASSIFY_SIZE - roi_frame.shape[1]) // 2gray_image = np.full((CLASSIFY_SIZE, CLASSIFY_SIZE, 3), 128, dtype=np.uint8)# 将调整大小后的目标图像放置到灰度图上gray_image[y_offset:y_offset + roi_frame.shape[0], x_offset:x_offset + roi_frame.shape[1]] = roi_frame# # 显示结果# cv2.imshow("gray_image", gray_image)# cv2.waitKey(1)# 将图像转为 rgbgray_image =  cv2.cvtColor(gray_image, cv2.COLOR_BGR2RGB)else:gray_image = cv2.cvtColor(roi_frame, cv2.COLOR_BGR2RGB)img_np = np.array(gray_image).transpose(2, 0, 1).astype(np.float32)# 假设模型需要[0,1]归一化img_np = img_np / 255.0# 均值 方差mean = np.array([0.485, 0.456, 0.406],dtype=np.float32).reshape(3, 1, 1)std = np.array([0.229, 0.224, 0.225],dtype=np.float32).reshape(3, 1, 1)img_np= (img_np - mean)/stdreturn np.expand_dims(img_np, axis=0)# todo
# 卸载 onnxruntime
# 安装  pip install onnxruntime-gpu
def get_onnx_providers():# 检查是否安装了GPU版本的ONNX Runtimeall_provider = ort.get_available_providers()if "CUDAExecutionProvider" in all_provider:providers = [("CUDAExecutionProvider", {"device_id": 0,"arena_extend_strategy": "kNextPowerOfTwo","gpu_mem_limit": 6 * 1024 * 1024 * 1024,  # 限制GPU内存使用为2GB"cudnn_conv_algo_search": "EXHAUSTIVE","do_copy_in_default_stream": True,}),"CPUExecutionProvider"]print("检测到NVIDIA GPU,使用CUDA加速")return providerselse:print("未检测到NVIDIA GPU,使用CPU")return ["CPUExecutionProvider"]if __name__ =="__main__":root_path =  "/home/xxx/Download"# ONNX模型路径MODEL_PATH = os.path.join(root_path, "08以图搜图_找相似度/98_weights/classify_modified_model_224.onnx")# 图像文件夹路径IMAGE_DIR = os.path.join(root_path, "08以图搜图_找相似度/99_test_datasets/8_bcd已验收/8")# 分类结果输出路径OUTPUT_DIR = os.path.join(root_path, "08以图搜图_找相似度/99_test_datasets/8_bcd已验收/8_kmeans_besk_k_classify")# 保存annoy建树路径ANNOY_PATH = os.path.join(root_path, "08以图搜图_找相似度/01kmeans和DBscan/kmeans/annoy_cls.ann")# 最佳聚类类别数量BEST_NUM_CLUSTERS = 2501# 图像特征提取后的向量维度FEATURE_DIM = 190# 推断图像尺寸CLASSIFY_SIZE = 224# 取top10TOP_K = 10# 手动划分分类数量# NUM_CLUSTERS = 3000# 创建输出文件夹os.makedirs(OUTPUT_DIR, exist_ok=True)print("ONNX Runtime版本:", ort.__version__)print("可用执行器:", ort.get_available_providers())#   可用执行器: ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'AzureExecutionProvider', 'CPUExecutionProvider']# 加载ONNX模型(动态获取输入/输出名称)ort_session = ort.InferenceSession(MODEL_PATH,providers=get_onnx_providers())# 确保输出名称正确input_name = ort_session.get_inputs()[0].nameoutput_name = ort_session.get_outputs()[0].namefrom annoy import AnnoyIndexAnnoy_ = AnnoyIndex(FEATURE_DIM, metric="angular")  # FEATURE_DIM是图像特征提取后的向量维度Annoy_.load(ANNOY_PATH) # 提取特征向量features = []image_paths = []# 获取所有图像路径for _, filename in tqdm(enumerate(os.listdir(IMAGE_DIR))):if filename.lower().endswith((".png", ".jpg", ".jpeg")):path = os.path.join(IMAGE_DIR, filename)image_paths.append(path)print("====开始对所有图像推理, 提取特征, 根据创建的树进行聚类====")for _, filename in tqdm(enumerate(os.listdir(IMAGE_DIR))):if filename.lower().endswith((".png", ".jpg", ".jpeg")):path = os.path.join(IMAGE_DIR, filename)try:# 前处理input_tensor = preprocess_image(path)# 推断feature = ort_session.run([output_name], {input_name: input_tensor})[0]# 确保特征展平为1D,  190维度features.append(feature.reshape(-1))# image_paths.append(path)# 取top10的相似图像similar_img_indices, similar_img_distances=Annoy_.get_nns_by_vector(feature.reshape(-1), TOP_K, include_distances=True)print("similar_img_index:", similar_img_indices)print("similar_img_distance:", similar_img_distances)shutil.copy(path, os.path.join(OUTPUT_DIR,"11"))#  移动相似图像到输出目录for idx in similar_img_indices:similar_image_path = image_paths[idx]# shutil.move(similar_image_path, OUTPUT_DIR)shutil.copy(similar_image_path, OUTPUT_DIR)except Exception as e:print(f"Error processing {filename}: {str(e)}")print("+++++提取特征结束+++++")print("+++++根据Annoy数特征聚类归类图像结束+++++++++")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/902858.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是电容?

什么是电容? 电荷与电压的比值就是电容量C。电容单位为法拉(F)。1法拉电容器在电压为1V时储存的电荷量为1库伦(C)。图1.1中的球体表面电压与储存的电荷Q关联。电压V等于。Q/V等于。如果球体位于电介质媒介中,电压V降低倍,Q/V等于。在电介质媒…

Linux服务器上mysql8.0+数据库优化

1.配置文件路径 /etc/my.cnf # CentOS/RHEL /etc/mysql/my.cnf # Debian/Ubuntu /etc/mysql/mysql.conf.d/mysqld.cnf # Ubuntu/Debian检查当前配置文件 sudo grep -v "^#" /etc/mysql/mysql.conf.d/mysqld.cnf | grep -v "^$&q…

MQTT学习资源

MQTT入门:强烈推荐

第十二章 Python语言-大数据分析PySpark(终)

目录 一. PySpark前言介绍 二.基础准备 三.数据输入 四.数据计算 1.数据计算-map方法 2.数据计算-flatMap算子 3.数据计算-reduceByKey方法 4.数据计算-filter方法 5.数据计算-distinct方法 6.数据计算-sortBy方法 五.数据输出 1.输出Python对象 (1&am…

【XR手柄交互】Unity 中使用 InputActions 实现手柄控制详解(基于 OpenXR + Unity新输入系统(Input Actions))

摘要: 本文主要介绍如何使用 Input Actions(Unity 新输入系统) OpenXR 来实现 VR手柄控制(监听ABXY按钮、摇杆、抓握等操作)。 🎮 Unity 中使用 InputActions 实现手柄控制详解(基于 OpenXR 新…

java实现网格交易回测

以下是一个基于Java实现的简单网格交易回测程序框架,以证券ETF(512880)为例。代码包含历史数据加载、网格策略逻辑和基础统计指标: import java.io.BufferedReader; import java.io.FileReader; import java.text.ParseException…

探秘 3D 展厅之卓越优势,解锁沉浸式体验新境界

(一)打破时空枷锁,全球触达​ 3D 展厅的首要优势便是打破了时空限制。在传统展厅中,观众需要亲临现场,且必须在展厅开放的特定时间内参观。而 3D 展厅依托互联网,让观众无论身处世界哪个角落,只…

第十二届蓝桥杯 2021 C/C++组 直线

目录 题目: 题目描述: 题目链接: 思路: 核心思路: 两点确定一条直线: 思路详解: 代码: 第一种方式代码详解: 第二种方式代码详解: 题目:…

微信小程序蓝牙连接打印机打印单据完整Demo【蓝牙小票打印】

文章目录 一、准备工作1. 硬件准备2. 开发环境 二、小程序配置1. 修改app.json 三、完整代码实现1. pages/index/index.wxml2. pages/index/index.wxss3. pages/index/index.js 四、ESC/POS指令说明五、测试流程六、常见问题解决七、进一步优化建议 下面我将提供一个完整的微信…

ubuntu opencv 安装

1.ubuntu opencv 安装 在Ubuntu系统中安装OpenCV,可以通过多种方式进行,以下是一种常用的安装方法,包括从源代码编译安装。请注意,安装步骤可能会因OpenCV的版本和Ubuntu系统的具体版本而略有不同。 一、安装准备 更新系统&…

【C++】class静态常量

Usage: static const T 1 background static const成员属于类,而不是类的实例,所以它们的初始化需要在类外进行(或者在C17之后可以用inline初始化)。 使用中可能遇到的情况: 在头文件中声明一个static const成员,然后在多个cpp…

Java 安全:如何防止 DDoS 攻击?

一、DDoS 攻击简介 DDoS(分布式拒绝服务)攻击是一种常见的网络攻击手段,攻击者通过控制大量的僵尸主机向目标服务器发送海量请求,致使服务器资源耗尽,无法正常响应合法用户请求。在 Java 应用开发中,了解 …

统计文件中单词出现的次数并累计

# 统计单词出现次数 fileopen("E:\Dasktape/python_test.txt","r",encoding"UTF-8") f1file.read() # 读取文件 countf1.count("is") # 统计文件中is 单词出现的次数 print(f"此文件中单词is出现了{count}次")# 2.判断单词出…

C语言实现贪心算法

一、贪心算法核心思想 特征:在每一步选择中都采取当前状态下最优(局部最优)的选择,从而希望导致全局最优解 适用场景:需要满足贪心选择性质和最优子结构性质 二、经典贪心算法示例 1. 活动选择问题 目标&#xff1a…

《一文读懂Transformers库:开启自然语言处理新世界的大门》

《一文读懂Transformers库:开启自然语言处理新世界的大门》 GitHub - huggingface/transformers: 🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX. HF-Mirror Hello! Transformers快速入门 pip install transformers -i https:/…

Vue里面elementUi-aside 和el-main不垂直排列

先说解决方法 main.js少导包 import element-ui/lib/theme-chalk/index.css; //加入此行即可 问题复现 排查了一个小时终于找出来问题了,建议导包去看官方的文档,作者就是因为看了别人的导包流程导致的问题 导包官网地址Element UI导包快速入门

MYSQL 常用字符串函数 和 时间函数详解

一、字符串函数 1、​CONCAT(str1, str2, …) 拼接多个字符串。 SELECT CONCAT(Hello, , World); -- 输出 Hello World2、SUBSTRING(str, start, length)​​ 或 ​SUBSTR() 截取字符串。 SELECT SUBSTRING(MySQL, 3, 2); -- 输出 SQ3、LENGTH(str)​​ 与 ​CHAR_LENGTH…

Python-Agent调用多个Server-FastAPI版本

Python-Agent调用多个Server-FastAPI版本 Agent调用多个McpServer进行工具调用 1-核心知识点 fastAPI的快速使用agent调用多个server 2-思路整理 1)先把每个子服务搭建起来2)再暴露一个Agent 3-参考网址 VSCode配置Python开发环境:https:/…

Drools+自定义规则库

文章目录 前言一、创建规则库二、SpringBootDrools程序1.Maven依赖2.application.yml3.Mapper.xml4.Drools配置类5.Service6.Contoller7.测试接口 前言 公司的技术方案想搭建Drools自定义规则库配合大模型进行数据的校验。本篇用来记录使用SpringBoot配合Drools开发Demo程序。…

潮了 低配电脑6G显存生成60秒AI视频 本地部署/一键包/云算力部署/批量生成

最近发现了一个让人眼前一亮的工具——FramePack,它能用一块普通的6GB显存笔记本GPU,生成60秒电影级的高清视频画面,效果堪称炸裂!那么我们就把他本地部署起来玩一玩、下载离线一键整合包,或者是用云算力快速上手。接下…