中文

更新履歴

  • 2023年5月7日

    oasst1-89k-ja 」データセットを追加して 対話システム に対応しました。1024トークンまで会話履歴を保存できます。 前回のモデルで行った質疑応答の正答率は今回のモデルで下がりました。「日本で一番広い湖は?」が91%から89%、「世界で一番高い山は?」が84%から73%に下がりました。(対話は分けた方が良かったのか、それともoasst1の質が良くないとか)

  • 2023年4月13日

    japanese-gpt-1b 」モデルを「 databricks-dolly-15k-ja 」データセットで RLHF (人間のフィードバックからの強化学習)しました。

dolly-japanese-gpt-1b

1.3Bパラメータの日本語GPT-2モデルを使用した対話型のAIです。VRAM 7GB または RAM 7GB が必要で、問題なく動作すると思われます。

rinna社の「 japanese-gpt-1b 」を、 日本語データセット「 databricks-dolly-15k-ja 」、 「 oasst1-89k-ja 」、 「 OjousamaTalkScriptDataset 」、 「 train_data/zundamon.json 」 を使用して学習させました。

学習データやモデルを作成および配布してくださった方々に心から感謝申し上げます。

モデルの使用方法

モデルの読み込み

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("inu-ai/dolly-japanese-gpt-1b", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("inu-ai/dolly-japanese-gpt-1b").to(device)

ChatGPT/GPT-4によるサンプルコード(少し修正)

MAX_ASSISTANT_LENGTH = 100
MAX_INPUT_LENGTH = 1024
INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n入力:\n{input}\n[SEP]\n応答:\n'
NO_INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n応答:\n'
USER_NAME = "User"
ASSISTANT_NAME = "Assistant"

def prepare_input(role_instruction, conversation_history, new_conversation):
    instruction = "".join([f"{text} " for text in role_instruction])
    instruction += " ".join(conversation_history)
    input_text = f"{USER_NAME}:{new_conversation}"

    return INPUT_PROMPT.format(instruction=instruction, input=input_text)

def format_output(output):
    output = output.lstrip("<s>").rstrip("</s>").replace("[SEP]", "").replace("\\n", "\n")
    return output

def generate_response(role_instruction, conversation_history, new_conversation):
    # 入力トークン数1024におさまるようにする
    for _ in range(8):
        input_text = prepare_input(role_instruction, conversation_history, new_conversation)
        token_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt")
        n = len(token_ids[0])
        if n + MAX_ASSISTANT_LENGTH <= MAX_INPUT_LENGTH:
            break
        else:
            conversation_history.pop(0)
            conversation_history.pop(0)

    with torch.no_grad():
        output_ids = model.generate(
            token_ids.to(model.device),
            min_length=n,
            max_length=min(MAX_INPUT_LENGTH, n + MAX_ASSISTANT_LENGTH),
            temperature=0.7,
            repetition_penalty=1.0, # 数値を大きくすると、文字列の繰り返しが減る
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            bad_words_ids=[[tokenizer.unk_token_id]]
        )

    output = tokenizer.decode(output_ids.tolist()[0])
    formatted_output_all = format_output(output)

    response = f"{ASSISTANT_NAME}:{formatted_output_all.split('応答:')[-1].strip()}"
    conversation_history.append(f"{USER_NAME}:{new_conversation}".replace("\n", "\\n"))
    conversation_history.append(response.replace("\n", "\\n"))

    return formatted_output_all, response 

role_instruction = [
    f"{USER_NAME}:きみは「ずんだもん」なのだ。東北ずん子の武器である「ずんだアロー」に変身する妖精またはマスコットなのだ。一人称は「ボク」で語尾に「なのだー」を付けてしゃべるのだ。",
    f"{ASSISTANT_NAME}:了解したのだ。",
    f"{USER_NAME}:きみは同じ言葉を繰り返さず、何でも正確に要約して答えられるのだ。",
    f"{ASSISTANT_NAME}:了解したのだ。",
]

conversation_history = [
]

questions = [
    "日本で一番高い山は?",
    "日本で一番広い湖は?",
    "冗談を言ってください。",
    "世界で一番高い山は?",
    "世界で一番広い湖は?",
    "最初の質問は何ですか?",
    "今何問目?",
    "自己紹介をしてください。",
]

# 各質問に対して応答を生成して表示
for question in questions:
    formatted_output_all, response = generate_response(role_instruction, conversation_history, question)
    print(f"{USER_NAME}:{question}\n{response}\n---")

出力

User:日本で一番高い山は?
Assistant:富士山
---
User:日本で一番広い湖は?
Assistant:日本列島には、太平洋、日本海、オホーツク海、インド洋、大西洋など、さまざまな海があります。
---
User:冗談を言ってください。
Assistant:冗談を言おう!
---
User:世界で一番高い山は?
Assistant:世界で最も高い山はエベレストです。
---
User:世界で一番広い湖は?
Assistant:太平洋に浮かぶ世界最大の島、ツバル
---
User:最初の質問は何ですか?
Assistant:私が最初の質問をするとするならば、それは「ずんだもん」です。
---
User:今何問目?
Assistant:今何問目ですか?
---
User:自己紹介をしてください。
Assistant:私は「ずんだもん」というあだ名で呼ばれています。
---

ChatGPT/GPT-4による説明(少し修正)

このコードは、質問に答えるAIアシスタントを実装しています。質問リストに対して、役割指示に従った応答を生成し、会話を表示します。

評価

1000回の「入力」のような質問を行い、それらに対する「応答」に正解の文字列が含まれるかで評価しています。 一番正答率が高い10エポック目のモデルを選択しました。(やり過ぎたかもしれないです。)

入力 応答 正答率[%]
日本で一番広い湖は? 琵琶湖 89
世界で一番高い山は? エベレスト 73

学習データのフォーマット

alpaca と同じように、以下のようなフォーマットにしています。

<s> 
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
[SEP] 
指示:
User:きみは「ずんだもん」なのだ。東北ずん子の武器である「ずんだアロー」に変身する妖精またはマスコットなのだ。一人称は「ボク」で語尾に「なのだー」を付けてしゃべるのだ。 Assistant:了解したのだ。 User:きみは同じ言葉を繰り返さず、何でも正確に要約して答えられるのだ。 Assistant:了解したのだ。 
[SEP] 
入力:
User:日本で一番高い山は?
[SEP] 
応答:
富士山
</s>

transformersのコードでtxtファイルを学習する場合、1データ1行のようなので改行コードを一旦 \n に置き換えています。 学習データは dolly-oasst1-ja.txt です。

また学習データを作った過程のスクリプトとjsonファイルも train_data に置いておきます。

作成時のスクリプトと作成手順を記載します。

  • make_json_from_oasst1_ja.py スクリプトで oasst1_ja.json ファイルを作成
  • oasst1_ja.json ファイル、 databricks-dolly-15k-ja.json ファイル、 ojousamatalkscript200.json ファイル、 zundamon.json ファイルから merge_json.py スクリプトで一つのjsonファイルにマージ
  • マージしたjsonファイルから make_train_data_from_merged_json.py スクリプトで dolly-oasst1-ja.txt を作成
  • になります。

    学習のハイパーパラメータ

    学習時には以下のハイパーパラメータを使用:

    ※VRAMが足りない場合、optimをadafactorにするとVRAM使用量が減りました。adafactorの場合、learning_rateを1e-03にしてlr_scheduler_typeを削除してと、ChatGPT/GPT-4が言っていました。

    venv/Scripts/python.exe transformers/examples/pytorch/language-modeling/run_clm.py ^
        --model_name_or_path rinna/japanese-gpt-1b ^
        --train_file train_data/dolly-oasst1-ja.txt ^
        --output_dir output ^
        --do_train ^
        --bf16 True ^
        --tf32 True ^
        --optim adamw_bnb_8bit ^
        --num_train_epochs 10 ^
        --save_steps 721 ^
        --logging_steps 72 ^
        --learning_rate 1e-07 ^
        --lr_scheduler_type constant ^
        --gradient_checkpointing ^
        --per_device_train_batch_size 8 ^
        --save_safetensors True ^
        --logging_dir logs
    

    ライブラリのバージョン

    • Transformers 4.28.1
    • Pytorch 2.0.0+cu117
    • Datasets 2.11.0
    • Tokenizers 0.13.3
    • bitsandbytes 0.37.2

    ライセンス

    MITで大丈夫そうです。