【深度学习总结_02】在自己的数据集微调SAM

【深度学习总结_02】在自己的数据集微调SAM

前言

SAM (Segment Anything Model)是Meta AI开发的一种分割模型。它被认为是计算机视觉的第一个基础模型。SAM是在包含数百万图像和数十亿mask的庞大数据语料库上进行训练的,这使得它非常强大。SAM能够为各种各样的图像生成准确的分割mask。

SAM通常在自然图像上表现优异,但是在特定领域,如医疗影响,遥感图像等,由于训练数据集缺乏这些数据,SAM的效果并不是理想。因此,在特定数据集上微调SAM是十分有必要的。

准备工作

(1)安装好segment anything:

git clone https://github.com/facebookresearch/segment-anything.git
cd segment-anything
python setup.py install

(2)安装lightning包,它是轻量级的PyTorch库,用于高性能人工智能研究的轻量级PyTorch包装器。本文基于它对SAM进行微调:

pip install lightning

使用的数据集下载地址:https://han-seg2023.grand-challenge.org/,它是一个多器官的医疗影像数据集,当然,你也可以使用自己的数据集

步骤

1、创建配置文件

该配置文件含有SAM的哪些部分需要训练,以及数据集的相关配置,如数据集位置,具体配置如下(在config.py文件中):

from box import Box
config = {"num_devices": 1,"batch_size": 6,"num_workers": 4,"num_epochs": 20,"save_interval": 2,"resume": None,"out_dir": "模型权重输出地址","opt": {"learning_rate": 8e-4,"weight_decay": 1e-4,"decay_factor": 10,"steps": [60000, 86666],"warmup_steps": 250,},"model": {"type": 'vit_b',"checkpoint": "SAM的权重地址","freeze": {"image_encoder": True,"prompt_encoder": True,"mask_decoder": True,},},"dataset": {"root_dir": "数据集的根目录","sample_num": 4,"target_size": 1024}
}
cfg = Box(config)

其中freeze部分决定SAM的哪些部分冷却不用训练,dataset则是数据集的相关配置,sample_num表示采样的point的数目,target_size则是输入SAM的图片大小。

这里使用了box这个包,可以通过如下命令安装:

pip install python-box

2、构建数据集

该部分负责在数据集加载的时候选择哪些数据进行训练,这里我选择器官mandible进行训练。

同时由于该数据是3D数据,对数据进行切片处理,将3D数据变成2D图像,该部分代码为:

class HaNDataset(Dataset):def __init__(self, cfg):super().__init__()self.gt_path = os.path.join(cfg.dataset.root_dir, "oar_3d")self.img_path = os.path.join(cfg.dataset.root_dir, "ct_3d")# 文件列表self.img_file_list = sorted(os.listdir(self.img_path))self.gt_file_list = sorted(os.listdir(self.gt_path))# 器官类别self.category = [7]self.cat2names = {7 : "mandible"}# 数据列表,含所有切片self.data_list = []for i in range(len(self.img_file_list)):img_file_path = os.path.join(self.img_path, self.img_file_list[i])gt_file_path = os.path.join(self.gt_path, self.gt_file_list[i])img_data = nib.load(img_file_path).get_fdata()gt_data = nib.load(gt_file_path).get_fdata()axial_num = img_data.shape[2]for a in range(axial_num):a_gt_data = gt_data[:, :, a]ps_gt_data = np.zeros_like(a_gt_data)for c in self.category:region = (a_gt_data == c)if np.sum(region) > 0:self.data_list.append([i, a, c])print(f"Data size is:{len(self.data_list)}")# 输入SAM的尺寸要是这个self.target_size = cfg.dataset.target_size# 正负样本点数目self.sample_point_num = cfg.dataset.sample_numdef __len__(self):return len(self.data_list)

由于HaN这个数据集的数据格式是nii文件,其数据的范围是0-2000,而图像的数据范围是0-255,因此需要将数据范围截断并重新映射。

输入SAM的图像大小应为1024*1024,因此需要将其resize成目标尺寸。

除此之外,由于HaN并没有提供box和point提示,因此还需要从mask中自动获得相应的提示。

这些部分的实现为(都在HaNDataset当中):

def convert_to_three_channels(self, image):# 创建一个具有相同尺寸的3通道图像数组three_channel_image 
= np.zeros((image.shape[
0
], image.shape[
1
], 
3
), dtype=np.uint8)# 将原始单通道图像复制到每个通道for i in range(3):three_channel_image[:, :, i] 
= imagereturn three_channel_image
def __getitem__(self, idx):data_id = self.data_list[idx]f_id = data_id[0]axial_id = data_id[1]category_id = data_id[2]name = self.cat2names[category_id]img_data_path = os.path.join(self.img_path, self.img_file_list[f_id])gt_data_path = os.path.join(self.gt_path, self.gt_file_list[f_id])# nii文件的数据范围是0-2000,和图像的范围不符img_data = nib.load(img_data_path).get_fdata()# 截断,对于ct图像img_data[img_data < (50 + 1024 - 200)] = (50 + 1024 - 200)img_data[img_data > (50 + 1024 + 200)] = (50 + 1024 + 200)img_data = (img_data - (50 + 1024 - 200)) / 400.0 * 255.0img_data = img_data[:, :, axial_id]img_data = self.convert_to_three_channels(img_data)all_gt_data = nib.load(gt_data_path).get_fdata()[:, :, axial_id]gt_data = np.zeros_like(all_gt_data)gt_data[all_gt_data == category_id] = 1# 将image和gt变为target sizeorg_size = gt_data.shapetransforms = train_transforms(self.target_size, org_size[0], org_size[1])augments = transforms(image=img_data, mask=gt_data)img_data, gt_data = augments['image'].to(torch.float32), augments['mask'].to(torch.int64)# 获得box,验证时max_pixel为0bbox_data = get_boxes_from_mask(gt_data, max_pixel=0)[0]# 获得point提示point_coords, point_labels = init_point_sampling(gt_data, self.sample_point_num)return {"org_size": torch.tensor(org_size),"category" : name,"image": img_data,"label" : gt_data,"bbox" : bbox_data,"point_coords": point_coords,"point_labels": point_labels}

获得box和point,以及resize图像的代码为:

def init_point_sampling(mask, get_point=1):if isinstance(mask, torch.Tensor):mask 
= mask.numpy()# Get coordinates of black/white pixelsfg_coords = np.argwhere(mask == 1)[:, ::-1]bg_coords = np.argwhere(mask == 0)[:, ::-1]fg_size = len(fg_coords)bg_size = len(bg_coords)if get_point == 1:if fg_size > 0:index = np.random.randint(fg_size)fg_coord = fg_coords[index]label = 1else:index = np.random.randint(bg_size)fg_coord = bg_coords[index]label = 0return torch.as_tensor([fg_coord.tolist()], dtype=torch.float), torch.as_tensor([label], dtype=torch.int)else:num_fg = get_point // 2num_bg = get_point - num_fgfg_indices = np.random.choice(fg_size, size=num_fg, replace=True)bg_indices = np.random.choice(bg_size, size=num_bg, replace=True)fg_coords = fg_coords[fg_indices]bg_coords = bg_coords[bg_indices]coords = np.concatenate([fg_coords, bg_coords], axis=0)labels = np.concatenate([np.ones(num_fg), np.zeros(num_bg)]).astype(int)indices = np.random.permutation(get_point)coords, labels = torch.as_tensor(coords[indices], dtype=torch.float), torch.as_tensor(labels[indices],dtype=torch.int)return coords, labels
def get_boxes_from_mask(mask, box_num=1, std=0.1, max_pixel=5):if isinstance(mask, torch.Tensor):mask = mask.numpy()label_img = label(mask)regions = regionprops(label_img)# Iterate through all regions and get the bounding box coordinatesboxes = [tuple(region.bbox) for region in regions]# If the generated number of boxes is greater than the number of categories,# sort them by region area and select the top n regionsif len(boxes) >= box_num:sorted_regions = sorted(regions, key=lambda x: x.area, reverse=True)[:box_num]boxes = [tuple(region.bbox) for region in sorted_regions]# If the generated number of boxes is less than the number of categories,# duplicate the existing boxeselif len(boxes) < box_num:num_duplicates = box_num - len(boxes)boxes += [boxes[i % len(boxes)] for i in range(num_duplicates)]# Perturb each bounding box with noisenoise_boxes = []for box in boxes:y0, x0, y1, x1 = boxwidth, height = abs(x1 - x0), abs(y1 - y0)# Calculate the standard deviation and maximum noise valuenoise_std = min(width, height) * stdmax_noise = min(max_pixel, int(noise_std * 5))# Add random noise to each coordinatetry:noise_x = np.random.randint(-max_noise, max_noise)except:noise_x = 0try:noise_y = np.random.randint(-max_noise, max_noise)except:noise_y = 0x0, y0 = x0 + noise_x, y0 + noise_yx1, y1 = x1 + noise_x, y1 + noise_ynoise_boxes.append((x0, y0, x1, y1))return torch.as_tensor(noise_boxes, dtype=torch.float)
def train_transforms(img_size, ori_h, ori_w):transforms = []transforms.append(A.Resize(int(img_size), int(img_size), interpolation=cv2.INTER_NEAREST))transforms.append(ToTensorV2(p=1.0))return A.Compose(transforms, p=1.)

3、构建SAM模型

因为我们已经安装好了segment anything,因此可以直接调用相关模块,然后组成一个生成mask的流程即可,该部分代码为:

import torch.nn as nn
import torch.nn.functional as F
from segment_anything import sam_model_registry
from segment_anything import SamPredictor
class Model(nn.Module):def __init__(self, cfg):super().__init__()self.cfg = cfgdef setup(self):self.model = sam_model_registry[self.cfg.model.type](checkpoint=self.cfg.model.checkpoint)self.model.train()if self.cfg.model.freeze.image_encoder:for name, param in self.model.image_encoder.named_parameters():param.requires_grad = Falseif self.cfg.model.freeze.prompt_encoder:for name, param in self.model.prompt_encoder.named_parameters():param.requires_grad = False# freeze mask decoder参数if self.cfg.model.freeze.mask_decoder:for name, param in self.model.mask_decoder.named_parameters():param.requires_grad = Falsedef forward(self, images, bboxes, org_size, point_coords = None, point_labels = None):_, _, H, W = images.shapeimage_embeddings = self.model.image_encoder(images)pred_masks = []ious = []# 还要添加points,输入格式(points coords, points label): #coords:B,N,2  labels:B,N# 一个batch一个batch处理for embedding, bbox, coord, label in zip(image_embeddings, bboxes, point_coords, point_labels):bbox = bbox.unsqueeze(0)coord = coord.unsqueeze(0)label = label.unsqueeze(0)point = (coord, label)sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=point,boxes=bbox,masks=None,)low_res_masks, iou_predictions = self.model.mask_decoder(image_embeddings=embedding.unsqueeze(0),image_pe=self.model.prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=False,)masks = F.interpolate(low_res_masks,(H, W),mode="bilinear",align_corners=False,)pred_masks.append(masks.squeeze(1))ious.append(iou_predictions)return pred_masks, iousdef get_predictor(self):return SamPredictor(self.model)

其中setup方法决定哪些参数需要进行训练,哪些不用。

4、使用数据进行训练

首先使用lightning进行配置:

import lightning as L
from config import cfg
fabric = L.Fabric(accelerator="auto",devices=cfg.num_devices,strategy="auto",loggers=[TensorBoardLogger(cfg.out_dir, name="lightning-sam")])
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)

然后创建模型和加载数据集,代码为:

with fabric.device:model = Model(cfg)model.setup()
train_data = HaNDataset(cfg)
train_loader = DataLoader(train_data, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=True)
train_data = fabric._setup_dataloader(train_loader)

接着创建优化器,代码为:

def configure_opt(cfg: Box, model: Model):def lr_lambda(step):if step < cfg.opt.warmup_steps:return step / cfg.opt.warmup_stepselif step < cfg.opt.steps[0]:return 1.0elif step < cfg.opt.steps[1]:return 1 / cfg.opt.decay_factorelse:return 1 / (cfg.opt.decay_factor**2)optimizer 
= torch.optim.Adam(model.model.parameters(), lr=cfg.opt.learning_rate, weight_decay=cfg.opt.weight_decay)scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)return optimizer, scheduler
optimizer, scheduler = configure_opt(cfg, model)
model, optimizer = fabric.setup(model, optimizer)

最后遍历数据集进行训练,这里使用的损失函数有Focal loss,Dice loss和IoU loss,代码为:

def train_sam(cfg: Box,fabric: L.Fabric,model: Model,optimizer: _FabricOptimizer,scheduler: _FabricOptimizer,train_dataloader: DataLoader,
)
:"""The SAM training loop."""focal_loss 
= FocalLoss()dice_loss = DiceLoss()# 从上次中断的地方训练start_epoch = 1if cfg.resume:map_location = 'cuda:%d' % fabric.global_rankcheckpoint = torch.load(cfg.resume, map_location={'cuda:0': map_location})start_epoch = checkpoint['epoch']network = checkpoint['network']opt = checkpoint['optimizer']sche = checkpoint['scheduler']model.model.load_state_dict(network)optimizer.load_state_dict(opt)scheduler.load_state_dict(sche)fabric.print(f"resume from:{cfg.resume}")for epoch in range(start_epoch, cfg.num_epochs):batch_time = AverageMeter(name="batch_time")data_time = AverageMeter(name="data_time")focal_losses = AverageMeter(name="focal_losses")dice_losses = AverageMeter(name="dice_losses")iou_losses = AverageMeter(name="iou_losses")total_losses = AverageMeter(name="total_losses")end = time.time()# 保存模型if epoch % cfg.save_interval == 0:fabric.print(f"Saving checkpoint to {cfg.out_dir}")state_dict = model.model.state_dict()checkpoint = {'epoch': epoch,'network': state_dict,'optimizer': optimizer.state_dict(),'scheduler': scheduler.state_dict()}# 多卡环境下只在rank=0的gpu上保存if fabric.global_rank == 0:torch.save(checkpoint, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-ckpt.pth"))for iter, data in enumerate(train_dataloader):data_time.update(time.time() - end)images = data["image"]gt_masks = data["label"]bboxes = data["bbox"]batch_size = images.shape[0]pred_masks, iou_predictions = model(images, bboxes, data["point_coords"], data["point_labels"])num_masks = sum(len(pred_mask) for pred_mask in pred_masks)loss_focal = torch.tensor(0., device=fabric.device)loss_dice = torch.tensor(0., device=fabric.device)loss_iou = torch.tensor(0., device=fabric.device)for pred_mask, gt_mask, iou_prediction in zip(pred_masks, gt_masks, iou_predictions):batch_iou = calc_iou(pred_mask, gt_mask)loss_focal += focal_loss(pred_mask, gt_mask, num_masks)loss_dice += dice_loss(pred_mask, gt_mask, num_masks)loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masksloss_total = 20. * loss_focal + loss_dice + loss_iouoptimizer.zero_grad()fabric.backward(loss_total)optimizer.step()scheduler.step()batch_time.update(time.time() - end)end = time.time()focal_losses.update(loss_focal.item(), batch_size)dice_losses.update(loss_dice.item(), batch_size)iou_losses.update(loss_iou.item(), batch_size)total_losses.update(loss_total.item(), batch_size)fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]'f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]'f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]'f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]'f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]'f' | IoU Loss [{iou_losses.val:.4f} ({iou_losses.avg:.4f})]'f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]')

通过以上步骤就可以对SAM进行微调了,如果是对mask decoder进行微调,显存占用大概在17G左右。

参考链接

lightning-sam

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/827148.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

超越OpenAI,谷歌重磅发布从大模型蒸馏的编码器Gecko

引言&#xff1a;介绍文本嵌入模型的重要性和挑战 文本嵌入模型在自然语言处理&#xff08;NLP&#xff09;领域扮演着至关重要的角色。它们将文本转换为密集的向量表示&#xff0c;使得语义相似的文本在嵌入空间中彼此靠近。这些嵌入被广泛应用于各种下游任务&#xff0c;包括…

VideoComposer: Compositional Video Synthesis with Motion Controllability

decompose videos into three distinct types of conditions: textual conditions, spatial conditions, temperal conditions 条件的内容&#xff1a; a. textual condition: coarse grained visual content and motions, 使用openclip vit-H/14的text encoder b. spatial co…

Splashtop 将在 NAB 展会上推出音视频剪辑增强功能

加利福尼亚州拉斯维加斯 Splashtop 在简化随处办公远程解决方案领域处于领先地位&#xff0c;在今年举行的 NAB 展会上将推出 Enterprise 解决方案的高级性能功能&#xff0c;均面向广播和媒体工作者而设计。 Splashtop Enterprise 经过优化&#xff0c;可为执行视频剪辑、唇…

Excel文件解析--超大Excel文件读写

使用POI写入 当我们想在Excel文件中写入100w条数据时&#xff0c;我们用普通的XSSFWorkbook对象写入时会发现&#xff0c;只有在将100w条数据全部加载入内存后才会用write()方法统一写入&#xff0c;这样效率很低&#xff0c;所以我们引入了SXSSFWorkbook进行超大Excel文件的读…

java开发之路——node.js安装

1. 安装node.js 最新Node.js安装详细教程及node.js配置 (1)默认的全局的安装路径和缓存路径 npm安装模块或库(可以统称为包)常用的两种命令形式&#xff1a; 本地安装(local)&#xff1a;npm install 名称全局安装(global)&#xff1a;npm install 名称 -g本地安装和全局安装…

【Leetcode】string类刷题

&#x1f525;个人主页&#xff1a;Quitecoder &#x1f525;专栏&#xff1a;Leetcode刷题 目录 1.仅反转字母2.字符串中第一个唯一字符3.验证回文串4.字符串相加5.反转字符串I I6.反转字符串中的单词III7.字符串相乘8.把字符串转换为整数 1.仅反转字母 题目链接&#xff1a;…

一篇文章带您了解面向对象(java)

1.简单理解面向过程编程和面向对象编程 面向过程编程&#xff1a;开发一个一个的方法&#xff0c;有数据需要处理&#xff0c;我们就可以调用方法来处理。 package com.web.quictstart;public class demo1 {public static void main(String[] args) {totalScore("张三&q…

mac上VMware fusion net模式无法正常使用的问题

更新时间&#xff1a;2024年04月22日21:39:04 1. 问题 环境&#xff1a; intel芯片的macbook pro VMware fusion 13.5.1 无法将“Ethernet0”连接到虚拟网络“/dev/vmnet8”。在这里显示这个之后&#xff0c;应该是vmnet8的网段发生了冲突&#xff0c;所以导致无法正常使用…

前端开发攻略---拖动归类,将元素拖拽到相应位置

1、演示 2、代码 <!DOCTYPE html><html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><meta name"viewport" content"widthdevice-…

2024年Q1季度平板电视行业线上市场销售数据分析

Q1季度平板电视线上市场表现不如预期。 根据鲸参谋数据显示&#xff0c;2024年1月至3月线上电商平台&#xff08;京东天猫淘宝&#xff09;平板电视累计销量约360万件&#xff0c;环比下降12%&#xff0c;同比下降30%&#xff1b;累计销售额约99亿元&#xff0c;环比下降28%&a…

学习STM32第十七天

备份域详解 一、简介 在参考手册的电源控制章节&#xff0c;提到了备份域&#xff0c;BKPR是在RTC外设中用到&#xff0c;包含20个备份数据寄存器&#xff08;80字节&#xff09;&#xff0c;备份域包括4KB的备份SRAM&#xff0c;以32位、16位或8位模式寻址&#xff0c;在VBAT…

C++初阶学习第二弹——C++入门(下)

C入门&#xff08;上&#xff09;&#xff1a;C初阶学习第一弹——C入门&#xff08;上&#xff09;-CSDN博客 目录 一、引用 1.1 引用的实质 1.2 引用的用法 二、函数重载 三、内敛函数 四、auto关键字 五、总结 前言&#xff1a; 在上面一章我们已经讲解了C的一些基本…

Vue2进阶之Vue2高级用法

Vue2高级用法 mixin示例一示例二 plugin插件自定义指令vue-element-admin slot插槽filter过滤器 mixin 示例一 App.vue <template><div id"app"></div> </template><script> const mixin2{created(){console.log("mixin creat…

【Java网络编程】TCP通信(Socket 与 ServerSocket)和UDP通信的三种数据传输方式

目录 1、TCP通信 1.1、Socket 和 ServerSocket 1.3、TCP通信示例 2、UDP的三种通信&#xff08;数据传输&#xff09;方式 1、TCP通信 TCP通信协议是一种可靠的网络协议&#xff0c;它在通信的两端各建立一个Socket对象 通信之前要保证连接已经建立&#xff08;注意TCP是一…

【Interconnection Networks 互连网络】Torus 网络拓扑

1. Torus 网络拓扑2. Torus 网络拓扑结构References 1. Torus 网络拓扑 Torus 和 Mesh 网络拓扑&#xff0c;又可以称为 k-ary n-cubes&#xff0c;在规则的 n 维网格中包裹着 N k^n 个节点&#xff0c;每个维度都有 k 个节点&#xff0c;并且最近邻居之间有通道。k-ary n-c…

YOLOv9有效改进专栏汇总|未来更新卷积、主干、检测头注意力机制、特征融合方式等创新![2024/4/21]

​ 专栏介绍&#xff1a;YOLOv9改进系列 | 包含深度学习最新创新&#xff0c;助力高效涨点&#xff01;&#xff01;&#xff01; 专栏介绍 YOLOv9作为最新的YOLO系列模型&#xff0c;对于做目标检测的同学是必不可少的。本专栏将针对2024年最新推出的YOLOv9检测模型&#xff0…

《HCIP-openEuler实验指导手册》1.3Apache动态功能模块加载卸载练习

1.3.1 配置思路 mod_status 模块可以帮助管理员通过web界面监控Apache运行状态&#xff0c;通过LoadModule指令加载该模块&#xff0c;再配置相关权限&#xff0c;并开启ExtendedStatus后&#xff0c;即可使用该模块。 1.3.2 配置步骤 检查mod_status模块状态&#xff08;使…

net模块

建立TCP的链接 1 发送消息的服务 2 接收消息 2 建立http的链接让浏览器进行访问 import net from netconst html <h1>TCP</h1>const respinseHeaders [HTTP/1.1 200 OK,Content-Type:text/html,Content-Length: html.length,\r\n,html]const http net.create…

RK3568 学习笔记 : u-boot 通过 tftp 网络更新 u-boot自身

前言 开发板型号&#xff1a; 【正点原子】 的 RK3568 开发板 AtomPi-CA1 使用 虚拟机 ubuntu 20.04 收到单独 编译 RK3568 u-boot 使用 rockchip Linux 内核的设备树 【替换】 u-boot 下的 rk3568 开发板设备树文件&#xff0c;解决 u-boot 下千兆网卡设备能识别但是无法 Pi…

Spring(下)

接上篇&#xff0c;从第八个问题讲起 八.Spring工厂创建复杂对象 1.什么是复杂对象 简单对象就是可以直接new出来的&#xff0c;也就是直接调用构造方法创建 所以复杂对象就是不能直接通过调用构造方法创建。就比如JDBC中的Connection 2.三种方法 &#xff08;1&#xff…