例子2
检测画框,并且合并
按照框大小,然后融合重叠的框



例子1
检测画框,并且合并
按照分数排序,然后融合重叠的框
缺点 丢失框
例子1
检测画框,并且合并
按照分数排序,然后融合重叠的框
缺点 丢失框



import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches
import time # 添加时间模块#################################### For Image ####################################
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor# 记录总开始时间
total_start_time = time.time()# 记录模型加载开始时间
model_load_start_time = time.time()# Load the model
model = build_sam3_image_model(checkpoint_path="/home/r9000k/v2_project/sam/sam3/assets/model/sam3.pt"
)
processor = Sam3Processor(model, confidence_threshold=0.5)# 记录模型加载结束时间
model_load_end_time = time.time()
model_load_time = model_load_end_time - model_load_start_time
print(f"模型加载时间: {model_load_time:.3f} 秒")# 记录单张检测开始时间
detection_start_time = time.time()image_path = "testimage/微信图片_20251120225838_38.jpg"
image_path = "3.jpg"
# Load an image
image = Image.open(image_path)
inference_state = processor.set_image(image)# Prompt the model with text
output = processor.set_text_prompt(state=inference_state, prompt="building") #building,road and playground building car、people、bicycle# Get the masks, bounding boxes, and scores
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]# 记录单张检测结束时间
detection_end_time = time.time()
detection_time = detection_end_time - detection_start_time
print(f"检测单张时间: {detection_time:.3f} 秒")
print(f"原始检测到 {len(masks)} 个分割结果")
print(f"掩码形状: {masks.shape}")def calculate_iou(box1, box2):"""计算两个边界框的IoU(交并比)"""# 解包坐标x1_1, y1_1, x1_2, y1_2 = box1x2_1, y2_1, x2_2, y2_2 = box2# 计算交集区域xi1 = max(x1_1, x2_1)yi1 = max(y1_1, y2_1)xi2 = min(x1_2, x2_2)yi2 = min(y1_2, y2_2)# 计算交集面积inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)# 计算并集面积box1_area = (x1_2 - x1_1) * (y1_2 - y1_1)box2_area = (x2_2 - x2_1) * (y2_2 - y2_1)union_area = box1_area + box2_area - inter_area# 避免除以零if union_area == 0:return 0.0#iou_= inter_area / union_areaiou_2= inter_area / box2_areaiou_1= inter_area / box1_areaiou_=max(iou_2,iou_1)return iou_def calculate_mask_overlap(mask1, mask2):"""计算两个掩码的重叠比例(基于mask1)"""mask1_np = mask1.cpu().numpy().squeeze().astype(bool)mask2_np = mask2.cpu().numpy().squeeze().astype(bool)# 计算交集和mask1的面积intersection = np.logical_and(mask1_np, mask2_np)mask1_area = np.sum(mask1_np)if mask1_area == 0:return 0.0overlap_ratio = np.sum(intersection) / mask1_areareturn overlap_ratiodef fuse_overlapping_masks(masks, boxes, scores, iou_threshold=0.5, overlap_threshold=0.6):"""融合重叠的掩码和边界框参数:masks: 形状为 [N, 1, H, W] 的掩码张量boxes: 形状为 [N, 4] 的边界框张量scores: 形状为 [N] 的得分张量iou_threshold: IoU阈值,用于判定边界框是否重叠overlap_threshold: 掩码重叠阈值,用于判定是否融合"""if len(masks) == 0:return masks, boxes, scores# 转换为numpy数组进行处理(使用copy()避免负步长问题)boxes_np = boxes.cpu().numpy().copy()scores_np = scores.cpu().numpy().copy()# 按得分降序排序# 降序索引sorted_indices = np.argsort(scores_np)[::-1]# 根据索引重新调整顺序boxes_sorted = boxes_np[sorted_indices]scores_sorted = scores_np[sorted_indices]# 处理masks:先转换为列表,然后按排序索引重新组织masks_list = [masks[i] for i in range(len(masks))]# 根据索引重新调整顺序masks_sorted = [masks_list[i] for i in sorted_indices]# 初始化保留索引keep_indices = []suppressed = set()for i in range(len(boxes_sorted)):print('=====================',i)if i in suppressed:print('1 跳过',i)continuekeep_indices.append(i)for j in range(i + 1, len(boxes_sorted)):if j in suppressed:print('2 跳过',i)continue# 计算IoUiou = calculate_iou(boxes_sorted[i], boxes_sorted[j])if iou > iou_threshold:# 计算掩码重叠比例#overlap_ratio = calculate_mask_overlap(masks_sorted[i], masks_sorted[j])#if overlap_ratio > overlap_threshold:suppressed.add(j)print(f"融合检测结果: 索引 {sorted_indices[i]} (得分: {scores_sorted[i]:.3f}) 覆盖索引 {sorted_indices[j]} (得分: {scores_sorted[j]:.3f})"," iou:",iou)#print(f" - IoU: {iou:.3f}, 掩码重叠比例: {overlap_ratio:.3f}")else:#keep_indices.append(i)print(f"xxxxxx融合检测结果: 索引 {sorted_indices[i]} (得分: {scores_sorted[i]:.3f}) 覆盖索引 {sorted_indices[j]} (得分: {scores_sorted[j]:.3f})"," iou:",iou)# 获取保留的检测结果final_indices = [sorted_indices[i] for i in keep_indices]# 使用PyTorch的索引操作来获取最终结果final_masks = torch.stack([masks_list[i] for i in final_indices])final_boxes = boxes[final_indices]final_scores = scores[final_indices]print(f"融合后剩余 {len(final_masks)} 个分割结果 (减少了 {len(masks) - len(final_masks)} 个)")return final_masks, final_boxes, final_scores# 应用融合函数
print("\n开始融合重叠的检测结果...")
fusion_start_time = time.time()# 调用融合函数
fused_masks, fused_boxes, fused_scores = fuse_overlapping_masks(masks, boxes, scores, iou_threshold=0.5, # 可以调整这个阈值overlap_threshold=0.6 # 可以调整这个阈值
)fusion_time = time.time() - fusion_start_time
print(f"融合完成时间: {fusion_time:.3f} 秒")def overlay_masks_with_info(image, masks, boxes, scores, fusion_mode=False):"""在图像上叠加掩码,并添加ID、得分和矩形框masks: 形状为 [N, 1, H, W] 的四维张量boxes: 形状为 [N, 4] 的边界框张量 [x1, y1, x2, y2]scores: 形状为 [N] 的得分张量fusion_mode: 是否为融合后的模式(使用不同颜色)"""# 转换为RGB模式以便绘制image = image.convert("RGB")draw = ImageDraw.Draw(image)# 尝试加载字体,如果失败则使用默认字体try:# 尝试使用系统中文字体font = ImageFont.truetype("SimHei.ttf", 20)except:try:font = ImageFont.truetype("Arial.ttf", 20)except:font = ImageFont.load_default()# 将掩码转换为numpy数组并去除通道维度masks_np = masks.cpu().numpy().astype(np.uint8) # 形状: [N, 1, H, W]masks_np = masks_np.squeeze(1) # 移除通道维度,形状: [N, H, W]boxes_np = boxes.cpu().numpy() # 形状: [N, 4]scores_np = scores.cpu().numpy() # 形状: [N]n_masks = masks_np.shape[0]# 根据是否为融合模式选择不同的颜色映射if fusion_mode:cmap = plt.cm.get_cmap("viridis", n_masks) # 融合模式使用viridis配色else:cmap = plt.cm.get_cmap("rainbow", n_masks) # 原始模式使用rainbow配色for i, (mask, box, score) in enumerate(zip(masks_np, boxes_np, scores_np)):# 获取颜色color = tuple(int(c * 255) for c in cmap(i)[:3])# 确保掩码是二维的if mask.ndim == 3:mask = mask.squeeze(0)# 创建透明度掩码alpha_mask = (mask * 128).astype(np.uint8) # 0.5透明度# 创建彩色覆盖层overlay = Image.new("RGBA", image.size, color + (128,))# 应用alpha通道alpha = Image.fromarray(alpha_mask, mode='L')overlay.putalpha(alpha)# 叠加到图像上image = Image.alpha_composite(image.convert("RGBA"), overlay).convert("RGB")draw = ImageDraw.Draw(image)# 绘制边界框x1, y1, x2, y2 = box# 确保坐标在图像范围内x1 = max(0, min(x1, image.width))y1 = max(0, min(y1, image.height))x2 = max(0, min(x2, image.width))y2 = max(0, min(y2, image.height))# 绘制矩形框draw.rectangle([x1, y1, x2, y2], outline=color, width=3)# 准备文本信息if fusion_mode:text = f"Fused-ID:{i} Score:{score:.3f}"else:text = f"ID:{i} Score:{score:.3f}"# 计算文本位置(在框的上方)try:# 新版本的PILleft, top, right, bottom = draw.textbbox((0, 0), text, font=font)text_width = right - lefttext_height = bottom - topexcept:# 旧版本的PILtext_width, text_height = draw.textsize(text, font=font)text_x = x1text_y = max(0, y1 - text_height - 5)# 绘制文本背景draw.rectangle([text_x, text_y, text_x + text_width + 10, text_y + text_height + 5], fill=color)# 绘制文本draw.text((text_x + 5, text_y + 2), text, fill="white", font=font)return image# 记录可视化开始时间
visualization_start_time = time.time()# 应用掩码叠加(原始结果)
original_image = Image.open(image_path)
result_image_original = overlay_masks_with_info(original_image, masks, boxes, scores, fusion_mode=False)# 应用掩码叠加(融合后结果)
result_image_fused = overlay_masks_with_info(original_image, fused_masks, fused_boxes, fused_scores, fusion_mode=True)# 保存结果图像
output_path_original = "segmentation_result_original.png"
output_path_fused = "segmentation_result_fused.png"
result_image_original.save(output_path_original)
result_image_fused.save(output_path_fused)# 记录可视化结束时间
visualization_end_time = time.time()
visualization_time = visualization_end_time - visualization_start_time
print(f"可视化时间: {visualization_time:.3f} 秒")print(f"原始分割结果已保存到: {output_path_original}")
print(f"融合后分割结果已保存到: {output_path_fused}")# 设置中文字体或使用英文避免警告
try:# 尝试设置中文字体plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']plt.rcParams['axes.unicode_minus'] = False
except:pass# 显示对比图像
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))# 原始结果
ax1.imshow(result_image_original)
ax1.set_title(f"原始结果: 检测到 {len(masks)} 个分割结果", fontsize=14)
ax1.axis('off')# 融合后结果
ax2.imshow(result_image_fused)
ax2.set_title(f"融合后结果: 剩余 {len(fused_masks)} 个分割结果", fontsize=14)
ax2.axis('off')plt.tight_layout()
plt.savefig("segmentation_comparison.png", bbox_inches='tight', dpi=300, facecolor='white')
plt.show()# 记录总结束时间
total_end_time = time.time()
total_time = total_end_time - total_start_time# 打印详细的时间统计
print("\n" + "="*50)
print("运行时间统计:")
print("="*50)
print(f"模型加载时间: {model_load_time:.3f} 秒")
print(f"检测单张时间: {detection_time:.3f} 秒")
print(f"融合处理时间: {fusion_time:.3f} 秒")
print(f"可视化时间: {visualization_time:.3f} 秒")
print("-"*50)
print(f"总运行时间: {total_time:.3f} 秒")
print("="*50)def save_mask(masks_to_save, boxes_to_save, scores_to_save, prefix="mask"):"""保存单个掩码的通用函数"""print(f"\n保存{prefix}的单个掩码...")for i, (mask, box, score) in enumerate(zip(masks_to_save, boxes_to_save, scores_to_save)):# 创建单个掩码的可视化base_image = Image.open(image_path).convert("RGB")single_draw = ImageDraw.Draw(base_image)# 尝试加载字体try:single_font = ImageFont.truetype("SimHei.ttf", 24)except:try:single_font = ImageFont.truetype("Arial.ttf", 24)except:single_font = ImageFont.load_default()# 处理掩码mask_np = mask.cpu().numpy().squeeze().astype(np.uint8)color = tuple(int(c * 255) for c in plt.cm.get_cmap("viridis", len(masks_to_save))(i)[:3])# 创建透明度掩码alpha_mask = (mask_np * 128).astype(np.uint8)overlay = Image.new("RGBA", base_image.size, color + (128,))alpha = Image.fromarray(alpha_mask, mode='L')overlay.putalpha(alpha)base_image = Image.alpha_composite(base_image.convert("RGBA"), overlay).convert("RGB")single_draw = ImageDraw.Draw(base_image)# 绘制边界框和文本x1, y1, x2, y2 = box.cpu().numpy()single_draw.rectangle([x1, y1, x2, y2], outline=color, width=3)text = f"ID:{i} Score:{score:.3f}"try:# 新版本的PILleft, top, right, bottom = single_draw.textbbox((0, 0), text, font=single_font)text_width = right - lefttext_height = bottom - topexcept:# 旧版本的PILtext_width, text_height = single_draw.textsize(text, font=single_font)text_x = x1text_y = max(0, y1 - text_height - 5)single_draw.rectangle([text_x, text_y, text_x + text_width + 10, text_y + text_height + 5], fill=color)single_draw.text((text_x + 5, text_y + 2), text, fill="white", font=single_font)base_image.save(f"{prefix}_with_info_{i:02d}.png")print(f"保存{prefix} {i:02d}.png (得分: {score:.3f})")# # 保存原始和融合后的单个掩码
# save_mask(masks, boxes, scores, "original_mask")
# save_mask(fused_masks, fused_boxes, fused_scores, "fused_mask")print("所有处理完成!")