英文

LSG模型

Transformers >= 4.23.1 此模型依赖于自定义的建模文件,您需要添加trust_remote_code=True 请查看 #13467

LSG ArXiv paper . 可在此处找到Github/conversion脚本 link .

  • 用法
  • 参数
  • 稀疏选择类型
  • 任务
  • 训练全局标记

此模型是在没有额外预训练的情况下从 BERT-base-uncased 改编而来的。它使用相同数量的参数/层和相同的分词器。

此模型可以处理长序列,但比Longformer或BigBird(来自Transformers)更快且更高效,它依赖于本地 + 稀疏 + 全局注意力(LSG)。

该模型要求序列的长度是块大小的倍数。该模型是“自适应”的,如果需要,它会自动填充序列(在配置中设置adaptive=True)。但是,由于分词器的存在,建议截断输入(截断=True)并可选地进行块大小的倍数填充(pad_to_multiple_of = ...)。

支持编码器-解码器,但我没有进行广泛测试。 采用PyTorch实现。

用法

此模型依赖于自定义的建模文件,您需要添加trust_remote_code=True以使用它。

from transformers import AutoModel, AutoTokenizer

model = AutoModel.from_pretrained("ccdv/lsg-bert-base-uncased-4096", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-bert-base-uncased-4096")

参数

您可以更改各种参数,例如:

  • 全局标记的数量(num_global_tokens=1)
  • 本地块大小(block_size=128)
  • 稀疏块大小(sparse_block_size=128)
  • 稀疏因子(sparsity_factor=2)
  • 遮盖第一个标记(因为它与第一个全局标记重复)
  • 请参见config.json文件

默认参数在实践中效果很好。如果内存不足,请减小块大小,增加稀疏因子,并消除注意力分数矩阵中的丢失。

from transformers import AutoModel

model = AutoModel.from_pretrained("ccdv/lsg-bert-base-uncased-4096", 
    trust_remote_code=True, 
    num_global_tokens=16,
    block_size=64,
    sparse_block_size=64,
    attention_probs_dropout_prob=0.0
    sparsity_factor=4,
    sparsity_type="none",
    mask_first_token=True
)

稀疏选择类型

有5种不同的稀疏选择模式。最佳类型取决于任务。 请注意,对于长度小于2 *块大小的序列,类型无效。

  • 稀疏类型="norm",选择最高范数的标记
    • 最适合较小的稀疏因子(2到4)
    • 附加参数:
  • 稀疏类型="pooling",使用平均池化合并标记
    • 最适合较小的稀疏因子(2到4)
    • 附加参数:
  • 稀疏类型="lsh",使用LSH算法将相似的标记聚类
    • 最适合较大的稀疏因子(4+)
    • LSH依赖于随机投影,因此使用不同的种子进行推断可能会稍有不同
    • 附加参数:
      • lsg_num_pre_rounds=1,在计算质心之前预合并标记n次
  • 稀疏类型="stride",使用每个头部的步幅机制
    • 每个头部将使用以sparsify_factor为步幅的不同标记
    • 如果sparsify_factor > num_heads,则不建议使用
  • 稀疏类型="block_stride",使用每个头部的块步幅机制
    • 每个头部将使用以sparsify_factor为步幅的块标记
    • 如果sparsify_factor > num_heads,则不建议使用

任务

填充掩码示例:

from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer

model = AutoModelForMaskedLM.from_pretrained("ccdv/lsg-bert-base-uncased-4096", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-bert-base-uncased-4096")

SENTENCES = "Paris is the [MASK] of France."
pipeline = FillMaskPipeline(model, tokenizer)
output = pipeline(SENTENCES)

> 'Paris is the capital of France.'

分类示例:

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("ccdv/lsg-bert-base-uncased-4096", 
    trust_remote_code=True, 
    pool_with_global=True, # pool with a global token instead of first token
)
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-bert-base-uncased-4096")

SENTENCE = "This is a test for sequence classification. " * 300
token_ids = tokenizer(
    SENTENCE, 
    return_tensors="pt", 
    #pad_to_multiple_of=... # Optional
    truncation=True
    )
output = model(**token_ids)

> SequenceClassifierOutput(loss=None, logits=tensor([[-0.3051, -0.1762]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)

训练全局标记

仅训练全局标记和分类头部:

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("ccdv/lsg-bert-base-uncased-4096", 
    trust_remote_code=True, 
    pool_with_global=True, # pool with a global token instead of first token
    num_global_tokens=16
)
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-bert-base-uncased-4096")

for name, param in model.named_parameters():
    if "global_embeddings" not in name:
        param.requires_grad = False
    else:
        param.required_grad = True

BERT

@article{DBLP:journals/corr/abs-1810-04805,
  author    = {Jacob Devlin and
               Ming{-}Wei Chang and
               Kenton Lee and
               Kristina Toutanova},
  title     = {{BERT:} Pre-training of Deep Bidirectional Transformers for Language
               Understanding},
  journal   = {CoRR},
  volume    = {abs/1810.04805},
  year      = {2018},
  url       = {http://arxiv.org/abs/1810.04805},
  archivePrefix = {arXiv},
  eprint    = {1810.04805},
  timestamp = {Tue, 30 Oct 2018 20:39:56 +0100},
  biburl    = {https://dblp.org/rec/journals/corr/abs-1810-04805.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}