模型:

nlpconnect/dpr-question_encoder_bert_uncased_L-2_H-128_A-2

英文

dpr-question_encoder_bert_uncased_L-2_H-128_A-2

这个模型(google/bert_uncased_L-2_H-128_A-2)是从头开始训练的,训练数据为:data.retriever.nq-adv-hn-train(facebookresearch/DPR)。在评估集上取得了以下结果:

评估数据

评估数据集:来自官方DPR GitHub的facebook-dpr-dev-dataset

model_name data_name num of queries num of passages R@10 R@20 R@50 R@100 R@100
nlpconnect/dpr-ctx_encoder_bert_uncased_L-2_H-128_A-2(our) nq-dev dataset 6445 199795 60.53% 68.28% 76.07% 80.98% 91.45%
nlpconnect/dpr-ctx_encoder_bert_uncased_L-12_H-128_A-2(our) nq-dev dataset 6445 199795 65.43% 71.99% 79.03% 83.24% 92.11%
*facebook/dpr-ctx_encoder-single-nq-base(hf/fb) nq-dev dataset 6445 199795 40.94% 49.27% 59.05% 66.00% 82.00%

评估数据集:UKPLab/beir测试数据,但我们只使用了前20万条文档。

model_name data_name num of queries num of passages R@10 R@20 R@50 R@100 R@100
nlpconnect/dpr-ctx_encoder_bert_uncased_L-2_H-128_A-2(our) nq-test dataset 3452 200001 49.68% 59.06% 69.40% 75.75% 89.28%
nlpconnect/dpr-ctx_encoder_bert_uncased_L-12_H-128_A-2(our) nq-test dataset 3452 200001 51.62% 61.09% 70.10% 76.07% 88.70%
*facebook/dpr-ctx_encoder-single-nq-base(hf/fb) nq-test dataset 3452 200001 32.93% 43.74% 56.95% 66.30% 83.92%

注意:*表示我们在相同的评估数据集上进行评估。

使用方法(HuggingFace Transformers)

passage_encoder = TFAutoModel.from_pretrained("nlpconnect/dpr-ctx_encoder_bert_uncased_L-12_H-128_A-2")
query_encoder = TFAutoModel.from_pretrained("nlpconnect/dpr-question_encoder_bert_uncased_L-12_H-128_A-2")

p_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/dpr-ctx_encoder_bert_uncased_L-12_H-128_A-2")
q_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/dpr-question_encoder_bert_uncased_L-12_H-128_A-2")

def get_title_text_combined(passage_dicts):
    res = []
    for p in passage_dicts:
        res.append(tuple((p['title'], p['text'])))
    return res
    
processed_passages = get_title_text_combined(passage_dicts)

def extracted_passage_embeddings(processed_passages, model_config):
    passage_inputs = tokenizer.batch_encode_plus(
                    processed_passages,
                    add_special_tokens=True,
                    truncation=True,
                    padding="max_length",
                    max_length=model_config.passage_max_seq_len,
                    return_token_type_ids=True
                )
    passage_embeddings = passage_encoder.predict([np.array(passage_inputs['input_ids']), 
                                                np.array(passage_inputs['attention_mask']), 
                                                np.array(passage_inputs['token_type_ids'])], 
                                                batch_size=512, 
                                                verbose=1)
    return passage_embeddings
    
passage_embeddings = extracted_passage_embeddings(processed_passages, model_config)


def extracted_query_embeddings(queries, model_config):
    query_inputs = tokenizer.batch_encode_plus(
                    queries,
                    add_special_tokens=True,
                    truncation=True,
                    padding="max_length",
                    max_length=model_config.query_max_seq_len,
                    return_token_type_ids=True
                )
    query_embeddings = query_encoder.predict([np.array(query_inputs['input_ids']), 
                                                np.array(query_inputs['attention_mask']), 
                                                np.array(query_inputs['token_type_ids'])], 
                                                batch_size=512, 
                                                verbose=1)
    return query_embeddings
    

query_embeddings = extracted_query_embeddings(queries, model_config)

训练超参数

训练过程中使用了以下超参数:

  • optimizer: None
  • training_precision: float32

框架版本

  • Transformers 4.15.0
  • TensorFlow 2.7.0
  • Tokenizers 0.10.3