大型语言模型 (LLM) 领域不断发展,新进展迅速涌现。一个令人兴奋的领域是多模态 LLM (MLLM) 的开发,它能够理解文本和图像并与之交互。这为文档理解、视觉问答等任务开辟了无限可能。
但在本文中,我们将探索一个强大的组合:InternVL 模型和 QLoRA 微调技术。我们将重点关注如何针对任何特定用例轻松定制此类模型。我们将使用这些工具创建一个收据理解管道,以高精度提取公司名称、地址和购买总金额等关键信息。
了解任务和数据集
本项目旨在开发一个系统,利用 InternVL 的功能从扫描收据中准确提取特定信息。这项任务提出了一个独特的挑战,不仅需要强大的自然语言处理(NLP)能力,还需要解释输入图像的视觉布局的能力。这将使我们能够创建一个单一的、无 OCR 的端到端管道,在复杂的文档中显示出强大的泛化能力。
为了训练和评估我们的模型,我们将使用 SROIE 数据集。SROIE 提供了 1000 张扫描的收据图像,每张图像都标注了关键实体,例如
我们将使用模糊相似度得分来评估模型的性能,该指标用于衡量预测实体与地面实况实体之间的相似度。该指标的范围从 0(无关结果)到 100(完美预测)。
InternVL:多模态强力工具
InternVL 是 OpenGVLab 的多模态 LLM 系列,旨在出色地完成涉及图像和文本的任务。其架构结合了视觉模型(如 InternViT)和语言模型(如 InternLM2 或 Phi-3)。我们将重点介绍 Mini-InternVL-Chat-2B-V1-5 变体,这是一个非常适合在消费级 GPU 上运行的较小版本。
InternVL 的主要优势:
利用 QLoRA 进行微调:内存效率高的方法
为了进一步提高模型的性能,我们将使用 QLoRA,这是一种微调技术,可在保持性能的同时显著降低内存消耗。下面是它的工作原理:
代码演练:基准性能
让我们深入了解代码。首先,我们将评估 Mini-InternVL-Chat-2B-V1-5 在未进行任何微调的情况下的基准性能:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = InternVLChatModel.from_pretrained(
args.path,
device_map={"": 0},
quantization_config=quant_config if args.quant else None,
torch_dtype=torch.bfloat16,
)
tokenizer = InternLM2Tokenizer.from_pretrained(args.path)
# set the max number of tiles in `max_num`
model.eval()
pixel_values = (
load_image(image_base_path / "X51005255805.jpg", max_num=6)
.to(torch.bfloat16)
.cuda()
)
generation_config = dict(
num_beams=1,
max_new_tokens=512,
do_sample=False,
)
# single-round single-image conversation
question = (
"Extract the company, date, address and total in json format."
"Respond with a valid JSON only."
)
# print(model)
response = model.chat(tokenizer, pixel_values, question, generation_config)
print(response)
结果:
```json
{
"company": "SAM SAM TRADING CO",
"date": "Fri, 29-12-2017",
"address": "67, JLN MENHAW 25/63 TNN SRI HUDA, 40400 SHAH ALAM",
"total": "RM 14.10"
}
```
此代码:
这次的零次评估结果令人印象深刻,平均模糊相似度达到 74.24%。这证明了 InternVL 无需微调即可理解收据并提取信息的能力。
微调: 利用 QLoRA 提高性能
为了进一步提高准确性,我们将使用 QLoRA 对模型进行微调。以下是我们的实施方法:
_data = load_data(args.data_path, fold="train")
# Quantization Config
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = InternVLChatModel.from_pretrained(
path,
device_map={"": 0},
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
tokenizer = InternLM2Tokenizer.from_pretrained(path)
# set the max number of tiles in `max_num`
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
print("img_context_token_id", img_context_token_id)
model.img_context_token_id = img_context_token_id
model.config.llm_config.use_cache = False
model = wrap_lora(model, r=128, lora_alpha=256)
training_data = SFTDataset(
data=_data, template=model.config.template, tokenizer=tokenizer
)
collator = CustomDataCollator(pad_token=tokenizer.pad_token_id, ignore_index=-100)
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
print("img_context_token_id", img_context_token_id)
model.img_context_token_id = img_context_token_id
print("model.img_context_token_id", model.img_context_token_id)
train_params = TrainingArguments(
output_dir=str(BASE_PATH / "results_modified"),
num_train_epochs=EPOCHS,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
optim="paged_adamw_32bit",
save_steps=len(training_data) // 10,
logging_steps=len(training_data) // 50,
learning_rate=5e-4,
lr_scheduler_type="cosine",
warmup_steps=100,
weight_decay=0.001,
max_steps=-1,
group_by_length=False,
max_grad_norm=1.0,
)
# Trainer
fine_tuning = SFTTrainer(
model=model,
train_dataset=training_data,
dataset_text_field="###",
tokenizer=tokenizer,
args=train_params,
data_collator=collator,
max_seq_length=tokenizer.model_max_length,
)
print(fine_tuning.model.print_trainable_parameters())
# Training
fine_tuning.train()
# Save Model
fine_tuning.model.save_pretrained(refined_model)
此代码:
下面是基本模型和 QLoRA 微调模型的比较示例:
Ground Truth:
{
"company": "YONG TAT HARDWARE TRADING","company": "YONG TAT HARDWARE TRADING",
"date": "13/03/2018",
"address": "NO 4,JALAN PERJIRANAN 10, TAMAN AIR BIRU, 81700 PASIR GUDANG, JOHOR.",
"total": "72.00"
}
Prediction Base: KO
```json
{
"company": "YONG TAT HARDWARE TRADING",
"date": "13/03/2016",
"address": "JM092487-D",
"total": "67.92"
}
```
Prediction QLoRA: OK
{
"company": "YONG TAT HARDWARE TRADING",
"date": "13/03/2018",
"address": "NO 4, JALAN PERUBANAN 10, TAMAN AIR BIRU, 81700 PASIR GUDANG, JOHOR",
"total": "72.00"
}
结果和结论
使用 QLoRA 进行微调后,我们的模型达到了 95.4% 的模糊相似度得分,比基准性能(74.24%)有了显著提高。这证明了 QLoRA 在提高模型准确性方面的强大功能,而无需大量计算资源(在 RTX 3080 GPU 上对 600 个样本进行 15 分钟的训练)。
我们利用 InternVL 和 QLoRA 成功构建了一个强大的收据理解管道。这种方法展示了多模态 LLM 在文档分析和信息提取等实际任务中的潜力。在这个用例中,我们利用几百个示例和几分钟的计算时间,在消费级 GPU 上提高了 30 个预测质量点。
多模态 LLM 的开发才刚刚开始,未来的可能性令人兴奋。在 MLLM时代,自动文档处理领域有着巨大的潜力。这些模型可以彻底改变我们从合同、发票和其他文档中提取信息的方式,而且只需要最少的训练数据。通过整合文本和视觉,它们能以前所未有的准确性分析复杂文档的布局,为更高效、更智能的信息管理铺平道路。
人工智能的未来是多模式的,而 InternVL 和 QLoRA 是强大的工具,可以帮助我们以较小的计算预算释放其潜力。