模型:
facebook/rag-sequence-base
这是一份非微调版本的RAG-Sequence模型,根据Patrick Lewis、Ethan Perez、Aleksandara Piktus等人的论文 Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks 创建。
RAG由问题编码器、检索器和生成器组成。检索器应该是一个RagRetriever实例。问题编码器可以是可以使用AutoModel加载的任何模型,生成器可以是可以使用AutoModelForSeq2SeqLM加载的任何模型。
这个模型是一个非微调的RAG-Sequence模型,创建过程如下:
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, AutoTokenizer model = RagSequenceForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large") question_encoder_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base") generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer) model.config.use_dummy_dataset = True model.config.index_name = "exact" retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer) model.save_pretrained("./") tokenizer.save_pretrained("./") retriever.save_pretrained("./")
请注意,该模型是非区分大小写的,因此所有大写输入字母都会转换为小写。
注意:模型默认使用虚拟检索器。通过设置config.index_name="legacy"和config.use_dummy_dataset=False可以使用完整的检索器获得更好的结果。可以按以下方式对模型进行微调:
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base") retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base") model = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever) input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", "michael phelps", return_tensors="pt") outputs = model(input_dict["input_ids"], labels=input_dict["labels"]) loss = outputs.loss # train on loss