如何使用自定义嵌入(如 Ollama 嵌入)实现Graph RAG

2025年01月02日 由 alex 发表 118 0

在本文中,我们将介绍如何使用 Neo4j、图形数据库和 Ollama 等自定义嵌入模型构建基于图形的检索增强生成 (RAG) 系统。我们将详细介绍每一步,并提供一个完整的代码示例,帮助你快速上手。


步骤概述:

  1. 设置Neo4j并连接到数据库:安装Neo4j并设置数据库凭据以便连接。
  2. 集成自定义嵌入:使用像Ollama这样的模型为文本生成嵌入。
  3. 处理PDF文档:将文档拆分成易于管理的块,并附带元数据以便高效处理(不仅限于PDF,还可以处理其他文档,请参阅langchain文档)。
  4. 在Neo4j中存储数据:将处理后的块及其关系(例如,关键词)保存在图数据库中。
  5. 创建并验证向量索引:设置基于嵌入的相似度搜索的向量索引。
  6. 执行向量搜索:使用嵌入来查找给定查询的相似块。
  7. 构建RAG链:使用向量存储检索相关块,并与语言模型(LLM)构建管道以回答问题。


分步指南


第一步:设置Neo4j并连接到数据库

首先,确保Neo4j已安装并在本地或服务器上运行。安装neo4j Python包并配置你的连接:


from langchain_community.graphs import Neo4jGraph
url = "bolt://localhost:7687"
username = "neo4j"
password = "your_password"
graph = Neo4jGraph(url=url, username=username, password=password)


第二步:集成自定义嵌入

我们将使用langchain_ollama库为我们的文本生成嵌入。初始化嵌入模型:


from langchain_ollama import OllamaEmbeddings
embeddings = OllamaEmbeddings(model="llama3.2") # or any ollama model


第三步:处理PDF文档

定义一个函数来加载和处理PDF文档。我们将把文本拆分成块,提取元数据,并生成关键词。


from typing import List, Dict
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from collections import Counter
import re
def _extract_keywords(text: str, top_n: int = 5) -> List[str]:
    import re
    from collections import Counter
    words = re.findall(r"\w+", text.lower())
    stop_words = set(
        [
            "the",
            "a",
            "an",
            "and",
            "or",
            "but",
            "in",
            "on",
            "at",
            "to",
            "for",
            "of",
            "with",
            "by",
        ]
    )
    filtered_words = [
        word for word in words if word not in stop_words and len(word) > 2
    ]
    return [word for word, count in Counter(filtered_words).most_common(top_n)]

def load_and_process_pdf(
    pdf_path: str, chunk_size: int = 1000, chunk_overlap: int = 200
) -> List[Dict]:
    loader = PyPDFLoader(pdf_path)
    pages = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
    )
    splits = text_splitter.split_documents(pages)
    processed_chunks = []
    for i, chunk in enumerate(splits):
        metadata = {
            "chunk_id": i,
            "source": pdf_path,
            "page_number": chunk.metadata.get("page", None),
            "total_length": len(chunk.page_content),
            "keywords": _extract_keywords(chunk.page_content),
            "text_preview": (
                chunk.page_content[:100] + "..."
                if len(chunk.page_content) > 100
                else chunk.page_content
            ),
        }
        processed_chunks.append({"text": chunk.page_content, "metadata": metadata})
    return processed_chunks
pdf_path = "Grokking Deep Reinforcement Learning by Miguel Morales 1.pdf"
chunks = load_and_process_pdf(pdf_path)
print(f"Total chunks created: {len(chunks)}")
for i, chunk in enumerate(chunks[:3]):
    print(f"\nChunk {i}:")
    print(f"Text Preview: {chunk['metadata']['text_preview']}")
    print(f"Keywords: {chunk['metadata']['keywords']}")
    print(f"Page Number: {chunk['metadata']['page_number']}")


第四步:在Neo4j中存储数据

将文档块及其元数据保存在Neo4j中。每个块将是一个节点,与关键词节点相连。


def create_graph_from_chunks(chunks: List[Dict]):
    graph.query("MATCH (n) DETACH DELETE n")
    create_chunk_query = """
    MERGE (chunk:Chunk {chunk_id: $chunk_id})
    ON CREATE SET
        chunk.source = $source,
        chunk.page_number = $page_number,
        chunk.total_length = $total_length,
        chunk.text_preview = $text_preview,
        chunk.full_text = $full_text
        WITH chunk
        UNWIND $keywords AS keyword
        MERGE (kw:Keyword {name: keyword})
        MERGE (chunk)-[:HAS_KEYWORD]->(kw)
        RETURN chunk
    """
    for chunk in chunks:
        graph.query(
            create_chunk_query,
            params={
                "chunk_id": chunk["metadata"]["chunk_id"],
                "source": chunk["metadata"]["source"],
                "page_number": chunk["metadata"]["page_number"],
                "total_length": chunk["metadata"]["total_length"],
                "text_preview": chunk["metadata"]["text_preview"],
                "full_text": chunk["text"],
                "keywords": chunk["metadata"]["keywords"],
            },
        )
create_graph_from_chunks(chunks[:200])
# After storing the data, create a unique constraint to ensure data integrity
graph.query(
    """
CREATE CONSTRAINT unique_chunk IF NOT EXISTS 
    FOR (c:Chunk) REQUIRE c.chunk_id IS UNIQUE
"""
)
embedding_dim = 3072


第五步:创建并验证向量索引

使用嵌入设置向量索引以进行相似度搜索。


def generate_embedding(text: str) -> List[float]:
   
    try:
        embedding = embeddings.embed_query(text)
        embedding = [float(x) for x in embedding]
        magnitude = sum(x * x for x in embedding) ** 0.5
        if magnitude > 0:
            embedding = [x / magnitude for x in embedding]
        if len(embedding) != embedding_dim:
            if len(embedding) < embedding_dim:
                embedding.extend([0.0] * (embedding_dim - len(embedding)))
            else:
                embedding = embedding[:embedding_dim]
        return embedding
    except Exception as e:
        print(f"Error generating embedding: {e}")
        return [0.0] * embedding_dim

# we create the vector index using the above function for generating embeddings
def create_vector_index(chunks: List[Dict]):
    try:
        graph.query(
            """
            DROP INDEX chunk_vector_index IF EXISTS 
        """
        )
        graph.query(
            """
            CALL db.index.vector.createNodeIndex(
                'chunk_vector_index',
                'Chunk',
                'embedding',
                $dim,
                'cosine'
            )
            """,
            params={"dim": embedding_dim},
        )
        batch_size = 10
        total_processed = 0
        for i in range(0, len(chunks), batch_size):
            batch = chunks[i : i + batch_size]
            batch_embeddings = []
            for chunk in batch:
                embedding = generate_embedding(chunk["text"])
                batch_embeddings.append(
                    {"chunk_id": chunk["metadata"]["chunk_id"], "embedding": embedding}
                )
            batch_update_query = """
            UNWIND $batch AS item
            MATCH (chunk:Chunk {chunk_id: item.chunk_id})
            SET chunk.embedding = item.embedding
            """
            graph.query(batch_update_query, params={"batch": batch_embeddings})
            total_processed += len(batch)
            print(f"Processed {total_processed}/{len(chunks)} chunks")
    except Exception as e:
        print(f"Error creating vector index: {e}")
        raise

try:
    create_vector_index(chunks[:200])
except Exception as e:
    print(f"Failed to create vector index: {e}")


第六步:执行向量搜索

查询向量索引,以找到与给定查询相似的块。


def verify_vector_index():
    query = """
    SHOW INDEXES
    YIELD name, type, labelsOrTypes, properties, options
    WHERE name = 'chunk_vector_index'
    """
    return graph.query(query)

def vector_search(query: str, top_k: int = 3) -> List[Dict]:
   
    try:
        query_embedding = embeddings.embed_query(query)
        search_query = """
        MATCH (c:Chunk)
        WITH c, vector.similarity.cosine(c.embedding, $embedding) AS score
        WHERE score > 0.7
        RETURN 
            c.chunk_id AS chunk_id,
            c.source AS source,
            c.page_number AS page_number,
            c.text_preview AS text_preview,
            c.full_text AS full_text,
            c.total_length AS total_length,
            score
        ORDER BY score DESC
        LIMIT $limit
        """
        results = graph.query(
            search_query, params={"embedding": query_embedding, "limit": top_k}
        )
        return results
    except Exception as e:
        print(f"Vector search error: {e}")
        return []

print(verify_vector_index())
print(vector_search("What is a Markov Decision Process?"))


第七步:构建RAG链

使用LangChain构建一个检索管道,将检索到的块输入到语言模型(LLM)中以回答问题。


from langchain.prompts import PromptTemplate
from langchain.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.vectorstores import Neo4jVector

neo4j_vector_store = Neo4jVector.from_existing_graph(
    embedding=embeddings,  
    url=url,
    username=username,
    password=password,
    index_name='chunk_vector_index',  
    node_label='Chunk',  
    text_node_properties=['full_text'], 
    embedding_node_property='embedding'
)
retriever = neo4j_vector_store.as_retriever()
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Initialize the Ollama model
llm = Ollama(model="llama3.2")
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
Always say "thanks for asking!" at the end of the answer.
{context}
Question: {question}
Helpful Answer:"""
custom_rag_prompt = PromptTemplate.from_template(template)
rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | custom_rag_prompt
    | llm
    | StrOutputParser()
)
rag_chain.invoke("what is a markov decision process?")


本文涵盖了从数据摄入到使用RAG系统回答查询的整个流程。用你的数据集进行实验吧!

文章来源:https://medium.com/@la_boukouffallah/how-to-create-a-graph-based-retrieval-augmented-generation-rag-with-custom-embeddings-like-5dd84ae095e9
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消