问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

通过LLaMaFactory+Qwen2-VL-2B微调一个多模态医疗大模型

创作时间:
作者:
@小白创作中心

通过LLaMaFactory+Qwen2-VL-2B微调一个多模态医疗大模型

引用
1
来源
1.
https://17aitech.com/?p=33382

随着多模态大模型的发展,其不仅限于文字处理,更能够在图像、视频、音频方面进行识别与理解。医疗领域中,医生们往往需要对各种医学图像进行处理,以辅助诊断和治疗。如果将多模态大模型与图像诊断相结合,那么这会极大地提升诊断效率。

项目目标

训练一个医疗多模态大模型,用于图像诊断。

实现过程

1. 数据集准备

为了训练模型,需要准备大量的医学图像数据。通过搜索我们找到以下训练数据:

数据名称:MedTrinity-25M

数据地址https://github.com/UCSC-VLAA/MedTrinity-25M

数据简介:MedTrinity-25M数据集是一个用于医学图像分析和计算机视觉研究的大型数据集。

数据来源:该数据集由加州大学圣克鲁兹分校(UCSC)提供,旨在促进医学图像处理和分析的研究。

数据量:MedTrinity-25M包含约2500万条医学图像数据,涵盖多种医学成像技术,如CT、MRI和超声等。

数据内容

该数据集有两份,分别是

25Mdemo

25Mfull

。25Mdemo(约162,000条)数据集内容如下:

25Mfull(约24,800,000条)数据集内容如下:

2. 数据下载

2.1 安装Hugging Face的Datasets库

pip install datasets

2.2 下载数据集

from datasets import load_dataset
# 加载数据集
ds = load_dataset("UCSC-VLAA/MedTrinity-25M", "25M_demo", cache_dir="cache")

执行结果:

说明:

  • 以上方法是使用HuggingFace的Datasets库下载数据集,下载的路径为当前脚本所在路径下的cache文件夹。
  • 使用HuggingFace下载需要能够访问https://huggingface.co/ 并且在网站上申请数据集读取权限才可以。
  • 如果没有权限访问HuggingFace,可以关注以下公众号后,回复 “MedTrinity”获取百度网盘下载地址。

2.3 预览数据集

# 查看训练集的前1个样本
print(ds['train'][:1])

运行结果:

{
    'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x15DD6D06530>],
    'id': ['8031efe0-1b5c-11ef-8929-000066532cad'],
    'caption': ['The image is a non-contrasted computed tomography (CT) scan of the brain, showing the cerebral structures without any medical devices present. The region of interest, located centrally and in the middle of the image, exhibits an area of altered density, which is indicative of a brain hemorrhage. This area is distinct from the surrounding brain tissue, suggesting a possible hematoma or bleeding within the brain parenchyma. The location and characteristics of this abnormality may suggest a relationship with the surrounding brain tissue, potentially causing a mass effect or contributing to increased intracranial pressure.']
}

使用如下命令对数据集的图片进行可视化查看:

# 可视化image内容
from PIL import Image
import matplotlib.pyplot as plt
image = ds['train'][0]['image']  # 获取第一张图像
plt.imshow(image)
plt.axis('off')  # 不显示坐标轴
plt.show()

运行结果:

3. 数据预处理

由于后续我们要通过LLama Factory进行多模态大模型微调,所以我们需要对上述的数据集进行预处理以符合LLama Factory的要求。

3.1 LLama Factory数据格式

查看LLama Factory的多模态数据格式要求如下:

[
  {
    "messages": [
      {
        "content": "<image>他们是谁?",
        "role": "user"
      },
      {
        "content": "他们是拜仁慕尼黑的凯恩和格雷茨卡。",
        "role": "assistant"
      },
      {
        "content": "他们在做什么?",
        "role": "user"
      },
      {
        "content": "他们在足球场上庆祝。",
        "role": "assistant"
      }
    ],
    "images": [
      "mllm_demo_data/1.jpg"
    ]
  }
]

3.2 实现数据格式转换脚本

from datasets import load_dataset
import os
import json
from PIL import Image

def save_images_and_json(ds, output_dir="mllm_data"):
    """
    将数据集中的图像和对应的 JSON 信息保存到指定目录。
    参数:
    ds: 数据集对象,包含图像和标题。
    output_dir: 输出目录,默认为 "mllm_data"。
    """
    # 创建输出目录
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # 创建一个列表来存储所有的消息和图像信息
    all_data = []
    # 遍历数据集中的每个项目
    for item in ds:
        img_path = f"{output_dir}/{item['id']}.jpg"  # 图像保存路径
        image = item["image"]  # 假设这里是一个 PIL 图像对象
        # 将图像对象保存为文件
        image.save(img_path)  # 使用 PIL 的 save 方法
        # 添加消息和图像信息到列表中
        all_data.append(
            {
                "messages": [
                    {
                        "content": "<image>图片中的诊断结果是怎样?",
                        "role": "user",
                    },
                    {
                        "content": item["caption"],  # 从数据集中获取的标题
                        "role": "assistant",
                    },
                ],
                "images": [img_path],  # 图像文件路径
            }
        )
    # 创建 JSON 文件
    json_file_path = f"{output_dir}/mllm_data.json"
    with open(json_file_path, "w", encoding='utf-8') as f:
        json.dump(all_data, f, ensure_ascii=False)  # 确保中文字符正常显示

if __name__ == "__main__":
    # 加载数据集
    ds = load_dataset("UCSC-VLAA/MedTrinity-25M", "25M_demo", cache_dir="cache")
    # 保存数据集中的图像和 JSON 信息
    save_images_and_json(ds['train'])

运行结果:

4. 模型下载

本次微调,我们使用阿里最新发布的多模态大模型:

Qwen2-VL-2B-Instruct

作为底座模型。

模型说明地址https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct

使用如下命令下载模型

git lfs install
# 下载模型
git clone https://www.modelscope.cn/Qwen/Qwen2-VL-2B-Instruct.git

5. 环境准备

5.1 机器环境

硬件:

  • 显卡:4080 Super
  • 显存:16GB

软件:

  • 系统:Ubuntu 20.04 LTS
  • python:3.10
  • pytorch:2.1.2 + cuda12.1

5.2 准备虚拟环境

# 创建python3.10版本虚拟环境
conda create --name train_env python=3.10
# 激活环境
conda activate train_env
# 安装依赖包
pip install streamlit torch torchvision
# 安装Qwen2建议的transformers版本
pip install git+https://github.com/huggingface/transformers
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号