Redis 是一种开源的内存数据结构存储,是在机器学习应用程序中进行缓存的绝佳选择。它的速度、持久性和对各种数据结构的支持使其成为处理实时推理任务的高吞吐量需求的理想选择。
在本教程中,我们将探讨 Redis 缓存在 Machnie 学习工作流程中的重要性。我们将演示如何使用 FastAPI 和 Redis 构建强大的机器学习应用程序。本教程将介绍如何在 Windows 上安装 Redis、在本地运行它,以及将其集成到机器学习项目中。最后,我们将通过发送重复和唯一请求来测试应用程序,以验证 Redis 缓存系统是否正常运行。
为什么在机器学习中使用 Redis 缓存?
在当今快节奏的数字环境中,用户希望从机器学习应用程序中获得即时结果。例如,考虑一个电子商务平台,该平台使用推荐模型向用户推荐产品。通过实施 Redis 来缓存重复请求,该平台可以显著缩短响应时间。
当用户请求产品推荐时,系统首先检查请求是否已缓存。如果有,则缓存的响应将在微秒内返回,从而提供无缝体验。否则,模型将处理请求,生成建议,并将结果存储在 Redis 中以供将来请求使用。这种方法不仅可以提高用户满意度,还可以优化服务器资源,使模型能够有效地处理更多请求。
使用 Redis 构建网络钓鱼电子邮件分类应用程序
在这个项目中,我们将构建一个网络钓鱼电子邮件分类应用程序。该过程包括从 Kaggle 加载和处理数据集、在处理后的数据上训练机器学习模型、评估其性能、保存训练后的模型,最后构建具有 Redis 集成的 FastAPI 应用程序。
1. 设置
pip install redis
sudo apt update
sudo apt install redis-server
sudo service redis-server start
你应该会看到一条确认消息,指示“redis-server”已成功启动。
2. 模型训练
训练脚本加载数据集、处理数据、训练模型并将其保存在本地。
import joblib
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
def main():
# Load dataset
df = pd.read_csv("data/Phishing_Email.csv") # adjust the path as necessary
# Assume dataset has columns "text" and "label"
X = df["Email Text"].fillna("")
y = df["Email Type"]
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Create a pipeline with TF-IDF and Logistic Regression
pipeline = Pipeline(
[
("tfidf", TfidfVectorizer(stop_words="english")),
("clf", LogisticRegression(solver="liblinear")),
]
)
# Train the model
pipeline.fit(X_train, y_train)
# Save the trained model to a file
joblib.dump(pipeline, "phishing_model.pkl")
print("Model trained and saved as phishing_model.pkl")
if __name__ == "__main__":
main()
python train.py
Model trained and saved as phishing_model.pkl
3. 模型评估
评估脚本加载数据集和保存的模型文件以执行模型评估。
import pandas as pd
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
import joblib
def main():
# Load dataset
df = pd.read_csv("data/Phishing_Email.csv") # adjust the path as necessary
# Assume dataset has columns "text" and "label"
X = df["Email Text"].fillna("")
y = df["Email Type"]
# Split the dataset
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Load the trained model
model = joblib.load("phishing_model.pkl")
# Make predictions on the test set
y_pred = model.predict(X_test)
# Evaluate the model
print("Accuracy: ", accuracy_score(y_test, y_pred))
print("Classification Report:")
print(classification_report(y_test, y_pred))
if __name__ == "__main__":
main()
结果近乎完美,F1 分数也非常出色。
python validate.py
Accuracy: 0.9723860589812332
Classification Report:
precision recall f1-score support
Phishing Email 0.96 0.97 0.96 1457
Safe Email 0.98 0.97 0.98 2273
accuracy 0.97 3730
macro avg 0.97 0.97 0.97 3730
weighted avg 0.97 0.97 0.97 3730
4. 使用 Redis 进行模型服务
为了提供模型,我们将使用 FastAPI 创建一个 REST API 并集成 Redis 来缓存预测。
import asyncio
import json
import joblib
from fastapi import FastAPI
from pydantic import BaseModel
import redis.asyncio as redis
# Create an asynchronous Redis client (make sure Redis is running on localhost:6379)
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
# Load the trained model (synchronously)
model = joblib.load("phishing_model.pkl")
app = FastAPI()
# Define the request and response data models
class PredictionRequest(BaseModel):
text: str
class PredictionResponse(BaseModel):
prediction: str
probability: float
@app.post("/predict", response_model=PredictionResponse)
async def predict_email(data: PredictionRequest):
# Use the email text as a cache key
cache_key = f"prediction:{data.text}"
cached = await redis_client.get(cache_key)
if cached:
return json.loads(cached)
# Run model inference in a thread to avoid blocking the event loop
pred = await asyncio.to_thread(model.predict, [data.text])
prob = await asyncio.to_thread(lambda: model.predict_proba([data.text])[0].max())
result = {"prediction": str(pred[0]), "probability": float(prob)}
# Cache the result for 1 hour (3600 seconds)
await redis_client.setex(cache_key, 3600, json.dumps(result))
return result
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
python serve.py
INFO: Started server process [17640]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
你可以通过转到 URL http://localhost:8000/docs 来检查 REST API 文档
该项目的源码、配置文件、模型和数据集可在 kingabzpro/Redis-ml-project GitHub 仓库中找到。
Redis 缓存在机器学习应用程序中的工作原理
以下是 Redis 缓存如何在我们的机器学习应用程序中运行的分步说明,以及用于说明该过程的图表:
1. 客户端提交输入数据以请求机器学习模型进行预测。
2. 根据输入数据生成唯一标识符,以检查预测是否已经存在。
3. 系统使用生成的键查询 Redis 缓存,以搜索以前存储的预测。
4. 新生成的预测存储在 Redis 缓存中以备将来使用。
5. 最终结果以 JSON 格式返回给客户端。
测试网络钓鱼电子邮件分类应用程序
构建完我们的网络钓鱼电子邮件分类应用程序后,是时候测试其功能了。在本节中,我们将通过使用 'cURL' 命令发送多封电子邮件文本并分析响应来评估应用程序。此外,我们将验证 Redis 数据库,以确保缓存系统按预期工作。
使用 CURL 命令测试 API
为了测试 API,我们将向 '/predict' 端点发送 5 个请求。其中,三个请求将包含唯一的电子邮件文本,而另外两个请求将是之前发送的电子邮件的副本。这将允许我们验证预测准确性和缓存机制。
echo "\n===== Testing API Endpoint with 5 Requests =====\n"
# First unique email
echo "\n----- Request 1 (First unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'
# Second unique email
echo "\n\n----- Request 2 (Second unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'
# First duplicate (same as first email)
echo "\n\n----- Request 3 (Duplicate of first email - should be cached) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "todays floor meeting you may get a few pointed questions about today article about lays potential severance of $ 80 mm"
}'
# Third unique email
echo "\n\n----- Request 4 (Third unique email) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "congratulations you have won a free iphone, click here to claim your prize now before it expires"
}'
# Second duplicate (same as second email)
echo "\n\n----- Request 5 (Duplicate of second email - should be cached) -----"
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"text": "urgent action required: your account has been compromised, click here to reset your password immediately"
}'
echo "\n\n===== Test Complete =====\n"
echo "Now run 'python check_redis.py' to verify the Redis cache entries"
当你运行上述脚本时,API 应返回每封电子邮件的预测。对于重复请求,应从 Redis 缓存中检索响应,以确保更快的响应时间。
sh test.sh
\n===== Testing API Endpoint with 5 Requests =====\n
\n----- Request 1 (First unique email) -----
{"prediction":"Safe Email","probability":0.7791625553383463}\n\n----- Request 2 (Second unique email) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}\n\n----- Request 3 (Duplicate of first email - should be cached) -----
{"prediction":"Safe Email","probability":0.7791625553383463}\n\n----- Request 4 (Third unique email) -----
{"prediction":"Phishing Email","probability":0.9169092144856761}\n\n----- Request 5 (Duplicate of second email - should be cached) -----
{"prediction":"Phishing Email","probability":0.8895319031315131}\n\n===== Test Complete =====\n
Now run 'python check_redis.py' to verify the Redis cache entries
验证 Redis 缓存
为了确认缓存系统是否正常工作,我们将使用 Python 脚本 'check_redis.py' 来检查 Redis 数据库。此脚本检索缓存的预测并以表格格式显示它们。
import redis
import json
from tabulate import tabulate
def main():
# Connect to Redis (ensure Redis is running on localhost:6379)
redis_client = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
# Retrieve all keys that start with "prediction:"
keys = redis_client.keys("prediction:*")
total_entries = len(keys)
print(f"Total number of cached prediction entries: {total_entries}\n")
table_data = []
# Process only the first 5 entries
for key in keys[:5]:
# Remove the 'prediction:' prefix to get the original email text
email_text = key.replace("prediction:", "", 1)
# Retrieve the cached value
value = redis_client.get(key)
try:
data = json.loads(value)
except json.JSONDecodeError:
data = {}
prediction = data.get("prediction", "N/A")
# Display only the first 7 words of the email text
words = email_text.split()
truncated_text = " ".join(words[:7]) + ("..." if len(words) > 7 else "")
table_data.append([truncated_text, prediction])
# Print table using tabulate (only two columns now)
headers = ["Email Text (First 7 Words)", "Prediction"]
print(tabulate(table_data, headers=headers, tablefmt="pretty"))
if __name__ == "__main__":
main()
当你运行脚本时,它将以表格格式显示缓存条目的数量和缓存的预测。check_redis.py
python check_redis.py
Total number of cached prediction entries: 3
+--------------------------------------------------+----------------+
| Email Text (First 7 Words) | Prediction |
+--------------------------------------------------+----------------+
| congratulations you have won a free iphone,... | Phishing Email |
| urgent action required: your account has been... | Phishing Email |
| todays floor meeting you may get a... | Safe Email |
+--------------------------------------------------+----------------+
总结
通过测试具有多个请求的网络钓鱼电子邮件分类应用程序,我们成功地证明了该 API 可以准确识别网络钓鱼电子邮件,同时使用 Redis 高效缓存重复请求。这种缓存机制通过减少重复输入的冗余计算来显著提高性能,这在 API 处理大量流量的实际应用程序中尤其有用。
尽管这是一个相对简单的机器学习模型,但在处理更大、更复杂的模型(例如图像识别)时,缓存的好处变得更加明显。例如,如果你正在部署大规模图像分类模型,则缓存频繁处理的输入的预测可以节省大量计算资源并显著缩短响应时间。