问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

手动训练YOLO模型后如何保存ultralytics库兼容的pt

创作时间:
作者:
@小白创作中心

手动训练YOLO模型后如何保存ultralytics库兼容的pt

引用
CSDN
1.
https://m.blog.csdn.net/qq_38069830/article/details/142301247

手动训练YOLO模型后如何保存ultralytics库兼容的pt

项目场景

最近在搭建AI训练平台的时候遇到一个问题:平台是使用Ray构建分布式训练框架,但是Ray的训练过程只能接收Pytorch的模型对象,所以我获取了YOLO的Pytorch模型对象放到Ray中做手动执行训练过程。

问题:

在训练完后使用torch.save()方法导出pt模型文件后,无法使用ultralytics的YOLO方法再次载入。

示例代码:

from ultralytics import YOLO
model = YOLO("model_epoch_1.pt")	# 这里是使用torch.save()导出的模型文件

使用ultralytics载入的时候会显示,ultralytics希望你用YOLO对象中的save()方法去保存模型:

WARNING ⚠️ The file 'model_epoch_1 (1).pt' appears to be improperly saved or formatted. For optimal results, use model.save('filename.pt') to correctly save YOLO models.

如果不能用ultralytics重新载入的话,那后续部署模型的时候,就很难使用ultralytics已经集成好的很多方法,所以必须找到方法用ultralytics兼容的格式载入。

解决方案

尝试了种种方法无果后,只能去查看ultralytics库中save方法的源码实现。

## ultralytics/engine/model.py
class Model(nn.Module):
    def save(self, filename: Union[str, Path] = "saved_model.pt", use_dill=True) -> None:
        """
        Saves the current model state to a file.
        This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
        the date, Ultralytics version, license information, and a link to the documentation.
        Args:
            filename (Union[str, Path]): The name of the file to save the model to.
            use_dill (bool): Whether to try using dill for serialization if available.
        Raises:
            AssertionError: If the model is not a PyTorch model.
        Examples:
            >>> model = Model('yolov8n.pt')
            >>> model.save('my_model.pt')
        """
        self._check_is_pytorch_model()
        from copy import deepcopy
        from datetime import datetime
        from ultralytics import __version__
        updates = {
            "model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model,
            "date": datetime.now().isoformat(),
            "version": __version__,
            "license": "AGPL-3.0 License (https://ultralytics.com/license)",
            "docs": "https://docs.ultralytics.com",
        }
        torch.save({**self.ckpt, **updates}, filename, use_dill=use_dill)

下面是ultralytics中save()方法,可以看到ultralytics也是使用torch.save()保存模型文件的,但是做了一些特殊处理:

  • 使用深拷贝去复制一个model对象,保证对model做处理的时候不会影响到原对象。
  • 后面就是保存当前的时间、ultralytics的版本信息、开源协议、ultralytics的文档地址。
  • 最后就是将上面这些信息和模型的权重一起以键值对的形式保存在pt文件中。

既然知道了ultralytics的保存方法之后,只需要按照ultralytics要求的格式去保存model文件

def yolo_save(model, path):
    # 创建更新字典
    updates = {
        "model": deepcopy(model).half() if isinstance(model, nn.Module) else model,
        "date": datetime.now().isoformat(),
        "version": __version__,  # 使用从 ultralytics 导入的 __version__
        "license": "AGPL-3.0 License (https://ultralytics.com/license)",
        "docs": "https://docs.ultralytics.com",
    }
    # 将模型的状态字典与更新字典合并并保存
    torch.save({**model.state_dict(), **updates}, path, use_dill=True)
num_epochs = 10
for epoch in range(num_epochs):
    # 前向传播
    outputs = model(inputs)
    
    # 提取预测张量
    predictions = outputs[0]  # 假设第一个输出是预测张量
    
    # 确保目标格式与预测格式匹配
    # 这里假设 predictions 的形状为 [batch_size, channels, height, width]
    # 你需要根据实际情况调整目标张量的形状
    batch_targets = targets
    
    # 计算损失
    loss = criterion(predictions, batch_targets)
    
    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
def yolo_save(model, path):
    # 创建更新字典
    updates = {
        "model": deepcopy(model).half() if isinstance(model, nn.Module) else model,
        "date": datetime.now().isoformat(),
        "version": __version__,  # 使用从 ultralytics 导入的 __version__
        "license": "AGPL-3.0 License (https://ultralytics.com/license)",
        "docs": "https://docs.ultralytics.com",
    }
    # 将模型的状态字典与更新字典合并并保存
    torch.save({**model.state_dict(), **updates}, path, use_dill=True)
yolo_save(model, r"./test.pt")
print("训练完成")

按照上面的方法去保存model,之后就可以使用ultralytics去载入模型了。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号