SAM 2图像分割模型微调指南:仅需60行代码

2024年08月14日 由 alex 发表 867 0

SAM2(Segment Anything 2)是 Meta 公司推出的一个新模型,旨在对图像中的任何内容进行分割,而不局限于特定的类别或领域。该模型的独特之处在于其训练数据的规模: 1100 万张图像和 110 亿个掩码。这种广泛的训练使 SAM2 成为训练新图像分割任务的强大起点。


你可能会问,既然 SAM 可以分割任何图像,为什么我们还要重新训练它呢?答案是,SAM 擅长常见对象,但在罕见或特定领域的任务上表现不佳。

不过,即使在 SAM 的结果不够理想的情况下,也可以通过在新数据上对其进行微调来显著提高模型的能力。在很多情况下,这比从头开始训练一个模型所需的训练数据更少,效果更好。


本文演示了如何在新数据上对 SAM2 进行微调,只需 60 行代码(不包括注释和导入)。


11


Segment Anything 的工作原理

SAM 的主要工作方式是获取图像和图像中的一个点,然后预测包含该点的段的掩码。这种方法无需人工干预即可实现完整的图像分割,而且对分割的类别或类型没有限制。


使用 SAM 进行完整图像分割的步骤:

  1. 在图像中选择一组点
  2. 使用 SAM 预测包含每个点的片段
  3. 将得到的片段合并为一张地图


虽然 SAM 也可以利用遮罩或边界框等其他输入,但这些主要与涉及人工输入的交互式分割相关。在本教程中,我们将专注于全自动分割,并只考虑单点输入。


下载 SAM2 并设置环境


请按照安装说明进行操作。


一般来说,你需要 Python >=3.11 和 PyTorch。


此外,我们还将使用 OpenCV,可以使用以下方法安装:


pip install opencv-python


下载预训练模型

你还需要从以下网址下载预训练模型:


https://github.com/facebookresearch/segment-anything-2?tab=readme-ov-file#download-checkpoints


有多种模型可供选择,均与本教程兼容。我建议使用小型模型,它的训练速度最快。


下载训练数据

下一步是下载用于微调模型的数据集。在本教程中,我们将使用 LabPics1 数据集来分割材料和液体。你可以从以下网址下载数据集:


https://zenodo.org/records/3697452/files/LabPicsV1.zip?download=1


准备数据阅读器

我们首先需要编写的是数据阅读器。它将为网络读取和准备数据。


数据阅读器需要生成

  1. 图像
  2. 图像中所有片段的掩码。
  3. 每个遮罩内的随机点


让我们从加载依赖项开始:


import numpy as np
import torch
import cv2
import os
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


接下来,我们列出数据集中的所有图像:


data_dir=r"LabPicsV1//" # Path to LabPics1 dataset folderr"LabPicsV1//" # Path to LabPics1 dataset folder
data=[] # list of files in dataset
for ff, name in enumerate(os.listdir(data_dir+"Simple/Train/Image/")):  # go over all folder annotation
    data.append({"image":data_dir+"Simple/Train/Image/"+name,"annotation":data_dir+"Simple/Train/Instance/"+name[:-4]+".png"})


现在是加载训练批次的主函数。训练批次包括 一幅随机图像、属于这幅图像的所有分割蒙版,以及每个蒙版中的一个随机点:


def read_batch(data): # read random image and its annotaion from  the dataset (LabPics)
   #  select image
        ent  = data[np.random.randint(len(data))] # choose random entry
        Img = cv2.imread(ent["image"])[...,::-1]  # read image
        ann_map = cv2.imread(ent["annotation"]) # read annotation
   # resize image
        r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # scalling factor
        Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
        ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),interpolation=cv2.INTER_NEAREST)
   # merge vessels and materials annotations
        mat_map = ann_map[:,:,0] # material annotation map
        ves_map = ann_map[:,:,2] # vessel  annotaion map
        mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # merged map
   # Get binary masks and points
        inds = np.unique(mat_map)[1:] # load all indices
        points= []
        masks = [] 
        for ind in inds:
            mask=(mat_map == ind).astype(np.uint8) # make binary mask
            masks.append(mask)
            coords = np.argwhere(mask > 0) # get all coordinates in mask
            yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate
            points.append([[yx[1], yx[0]]])
        return Img,np.array(masks),np.array(points), np.ones([len(masks),1])


该功能的第一部分是随机选择一张图片并加载它:


ent  = data[np.random.randint(len(data))] # choose random entrylen(data))] # choose random entry
Img = cv2.imread(ent["image"])[...,::-1]  # read image
ann_map = cv2.imread(ent["annotation"]) # read annotation
Note that OpenCV reads images as BGR while SAM expects images as RGB, using […,::-1] to change the image from BGR to RGB.


请注意,OpenCV 以 BGR 格式读取图像,而 SAM 希望读取 RGB 图像。通过使用 [...,::-1],我们将图像从 BGR 变为 RGB。


SAM 希望图像大小不超过 1024,因此我们将调整图像和注释映射的大小。


r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # scalling factormin([1024 / Img.shape[1], 1024 / Img.shape[0]]) # scalling factor
Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),interpolation=cv2.INTER_NEAREST)


重要的一点是,在调整标注图 (ann_map) 的大小时,我们使用 INTER_NEAREST 模式(近邻)。在标注图中,每个像素值都是其所属图段的索引。因此,使用不会给地图带来新值的调整大小方法非常重要。


下一个区块是 LabPics1 数据集的特定格式。注释图 (ann_map) 包含一个通道的图像血管分割图和另一个通道的材料注释图。我们将把它们合并为一张地图。


  mat_map = ann_map[:,:,0] # material annotation map0] # material annotation map
  ves_map = ann_map[:,:,2] # vessel  annotaion map
  mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # merged map


这样我们就得到了一个映射图 (mat_map),其中每个像素的值都是它所属片段的索引(例如:所有值为 3 的单元格都属于片段 3)。我们希望将其转换为一组二进制掩码(0/1),其中每个掩码对应一个不同的片段。此外,我们希望从每个掩码中提取一个点。


inds = np.unique(mat_map)[1:] # list of all indices in map1:] # list of all indices in map
points= [] # list of all points (one for each mask)
masks = [] # list of all masks
for ind in inds:
            mask = (mat_map == ind).astype(np.uint8) # make binary mask for index ind
            masks.append(mask)
            coords = np.argwhere(mask > 0) # get all coordinates in mask
            yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate
            points.append([[yx[1], yx[0]]])
return Img,np.array(masks),np.array(points), np.ones([len(masks),1])


就是这样!我们得到了图像 (Img)、与图像中的片段相对应的二进制掩码列表 (掩码),以及每个掩码中一个点的坐标 (点)。

12


加载 SAM 模型

现在让我们加载网络:


sam2_checkpoint = "sam2_hiera_small.pt" # path to model weight"sam2_hiera_small.pt" # path to model weight
model_cfg = "sam2_hiera_s.yaml" # model config
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # load model
predictor = SAM2ImagePredictor(sam2_model) # load net


任何分段 一般结构

在设置训练参数之前,我们需要了解 SAM 模型的基本结构。

SAM 由三部分组成:

1) 图像编码器;2) 提示编码器;3) 掩码解码器。

图像编码器负责处理图像并创建代表图像的嵌入。这部分由一个 VIT 变换器组成,是网络中最大的组成部分。我们通常不想对它进行训练,因为它已经给出了很好的表示,而训练将需要大量资源。

提示编码器处理网络的额外输入,在我们的例子中就是输入点。

掩码解码器接收图像编码器和提示编码器的输出,生成最终的分割掩码。一般来说,我们只想训练掩码解码器,或许还想训练提示编码器。这些部分都是轻量级的,可以通过适度的 GPU 进行快速微调。


设置训练参数:

我们可以通过设置来启用掩码解码器和提示编码器的训练:


predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder True) # enable training of mask decoder 
predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder


接下来,我们定义标准的 adamW 优化器:


optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)1e-5,weight_decay=4e-5)


我们还将使用混合精度训练,这是一种更节省内存的训练策略:


scaler = torch.cuda.amp.GradScaler() # set mixed precision# set mixed precision


主训练循环

现在让我们构建主训练循环。第一部分是读取和准备数据:


for itr in range(100000):
    with torch.cuda.amp.autocast(): # cast to mix precision
            image,mask,input_point, input_label = read_batch(data) # load data batch
            if mask.shape[0]==0: continue # ignore empty batches
            predictor.set_image(image) # apply SAM image encoder to the image


首先,我们对数据进行混合精度处理,以提高训练效率:


with torch.cuda.amp.autocast():


接下来,我们使用之前创建的阅读器函数来读取训练数据:


image,mask,input_point, input_label = read_batch(data)


我们将加载的图像通过图像编码器(网络的第一部分):


predictor.set_image(image)


接下来,我们使用网络提示编码器处理输入点:


  mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)None, mask_logits=None, normalize_coords=True)
  sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)


请注意,在这一部分我们还可以输入方框或掩码,但我们不会使用这些选项。


现在,我们已经对提示(点)和图像进行了编码,最后就可以预测分割掩码了:


batched_mode = unnorm_coords.shape[0] > 1 # multi mask prediction0] > 1 # multi mask prediction
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=batched_mode,high_res_features=high_res_features,)
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution


该代码的主要部分是 model.sam_mask_decoder,它运行网络的掩码解码器部分,并生成分割掩码(low_res_masks)及其分数(prd_scores)。


这些掩码的分辨率低于原始输入图像,并在后处理掩码函数中被调整为原始输入图像的大小。


这样我们就能得到网络的最终预测结果: prd_masks包含每个输入点的 3 个预测掩码,但我们只使用每个点的第一个掩码。prd_scores 包含网络认为每个掩码有多好的分数(或它对预测有多确定)。


损失函数


分割损失

有了网络预测结果,我们就可以计算损失了。首先,我们计算分割损失,即预测的掩码与地面真实掩码相比有多好。为此,我们使用标准的交叉熵损失。


首先,我们需要使用 sigmoid 函数将预测掩码 (prd_mask) 从对数转换为概率:


prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map0])# Turn logit map to probability map


接下来,我们将地面实况掩码转换为torch张量:


prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map0])# Turn logit map to probability map


最后,我们使用地面实况(gt_mask)和预测概率图(prd_mask)手动计算交叉熵损失(seg_loss):


seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss 


(我们添加 0.0001 是为了防止对数函数因数值为零而爆炸)。


分数损失(可选)

除了掩码之外,网络还会预测每个预测掩码的得分。这部分的训练不那么重要,但也很有用。要训练这一部分,我们首先需要知道每个预测掩码的真实得分。也就是说,预测的掩码实际有多好。我们将使用交集大于联合(IOU)指标来比较 GT 掩膜和相应的预测掩膜。IOU 简单来说就是两个掩码的重叠部分除以两个掩码的合并面积。首先,我们计算预测掩膜和 GT 掩膜之间的交集(两者重叠的区域):


inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)0.5)).sum(1).sum(1)


我们使用阈值(prd_mask > 0.5)将预测掩码从概率掩码转换为二进制掩码。


接下来,我们用交集除以预测掩码和 gt 掩码的合并区域(联合),得到 IOU:


iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)


我们将把 IOU 作为每个面具的真实得分,并把预测得分与我们刚刚计算出的 IOU 之间的绝对差值作为得分损失。


score_loss = torch.abs(prd_scores[:, 0] - iou).mean()abs(prd_scores[:, 0] - iou).mean()


最后,我们合并了分割损失和分数损失(前者的权重更高):


loss = seg_loss+score_loss*0.05  # mix losses0.05  # mix losses


最后一步:反向传播和保存模型

一旦我们得到了损失,接下来的步骤都是完全标准的。我们使用之前创建的优化器来计算反向传播,并更新权重:


predictor.model.zero_grad() # empty gradient# empty gradient
scaler.scale(loss).backward()  # Backpropogate
scaler.step(optimizer)
scaler.update() # Mix precision


我们还希望在每1000个步骤后保存训练好的模型:


if itr00==0: torch.save(predictor.model.state_dict(), "model.torch") # save model 


由于我们已经计算了IOU,我们可以将其显示为移动平均值,以便看到模型预测随着时间的推移的改进情况:


if itr==0: mean_iou=0
mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
print("step)",itr, "Accuracy(IOU)=",mean_iou)


现在,我们已经在不到60行的代码中完成了Segment-Anything 2的训练/微调(不包括注释和导入)。大约经过25000个步骤,你应该会看到主要的改进。


该模型将保存为“model.torch”。


推理:加载和使用经过训练的模型:

现在我们已经对模型进行了微调,让我们使用它来分割一张图片。


我们将按照以下步骤进行操作:

  1. 加载我们刚刚训练的模型。
  2. 为模型提供一张图像和一系列随机点。对于每个点,模型将预测包含该点和一个分数的段落掩模。
  3. 将这些掩模拼接在一起,形成一个分割图。


首先,我们加载依赖项并将权重转换为float16,这样可以使模型运行速度更快(仅适用于推理):


# use bfloat16 for the entire script (memory efficient)
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()


接下来,我们加载一张示例图像和我们要分割的图像区域的掩模(下载图像/掩模):


image_path = r"sample_image.jpg" # path to imager"sample_image.jpg" # path to image
mask_path = r"sample_mask.png" # path to mask, the mask will define the image region to segment
def read_image(image_path, mask_path): # read and resize image and mask
        img = cv2.imread(image_path)[...,::-1]  # read image as rgb
        mask = cv2.imread(mask_path,0) # mask of the region we want to segment
        
        # Resize image to maximum size of 1024
        r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
        img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
        mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)),interpolation=cv2.INTER_NEAREST)
        return img, mask
image,mask = read_image(image_path, mask_path)


在我们要分割的区域内,随机选取30个样本点:


num_samples = 30 # number of points/segment to sample30 # number of points/segment to sample
def get_points(mask,num_points): # Sample points inside the input mask
        points=[]
        for i in range(num_points):
            coords = np.argwhere(mask > 0)
            yx = np.array(coords[np.random.randint(len(coords))])
            points.append([[yx[1], yx[0]]])
        return np.array(points)
input_points = get_points(mask,num_samples)


首先,加载标准的SAM模型(与训练中的模型相同):


# Load model you need to have pretrained model already made
sam2_checkpoint = "sam2_hiera_small.pt" 
model_cfg = "sam2_hiera_s.yaml" 
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)


接下来,加载我们刚刚训练的模型的权重(model.torch):


predictor.model.load_state_dict(torch.load("model.torch"))"model.torch"))


运行经过微调的模型,预测我们选择的每个点的掩模:


with torch.no_grad(): # prevent the net from caclulate gradient (more efficient inference)
        predictor.set_image(image) # image encoder
        masks, scores, logits = predictor.predict(  # prompt encoder + mask decoder
            point_coords=input_points,
            point_labels=np.ones([input_points.shape[0],1])
        )


现在我们有一系列预测的掩模及其得分。我们希望将它们拼接成一个一致的分割图。然而,许多掩模可能重叠,并且可能彼此不一致。由于我们是随机选择的点,所以很可能某些点会落在同一个段落内。


拼接的方法很简单,我们将根据它们的预测得分对预测的掩模进行排序:


np_masks = np.array(masks[:,0].cpu().numpy()) # convert from torch to numpy0].cpu().numpy()) # convert from torch to numpy
np_scores = scores[:,0].float().cpu().numpy() # convert from torch to numpy
shorted_masks = np_masks[np.argsort(np_scores)][::-1] # arrange mask according to score


现在让我们创建一个空的分割图和一个占用图:


seg_map = np.zeros_like(shorted_masks[0],dtype=np.uint8)0],dtype=np.uint8)
occupancy_mask = np.zeros_like(shorted_masks[0],dtype=bool)


接下来,我们逐个将掩模(从高分到低分)添加到分割图中。只有在要添加的掩模与先前添加的掩模保持一致时,我们才会添加该掩模。这意味着只有当我们要添加的掩模与已占用区域的重叠度小于15%时,我们才会添加它。


for i in range(shorted_masks.shape[0]):
    mask = shorted_masks[i]
    if (mask*occupancy_mask).sum()/mask.sum()>0.15: continue 
    mask[occupancy_mask]=0
    seg_map[mask]=i+1
    occupancy_mask[mask]=1


就是这样了。


seg_mask现在包含了预测的分割图,每个段落都有不同的值,背景为0。


我们可以使用以下代码将其转换成彩色图:


rgb_image = np.zeros((seg_map.shape[0], seg_map.shape[1], 3), dtype=np.uint8)0], seg_map.shape[1], 3), dtype=np.uint8)
for id_class in range(1,seg_map.max()+1):
    rgb_image[seg_map == id_class] = [np.random.randint(255), np.random.randint(255), np.random.randint(255)]


并显示:


cv2.imshow("annotation",rgb_image)"annotation",rgb_image)
cv2.imshow("mix",(rgb_image/2+image/2).astype(np.uint8))
cv2.imshow("image",image)
cv2.waitKey()


13

文章来源:https://medium.com/@sagieppel/train-fine-tune-segment-anything-2-sam-2-in-60-lines-of-code-928dd29a63b3
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消