使用FastAPI和Redis缓存加速机器学习模型服务

2025年04月24日 由 alex 发表 2400 0

Redis 是一种开源的内存数据结构存储,是在机器学习应用程序中进行缓存的绝佳选择。它的速度、持久性和对各种数据结构的支持使其成为处理实时推理任务的高吞吐量需求的理想选择。


在本教程中,我们将探讨 Redis 缓存在 Machnie 学习工作流程中的重要性。我们将演示如何使用 FastAPI 和 Redis 构建强大的机器学习应用程序。本教程将介绍如何在 Windows 上安装 Redis、在本地运行它,以及将其集成到机器学习项目中。最后,我们将通过发送重复和唯一请求来测试应用程序,以验证 Redis 缓存系统是否正常运行。


为什么在机器学习中使用 Redis 缓存?

在当今快节奏的数字环境中,用户希望从机器学习应用程序中获得即时结果。例如,考虑一个电子商务平台,该平台使用推荐模型向用户推荐产品。通过实施 Redis 来缓存重复请求,该平台可以显著缩短响应时间。


当用户请求产品推荐时,系统首先检查请求是否已缓存。如果有,则缓存的响应将在微秒内返回,从而提供无缝体验。否则,模型将处理请求,生成建议,并将结果存储在 Redis 中以供将来请求使用。这种方法不仅可以提高用户满意度,还可以优化服务器资源,使模型能够有效地处理更多请求。


使用 Redis 构建网络钓鱼电子邮件分类应用程序

在这个项目中,我们将构建一个网络钓鱼电子邮件分类应用程序。该过程包括从 Kaggle 加载和处理数据集、在处理后的数据上训练机器学习模型、评估其性能、保存训练后的模型,最后构建具有 Redis 集成的 FastAPI 应用程序。


1. 设置

  • 从 Kaggle 下载网络钓鱼电子邮件检测数据集,并将其放入 'data/ ' 目录中。
  • 首先,你需要安装 Redis。在终端中运行以下命令以安装 Redis Python 客户端:


pip install redis

 

  • 如果你使用的是 Windows 且未安装适用于 Linux 的 Windows 子系统 (WSL),请按照 Microsoft 的指南启用 WSL 并从 Microsoft Store 安装 Linux 发行版(例如 Ubuntu)。
  • 设置 WSL 后,打开 WSL 终端并执行以下命令以安装 Redis:


sudo apt update
sudo apt install redis-server

 


  • 要启动 Redis 服务器,请运行:


sudo service redis-server start

 


你应该会看到一条确认消息,指示“redis-server”已成功启动。


2. 模型训练

训练脚本加载数据集、处理数据、训练模型并将其保存在本地。


import joblib
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
def main():
    # Load dataset
    df = pd.read_csv("data/Phishing_Email.csv")  # adjust the path as necessary
    # Assume dataset has columns "text" and "label"
    X = df["Email Text"].fillna("")
    y = df["Email Type"]
    # Split the dataset into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    # Create a pipeline with TF-IDF and Logistic Regression
    pipeline = Pipeline(
        [
            ("tfidf", TfidfVectorizer(stop_words="english")),
            ("clf", LogisticRegression(solver="liblinear")),
        ]
    )
    # Train the model
    pipeline.fit(X_train, y_train)
    # Save the trained model to a file
    joblib.dump(pipeline, "phishing_model.pkl")
    print("Model trained and saved as phishing_model.pkl")
if __name__ == "__main__":
    main()

 

python train.py


Model trained and saved as phishing_model.pkl

 

3. 模型评估

评估脚本加载数据集和保存的模型文件以执行模型评估。


import pandas as pd
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
import joblib
def main():
    # Load dataset
    df = pd.read_csv("data/Phishing_Email.csv")  # adjust the path as necessary
    # Assume dataset has columns "text" and "label"
    X = df["Email Text"].fillna("")
    y = df["Email Type"]
    # Split the dataset
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    # Load the trained model
    model = joblib.load("phishing_model.pkl")
    # Make predictions on the test set
    y_pred = model.predict(X_test)
    # Evaluate the model
    print("Accuracy: ", accuracy_score(y_test, y_pred))
    print("Classification Report:")
    print(classification_report(y_test, y_pred))
if __name__ == "__main__":
    main()

 

结果近乎完美,F1 分数也非常出色。


python validate.py

 

Accuracy:  0.9723860589812332
Classification Report:
                precision    recall  f1-score   support
Phishing Email       0.96      0.97      0.96      1457
    Safe Email       0.98      0.97      0.98      2273
      accuracy                           0.97      3730
     macro avg       0.97      0.97      0.97      3730
  weighted avg       0.97      0.97      0.97      3730

 

4. 使用 Redis 进行模型服务

为了提供模型,我们将使用 FastAPI 创建一个 REST API 并集成 Redis 来缓存预测。


import asyncio
import json
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
import redis.asyncio as redis
# Create an asynchronous Redis client (make sure Redis is running on localhost:6379)
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
# Load the trained model (synchronously)
model = joblib.load("phishing_model.pkl")
app = FastAPI()
# Define the request and response data models
class PredictionRequest(BaseModel):
    text: str
class PredictionResponse(BaseModel):
    prediction: str
    probability: float
@app.post("/predict", response_model=PredictionResponse)
async def predict_email(data: PredictionRequest):
    # Use the email text as a cache key
    cache_key = f"prediction:{data.text}"
    cached = await redis_client.get(cache_key)
    if cached:
        return json.loads(cached)
    # Run model inference in a thread to avoid blocking the event loop
    pred = await asyncio.to_thread(model.predict, [data.text])
    prob = await asyncio.to_thread(lambda: model.predict_proba([data.text])[0].max())
    result = {"prediction": str(pred[0]), "probability": float(prob)}
    # Cache the result for 1 hour (3600 seconds)
    await redis_client.setex(cache_key, 3600, json.dumps(result))
    return result
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
 


python serve.py


INFO:     Started server process [17640]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


你可以通过转到 URL http://localhost:8000/docs 来检查 REST API 文档


awan_accelerate_ml_model_serving_fastapi_redis_caching_2


该项目的源码、配置文件、模型和数据集可在 kingabzpro/Redis-ml-project GitHub 仓库中找到。


Redis 缓存在机器学习应用程序中的工作原理

以下是 Redis 缓存如何在我们的机器学习应用程序中运行的分步说明,以及用于说明该过程的图表:


awan_accelerate_ml_model_serving_fastapi_redis_caching_1


1. 客户端提交输入数据以请求机器学习模型进行预测。

2. 根据输入数据生成唯一标识符,以检查预测是否已经存在。

3. 系统使用生成的键查询 Redis 缓存,以搜索以前存储的预测。

  • 如果找到缓存的预测,则会在 JSON 响应中检索并返回该预测。
  • 如果未找到缓存的预测,则输入数据将传递给机器学习模型以生成新的预测。

4. 新生成的预测存储在 Redis 缓存中以备将来使用。

5. 最终结果以 JSON 格式返回给客户端。


测试网络钓鱼电子邮件分类应用程序

构建完我们的网络钓鱼电子邮件分类应用程序后,是时候测试其功能了。在本节中,我们将通过使用 'cURL' 命令发送多封电子邮件文本并分析响应来评估应用程序。此外,我们将验证 Redis 数据库,以确保缓存系统按预期工作。


 使用 CURL 命令测试 API

为了测试 API,我们将向 '/predict' 端点发送 5 个请求。其中,三个请求将包含唯一的电子邮件文本,而另外两个请求将是之前发送的电子邮件的副本。这将允许我们验证预测准确性和缓存机制。


echo "\n===== Testing API Endpoint with 5 Requests =====\n"
# First unique email
echo "\n----- Request 1 (First unique email) -----"
curl -X 'POST' \
  'http://localhost:8000/predict' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'
# Second unique email
echo "\n\n----- Request 2 (Second unique email) -----"
curl -X 'POST' \
  'http://localhost:8000/predict' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'
# First duplicate (same as first email)
echo "\n\n----- Request 3 (Duplicate of first email - should be cached) -----"
curl -X 'POST' \
  'http://localhost:8000/predict' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'
# Third unique email
echo "\n\n----- Request 4 (Third unique email) -----"
curl -X 'POST' \
  'http://localhost:8000/predict' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "text": "congratulations you have won a free iphone, click here to claim your prize now before it expires"
}'
# Second duplicate (same as second email)
echo "\n\n----- Request 5 (Duplicate of second email - should be cached) -----"
curl -X 'POST' \
  'http://localhost:8000/predict' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'
echo "\n\n===== Test Complete =====\n"
echo "Now run 'python check_redis.py' to verify the Redis cache entries"

 

当你运行上述脚本时,API 应返回每封电子邮件的预测。对于重复请求,应从 Redis 缓存中检索响应,以确保更快的响应时间。


sh test.sh

 

\n===== Testing API Endpoint with 5 Requests =====\n
\n----- Request 1 (First unique email) -----
{"prediction":"Safe Email","probability":0.7791625553383463}\n\n----- Request 2 (Second unique email) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}\n\n----- Request 3 (Duplicate of first email - should be cached) -----
{"prediction":"Safe Email","probability":0.7791625553383463}\n\n----- Request 4 (Third unique email) -----
{"prediction":"Phishing Email","probability":0.9169092144856761}\n\n----- Request 5 (Duplicate of second email - should be cached) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}\n\n===== Test Complete =====\n
Now run 'python check_redis.py' to verify the Redis cache entries
 


验证 Redis 缓存

为了确认缓存系统是否正常工作,我们将使用 Python 脚本 'check_redis.py' 来检查 Redis 数据库。此脚本检索缓存的预测并以表格格式显示它们。


import redis
import json
from tabulate import tabulate
def main():
    # Connect to Redis (ensure Redis is running on localhost:6379)
    redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
    # Retrieve all keys that start with "prediction:"
    keys = redis_client.keys("prediction:*")
    total_entries = len(keys)
    print(f"Total number of cached prediction entries: {total_entries}\n")
    table_data = []
    # Process only the first 5 entries
    for key in keys[:5]:
        # Remove the 'prediction:' prefix to get the original email text
        email_text = key.replace("prediction:", "", 1)
        # Retrieve the cached value
        value = redis_client.get(key)
        try:
            data = json.loads(value)
        except json.JSONDecodeError:
            data = {}
        prediction = data.get("prediction", "N/A")
        # Display only the first 7 words of the email text
        words = email_text.split()
        truncated_text = " ".join(words[:7]) + ("..." if len(words) > 7 else "")
        table_data.append([truncated_text, prediction])
    # Print table using tabulate (only two columns now)
    headers = ["Email Text (First 7 Words)", "Prediction"]
    print(tabulate(table_data, headers=headers, tablefmt="pretty"))
if __name__ == "__main__":
    main()
 


当你运行脚本时,它将以表格格式显示缓存条目的数量和缓存的预测。check_redis.py


python check_redis.py

 

Total number of cached prediction entries: 3
+--------------------------------------------------+----------------+
|            Email Text (First 7 Words)            |   Prediction   |                            
+--------------------------------------------------+----------------+
|  congratulations you have won a free iphone,...  | Phishing Email |
| urgent action required: your account has been... | Phishing Email |
|      todays floor meeting you may get a...       |   Safe Email   |
+--------------------------------------------------+----------------+


总结

通过测试具有多个请求的网络钓鱼电子邮件分类应用程序,我们成功地证明了该 API 可以准确识别网络钓鱼电子邮件,同时使用 Redis 高效缓存重复请求。这种缓存机制通过减少重复输入的冗余计算来显著提高性能,这在 API 处理大量流量的实际应用程序中尤其有用。


尽管这是一个相对简单的机器学习模型,但在处理更大、更复杂的模型(例如图像识别)时,缓存的好处变得更加明显。例如,如果你正在部署大规模图像分类模型,则缓存频繁处理的输入的预测可以节省大量计算资源并显著缩短响应时间。

文章来源:https://www.kdnuggets.com/accelerate-machine-learning-model-serving-with-fastapi-and-redis-caching
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消