模型:

facebook/rag-token-base

英文

RAG

这是一篇关于RAG-Token模型的非微调版本,参考自Patrick Lewis、Ethan Perez、Aleksandara Piktus等人的论文 Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks

RAG由一个问题编码器、检索器和生成器组成。检索器应该是一个RagRetriever实例。问题编码器可以是可以使用AutoModel加载的任何模型,而生成器可以是可以使用AutoModelForSeq2SeqLM加载的任何模型。

这个模型是一个非微调的RAG-Token模型,创建过程如下:

from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, AutoTokenizer

model = RagTokenForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large")

question_encoder_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
model.config.use_dummy_dataset = True
model.config.index_name = "exact"
retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer)

model.save_pretrained("./")
tokenizer.save_pretrained("./")
retriever.save_pretrained("./")

请注意,该模型是不区分大小写的,所以所有大写输入字母都会转换为小写。

用法:

注意:模型使用默认的dummy检索器。通过设置config.index_name="legacy"和config.use_dummy_dataset=False,使用完整的检索器可以获得更好的结果。可以按照以下步骤微调模型:

from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration

tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
retriever = RagRetriever.from_pretrained("facebook/rag-token-base")
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)

input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", "michael phelps", return_tensors="pt") 

outputs = model(input_dict["input_ids"], labels=input_dict["labels"])

loss = outputs.loss

# train on loss