模型:
aubmindlab/aragpt2-large
您可以在我们的论文中找到更多信息 AraGPT2
此存储库中的代码被用于训练所有的GPT2变体。代码支持使用TPUEstimator API在GPU和TPU上训练和微调GPT2。
GPT2-base和medium使用来自gpt2文件夹的代码,并且可以使用 minimaxir/gpt-2-simple 仓库中的代码训练模型。这些模型使用lamb优化器进行训练,遵循与gpt2相同的架构,与transformers库完全兼容。
GPT2-large和GPT2-mega使用了 imcaspar/gpt2-ml 库,并遵循grover架构。您可以使用grover/modeling_gpt2.py中的pytorch类直接替换transformers库中的类(它应该支持transformers的v4.x版本)。这两个模型都使用adafactor优化器进行训练,因为adam和lamb优化器内存使用过多,导致模型无法适应TPU核心的1批次。
AraGPT2训练使用的是与AraBERTv2相同的大型阿拉伯语数据集。
from transformers import GPT2TokenizerFast, pipeline #for base and medium from transformers import GPT2LMHeadModel #for large and mega # pip install arabert from arabert.aragpt2.grover.modeling_gpt2 import GPT2LMHeadModel from arabert.preprocess import ArabertPreprocessor MODEL_NAME='aubmindlab/aragpt2-large' arabert_prep = ArabertPreprocessor(model_name=MODEL_NAME) text="" text_clean = arabert_prep.preprocess(text) model = GPT2LMHeadModel.from_pretrained(MODEL_NAME) tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_NAME) generation_pipeline = pipeline("text-generation",model=model,tokenizer=tokenizer) #feel free to try different decoding settings generation_pipeline(text, pad_token_id=tokenizer.eos_token_id, num_beams=10, max_length=200, top_p=0.9, repetition_penalty = 3.0, no_repeat_ngram_size = 3)[0]['generated_text'] >>>
请按照链接 here 中的指南进行操作
创建训练TFRecords
python create_pretraining_data.py --input_file=<RAW TEXT FILE with documents/article separated by an empty line> --output_file=<OUTPUT TFRecord> --tokenizer_dir=<Directory with the GPT2 Tokenizer files>
微调
python3 run_pretraining.py \\\r\n --input_file="gs://<GS_BUCKET>/pretraining_data/*" \\\r\n --output_dir="gs://<GS_BUCKET>/pretraining_model/" \\\r\n --config_file="config/small_hparams.json" \\\r\n --batch_size=128 \\\r\n --eval_batch_size=8 \\\r\n --num_train_steps= \\\r\n --num_warmup_steps= \\\r\n --learning_rate= \\\r\n --save_checkpoints_steps= \\\r\n --max_seq_length=1024 \\\r\n --max_eval_steps= \\\r\n --optimizer="lamb" \\\r\n --iterations_per_loop=5000 \\\r\n --keep_checkpoint_max=10 \\\r\n --use_tpu=True \\\r\n --tpu_name=<TPU NAME> \\\r\n --do_train=True \\\r\n --do_eval=False
Model | Optimizer | Context size | Embedding Size | Num of heads | Num of layers | Model Size / Num of Params |
---|---|---|---|---|---|---|
AraGPT2-base | lamb | 1024 | 768 | 12 | 12 | 527MB/135M |
AraGPT2-medium | lamb | 1024 | 1024 | 16 | 24 | 1.38G/370M |
AraGPT2-large | adafactor | 1024 | 1280 | 20 | 36 | 2.98GB/792M |
AraGPT2-mega | adafactor | 1024 | 1536 | 25 | 48 | 5.5GB/1.46B |
所有模型都可以在HuggingFace模型页面上以 aubmindlab 名称找到。检查点可以以PyTorch、TF2和TF1格式使用。
有关数据集来源,请参见数据集部分
Model | Hardware | num of examples (seq len = 1024) | Batch Size | Num of Steps | Time (in days) |
---|---|---|---|---|---|
AraGPT2-base | TPUv3-128 | 9.7M | 1792 | 125K | 1.5 |
AraGPT2-medium | TPUv3-8 | 9.7M | 1152 | 85K | 1.5 |
AraGPT2-large | TPUv3-128 | 9.7M | 256 | 220k | 3 |
AraGPT2-mega | TPUv3-128 | 9.7M | 256 | 780K | 9 |
用于新AraBERT模型的预训练数据也用于GPT2和ELECTRA
该数据集包含77GB或200,095,961行或8,655,948,860个单词或82,232,988,358个字符(在应用Farasa分词之前)
对于新数据集,我们将经过彻底过滤的未随机排列的OSCAR语料库添加到了之前用于AraBERTv1的数据集中,但我们删除了先前爬取的网站:
GPT2 Arabic生成的文本是由训练有素的神经网络模型自动生成的,该模型在大量文本上进行了训练,不代表作者或其机构的官方态度和偏好。GPT2 Arabic生成的文本仅应用于研究和科学目的。如果侵犯了您的权益和利益,或违反了社会道德,请不要传播。
@inproceedings{antoun-etal-2021-aragpt2, title = "{A}ra{GPT}2: Pre-Trained Transformer for {A}rabic Language Generation", author = "Antoun, Wissam and Baly, Fady and Hajj, Hazem", booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop", month = apr, year = "2021", address = "Kyiv, Ukraine (Virtual)", publisher = "Association for Computational Linguistics", url = "https://www.aclweb.org/anthology/2021.wanlp-1.21", pages = "196--207", }
感谢TensorFlow Research Cloud(TFRC)提供免费访问Cloud TPUs的机会,没有这个计划我们无法完成这项工作,还要感谢 AUB MIND Lab 成员对我们的持续支持。同时感谢 Yakshof 和Assafir提供数据和存储访问的支持。还要感谢Habib Rahal( https://www.behance.net/rahalhabib ),为AraBERT代表提供了面孔。
Wissam Antoun : Linkedin | Twitter | Github | wfa07@mail.aub.edu | wissam.antoun@gmail.com
Fady Baly : Linkedin | Twitter | Github | fgb06@mail.aub.edu | baly.fady@gmail.com