LlamaIndex:通过RAG混合搜索中的Alpha调整增强检索性能

2024年02月04日 由 alex 发表 795 0

简介


检索适当的数据块、节点或上下文是构建高效检索增强生成(RAG)应用程序的一个重要方面。然而,基于向量或嵌入的搜索可能无法有效处理所有类型的用户查询。


为了解决这个问题,混合搜索结合了基于关键字的方法(BM25)和矢量(嵌入)搜索技术。混合搜索有一个特定参数,即 Alpha,用于平衡关键词(BM25)和矢量搜索在为 RAG 应用程序检索正确上下文时的权重。(alpha=0.0 - 关键字搜索(BM25),alpha=1.0 - 向量搜索)


但有趣的地方在于:微调 Alpha 不仅仅是一项任务,更是一门艺术。实现理想的平衡对于释放混合搜索的全部潜力至关重要。这需要针对 RAG 系统中各种类型的用户查询调整不同的 Alpha 值。


在本文中,我们将利用 LlamaIndex 的检索评估模块,在命中率和 MRR 指标的帮助下,对 Weaviate 向量数据库中的 Alpha 值进行调整。


在深入实施之前,让我们先了解一下本文将使用的不同查询类型和指标。


不同的用户查询类型:


RAG 应用程序中的用户查询因个人意图而异。针对这些不同的查询类型,必须对 Alpha 参数进行微调。这一过程包括将每个用户查询路由到特定的 Alpha 值,以实现有效的检索和响应合成。微软已经确定了各种用户查询类别,我们选择了其中几类来调整我们的混合搜索。以下是我们考虑的不同用户查询类型:


  1. 网络搜索查询: 类似于通常输入搜索引擎的简短查询。
  2. 概念搜索查询: 抽象的问题,需要详细的、多句式的回答。
  3. 事实搜索查询: 只有一个明确答案的查询。
  4. 关键词查询: 仅由关键识别词组成的简明查询。
  5. 拼写错误查询: 包含错别字、移位和常见拼写错误的查询。
  6. 精确子字符串搜索: 与原始上下文中的子字符串完全匹配的查询。


让我们来看看这些不同用户查询类型的示例:


1.网页搜索查询


Transfer capabilities of LLaMA language model to non-English languages

2.概念寻求查询

What is the dual-encoder architecture used in recent works on dense retrievers?

3.事实寻求查询

What is the total number of propositions the English Wikipedia dump is segmented into in FACTOID WIKI?

4. 关键字查询

GTR retriever recall rate

5. 拼写错误的查询

What is the advntage of prposition retrieval over sentnce or passage retrieval?

6. 精确子串搜索

first kwords for the GTR retriever. Finer-grained


检索评估指标:


我们将使用命中率和 MRR 指标进行检索评估。让我们来了解一下这些指标。


命中率:


命中率衡量的是在前 k 个结果块/上下文中出现正确块/上下文的查询比例。简单地说,它评估的是我们的系统在前 k 个数据块中正确识别数据块的频率。


平均互易排名(MRR):


MRR 通过考虑每个查询中排名最高的相关块/上下文的位置来评估系统的准确性。它计算的是所有查询中这些位置的倒数平均值。例如,如果第一个相关块/上下文位于列表顶部,则其倒数排名为 1;如果它是第二个项目,则倒数排名变为 1/2,这种模式会一直持续下去。


开始实施


我们将采用系统的方法来实施实验工作流程,包括以下步骤:


  1. 数据下载。
  2. 数据加载。
  3. Weaviate 客户端设置。
  4. 创建索引和插入节点
  5. 定义 LLM (GPT-4)
  6. 定义 CohereAI Reranker。
  7. 生成各种查询类型的合成查询。
  8. 定义自定义检索器
  9. 检索评估和指标计算函数
  10. 针对不同查询类型和 Alpha 值进行检索评估。


首先,让我们定义实现过程中的一些基本函数。


  1. get_weaviate_client - 设置 weaviate 客户端。
  2. load_documents - 从文件路径加载文档。
  3. create_nodes - 使用文本分割器分割文档,创建节点。
  4. connect_index - 连接到 weaviate 索引。
  5. insert_nodes_index - 向索引中插入节点。


def get_weaviate_client(api_key, url):
  auth_config = weaviate.AuthApiKey(api_key=api_key)
  client = weaviate.Client(
    url=url,
    auth_client_secret=auth_config
  )
  return client
def load_documents(file_path, num_pages=None):
  if num_pages:
    documents = SimpleDirectoryReader(input_files=[file_path]).load_data()[:num_pages]
  else:
    documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
  return documents
def create_nodes(documents, chunk_size=512, chunk_overlap=0):
  node_parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
  nodes = node_parser.get_nodes_from_documents(documents)
  return nodes
def connect_index(weaviate_client):
  vector_store = WeaviateVectorStore(weaviate_client=weaviate_client)
  storage_context = StorageContext.from_defaults(vector_store=vector_store)
  index = VectorStoreIndex([], storage_context=storage_context)
  return index
def insert_nodes_index(index, nodes):
  index.insert_nodes(nodes)


下载数据


!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2312.04511.pdf" -O "llm_compiler.pdf""Mozilla" "https://arxiv.org/pdf/2312.04511.pdf" -O "llm_compiler.pdf"
!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2401.01055.pdf" -O "llama_beyond_english.pdf"
!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2312.06648.pdf" -O "dense_x_retrieval.pdf"


加载数据


# load documents, we will skip references and appendices from the papers.
documents1 = load_documents("llm_compiler.pdf", 12)
documents2 = load_documents("dense_x_retrieval.pdf", 9)
documents3 = load_documents("llama_beyond_english.pdf", 7)
# create nodes
nodes1 = create_nodes(documents1)
nodes2 = create_nodes(documents2)
nodes3 = create_nodes(documents3)


设置 Weaviate 客户端


url = 'cluster URL''cluster URL'
api_key = 'your api key'
client = get_weaviate_client(api_key, url)


创建索引并插入节点


index = connect_index(client)
insert_nodes_index(index, nodes1)


定义 LLM


# Deing LLM for query generation
llm = OpenAI(model='gpt-4', temperature=0.1)


 创建合成查询


我们将创建前面讨论过的查询,检查笔记本中每种查询类型的提示,并为每种查询类型编写代码。显示代码片段以供参考。


queries = generate_question_context_pairs(
    nodes, 
  llm=llm, 
  num_questions_per_chunk=2, 2, 
  qa_generate_prompt_tmpl = qa_template
)


定义重链器


reranker = CohereRerank(api_key=os.environ['COHERE_API_KEY'], top_n=4)'COHERE_API_KEY'], top_n=4)


定义CustomRetriever


我们将定义 CustomRetriever 类,以便在使用或不使用检索器的情况下执行检索操作。


class CustomRetriever(BaseRetriever):
    """Custom retriever that performs hybrid search with and without reranker"""
    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        reranker: CohereRerank
    ) -> None:
        """Init params."""
        self._vector_retriever = vector_retriever
        self._reranker = reranker
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        retrieved_nodes = self._vector_retriever.retrieve(query_bundle)
        if self._reranker != None:
            retrieved_nodes = self._reranker.postprocess_nodes(retrieved_nodes, query_bundle)
        else:
            retrieved_nodes = retrieved_nodes[:4]
        return retrieved_nodes
    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Asynchronously retrieve nodes given query.
        Implemented by the user.
        """
        return self._retrieve(query_bundle)
    async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
        if isinstance(str_or_query_bundle, str):
            str_or_query_bundle = QueryBundle(str_or_query_bundle)
        return await self._aretrieve(str_or_query_bundle)


定义用于寻回器评估和指标计算的函数


我们将研究使用和不使用 Reranker 时,不同阿尔法值的检索器性能。


# Alpha values and datasets to test
alpha_values = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
# Function to evaluate retriever and return results
async def evaluate_retriever(alpha, dataset, reranker=None):
    retriever = VectorIndexRetriever(index,
                                     vector_store_query_mode="hybrid",
                                     similarity_top_k=10,
                                     alpha=alpha)
    custom_retriever = CustomRetriever(retriever,
                                       reranker)
    retriever_evaluator = RetrieverEvaluator.from_metric_names(["mrr", "hit_rate"], retriever=custom_retriever)
    eval_results = await retriever_evaluator.aevaluate_dataset(dataset)
    return eval_results
# Function to calculate and store metrics
def calculate_metrics(eval_results):
    metric_dicts = []
    for eval_result in eval_results:
        metric_dict = eval_result.metric_vals_dict
        metric_dicts.append(metric_dict)
    full_df = pd.DataFrame(metric_dicts)
    hit_rate = full_df["hit_rate"].mean()
    mrr = full_df["mrr"].mean()
    return hit_rate, mrr


检索评估


在这里,我们对不同的查询类型(数据集)和 alpha 值进行检索评估,以了解哪种 alpha 值适合哪种查询类型。你需要相应地插入 Reranker,以计算使用和不使用 Reranker 时的检索评估结果。


# Asynchronous function to loop over datasets and alpha values and evaluate
async def main():
    results_df = pd.DataFrame(columns=['Dataset', 'Alpha', 'Hit Rate', 'MRR'])
    for dataset in datasets_single_document.keys():
        for alpha in alpha_values:
            eval_results = await evaluate_retriever(alpha, datasets_single_document[dataset])
            hit_rate, mrr = calculate_metrics(eval_results)
            new_row = pd.DataFrame({'Dataset': [dataset], 'Alpha': [alpha], 'Hit Rate': [hit_rate], 'MRR': [mrr]})
            results_df = pd.concat([results_df, new_row], ignore_index=True)
    # Determine the grid size for subplots
    num_rows = len(datasets_single_document) // 2 + len(datasets_single_document) % 2
    num_cols = 2
    # Plotting the results in a grid
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, num_rows * 4), squeeze=False)  # Ensure axes is always 2D
    for i, dataset in enumerate(datasets_single_document):
        ax = axes[i // num_cols, i % num_cols]
        dataset_df = results_df[results_df['Dataset'] == dataset]
        ax.plot(dataset_df['Alpha'], dataset_df['Hit Rate'], marker='o', label='Hit Rate')
        ax.plot(dataset_df['Alpha'], dataset_df['MRR'], marker='o', linestyle='--', label='MRR')
        ax.set_xlabel('Alpha')
        ax.set_ylabel('Metric Value')
        ax.set_title(f'{dataset}')
        ax.legend()
        ax.grid(True)
    # If the number of datasets is odd, remove the last (empty) subplot
    if len(datasets_single_document) % num_cols != 0:
        fig.delaxes(axes[-1, -1])  # Remove the last subplot if not needed
    # Adjust layout to prevent overlap
    plt.tight_layout()
    plt.show()
# Run the main function
asyncio.run(main())


分析结果:


完成实施阶段后,我们现在将注意力转向分析结果。我们进行了两组实验:一组针对单个文档,另一组针对多个文档。这些实验在阿尔法值、用户查询类型以及包含或不包含重新搜索器等方面各不相同。附图显示了实验结果,重点是作为检索评价指标的命中率和 MRR(平均互易等级)。


使用单一文档


无Reranker


6


使用 Reranker:


7


有多个文档:


无 Reranker


8


使用 Reranker:


9


观察结果


  1. 在 reranker 的帮助下,单个和多个文档索引的命中率和 MRR 均有所提高。这一次又一次地证明,在 RAG 应用程序中使用 reranker 是非常有用的。
  2. 虽然大多数情况下混合搜索比关键字/矢量搜索更胜一筹,但仍应根据 RAG 应用程序中的用户查询,针对不同查询类型进行仔细评估。
  3. 索引单个文档和多个文档时的行为是不同的,这表明最好在向索引中添加文档时调整 alpha。


总结


在这篇博文中,我们研究了如何在混合搜索系统中针对一系列查询类型调整 Alpha。有趣的是,当索引单个文档或多个文档时,结果是如何变化的。今后,你可以考虑尝试使用来自不同领域的文档,对各种查询类型采用不同的查询长度。


文章来源:https://medium.com/llamaindex-blog/llamaindex-enhancing-retrieval-performance-with-alpha-tuning-in-hybrid-search-in-rag-135d0c9b8a00
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消