Llama 2 是开源 LLM(语言模型)领域的重要里程碑。最大的模型及其微调变体位于Hugging Face的开源 LLM 排行榜的顶部。多个基准测试显示,它在性能上接近甚至超越了GPT-3.5。所有这些意味着开源 LLM 在复杂的 LLM 应用中越来越成为一种可行且可靠的选择,从RAG 系统到智能代理等各种应用场景。
然而,最小的 Llama-2-7B 模型(7B参数)的缺点是在生成 SQL 方面不太擅长,这使得它在结构化分析的应用场景中不实用。以一个例子来说明,我们试图让 Llama-2-7B 根据以下提示模板生成正确的 SQL 语句:
You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.
You must output the SQL query that answers the question.
### Input:
{input}
### Context:
{context}
### Response:
这里我们插入了 sql-create-context 数据集中的一个样本条目。
input: In 1981 which team picked overall 148?
context: CREATE TABLE table_name_8 (team VARCHAR, year VARCHAR, overall_pick VARCHAR)
与此同时,以下是生成的输出与正确的输出对比:
Generated output: SELECT * FROM `table_name_8` WHERE '1980' = YEAR AND TEAM = "Boston Celtics" ORDER BY OVERALL_PICK DESC LIMIT 1;
Correct output: SELECT team FROM table_name_8 WHERE year = 1981 AND overall_pick = "148"
这显然远非理想。与 ChatGPT 和 GPT-4 不同,Llama-2 不可靠地生成格式良好且正确的 SQL 输出。
这正是微调发挥作用的地方——通过给定适当的文本到 SQL 数据集,我们可以教 Llama-2 在从自然语言生成 SQL 输出方面更加准确。在高层次上,微调涉及以某种方式修改模型的权重。微调模型的方法有很多,可以更新网络的所有参数,也可以更新部分参数,甚至只微调额外的参数(例如 LoRA 的工作原理)。
教程概述
在本文中,我们将向你展示如何使用文本到 SQL 数据集对 Llama-2 进行微调,并利用 LlamaIndex 的功能针对任何 SQL 数据库进行结构化分析。
以下是我们使用的技术栈:
1. 从 Hugging Face 数据集中使用 b-mc2/sql-create-context 作为训练数据集。
2. 使用 OpenLLaMa 的 open_llama_7b_v2 作为基础模型。
3. 使用 PEFT 进行高效的微调。
4. 使用 Modal 处理微调过程中的所有云计算和编排工作,也参考了优秀的 doppel-bot 存储库。
5. 使用 LlamaIndex 进行针对任何 SQL 数据库的文本到 SQL 推理。
如上所述,进行微调需要执行许多步骤。我们的目标是使其尽可能简单明了,方便直接使用。
步骤 1:加载用于微调 LLaMa 的训练数据
在这一步中,首先打开 Jupyter notebook。该 notebook 被组织为一系列可执行的脚本,每个脚本都执行加载数据所需的步骤。
我们的代码在整个编排过程中都使用 Modal,并且最好将 Modal 应用于 Python 脚本的顶层。这就是为什么很多单元格本身不包含 Python 代码块的原因。
首先,我们使用 Modal 加载 b-mc2/sql-create-context 数据集。这个任务很简单,只是加载数据集并将其格式化为 .jsonl 文件。
modal run src.load_data_sql --data-dir "data_sql"
正如我们所看到的,在幕后,这个任务相当简单:
# Modal stubs allow our function to run remotely
@stub.function(
retries=Retries(
max_retries=3,
initial_delay=5.0,
backoff_coefficient=2.0,
),
timeout=60 * 60 * 2,
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
cloud="gcp",
)
def load_data_sql(data_dir: str = "data_sql"):
from datasets import load_dataset
dataset = load_dataset("b-mc2/sql-create-context")
dataset_splits = {"train": dataset["train"]}
out_path = get_data_path(data_dir)
out_path.parent.mkdir(parents=True, exist_ok=True)
for key, ds in dataset_splits.items():
with open(out_path, "w") as f:
for item in ds:
newitem = {
"input": item["question"],
"context": item["context"],
"output": item["answer"],
}
f.write(json.dumps(newitem) + "\n")
步骤 2:运行微调脚本
接下来的步骤是在解析的数据集上运行我们的微调脚本。
modal run src.finetune_sql --data-dir "data_sql" --model-dir "model_sql"
微调脚本执行以下步骤。
将数据集分割为训练集和验证集。
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
将每个数据拆分转化为 (输入提示,标签) 元组:输入查询和上下文被格式化为相同的输入提示。然后,输入提示被分词,标签被设置为与输入提示完全相同的内容 — 这使得模型可以在下一个词预测上进行训练。
def generate_and_tokenize_prompt(data_point):
full_prompt = generate_prompt_sql(
data_point["input"],
data_point["context"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
raise NotImplementedError("not implemented yet")
return tokenized_full_prompt
输入提示与本博客开头给出的一样。
当运行微调脚本时,模型将保存在由 model_dir 指定的远程云目录中(如果未指定,则设置为默认值)。
步骤 3:评估
模型已经完成微调,并可以从云端提供服务。我们可以使用来自 sql-create-context 的示例数据运行一些基本评估,以比较微调模型与基准 Llama 2 模型的性能。
modal run src.eval_sql::main
结果显示,经过微调的模型有了显著的改进:
Input 1: {'input': 'Which region (year) has Abigail at number 7, Sophia at number 1 and Aaliyah at number 5?', 'context': 'CREATE TABLE table_name_12 (region__year_ VARCHAR, no_5 VARCHAR, no_7 VARCHAR, no_1 VARCHAR)', 'output': 'SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "sophia" AND
no_5 = "aaliyah"'}
Output 1 (finetuned model): SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "aaliyah" AND no_5 = "sophia"
Output 1 (base model): SELECT * FROM table_name_12 WHERE region__year = '2018' AND no_5 = 'Abigail' AND no_7 = 'Sophia' AND no_1 = 'Aaliyah';
Input 2: {'input': 'Name the result/games for 54741', 'context': 'CREATE TABLE table_21436373_11 (result_games VARCHAR, attendance VARCHAR)', 'output': 'SELECT result_games FROM table_21436373_11 WHERE attendance = 54741'}
Output 2 (finetuned model): SELECT result_games FROM table_21436373_11 WHERE attendance = "54741"
Output 2 (base model): SELECT * FROM table_21436373_11 WHERE result_games = 'name' AND attendance > 0;
而基准模型会生成格式错误的输出或错误的 SQL 语句,
经过微调的模型能够生成更接近预期输出的结果。
步骤 4:将微调模型与 LlamaIndex 集成
现在,我们可以在 LlamaIndex 中使用这个模型进行任何数据库的文本到 SQL 转换。
我们首先定义一个测试 SQL 数据库,然后使用它来测试模型的推理能力。
我们创建了一个名为 city_stats 的示例表,其中包含城市名称、人口和国家信息,并填充了一些示例城市的数据。
db_file = "cities.db"
engine = create_engine(f"sqlite:///{db_file}")
metadata_obj = MetaData()
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)
这些数据存储在一个 cities.db 文件中。
然后,我们可以使用 Modal 将微调模型和这个数据库文件加载到 LlamaIndex 的 NLSQLTableQueryEngine 中 - 这个查询引擎允许用户轻松地开始在给定数据库上执行文本到 SQL 转换。
modal run src.inference_sql_llamaindex::main --query "Which city has the highest population?" --sqlite-file-path "nbs/cities.db" --model-dir "model_sql" --use-finetuned-model True
我们会得到以下类似的响应:
SQL Query: SELECT MAX(population) FROM city_stats WHERE country = "United States"
Response: [(2679000,)]
结论
本文为你提供了一个高级的方式,让你可以开始微调 Llama 2 模型来生成 SQL 语句,并展示了如何将其与 LlamaIndex 集成到你的文本到 SQL 工作流中。