模型:

tinkoff-ai/ruDialoGPT-medium

英文

此代模型是基于 sberbank-ai/rugpt3medium_based_on_gpt2 的。它是根据大量的对话数据进行训练的,可以用于构建生成式对话代理。

此模型的上下文大小为3。

我们在一份私人验证集上计算了 this paper 中介绍的指标:

  • 适宜性:Crowdsourcers被问及模型的回答是否在给定的情境下有意义
  • 特异性:Crowdsourcers被问及模型的回答是否对给定的情境具体,换句话说,我们不希望我们的模型给出泛泛而论的回答
  • SSA,即上述两个指标的平均值(适宜性特异性平均)
sensibleness specificity SSA
1233321 0.64 0.5 0.57
1234321 0.78 0.69 0.735

如何使用:

import torch
from transformers import AutoTokenizer, AutoModelWithLMHead

tokenizer = AutoTokenizer.from_pretrained('tinkoff-ai/ruDialoGPT-medium')
model = AutoModelWithLMHead.from_pretrained('tinkoff-ai/ruDialoGPT-medium')
inputs = tokenizer('@@ПЕРВЫЙ@@ привет @@ВТОРОЙ@@ привет @@ПЕРВЫЙ@@ как дела? @@ВТОРОЙ@@', return_tensors='pt')
generated_token_ids = model.generate(
    **inputs,
    top_k=10,
    top_p=0.95,
    num_beams=3,
    num_return_sequences=3,
    do_sample=True,
    no_repeat_ngram_size=2,
    temperature=1.2,
    repetition_penalty=1.2,
    length_penalty=1.0,
    eos_token_id=50257,
    max_new_tokens=40
)
context_with_response = [tokenizer.decode(sample_token_ids) for sample_token_ids in generated_token_ids]
context_with_response