模型:
ITG/DialoGPT-medium-spanish-chitchat
这是一个改进的transformer-decoder模型,适用于单轮对话任务,用于处理西班牙语。我们对来自Microsoft的3.45亿参数的模型进行了微调,采用了因果语言建模(CLM)的目标函数。
我们使用了 Bot Framework Tools repository 中可用的数据集之一。我们处理了 the professional-styled personality chat dataset in Spanish ,该文件可供使用 to download here 。
import torch from transformers import AutoModelForCausalLM, AutoTokenizer CHAT_TURNS = 5 MAX_LENGTH = 1000 model = AutoModelForCausalLM.from_pretrained('ITG/DialoGPT-medium-spanish-chitchat') tokenizer = AutoTokenizer.from_pretrained('ITG/DialoGPT-medium-spanish-chitchat') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) for i in range(CHAT_TURNS): user_input = input(f"Step - {i} >> user prompt ->") with torch.no_grad(): # User turn, where "user_input" is the question (single-turn dialogue task) user_inputs_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt") user_inputs_ids = user_inputs_ids.to(device) # The chat history adds the generated tokens for the answer chat_history = model.generate(user_inputs_ids, max_length=MAX_LENGTH, pad_token_id=tokenizer.eos_token_id) # decode just the last generated output tokens from the model (do not include the user prompt again) step_model_answer = tokenizer.decode(chat_history[:, user_inputs_ids.shape[-1]:][0], skip_special_tokens=True) print(f"Step - {i} >> DialoGPT-spanish model answer -> {step_model_answer}")
Role | Response |
---|---|
User | Hola, buenos dias |
Bot | Buenos días |
User | ¿Qué tal estás? |
Bot | Fenomenal, gracias. |
User | ¡Tengo hambre! |
Bot | Quizá picar algo te ayude. |
User | Estoy triste |
Bot | Siento escuchar eso. |
User | ¿Tienes personalidad? |
Bot | Soy un bot. |
User | ¡Cuenta un chiste! |
Bot | En realidad no tengo tanta gracia. |
User | Jajajajajajajaja |
Bot | ¡Me alegro de que te rías! |
Hyper-parameter | Value |
---|---|
Validation partition (%) | 20% |
Training batch size | 8 |
Learning rate | 5e-4 |
Max training epochs | 20 |
Warmup training steps (%) | 6% |
Weight decay | 0.01 |
Optimiser (beta1, beta2, epsilon) | AdamW (0.9, 0.999, 1e-08) |
Monitoring metric (delta, patience) | Validation loss (0.1, 3) |
如果您想微调自己的对话模型,我们建议您从 DialoGPT model 开始。您可以查看 original GitHub repository 。
由于我们的方法可以为任何Unicode字符串分配概率,这使得我们可以在任何预处理、分词或词汇大小的数据集上对我们的语言模型进行评估。