通过动态Top-K调整优化检索增强以实现高效问答

2023年11月07日 由 alex 发表 494 0

在问答系统领域,top-k检索设置的静态性质提出了一个重大挑战,通常会导致信息过载或稀缺,这反过来又会影响所生成回复的效率和准确性。这种一刀切的方法忽略了不同查询复杂性的细微需求,导致语言模型功能的利用率不理想,并导致不必要的计算支出。我们的解决方案通过训练交叉编码器引入了动态范式转变,该编码器能够实时熟练地调整检索广度。通过评估每个查询的复杂性,系统预测最适合检索的top-k值,确保每个问题都能得到量身定制、精确且资源高效的回答。这不仅简化了检索过程,而且显著提高了问答系统的整体性能。


1


以下是流程的更精简版本:


训练


1. 我们将每个文档分割成块,并在矢量数据库中对这些块进行索引。


2. 从k=1到10迭代,我们检索每次迭代的前k段,并提示语言模型根据给定的问题生成答案。


3. 使用自定义的RankCorrectnessEvaluator,我们根据参考答案和查询对十个候选答案进行排名。


4. 我们编译了一个训练数据集,该数据集包括查询、文档和从排名步骤导出的每个查询-文档对的最佳top k值。


5. 该数据集用于训练交叉编码器,该编码器通过从查询和文档上下文中学习来预测前k值。


推论


1. 我们首先将文档划分为块,并在矢量数据库中对这些块进行索引。


2. 对于给定的文档和查询,我们预测前k值,并将其输入到检索器中。


3. 大型语言模型接收前k个段落作为上下文来生成答案。


4. 我们通过使用CorrectnessEvaluator、SimilarityEvaluator和TokenCounter等评估指标来评估静态与动态检索策略的有效性。


检索增强生成


检索增强生成(RAG)是一种强大的技术,它结合了信息检索和语言生成的优势,以提高机器生成文本的质量。通过检索相关的文档段落并将其用作额外的上下文,RAG模型可以产生更准确、信息丰富和上下文相关的答案。这种方法对于复杂的问答任务尤其重要,因为在这些任务中,对广泛的主题有深入的理解是产生连贯和事实正确的回答所必需的。此外,由于LLM一直训练到某个时间点,它们缺乏访问或整合训练后事件或知识的能力,这使得检索增强方法对于最新和扩展的内容理解至关重要。


Weaviate矢量数据库


矢量数据库是一种专门用于处理矢量嵌入的数据库,矢量嵌入是数据点的高维表示,通常由机器学习模型生成。这些嵌入捕获数据点之间的语义关系,从而实现高效的相似性搜索。这对于检索增强生成(RAG)系统尤其重要,因为它们依赖于从大型语料库中快速找到最相关的文档段落,以帮助生成准确且与上下文相关的响应。Weaviate是一个开源的矢量数据库。它允许你存储来自你喜爱的ML模型的数据对象和向量嵌入,并无缝扩展到数十亿个数据对象中。


LlamaIndex端到端评估


LlamaIndex为评估检索增强生成系统提供了一个强大的框架,确保从数据检索到最终响应生成的整个流程有效运行并提供准确的结果。利用LlamaIndex的评估工具至关重要,因为它们提供了一套全面的指标和评估器,可以诊断和提高RAG系统的性能,从而获得更可靠和与上下文相关的答案,而无需进行详尽的手动审查。在这个实验中,我使用以下:


1. 正确性评估器:它根据参考答案评估生成答案的相关性和正确性。


2. 嵌入相似性评估器:它通过语义相似性来评估生成答案与参考答案的质量。


除了这两个指标外,我还使用LLM总令牌计数作为额外的指标来评估不同的策略。


实施


让我们从导入必要的库开始。


import os
import re
import pickle
import openai
import tiktoken
import random
import ast
import time
import pandas as pd
import weaviate
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset
from llama_index.vector_stores import WeaviateVectorStore
from IPython.display import Markdown, display
from llama_index import QueryBundle
from llama_index.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index import Document
from typing import Any, List, Optional
from tqdm.auto import tqdm
from llama_index import (
    VectorStoreIndex,
    SimpleDirectoryReader,
    ServiceContext,
    Response,
    set_global_service_context
)
from llama_index.storage.storage_context import StorageContext
from llama_index.vector_stores import VectorStoreQuery
from llama_index.schema import NodeWithScore
from llama_index.embeddings import OpenAIEmbedding
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.llms import OpenAI
from llama_index.prompts import PromptTemplate
from llama_index.llms import ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate
from llama_index import Document
from llama_index.evaluation import SemanticSimilarityEvaluator
from llama_index.embeddings import SimilarityMode
from llama_index.evaluation import CorrectnessEvaluator
from llama_index.evaluation.eval_utils import get_responses, get_results_df
from llama_index.callbacks import CallbackManager, TokenCountingHandler
from dotenv import load_dotenv


数据准备


实验是使用QASPER数据集进行的,该数据集是一个用于科研论文问答的数据集。在下面的代码中,我们将提取论文文本、问题和答案,并创建训练集和测试集。


# Download QASPER dataset from HuggingFace https://huggingface.co/datasets/allenai/qasper
dataset = load_dataset("allenai/qasper")
# Split the dataset into train, validation, and test splits
train_dataset = dataset["train"]
validation_dataset = dataset["validation"]
test_dataset = dataset["test"]
random.seed(42)  # Set a random seed for reproducibility
# Randomly sample 800 rows from the training split
train_sampled_indices = random.sample(range(len(train_dataset)), 800)
train_samples = [train_dataset[i] for i in train_sampled_indices]

# Randomly sample 100 rows from the test split
test_sampled_indices = random.sample(range(len(test_dataset)), 80)
test_samples = [test_dataset[i] for i in test_sampled_indices]
# Get full text paper data , questions on the paper from training samples of QASPER to generate training dataset for cross-encoder finetuning
# Utility function to get full-text of the research papers from the dataset
def get_full_text(sample: dict) -> str:
    """
    :param dict sample: the row sample from QASPER
    """
    title = sample["title"]
    abstract = sample["abstract"]
    sections_list = sample["full_text"]["section_name"]
    paragraph_list = sample["full_text"]["paragraphs"]
    combined_sections_with_paras = ""
    if len(sections_list) == len(paragraph_list):
        combined_sections_with_paras += title + "\t"
        combined_sections_with_paras += abstract + "\t"
        for index in range(0, len(sections_list)):
            combined_sections_with_paras += str(sections_list[index]) + "\t"
            combined_sections_with_paras += "".join(paragraph_list[index])
        return combined_sections_with_paras
    else:
        print("Not the same number of sections as paragraphs list")
# utility function to extract list of questions from the dataset
def get_questions(sample: dict) -> List[str]:
    """
    :param dict sample: the row sample from QASPER
    """
    questions_list = sample["qas"]["question"]
    return questions_list
# Utility function to extract answers from the dataset
def get_answers(sample: dict) -> List[str]:
    """
    :param dict sample: the row sample from the train split of QASPER
    """
    final_answers_list = []
    answers = sample["qas"]["answers"]
    for answer in answers:
        local_answer = ""
        types_of_answers = answer["answer"][0]
        if types_of_answers["unanswerable"] == False:
            if types_of_answers["free_form_answer"] != "":
                local_answer = types_of_answers["free_form_answer"]
            else:
                local_answer = "Unacceptable"
        else:
            local_answer = "Unacceptable"
        final_answers_list.append(local_answer)
    return final_answers_list

doc_qa_dict_list = []
eval_doc_qa_answer_list = []
for train_sample in train_samples:
    full_text = get_full_text(train_sample)
    questions_list = get_questions(train_sample)
    answers_list = get_answers(train_sample)
    local_dict = {
        "paper": full_text,
        "questions": questions_list,
        "answers": answers_list,
    }
    doc_qa_dict_list.append(local_dict)
for test_sample in test_samples:
    full_text = get_full_text(test_sample)
    questions_list = get_questions(test_sample)
    answers_list = get_answers(test_sample)
    local_dict = {
        "paper": full_text,
        "questions": questions_list,
        "answers": answers_list,
    }
    eval_doc_qa_answer_list.append(local_dict)


设置Weaviate


你可以通过创建帐户并按照此处提到的说明创建和查询索引来设置Weaviate矢量数据库。


# cloud
client = weaviate.Client(
    url="https://....weaviate.network") #replace by url by your personal client url
client.schema.get()
weaviate_vector_store = WeaviateVectorStore(
    weaviate_client=client, index_name="LlamaIndex"
)


自定义RankCorrectnessEvaluator


我查阅了LlamaIndex的文档,找不到一个可以对多个回复进行排名的评估者。PairwiseComparisonEvaluator本质上将2个响应wrt与参考答案和查询进行比较,但在本实验中,我们有多个响应。因此,我创建了RankCorrectnessEvaluator,它在检索过程中对与相应k值相关联的多个响应进行排名。你可以查看提示以了解更多详细信息。


from typing import Any, Optional, Sequence, Union, Dict
from llama_index.evaluation.base import BaseEvaluator, EvaluationResult
from llama_index.indices.service_context import ServiceContext
from llama_index.prompts import (
    BasePromptTemplate,
    ChatMessage,
    ChatPromptTemplate,
    MessageRole,
    PromptTemplate,
)
from llama_index.prompts.mixin import PromptDictType
RANKING_SYSTEM_TEMPLATE = """
You are an expert evaluation system for a question answering chatbot.
You are given the following information:
- a user query,
- a reference answer, and
- a list of generated answers, each associated with a different 'k' value.
Your job is to rank the generated answers in order of correctness and relevance to the user query and the reference answer, from best to worst.
Correctness should be judged based on:
- The number of overlapping tokens with the reference answer.
- The absence of incorrect information not present in the reference answer.
- The lack of unnecessary or irrelevant tokens that do not contribute to answering the query.
Please provide a ranked list of the 'k' values associated with the generated answers, starting with the 'k' value of the best answer and ending with the 'k' value of the worst answer.
Do not return answers in any other format.
You are given the following information:
- a user query,
- a reference answer, and
- generated answers.
User Query
query
Reference Answer
reference_answer
Generated Answers
k_1: answer_1
k_2: answer_2
...
k_10: answer_10
Based on the information provided and the criteria for correctness, rank the 'k' values from best to worst. 
For example:
["k_7", "k_2", "k_9", ..., "k_3"]
"""
DEFAULT_USER_TEMPLATE = """
## User Query
{query}
## Reference Answer
{reference_answer}
## Generated Answers
{generated_answers}
"""
DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
    message_templates=[
        ChatMessage(role=MessageRole.SYSTEM, content=RANKING_SYSTEM_TEMPLATE),
        ChatMessage(role=MessageRole.USER, content=DEFAULT_USER_TEMPLATE),
    ]
)
class RankCorrectnessEvaluator(BaseEvaluator):
    """Rank correctness evaluator.
    Evaluates and ranks the correctness of multiple generated answers for a question answering system.
    This evaluator depends on `reference` answer to be provided, in addition to the
    query string and multiple response strings.
    It outputs a ranked list of 'k' values associated with the generated answers.
    Args:
        service_context (Optional[ServiceContext]): Service context.
        eval_template (Optional[Union[BasePromptTemplate, str]]):
            Template for the evaluation prompt.
    """
    def __init__(
        self,
        service_context: Optional[ServiceContext] = None,
        eval_template: Optional[Union[BasePromptTemplate, str]] = None,
    ) -> None:
        self._service_context = service_context or ServiceContext.from_defaults()
        self._eval_template: BasePromptTemplate
        if isinstance(eval_template, str):
            self._eval_template = PromptTemplate(eval_template)
        else:
            self._eval_template = eval_template or DEFAULT_EVAL_TEMPLATE
    def _get_prompts(self) -> PromptDictType:
        """Get prompts."""
        return {
            "eval_template": self._eval_template,
        }
    def _update_prompts(self, prompts: PromptDictType) -> None:
        """Update prompts."""
        if "eval_template" in prompts:
            self._eval_template = prompts["eval_template"]
    async def aevaluate(
        self,
        query: Optional[str] = None,
        responses: Optional[Dict[str, str]] = None,
        reference: Optional[str] = None,
        **kwargs: Any,
    ) -> EvaluationResult:
        del kwargs  # Unused
        if query is None or responses is None or reference is None:
            raise ValueError("query, responses, and reference must be provided")
        generated_answers_str = "\n".join(
            [f"{k}: {answer}" for k, answer in responses.items()]
        )
        eval_response = await self._service_context.llm_predictor.apredict(
            prompt=self._eval_template,
            query=query,
            generated_answers=generated_answers_str,
            reference_answer=reference,
        )
        return EvaluationResult(
            query=query,
            response=eval_response
        )
train_evaluator_service_context = ServiceContext.from_defaults(llm=OpenAI("gpt-4"))
train_evaluator = RankCorrectnessEvaluator(service_context=train_evaluator_service_context)


检索循环


这里,我们从k=1迭代到10,并将所有响应存储在字典中,其中键是k值,值是相应的响应。接下来,我们将这些响应发送到RankCorrectivity Evaluator模块,以返回最相关或最正确的响应的k值。


topk_training_dataset = []
for paper in tqdm(doc_qa_dict_list[:100]): 
    try: ## safety against any openai error
        questions_list = paper["questions"]
        documents = [Document(text=paper["paper"])]
        reference_answers_list = paper["answers"]
        assert len(questions_list) == len(reference_answers_list)
        for question, reference_answer in zip(questions_list, reference_answers_list):
            responses = {}
            if reference_answer == "Unacceptable":
                continue
            for k_val in tqdm(range(1, 11)):
                service_context = ServiceContext.from_defaults(chunk_size=512)
                node_parser = service_context.node_parser
                nodes = node_parser.get_nodes_from_documents(documents)
                storage_context = StorageContext.from_defaults(vector_store=weaviate_vector_store)
                storage_context.docstore.add_documents(nodes)
                weaviate_vector_index = VectorStoreIndex(nodes, storage_context=storage_context)
                weaviate_vector_retriever = VectorIndexRetriever(index=weaviate_vector_index, similarity_top_k=k_val)
                query_engine = RetrieverQueryEngine.from_args(weaviate_vector_retriever)
                response = query_engine.query(question)
                responses["k_"+str(k_val)] = response.response
            # rankcorrectness evaluator
            result = await train_evaluator.aevaluate(
                query=question,
                responses=responses,
                reference=reference_answer,
            )
            try:
                result = result.response
                eval_response_list = ast.literal_eval(result)
                ranked_k_values = [int(k.split('_')[1]) for k in eval_response_list]
            except:
                try:
                    list_pattern = r'\["k_[0-9]+(?:", "k_[0-9]+)*"\]'
                    match = re.search(list_pattern, result.response)
                    if match:
                        list_str = match.group(0)
                        eval_response_list = ast.literal_eval(list_str)
                        ranked_k_values = [int(k.split('_')[1]) for k in eval_response_list]
                    else:
                        continue
                except:
                    continue
            # insert train dataset
            best_response_insertion = (paper['paper'], question, str(ranked_k_values[0]))
            topk_training_dataset.append(best_response_insertion)
            with open('training_tuples.pkl', 'wb') as file:
                pickle.dump(topk_training_dataset, file)
            time.sleep(10)
    except:
      pass


汇总文档


交叉编码器bert模型的最大输入长度为512个标记。因此,我使用以下提示来总结这篇长篇背景文件。然后,我将摘要文档和查询以及标签(即基本事实顶部k值)提供给交叉编码器进行训练。


system_prompt = f"""Your task is to summarize the following document into a concise version of approximately 500 words. """Your task is to summarize the following document into a concise version of approximately 500 words. 
Focus on capturing the main themes, essential points, and key arguments presented. Omit any extraneous details or repetitive 
information to create a clear, coherent, and comprehensive summary that conveys the document's core message and intent. 
Please ensure the summary is well-structured, with a clear beginning, middle, and end, and maintains the original 
document's tone and perspective.
"""
def get_summary(document):
  res = openai.ChatCompletion.create(
      model="gpt-3.5-turbo-16k",
      messages=[
          {"role": "system", "content": system_prompt},
          {"role": "user", "content": document}
      ]
  )
  return res
final_topk_training_dataset = []
for doc, qst, topk in tqdm(loaded_topk_training_dataset):
    response = get_summary(doc)
    summarized_doc = response["choices"][0]["message"]["content"]
    final_topk_training_dataset.append((summarized_doc, qst, topk))


训练交叉编码器


现在,我们为基于BERT的交叉编码器模型设置了训练管道,该模型旨在将序列分类为十个标签中的一个。


import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset
class CEDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(int(self.labels[idx]) - 1) 
        return item
    def __len__(self):
        return len(self.labels)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
texts = [(doc, qst) for doc, qst, topk in loaded_final_topk_training_dataset]
labels = [topk for doc, qst, topk in loaded_final_topk_training_dataset]
encodings = tokenizer.batch_encode_plus(texts, truncation=True, padding=True, max_length=512)
ce_dataset = CEDataset(encodings, labels)
num_labels = 10  
ce_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)
training_args = TrainingArguments(
    output_dir='./results',          
    num_train_epochs=200,              
    per_device_train_batch_size=5,  
    warmup_steps=500,                
    weight_decay=0.01,               
    logging_dir='./logs',            
    logging_steps=20,
)
# Save the model
ce_model.save_pretrained('./ce_model')
tokenizer.save_pretrained('./ce_model')


2


基线RAG评估


在这里,我们正在设置评估指标。但最重要的是,我们将评估这些指标在策略中的性能,其中top_k是静态的,并且在查询引擎中是固定的,特别是在每个问题的检索器中,而我们的方法是根据问题和文档单独预测每个问题的top_k值。


gpt4 = OpenAI(temperature=0, model="gpt-4")
service_context_gpt4 = ServiceContext.from_defaults(llm=gpt4)
evaluator_similarity = SemanticSimilarityEvaluator()
evaluator_gpt4_correctness = CorrectnessEvaluator(service_context=service_context_gpt4)
token_counter = TokenCountingHandler(
    tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode
)
callback_manager = CallbackManager([token_counter])
llm = OpenAI()
service_context = ServiceContext.from_defaults(
    llm=llm, callback_manager=callback_manager, embed_model="local"
)
set_global_service_context(service_context)


similarity_scores_list = []
correctness_scores_list = []
total_llm_token_list = []
baseline_dict_list = []
for index, row in df_test.iterrows():
    documents = [Document(text=row["paper"])]
    query_list = row["questions"]
    reference_answers_list = row["answers"]
    number_of_accepted_queries = 0
    vector_index = VectorStoreIndex.from_documents(documents)
    query_engine = vector_index.as_query_engine(similarity_top_k=5)
    assert len(query_list) == len(reference_answers_list)
    similarity_local_score = 0
    correctness_local_score = 0
    total_llm_token_local = 0
    for index in range(0, len(query_list)):
        query = query_list[index]
        reference = reference_answers_list[index]
        if reference != "Unacceptable":
            number_of_accepted_queries += 1
            response = str(query_engine.query(query))
            baseline_dict = {
                "query": query,
                "response": response,
                "reference": reference,
            }
            baseline_dict_list.append(baseline_dict)
            similarity_eval_result = await evaluator_similarity.aevaluate(
                response=response, reference=reference
            )
            correctness_eval_result = await evaluator_gpt4_correctness.aevaluate(
                query=query,
                response=response,
                reference=reference,
            )
            similarity_score = similarity_eval_result.score
            correctness_score = correctness_eval_result.score
            total_llm_token = int(token_counter.total_llm_token_count)
            similarity_local_score += similarity_score
            correctness_local_score += correctness_score 
            total_llm_token_local += total_llm_token
            token_counter.reset_counts()
        else:
            pass
    if number_of_accepted_queries > 0:
        avg_similarity_local_score = (
            similarity_local_score / number_of_accepted_queries
        )
        similarity_scores_list.append(avg_similarity_local_score)
        avg_correctness_local_score = (
            correctness_local_score / number_of_accepted_queries
        )
        correctness_scores_list.append(avg_correctness_local_score)
        avg_total_llm_token_local = (
            total_llm_token_local / number_of_accepted_queries
        )
        total_llm_token_list.append(avg_total_llm_token_local)

overall_similarity_average_score = sum(similarity_scores_list) / len(
    similarity_scores_list
)
overall_correctness_average_score = sum(correctness_scores_list) / len(
    correctness_scores_list
)
overall_total_llm_token_average = sum(total_llm_token_list) / len(
    total_llm_token_list
)
df_responses = pd.DataFrame(baseline_dict_list)
df_responses.to_csv("Baseline_Responses_k_1.csv")


经过培训的Top-k RAG评估


这里推理过程中唯一的区别是,我们调用了基于BERT的交叉编码器模型来预测top_k值。


model_path = './ce_model'  
loaded_ce_tokenizer = BertTokenizer.from_pretrained(model_path)
loaded_ce_model = BertForSequenceClassification.from_pretrained(model_path)
loaded_ce_model.eval()
def predict_top_k(document, question):
    inputs = loaded_ce_tokenizer.encode_plus(
        question, 
        document, 
        add_special_tokens=True, 
        return_tensors="pt",
        max_length=512,  
        truncation=True,
        padding="max_length"
    )
    with torch.no_grad():  
        outputs = loaded_ce_model(**inputs)
        prediction = torch.argmax(outputs.logits, dim=-1)
        predicted_top_k = prediction.item() + 1 
    return predicted_top_k


similarity_scores_list = []
correctness_scores_list = []
total_llm_token_list = []
baseline_dict_list = []
for index, row in df_test.iterrows():
    documents = [Document(text=row["paper"])]
    query_list = row["questions"]
    reference_answers_list = row["answers"]
    number_of_accepted_queries = 0
    assert len(query_list) == len(reference_answers_list)
    similarity_local_score = 0
    correctness_local_score = 0
    total_llm_token_local = 0
    vector_index = VectorStoreIndex.from_documents(documents)
    response = get_summary(row["paper"])
    summarized_paper = response["choices"][0]["message"]["content"]
    for index in range(0, len(query_list)):
        query = query_list[index]
        reference = reference_answers_list[index]
        predicted_top_k = predict_top_k(summarized_paper, query)
        query_engine = vector_index.as_query_engine(similarity_top_k=predicted_top_k)
        if reference != "Unacceptable":
            number_of_accepted_queries += 1
            response = str(query_engine.query(query))
            baseline_dict = {
                "query": query,
                "response": response,
                "reference": reference,
            }
            baseline_dict_list.append(baseline_dict)
            similarity_eval_result = await evaluator_similarity.aevaluate(
                response=response, reference=reference
            )
            correctness_eval_result = await evaluator_gpt4_correctness.aevaluate(
                query=query,
                response=response,
                reference=reference,
            )
            similarity_score = similarity_eval_result.score
            correctness_score = correctness_eval_result.score
            total_llm_token = int(token_counter.total_llm_token_count)
            similarity_local_score += similarity_score
            correctness_local_score += correctness_score 
            total_llm_token_local += total_llm_token
            token_counter.reset_counts()
        else:
            pass
    if number_of_accepted_queries > 0:
        avg_similarity_local_score = (
            similarity_local_score / number_of_accepted_queries
        )
        similarity_scores_list.append(avg_similarity_local_score)
        avg_correctness_local_score = (
            correctness_local_score / number_of_accepted_queries
        )
        correctness_scores_list.append(avg_correctness_local_score)
        avg_total_llm_token_local = (
            total_llm_token_local / number_of_accepted_queries
        )
        total_llm_token_list.append(avg_total_llm_token_local)
overall_similarity_average_score = sum(similarity_scores_list) / len(
    similarity_scores_list
)
overall_correctness_average_score = sum(correctness_scores_list) / len(
    correctness_scores_list
)
overall_total_llm_token_average = sum(total_llm_token_list) / len(
    total_llm_token_list
)
df_responses = pd.DataFrame(baseline_dict_list)
df_responses.to_csv("Trained_topk.csv")


以下是测试集中查询的预测top-k值的分布图。


3


测试集中的一些问题被预测为具有top-k1:


What datasets were used in this work?
Which datasets do they experiment on?
Which language pairs do they evaluate on?
Which domain are the conversations in?


以下是预测的top-k 8的问题:


Did the survey provide insight into features commonly found to be predictive of abusive content on online platforms?
What are the opportunities presented by the use of Semantic Web technologies in Machine Translation?
Which other units of text do they experiment with (apart from BPE and ortographic syllables)?
Why is improvement on OntoNotes significantly smaller compared to improvement on WNUT 2017?


结果


该项目的初步成果确实很有希望。我拿了100份有多个问题的培训文件(占实际培训数据的1/8),为了测试,我又拿了25份有多问题的文件(约占实际测试数据的1/3)。这种限制是由于我个人资源的限制。训练的top-k模型展示了0.835831的竞争相似性得分,这表明top-k的动态选择不会损害检索到的信息的相关性。值得注意的是,2.928571的正确性分数表明生成的答案具有很高的准确性,甚至可以与基线k_10模型相媲美。这是在将平均总LLM令牌计数显著降低到1857.238095的同时实现的,与基线k_5和k_10模型相比,这是计算和资源效率的显著提高。我们都知道GPT-4有多贵,因此提供不必要的上下文会增加输入令牌的数量,进而增加成本。这些结果强调了交叉编码器在根据查询的复杂性调整检索过程方面的有效性,从而优化了答案的精度和系统的资源支出。


4


总结


总之,该项目通过引入适应单个查询复杂性的动态方法,代表了一种摆脱静态top-k检索设置的方法。交叉编码器被训练来预测最适合检索的top-k值,确保大型语言模型为每个问题接收最优的上下文量。这种方法不仅保留了检索到的信息的相关性和正确性,正如有希望的初步结果所证明的那样,而且显著减少了所需的计算资源。该项目展示了检索增强生成更细致、更智能应用的潜力,从而开发出一个回答准确、使用语言模型标记经济的问答系统。


文章来源:https://medium.com/@sauravjoshi23/optimizing-retrieval-augmentation-with-dynamic-top-k-tuning-for-efficient-question-answering-11961503d4ae
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消