【指南】从头开始掌握缓存增强生成 (CAG)

2025年01月15日 由 alex 发表 337 0

逐步教程:在Python中实现缓存增强的生成(Cache-Augmented Generation)。在本文中,我们将简要探讨缓存增强的生成(CAG)背后的理论,以理解其核心概念。


基本概念

简而言之,RAG的策略涉及将外部知识编码为向量并存储在向量数据库中。在查询大型语言模型(LLM)之前,输入查询也被编码为向量,并检索与查询向量相似度最高的知识向量。然后,将这些检索到的信息添加到给LLM的提示中,以生成响应。这种方法功能强大,并且在理论上可以扩展到非常大的知识源。然而,它在文档选择方面引入了潜在错误,这取决于文档如何被分块以及用于创建向量数据库的嵌入模型的质量。


CAG提供了一种更简单的方法。如果你的外部知识库规模可控,CAG涉及直接将整个知识库与查询一起包含在提示中。然后,LLM可以处理查询和知识库以生成响应。这种策略消除了对向量数据库和相似度计算的需求。CAG受益于LLM的最新进展,如Llama、Mixtral和Gemma等模型,这些模型在更大的上下文窗口下表现出更高的性能和效率。


然而,如果CAG的朴素实现是在每个提示中都包含整个知识库,这将导致非常慢的推理时间。这是因为LLM通常一次生成一个标记,并且每个预测都依赖于整个前面的上下文。这就是CAG的关键创新之处:通过将知识库预加载到模型的上下文中,并使用动态缓存策略(特别是键值缓存),我们可以避免为每个新查询重复处理知识库。模型有效地“记住”了处理过的知识,使其能够在推理期间仅关注查询。


以下是基本概念概述:


10


代码教程:实现CAG

本节将深入探讨CAG概念的实际实现。我们的代码将基于hhhuang的工作,特别是他们GitHub仓库中的kvache.py脚本。核心思想来自原始的CAG研究论文。


代码将使用与研究论文相同的LLM模型:“Llama-3.1B-Instruct”,并且已在Kaggle笔记本环境中成功测试。这确保了代码功能可以轻松适应你自己的项目。


我们将从设置环境开始,然后深入探讨kvache.py脚本的细节。该脚本可能侧重于创建和利用键值缓存,以在所选的LLM中实现CAG功能。


在深入代码本身之前,让我们确保已安装必要的库:


#!pip install -U bitsandbytes 
import torch
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoModelForCausalLM)
import bitsandbytes as bnb
from transformers.cache_utils import DynamicCache


使用的版本


transformers  : 4.44.2
bitsandbytes  : 0.45.0
torch         : 2.4.1+cu121
### GPU 
Kaggle GPU T4x2
CUDA Version   : 12.6
Driver Version : 560.35.03 


登录Hugging Face

需要使用LLama-3.1模型

  1. 创建账户:访问https://huggingface.co/并注册一个免费账户。
  2. 生成访问令牌:进入你的个人资料设置(右上角)-> 访问令牌 -> 创建一个新令牌。此令牌授予对Hugging Face功能的访问权限,如上传微调模型。


from huggingface_hub import notebook_login
notebook_login()


准备知识

为了这次演示,我们将为模型提供一些背景信息以供其使用。这些信息包括模拟的临床报告和与医疗器械相关的事件。需要强调的是,所有这些数据都是完全合成的,并非基于真实事件。


这些知识非常特定于医疗器械领域。如果没有首先提供这个上下文,一个标准的预训练LLM将无法回答有关这些报告的问题。换句话说,模型需要这些特定的知识来理解和回答有关报告的问题。


knowledge = """
Incident 1: Glucose Meter Malfunction Leads to Hyperglycemia
    Patient: John Miller, 62 years old
    Device: GlucoFast Ultra glucose meter, manufactured by MediTech Solutions Inc.
    Incident: Mr. Miller, a known diabetic, used his GlucoFast Ultra meter to check his blood glucose level before dinner. 
    The meter displayed a reading of 90 mg/dL, which was within his target range. 
    However, shortly after eating, he began experiencing symptoms of hyperglycemia, including excessive thirst, frequent urination, and blurred vision. 
    A subsequent check with a hospital-grade blood glucose analyzer revealed a blood glucose level of 250 mg/dL.
    Investigation: It was determined that the GlucoFast Ultra meter was providing falsely low readings, likely due to a faulty batch of test strips. MediTech Solutions Inc. 
    issued a recall for the affected lot of test strips.
    Outcome: Mr. Miller was treated for hyperglycemia and recovered fully.
Incident 2: Heart Pump Failure During Surgery
    Patient: Jane Doe, 58 years old
    Device: CardioAssist Ventricular Assist Device (VAD), manufactured by HeartLife Technologies.
    Incident: Ms. Doe was undergoing a heart transplant surgery. During the procedure, 
    the CardioAssist VAD, which was supporting her circulation, suddenly malfunctioned, causing a critical drop in blood pressure.
    Investigation: The investigation revealed a software glitch in the VAD's control system, causing it to unexpectedly shut down. 
    HeartLife Technologies issued a software update to address the issue.
    Outcome: The surgical team was able to stabilize Ms. Doe and complete the transplant successfully. 
    However, the incident caused a delay in the procedure and increased the risk of complications.
Incident 3: X-Ray Machine Overexposure
    Patient: Robert Smith, 45 years old
    Device: XR-5000 Digital X-Ray System, manufactured by Imaging Dynamics Corp.
    Incident: Mr. Smith was undergoing a routine chest X-ray. Due to a malfunction in the X-Ray system's calibration, 
    he received a significantly higher dose of radiation than intended.
    Investigation: The investigation revealed a faulty sensor in the X-ray machine's control panel, 
    which led to an incorrect radiation output. Imaging Dynamics Corp. issued a service bulletin to inspect and recalibrate all affected XR-5000 systems.
    Outcome: Mr. Smith was informed of the overexposure and monitored for any potential long-term effects of the increased radiation dose. 
    knowledge = The immediate risk was considered low, but long-term risks could not be fully excluded.
"""


预加载知识

现在,我们将创建一个简单的函数,将准备好的知识预加载到模型中。此过程使用Hugging Face的动态缓存机制(特别是键值缓存)来高效地存储处理过的知识。对于缓存增强生成(CAG)而言,这一预加载步骤至关重要,因为它允许模型“记住”知识,并在推理过程中避免冗余计算。


该函数基本上将准备好的知识文本作为输入,并通过模型处理一次。然后,将注意力层产生的键值状态存储在缓存中。随后的查询可以利用这些缓存的信息,从而显著加快生成过程。


本质上,该函数返回代表预处理知识的“键”和“值”,这些键值在生成阶段准备就绪可供使用。这就是模型如何高效地融入外部知识,而无需对每个新查询重新处理的方式。


def preprocess_knowledge(
    model,
    tokenizer,
    prompt: str) -> DynamicCache:
    """
    Prepare knowledge kv cache for CAG.
    Args:
        model: HuggingFace model with automatic device mapping
        tokenizer: HuggingFace tokenizer
        prompt: The knowledge to preprocess, which is basically a prompt
    Returns:
        DynamicCache: KV Cache
    """
    embed_device = model.model.embed_tokens.weight.device # check which device are used 
    input_ids    = tokenizer.encode(prompt, return_tensors="pt").to(embed_device)
    past_key_values = DynamicCache()
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False)
    return outputs.past_key_values


准备知识和创建键值缓存数据

在生成键值(KV)缓存数据之前,我们需要格式化提示并向模型提供指令。这个提示的结构,包括任何特定的指令或特殊标记,都至关重要,并且高度依赖于所选的模型。


不同的语言模型有不同的输入要求。一些模型使用独特的特殊标记(如<s>、[CLS]或<bos>)来表示序列的开始、分隔输入的不同部分或指示特定任务。因此,根据所使用的特定模型定制提示和指令是至关重要的。


在我们的案例中,我们将根据所使用模型(假定为Llama-3.1-Instruct)的要求来格式化提示和指令。这将确保模型正确处理知识并生成适当的键值缓存数据。


def prepare_kvcache(documents, answer_instruction: str = None):
    # Prepare the knowledges kvcache
    if answer_instruction is None:
        answer_instruction = "Answer the question with a super short answer."
    knowledges = f"""
    <|begin_of_text|>
    <|start_header_id|>system<|end_header_id|>
    You are an medical assistant for giving short answers 
    based on given reports.<|eot_id|>
    <|start_header_id|>user<|end_header_id|>
    Context information is bellow.
    ------------------------------------------------
    {documents}
    ------------------------------------------------
    {answer_instruction}
    Question:
    """
    # Get the knowledge cache
    kv = preprocess_knowledge(model, tokenizer, knowledges)
    kv_len = kv.key_cache[0].shape[-2]
    print("kvlen: ", kv_len)
    return kv, kv_len

knowledge_cache, kv_len  = prepare_kvcache(documents =knowledge)
# kvlen:  610


将知识预加载到键值(KV)缓存后,我们存储其长度。这一点至关重要,因为查询会扩展KV缓存。为了保持后续查询仅包含预加载知识的上下文一致性,我们在每个查询后将KV缓存截断回其原始长度。这确保了每个查询都在预期的知识库上操作,防止查询之间发生不希望的交互。


查询回答

在将知识预加载到大型语言模型(LLM)的键值(KV)缓存后,我们现在可以回答有关报告的问题了。至关重要的第一步是实现一个清理函数。如上所述,该函数将负责在每个查询后将KV缓存恢复到其原始状态(仅包含预加载的知识)。


def clean_up(kv: DynamicCache, origin_len: int):
    """
    Truncate the KV Cache to the original length.
    """
    for i in range(len(kv.key_cache)):
        kv.key_cache[i] = kv.key_cache[i][:, :, :origin_len, :]
        kv.value_cache[i] = kv.value_cache[i][:, :, :origin_len, :]


这个函数处理预测过程,其中包括利用预加载的知识(存储在KV缓存中)来回答查询:


def generate(
    model,
    input_ids: torch.Tensor,
    past_key_values,
    max_new_tokens: int = 300
) -> torch.Tensor:
    """
    Generate text with greedy decoding.
    Args:
        model: HuggingFace model with automatic device mapping
        input_ids: Input token ids
        past_key_values: KV Cache for knowledge
        max_new_tokens: Maximum new tokens to generate
    """
    embed_device = model.model.embed_tokens.weight.device
    origin_ids = input_ids
    input_ids = input_ids.to(embed_device)
    output_ids = input_ids.clone()
    next_token = input_ids
    with torch.no_grad():
        for _ in range(max_new_tokens):
            outputs = model(
                input_ids=next_token, 
                past_key_values=past_key_values,
                use_cache=True
            )
            next_token_logits = outputs.logits[:, -1, :]
            next_token = next_token_logits.argmax(dim=-1).unsqueeze(-1)
            next_token = next_token.to(embed_device)
            past_key_values = outputs.past_key_values
            output_ids = torch.cat([output_ids, next_token], dim=1)
            if (next_token.item() in model.config.eos_token_id) and (_ > 0):
                break
    return output_ids[:, origin_ids.shape[-1]:]


开始预测过程

我们现在准备开始预测过程。这包括利用高效存储在键值(KV)缓存中的预加载知识来生成对用户查询的答案。


query = 'which Patient experienced issues with blood glucose meter, 
what was the problem ?'
clean_up(knowledge_cache, kv_len)
input_ids = tokenizer.encode(query, return_tensors="pt").to(model.device)
output = generate(model, input_ids, knowledge_cache)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, temperature=None)
print(f"Response of the model:\n {generated_text}")


Response of the model:
assistant
Mr. Miller experienced issues with the blood glucose meter. 
The problem was that it 
provided falsely low readings due to a faulty batch of test strips


结论

通过最终的代码片段,你现在可以向模型测试各种问题,它将基于预缓存的知识生成答案。本文提供了一个关于实现缓存增强生成(CAG)的基本且简化的概述。


本演示使用了一个包含有限数量示例的小型知识库。然而,如果你正在处理一个显著更大的数据集(例如,超过1,000个示例),则预加载模型和生成KV缓存可能会变得计算成本高昂。在这种情况下,强烈建议将生成的KV缓存数据存储到磁盘上。这样,你可以直接加载预计算的缓存,避免每次都需要重新生成,这对于大规模应用中的可扩展性至关重要。虽然对于这个小规模的演示来说这不是必需的,但对于CAG的实际、现实世界应用来说,这种优化是必不可少的。


def write_kv_cache(kv: DynamicCache, path: str):
    """
    Write the KV Cache to a file.
    """
    torch.save(kv, path)
def read_kv_cache(path: str) -> DynamicCache:
    """
    Read the KV Cache from a file.
    """
    kv = torch.load(path, weights_only=True)
    return kv

文章来源:https://medium.com/@sabaybiometzger/cache-augmented-generation-cag-from-scratch-441adf71c6a3
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消