深度学习:大模型微调之SAM模型微调实战
深度学习:大模型微调之SAM模型微调实战
随着Meta发布的Segment Anything Model(SAM)在计算机视觉领域掀起波澜,本文将深入探讨如何对这个强大的分割模型进行微调,以适应特定应用场景。通过详细的技术解析和代码示例,读者将掌握从数据准备到模型训练的完整流程,从而让SAM在特定任务中发挥更大效能。
什么是 Segment Anything 模型 (SAM)?
Segment Anything 模型(SAM)是由Meta AI开发的分割模型,被认为是计算机视觉领域的首个基础模型。它在包含数百万张图像和数十亿个掩码的庞大数据语料库上进行了训练,展现出强大的分割能力。SAM能够为各种图像生成准确的分割掩码,并支持多模式提示,包括点、边界框和文本提示。
该模型由三个主要组件构成:图像编码器、提示编码器和掩码解码器。图像编码器负责生成图像的嵌入表示,提示编码器则处理用户提供的提示信息,而轻量级的掩码解码器则根据这些嵌入预测最终的分割掩码。
什么是模型微调?
模型微调是指在预训练模型的基础上,通过向其展示特定用例数据来优化模型性能的过程。与从头开始训练相比,微调利用了预训练模型已经学习到的特征,只需要调整权重和偏差以适应特定任务,从而节省大量计算资源。
为什么要微调模型?
微调模型的主要目的是在预训练模型未充分覆盖的特定场景中提升性能。例如,一个在水平视角图像上训练的模型可能无法很好地处理垂直视角的卫星图像。通过微调,我们可以利用预训练模型已经掌握的基础分割能力,进一步优化其在特定任务上的表现。
如何微调分段任何模型[使用代码]
背景与架构
为了微调SAM,我们需要关注其轻量级的掩码解码器部分。由于掩码解码器相对较小,因此更容易、更快地进行微调,同时占用较少的内存资源。我们不能直接使用SamPredictor.predict
函数,因为它会阻止梯度计算。因此,我们需要手动调用相关函数并启用掩码解码器的梯度计算。
创建自定义数据集
为了微调SAM,我们需要准备三类数据:
- 用于分割的图像
- 分割的地面真实掩码
- 提供给模型的提示
我们选择印章验证数据集作为示例,因为:
- 这类数据在SAM的训练集中可能未充分覆盖
- 数据集包含精确的地面真实掩码,便于计算损失
- 数据集中的边界框可以作为有效的提示信息
输入数据预处理
我们需要将输入图像从NumPy数组转换为PyTorch张量,并进行适当的预处理。具体步骤包括:
- 使用
utils.transform.ResizeLongestSide
调整图像大小 - 将图像转换为PyTorch张量
- 应用SAM模型的预处理方法
训练设置
首先加载预训练的SAM模型权重:
sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')
设置Adam优化器并指定需要微调的参数:
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters())
选择损失函数,例如均方误差:
loss_fn = torch.nn.MSELoss()
训练循环
在训练循环中,我们需要:
- 使用
torch.no_grad()
上下文管理器计算图像和提示的嵌入 - 生成低分辨率的掩码
- 将掩码放大回原始图像大小
- 计算损失并更新模型参数
关键代码如下:
with torch.no_grad():
image_embedding = sam_model.image_encoder(input_image)
with torch.no_grad():
sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_masks, iou_predictions = sam_model.mask_decoder(
image_embeddings=image_embedding,
image_pe=sam_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
from torch.nn.functional import threshold, normalize
binary_mask = normalize(threshold(upscaled_masks, 0.0, 0)).to(device)
loss = loss_fn(binary_mask, gt_binary_mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
保存检查点并从中启动模型
训练完成后,可以保存模型的状态字典:
torch.save(model.state_dict(), PATH)
在需要时,可以通过加载状态字典来恢复模型:
model.load_state_dict(torch.load(PATH))
针对下游应用程序的微调
虽然目前SAM模型本身不提供开箱即用的微调功能,但可以通过与Encord平台集成来实现一键式微调。这种集成方式可以自动设置超参数,简化微调流程。
结论
通过本文的介绍,读者应该掌握了如何对SAM模型进行微调的基本方法。对于希望快速应用SAM模型的用户,可以考虑使用Encord平台提供的开箱即用微调功能,无需编写代码即可完成模型优化。