英文

DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary

模型描述

该模型基于4个自然语言推理(NLI)数据集的782,357个假设-前提对进行训练: MultiNLI Fever-NLI LingNLI ANLI

需要注意的是,该模型是通过二元NLI训练的,用于预测“包含”或“不包含”。这特别设计用于零-shot分类,其中“中立”和“矛盾”的差异是无关紧要的。

基础模型是 DeBERTa-v3-xsmall from Microsoft 。DeBERTa的v3变体通过包括不同的预训练目标明显优于之前版本的模型,请参阅 DeBERTa-V3 paper

对于最高性能(但速度较慢),建议使用 https://huggingface.co/MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli

用途和限制

如何使用模型
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model_name = "MoritzLaurer/DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing."
hypothesis = "The movie was good."

input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))  # device = "cuda:0" or "cpu"
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "not_entailment"]
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
print(prediction)

训练数据

该模型基于4个NLI数据集的782,357个假设-前提对进行训练: MultiNLI Fever-NLI LingNLI ANLI

训练过程

使用Hugging Face训练器对DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary进行了训练,以下是训练时的超参数。

training_args = TrainingArguments(
    num_train_epochs=5,              # total number of training epochs
    learning_rate=2e-05,
    per_device_train_batch_size=32,   # batch size per device during training
    per_device_eval_batch_size=32,    # batch size for evaluation
    warmup_ratio=0.1,                # number of warmup steps for learning rate scheduler
    weight_decay=0.06,               # strength of weight decay
    fp16=True                        # mixed precision training
)

评估结果

该模型使用MultiNLI、ANLI、LingNLI的二元测试集以及Fever-NLI的二元开发集进行评估。评估指标为准确率。

dataset mnli-m-2c mnli-mm-2c fever-nli-2c anli-all-2c anli-r3-2c lingnli-2c
accuracy 0.925 0.922 0.892 0.676 0.665 0.888
speed (text/sec, CPU, 128 batch) 6.0 6.3 3.0 5.8 5.0 7.6
speed (text/sec, GPU Tesla P100, 128 batch) 473 487 230 390 340 586

限制和偏见

请参考原始的DeBERTa论文以及不同NLI数据集的文献,了解可能存在的偏见。

引用

如果您使用了该模型,请引用:Laurer, Moritz, Wouter van Atteveldt, Andreu Salleras Casas, and Kasper Welbers. 2022. ‘Less Annotating, More Classifying – Addressing the Data Scarcity Issue of Supervised Machine Learning with Deep Transfer Learning and BERT - NLI’. Preprint, June. Open Science Framework. https://osf.io/74b8k .

合作或问题想法?

如果您有问题或合作想法,请通过m{点}laurer{at}vu{点}nl或 LinkedIn 与我联系。

调试和问题

请注意,DeBERTa-v3于06.12.21发布,旧版本的HF Transformers似乎存在运行该模型的问题(例如,与标记器相关的问题)。使用Transformers>=4.13可能会解决一些问题。