在过去十年中,AI 研究和工具的惊人进步催生了更精确、更可靠的机器学习模型,以及将AI功能集成到现有应用变得更加简单的库和框架。
然而,在高要求的生产环境中,大规模推理仍然是相当大的挑战。例如,假设我们有一个简单的搜索服务,它接收一些关键词,然后使用语言模型将它们嵌入到一个向量中,并在某个向量数据库中搜索类似的文本或文档。这是一个非常流行的用例,也是 RAG 架构的核心部分,通常用于将生成式 AI 应用于特定领域的知识和数据。
就其本身而言,这似乎是一个相对简单的用例。有许多开源的语言模型和模型中心,我们可以在几行代码中使用它们作为嵌入模型。如果我们进一步假设我们需要存储和查询的向量数量相对适中(例如,少于 1M),那么有很多简单的向量存储和搜索选项:从纯内存存储到像 Postgres、Redis 或 Elastic 这样的数据库。
但是,如果我们的服务需要每秒处理数千甚至数十万个请求怎么办?如果我们需要在每个请求上维持相对较低(低于一秒)的延迟怎么办?如果我们需要快速扩容以应对请求高峰怎么办?
虽然我们的用例确实非常简单,但规模和负载要求无疑使它成为一个挑战。可扩展的高吞吐量系统通常基于多个小型高性能二进制文件实例,这些二进制文件可以快速引导、扩展和服务请求。这在机器学习系统的上下文中,特别是在深度学习中带来了挑战,因为通用库通常笨重而笨拙,部分原因是大多数库都是用 Python 实现的,在苛刻的环境中不容易扩展。
因此,面对这些挑战,我们要么选择使用一些付费的服务平台来处理规模问题,要么我们将不得不使用多种技术创建我们自己的专用服务层。
为了应对这些挑战,Hugging Face 引入了 Candle 框架,被描述为“一个以性能和易用性为中心的极简主义机器学习框架”。Candle 使我们能够使用一个类似 torch 的 API 在 Rust 中构建健壮且轻量级的模型推理服务。基于 Candle 的推理服务将容易扩展,快速引导,并且以极快的速度处理请求,这使它更适合应对规模和韧性挑战的云原生无服务器环境。
本帖的目的是以之前描述的流行用例为例,展示如何使用 Candle 框架实现端到端的解决方案。我们将深入探讨一个基于 Candle 和 Axum(作为我们的 web 框架)的相对简单但健壮的向量嵌入和搜索 REST 服务的实现。我们将使用一个特定的新闻标题数据集,但代码可以非常容易地扩展到任何文本数据集。
这将是一个非常有实践性的帖子。第2节介绍了我们要实现的推理服务的主要设计或流程,以及我们将开发和使用的涉及组件。第3节关注 Candle 框架,并展示如何使用 Bert 模型实现向量嵌入和搜索功能。第4节展示如何使用 Axum 将模型推理功能包装成 REST web 服务。第5节解释如何创建我们的服务将需要的实际嵌入和工件。第6节总结。
从准备实现的推理服务的主要构建块开始。主要要求是创建一个 HTTP REST 端点,该端点将接收包含一些关键字的文本查询,并响应与搜索查询最相似的前 5 条新闻标题。
对于这项任务,我们将使用 Bert 作为语言模型,因为它通常在文档嵌入任务上表现良好。在离线批处理过程中(将在第5节中解释),我们将使用 Bert 嵌入大约 20K 新闻标题或文档,并为每个标题创建一个向量嵌入,该服务将使用这些嵌入来搜索匹配项。无论是嵌入还是文本标题都将序列化到二进制格式,每个服务实例将加载这些格式以便服务查询请求。
如上所述,接收到带有搜索查询的请求后,服务将首先使用语言模型将搜索查询嵌入为一个向量。接下来,它将搜索预先加载的嵌入,以找到 N 个最相似的向量(每个向量代表一个新闻标题)。最后,它将使用最相似向量的索引来获取它所代表的实际文本标题,使用映射文件。
我们将通过创建一个名为 BertInferenceModel 的模块或 rust 库来实现这一点,它将提供主要功能:模型加载、句子嵌入和向量搜索。该模块将被 Axum 网络框架实现的 REST 服务使用。下一节将着重实现模块,而后续部分将专门讨论网络服务本身。
请注意,接下来的部分包含许多代码示例,为了清晰起见,它们仅显示实现的主要功能。
本节将重点放在实现模块上,该模块将作为 Candle 库 API 的抽象层。我们将以名为 BertInferenceModel 的结构实现此功能,该结构包含 3 个主要功能:模型加载、推理(或句子嵌入)和使用余弦相似度进行简单的向量搜索。
pub struct BertInferenceModel {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
embeddings: Tensor,
}
BertInferenceModel 将封装我们从 Hugging Face 存储库下载的 Bert 模型和分词器,基本上封装了它们的功能。
从 Hugging Face Hub 加载模型
BertInferenceModel 将通过一个名为 load() 的函数实例化,该函数将返回一个新的结构实例,该实例已加载相关模型和分词器,并准备好进行推理任务。load 函数的参数包括我们希望加载的模型的名称和修订版本(我们将使用 Bert 句子转换器)和嵌入文件路径。
pub fn load(
model_name: &str,
revision: &str,
embeddings_filename: &str,
embeddings_key: &str,
) -> anyhow::Result<Self> {}
如下面的 load 函数代码所示,加载模型涉及创建一个包含 Hugging Face 相关存储库属性(例如,名称和修订版)的 Repo 结构,然后创建一个 API 结构实际连接到存储库并下载相关文件(模型权重由 HuggingFace 的 safetensors 格式表示)。
api.get 函数返回相关文件的本地名称(无论是下载还是仅从缓存中读取)。文件将仅下载一次,而后续对 api.get 的调用将简单地使用缓存版本。我们使用分词器配置文件实例化一个分词器结构,并使用权重文件(以 safetensors 格式)和配置文件构建我们的模型。
在我们加载了模型和分词器之后,我们终于可以加载实际的嵌入文件,我们将使用这个文件来搜索匹配项。稍后我将展示如何使用相同的模型生成它,然后将其序列化到文件中。使用 HuggingFace 的 safetensors 模块将嵌入作为 Tensor 加载非常简单,我们需要的只是文件名和保存 Tensor 时给定的键。
现在我们有了一个加载的模型和分词器,以及内存中的嵌入向量,我们完成了返回给调用函数的 BertInferenceModel 的初始化,并可以继续实现推理方法。
句子推理和嵌入
推理功能也相当简单。我们首先使用加载的分词器对句子进行编码(第 5 行)。encode() 函数返回一个 Encoding 结构,其 get_ids() 函数返回我们想要嵌入的句子中单词的数字表示形式的数组。接下来,我们将令牌 id 数组包装在一个张量中,我们可以将张量提供给我们的嵌入模型,并使用模型的 forward 函数来获取代表该句子的嵌入向量(第 10 行)。
我们在第 12 行从嵌入模型获得的向量的维度是 [128, 384]。之所以如此,是因为 Bert 使用大小为 384 的向量来表示每个标记或单词,而句子向量的最大输入长度为 128(因为我们的输入只有几个单词,所以大部分是填充)。换句话说,我们实际上得到了一个代表我们句子中每个单词的大小为 384 的向量,除了填充和其他指令标记。
接下来,我们需要将句子向量从大小为[128, 384]的张量压缩为一个大小为[1, 384]的单一向量,该向量将代表或捕捉句子的"本质",这样我们就可以在嵌入中与其他句子进行匹配,并找到与之相似的句子。为此,部分原因是我们处理的输入是短关键字而不是长句子,我们将使用最大池化(max pooling),它通过在给定张量的每个维度上取最大值来创建一个新向量,以便捕获每个维度中最突出的特征。如下所示,使用Candle的API实现这一点相当简单。最后,我们使用L2规范化以避免偏差,并通过确保所有向量具有相同的大小,来改善余弦相似度度量。以下是池化和规范化函数的实际实现情况。
测量向量相似度
虽然这不是直接与Candle库相关,但我们的模块还将提供一个向量搜索实用方法,它将接收一个向量,并使用其内部嵌入来返回最相似向量的索引。
这是非常简单地实现的:我们首先创建一个元组集合(第7行),其中元组的第一个成员将代表相关文本的索引,第二个成员将代表余弦相似度分数。然后,我们遍历所有索引,并测量每个索引与我们需要匹配的给定向量之间的余弦相似度。最后,我们将具有(索引,相似度分数)的元组添加到集合中,对它进行排序,并返回请求的前N个。
现在我们已经有了一个封装了主要模型功能的结构,我们需要将它封装在一个REST服务中。我们将创建一个REST端点,它有一个POST路由,该路由将接收JSON有效载荷中的几个关键字,并返回预加载嵌入中最相似向量的索引。在请求时,服务将关键字嵌入到向量中,搜索其内存嵌入中的相似性,并返回最相似向量的索引。服务将使用索引查找文本映射文件中的相应标题。
我们将使用出色的Axum Web框架实现服务。大部分相关代码是典型的Axum样板代码,所以我不会过多讲解如何使用Axum创建REST端点的细节。像许多Web框架一样,构建REST端点通常涉及创建一个Router,并在某些路由上注册一个处理函数来处理请求。然而,ML模型服务层有额外的复杂性,需要管理模型本身的状态和持久性。模型加载在性能方面可能非常昂贵,因为它涉及加载模型文件的IO(无论它是从Hugging Face的仓库还是本地)。同样,我们需要找到一种方法来缓存和重用模型来处理多个请求。
为了满足这些要求,Axum提供了一个应用程序State功能,我们可以使用它来初始化和持久化我们想要注入到每个请求上下文中的任何资产。让我们先逐行查看服务初始化的整个代码,看看这是如何工作的。
每个服务实例都将创建并加载模型封装,然后将其缓存以供其接收的每个请求重复使用。在第3行,我们通过调用load()函数来创建模型封装,来引导并加载模型。除了我们从HF加载的Bert模型的名称和版本,我们还需要指定嵌入文件的位置,我们将其加载到内存中,以便搜索相似向量,以及我们在创建嵌入时使用的关键字。
除了实际的模型,我们还需要为每个请求重新缓存映射文件。在服务器使用模型嵌入关键字后,它搜索其嵌入文件中最相似的向量,然后返回它们的索引。然后,服务器使用映射文件来提取与最相似向量索引对应的实际文本。在一个更健壮的生产系统中,在从模型封装中接收到最相似向量的索引后,服务将从一些快速访问数据库中获取实际文本,虽然在我们这个例子中,预加载在文件中存储的字符串列表就足够了。在第10行,我们加载了之前保存为二进制文件的列表。
现在我们有了两个需要缓存和重用的资产——模型(封装)和映射文件。Axum让我们可以通过使用Arc(一个线程安全的引用计数指针)来实现这一点,每个请求都将共享这个指针。正如你在第15行看到的,我们围绕由模型封装和映射文件组成的元组创建了一个新的Arc。在第17-19行,我们创建了一个新的HTTP路由到将处理每个请求的函数。
为了缓存元组并使其对每个请求可用,我们使用with_state(state)函数将其注入到相关请求上下文中。让我们看看具体是怎么做的。
处理请求
我们的服务将处理携带以下有效载荷的HTTP POST请求,有效载荷将包含关键字和我们想要接收的相似向量或标题的数量。
{
"text": "europe climate change storm",
"num_results":5
}
我们将实现处理函数的对应请求和响应结构体,Axum将在需要时处理序列化和反序列化。
#[derive(Deserialize)]
struct ReqPayload {
keywords: String,
num_results: u32,
}
#[derive(Serialize)]
struct ResPayload {
text: Vec<String>,
}
接下来,我们可以继续处理函数本身。处理程序将接受两个参数:我们之前初始化的应用程序状态(Axum将负责将其注入到每个函数调用中),以及我们之前定义的请求结构体。
处理每个请求将包括4个主要阶段,到现在应该很直接。在第5行,我们首先提取一个引用,指向包含模型和映射文件引用的状态元组。在第6行,我们使用模型将关键字嵌入到一个向量中。在第9行,我们搜索N个最相似的向量。score_vector_similarity()函数返回一个元组向量,每个元组包含一个索引和余弦相似度分数。最后,我们遍历索引元组,从映射文件中提取与索引对应的字符串,并将其包装在响应有效载荷结构体中。
尽管这不一定说明什么,但我在我的Mac上测试了大约20K向量的嵌入,并得到了平均响应时间为100ms的不错结果。对于基于Bert的向量嵌入+向量搜索来说还不错。
curl -s -w "\\nTotal time: %{time_total}s\\n" \
-X POST http://localhost:3000/similar \
-H "Content-Type: application/json" \
-d '{"text": "self driving cars navigation", "num_results": 3}' | jq
{
"text": [
"Item:Stereo Acoustic Perception ... (index: 8441 score:0.8516491)",
"Item:Vision-based Navigation of ... (index: 7253 score:0.85097575)",
"Item:Learning On-Road Visual ..... (index: 30670 score:0.8500275)"
]
}
Total time: 0.091665s
(这个例子是使用在Arxiv论文摘要数据集上生成的嵌入创建的。实际数据集在这里可以在公共领域许可下获得。)
在我们结束之前,流程中还有一个最后的组件需要涵盖。到目前为止,我们已经假设了一个嵌入文件的存在,我们在其中搜索相似向量。然而,我尚未解释如何自己创建嵌入文件。
回想一下,在上一节创建的结构——BertInferenceModel中,已经包含了一个函数,可以将一组关键字嵌入到一个向量中。当我们需要嵌入多组关键字时,我们需要做的就是批量处理它们。
使用BertInferenceModel的主要不同之处在于使用分词器的encode_batch函数而不是encode,它接收一个字符串向量而不是一个字符串。然后我们简单地将所有向量堆叠成一个单一张量,然后像我们用单一向量嵌入一样将其输入到模型的forward()函数(您可以在下面链接的配套仓库中看到该函数的完整源代码)。
一旦我们有了这样的函数可以嵌入多个字符串,那么嵌入生成器本身就很简单了。它使用rayon crate以并行方式嵌入文本文件,然后将结果堆叠在一起以创建一个单一张量。最后,它使用safetensors格式将嵌入写入磁盘。嵌入是该管道中的一个重要资产,因为它需要复制到每个服务实例。
现在我们可以得出结论:-)
机器学习工程中最大的挑战之一是规模化地运行推理。AI绝不是轻量级的,因此规模化推理负荷往往变得非常昂贵或过度设计。这正是Hugging Face的Candle库试图解决的挑战。通过在Rust中使用类似Torch的API,它使我们能够创建一个轻便且快速的模型服务层,可以轻松扩展并在无服务器环境中运行。
这篇文章介绍了如何使用Candle创建一个端到端的模型推理层,该层可以为向量嵌入和搜索提供请求服务。解释了如何将Bert/语句转换器模型封装在一个内存占用小的库中,并将其用在基于Axum的REST服务中。
Hugging Face的Candle库的真正价值在于它能够弥合强大的机器学习功能和有效资源利用之间的差距。通过利用Rust的性能和安全特性,Candle为更可持续、成本效益更高的机器学习解决方案铺平了道路。这对于寻求在没有开销的情况下规模化部署AI的组织特别有利。我希望借助Candle,我们将看到一系列新的机器学习应用,这些应用不仅性能高,而且更轻便,更能适应各种环境。