该存储库提供了一个基础大小的日本语RoBERTa模型。该模型是使用GitHub存储库 rinnakk/japanese-pretrained-models 中的代码由 rinna Co., Ltd. 进行训练的
from transformers import AutoTokenizer, AutoModelForMaskedLM tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-roberta-base", use_fast=False) tokenizer.do_lower_case = True # due to some bug of tokenizer config loading model = AutoModelForMaskedLM.from_pretrained("rinna/japanese-roberta-base")
要预测一个遮蔽标记,请确保在句子之前添加一个[CLS]标记,以便模型能够正确编码它,因为它在模型训练过程中被使用。
A) 直接在输入字符串中输入 [MASK],B) 在标记化之后用 [MASK] 替换一个标记,将会产生不同的标记序列,因此会得到不同的预测结果。在标记化之后使用 [MASK] 更合适(因为它与模型预训练的方式一致)。然而,Huggingface推理API只支持在输入字符串中输入 [MASK],并产生不太稳定的预测。
当为Roberta*模型提供position_ids时,Huggingface的transformers库会自动构建它,但是从padding_idx而不是0开始(请参见 issue 和Huggingface的create_position_ids_from_input_ids()函数),这与rinna/japanese-roberta-base不符合预期,因为相应tokenizer的padding_idx不是0。因此,请确保自己构造position_ids,并使其从位置ID 0 开始。
这是一个示例,说明我们的模型如何作为一个遮蔽语言模型工作。注意以下代码示例与运行Huggingface Inference API之间的区别。
# original text text = "4年に1度オリンピックは開かれる。" # prepend [CLS] text = "[CLS]" + text # tokenize tokens = tokenizer.tokenize(text) print(tokens) # output: ['[CLS]', '▁4', '年に', '1', '度', 'オリンピック', 'は', '開かれる', '。'] # mask a token masked_idx = 5 tokens[masked_idx] = tokenizer.mask_token print(tokens) # output: ['[CLS]', '▁4', '年に', '1', '度', '[MASK]', 'は', '開かれる', '。'] # convert to ids token_ids = tokenizer.convert_tokens_to_ids(tokens) print(token_ids) # output: [4, 1602, 44, 24, 368, 6, 11, 21583, 8] # convert to tensor import torch token_tensor = torch.LongTensor([token_ids]) # provide position ids explicitly position_ids = list(range(0, token_tensor.size(1))) print(position_ids) # output: [0, 1, 2, 3, 4, 5, 6, 7, 8] position_id_tensor = torch.LongTensor([position_ids]) # get the top 10 predictions of the masked token with torch.no_grad(): outputs = model(input_ids=token_tensor, position_ids=position_id_tensor) predictions = outputs[0][0, masked_idx].topk(10) for i, index_t in enumerate(predictions.indices): index = index_t.item() token = tokenizer.convert_ids_to_tokens([index])[0] print(i, token) """ 0 総会 1 サミット 2 ワールドカップ 3 フェスティバル 4 大会 5 オリンピック 6 全国大会 7 党大会 8 イベント 9 世界選手権 """
一个12层,768隐藏大小的基于Transformer的遮蔽语言模型。
该模型在 Japanese CC-100 和 Japanese Wikipedia 上训练,以优化遮蔽语言建模目标,在8台V100 GPU上进行了约15天的训练。在从CC-100中抽样的开发集上达到约3.9的困惑度。
该模型使用基于 sentencepiece 的标记器,词汇表是使用官方的sentencepiece训练脚本在日本维基百科上训练的。