语言:
from typing import List, TypedDict from dataclasses import dataclass from itertools import chain from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch @dataclass class H2PersonaChatHyperparametersV1: """ chat_history_pair_length: int - dialogue pairs amount from the end """ model_name: str = "facebook/bart-base" chat_history_pair_length: int = 7 persona_max_length: int = 14 chat_max_length: int = 25 debug_status: int = 0 class PersonaChatDatasetSampleV1(TypedDict): """ persona: List[str] - person fact sentence set history: List[str] - chating history """ persona: List[str] history: List[str] sample_id: str class H2Seq2SeqInferenceSampleDictV1(TypedDict): input_ids: List[int] attention_mask: List[int] class H2Seq2SeqInferenceSampleDictV2(TypedDict): input_ids: torch.Tensor attention_mask: torch.Tensor def flat_list(list_of_lists: List[List]) -> List: return list(chain.from_iterable(list_of_lists)) class H2Seq2SeqInferencePersonaSampleV1: def __init__( self, dataset_sample: PersonaChatDatasetSampleV1, tokenizer: AutoTokenizer, hyperparameters: H2PersonaChatHyperparametersV1, ) -> None: self.dataset_sample = dataset_sample self.tokenizer = tokenizer self.hyperparameters = hyperparameters def add_spaces_after( self, items: List[str], ) -> List[str]: items = [item + " " for item in items] return items @property def bos_token_id(self): if "t5" in self.hyperparameters.model_name: return [] if self.tokenizer.bos_token_id is None: return [] return [self.tokenizer.bos_token_id] @property def eos_token_id(self): if self.tokenizer.eos_token_id is None: return [] return [self.tokenizer.eos_token_id] def add_sep_beetween(self, items: List[str], sep=" EOS ") -> List[str]: for i in range(1, len(items)): items[i] = sep + items[i] return items def add_spaces_between(self, items: List[str]) -> List[str]: items = self.add_spaces_after(items) items[-1] = items[-1].strip() return items def get_sample(self) -> H2Seq2SeqInferenceSampleDictV1: dialog_history = self.dataset_sample["history"] dialog_history = dialog_history[-self.hyperparameters.chat_history_pair_length * 2 - 1 :] dialog_history = self.add_sep_beetween(dialog_history) persona = self.dataset_sample["persona"] persona = self.add_sep_beetween( persona, sep=" ", ) KNOWLEDGE_IDS = self.tokenizer.encode( " [KNOWLEDGE] ", add_special_tokens=False, ) CONTEXT_IDS = self.tokenizer.encode( " [CONTEXT] ", add_special_tokens=False, ) encoded_history = self.tokenizer.batch_encode_plus( dialog_history, add_special_tokens=False, truncation=True, max_length=self.hyperparameters.chat_max_length, ) encoded_history = flat_list(encoded_history["input_ids"]) encoded_persona = self.tokenizer.batch_encode_plus( persona, add_special_tokens=False, truncation=True, max_length=self.hyperparameters.persona_max_length, ) encoded_persona = flat_list(encoded_persona["input_ids"]) input_ids = [ *self.bos_token_id, *CONTEXT_IDS, *encoded_history, *KNOWLEDGE_IDS, *encoded_persona, *self.eos_token_id, ] attention_mask = [1] * len(input_ids) return H2Seq2SeqInferenceSampleDictV1( input_ids=input_ids, attention_mask=attention_mask, ) class DialogBotV1: def __init__( self, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer, hyperparameters: H2PersonaChatHyperparametersV1, history: List[str] = None, persona: List[str] = None, device: str = "cuda", shuffle_persona: bool = True, ): self.model = model self.tokenizer = tokenizer self.hyperparameters = hyperparameters self.device = device self.shuffle_persona = shuffle_persona self.debug_status = hyperparameters.debug_status if history is None: self.history = [] self.history = history if persona is None: self.persona = [] self.persona = persona def _get_sample( self, persona: List[str], history: List[str], ) -> H2Seq2SeqInferenceSampleDictV1: dataset_sample = PersonaChatDatasetSampleV1( persona=persona, history=history, ) sample = H2Seq2SeqInferencePersonaSampleV1( tokenizer=self.tokenizer, hyperparameters=self.hyperparameters, dataset_sample=dataset_sample, ) sample = sample.get_sample() print(self.tokenizer.decode(sample['input_ids'])) for key in sample.keys(): sample[key] = torch.tensor(sample[key]).unsqueeze(0).to(self.device) return sample def next_response( self, **generation_params, ) -> str: sample = self._get_sample( persona=self.persona, history=self.history, ) answer = self.generate_response( sample, **generation_params, ) answer = self.tokenizer.batch_decode( answer, skip_special_tokens=True, ) self.history.append(answer[0]) return answer[0] def generate_response( self, sample: H2Seq2SeqInferenceSampleDictV1, **generation_params, ): """ generation_params - https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/text_generation """ with torch.no_grad(): return self.model.generate( **sample, **generation_params, ) PRETRAINED_MODEL_NAME_OR_PATH = "DeepPavlov/bart-base-en-persona-chat" PAIR_DIALOG_HISTORY_LENGTH = 2 # CHAT_MAX_LENGTH for single sentence, in tokens CHAT_MAX_LENGTH = 25 # PERSONA_MAX_LENGTH for single sentence, in tokens PERSONA_MAX_LENGTH = 19 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModelForSeq2SeqLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH) model.to(device) model.eval() tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH) if torch.cuda.is_available(): model.half() hyperparameters = H2PersonaChatHyperparametersV1( chat_history_pair_length=PAIR_DIALOG_HISTORY_LENGTH, persona_max_length=PERSONA_MAX_LENGTH, chat_max_length=CHAT_MAX_LENGTH, model_name=PRETRAINED_MODEL_NAME_OR_PATH, ) persona = [ "I like to play guitar.", "I hate onions." ] history = [ "I hate to talk about politics, what about you?" ] persona_bot = DialogBotV1( model=model, tokenizer=tokenizer, hyperparameters=hyperparameters, history=history, persona=persona, device=device, ) GENERATION_PARAMS = { "max_new_tokens": 60, "penalty_alpha": 0.15, "top_k": 10 } response = persona_bot.next_response( **GENERATION_PARAMS, ) print(response) # i am not into politics. i am into music.
[需要更多信息]
def persona_chat_dataset_tranformer_v1( initial_dataset_path: str, output_folder: str, ) -> None: """ example persona_chat_dataset_tranformer_v1( initial_dataset_path="./datasets/persona_chat/persona_chat.json", output_folder="./datasets/persona_chat", ) """ assert initial_dataset_path is not None, "initial_dataset_path is None" assert output_folder is not None, "output_folder is None" with open(initial_dataset_path) as f: initial_dataset = json.load(f) train_dataset = initial_dataset["train"] val_len = len(initial_dataset["valid"]) valid_dataset = initial_dataset["valid"][: val_len // 2] test_dataset = initial_dataset["valid"][val_len // 2 :] print( f"Dataset lengths: train {len(train_dataset)}, valid {len(valid_dataset)}, test {len(test_dataset)}" ) # save json files with open(output_folder + "/train.json", "w") as f: json.dump(train_dataset, f) with open(output_folder + "/valid.json", "w") as f: json.dump(valid_dataset, f) with open(output_folder + "/test.json", "w") as f: json.dump(test_dataset, f) print("Datasets saved.")