基于PyTorch的语义分割模型训练框架:动态绘制性能指标曲线
创作时间:
作者:
@小白创作中心
基于PyTorch的语义分割模型训练框架:动态绘制性能指标曲线
引用
CSDN
1.
https://m.blog.csdn.net/2201_76033400/article/details/144760925
本文将分享一个基于PyTorch的语义分割训练框架的实现,涵盖从数据加载、训练逻辑、验证指标计算到性能指标曲线绘制的完整过程。重点介绍如何动态绘制性能指标(如mIoU、Recall、Precision、F1 Score)及其随训练过程的变化曲线,同时解读核心训练脚本train.py的设计思想。
项目结构概览
整个项目的结构如下:
├── data/ # 数据存储目录
├── model/ # 模型定义目录
│ └── model.py # 自定义模型定义
├── utils/ # 工具函数目录
│ ├── data_loading.py # 数据加载工具
│ ├── metrics.py # 评价指标计算
│ ├── plot.py # 可视化工具
│ ├── dice_score.py # Dice 损失计算工具
├── train.py # 核心训练脚本
└── evaluate.py # 模型评估脚本
train.py 主要功能
train.py是整个项目的核心训练脚本,功能模块包括:
数据加载
在train.py中,我们定义了数据加载器,用于加载训练集和验证集,并支持数据增广与批量加载。
from torch.utils.data import DataLoader
from utils.data_loading import CarvanaDataset, BasicDataset
# 数据目录
dir_img = Path('/root/task/data/train/imgs')
dir_mask = Path('/root/task/data/train/masks')
valid_img = Path('/root/task/data/valid/imgs')
valid_mask = Path('/root/task/data/valid/masks')
# 数据加载器定义
train_dataset = CarvanaDataset(dir_img, dir_mask, img_scale=0.5)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
val_dataset = CarvanaDataset(valid_img, valid_mask, img_scale=0.5)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)
模型初始化
模型部分调用自定义网络(存放于model/model.py中),支持多类别分类。
from model.model import self_net
# 初始化模型,这里用的是简单的Unet模型
model = self_net(n_channels=3, n_classes=4, bilinear=True) #参数可自己调整
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
损失函数与优化器
使用交叉熵损失和Dice损失的组合来优化语义分割任务,并配置了学习率调度器。
import torch.nn as nn
from torch import optim
from utils.dice_score import dice_loss
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器与学习率调度器
optimizer = optim.RMSprop(model.parameters(), lr=1e-6, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)
核心训练与验证逻辑
训练逻辑
训练逻辑通过PyTorch的标准流程实现,同时每个epoch后动态计算并保存mIoU、Recall、Precision等指标。
from tqdm import tqdm
from utils.metrics import Evaluator
def train_model(model, train_loader, val_loader, device, epochs, optimizer, criterion, amp=False):
best_miou = -float('inf') # 初始化最优 mIoU
train_miou_logs, val_miou_logs = [], [] # 用于绘制曲线的日志
for epoch in range(epochs):
model.train()
epoch_loss = 0
evaluator = Evaluator(num_class=model.n_classes)
with tqdm(total=len(train_loader), desc=f'Epoch {epoch+1}/{epochs}', unit='img') as pbar:
for batch in train_loader:
images, masks = batch['image'].to(device), batch['mask'].to(device)
optimizer.zero_grad()
with torch.autocast(device_type='cuda', enabled=amp):
preds = model(images)
loss = criterion(preds, masks)
loss += dice_loss(
F.softmax(preds, dim=1).float(),
F.one_hot(masks, model.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True
)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# 更新评估指标
preds = preds.argmax(dim=1).cpu().numpy()
masks = masks.cpu().numpy()
evaluator.add_batch(masks, preds)
pbar.set_postfix(loss=loss.item())
pbar.update()
# 计算训练集的 mIoU
IoU_per_class, train_miou = evaluator.Mean_Intersection_over_Union()
train_miou_logs.append(train_miou)
# 验证模型
val_miou = validate_model(model, val_loader, evaluator, device, criterion)
val_miou_logs.append(val_miou)
# 保存最佳模型
if val_miou > best_miou:
best_miou = val_miou
torch.save(model.state_dict(), 'best_model.pth')
print(f"Epoch {epoch+1}: Train mIoU: {train_miou:.4f}, Validation mIoU: {val_miou:.4f}")
验证逻辑
验证阶段主要计算验证集的指标(mIoU、Recall、Precision等),并动态绘制性能曲线。
def validate_model(model, val_loader, evaluator, device, criterion):
model.eval()
evaluator.reset()
with torch.no_grad():
for batch in val_loader:
images, masks = batch['image'].to(device), batch['mask'].to(device)
preds = model(images)
evaluator.add_batch(masks.cpu().numpy(), preds.argmax(dim=1).cpu().numpy())
_, mIoU = evaluator.Mean_Intersection_over_Union()
return mIoU
性能指标曲线绘制
通过在每轮训练和验证后,动态调用以下绘图函数,将曲线保存到本地:
from utils.plot import plot_miou
# 在训练后调用绘图函数
plot_miou(train_miou_logs, val_miou_logs, save_path='/root/task/plots')
绘制效果
以下是训练120轮过程中保存的曲线图示例:
mIoU 曲线:展示训练与验证的 mIoU 随 epoch 的变化趋势。
损失曲线:用于评估模型是否过拟合。
混淆矩阵:显示各类别的分类结果。
还有recall和precision曲线就不一一展示了
运行脚本
完整训练脚本可以通过以下命令运行:
python train.py --epochs 50 --batch-size 8 --learning-rate 1e-6
总结
本文分享了语义分割训练框架的核心实现,重点展示了如何在训练中动态绘制mIoU、Recall、Precision和F1 Score等性能指标曲线。这些功能可以帮助开发者快速评估模型性能并优化训练过程。
如果你对完整代码感兴趣,欢迎留言交流!
热门推荐
李白笔下的峨眉山:绝美秋景再现
李白《峨眉山月歌》中的月亮意象解读
从峨眉山月到艺术歌曲:李白《峨眉山月歌》的现代演绎
李白笔下的峨眉山:从古诗意境到现代美景
成都必打卡:龙抄手&陈麻婆豆腐
成都必打卡:宽窄巷子、武侯祠、锦里,像本地人一样玩转三大景点
成都三日游攻略:武侯祠、熊猫基地、锦里古街深度游
喀什巴旦木:从丝路明珠到产业新星
莎车巴旦木:致富路上的甜蜜果实
科技赋能品质,喀什巴旦木如何守护“舌尖安全”?
2025春节县域旅游爆火!十大热门目的地全攻略
春节旅游旺季:景区招商秘籍大揭秘!
春节申遗成功首个非遗年,近7000项活动让年味更浓
色彩搭配在室内设计中的应用
秋冬打卡:边城茶峒最美拍摄点推荐
郭璞和杨筠松:中国古代风水大师的传奇人生
三僚国学文化研究院推荐:掌握装修设计中的风水原则
冬季健康居家风水指南
中南大学湘雅三医院分享优秀护理带教经验
护士心理健康问题:现状、原因与对策
虚拟现实技术革新护理教育:从模拟训练到临床实践
佛陀的智慧:人生如梦,如何从无常中找到宁静?
如何选择适合供奉的水果
FaceTime和iMessage:让iPhone沟通更顺畅!
南安十大景点全攻略:千年古刹、英雄雕像、湖光山色,尽显闽南魅力
泉州一日游最佳路线,泉州旅游攻略一日游怎么安排
赵本山春晚小品盘点:21个经典作品,你最爱哪个?
孕期能吃麻辣香锅吗?这些注意事项请收好
秋冬肠胃不适?枯草杆菌来帮忙!
川味风暴来袭:自制麻辣香锅秘笈