通过视觉模型提升RAG效果:增强文档处理

2024年10月24日 由 alex 发表 173 0

传统的检索增强生成 (RAG) 方法彻底改变了我们与文档的交互方式,但它仍然缺少关键的视觉上下文。如果 RAG 不仅可以读取,还可以查看,会怎么样?通过将视觉语言模型 (VLM) 与传统文本处理相结合,我们开发了一种双流 RAG 架构,可处理 PDF 文档中的文本和视觉内容。我们的方法利用多Qdrant’s向量功能来存储文本和图像嵌入,从而实现更丰富的上下文检索。查询时,系统不仅匹配文本 - 它实际上“看到”文档页面,从而获得更准确、更具有上下文感知的响应。在本文中,我们将探讨这种视觉增强型 RAG 系统如何为文档理解和检索开辟新的可能性。


2


架构:

我根据图表来解释一下这个创新的视觉增强RAG系统的架构。


该系统以 PDF 文档作为主要输入,经过双重处理流以最大限度地提取信息。在第一个流中,每页都转换为图像,而在并行流中,从每页中提取文本。这种双重方法确保在处理阶段不会丢失任何信息。然后,提取的内容被矢量化并存储在 Qdrant 中,这是一个矢量数据库,可以高效处理每个文档的多种矢量类型。Qdrant 中的每个条目都包含图像和文本矢量,以及必要的元数据,包括页码、base64 编码的页面图像和提取的文本。


当用户提交查询时,Qdrant 的预取功能就会发挥作用,根据向量相似性检索前三个最相关的结果(如本实现中配置的那样)。这就是架构变得特别有趣的地方——系统不会止步于传统的基于文本的检索。相同的用户查询以及检索到的 base64 编码图像被传递到视觉语言模型 (VLM),在本例中具体是 OpenAI 的视觉模型。这使系统能够对实际文档布局和内容进行视觉分析,从而提供额外的理解层面。


该架构的最后一部分涉及聚合语言学习模型 (LLM),该模型结合了基于文本的检索和视觉模型的分析结果。该聚合器综合了来自两个流的信息,生成了综合响应,充分利用了对文档的文本和视觉理解。结果是一个更强大且具有上下文感知能力的系统,可以从文本和视觉角度提供具有强大支持证据的答案。


这种架构的出色之处在于它不仅能够将文档理解为文本,还能理解文档的本来面目——包括布局、格式和通常包含关键上下文信息的视觉元素。这种双流方法与现代矢量搜索功能和视觉模型相结合,代表了 RAG 系统的重大进步。


实施:

让我们看看下面架构中的摄取部分。


3


让我们设计一个类,pdf_processor.py并调用所示的方法。


4


from pdf2image import convert_from_path
from pypdf import PdfReader
import os

class PDFProcessor:
    """
    A class to handle PDF processing operations including text extraction and image conversion.
    """
    def __init__(self, pdf_path, output_dir):
        """
        Initialize the PDF processor.
        Parameters:
        - pdf_path: str, path to the PDF file
        - output_dir: str, directory to save the outputs
        """
        self.pdf_path = pdf_path
        self.output_dir = output_dir
        self.saved_images = []
        self.page_texts = []
        self.page_dicts = []
        # Create output directory if it doesn't exist
        os.makedirs(self.output_dir, exist_ok=True)
    def extract_text(self):
        """
        Extract text from each page of the PDF.
        """
        print("Extracting text from PDF...")
        reader = PdfReader(self.pdf_path)
        # Extract text from each page
        for i, page in enumerate(reader.pages):
            text = page.extract_text()
            self.page_texts.append(text)
            # Save text to file
            text_file_path = os.path.join(self.output_dir, f'page_{i + 1}.txt')
            with open(text_file_path, 'w', encoding='utf-8') as f:
                f.write(text)
            print(f"Saved text from page {i + 1} to {text_file_path}")
    def convert_to_images(self, dpi=200, fmt='png'):
        """
        Convert each page of the PDF to images.
        Parameters:
        - dpi: int, resolution of output images
        - fmt: str, output image format
        """
        print("Converting PDF pages to images...")
        pages = convert_from_path(self.pdf_path, dpi=dpi)
        # Save each page as an image
        for i, page in enumerate(pages):
            image_path = os.path.join(self.output_dir, f'page_{i + 1}.{fmt}')
            page.save(image_path, fmt)
            self.saved_images.append(image_path)
            print(f"Saved image from page {i + 1} to {image_path}")
    def create_page_dicts(self, fmt='png'):
        """
        Create a list of dictionaries containing page information.
        Parameters:
        - fmt: str, image format used (needed for filenames)
        Returns:
        - list of dictionaries with page information
        """
        num_pages = max(len(self.saved_images) if self.saved_images else 0,
                        len(self.page_texts) if self.page_texts else 0)
        self.page_dicts = []
        for i in range(num_pages):
            page_dict = {
                "image": f"page_{i + 1}.{fmt}" if self.saved_images else None,
                "text": f"page_{i + 1}.txt" if self.page_texts else None
            }
            self.page_dicts.append(page_dict)
        return self.page_dicts
    def process(self, extract_images=True, extract_text=True, dpi=200, fmt='png'):
        """
        Process the PDF file with specified operations.
        Parameters:
        - extract_images: bool, whether to convert pages to images
        - extract_text: bool, whether to extract text
        - dpi: int, resolution of output images
        - fmt: str, output image format
        Returns:
        - tuple: (list of image paths, list of text content, list of page dictionaries)
        """
        try:
            if extract_text:
                self.extract_text()
            if extract_images:
                self.convert_to_images(dpi=dpi, fmt=fmt)
            self.create_page_dicts(fmt=fmt)
            return self.saved_images, self.page_texts, self.page_dicts
        except Exception as e:
            print(f"Error processing PDF: {str(e)}")
            return [], [], []
    def print_extracted_text(self):
        """
        Print the extracted text from each page with clear separation.
        """
        for i, text in enumerate(self.page_texts, 1):
            print(f"\n{'=' * 40}")
            print(f"Page {i}")
            print(f"{'=' * 40}")
            print(text.strip())

# Example driver usage
# if __name__ == "__main__":
#     # Example parameters
#     pdf_file = "data/rag.pdf"  # infact any pdf as input here.
#     output_folder = "pdf_output"
#
#     # Create processor instance
#     processor = PDFProcessor(pdf_file, output_folder)
#
#     # Process PDF - extract both images and text
#     image_paths, texts, page_dicts = processor.process(
#         extract_images=True,
#         extract_text=True,
#         dpi=200,
#         fmt='png'
#     )
#
#     print("\nProcessing complete.")
#     print("\nPage information:")
#     for i, page_info in enumerate(page_dicts, 1):
#         print(f"Page {i}:", page_info)


pdf_output收集图像和全文,以便进一步处理。现在,让我们再创建一个DataIndexerAndRetriever.py类,如下所示。


5


from dotenv import load_dotenv, find_dotenv
from qdrant_client import QdrantClient, models
from fastembed import TextEmbedding
from sentence_transformers import SentenceTransformer
from PIL import Image
import openai
import base64
import io
import os
from pdf_processor import PDFProcessor

class DataIndexerAndRetriever:
    def __init__(self, data_dir='./pdf_output', qdrant_url="http://localhost:6333", qdrant_api_key='th3s3cr3tk3y'):
        """
        Initialize the Research Paper Processor.
        Parameters:
        - data_dir: str, directory containing PDF output files
        - qdrant_url: str, Qdrant server URL
        - qdrant_api_key: str, Qdrant API key
        """
        # Load environment variables
        _ = load_dotenv(find_dotenv())
        self.data_dir = data_dir
        self.collection_name = 'research_papers'
        # Initialize models
        self.client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
        self.image_embedding_model = SentenceTransformer("clip-ViT-B-32")
        self.text_embedding_model = TextEmbedding(
            model_name='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
        )
        # Initialize OpenAI API
        api_key = 'sk-proj-api-key'
        self.openai_client = openai.OpenAI(api_key=api_key)
        # Initialize collection if it doesn't exist
        self._initialize_collection()
    def _initialize_collection(self):
        """Initialize Qdrant collection if it doesn't exist."""
        if not self.client.collection_exists(collection_name=self.collection_name):
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config={
                    "clip-ViT-B-32": models.VectorParams(
                        size=512,
                        distance=models.Distance.COSINE
                    ),
                    "paraphrase-multilingual-MiniLM-L12-v2": models.VectorParams(
                        size=384,
                        distance=models.Distance.COSINE
                    ),
                }
            )
    def get_text_embeddings(self, text_file_path):
        """
        Get embeddings for text file content.
        Parameters:
        - text_file_path: str, path to text file
        Returns:
        - tuple: (text embeddings, full text content)
        """
        with open(file=text_file_path, mode='r') as data:
            full_text = data.read()
        return next(self.text_embedding_model.passage_embed(full_text)), full_text
    def image_to_base64(self, image_path):
        """
        Convert image to base64 and get embeddings.
        Parameters:
        - image_path: str, path to image file
        Returns:
        - tuple: (image embeddings, base64 encoded string)
        """
        try:
            with open(image_path, "rb") as image_file:
                encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
            with Image.open(image_path) as img:
                image_embedding = self.image_embedding_model.encode(img).tolist()
            return image_embedding, encoded_string
        except Exception as e:
            print(f"Error converting image to base64: {str(e)}")
            return None
    def base64_to_image(self, base64_string, output_path=None, fmt='png'):
        """
        Convert base64 string back to image.
        Parameters:
        - base64_string: str, base64 encoded image string
        - output_path: str, path to save decoded image (optional)
        - fmt: str, image format (default: 'png')
        Returns:
        - PIL.Image or str: Image object or path to saved image
        """
        try:
            image_data = base64.b64decode(base64_string)
            image = Image.open(io.BytesIO(image_data))
            if output_path:
                image.save(output_path, fmt)
                return output_path
            return image
        except Exception as e:
            print(f"Error converting base64 to image: {str(e)}")
            return None
    def index_pages(self, pages_data):
        """
        Process and index pages data.
        Parameters:
        - pages_data: list of dict, containing image and text file information
        """
        for index, obj in enumerate(pages_data):
            image_path = os.path.join(self.data_dir, obj["image"])
            text_file_path = os.path.join(self.data_dir, obj["text"])
            image_embedding, base64str = self.image_to_base64(image_path)
            text_embedding, full_text = self.get_text_embeddings(text_file_path=text_file_path)
            points = [
                models.PointStruct(
                    id=index + 1,
                    vector={
                        "clip-ViT-B-32": image_embedding,
                        "paraphrase-multilingual-MiniLM-L12-v2": text_embedding
                    },
                    payload={
                        "_id": index + 1,
                        "base64str": base64str,
                        "full_text": full_text,
                        "page": index + 1
                    }
                )
            ]
            self.client.upsert(
                collection_name=self.collection_name,
                points=points
            )
    def query_with_rrf(self, query_text: str = '', query_image_path: str = ''):
        """
        Query the collection using Reciprocal Rank Fusion.
        Parameters:
        - query_text: str, text query
        - query_image_path: str, path to query image
        Returns:
        - list: search results
        """
        text_embedding = None
        if query_text != '':
            text_embedding = next(self.text_embedding_model.embed(query_text)).tolist()
        image_embedding = None
        if query_image_path != '':
            with Image.open(query_image_path) as img:
                image_embedding = self.image_embedding_model.encode(img).tolist()
        prefetch = None
        if text_embedding and len(text_embedding) > 0:
            prefetch = [
                models.Prefetch(
                    query=text_embedding,
                    using="paraphrase-multilingual-MiniLM-L12-v2",
                    limit=3,
                )
            ]
        if image_embedding and len(image_embedding) > 0:
            prefetch = [
                models.Prefetch(
                    query=image_embedding,
                    using="clip-ViT-B-32",
                    limit=3,
                )
            ]
        results = self.client.query_points(
            collection_name=self.collection_name,
            prefetch=prefetch,
            query=models.FusionQuery(
                fusion=models.Fusion.RRF
            ),
            with_payload=True,
            limit=3,
        )
        return results
    # Function to ask a question about the image using OpenAI API
    def ask_image_question(self, base64_image, question):
        try:
            # Send the image and question to the OpenAI API
            response = self.openai_client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": question + ". Support your answer with evidence from given context. example: page number, section heading etc",
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f_"data:image/jpeg;base64,{base64_image}"
                                },
                            },
                        ],
                    }
                ],
            )
            # Extract and return the response
            answer = response.choices[0].message.content
            return answer
        except Exception as e:
            print(f"Error during API call: {e}")
            return None

# Example usage
if __name__ == "__main__":
    # Sample pages data
    # Example parameters
    pdf_file = "data/rag.pdf"
    output_folder = "pdf_output"
    # Create processor instance (open the below comments for the first time 
    # when you want to process the pdf file)
    # processor = PDFProcessor(pdf_file, output_folder)
    # Process PDF - extract both images and text
    # image_paths, texts, page_dicts = processor.process(
    #     extract_images=True,
    #     extract_text=True,
    #     dpi=200,
    #     fmt='png'
    # )
    # Initialize processor
    processor = DataIndexerAndRetriever()
    # Process pages (uncomment to run indexing into qdrant)
    # processor.index_pages(page_dicts)
    # Query example
    question = 'What is the OpenAI assistants workflow?'
    result = processor.query_with_rrf(query_text=question)
    for point in result.points:
        response = processor.ask_image_question(base64_image=point.payload['base64str'],
                                                question=question)
        print("-" * 50)
        print(response)


下面简要介绍代码的主要功能:

1. DataIndexerAndRetriever类处理具有文本和图像功能的双流文档处理:

  • 初始化与Qdrant向量数据库的连接,并加载所需的嵌入模型(CLIP 用于图像,MiniLM 用于文本)
  • 为视觉模型集成设置 OpenAI 客户端

2. 核心处理功能

  • 将 PDF 页面转换为图像和文本
  • 生成图像和文本内容的嵌入模型
  • 将数据存储在 Qdrant 中,每个文档页面有两个向量

3. 检索系统:

  • 使用互易等级融合 (RRF)在文本和图像向量中进行搜索
  • 默认返回前 3 个最相关的结果
  • 结果中包括原始的 base64 图像和全文

4. 视觉集成:

  • 使用OpenAI 的 GPT-4o视觉模型处理查询
  • 获取用户问题和相关页面图片
  • 根据文档上下文的证据返回答案

5. 主要工作流程:

  • 处理 PDF 文档
  • 索引 Qdrant 中的内容
  • 接受用户查询
  • 利用文本和视觉功能返回上下文答案


结果

观察我们提出的问题,看看 OpenAi 是如何做出明确回应的,如下面所示。


6

7


结论

总之,将视觉模型集成到检索增强生成(RAG)系统中代表了文档处理领域的一大进步。通过同时利用图像和文本数据,我们增强了索引和检索能力,从而获得更丰富、更贴近上下文的响应。这种创新方法不仅提高了信息检索的准确性,还提供了令人信服的证据,加强了从文件中获得的洞察力。随着我们继续探索视觉与语言模型之间的协同作用,我们将越来越有可能实现更有效、更细致的文档理解。

文章来源:https://medium.com/@manthapavankumar11/revolutionizing-rag-by-integrating-vision-models-for-enhanced-document-processing-b3aaa7ab386a
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消