问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

SAM+YOLOv8实现图像批量分割提取分割掩码

创作时间:
作者:
@小白创作中心

SAM+YOLOv8实现图像批量分割提取分割掩码

引用
CSDN
1.
https://blog.csdn.net/zzzyyy8/article/details/143220175

Meta的FAIR实验室发布的Segment Anything Model (SAM)是一种先进的图像分割模型,它通过提示工程来适应不同的下游分割问题。本文将介绍如何使用SAM和YOLOv8实现图像的批量分割提取分割掩码。

0. 概述

Meta的FAIR实验室发布的Segment Anything Model (SAM),是一种最先进的图像分割模型,旨在改变计算机视觉领域。

SAM 基于对自然语言处理 (NLP) 产生重大影响的基础模型(Foundation Model)。它专注于提示分割任务 (promptable segmentation tasks),使用提示工程来适应不同的下游分割问题。

模型

SAM 的架构包含三个组件,它们协同工作以返回有效的分割掩膜

  • 一种图像编码器,用于生成一次性图像嵌入。
  • 嵌入提示的提示编码器。
  • 一个轻量级掩膜解码器,结合了提示和图像编码器的嵌入。

SAM结构

图像编码器

在最高级别,图像编码器 (掩膜自编码器,MAE,预训练视觉Transformer,ViT)生成一次性图像嵌入,并且可以在提示模型之前应用。

提示编码器

提示编码器将背景点、掩膜、边界框或文本实时编码为嵌入向量。该研究考虑了两组提示:稀疏(点、框、文本)和密集(掩膜)。

点和框由位置编码表示,并为每种提示类型添加学习的嵌入。自由格式的文本提示由CLIP的现成文本编码器表示。密集提示(如掩膜)通过卷积嵌入,并通过图像嵌入按元素求和。

掩膜解码器

轻量级掩膜解码器根据图像和提示编码器的嵌入来预测分割掩膜。它将图像嵌入、提示嵌入和输出token映射到掩码。所有嵌入均由解码器块更新,解码器块在两个方向(从提示到图像嵌入再返回)使用即时自注意力和交叉注意力。

掩膜带有标注并用于更新模型权重。这种布局增强了数据集,并允许模型随着时间的推移进行学习和改进,使其高效且灵活。

本文主要根据xml标签信息作为提示,将目标检测框框选的内容进行实例分割,提取分割掩码。

SAM可以基于全局来分割,根据提示点来分割,也可以根据边界框来分割,这篇文章主要实现用YOLOv8推理的边界框作为提示,用sam进行目标的分割。如下图所示。

1. 安装环境

安装ultralytics、cv库、pustil。

pip install -U ultralytics    
pip install opencv-python
pip install psutil

2. 代码实现

实现将边界框转换为分割掩模。此代码实现图像批量分割,依次读取要分割的图像信息以及xml标签信息,将所有的目标对象都分割出来并保存在文件夹中。这个脚本写在yolov8项目里新建py文件里面即可。在执行代码时候会自动下载sam的权重文件,这里推荐下载比较大的,个人测试,sam_l.pt推理出来比轻量化的那些好点。下载估计需要1.6G。

其中,image_folder保存原始的图片。

xml_folder保存原始图片对应的xml标签信息,也就是你的提示框信息。

其余的output_folder,mask_output_folder,cropped_output_folder均是代码输出结果的保存路径。

import os
import xml.etree.ElementTree as ET
from ultralytics import SAM
from PIL import Image
import numpy as np

# 定义输入输出路径
image_folder = r'D:\testdata\image'  # 输入图片文件夹
xml_folder = r'D:\testdata\annotations'  # 对应的XML文件夹,根据该提示框进行分割
output_folder = r'D:\testdata\output\samimage'  # 推理结果输出文件夹
mask_output_folder = r'D:\testdata\output\masks'  # 掩码保存文件夹
cropped_output_folder = r'D:\testdata\output\cropped'  # 提取区域保存文件夹

# 如果输出文件夹不存在,创建文件夹
os.makedirs(output_folder, exist_ok=True)
os.makedirs(mask_output_folder, exist_ok=True)
os.makedirs(cropped_output_folder, exist_ok=True)

# 加载SAM模型
model = SAM("sam_l.pt")

# 从XML文件中提取bboxes
def get_bboxes_from_xml(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    bboxes = []
    for obj in root.findall('object'):
        bndbox = obj.find('bndbox')
        xmin = int(bndbox.find('xmin').text)
        ymin = int(bndbox.find('ymin').text)
        xmax = int(bndbox.find('xmax').text)
        ymax = int(bndbox.find('ymax').text)
        bboxes.append([xmin, ymin, xmax, ymax])
    return bboxes

# 对每张图片进行推理
for image_name in os.listdir(image_folder):
    if image_name.endswith('.jpg'):  # 确保是图片文件
        image_path = os.path.join(image_folder, image_name)
        xml_name = image_name.replace('.jpg', '.xml')
        xml_path = os.path.join(xml_folder, xml_name)
        if os.path.exists(xml_path):
            # 获取图片对应的bboxes
            bboxes = get_bboxes_from_xml(xml_path)
            # 进行推理
            results = model(image_path, bboxes=bboxes)
            # 加载原始图片
            img = Image.open(image_path)
            # 保存推理结果的图片(带标注)
            result_image_path = os.path.join(output_folder, f"result_{image_name}")
            result_image = results[0].plot()  # 获取结果
            result_image_pil = Image.fromarray(result_image)  # 将结果转换为 PIL 格式
            result_image_pil.save(result_image_path)  # 保存图片
            print(f"Processed and saved inference result: {result_image_path}")
            # 遍历推理结果,处理所有的掩码
            for i, mask_data in enumerate(results[0].masks.data):
                mask = mask_data.cpu().numpy()  # 提取每个掩码
                mask_output_path = os.path.join(mask_output_folder,
                                                f"{image_name.replace('.jpg', f'_mask_{i}.png')}")
                # 保存mask
                Image.fromarray((mask * 255).astype(np.uint8)).save(mask_output_path)
                # 提取mask对应的图像区域
                mask_array = np.array(mask)
                cropped_image = Image.fromarray(np.array(img) * mask_array[:, :, None])  # 将掩码区域应用到原图
                # 保存提取的区域图像
                cropped_image_output_path = os.path.join(cropped_output_folder,
                                                         f"{image_name.replace('.jpg', f'_cropped_{i}.png')}")
                cropped_image.save(cropped_image_output_path)
                print(f"Processed and saved mask: {mask_output_path}")
                print(f"Processed and saved cropped region: {cropped_image_output_path}")
        else:
            print(f"XML file not found for image: {image_name}")
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号