英文

GPT-2 Base Thai

GPT-2 Base Thai是基于 OpenAI GPT-2 模型的因果语言模型。它是在 OSCAR 数据集上进行训练的,具体是unshuffled_deduplicated_th子集。该模型是从头开始训练的,评估损失为1.708,评估困惑度为5.516。

此模型是使用HuggingFace的Flax框架进行训练的,并且是HuggingFace组织的 JAX/Flax Community Week 的一部分。所有的训练是在由Google Cloud团队赞助的TPUv3-8 VM上进行的。

所有用于训练的必要脚本可以在 Files and versions 标签中找到,以及通过Tensorboard记录的 Training metrics

模型

Model #params Arch. Training/Validation data (text)
gpt2-base-thai 124M GPT-2 unshuffled_deduplicated_th Dataset

评估结果

该模型经过3个时代的训练,以下是训练结束后的最终结果。

train loss valid loss valid PPL total time
1.638 1.708 5.516 6:12:34

使用方法

作为因果语言模型

from transformers import pipeline

pretrained_name = "flax-community/gpt2-base-thai"

nlp = pipeline(
    "text-generation",
    model=pretrained_name,
    tokenizer=pretrained_name
)

nlp("สวัสดีตอนเช้า")

在PyTorch中进行特征提取

from transformers import GPT2Model, GPT2TokenizerFast

pretrained_name = "flax-community/gpt2-base-thai"
model = GPT2Model.from_pretrained(pretrained_name)
tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_name)

prompt = "สวัสดีตอนเช้า"
encoded_input = tokenizer(prompt, return_tensors='pt')
output = model(**encoded_input)

团队成员