问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

深度学习模型对海陆分割遥感图像数据集进行语义分割

创作时间:
作者:
@小白创作中心

深度学习模型对海陆分割遥感图像数据集进行语义分割

引用
CSDN
1.
https://blog.csdn.net/2401_88441190/article/details/146094367

本文将详细介绍如何使用深度学习模型(UNet、DeepLabV3)对海陆分割遥感图像进行语义分割。通过本文,读者将学习到数据集准备、模型训练和结果可视化等关键步骤,并获得完整的代码示例。

数据集准备

假设你的数据集结构如下:

sea_land_segmentation/
├── images/
│   ├── train/
│   │   ├── img1.jpg
│   │   └── ...
│   ├── val/
│   │   ├── img1.jpg
│   │   └── ...
│   └── test/
│       ├── img1.jpg
│       └── ...
└── masks/
    ├── train/
    │   ├── img1.png
    │   └── ...
    ├── val/
    │   ├── img1.png
    │   └── ...
    └── test/
        ├── img1.png
        └── ...
data_sea_land.yaml

data_sea_land.yaml 文件内容示例:

train: ./sea_land_segmentation/images/train/
val: ./sea_land_segmentation/images/val/

nc: 2  # 类别数量:Sea 和 Land
names: ['Sea', 'Land']

确保每个掩码图像(mask)是单通道的PNG图像,其中像素值代表类别ID(例如0代表Sea,1代表Land)。

安装依赖库

确保安装了必要的库:

pip install segmentation-models-pytorch albumentations opencv-python-headless torch torchvision

自定义数据集类

编写一个自定义的数据集类来读取图像和对应的掩码。

import os
import cv2
import numpy as np
from torch.utils.data import Dataset

class SeaLandDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in os.listdir(images_dir)]
        self.masks_fps = [os.path.join(masks_dir, mask_id) for mask_id in os.listdir(images_dir)]
        self.transform = transform

    def __getitem__(self, i):
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

    def __len__(self):
        return len(self.images_fps)

模型定义与训练

选择一个合适的模型架构(例如UNet或DeepLabV3),并配置优化器、损失函数和评估指标。

训练脚本

import torch
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# 加载数据集
train_dataset = SeaLandDataset('./sea_land_segmentation/images/train/', './sea_land_segmentation/masks/train/')
valid_dataset = SeaLandDataset('./sea_land_segmentation/images/val/', './sea_land_segmentation/masks/val/')

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=4)

# 定义模型
ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['Sea', 'Land']
ACTIVATION = 'sigmoid'  # 二分类问题可以使用sigmoid激活函数

model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

# 设置损失函数和优化器
loss = smp.utils.losses.CrossEntropyLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device='cuda',
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device='cuda',
    verbose=True,
)

max_score = 0

for i in range(0, 40):  # 训练周期数
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')

    if i == 25:
        optimizer.param_groups[0]['lr'] /= 10
        print('Decrease decoder learning rate to 1e-5!')

推理与结果可视化

训练完成后,我们可以利用训练好的模型对新图片进行预测,并将结果可视化。

推理脚本

import matplotlib.pyplot as plt

best_model = torch.load('./best_model.pth')

def visualize(image, gt_mask, pr_mask):
    figure, ax = plt.subplots(1, 3, figsize=(10, 10))

    ax[0].imshow(image)
    ax[0].set_title("Image")
    ax[1].imshow(gt_mask, cmap='gray')
    ax[1].set_title("Ground Truth Mask")
    ax[2].imshow(pr_mask, cmap='gray')
    ax[2].set_title("Predicted Mask")

    plt.show()

test_dataset = SeaLandDataset('./sea_land_segmentation/images/test/', './sea_land_segmentation/masks/test/')
for i in range(5):  # 可视化前5个测试样本
    image, gt_mask = test_dataset[i]
    x_tensor = torch.from_numpy(image).to('cuda').permute(2, 0, 1).unsqueeze(0).float() / 255.
    pr_mask = best_model.predict(x_tensor)
    pr_mask = pr_mask.squeeze().cpu().numpy()
    visualize(image, gt_mask, pr_mask)

运行步骤总结

  1. 数据集准备 :确认数据集已按要求组织好。
  2. 安装依赖 :通过提供的命令安装所需的Python库。
  3. 自定义数据集类 :创建自定义数据集类以加载图像和掩码。
  4. 模型定义与训练 :选择合适的模型架构,配置优化器、损失函数,并开始训练过程。
  5. 推理与可视化 :使用训练好的模型进行推理,并可视化结果。

以上提供了完整的流程——使用Segmentation Models for PyTorch对海陆分割遥感图像进行语义分割任务。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号