模型:

google/pix2struct-textcaps-base

英文

Pix2Struct模型卡 - 在TextCaps上微调

目录

  • 简介
  • 使用模型
  • 贡献
  • 引用
  • 简介

    Pix2Struct是一种图像编码器 - 文本解码器模型,它是在图像与文本对上进行训练的,用于各种任务,包括图像描述和视觉问答。可以在论文的表1中找到可用模型的完整列表:

    该模型的摘要声明如下:

    视觉语境是无处不在的 - 来源范围包括带有图示的教科书,带有图像和表格的网页,带有按钮和表单的移动应用程序等。也许由于这种多样性,以往的工作通常依赖于具有有限共享基础数据、模型架构和目标的领域特定配方。我们提出了Pix2Struct,一种在纯视觉语言理解中针对性地处理视错语的预训练图像到文本模型,并可以在包含视觉语境的任务上进行微调。Pix2Struct在登录将网页屏幕截图解析为简化的HTML的同时进行预训练。通过反映在HTML结构中的视觉元素丰富性,网络提供了适合下游任务多样性的大量预训练数据的源。直观地说,该目标包含了常见的预训练信号,例如OCR、语言建模和图像描述。除了新颖的预训练策略,我们还介绍了可变分辨率输入表示和更灵活的语言和视觉输入集成,其中诸如问题之类的语言提示直接渲染在输入图像的顶部。我们首次展示单个预训练模型可以在四个领域的九项任务中的六项任务中取得最先进的结果,这四个领域包括文件、插图、用户界面和自然图像。

    使用模型

    从T5x转换为huggingface

    您可以按照以下方式使用脚本 convert_pix2struct_checkpoint_to_pytorch.py

    python convert_pix2struct_checkpoint_to_pytorch.py --t5x_checkpoint_path PATH_TO_T5X_CHECKPOINTS --pytorch_dump_path PATH_TO_SAVE
    

    如果您正在转换大型模型,请运行:

    python convert_pix2struct_checkpoint_to_pytorch.py --t5x_checkpoint_path PATH_TO_T5X_CHECKPOINTS --pytorch_dump_path PATH_TO_SAVE --use-large
    

    保存后,您可以使用以下代码片段推送已转换的模型:

    from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
    
    model = Pix2StructForConditionalGeneration.from_pretrained(PATH_TO_SAVE)
    processor = Pix2StructProcessor.from_pretrained(PATH_TO_SAVE)
    
    model.push_to_hub("USERNAME/MODEL_NAME")
    processor.push_to_hub("USERNAME/MODEL_NAME")
    

    运行模型

    在CPU上使用完整精度:

    您可以在CPU上使用完整精度运行模型:

    import requests
    from PIL import Image
    from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
    
    url = "https://www.ilankelman.org/stopsigns/australia.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    
    model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
    processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
    
    # image only
    inputs = processor(images=image, return_tensors="pt")
    
    predictions = model.generate(**inputs)
    print(processor.decode(predictions[0], skip_special_tokens=True))
    >>> A stop sign is on a street corner.
    

    在GPU上使用完整精度:

    您可以在GPU上使用完整精度运行模型:

    import requests
    from PIL import Image
    from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
    
    url = "https://www.ilankelman.org/stopsigns/australia.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    
    model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base").to("cuda")
    processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
    
    # image only
    inputs = processor(images=image, return_tensors="pt").to("cuda")
    
    predictions = model.generate(**inputs)
    print(processor.decode(predictions[0], skip_special_tokens=True))
    >>> A stop sign is on a street corner.
    

    在GPU上使用半精度:

    您可以在GPU上使用半精度运行模型:

    import requests
    import torch
    
    from PIL import Image
    from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
    
    url = "https://www.ilankelman.org/stopsigns/australia.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    
    model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base", torch_dtype=torch.bfloat16).to("cuda")
    processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
    
    # image only
    inputs = processor(images=image, return_tensors="pt").to("cuda", torch.bfloat16)
    
    predictions = model.generate(**inputs)
    print(processor.decode(predictions[0], skip_special_tokens=True))
    >>> A stop sign is on a street corner.
    

    使用不同的序列长度

    该模型在序列长度为2048上进行了训练。您可以尝试缩短序列长度以获得更高的内存效率推断,但对于小的序列长度(