模型:
facebook/rag-token-base
这是一篇关于RAG-Token模型的非微调版本,参考自Patrick Lewis、Ethan Perez、Aleksandara Piktus等人的论文 Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks 。
RAG由一个问题编码器、检索器和生成器组成。检索器应该是一个RagRetriever实例。问题编码器可以是可以使用AutoModel加载的任何模型,而生成器可以是可以使用AutoModelForSeq2SeqLM加载的任何模型。
这个模型是一个非微调的RAG-Token模型,创建过程如下:
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, AutoTokenizer model = RagTokenForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large") question_encoder_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base") generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer) model.config.use_dummy_dataset = True model.config.index_name = "exact" retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer) model.save_pretrained("./") tokenizer.save_pretrained("./") retriever.save_pretrained("./")
请注意,该模型是不区分大小写的,所以所有大写输入字母都会转换为小写。
注意:模型使用默认的dummy检索器。通过设置config.index_name="legacy"和config.use_dummy_dataset=False,使用完整的检索器可以获得更好的结果。可以按照以下步骤微调模型:
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") retriever = RagRetriever.from_pretrained("facebook/rag-token-base") model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever) input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", "michael phelps", return_tensors="pt") outputs = model(input_dict["input_ids"], labels=input_dict["labels"]) loss = outputs.loss # train on loss