提高自定义用例中MLLM性能的简单方法

2024年06月28日 由 alex 发表 179 0

大型语言模型 (LLM) 领域不断发展,新进展迅速涌现。一个令人兴奋的领域是多模态 LLM (MLLM) 的开发,它能够理解文本和图像并与之交互。这为文档理解、视觉问答等任务开辟了无限可能。


但在本文中,我们将探索一个强大的组合:InternVL 模型和 QLoRA 微调技术。我们将重点关注如何针对任何特定用例轻松定制此类模型。我们将使用这些工具创建一个收据理解管道,以高精度提取公司名称、地址和购买总金额等关键信息。


了解任务和数据集

本项目旨在开发一个系统,利用 InternVL 的功能从扫描收据中准确提取特定信息。这项任务提出了一个独特的挑战,不仅需要强大的自然语言处理(NLP)能力,还需要解释输入图像的视觉布局的能力。这将使我们能够创建一个单一的、无 OCR 的端到端管道,在复杂的文档中显示出强大的泛化能力。


为了训练和评估我们的模型,我们将使用 SROIE 数据集。SROIE 提供了 1000 张扫描的收据图像,每张图像都标注了关键实体,例如


  • 公司: 商店或企业名称。
  • 日期:购买日期。
  • 地址:商店地址。
  • 总额: 支付的总金额。


9


我们将使用模糊相似度得分来评估模型的性能,该指标用于衡量预测实体与地面实况实体之间的相似度。该指标的范围从 0(无关结果)到 100(完美预测)。


InternVL:多模态强力工具

InternVL 是 OpenGVLab 的多模态 LLM 系列,旨在出色地完成涉及图像和文本的任务。其架构结合了视觉模型(如 InternViT)和语言模型(如 InternLM2 或 Phi-3)。我们将重点介绍 Mini-InternVL-Chat-2B-V1-5 变体,这是一个非常适合在消费级 GPU 上运行的较小版本。


InternVL 的主要优势:


  • 高效: 体积小巧,可实现高效的训练和推理。
  • 准确性: 尽管体积较小,但它在各种基准测试中取得了极具竞争力的性能。
  • 多模式能力: 它将图像和文本理解完美地结合在一起。


利用 QLoRA 进行微调:内存效率高的方法

为了进一步提高模型的性能,我们将使用 QLoRA,这是一种微调技术,可在保持性能的同时显著降低内存消耗。下面是它的工作原理:


  1. 量化: 预训练的 LLM 被量化为 4 位精度,从而减少了内存占用。
  2. 低阶适配器(LoRA): LoRA 并不修改预训练模型的所有参数,而是在网络中添加小型可训练适配器。这些适配器可捕捉特定任务的信息,而无需更改主模型。
  3. 高效训练: 量化与 LoRA 的结合,即使在内存有限的 GPU 上也能实现高效的微调。


代码演练:基准性能

让我们深入了解代码。首先,我们将评估 Mini-InternVL-Chat-2B-V1-5 在未进行任何微调的情况下的基准性能:


quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model = InternVLChatModel.from_pretrained(
    args.path,
    device_map={"": 0},
    quantization_config=quant_config if args.quant else None,
    torch_dtype=torch.bfloat16,
)
tokenizer = InternLM2Tokenizer.from_pretrained(args.path)
# set the max number of tiles in `max_num`
model.eval()
pixel_values = (
    load_image(image_base_path / "X51005255805.jpg", max_num=6)
    .to(torch.bfloat16)
    .cuda()
)
generation_config = dict(
    num_beams=1,
    max_new_tokens=512,
    do_sample=False,
)
# single-round single-image conversation
question = (
    "Extract the company, date, address and total in json format."
    "Respond with a valid JSON only."
)
# print(model)
response = model.chat(tokenizer, pixel_values, question, generation_config)
print(response)


结果:


```json
{
  "company": "SAM SAM TRADING CO",
  "date": "Fri, 29-12-2017",
  "address": "67, JLN MENHAW 25/63 TNN SRI HUDA, 40400 SHAH ALAM",
  "total": "RM 14.10"
}
```


此代码:


  1. 从 "Hugging Face "中心加载模型。
  2. 加载收据图像样本并将其转换为张量。
  3. 提出一个问题,要求模型从图像中提取相关信息。
  4. 运行模型并以 JSON 格式输出提取的信息。


这次的零次评估结果令人印象深刻,平均模糊相似度达到 74.24%。这证明了 InternVL 无需微调即可理解收据并提取信息的能力。


微调: 利用 QLoRA 提高性能

为了进一步提高准确性,我们将使用 QLoRA 对模型进行微调。以下是我们的实施方法:


_data = load_data(args.data_path, fold="train")
# Quantization Config
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model = InternVLChatModel.from_pretrained(
    path,
    device_map={"": 0},
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)
tokenizer = InternLM2Tokenizer.from_pretrained(path)
# set the max number of tiles in `max_num`
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
print("img_context_token_id", img_context_token_id)
model.img_context_token_id = img_context_token_id
model.config.llm_config.use_cache = False
model = wrap_lora(model, r=128, lora_alpha=256)
training_data = SFTDataset(
    data=_data, template=model.config.template, tokenizer=tokenizer
)
collator = CustomDataCollator(pad_token=tokenizer.pad_token_id, ignore_index=-100)
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
print("img_context_token_id", img_context_token_id)
model.img_context_token_id = img_context_token_id
print("model.img_context_token_id", model.img_context_token_id)
train_params = TrainingArguments(
    output_dir=str(BASE_PATH / "results_modified"),
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    optim="paged_adamw_32bit",
    save_steps=len(training_data) // 10,
    logging_steps=len(training_data) // 50,
    learning_rate=5e-4,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    weight_decay=0.001,
    max_steps=-1,
    group_by_length=False,
    max_grad_norm=1.0,
)
# Trainer
fine_tuning = SFTTrainer(
    model=model,
    train_dataset=training_data,
    dataset_text_field="###",
    tokenizer=tokenizer,
    args=train_params,
    data_collator=collator,
    max_seq_length=tokenizer.model_max_length,
)
print(fine_tuning.model.print_trainable_parameters())
# Training
fine_tuning.train()
# Save Model
fine_tuning.model.save_pretrained(refined_model)


此代码:


  1. 加载已启用量化的模型。
  2. 用 LoRA 封装模型,添加可训练的适配器。
  3. 从 SROIE 数据集创建数据集。
  4. 定义训练参数,如学习率、批量大小和历时。
  5. 初始化训练器以处理训练过程。
  6. 在 SROIE 数据集上训练模型。
  7. 保存微调后的模型。


下面是基本模型和 QLoRA 微调模型的比较示例:


Ground Truth: 
{
    "company": "YONG TAT HARDWARE TRADING","company": "YONG TAT HARDWARE TRADING",
    "date": "13/03/2018",
    "address": "NO 4,JALAN PERJIRANAN 10, TAMAN AIR BIRU, 81700 PASIR GUDANG, JOHOR.",
    "total": "72.00"
}
Prediction Base: KO
```json
{
  "company": "YONG TAT HARDWARE TRADING",
  "date": "13/03/2016",
  "address": "JM092487-D",
  "total": "67.92"
}
```
Prediction QLoRA: OK
{
    "company": "YONG TAT HARDWARE TRADING",
    "date": "13/03/2018",
    "address": "NO 4, JALAN PERUBANAN 10, TAMAN AIR BIRU, 81700 PASIR GUDANG, JOHOR",
    "total": "72.00"
}


结果和结论

使用 QLoRA 进行微调后,我们的模型达到了 95.4% 的模糊相似度得分,比基准性能(74.24%)有了显著提高。这证明了 QLoRA 在提高模型准确性方面的强大功能,而无需大量计算资源(在 RTX 3080 GPU 上对 600 个样本进行 15 分钟的训练)。


我们利用 InternVL 和 QLoRA 成功构建了一个强大的收据理解管道。这种方法展示了多模态 LLM 在文档分析和信息提取等实际任务中的潜力。在这个用例中,我们利用几百个示例和几分钟的计算时间,在消费级 GPU 上提高了 30 个预测质量点。


多模态 LLM 的开发才刚刚开始,未来的可能性令人兴奋。在 MLLM时代,自动文档处理领域有着巨大的潜力。这些模型可以彻底改变我们从合同、发票和其他文档中提取信息的方式,而且只需要最少的训练数据。通过整合文本和视觉,它们能以前所未有的准确性分析复杂文档的布局,为更高效、更智能的信息管理铺平道路。


人工智能的未来是多模式的,而 InternVL 和 QLoRA 是强大的工具,可以帮助我们以较小的计算预算释放其潜力。

文章来源:https://medium.com/towards-data-science/a-simple-recipe-to-boost-the-performance-of-mllms-on-your-custom-use-case-6014440f5373
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消