GPT-2 Base Thai是基于 OpenAI GPT-2 模型的因果语言模型。它是在 OSCAR 数据集上进行训练的,具体是unshuffled_deduplicated_th子集。该模型是从头开始训练的,评估损失为1.708,评估困惑度为5.516。
此模型是使用HuggingFace的Flax框架进行训练的,并且是HuggingFace组织的 JAX/Flax Community Week 的一部分。所有的训练是在由Google Cloud团队赞助的TPUv3-8 VM上进行的。
所有用于训练的必要脚本可以在 Files and versions 标签中找到,以及通过Tensorboard记录的 Training metrics 。
Model | #params | Arch. | Training/Validation data (text) |
---|---|---|---|
gpt2-base-thai | 124M | GPT-2 | unshuffled_deduplicated_th Dataset |
该模型经过3个时代的训练,以下是训练结束后的最终结果。
train loss | valid loss | valid PPL | total time |
---|---|---|---|
1.638 | 1.708 | 5.516 | 6:12:34 |
from transformers import pipeline pretrained_name = "flax-community/gpt2-base-thai" nlp = pipeline( "text-generation", model=pretrained_name, tokenizer=pretrained_name ) nlp("สวัสดีตอนเช้า")
from transformers import GPT2Model, GPT2TokenizerFast pretrained_name = "flax-community/gpt2-base-thai" model = GPT2Model.from_pretrained(pretrained_name) tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_name) prompt = "สวัสดีตอนเช้า" encoded_input = tokenizer(prompt, return_tensors='pt') output = model(**encoded_input)