将多模态数据融入大型语言模型:方法与应用

2024年10月18日 由 alex 发表 44 0

大型语言模型 (LLM) 具有知识截止日期,无法回答针对其知识库中不存在的特定数据的查询。例如,LLM 无法回答有关公司去年会议记录的数据的查询。同样,LLM 容易产生幻觉并提供看似合理的错误答案。


为了解决这个问题,检索增强生成 (RAG) 解决方案变得越来越流行。RAG 的主要思想是将外部文档集成到 LLM 中,并指导其行为仅从外部知识库回答问题。这是通过将文档分块为较小的块、计算每个块的嵌入(数字表示)并将嵌入作为索引存储在专门的矢量数据库中来实现的。


3


语境检索 RAG

将用户的查询与向量数据库中的小块进行匹配的过程通常效果很好; 但是,它存在以下问题:

  • 一个问题的答案可能需要多个块,这些块可能彼此相距甚远。由于上下文缺失,不可能找到所有相关的块。例如,考虑一个法律文件的问题:“Alpha A 和 Beta B 公司之间终止合作关系的条件是什么?”文档中的一个块可能写着:“协议可以在特定条件下终止”。然而,由于缺乏任何上下文信息(没有公司名称),在检索过程中无法选择这个块。
  • 对于某些问题,老式的最佳匹配搜索比语义搜索效果更好,尤其是对于精确匹配。例如,在电子商务文档中,语义搜索方法对查询“产品 ID ZX-450 是什么? ”的回答可能会带来有关多种产品的信息,但缺少精确的“ ZX-450 ”产品。
  • 从向量数据库检索到的信息被传递到 LLM,LLM 根据查询生成最终答案。在此过程中,LLM 必须决定最合适的块来生成最终答案。检索到的块太多可能会导致响应中出现不相关的信息。因此,LLM 必须有一个排名机制。


为了解决这些问题,Anthropic 最近推出了一种向每个块添加上下文的方法,与单纯的 RAG 相比,该方法的性能有了显著的提升。将文档拆分成块后,该方法首先通过将块连同整个文档作为上下文一起发送到 LLM,为每个块分配一个简短的上下文。随后,将上下文附加的块保存到向量数据库中。他们进一步将上下文分块与最佳匹配相结合,使用bm25 检索器(使用 BM25 方法搜索文档)和重新排序模型(根据相关性为每个检索到的块分配排名分数)。


具有上下文检索功能的多模态 RAG

尽管性能显著提升,但 Anthropic 仅证明了这些方法仅适用于文本。许多文档中丰富的信息来源是图像(图形、图表)和复杂表格。如果我们仅解析文档中的文本,我们将无法深入了解文档中的其他模态。包含图像和复杂表格的文档需要高效的解析方法,这不仅需要从文档中正确提取它们,还需要理解它们。


使用 Anthropic 的最新模型 ( claude-3–5-sonnet-20240620 )为文档中的每个块分配上下文,在文档较大的情况下可能会产生高昂的成本,因为这需要将整个文档与每个块一起发送。尽管 Claude 的即时缓存技术可以通过在 API 调用之间缓存常用的上下文来显著降低此成本,但成本仍然远高于 OpenAI 的经济高效的模型(例如gpt-4o-mini)。


本文讨论了 Anthropic 方法的扩展,如下所示:

  • 使用LlamaParse将所有内容(从文本到表格到图像)提取为结构良好的 markdown。
  • 节点解析器用于将文档解析为节点,而不是使用文本分割器将文档分割成块。这不仅涉及分割文本,还涉及理解文档的结构、语义和元数据。
  • OpenAI 极具成本效益的 LLM gpt-4o-mini和嵌入模型text-embedding-3-small用于为每个节点分配上下文、生成最终响应以及计算节点的嵌入。


本文讨论的上下文检索实现是一种低成本、多模态 RAG 解决方案,通过 BM25 搜索和重新排序提高了检索性能。还将这种基于上下文检索的多模态 RAG (CMRAG) 的性能与基本 RAG 和 LlamaIndex 的上下文检索实现进行了比较。


本文实现 CMRAG 的总体方法如下:


4


多模态解析

为了运行本文讨论的代码,需要安装以下库。


!pip install llama-index ipython cohere rank-bm25 pydantic nest-asyncio python-dotenv openai llama-parse


LlamaParse 使用供应商提供的 多模态模型(如gpt-4o )进行 多模态 解析,以处理文档提取。


parser = LlamaParse(
  use_vendor_multimodal_model=True
  vendor_multimodal_model_name="openai-gpt-4o"
  vendor_multimodal_api_key=sk-proj-xxxxxx
)


在这种模式下,文档的每一页都会被截图,然后发送给多模态模型,并指示将其提取为 markdown。每一页的 markdown 结果将合并为最终输出。


最近推出的LlamaParse Premium 模式提供了先进的多模态文档解析功能,可将文本、表格和图片提取为结构良好的标记符,同时大幅减少内容缺失和幻觉


LlamaParse 高级模式的使用方法如下:


from llama_parse import LlamaParse
import os
# Function to read all files from a specified directory
def read_docs(data_dir) -> List[str]:
    files = []
    for f in os.listdir(data_dir):
        fname = os.path.join(data_dir, f)
        if os.path.isfile(fname):
            files.append(fname)
    return files
parser = LlamaParse(
    result_type="markdown",
    premium_mode=True,
    api_key=os.getenv("LLAMA_CLOUD_API_KEY")
)
files = read_docs(data_dir = DATA_DIR) 


我们首先从指定目录读取文档,使用解析器的get_json_result() 方法解析文档,并使用解析器的get_images() 方法获取图像字典。随后,使用retrieve_nodes() 方法提取节点并发送到 LLM,以便根据整个文档分配上下文。解析该文档(60 页)(包括获取图像字典)耗时 5 分 34 秒(一次性处理)。


print("Parsing...")
json_results = parser.get_json_result(files)
print("Getting image dictionaries...")
images = parser.get_images(json_results, download_path=image_dir)
print("Retrieving nodes...")


5


json_results[0]["pages"][3] 


6


上下文检索

通过retrieve_nodes() 函数从解析后的josn_results中提取单个节点和相关图像(截图)。每个节点会连同所有节点(以下代码中的doc变量)一起发送到_assign_context() 函数。_assign_context()函数使用提示模板CONTEXT_PROMPT_TMPL (采用并修改自此源)为每个节点添加简洁的上下文。这样,我们就将元数据、markdown 文本、上下文和原始文本整合到了节点中。


下面的代码展示了retrieve_nodes() 函数的实现。两个辅助函数_get_sorted_image_files()和get_img_page_number()分别用于获取按页排序的图像文件和图像的页码。总体目标是不像简单的 RAG 那样只依赖原始文本生成最终答案,而是考虑元数据、markdown 文本、上下文、原始文本以及检索到的节点的整个图像(截图)(节点元数据中的图像链接)来生成最终响应。


# Function to get page number of images using regex on file names
def get_img_page_number(file_name):
    match = re.search(r"-page-(\d+)\.jpg$", str(file_name))
    if match:
        return int(match.group(1))
    return 0
# Function to get image files sorted by page
def _get_sorted_image_files(image_dir):
    raw_files = [f for f in list(Path(image_dir).iterdir()) if f.is_file()]
    sorted_files = sorted(raw_files, key=get_img_page_number)
    return sorted_files
# Context prompt template for contextual chunking
CONTEXT_PROMPT_TMPL = """
You are an AI assistant specializing in document analysis. Your task is to provide brief, relevant context for a chunk of text from the given document.
Here is the document:
<document>
{document}
</document>
Here is the chunk we want to situate within the whole document:
<chunk>
{chunk}
</chunk>
Provide a concise context (2-3 sentences) for this chunk, considering the following guidelines:
1. Identify the main topic or concept discussed in the chunk.
2. Mention any relevant information or comparisons from the broader document context.
3. If applicable, note how this information relates to the overall theme or purpose of the document.
4. Include any key figures, dates, or percentages that provide important context.
5. Do not use phrases like "This chunk discusses" or "This section provides". Instead, directly state the context.
Please give a short succinct context to situate this chunk within the overall document to improve search retrieval of the chunk. 
Answer only with the succinct context and nothing else.
Context:
"""
CONTEXT_PROMPT = PromptTemplate(CONTEXT_PROMPT_TMPL)
# Function to generate context for each chunk
def _assign_context(document: str, chunk: str, llm) -> str:
    prompt = CONTEXT_PROMPT.format(document=document, chunk=chunk)
    response = llm.complete(prompt)
    context = response.text.strip()
    return context
# Function to create text nodes with context
def retrieve_nodes(json_results, image_dir, llm) -> List[TextNode]:
    nodes = []
    for result in json_results:
        json_dicts = result["pages"]
        document_name = result["file_path"].split('/')[-1]
        docs = [doc["md"] for doc in json_dicts]  # Extract text
        image_files = _get_sorted_image_files(image_dir)  # Extract images
        # Join all docs to create the full document text
        document_text = "\n\n".join(docs)
        for idx, doc in enumerate(docs):
            # Generate context for each chunk (page)
            context = _assign_context(document_text, doc, llm)
            # Combine context with the original chunk
            contextualized_content = f"{context}\n\n{doc}"
            # Create the text node with the contextualized content
            chunk_metadata = {"page_num": idx + 1}
            chunk_metadata["image_path"] = str(image_files[idx])
            chunk_metadata["parsed_text_markdown"] = docs[idx]
        
            node = TextNode(
                text=contextualized_content,
                metadata=chunk_metadata,
            )
            nodes.append(node)
    return nodes
# Get text nodes
text_node_with_context = retrieve_nodes(json_results, image_dir, llm)First page of the report (image by author)First page of the report (image by author)


下面是与报告第一页相对应的节点的描述。


7


利用 BM25 和重新排序加强上下文检索

所有包含元数据、原始文本、标记文本和上下文信息的节点都会被索引到一个矢量数据库中。为节点创建 BM25 索引,并保存在一个 pickle 文件中,用于查询推理。处理过的节点也会被保存起来,以供日后使用(text_node_with_context.pkl)。


    # Create the vector store index
    index = VectorStoreIndex(text_node_with_context, embed_model=embed_model)
    index.storage_context.persist(persist_dir=output_dir)
    # Build BM25 index
    documents = [node.text for node in text_node_with_context]
    tokenized_documents = [doc.split() for doc in documents]
    bm25 = BM25Okapi(tokenized_documents)
    # Save bm25 and text_node_with_context
    with open(os.path.join(output_dir, 'tokenized_documents.pkl'), 'wb') as f:
        pickle.dump(tokenized_documents, f)
    with open(os.path.join(output_dir, 'text_node_with_context.pkl'), 'wb') as f:
        pickle.dump(text_node_with_context, f)


现在,我们可以使用以下管道初始化查询引擎,进行查询。但在此之前,我们需要设置以下提示,以指导 LLM 生成最终响应的行为。初始化多模态 LLM(gpt-4o-mini)以生成最终响应。该提示可根据需要进行调整。


# Define the QA prompt template
RAG_PROMPT = """\
Below we give parsed text from documents in two different formats, as well as the image.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, answer the query. Generate the answer by analyzing parsed markdown, raw text and the related
image. Especially, carefully analyze the images to look for the required information.
Format the answer in proper format as deems suitable (bulleted lists, sections/sub-sections, tables, etc.)
Give the page's number and the document name where you find the response based on the Context.
Query: {query_str}
Answer: """
PROMPT = PromptTemplate(RAG_PROMPT)
# Initialize the multimodal LLM
MM_LLM = OpenAIMultiModal(model="gpt-4o-mini", temperature=0.0, max_tokens=16000)


在查询引擎中集成整个流程

下面的QueryEngine 类实现了上述工作流程。BM25 搜索的节点数(top_n_bm25 )和重新排序器重新排序的结果数(top_n )可以根据需要进行调整。通过切换 GitHub 代码中的best_match_25和re_ranking变量,可以选择或取消 BM25 搜索和重新排序。


以下是QueryEngine 类实现的整体工作流程。

  1. 查找查询嵌入
  2. 使用基于向量的检索从向量数据库中检索节点
  3. 使用 BM25 检索(如果选择)检索节点
  4. 合并 BM25 和基于向量检索的节点。查找节点的唯一数量(删除重复节点)
  5. 应用重新排序对合并结果进行重新排序(如果选择)。在此,我们使用 Cohere 的rerank-english-v2.0重新排序器模型。你可以在Cohere 网站上创建一个账户,获取试用版 API 密钥。
  6. 从与节点相关的图像中创建图像节点
  7. 从解析的标记符文本中创建上下文字符串
  8. 将节点图像发送至多模态 LLM 进行解释。
  9. 将文本节点、图像节点描述和元数据发送至 LLM,生成最终响应。


# DeFfine the QueryEngine integrating all methods
class QueryEngine(CustomQueryEngine):
    # Public fields
    qa_prompt: PromptTemplate
    multi_modal_llm: OpenAIMultiModal
    node_postprocessors: Optional[List[BaseNodePostprocessor]] = None
    # Private attributes using PrivateAttr
    _bm25: BM25Okapi = PrivateAttr()
    _llm: OpenAI = PrivateAttr()
    _text_node_with_context: List[TextNode] = PrivateAttr()
    _vector_index: VectorStoreIndex = PrivateAttr()
    def __init__(
        self,
        qa_prompt: PromptTemplate,
        bm25: BM25Okapi,
        multi_modal_llm: OpenAIMultiModal,
        vector_index: VectorStoreIndex,
        node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
        llm: OpenAI = None,
        text_node_with_context: List[TextNode] = None,
    ):
        super().__init__(
            qa_prompt=qa_prompt,
            retriever=None,
            multi_modal_llm=multi_modal_llm,
            node_postprocessors=node_postprocessors
        )
        self._bm25 = bm25
        self._llm = llm
        self._text_node_with_context = text_node_with_context
        self._vector_index = vector_index
    def custom_query(self, query_str: str):
        # Prepare the query bundle
        query_bundle = QueryBundle(query_str)
        bm25_nodes = []
        if best_match_25 == 1:  # if BM25 search is selected
            # Retrieve nodes using BM25
            query_tokens = query_str.split()
            bm25_scores = self._bm25.get_scores(query_tokens)
            top_n_bm25 = 5  # Adjust the number of top nodes to retrieve
            # Get indices of top BM25 scores
            top_indices_bm25 = bm25_scores.argsort()[-top_n_bm25:][::-1]
            bm25_nodes = [self._text_node_with_context[i] for i in top_indices_bm25]
            logging.info(f"BM25 nodes retrieved: {len(bm25_nodes)}")
        else:
            logging.info("BM25 not selected.")
        # Retrieve nodes using vector-based retrieval from the vector store
        vector_retriever = self._vector_index.as_query_engine().retriever
        vector_nodes_with_scores = vector_retriever.retrieve(query_bundle)
        # Specify the number of top vectors you want
        top_n_vectors = 5  # Adjust this value as needed
        # Get only the top 'n' nodes
        top_vector_nodes_with_scores = vector_nodes_with_scores[:top_n_vectors]
        vector_nodes = [node.node for node in top_vector_nodes_with_scores]
        logging.info(f"Vector nodes retrieved: {len(vector_nodes)}")
        # Combine nodes and remove duplicates
        all_nodes = vector_nodes + bm25_nodes
        unique_nodes_dict = {node.node_id: node for node in all_nodes}
        unique_nodes = list(unique_nodes_dict.values())
        logging.info(f"Unique nodes after deduplication: {len(unique_nodes)}")
        nodes = unique_nodes
        if re_ranking == 1:  # if re-ranking is selected
            # Apply Cohere Re-ranking to rerank the combined results
            documents = [node.get_content() for node in nodes]
            max_retries = 3
            for attempt in range(max_retries):
                try:
                    reranked = cohere_client.rerank(
                        model="rerank-english-v2.0",
                        query=query_str,
                        documents=documents,
                        top_n=3  # top-3 re-ranked nodes
                    )
                    break
                except CohereError as e:
                    if attempt < max_retries - 1:
                        logging.warning(f"Error occurred: {str(e)}. Waiting for 60 seconds before retry {attempt + 1}/{max_retries}")
                        time.sleep(60)  # Wait before retrying
                    else:
                        logging.error("Error occurred. Max retries reached. Proceeding without re-ranking.")
                        reranked = None
                        break
            if reranked:
                reranked_indices = [result.index for result in reranked.results]
                nodes = [nodes[i] for i in reranked_indices]
            else:
                nodes = nodes[:3]  # Fallback to top 3 nodes
            logging.info(f"Nodes after re-ranking: {len(nodes)}")
        else:
            logging.info("Re-ranking not selected.")
        # Limit and filter node content for context string
        max_context_length = 16000  # Adjust as required
        current_length = 0
        filtered_nodes = []
        # Initialize tokenizer
        from transformers import GPT2TokenizerFast
        tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        for node in nodes:
            content = node.get_content(metadata_mode=MetadataMode.LLM).strip()
            node_length = len(tokenizer.encode(content))
            logging.info(f"Node ID: {node.node_id}, Content Length (tokens): {node_length}")
            if not content:
                logging.warning(f"Node ID: {node.node_id} has empty content. Skipping.")
                continue
            if current_length + node_length <= max_context_length:
                filtered_nodes.append(node)
                current_length += node_length
            else:
                logging.info(f"Reached max context length with Node ID: {node.node_id}")
                break
        logging.info(f"Filtered nodes for context: {len(filtered_nodes)}")
        # Create context string
        ctx_str = "\n\n".join(
            [n.get_content(metadata_mode=MetadataMode.LLM).strip() for n in filtered_nodes]
        )
        # Create image nodes from the images associated with the nodes
        image_nodes = []
        for n in filtered_nodes:
            if "image_path" in n.metadata:
                image_nodes.append(
                    NodeWithScore(node=ImageNode(image_path=n.metadata["image_path"]))
                )
            else:
                logging.warning(f"Node ID: {n.node_id} lacks 'image_path' metadata.")
        logging.info(f"Image nodes created: {len(image_nodes)}")
        # Prepare prompt for the LLM
        fmt_prompt = self.qa_prompt.format(context_str=ctx_str, query_str=query_str)
        # Use the multimodal LLM to interpret images and generate a response
        llm_response = self.multi_modal_llm.complete(
            prompt=fmt_prompt,
            image_documents=[image_node.node for image_node in image_nodes],
            max_tokens=16000
        )
        logging.info(f"LLM response generated.")
        # Return the final response
        return Response(
            response=str(llm_response),
            source_nodes=filtered_nodes,
            metadata={
                "text_node_with_context": self._text_node_with_context,
                "image_nodes": image_nodes,
            },
        )
# Initialize the query engine with BM25, Cohere Re-ranking, and Query Expansion
query_engine = QueryEngine(
    qa_prompt=PROMPT,
    bm25=bm25,
    multi_modal_llm=MM_LLM,
    vector_index=index,
    node_postprocessors=[],
    llm=llm,
    text_node_with_context=text_node_with_context
)
print("All done")


使用 OpenAI 模型(尤其是gpt-4o-mini)的优势在于,上下文分配和查询推理运行的成本更低,上下文分配时间也更短。虽然 OpenAI 和 Anthropic 的基本层都能很快达到 API 调用的最大速率限制,但 Anthropic 基本层的重试时间各不相同,可能会过长。使用claude-3-5-sonnet-20240620 仅对本文档的前 20 页进行上下文分配,在及时缓存的情况下耗时约 170 秒,花费 20 美分(输入 + 输出令牌)。而gpt-4o-mini 与 Claude 3.5 Sonnet 相比,输入令牌的成本大约降低了 20 倍,输出令牌的成本大约降低了 25 倍。OpenAI 声称 对重复内容实施了提示缓存 ,可自动用于所有 API 调用。


相比之下,通过gpt-4o-mini对整个文档(60 页)中的节点进行上下文分配仅用了约 193 秒,且没有任何重试请求。


在实现QueryEngine 类后,我们可以按如下方式运行查询推理:


original_query = """What are the top countries to whose citizens the Finnish Immigration Service issued the highest number of first residence permits in 2023?
Which of these countries received the highest number of first residence permits?"""
response = query_engine.query(original_query)
display(Markdown(str(response)))


以下是对该查询的标记回复。


8


查询答复中引用的网页如下。


9


现在,让我们比较一下基于 gpt-4o-mini 的 RAG(LlamaParse 高级搜索 + 上下文检索 + BM25 + 重新排序)和基于 Claude 的 RAG(LlamaParse 高级搜索 + 上下文检索)的性能。我还实现了一个简单的基准 RAG,可在 GitHub 的笔记本中找到。以下是要比较的三种 RAG。

  1. LlamaIndex 中的简单 RAG 使用SentenceSplitter 将文档分割成块(chunk_size= 800,chunk_overlap= 400),创建向量索引和向量检索。
  2. CMRAG (claude-3-5-sonnet-20240620,voyage-3) - LlamaParse 高级模式 + 上下文检索
  3. CMRAG (gpt-4o-mini, text-embedding-3-small) - LlamaParse 高级模式 + 上下文检索 + BM25 + 重新排序


为简单起见,我们将这些 RAG 分别称为 RAG0、RAG1 和 RAG2。下面是报告中的三页,我向每个 RAG 提出了三个问题(每页一个问题)。红色矩形突出显示的区域为基本事实或正确答案的来源。


10

11

12


以下是三个 RAG 对每个问题的答复。


13


可以看出,RAG2 的表现非常出色。对于第一个问题,RAG0 提供了一个错误答案,因为该问题是通过图像提出的。RAG1 和 RAG2 都给出了正确答案。对于其他两个问题,RAG0 无法提供任何答案。而 RAG1 和 RAG2 都给出了正确答案。


总体而言,由于整合了 BM25、重新排序和更好的提示,RAG2 的性能在许多情况下与 RAG1 相当,甚至更好。它为情境、多模态 RAG 提供了一种经济高效的解决方案。假设文档嵌入(hyde)或查询扩展是这一流程中可能的整合。



文章来源:https://medium.com/towards-data-science/integrating-multimodal-data-into-a-large-language-model-d1965b8ab00c
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消