微调Llama 3.2 Vision:提升AI医学图像解读能力

2024年11月26日 由 alex 发表 186 0

你是否曾好奇过AI模型是如何学会理解医学影像的?今天,我将带你了解一个激动人心的项目:对Meta的Llama 3.2视觉模型进行微调,以分析放射影像。


这是关于什么的?

想象一下,有一个AI助手能够查看X光片并提供详细的医学描述。这正是我们正在构建的内容。我们正在利用Meta强大的Llama 3.2视觉模型(一个拥有110亿参数的AI)并教它更好地理解医学影像。


微调前后的对比

这有趣的地方在于:在训练之前,模型给出的医学图像描述是笼统且有些模糊的。但经过我们的微调过程后,它变得更加精确和专业,说话更像是一位专业的放射科医生。


它是如何工作的?

这个过程就像通过例子来教学生一样。我们使用了一个名为“Radiology_mini”的数据集,其中包含X光片图像和专家描述。我们反复向模型展示这些图像,它学会了:

  • 识别特定的医学特征
  • 使用正确的医学术语
  • 像专业放射科医生一样构建其回应


幕后的魔法

我们使用了一种巧妙的技术,称为LoRA(低秩适应),这使得即使在单个GPU上也能训练这个庞大的模型。你可以把它想象成是在教模型更好地完成其工作,而无需重写其整个知识库。


结果

转变是显著的。在训练之前,模型给出了像“这张X光片似乎是上颌和下颌的全景视图…”这样的一般临床观察。训练后,它提供了更专注和结构化的观察,如“全景X光片显示双侧动脉瘤样骨囊肿(ABC)”——这对于医学专业人员来说更加精确和有用!


技术实现

让我们深入了解如何自己实现这一点。以下是一个包含代码的逐步指南:


设置和安装

首先,安装所需的包:


pip install unsloth
export HF_TOKEN=xxxxxxxxxxxxx  # Your Hugging Face token


完整代码

以下是按逻辑部分分解的完整实现:


import os
from unsloth import FastVisionModel
import torch
from datasets import load_dataset
from transformers import TextStreamer
from unsloth import is_bf16_supported
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig
# Load the model
model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Llama-3.2-11B-Vision-Instruct",
    load_in_4bit = True,
    use_gradient_checkpointing = "unsloth",
)
# Configure fine-tuning parameters
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True,
    finetune_language_layers   = True,
    finetune_attention_modules = True,
    finetune_mlp_modules      = True,
    r = 16,
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)
# Load and prepare the dataset
dataset = load_dataset("unsloth/Radiology_mini", split = "train")
instruction = "You are an expert radiographer. Describe accurately what you see in this image."
def convert_to_conversation(sample):
    conversation = [
        { "role": "user",
          "content" : [
            {"type" : "text",  "text"  : instruction},
            {"type" : "image", "image" : sample["image"]} ]
        },
        { "role" : "assistant",
          "content" : [
            {"type" : "text",  "text"  : sample["caption"]} ]
        },
    ]
    return { "messages" : conversation }
converted_dataset = [convert_to_conversation(sample) for sample in dataset]
# Configure the trainer
FastVisionModel.for_training(model)
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    data_collator = UnslothVisionDataCollator(model, tokenizer),
    train_dataset = converted_dataset,
    args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 30,
        learning_rate = 2e-4,
        fp16 = not is_bf16_supported(),
        bf16 = is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        dataset_num_proc = 4,
        max_seq_length = 2048,
    ),
)
# Train the model
trainer_stats = trainer.train()
# Test after training
print("\nAfter training:\n")
FastVisionModel.for_inference(model)
image = dataset[0]["image"]
instruction = "You are an expert radiographer. Describe accurately what you see in this image."
messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": instruction}
    ]}
]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
inputs = tokenizer(
    image,
    input_text,
    add_special_tokens = False,
    return_tensors = "pt",
).to("cuda")
text_streamer = TextStreamer(tokenizer, skip_prompt = True)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
                   use_cache = True, temperature = 1.5, min_p = 0.1)
# Save and upload the model
model.save_pretrained("lora_model")
tokenizer.save_pretrained("lora_model")
model.save_pretrained_merged("your-username/Llama-3.2-11B-Vision-Radiology-mini", tokenizer,)
model.push_to_hub_merged("your-username/Llama-3.2-11B-Vision-Radiology-mini", 
                        tokenizer, 
                        save_method = "merged_16bit", 
                        token = os.environ.get("HF_TOKEN"))


模型加载:我们以4位精度加载预训练的Llama 3.2视觉模型,以节省内存。


微调配置:我们启用对各种模型组件的微调,包括视觉层、语言层和注意力模块。


数据集准备:代码将放射学图像及其描述转换为模型可以理解的对话格式。


训练配置:我们使用具有特定参数的SFTTrainer:

  • 每个设备的批处理大小为2
  • 4个梯度累积步骤
  • 最大训练步骤为30
  • 学习率为2e-4
  • 线性学习率调度器


模型保存:训练后,我们保存LoRA权重和模型的合并版本。


总结

  1. 在训练期间始终监控你的GPU内存使用情况
  2. 从少量的训练步骤开始测试你的设置
  3. 确保你的训练数据质量高且标记正确
  4. 记录微调前后的结果以衡量改进情况

随意尝试不同的超参数,并根据你的特定用例调整代码。

文章来源:https://medium.com/@naman1011/fine-tuning-llama-3-2-vision-making-ai-better-at-reading-medical-images-bdf340fa8ee9
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消