模型:

google/pix2struct-textcaps-large

英文

Pix2Struct - 在TextCaps上进行微调的大型版本 的模型卡

目录

  • 简介
  • 使用模型
  • 贡献
  • 引用
  • 简介

    Pix2Struct是一个图像编码器-文本解码器模型,该模型在图像和文本对上进行训练,用于各种任务,包括图像标题生成和视觉问答。可在论文的表1中找到可用模型的完整列表:

    该模型的摘要如下:

    视觉语言是无处不在的-来源包括带有图表的教科书,带有图像和表格的网页,带有按钮和表单的移动应用程序。也许由于这种多样性,以前的工作通常依赖于领域特定的配方,对底层数据、模型架构和目标的共享有限。我们提出了Pix2Struct,这是一个用于纯视觉语言理解的预训练图像到文本模型,可在包含视觉语言的任务上进行微调。Pix2Struct通过学习将屏幕截图中的遮蔽解析为简化的HTML来进行预训练。网络作为一个富于视觉元素的源头,在HTML结构中清晰反映了这些元素,为多样化的下游任务提供了大量的预训练数据。直观地说,该目标囊括了常见的预训练信号,如光学字符识别(OCR)、语言建模和图像标题生成。除了新颖的预训练策略,我们还引入了可变分辨率的输入表示和更灵活的语言和视觉输入集成方式,其中诸如问题之类的语言提示直接呈现在输入图像上。我们首次展示了单一预训练模型可以在四个领域的九个任务中取得六个任务的最新结果:文档、插图、用户界面和自然图像。

    使用模型

    从T5x转换到huggingface

    您可以使用以下脚本进行转换:

    python convert_pix2struct_checkpoint_to_pytorch.py --t5x_checkpoint_path PATH_TO_T5X_CHECKPOINTS --pytorch_dump_path PATH_TO_SAVE
    

    如果要转换大型模型,请运行:

    python convert_pix2struct_checkpoint_to_pytorch.py --t5x_checkpoint_path PATH_TO_T5X_CHECKPOINTS --pytorch_dump_path PATH_TO_SAVE --use-large
    

    保存后,您可以使用以下代码片段推送已转换的模型:

    from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
    
    model = Pix2StructForConditionalGeneration.from_pretrained(PATH_TO_SAVE)
    processor = Pix2StructProcessor.from_pretrained(PATH_TO_SAVE)
    
    model.push_to_hub("USERNAME/MODEL_NAME")
    processor.push_to_hub("USERNAME/MODEL_NAME")
    

    运行模型

    在CPU上全精度运行:

    您可以在CPU上以全精度运行模型:

    import requests
    from PIL import Image
    from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
    
    url = "https://www.ilankelman.org/stopsigns/australia.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    
    model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
    processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
    
    # image only
    inputs = processor(images=image, return_tensors="pt")
    
    predictions = model.generate(**inputs)
    print(processor.decode(predictions[0], skip_special_tokens=True))
    >>> A street scene with a sign that says "STOP".
    

    在GPU上全精度运行:

    您可以在GPU上以全精度运行模型:

    import requests
    from PIL import Image
    from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
    
    url = "https://www.ilankelman.org/stopsigns/australia.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    
    model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-large").to("cuda")
    processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-large")
    
    # image only
    inputs = processor(images=image, return_tensors="pt").to("cuda")
    
    predictions = model.generate(**inputs)
    print(processor.decode(predictions[0], skip_special_tokens=True))
    >>> A street scene with a sign that says "STOP".
    

    在GPU上半精度运行:

    您可以在GPU上以半精度运行模型:

    import requests
    import torch
    
    from PIL import Image
    from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
    
    url = "https://www.ilankelman.org/stopsigns/australia.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    
    model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-large", torch_dtype=torch.bfloat16).to("cuda")
    processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-large")
    
    # image only
    inputs = processor(images=image, return_tensors="pt").to("cuda", torch.bfloat16)
    
    predictions = model.generate(**inputs)
    print(processor.decode(predictions[0], skip_special_tokens=True))
    >>> A street scene with a sign that says "STOP".
    

    使用不同的序列长度

    该模型已经在序列长度为4096的情况下进行了训练。您可以尝试减小序列长度以获得更高的内存效率,但对于小序列长度(<1024),可能会观察到一些性能下降。只需在调用处理器时传入max_patches参数即可:

    inputs = processor(images=image, return_tensors="pt", max_patches=1024)
    

    有条件的生成

    您还可以在输入文本之前添加一些输入文本以进行有条件的生成:

    import requests
    from PIL import Image
    from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
    
    url = "https://www.ilankelman.org/stopsigns/australia.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    text = "A picture of"
    
    model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-large")
    processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-large")
    
    # image only
    inputs = processor(images=image, text=text, return_tensors="pt")
    
    predictions = model.generate(**inputs)
    print(processor.decode(predictions[0], skip_special_tokens=True))
    

    贡献

    该模型最初由Kenton Lee、Mandar Joshi等人贡献,并由 Younes Belkada 添加到Hugging Face生态系统中。

    引用

    如果您想引用这项工作,请考虑引用原始论文:

    @misc{https://doi.org/10.48550/arxiv.2210.03347,
      doi = {10.48550/ARXIV.2210.03347},
      
      url = {https://arxiv.org/abs/2210.03347},
      
      author = {Lee, Kenton and Joshi, Mandar and Turc, Iulia and Hu, Hexiang and Liu, Fangyu and Eisenschlos, Julian and Khandelwal, Urvashi and Shaw, Peter and Chang, Ming-Wei and Toutanova, Kristina},
      
      keywords = {Computation and Language (cs.CL), Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
      
      title = {Pix2Struct: Screenshot Parsing as Pretraining for Visual Language Understanding},
      
      publisher = {arXiv},
      
      year = {2022},
      
      copyright = {Creative Commons Attribution 4.0 International}
    }