介绍
在本篇文章中,我将探讨另一种称为 LLM 微调的方法。我使用一个名为 MLX 的工具在 macOS 上对 Meta 的 LLaMA-3 和 Mistral LLM 进行了微调,MLX 是一个专为Apple芯片上的机器学习研究定制的阵列框架。这种微调是通过一种名为低秩适配器(Low Rank Adapters,LoRA)的技术完成的。随后,这些经过微调的 LLM 在 Ollama 上运行。
LLM 微调
微调是指调整预先训练好的大型语言模型(LLM)的参数或权重,使其专门用于特定任务或领域的过程。虽然像 GPT 这样的预训练语言模型拥有广泛的通用语言知识,但它们往往缺乏专业领域的专业知识。通过在特定领域的数据上对模型进行训练,微调可以克服这一问题,从而提高其在目标应用中的准确性和有效性。这一过程包括让模型接触特定任务的示例,使其更深入地掌握该领域的细微差别。这一关键步骤将通用语言模型转变为专业工具,从而释放 LLM 在特定领域或应用中的全部潜力。然而,对 LLM 进行微调需要大量的计算资源,如 GPU,以确保高效的训练。
有多种 LLM 微调技术可供选择,包括低秩适配器 (LoRA)、量化 LoRA (QLoRA)、参数高效微调 (PEFT)、DeepSpeed 和 ZeRO。在本篇文章中,我将讨论在Apple MLX 框架内使用 LoRA 技术对 LLM 进行微调的问题。LoRA 由微软研究团队于 2021 年首次推出,它提供了一种参数高效的微调方法。传统方法需要对整个基础模型进行微调,涉及面广,成本高昂,而 LoRA 则不同,它只需添加少量可训练参数,同时冻结原始模型参数。
LoRA 的精髓在于为模型添加适配层,从而提高模型的效率和适应性。LoRA 并没有加入全新的层,而是通过引入低秩矩阵来修改现有层的行为。这种方法引入的额外参数极少,因此与完整的模型再训练相比,大大减少了计算开销和内存使用。通过将适应重点放在特定的模型组件上,LoRA 保留了原始权重中蕴含的基础知识,最大限度地降低了灾难性遗忘的风险。这种有针对性的调整不仅能保持模型的一般能力,还能实现快速迭代和特定任务的增强,从而使 LoRA 成为微调大型预训练模型的灵活、可扩展的解决方案。
RAG 与 LLM 微调对比
RAG(Retrieval-Augmented Generation,检索增强生成)通过提供对精心策划的数据库的访问来增强 LLM,使其能够动态检索相关信息以生成响应。相比之下,微调则是通过在特定的标注数据集上进行训练来调整模型参数,从而提高其在特定任务中的性能。微调修改的是模型本身,而 RAG 则扩展了模型可以访问的数据。
当你需要用初始训练时没有的数据来补充语言模型的提示时,可以使用 RAG。这可能包括实时数据、用户特定数据或与提示相关的上下文信息。RAG 是确保模型获得最新相关数据的理想工具。而微调则是训练模型以更准确地理解和执行特定任务的最佳方式。
利用Apple MLX 进行 LLM 微调
长期以来,人们一直认为只有在 Nvidia GPU 上才能进行 ML 训练和推理。然而,随着 ML 框架 MLX 的发布,这种观点发生了改变,它可以在Apple Silicon CPU/GPU 上进行 ML 训练和推理。由Apple公司开发的 MLX 库类似于 TensorFlow 和 PyTorch,支持 GPU 支持的任务。该库允许在新的 Apple Silicon(M 系列)芯片上对 LLM 进行微调。此外,MLX 还支持使用 LoRA 方法对 LLM 进行微调。我已经成功地使用带有 LoRA 的 MLX 微调了几个 LLM,包括 Llama-3 和 Mistral。
使用案例
在本篇文章中,我将讨论如何使用 MLX 和 LoRA 微调 mistralai/Mistral-7B-Instruct-v0.2 LLM,以完成文本到 SQL 的特定任务,其中包括根据用户提示生成自定义 SQL 查询。经过微调的 LLM 将根据用户的文本输入生成 SQL 查询。我使用了 gretelai/synthetic_text_to_sql 数据集,它是高质量合成文本到 SQL 样本的丰富集合。该数据集包含转换为 SQL 格式的文本输入。利用该数据集,我训练 Mistral-7B LLM 根据用户输入生成 SQL 查询(文本到 SQL)。
安装 MLX 和其他工具
首先,我需要安装 MLX 和一套所需的工具。下面列出了我已安装的工具,以及如何设置和配置 MLX 环境。
# used repository called mlxo
❯❯ git clone https://gitlab.com/rahasak-labs/mlxo.git
❯❯ cd mlxo
# create and activate virtial enviroument
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate
# install mlx
❯❯ pip install -U mlx-lm
# install other requried pythong pakcages
❯❯ pip install pandas
❯❯ pip install pyarrow
安装 Huggingface-CLI
我要从 Hugging Face 获取 LLM(基础模型)和数据集。为此,我需要在 Hugging Face 上建立一个账户,并配置 huggingface-cli 命令行工具。
# setup account in hugging-face from here
https://huggingface.co/welcome
# create access token to read/write data from hugging-face through the cli
# this token required when login to huggingface cli
https://huggingface.co/settings/tokens
# setup hugginface-cli
❯❯ pip install huggingface_hub
❯❯ pip install "huggingface_hub[cli]"
# login to huggingface through cli
# it will ask the access token previously created
❯❯ huggingface-cli login
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
Setting a new token will erase the existing one.
To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible):
Add token as git credential? (Y/n) Y
Token is valid (permission: read).
Your token has been saved in your configured git credential helpers (osxkeychain).
Your token has been saved to /Users/lambda.eranga/.cache/huggingface/token
Login successful
# once login the tokne will be saved in the ~/.cache/huggingface
❯❯ ls ~/.cache/huggingface
datasets
hub
token
准备数据集
MLX 要求数据采用特定格式。MLX 讨论了三种主要格式:聊天、完成和文本。在本用例中,我将使用完成格式,它遵循特定的提示和完成结构,如下所述。在这种情况下,我需要生成一个包含提示和完成的数据集。数据集的生成在 LLM 的微调过程中起着至关重要的作用,因为它直接影响到微调模型的准确性。微调 LLM 可以使用多种技术生成数据集。例如,本篇文章将讨论如何利用 LLM 和提示工程生成数据集。
{
"prompt": "What is the capital of France?",
"completion": "Paris."
}
Hugging Face的原始数据集结构如下,以 .paraquest 文件形式提供。
{
"id": 39325,
"domain": "public health",
"domain_description": "Community health statistics, infectious disease tracking data, healthcare access metrics, and public health policy analysis.",
"sql_complexity": "aggregation",
"sql_complexity_description": "aggregation functions (COUNT, SUM, AVG, MIN, MAX, etc.), and HAVING clause",
"sql_task_type": "analytics and reporting",
"sql_task_type_description": "generating reports, dashboards, and analytical insights",
"sql_prompt": "What is the total number of hospital beds in each state?",
"sql_context": "CREATE TABLE Beds (State VARCHAR(50), Beds INT); INSERT INTO Beds (State, Beds) VALUES ('California', 100000), ('Texas', 85000), ('New York', 70000);",
"sql": "SELECT State, SUM(Beds) FROM Beds GROUP BY State;",
"sql_explanation": "This query calculates the total number of hospital beds in each state in the Beds table. It does this by using the SUM function on the Beds column and grouping the results by the State column."
}
我利用数据集中的 sql_prompt、sql_context 和 sql 字段将数据集转换为完成格式。我将 sql_prompt 和 sql_context 合并为一个提示字段,并使用 sql 字段作为完成格式。此外,MLX 需要三组数据集:训练集、测试集和有效集。数据文件应为 JSONL 格式。下面是将 Hugging Face 的数据转换为这种格式的脚本。
import pandas as pd
def prepare_train():
df = pd.read_parquet('train.parquet')
df['prompt'] = df['sql_prompt'] + " with given SQL schema " + df['sql_context']
df.rename(columns={'sql': 'completion'}, inplace=True)
df = df[['prompt', 'completion']]
print(df.head(10))
# Convert the DataFrame to a JSON format, with each record on a new line
# save as .jsonl
df.to_json('train.jsonl', orient='records', lines=True)
def prepare_test_valid():
df = pd.read_parquet('test.parquet')
df['prompt'] = df['sql_prompt'] + " with given SQL schema " + df['sql_context']
df.rename(columns={'sql': 'completion'}, inplace=True)
df = df[['prompt', 'completion']]
# Calculate split index for two-thirds
split_index = int(len(df) * 2 / 3)
# Split the DataFrame into two parts
test_df = df[:split_index]
valid_df = df[split_index:]
print(test_df.head(10))
print(valid_df.head(10))
# Save the subsets to their respective JSONL files
test_df.to_json('test.jsonl', orient='records', lines=True)
valid_df.to_json('valid.jsonl', orient='records', lines=True)
prepare_train()
prepare_test_valid()
我从 Hugging Face 下载了数据文件,并将其放入数据目录。随后,我运行脚本生成了训练数据集、测试数据集和有效数据集的 JSONL 格式文件。以下是生成的数据文件的结构。
# activate virtual env
❯❯ source .venv/bin/activate
# data directory
# `test.parquet` and `train.parquet` downloaded from the huggingface
# https://huggingface.co/datasets/gretelai/synthetic_text_to_sql/tree/main
❯❯ ls -al data
prepare.py
test.parquet
train.parquet
# generate jsonl files
❯❯ cd data
❯❯ python prepare.py
# generated files
❯❯ ls -ls
test.jsonl
train.jsonl
valid.jsonl
# train.jsonl
{"prompt":"What is the total volume of timber sold by each salesperson, sorted by salesperson? with given SQL schema CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');","completion":"SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;"}
{"prompt":"List all the unique equipment types and their corresponding total maintenance frequency from the equipment_maintenance table. with given SQL schema CREATE TABLE equipment_maintenance (equipment_type VARCHAR(255), maintenance_frequency INT);","completion":"SELECT equipment_type, SUM(maintenance_frequency) AS total_maintenance_frequency FROM equipment_maintenance GROUP BY equipment_type;"}
{"prompt":"How many marine species are found in the Southern Ocean? with given SQL schema CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));","completion":"SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';"}
{"prompt":"What is the total trade value and average price for each trader and stock in the trade_history table? with given SQL schema CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);","completion":"SELECT trader_id, stock, SUM(price * quantity) as total_trade_value, AVG(price) as avg_price FROM trade_history GROUP BY trader_id, stock;"}
# test.jsonl
{"prompt":"What is the average explainability score of creative AI applications in 'Europe' and 'North America' in the 'creative_ai' table? with given SQL schema CREATE TABLE creative_ai (application_id INT, name TEXT, region TEXT, explainability_score FLOAT); INSERT INTO creative_ai (application_id, name, region, explainability_score) VALUES (1, 'ApplicationX', 'Europe', 0.87), (2, 'ApplicationY', 'North America', 0.91), (3, 'ApplicationZ', 'Europe', 0.84), (4, 'ApplicationAA', 'North America', 0.93), (5, 'ApplicationAB', 'Europe', 0.89);","completion":"SELECT AVG(explainability_score) FROM creative_ai WHERE region IN ('Europe', 'North America');"}
{"prompt":"Delete all records of rural infrastructure projects in Indonesia that have a completion date before 2010. with given SQL schema CREATE TABLE rural_infrastructure (id INT, project_name TEXT, sector TEXT, country TEXT, completion_date DATE); INSERT INTO rural_infrastructure (id, project_name, sector, country, completion_date) VALUES (1, 'Water Supply Expansion', 'Infrastructure', 'Indonesia', '2008-05-15'), (2, 'Rural Electrification', 'Infrastructure', 'Indonesia', '2012-08-28'), (3, 'Transportation Improvement', 'Infrastructure', 'Indonesia', '2009-12-31');","completion":"DELETE FROM rural_infrastructure WHERE country = 'Indonesia' AND completion_date < '2010-01-01';"}
{"prompt":"How many accidents have been recorded for SpaceX and Blue Origin rocket launches? with given SQL schema CREATE TABLE Accidents (id INT, launch_provider VARCHAR(255), year INT, description TEXT); INSERT INTO Accidents (id, launch_provider, year, description) VALUES (1, 'SpaceX', 2015, 'Falcon 9 explosion'), (2, 'Blue Origin', 2011, 'Propulsion system failure'), (3, 'SpaceX', 2016, 'Falcon 9 explosion');","completion":"SELECT launch_provider, COUNT(*) FROM Accidents GROUP BY launch_provider;"}
{"prompt":"What is the maximum quantity of seafood sold in a single transaction? with given SQL schema CREATE TABLE sales (id INT, location VARCHAR(20), quantity INT, price DECIMAL(5,2)); INSERT INTO sales (id, location, quantity, price) VALUES (1, 'Northeast', 50, 12.99), (2, 'Midwest', 75, 19.99), (3, 'West', 120, 14.49);","completion":"SELECT MAX(quantity) FROM sales;"}
# valid.jsonl
{"prompt":"What is the total number of tickets sold for all football games? with given SQL schema CREATE TABLE tickets (ticket_id INT, game_id INT, region VARCHAR(50), quantity INT); INSERT INTO tickets (ticket_id, game_id, region, quantity) VALUES (1, 1, 'Midwest', 500); INSERT INTO tickets (ticket_id, game_id, region, quantity) VALUES (2, 2, 'Northeast', 700); CREATE TABLE games (game_id INT, sport VARCHAR(50)); INSERT INTO games (game_id, sport) VALUES (1, 'Football'); INSERT INTO games (game_id, sport) VALUES (2, 'Basketball');","completion":"SELECT SUM(quantity) FROM tickets INNER JOIN games ON tickets.game_id = games.game_id WHERE sport = 'Football';"}
{"prompt":"What is the total revenue for the soccer team from ticket sales in London and Paris? with given SQL schema CREATE TABLE tickets (ticket_id INT, game_id INT, quantity INT, price DECIMAL(5,2)); INSERT INTO tickets VALUES (1, 1, 50, 25.99); INSERT INTO tickets VALUES (2, 2, 30, 19.99); CREATE TABLE games (game_id INT, team VARCHAR(20), location VARCHAR(20), price DECIMAL(5,2)); INSERT INTO games VALUES (1, 'Arsenal', 'London', 50.00); INSERT INTO games VALUES (2, 'PSG', 'Paris', 40.00);","completion":"SELECT SUM(tickets.quantity * games.price) FROM tickets INNER JOIN games ON tickets.game_id = games.game_id WHERE games.location IN ('London', 'Paris');"}
{"prompt":"Identify the number of security incidents that occurred in 'Europe' in the last month. with given SQL schema CREATE TABLE incidents (incident_id INT PRIMARY KEY, incident_date DATE, incident_location VARCHAR(50)); INSERT INTO incidents (incident_id, incident_date, incident_location) VALUES (1, '2022-01-01', 'HQ'), (2, '2022-02-15', 'Branch01'), (3, '2022-03-30', 'Asia'), (4, '2022-04-15', 'Europe'), (5, '2022-04-20', 'Europe');","completion":"SELECT COUNT(*) FROM incidents WHERE incident_location = 'Europe' AND incident_date >= DATE_SUB(CURRENT_DATE, INTERVAL 1 MONTH);"}
{"prompt":"Identify the top 5 threat intelligence sources with the highest number of reported incidents in the last year, according to our Incident Tracking database. with given SQL schema CREATE TABLE IncidentTracking (id INT, source VARCHAR(50), incident_count INT, timestamp DATETIME); INSERT INTO IncidentTracking (id, source, incident_count, timestamp) VALUES (1, 'TechFirmA', 200, '2021-01-01 10:00:00'), (2, 'TechFirmB', 150, '2021-01-01 10:00:00');","completion":"SELECT source, SUM(incident_count) as total_incidents FROM IncidentTracking WHERE timestamp >= DATE_SUB(NOW(), INTERVAL 1 YEAR) GROUP BY source ORDER BY total_incidents DESC LIMIT 5;"}
微调/训练 LLM
下一步是使用我之前准备的数据集通过 MLX 对 Mistral-7B LLM 进行微调。首先,我使用 huggingface-cli 从 Hugging Face 下载了 mistralai/Mistral-7B-Instruct-v0.2 LLM。然后,我使用提供的数据集和 LoRA 对 LLM 进行了训练。LoRA 即低阶自适应(Low-Rank Adaptation),包括引入低阶矩阵来调整模型的行为,而无需进行大量的重新训练,从而在保留原始模型参数的同时实现高效、有针对性的自适应。
在配备 64GB 内存和 30 个 GPU 的 Mac M2 上,训练 LLM 和生成必要的适配器大约需要 36 分钟。
# download llm
❯❯ huggingface-cli download mistralai/Mistral-7B-Instruct-v0.2
/Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/1296dc8fd9b21e6424c9c305c06db9ae60c03ace
# model is downloaded into ~/.cache/huggingface/hub/
❯❯ ls ~/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
blobs refs snapshots
# list all downloaded models from huggingface
❯❯ huggingface-cli scan-cache
REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH
-------------------------------------- --------- ------------ -------- ------------- ------------- ---- -------------------------------------------------------------------------------------------
NousResearch/Meta-Llama-3-8B model 16.1G 14 2 days ago 2 days ago main /Users/lambda.eranga/.cache/huggingface/hub/models--NousResearch--Meta-Llama-3-8B
gpt2 model 2.9M 5 3 months ago 3 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--gpt2
mistralai/Mistral-7B-Instruct-v0.2 model 29.5G 17 5 hours ago 5 hours ago main /Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
sentence-transformers/all-MiniLM-L6-v2 model 91.6M 11 3 months ago 3 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2
# fine-tune llm
# --model - original model which download from huggin face
# --data data - data directory path with train.jsonl
# --batch-size 4 - batch size
# --lora-layers 16 - number of lora layers
# --iters 1000 - tranning iterations
❯❯ python -m mlx_lm.lora \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--data data \
--train \
--batch-size 4\
--lora-layers 16\
--iters 1000
# following is the tranning output
# when tranning is started, the initial validation loss is 1.939 and tranning loss is 1.908
# once is tranning finished, validation loss is 0.548 and tranning loss is is 0.534
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models wont be available and only tokenizers, configuration and file/data utilities can be used.
Loading pretrained model
Fetching 11 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 96.71it/s]
Loading datasets
Training
Trainable parameters: 0.024% (1.704M/7241.732M)
Starting training..., iters: 1000
Iter 1: Val loss 1.939, Val took 47.185s
Iter 10: Train loss 1.908, Learning Rate 1.000e-05, It/sec 0.276, Tokens/sec 212.091, Trained Tokens 7688, Peak mem 21.234 GB
Iter 20: Train loss 1.162, Learning Rate 1.000e-05, It/sec 0.330, Tokens/sec 275.972, Trained Tokens 16040, Peak mem 21.381 GB
Iter 30: Train loss 0.925, Learning Rate 1.000e-05, It/sec 0.360, Tokens/sec 278.166, Trained Tokens 23769, Peak mem 21.381 GB
Iter 40: Train loss 0.756, Learning Rate 1.000e-05, It/sec 0.289, Tokens/sec 258.912, Trained Tokens 32717, Peak mem 24.291 GB
---
Iter 960: Train loss 0.510, Learning Rate 1.000e-05, It/sec 0.360, Tokens/sec 283.649, Trained Tokens 717727, Peak mem 24.332 GB
Iter 970: Train loss 0.598, Learning Rate 1.000e-05, It/sec 0.398, Tokens/sec 276.395, Trained Tokens 724663, Peak mem 24.332 GB
Iter 980: Train loss 0.612, Learning Rate 1.000e-05, It/sec 0.419, Tokens/sec 280.406, Trained Tokens 731359, Peak mem 24.332 GB
Iter 990: Train loss 0.605, Learning Rate 1.000e-05, It/sec 0.371, Tokens/sec 292.855, Trained Tokens 739260, Peak mem 24.332 GB
Iter 1000: Val loss 0.548, Val took 36.479s
Iter 1000: Train loss 0.534, Learning Rate 1.000e-05, It/sec 2.469, Tokens/sec 1886.081, Trained Tokens 746899, Peak mem 24.332 GB
Iter 1000: Saved adapter weights to adapters/adapters.safetensors and adapters/0001000_adapters.safetensors.
Saved final adapter weights to adapters/adapters.safetensors.
# gpu usage while tranning
❯❯ sudo powermetrics --samplers gpu_power -i500 -n1
Machine model: Mac14,6
OS version: 23F79
Boot arguments:
Boot time: Wed Jun 19 20:50:45 2024
*** Sampled system activity (Thu Jun 20 16:25:57 2024 -0400) (503.31ms elapsed) ***
**** GPU usage ****
GPU HW active frequency: 1398 MHz
GPU HW active residency: 100.00% (444 MHz: 0% 612 MHz: 0% 808 MHz: 0% 968 MHz: 0% 1110 MHz: 0% 1236 MHz: 0% 1338 MHz: 0% 1398 MHz: 100%)
GPU SW requested state: (P1 : 0% P2 : 0% P3 : 0% P4 : 0% P5 : 0% P6 : 0% P7 : 0% P8 : 100%)
GPU SW state: (SW_P1 : 0% SW_P2 : 0% SW_P3 : 0% SW_P4 : 0% SW_P5 : 0% SW_P6 : 0% SW_P7 : 0% SW_P8 : 0%)
GPU idle residency: 0.00%
GPU Power: 45630 mW
# end of the tranning the LoRA adapters generated into the adapters folder
❯❯ ls adapters
0000100_adapters.safetensors
0000300_adapters.safetensors
0000500_adapters.safetensors
0000700_adapters.safetensors
0000900_adapters.safetensors
adapter_config.json
0000200_adapters.safetensors
0000400_adapters.safetensors
0000600_adapters.safetensors
0000800_adapters.safetensors
0001000_adapters.safetensors
adapters.safetensors
评估微调后的 LLM
LLM 现在已经训练完成,LoRA 适配器也已创建。我们可以利用这些适配器和原始 LLM 来测试微调后 LLM 的功能。最初,我使用 --train 参数用 MLX 测试了 LLM。随后,我向原始 LLM 和经过微调的 LLM 提出了同样的问题。通过比较,我们可以看到微调后的 LLM 是如何根据所提供的数据集对文本到 SQL 用例进行优化的。通过修改提示、数据集和其他参数等,可以进一步改进微调过程。
# test the llm with the test data
# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --data data - data directory path with test.jsonl
❯❯ python -m mlx_lm.lora \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-path adapters \
--data data \
--test
# testing output
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models wont be available and only tokenizers, configuration and file/data utilities can be used.
Fetching 11 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 127804.28it/s]
==========
Prompt: <s>[INST] List all transactions and customers from the 'Africa' region. [/INST]
SELECT * FROM Transactions WHERE region = 'Africa' UNION SELECT * FROM Customers WHERE region = 'Africa';
==========
Prompt: 63.182 tokens-per-sec
Generation: 21.562 tokens-per-sec
# first ask the question from original llm using mlx
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--max-tokens 500 \
--prompt "List all transactions and customers from the 'Africa' region."
# it provides genric answer as below
Prompt: <s>[INST] List all transactions and customers from the 'Africa' region. [/INST]
I'm an AI language model and don't have the ability to directly access or list specific data from a database or system. I can only provide you with an example of how you might write a SQL query to retrieve the transactions and customers from the 'Africa' region based on the assumption that you have a database with a table named 'transactions' and 'customers' and both tables have a 'region' column.
```sql
-- Query to get all transactions from Africa
SELECT *
FROM transactions
WHERE region = 'Africa';
-- Query to get all customers from Africa
SELECT *
FROM customers
WHERE region = 'Africa';
```
These queries will return all transactions and customers that have the 'region' set to 'Africa'. Adjust the table and column names to match your specific database schema.
==========
Prompt: 67.650 tokens-per-sec
Generation: 23.078 tokens-per-sec
# same question asked from fine-tunneld llm with usingn adapter
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --adapter-path adapters - location of the lora adapters
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--max-tokens 500 \
--adapter-path adapters \
--prompt "List all transactions and customers from the 'Africa' region."
# it provides specific answer with generated sql
Prompt: <s>[INST] List all transactions and customers from the 'Africa' region. [/INST]
SELECT * FROM Transactions WHERE region = 'Africa' UNION SELECT * FROM Customers WHERE region = 'Africa';
==========
Prompt: 64.955 tokens-per-sec
Generation: 21.667 tokens-per-sec
利用融合适配器建立新模型
完成微调后,我可以将新模型学习到的调整与现有模型的权重合并,这个过程称为融合。从技术上讲,这涉及更新预训练/基础模型的权重和参数,以纳入微调模型的改进。基本上,我可以继续将 LoRA 适配器文件融合回基础模型中。
完成微调过程后,我可以将新模型学习到的调整与现有模型的权重合并,这一过程称为融合。从技术上讲,这涉及到更新预训练/基础模型的权重和参数,以纳入微调模型的改进。从本质上讲,我可以继续将 LoRA 适配器融合回基础模型中。
# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --save-path models/effectz-sql - new model path
# --de-quantize - use this flag if you want convert the model GGUF format later
❯❯ python -m mlx_lm.fuse \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-path adapters \
--save-path models/effectz-sql \
--de-quantize
# output
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models wont be available and only tokenizers, configuration and file/data utilities can be used.
Loading pretrained model
Fetching 11 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 129599.28it/s]
De-quantizing model
# new model generatd in the models directory
❯❯ tree models
models
└── effectz-sql
├── config.json
├── model-00001-of-00003.safetensors
├── model-00002-of-00003.safetensors
├── model-00003-of-00003.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── tokenizer.json
└── tokenizer_config.json
# now i can directly ask question from the new model
# --model models/effectz-sql - new model path
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model models/effectz-sql \
--max-tokens 500 \
--prompt "List all transactions and customers from the 'Africa' region."
# otuput
==========
Prompt: <s>[INST] List all transactions and customers from the 'Africa' region. [/INST]
SELECT * FROM transactions WHERE region = 'Africa'; SELECT * FROM customers WHERE region = 'Africa';
==========
Prompt: 65.532 tokens-per-sec
Generation: 23.193 tokens-per-sec
创建 GGUF 模型
我想用 Ollama 运行这个新创建的模型,Ollama 是一个轻量级的灵活框架,专为在个人电脑上本地部署 LLM 而设计。要在 Ollama 中运行这个合并模型,我需要将其转换为 GGUF(格奥尔基-格尔加诺夫统一格式)文件。GGUF 是 Ollama 使用的一种标准化存储格式。为了将模型转换成 GGUF,我使用了另一个名为 llama.cpp 的工具,它是一个用 C++ 编写的开源软件库,可以对各种 LLM 进行推理。下面是将模型转换为 GGUF 格式并建立 Ollama 模型的方法。
# clone llama.cpp into same location where mlxo repo exists
❯❯ git clone https://github.com/ggerganov/llama.cpp.git
# directory stcture where llama.cpp and mlxo exists
❯❯ ls
llama.cpp
mlxo
# configure required packages in llama.cpp with setting virtual enviroument
❯❯ cd llama.cpp
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate
❯❯ pip install -r requirements.txt
# llama.cpp contains a script `convert-hf-to-gguf.py` to convert hugging face model gguf
❯❯ ls convert-hf-to-gguf.py
convert-hf-to-gguf.py
# convert newly generated model(in mlxo/models/effectz-sql) to gguf
# --outfile ../mlxo/models/effectz-sql.gguf - output gguf model file path
# --outtype q8_0 - 8 bit quantize which helps improve inference speed
❯❯ python convert-hf-to-gguf.py ../mlxo/models/effectz-sql \
--outfile ../mlxo/models/effectz-sql.gguf \
--outtype q8_0
# output
INFO:hf-to-gguf:Loading model: effectz-sql
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Set model parameters
INFO:hf-to-gguf:gguf: context length = 32768
INFO:hf-to-gguf:gguf: embedding length = 4096
INFO:hf-to-gguf:gguf: feed forward length = 14336
INFO:hf-to-gguf:gguf: head count = 32
INFO:hf-to-gguf:gguf: key-value head count = 8
INFO:hf-to-gguf:gguf: rope theta = 1000000.0
INFO:hf-to-gguf:gguf: rms norm epsilon = 1e-05
INFO:hf-to-gguf:gguf: file type = 7
INFO:hf-to-gguf:Set model tokenizer
INFO:gguf.vocab:Setting special token type bos to 1
INFO:gguf.vocab:Setting special token type eos to 2
INFO:gguf.vocab:Setting special token type unk to 0
INFO:gguf.vocab:Setting add_bos_token to True
INFO:gguf.vocab:Setting add_eos_token to False
INFO:gguf.vocab:Setting chat_template to {{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}
INFO:hf-to-gguf:Exporting model to '../mlxo/models/effectz-sql.gguf'
INFO:hf-to-gguf:gguf: loading model weight map from 'model.safetensors.index.json'
INFO:hf-to-gguf:gguf: loading model part 'model-00001-of-00003.safetensors'
INFO:hf-to-gguf:token_embd.weight, torch.bfloat16 --> Q8_0, shape = {4096, 32000}
---
INFO:hf-to-gguf:blk.31.attn_v.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:output_norm.weight, torch.bfloat16 --> F32, shape = {4096}
Writing: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7.70G/7.70G [01:17<00:00, 99.6Mbyte/s]
INFO:hf-to-gguf:Model successfully exported to '../mlxo/models/effectz-sql.gguf'
# new gguf model generated in the mlxo/models
❯❯ cd ../mlxo
❯❯ ls models/effectz-sql.gguf
models/effectz-sql.gguf
构建并运行 Ollama 模型
现在我可以创建一个 Ollama Modelfile,并使用名为 effectz-sql.gguf 的 GGUF 模型文件构建一个 Ollama 模型。Ollama Modelfile 是一个配置文件,用于定义和管理 Ollama 平台上的模型。下面是创建 Modelfile 和生成新 Ollama 模型的方法。
# create file named `Modelfile` in models directory with following content
❯❯ cat models/Modelfile
FROM ./effectz-sql.gguf
# create ollama model
❯❯ ollama create effectz-sql -f models/Modelfile
transferring model data
using existing layer sha256:24ba3b41f3b846bd56142b713b12df8da3e7ab1c5ee9ae3f5afc76d87c69d796
using autodetected template mistral-instruct
creating new layer sha256:fd230ae885e75bf5240291df2bfb0090af58f4bdf3689deafcef415582915a33
writing manifest
success
# list ollama models
# effectz-sql:latest is the newly created model
❯❯ ollama ls
NAME ID SIZE MODIFIED
effectz-sql:latest 736275f4faa4 7.7 GB 17 seconds ago
mistral:latest 2ae6f6dd7a3d 4.1 GB 3 days ago
llama3:latest a6990ed6be41 4.7 GB 7 weeks ago
llama3:8b a6990ed6be41 4.7 GB 7 weeks ago
llama2:latest 78e26419b446 3.8 GB 2 months ago
llama2:13b d475bf4c50bc 7.4 GB 2 months ago
# run model with ollama and ask question
# it will convert the prompt into sql
❯❯ ollama run effectz-sql
>>> List all transactions and customers from the 'Africa' region.
SELECT * FROM transactions WHERE customer_region = 'Africa';<|im_end|>