英文

LSG 模型

Transformers >= 4.23.1 此模型依赖于自定义的建模文件,您需要添加trust_remote_code=True 详见 #13467

LSG ArXiv paper . Github/转换脚本可在此 link 下找到。

  • 使用方法
  • 参数
  • 稀疏选择类型
  • 任务
  • 训练全局令牌

该模型是根据 XLM-RoBERTa-base 模型进行了适应,尚未进行额外的预训练。它使用相同数量的参数/层和相同的分词器。

该模型可以处理长序列,但比 Longformer 或 BigBird(来自 Transformers)更快且更高效,并依赖于本地 + 稀疏 + 全局注意力 (LSG)。

该模型要求序列长度为块大小的倍数。模型是“自适应的”,如果需要会自动填充序列(在配置中设置adaptive=True)。然而,建议使用分词器截断输入(截断=True),并在需要时可使用块大小的倍数进行填充(pad_to_multiple_of=...)。

支持编码器-解码器,但我尚未进行详尽测试。 使用 PyTorch 实现。

使用方法

该模型依赖于自定义的建模文件,需要添加trust_remote_code=True才能使用。

from transformers import AutoModel, AutoTokenizer

model = AutoModel.from_pretrained("ccdv/lsg-xlm-roberta-base-4096", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-xlm-roberta-base-4096")

参数

您可以更改各种参数,如:

  • 全局令牌的数量 (num_global_tokens=1)
  • 本地块大小 (block_size=128)
  • 稀疏块大小 (sparse_block_size=128)
  • 稀疏因子 (sparsity_factor=2)
  • 遮罩第一个令牌 (遮罩第一个令牌,因为它与第一个全局令牌重复)
  • 参见 config.json 文件

默认参数在实践中表现良好。如果内存不足,请缩小块大小,增加稀疏因子,并删除注意力得分矩阵中的丢失比例。

from transformers import AutoModel

model = AutoModel.from_pretrained("ccdv/lsg-xlm-roberta-base-4096", 
    trust_remote_code=True, 
    num_global_tokens=16,
    block_size=64,
    sparse_block_size=64,
    attention_probs_dropout_prob=0.0
    sparsity_factor=4,
    sparsity_type="none",
    mask_first_token=True
)

稀疏选择类型

有5种不同的稀疏选择模式。最佳类型取决于任务。 需注意,对于长度小于 2*block_size 的序列,类型没有影响。

  • 稀疏类型="norm",选择最高范数的令牌
    • 对于小的稀疏因子(2到4),效果最好
    • 其他参数:
  • 稀疏类型="pooling",使用平均池化合并令牌
    • 对于小的稀疏因子(2到4),效果最好
    • 其他参数:
  • 稀疏类型="lsh",使用 LSH 算法来聚类相似的令牌
    • 对于大的稀疏因子(4+),效果最好
    • LSH 基于随机投影,因此不同的种子可能导致推理结果略有不同
    • 其他参数:
      • lsg_num_pre_rounds=1,在计算质心前将令牌合并 n 次
  • 稀疏类型="stride",每个头使用分步机制
    • 每个头都会使用不同的以 sparsify_factor 为步长的令牌
    • 如果 sparsify_factor > num_heads,则不推荐使用
  • 稀疏类型="block_stride",每个头使用分块步进机制
    • 每个头都会使用以 sparsify_factor 为步长的令牌块
    • 如果 sparsify_factor > num_heads,则不推荐使用

任务

填充遮罩示例:

from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer

model = AutoModelForMaskedLM.from_pretrained("ccdv/lsg-xlm-roberta-base-4096", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-xlm-roberta-base-4096")

SENTENCES = ["Paris is the <mask> of France."]
pipeline = FillMaskPipeline(model, tokenizer)
output = pipeline(SENTENCES, top_k=1)
    
output = [o[0]["sequence"] for o in output]
> ['Paris is the capital of France.']

分类示例:

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("ccdv/lsg-xlm-roberta-base-4096", 
    trust_remote_code=True, 
    pool_with_global=True, # pool with a global token instead of first token
)
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-xlm-roberta-base-4096")

SENTENCE = "This is a test for sequence classification. " * 300
token_ids = tokenizer(
    SENTENCE, 
    return_tensors="pt", 
    #pad_to_multiple_of=... # Optional
    truncation=True
    )
output = model(**token_ids)

> SequenceClassifierOutput(loss=None, logits=tensor([[-0.3051, -0.1762]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)

训练全局令牌

仅训练全局令牌和分类头部:

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("ccdv/lsg-xlm-roberta-base-4096", 
    trust_remote_code=True, 
    pool_with_global=True, # pool with a global token instead of first token
    num_global_tokens=16
)
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-xlm-roberta-base-4096")

for name, param in model.named_parameters():
    if "global_embeddings" not in name:
        param.requires_grad = False
    else:
        param.required_grad = True

XLM-RoBERTa

@article{DBLP:journals/corr/abs-2105-00572,
  author    = {Naman Goyal and
               Jingfei Du and
               Myle Ott and
               Giri Anantharaman and
               Alexis Conneau},
  title     = {Larger-Scale Transformers for Multilingual Masked Language Modeling},
  journal   = {CoRR},
  volume    = {abs/2105.00572},
  year      = {2021},
  url       = {https://arxiv.org/abs/2105.00572},
  eprinttype = {arXiv},
  eprint    = {2105.00572},
  timestamp = {Wed, 12 May 2021 15:54:31 +0200},
  biburl    = {https://dblp.org/rec/journals/corr/abs-2105-00572.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}