本存储库提供了一个36亿参数的日语GPT-NeoX模型。该模型基于 rinna/japanese-gpt-neox-3.6b-instruction-sft-v2 并经过对齐,可用作指令跟随的对话系统。
模型架构
这是一个36层,2816隐藏单元的基于Transformer的语言模型。
RLHF
模型的行为经过强化学习来对齐输入指令。特别地,模型经过两个阶段的训练,即有监督微调(SFT)和基于 PPO 的强化学习(RL)。
PPO vs SFT评估
我们对100个提示进行了人工评估和基于ChatGPT的自动评估,以评估通过强化学习获得的性能提升。
1239321 vs. 12310321 | win | tie | loss |
---|---|---|---|
Human evaluation | 47 % | 30% | 23% |
ChatGPT auto. evaluation | 63 % | 3% | 34% |
强化学习
我们使用了 CarperAI/trlx 和其实现的PPO算法进行RL阶段的训练。
RL数据是以下数据集的子集,并已翻译为日语。
模型系列
Variant | Link |
---|---|
3.6B PPO | 12313321 |
3.6B SFT-v2 | 12314321 |
3.6B SFT | 12315321 |
3.6B pretrained | 12316321 |
作者
我们采用了一种特殊的格式来构造输入。
以下是一个构造输入的示例。
prompt = [ { "speaker": "ユーザー", "text": "コンタクトレンズを慣れるにはどうすればよいですか?" }, { "speaker": "システム", "text": "これについて具体的に説明していただけますか?何が難しいのでしょうか?" }, { "speaker": "ユーザー", "text": "目が痛いのです。" }, { "speaker": "システム", "text": "分かりました、コンタクトレンズをつけると目がかゆくなるということですね。思った以上にレンズを外す必要があるでしょうか?" }, { "speaker": "ユーザー", "text": "いえ、レンズは外しませんが、目が赤くなるんです。" } ] prompt = [ f"{uttr['speaker']}: {uttr['text']}" for uttr in prompt ] prompt = "<NL>".join(prompt) prompt = ( prompt + "<NL>" + "システム: " ) print(prompt) # "ユーザー: コンタクトレンズを慣れるにはどうすればよいですか?<NL>システム: これについて具体的に説明していただけますか?何が難しいのでしょうか?<NL>ユーザー: 目が痛いのです。<NL>システム: 分かりました、コンタクトレンズをつけると目がかゆくなるということですね。思った以上にレンズを外す必要があるでしょうか?<NL>ユーザー: いえ、レンズは外しませんが、目が赤くなるんです。<NL>システム: "
import torch from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo", use_fast=False) model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo") if torch.cuda.is_available(): model = model.to("cuda") token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") with torch.no_grad(): output_ids = model.generate( token_ids.to(model.device), do_sample=True, max_new_tokens=128, temperature=0.7, repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id ) output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):]) output = output.replace("<NL>", "\n") print(output) """それは、コンタクトレンズが目に合わないために起こることがあります。レンズが目の表面に長時間触れ続けることが原因となることがあります。また、コンタクトレンズが汚れている可能性もあります。コンタクトレンズケースを定期的に洗浄したり、コンタクトレンズを正しくフィットさせるようにしたりすることが役立ちます。</s>"""
该模型使用了基于 sentencepiece 的分词器。
print(tokenizer.tokenize("吾輩は猫である")) # ['吾', '輩', 'は', '猫', 'である'] # instead of ['▁', '吾', '輩', 'は', '猫', 'である'] as in rinna/japanese-gpt-1b
print(tokenizer.tokenize(" 吾輩は 猫である ")) # ['▁', '▁', '吾', '輩', 'は', '▁', '▁', '猫', 'である', '▁', '▁', '▁'] # instead of ['▁', '吾', '輩', 'は', '▁猫', 'である'] as in rinna/japanese-gpt-1b
good_tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b", use_fast=False) bad_tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b") print(good_tokenizer.decode(good_tokenizer.encode("გამარჯობა 吾輩は 猫である "))) # 'გამარჯობა 吾輩は 猫である </s>' print(bad_tokenizer.decode(bad_tokenizer.encode("გამარჯობა 吾輩は 猫である "))) # 'გამარ[UNK]ობა 吾輩は 猫である </s>'