该模型使用了来自8个自然语言推理(NLI)数据集的1,279,665个假设-前提对进行训练,其中包括 MultiNLI 、 Fever-NLI 、 LingNLI 和 DocNLI (包括 ANLI 、QNLI、DUC、CNN/DailyMail和Curation)。
这是模型库中唯一经过8个NLI数据集训练的模型,包括使用超长文本进行长距离推理的DocNLI。请注意,该模型通过将“neural”和“contradiction”的类别合并为“not-entailment”来创建更多的训练数据。
基本模型是 DeBERTa-v3-small from Microsoft 。DeBERTa的v3变体通过使用不同的预训练目标显著优于之前的模型版本,请参阅原始 DeBERTa paper 的附录11以及 DeBERTa-V3 paper 。
from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch model_name = "MoritzLaurer/DeBERTa-v3-small-mnli-fever-docnli-ling-2c" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing." hypothesis = "The movie was good." input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt") output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu" prediction = torch.softmax(output["logits"][0], -1).tolist() label_names = ["entailment", "neutral", "contradiction"] prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)} print(prediction)
该模型使用了来自8个NLI数据集的1,279,665个假设-前提对进行训练,其中包括 MultiNLI 、 Fever-NLI 、 LingNLI 和 DocNLI (包括 ANLI 、QNLI、DUC、CNN/DailyMail和Curation)。
DeBERTa-v3-small-mnli-fever-docnli-ling-2c使用Hugging Face训练器进行训练,使用了以下超参数。
training_args = TrainingArguments( num_train_epochs=3, # total number of training epochs learning_rate=2e-05, per_device_train_batch_size=32, # batch size per device during training per_device_eval_batch_size=32, # batch size for evaluation warmup_ratio=0.1, # number of warmup steps for learning rate scheduler weight_decay=0.06, # strength of weight decay fp16=True # mixed precision training )
该模型使用MultiNLI和ANLI的二进制测试集以及Fever-NLI的二进制开发集进行评估(两个类别而不是三个类别)。所使用的度量标准是准确度。
mnli-m-2c | mnli-mm-2c | fever-nli-2c | anli-all-2c | anli-r3-2c |
---|---|---|---|---|
0.927 | 0.921 | 0.892 | 0.684 | 0.673 |
关于潜在偏见,请参考原始DeBERTa论文和不同NLI数据集的文献。
如果您要引用此模型,请引用原始的DeBERTa论文、相关的NLI数据集,并包含Hugging Face模型库上该模型的链接。
如果您有任何问题或合作想法,请通过m{dot}laurer{at}vu{dot}nl或 LinkedIn 与我联系。
请注意,DeBERTa-v3是最近发布的,较旧版本的HF Transformers似乎存在运行该模型的问题(例如与标记器相关的问题)。使用Transformers==4.13可能可以解决部分问题。