模型:
facebook/dragon-plus-context-encoder
DRAGON+是一个基于BERT-base大小的密集检索器,它是由 RetroMAE 初始化,并通过MS MARCO语料库的数据增强进行进一步训练,遵循 How to Train Your DRAGON: Diverse Augmentation Towards Generalizable Dense Retrieval 中描述的方法。
相关的GitHub存储库在此可用 https://github.com/facebookresearch/dpr-scale/tree/main/dragon 。我们使用非对称双编码器,具有两个不同参数化的编码器。还提供以下模型:
Model | Initialization | MARCO Dev | BEIR | Query Encoder Path | Context Encoder Path |
---|---|---|---|---|---|
DRAGON+ | Shitao/RetroMAE | 39.0 | 47.4 | 1234321 | 1235321 |
DRAGON-RoBERTa | RoBERTa-base | 39.4 | 47.2 | 1236321 | 1237321 |
直接在HuggingFace transformers中使用该模型。
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('facebook/dragon-plus-query-encoder')
query_encoder = AutoModel.from_pretrained('facebook/dragon-plus-query-encoder')
context_encoder = AutoModel.from_pretrained('facebook/dragon-plus-context-encoder')
# We use msmarco query and passages as an example
query = "Where was Marie Curie born?"
contexts = [
"Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.",
"Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace."
]
# Apply tokenizer
query_input = tokenizer(query, return_tensors='pt')
ctx_input = tokenizer(contexts, padding=True, truncation=True, return_tensors='pt')
# Compute embeddings: take the last-layer hidden state of the [CLS] token
query_emb = query_encoder(**query_input).last_hidden_state[:, 0, :]
ctx_emb = context_encoder(**ctx_input).last_hidden_state[:, 0, :]
# Compute similarity scores using dot product
score1 = query_emb @ ctx_emb[0] # 396.5625
score2 = query_emb @ ctx_emb[1] # 393.8340