模型:
facebook/rag-sequence-base
这是一份非微调版本的RAG-Sequence模型,根据Patrick Lewis、Ethan Perez、Aleksandara Piktus等人的论文 Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks 创建。
RAG由问题编码器、检索器和生成器组成。检索器应该是一个RagRetriever实例。问题编码器可以是可以使用AutoModel加载的任何模型,生成器可以是可以使用AutoModelForSeq2SeqLM加载的任何模型。
这个模型是一个非微调的RAG-Sequence模型,创建过程如下:
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, AutoTokenizer
model = RagSequenceForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large")
question_encoder_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
model.config.use_dummy_dataset = True
model.config.index_name = "exact"
retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer)
model.save_pretrained("./")
tokenizer.save_pretrained("./")
retriever.save_pretrained("./")
请注意,该模型是非区分大小写的,因此所有大写输入字母都会转换为小写。
注意:模型默认使用虚拟检索器。通过设置config.index_name="legacy"和config.use_dummy_dataset=False可以使用完整的检索器获得更好的结果。可以按以下方式对模型进行微调:
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base")
model = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", "michael phelps", return_tensors="pt")
outputs = model(input_dict["input_ids"], labels=input_dict["labels"])
loss = outputs.loss
# train on loss