英文

lilt-en-funsd

该模型是在funsd-layoutlmv3数据集上基于 SCUT-DLVCLab/lilt-roberta-en-base 进行微调的版本。它在评估集上取得了以下结果:

  • 损失: 1.6117
  • 答案: {'precision': 0.8821428571428571, 'recall': 0.9069767441860465, 'f1': 0.8943874471937237, 'number': 817}
  • 表头: {'precision': 0.6126126126126126, 'recall': 0.5714285714285714, 'f1': 0.591304347826087, 'number': 119}
  • 问题: {'precision': 0.9045045045045045, 'recall': 0.9322191272051996, 'f1': 0.9181527206218564, 'number': 1077}
  • 整体精确度: 0.8797
  • 整体召回率: 0.9006
  • 整体F1值: 0.8900
  • 整体准确度: 0.8204

模型使用

from transformers import LiltForTokenClassification, LayoutLMv3Processor
from PIL import Image, ImageDraw, ImageFont
import torch

# load model and processor from huggingface hub
model = LiltForTokenClassification.from_pretrained("philschmid/lilt-en-funsd")
processor = LayoutLMv3Processor.from_pretrained("philschmid/lilt-en-funsd")


# helper function to unnormalize bboxes for drawing onto the image
def unnormalize_box(bbox, width, height):
    return [
        width * (bbox[0] / 1000),
        height * (bbox[1] / 1000),
        width * (bbox[2] / 1000),
        height * (bbox[3] / 1000),
    ]


label2color = {
    "B-HEADER": "blue",
    "B-QUESTION": "red",
    "B-ANSWER": "green",
    "I-HEADER": "blue",
    "I-QUESTION": "red",
    "I-ANSWER": "green",
}
# draw results onto the image
def draw_boxes(image, boxes, predictions):
    width, height = image.size
    normalizes_boxes = [unnormalize_box(box, width, height) for box in boxes]

    # draw predictions over the image
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    for prediction, box in zip(predictions, normalizes_boxes):
        if prediction == "O":
            continue
        draw.rectangle(box, outline="black")
        draw.rectangle(box, outline=label2color[prediction])
        draw.text((box[0] + 10, box[1] - 10), text=prediction, fill=label2color[prediction], font=font)
    return image


# run inference
def run_inference(image, model=model, processor=processor, output_image=True):
    # create model input
    encoding = processor(image, return_tensors="pt")
    del encoding["pixel_values"]
    # run inference
    outputs = model(**encoding)
    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    # get labels
    labels = [model.config.id2label[prediction] for prediction in predictions]
    if output_image:
        return draw_boxes(image, encoding["bbox"][0], labels)
    else:
        return labels


run_inference(dataset["test"][34]["image"])

训练过程

训练超参数

在训练过程中使用了以下超参数:

  • 学习率:5e-05
  • 训练批大小:8
  • 评估批大小:8
  • 随机种子:42
  • 优化器:Adam,beta值为(0.9,0.999),epsilon值为1e-08
  • 学习率调度器类型:线性
  • 训练步数:2500
  • 混合精度训练:Native AMP

训练结果

Training Loss Epoch Step Validation Loss Answer Header Question Overall Precision Overall Recall Overall F1 Overall Accuracy
0.0211 10.53 200 1.5528 {'precision': 0.8458904109589042, 'recall': 0.9069767441860465, 'f1': 0.8753691671588896, 'number': 817} {'precision': 0.5684210526315789, 'recall': 0.453781512605042, 'f1': 0.5046728971962617, 'number': 119} {'precision': 0.896551724137931, 'recall': 0.89322191272052, 'f1': 0.8948837209302325, 'number': 1077} 0.8596 0.8728 0.8662 0.8011
0.0132 21.05 400 1.3143 {'precision': 0.8447058823529412, 'recall': 0.8788249694002448, 'f1': 0.8614277144571085, 'number': 817} {'precision': 0.6020408163265306, 'recall': 0.4957983193277311, 'f1': 0.543778801843318, 'number': 119} {'precision': 0.8854262144821264, 'recall': 0.8969359331476323, 'f1': 0.8911439114391144, 'number': 1077} 0.8548 0.8659 0.8603 0.8095
0.0052 31.58 600 1.5747 {'precision': 0.8482446206115515, 'recall': 0.9167686658506732, 'f1': 0.8811764705882352, 'number': 817} {'precision': 0.6283185840707964, 'recall': 0.5966386554621849, 'f1': 0.6120689655172413, 'number': 119} {'precision': 0.8997161778618732, 'recall': 0.883008356545961, 'f1': 0.8912839737582005, 'number': 1077} 0.8626 0.8798 0.8711 0.8030
0.0073 42.11 800 1.4848 {'precision': 0.8487972508591065, 'recall': 0.9069767441860465, 'f1': 0.8769230769230769, 'number': 817} {'precision': 0.5190839694656488, 'recall': 0.5714285714285714, 'f1': 0.5439999999999999, 'number': 119} {'precision': 0.8941947565543071, 'recall': 0.8867223769730733, 'f1': 0.8904428904428905, 'number': 1077} 0.8514 0.8763 0.8636 0.7969
0.0057 52.63 1000 1.3993 {'precision': 0.8852071005917159, 'recall': 0.9155446756425949, 'f1': 0.9001203369434416, 'number': 817} {'precision': 0.5454545454545454, 'recall': 0.6050420168067226, 'f1': 0.5737051792828685, 'number': 119} {'precision': 0.899090909090909, 'recall': 0.9182915506035283, 'f1': 0.9085898024804776, 'number': 1077} 0.8710 0.8987 0.8846 0.8198
0.0023 63.16 1200 1.6463 {'precision': 0.8961201501877347, 'recall': 0.8763769889840881, 'f1': 0.886138613861386, 'number': 817} {'precision': 0.5625, 'recall': 0.5294117647058824, 'f1': 0.5454545454545455, 'number': 119} {'precision': 0.888, 'recall': 0.9275766016713092, 'f1': 0.9073569482288827, 'number': 1077} 0.8733 0.8833 0.8782 0.8082
0.001 73.68 1400 1.6476 {'precision': 0.8676814988290398, 'recall': 0.9069767441860465, 'f1': 0.8868940754039496, 'number': 817} {'precision': 0.6571428571428571, 'recall': 0.5798319327731093, 'f1': 0.6160714285714286, 'number': 119} {'precision': 0.908256880733945, 'recall': 0.9192200557103064, 'f1': 0.9137055837563451, 'number': 1077} 0.8785 0.8942 0.8863 0.8137
0.0014 84.21 1600 1.6493 {'precision': 0.8814814814814815, 'recall': 0.8739290085679314, 'f1': 0.8776889981561156, 'number': 817} {'precision': 0.6194690265486725, 'recall': 0.5882352941176471, 'f1': 0.603448275862069, 'number': 119} {'precision': 0.894404332129964, 'recall': 0.9201485608170845, 'f1': 0.9070938215102976, 'number': 1077} 0.8740 0.8818 0.8778 0.8041
0.0006 94.74 1800 1.6193 {'precision': 0.8766467065868263, 'recall': 0.8959608323133414, 'f1': 0.8861985472154963, 'number': 817} {'precision': 0.6068376068376068, 'recall': 0.5966386554621849, 'f1': 0.6016949152542374, 'number': 119} {'precision': 0.8946428571428572, 'recall': 0.9303621169916435, 'f1': 0.912152935821575, 'number': 1077} 0.8711 0.8967 0.8837 0.8137
0.0001 105.26 2000 1.6048 {'precision': 0.8751472320376914, 'recall': 0.9094247246022031, 'f1': 0.8919567827130852, 'number': 817} {'precision': 0.6140350877192983, 'recall': 0.5882352941176471, 'f1': 0.6008583690987125, 'number': 119} {'precision': 0.9062784349408554, 'recall': 0.924791086350975, 'f1': 0.9154411764705882, 'number': 1077} 0.8773 0.8987 0.8879 0.8194
0.0001 115.79 2200 1.6117 {'precision': 0.8821428571428571, 'recall': 0.9069767441860465, 'f1': 0.8943874471937237, 'number': 817} {'precision': 0.6126126126126126, 'recall': 0.5714285714285714, 'f1': 0.591304347826087, 'number': 119} {'precision': 0.9045045045045045, 'recall': 0.9322191272051996, 'f1': 0.9181527206218564, 'number': 1077} 0.8797 0.9006 0.8900 0.8204
0.0001 126.32 2400 1.6163 {'precision': 0.8799048751486326, 'recall': 0.9057527539779682, 'f1': 0.8926417370325694, 'number': 817} {'precision': 0.6052631578947368, 'recall': 0.5798319327731093, 'f1': 0.5922746781115881, 'number': 119} {'precision': 0.9062784349408554, 'recall': 0.924791086350975, 'f1': 0.9154411764705882, 'number': 1077} 0.8788 0.8967 0.8876 0.8192

框架版本

  • Transformers 4.24.0
  • Pytorch 1.12.1+cu113
  • Datasets 2.7.0
  • Tokenizers 0.12.1