背景
在这篇文章中,我将探讨 LLM 微调在网络安全领域的预测任务中的更高级用例。具体来说,我提出了一种使用微调 LLM 进行实时网络攻击检测的机制。在这里,我使用包含恶意和良性数据包的mistralai/Mistral-7B-Instruct-v0.2网络PCAP 数据集进行了微调。通过使用该数据集训练模型,模型可以学习预测实时网络数据(例如 PCAP 数据)中的异常。微调过程是在配备 M2 芯片的 Apple Silicon Mac 上使用 LoRA 和 Apple MLX 框架执行的。微调之后,这些自定义模型将与一起运行。
LLM微调
微调是调整预训练大型语言模型(LLM)的参数或权重,以使其专门用于特定任务或领域的过程。虽然像GPT这样的预训练语言模型拥有广泛的通用语言知识,但它们往往缺乏特定领域的专业知识。微调通过训练模型使用领域特定数据来克服这一局限,从而提高了模型在目标应用中的准确性和有效性。这个过程包括让模型接触任务特定的示例,使其能够更深入地理解领域的细微差别。这一关键步骤将通用语言模型转变为专用工具,从而充分发挥LLM在特定领域或应用中的潜力。然而,微调LLM需要大量的计算资源,如GPU,以确保高效训练。
有多种LLM微调技术可供选择,包括低秩适配器(LoRA)、量化LoRA(QLoRA)、参数高效微调(PEFT)、DeepSpeed和ZeRO。在本文中,我将讨论在Apple MLX框架内使用LoRA技术微调LLM。LoRA由微软研究团队于2021年首次提出,提供了一种参数高效的微调方法。与传统方法需要微调整个基础模型(这可能非常庞大且成本高昂)不同,LoRA在保持原始模型参数冻结的同时,仅增加少量可训练参数。
LoRA的核心在于向模型中添加适配器层,从而提高其效率和适应性。LoRA不是引入全新的层,而是通过引入低秩矩阵来修改现有层的行为。这种方法引入的额外参数极少,因此与完整的模型重新训练相比,显著降低了计算开销和内存使用量。通过专注于特定模型组件的适应,LoRA保留了原始权重中嵌入的基础知识,最小化了灾难性遗忘的风险。这种有针对性的适应不仅保持了模型的通用能力,还实现了快速迭代和任务特定的增强,使LoRA成为微调大型预训练模型的灵活且可扩展的解决方案。
RAG与LLM微调的比较
检索增强生成(RAG)通过为LLM提供访问精选数据库的能力来增强其性能,允许它动态检索相关信息以生成响应。相比之下,微调涉及通过训练模型使用特定的标记数据集来调整其参数,以提高其在特定任务上的性能。微调会修改模型本身,而RAG则扩展了模型可以访问的数据范围。
当你需要用未在模型初始训练时可用的数据来补充语言模型的提示时,请使用RAG。这可以包括实时数据、用户特定数据或与提示相关的上下文信息。RAG非常适合确保模型能够访问最新和最相关的数据。另一方面,微调对于训练模型以更准确地理解和执行特定任务而言是最佳的。
使用Apple MLX进行LLM微调
长久以来,人们一直认为机器学习(ML)训练和推理只能在Nvidia GPU上进行。然而,随着ML框架MLX的发布,这一观念发生了改变。MLX使得在Apple Silicon CPU/GPU上进行ML训练和推理成为可能。由Apple开发的MLX库类似于TensorFlow和PyTorch,支持GPU加速的任务。该库允许在新款Apple Silicon(M系列)芯片上对大型语言模型(LLM)进行微调。此外,MLX还支持使用LoRA方法进行LLM微调。我已经成功使用MLX和LoRA微调了几个LLM,包括Llama-3和Mistral。
应用案例
在这篇文章中,我将探索一个更高级的LLM微调应用案例,即实时网络攻击检测。在这里,我将使用包含恶意和良性数据包的PCAP数据集对mistralai/Mistral-7B-Instruct-v0.2进行微调。通过用此数据集训练mistralai/Mistral-7B-Instruct-v0.2模型,模型将学会预测实时网络数据(如PCAP数据)中的异常。
设置MLX和其他工具
首先,我需要安装MLX以及一系列必要的工具。以下是我已安装的工具列表,以及我如何设置和配置MLX环境。
# used repository called mlxa
❯❯ git clone https://gitlab.com/rahasak-labs/mlxa.git
❯❯ cd mlxa
# create and activate virtial enviroument
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate
# install mlx
❯❯ pip install -U mlx-lm
# install other requried pythong pakcages
❯❯ pip install pandas
❯❯ pip install pyarrow
设置Huggingface-CLI
我从Hugging Face获取大型语言模型(基础模型)和数据集。为此,我需要在Hugging Face上设置一个账户,并配置huggingface-cli命令行工具。
# setup account in hugging-face from here
https://huggingface.co/welcome
# create access token to read/write data from hugging-face through the cli
# this token required when login to huggingface cli
https://huggingface.co/settings/tokens
# setup hugginface-cli
❯❯ pip install huggingface_hub
❯❯ pip install "huggingface_hub[cli]"
# login to huggingface through cli
# it will ask the access token previously created
❯❯ huggingface-cli login
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
Setting a new token will erase the existing one.
To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible):
Add token as git credential? (Y/n) Y
Token is valid (permission: read).
Your token has been saved in your configured git credential helpers (osxkeychain).
Your token has been saved to /Users/lambda.eranga/.cache/huggingface/token
Login successful
# once login the tokne will be saved in the ~/.cache/huggingface
❯❯ ls ~/.cache/huggingface
datasets
hub
token
准备数据集
MLX要求数据以特定的格式呈现。在MLX中主要讨论了三种格式:聊天、补全和文本。对于这个用例,我将使用补全格式的数据,该格式将提示和补全信息结合在一个自然语言短语中。这种格式需要提示(即输入给LLM的内容)和补全(即LLM的响应)。数据集的生成在LLM的微调中起着至关重要的作用,因为它直接影响微调模型的准确性。可以采用各种技术来生成用于微调LLM的数据集。例如,这篇文章讨论了如何使用提示工程技术与LLM结合来生成数据集。
原始数据集是以.csv文件的形式组织的。为了使用Apple MLX进行微调,我将其转换为了基于补全的格式。每条文本记录都包含提示和补全,它们被整合成一个连贯的自然语言短语。以下是原始的PCAP CSV数据和转换后的补全格式数据。
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 34293, network protocol: tcp, duration of the connection: 2.998556, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 41106, network protocol: tcp, duration of the connection: 2.998779, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 45986, network protocol: tcp, duration of the connection: 2.998807, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.1, source port: 3, network protocol: icmp, duration of the connection: 1.999924, connection state: OTH, missed_bytes: 0, number of packets sent from source to destination: 2, number of ip bytes sent from source to destination: 136, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "normal network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 51216, network protocol: tcp, duration of the connection: 2.999077, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 51022, network protocol: tcp, duration of the connection: 2.999300, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
MLX需要三组数据集来进行微调:训练集、测试集和验证集。数据文件应采用JSONL格式。下面的脚本将PCAP数据从CSV文件转换为JSONL格式。在转换过程中,它会将CSV记录中的PCAP数据属性组合起来,并生成一个单一的自然语言短语作为提示。完成字段则是根据CSV记录中的标签字段生成的。
import pandas as pd
import json
import random
# Read the CSV file
csv_file_path = "pcap-labeled.csv"
pcap_data = pd.read_csv(csv_file_path, delimiter="|")
# Prepare the JSONL file content
jsonl_data = []
# Fields to check for '-'
fields_to_check = [
'id.orig_p', 'id.orig_p',
'proto', 'conn_state',
'duration', 'missed_bytes',
'orig_pkts', 'resp_pkts',
'orig_ip_bytes', 'resp_ip_bytes'
]
for _, row in pcap_data.iterrows():
# Skip rows with '-' in any of the specified fields
if any(row[field] == '-' for field in fields_to_check):
continue
# constrct lable
label = "anomaly" if row['label'] == "Malicious" else "normal"
# Create the JSONL record
record = {
"prompt": (
f"You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. "
f"Here are the attributes, "
f"'source ip address: {row['id.orig_h']}, source port: {row['id.orig_p']}, "
f"network protocol: {row['proto']}, "
f"duration of the connection: {row['duration']}, "
f"connection state: {row['conn_state']}, "
f"missed_bytes: {row['missed_bytes']}, number of packets sent from source to destination: {row['orig_pkts']}, "
f"number of ip bytes sent from source to destination: {row['orig_ip_bytes']}, number of packets sent from destination to source: {row['resp_pkts']}, "
f"number of ip bytes sent from destination to source: {row['resp_ip_bytes']}'. "
),
"completion": f"{label} network traffic"
}
jsonl_data.append(record)
# Shuffle the data
random.shuffle(jsonl_data)
# Calculate split indices for 2/3 train and 1/3 test
train_split = int(len(jsonl_data) * 7 / 10)
# Split the data
train_data = jsonl_data[:train_split]
test_data = jsonl_data[train_split:]
# Save train.jsonl
train_file_path = 'train.jsonl'
with open(train_file_path, 'w', encoding='utf-8') as train_file:
for entry in train_data:
train_file.write(json.dumps(entry) + '\n')
# Save test.jsonl
test_file_path = 'test.jsonl'
with open(test_file_path, 'w', encoding='utf-8') as test_file:
for entry in test_data:
test_file.write(json.dumps(entry) + '\n')
# Save valid.jsonl
valid_file_path = 'valid.jsonl'
with open(valid_file_path, 'w', encoding='utf-8') as valid_file:
for entry in test_data:
valid_file.write(json.dumps(entry) + '\n')
我已将数据集放置在数据目录中。随后,我运行了脚本以生成训练集、测试集和验证集的JSONL格式文件。以下是生成的数据文件的结构。
# activate virtual env
❯❯ source .venv/bin/activate
# data directory
❯❯ ls -al data
prepare.py
s2d.csv
# generate jsonl files
❯❯ cd data
❯❯ python prepare.py
# generated files
❯❯ ls -ls
test.jsonl
train.jsonl
valid.jsonl
# train.jsonl
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 33037, network protocol: tcp, duration of the connection: 2.998781, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 45506, network protocol: tcp, duration of the connection: 2.998830, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 39972, network protocol: tcp, duration of the connection: 2.998819, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 40262, network protocol: tcp, duration of the connection: 2.998574, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 60633, network protocol: tcp, duration of the connection: 2.999060, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
# test.jsonl
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 55520, network protocol: tcp, duration of the connection: 2.998521, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 45461, network protocol: tcp, duration of the connection: 2.998792, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 37168, network protocol: tcp, duration of the connection: 2.998772, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "normal network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 60316, network protocol: tcp, duration of the connection: 2.998572, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 53036, network protocol: tcp, duration of the connection: 2.998555, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
# valid.jsonl
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 46557, network protocol: tcp, duration of the connection: 2.998811, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47863, network protocol: tcp, duration of the connection: 2.998817, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 38867, network protocol: tcp, duration of the connection: 2.998996, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 54329, network protocol: tcp, duration of the connection: 2.998824, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
{"prompt": "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 60526, network protocol: tcp, duration of the connection: 2.999080, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. ", "completion": "anomaly network traffic"}
微调/训练LLM
下一步是使用我之前准备的数据集,通过MLX对Mistral-7B LLM进行微调。首先,我使用huggingface-cli从Hugging Face下载了mistralai/Mistral-7B-Instruct-v0.2 LLM。然后,我使用提供的数据集和LoRA技术对LLM进行了训练。LoRA,即低秩适应(Low-Rank Adaptation),通过引入低秩矩阵来调整模型的行为,而无需进行大量的重新训练,从而在保持原始模型参数的同时,实现高效、有针对性的适应。
在配备64GB RAM和30个GPU的Mac M2上,训练LLM并生成必要的适配器大约需要25分钟。
# download llm
❯❯ huggingface-cli download mistralai/Mistral-7B-Instruct-v0.2
/Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/1296dc8fd9b21e6424c9c305c06db9ae60c03ace
# model is downloaded into ~/.cache/huggingface/hub/
❯❯ ls ~/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
blobs refs snapshots
# list all downloaded models from huggingface
❯❯ huggingface-cli scan-cache
REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH
-------------------------------------- --------- ------------ -------- ------------- ------------- ---- -------------------------------------------------------------------------------------------
BAAI/bge-reranker-base model 1.1G 6 3 months ago 4 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--BAAI--bge-reranker-base
NousResearch/Meta-Llama-3-8B model 16.1G 14 2 months ago 5 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--NousResearch--Meta-Llama-3-8B
gpt2 model 2.9M 5 8 months ago 8 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--gpt2
infgrad/stella_en_1.5B_v5 model 240.7K 6 4 months ago 4 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--infgrad--stella_en_1.5B_v5
mistralai/Mistral-7B-Instruct-v0.2 model 29.5G 21 1 day ago 1 day ago main /Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
sentence-transformers/all-MiniLM-L6-v2 model 91.6M 11 8 months ago 8 months ago main /Users/lambda.eranga/.cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2
Done in 0.0s. Scanned 6 repo(s) for a total of 46.8G.
Got 1 warning(s) while scanning. Use -vvv to print details
# fine-tune llm
# --model - original model which download from huggin face
# --data data - data directory path with train.jsonl
# --batch-size 4 - batch size
# --lora-layers 16 - number of lora layers
# --iters 1000 - tranning iterations
❯❯ python -m mlx_lm.lora \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--data data \
--train \
--batch-size 4\
--lora-layers 16\
--iters 1000
# following is the tranning output
# when tranning is started, the initial validation loss is 1.939 and tranning loss is 1.908
# once is tranning finished, validation loss is 0.548 and tranning loss is is 0.534
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 17980.26it/s]
Loading datasets
Training
Trainable parameters: 0.024% (1.704M/7241.732M)
Starting training..., iters: 1000
Iter 1: Val loss 3.346, Val took 35.688s
Iter 10: Train loss 2.713, Learning Rate 1.000e-05, It/sec 0.473, Tokens/sec 314.028, Trained Tokens 6640, Peak mem 16.083 GB
Iter 20: Train loss 1.179, Learning Rate 1.000e-05, It/sec 0.474, Tokens/sec 314.641, Trained Tokens 13280, Peak mem 16.083 GB
Iter 30: Train loss 0.432, Learning Rate 1.000e-05, It/sec 0.474, Tokens/sec 314.606, Trained Tokens 19920, Peak mem 16.083 GB
Iter 40: Train loss 0.484, Learning Rate 1.000e-05, It/sec 0.476, Tokens/sec 313.707, Trained Tokens 26511, Peak mem 16.083 GB
Iter 50: Train loss 0.361, Learning Rate 1.000e-05, It/sec 0.471, Tokens/sec 311.593, Trained Tokens 33123, Peak mem 16.083 GB
Iter 60: Train loss 0.335, Learning Rate 1.000e-05, It/sec 0.472, Tokens/sec 312.100, Trained Tokens 39739, Peak mem 16.083 GB
Iter 70: Train loss 0.315, Learning Rate 1.000e-05, It/sec 0.473, Tokens/sec 313.810, Trained Tokens 46367, Peak mem 16.083 GB
Iter 80: Train loss 0.311, Learning Rate 1.000e-05, It/sec 0.473, Tokens/sec 313.270, Trained Tokens 52994, Peak mem 16.083 GB
---
Iter 910: Train loss 0.165, Learning Rate 1.000e-05, It/sec 0.470, Tokens/sec 311.057, Trained Tokens 603153, Peak mem 16.213 GB
Iter 920: Train loss 0.154, Learning Rate 1.000e-05, It/sec 0.469, Tokens/sec 310.728, Trained Tokens 609781, Peak mem 16.213 GB
Iter 930: Train loss 0.158, Learning Rate 1.000e-05, It/sec 0.470, Tokens/sec 311.830, Trained Tokens 616409, Peak mem 16.213 GB
Iter 940: Train loss 0.148, Learning Rate 1.000e-05, It/sec 0.465, Tokens/sec 308.425, Trained Tokens 623037, Peak mem 16.213 GB
Iter 950: Train loss 0.156, Learning Rate 1.000e-05, It/sec 0.468, Tokens/sec 310.664, Trained Tokens 629669, Peak mem 16.213 GB
Iter 960: Train loss 0.166, Learning Rate 1.000e-05, It/sec 0.466, Tokens/sec 308.836, Trained Tokens 636297, Peak mem 16.213 GB
Iter 970: Train loss 0.150, Learning Rate 1.000e-05, It/sec 0.470, Tokens/sec 311.695, Trained Tokens 642925, Peak mem 16.213 GB
Iter 980: Train loss 0.151, Learning Rate 1.000e-05, It/sec 0.474, Tokens/sec 314.600, Trained Tokens 649564, Peak mem 16.213 GB
Iter 990: Train loss 0.146, Learning Rate 1.000e-05, It/sec 0.470, Tokens/sec 312.381, Trained Tokens 656204, Peak mem 16.213 GB
Iter 1000: Val loss 0.149, Val took 35.181s
Iter 1000: Train loss 0.146, Learning Rate 1.000e-05, It/sec 4.690, Tokens/sec 3106.646, Trained Tokens 662828, Peak mem 16.213 GB
Iter 1000: Saved adapter weights to adapters/adapters.safetensors and adapters/0001000_adapters.safetensors.
Saved final adapter weights to adapters/adapters.safetensors.
# gpu usage while trainning
❯❯ sudo powermetrics --samplers gpu_power -i500 -n1
Password:
Machine model: Mac14,6
OS version: 23F79
Boot arguments:
Boot time: Fri Dec 6 09:26:57 2024
*** Sampled system activity (Mon Dec 9 08:43:34 2024 -0500) (503.37ms elapsed) ***
**** GPU usage ****
GPU HW active frequency: 1397 MHz
GPU HW active residency: 98.22% (444 MHz: .09% 612 MHz: 0% 808 MHz: 0% 968 MHz: 0% 1110 MHz: 0% 1236 MHz: 0% 1338 MHz: 0% 1398 MHz: 98%)
GPU SW requested state: (P1 : 0% P2 : 0% P3 : 0% P4 : 0% P5 : 0% P6 : 0% P7 : 0% P8 : 100%)
GPU SW state: (SW_P1 : 0% SW_P2 : 0% SW_P3 : 0% SW_P4 : 0% SW_P5 : 0% SW_P6 : 0% SW_P7 : 0% SW_P8 : 0%)
GPU idle residency: 1.78%
GPU Power: 42023 mW
# end of the tranning the LoRA adapters generated into the adapters folder
❯❯ ls adapters
0000100_adapters.safetensors
0000300_adapters.safetensors
0000500_adapters.safetensors
0000700_adapters.safetensors
0000900_adapters.safetensors
adapter_config.json
0000200_adapters.safetensors
0000400_adapters.safetensors
0000600_adapters.safetensors
0000800_adapters.safetensors
0001000_adapters.safetensors
adapters.safetensors
评估微调后的LLM
LLM现在已经训练完成,并且已经创建了LoRA适配器。我们可以将这些适配器与原始的LLM一起使用,以测试微调后的LLM的功能。首先,我使用--train参数通过MLX测试了LLM。随后,我向原始的LLM和微调后的LLM提出了相同的问题。通过这种比较,我们可以看到基于提供的医疗诊断预测用例数据集,微调后的LLM是如何得到优化的。通过修改提示、数据集和其他参数等,可以进一步提高微调过程的效果。
# test the llm with the test data
# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --data data - data directory path with test.jsonl
❯❯ python -m mlx_lm.lora \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-path adapters \
--data data \
--test
# output
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 54471.48it/s]
Loading datasets
Testing
Test loss 0.151, Test ppl 1.163.
# first ask the question from original llm using mlx
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--max-tokens 500 \
--prompt "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'."
# it provide genearic answer based on the original knowlege base of the llm
Prompt: <s> [INST] You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. [/INST]
Based on the provided network traffic attributes, the traffic appears to be normal. Here`s why:
1. Source IP address: The IP address 192.168.100.103 falls within a private IP address range, which is commonly used for internal networks. This is not an indicator of anomalous traffic.
2. Source port: The source port number 47525 is not associated with any known malicious activity. It is a high-numbered ephemeral port, which is commonly used for outgoing connections from applications.
3. Network protocol: The use of TCP (Transmission Control Protocol) is a normal and common protocol used for reliable data transfer between applications.
4. Duration of the connection: The connection duration of 2.998777 seconds is within the normal range for a typical connection.
5. Connection state: The connection state S0 indicates that the connection is still in the initial SYN (synchronize) stage, which is normal for the beginning of a connection.
6. Missed bytes and packets: The absence of missed bytes and packets indicates that all data was received as intended.
7. Number of packets and IP bytes: The number of packets and IP bytes sent and received is within a reasonable range for a typical connection.
Therefore, based on the given network traffic attributes, the traffic appears to be normal.
==========
Prompt: 348.322 tokens-per-sec
Generation: 20.630 tokens-per-sec
# same question asked from fine-tunneld llm with usingn adapter
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --adapter-path adapters - location of the lora adapters
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--max-tokens 500 \
--adapter-path adapters \
--prompt "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'."
# it provide specific answer based on the dataset used to fine-tune the llm
Prompt: <s> [INST] You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. [/INST]
anomaly network traffic
==========
Prompt: 351.853 tokens-per-sec
Generation: 21.462 tokens-per-sec
使用融合适配器构建新模型
在完成微调之后,我可以将新模型学习到的调整与现有模型的权重进行合并,这个过程被称为融合。从技术上讲,这涉及更新预训练/基础模型的权重和参数,以纳入微调模型所带来的改进。基本上,我可以继续将LoRA适配器文件融合回基础模型中。
在完成微调过程后,我可以将新模型学习到的调整与现有模型的权重进行合并,这个过程被称为融合。从技术上讲,这涉及更新预训练/基础模型的权重和参数,以吸纳微调模型的改进。本质上,我可以着手将LoRA适配器融合回基础模型中。
# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --save-path models/effectz-attack - new model path
# --de-quantize - use this flag if you want convert the model GGUF format later
❯❯ python -m mlx_lm.fuse \
--model mistralai/Mistral-7B-Instruct-v0.2 \
--adapter-path adapters \
--save-path models/effectz-attack \
--de-quantize
# output
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 19929.74it/s]
De-quantizing model
# new model generatd in the models directory
❯❯ tree models
models
└── effectz-attack
├── config.json
├── model-00001-of-00003.safetensors
├── model-00002-of-00003.safetensors
├── model-00003-of-00003.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer.model
└── tokenizer_config.json
# now i can directly ask question from the new model
# --model models/effectz-attack - new model path
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
--model models/effectz-attack \
--max-tokens 500 \
--prompt "You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. "
# output
Prompt: <s> [INST] You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'. [/INST]
anomaly network traffic
==========
Prompt: 346.093 tokens-per-sec
Generation: 23.113 tokens-per-sec
创建GGUF模型
我想用Ollama运行这个新创建的模型,Ollama是一个为在个人电脑上本地部署LLM而设计的轻量级且灵活的框架。为了在Ollama中运行这个合并后的模型,我需要将其转换为GGUF(Georgi Gerganov Unified Format)文件。GGUF是Ollama使用的标准化存储格式。为了将模型转换为GGUF,我使用了另一个名为llama.cpp的工具,这是一个用C++编写的开源软件库,可以对各种LLM进行推理。以下是将模型转换为GGUF格式并构建Ollama模型的方法。
# clone llama.cpp into same location where mlxa repo exists
❯❯ git clone https://github.com/ggerganov/llama.cpp.git
# directory stcture where llama.cpp and mlxa exists
❯❯ ls
llama.cpp
mlxa
# configure required packages in llama.cpp with setting virtual enviroument
❯❯ cd llama.cpp
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate
❯❯ pip install -r requirements.txt
# llama.cpp contains a script `convert-hf-to-gguf.py` to convert hugging face model gguf
❯❯ ls convert-hf-to-gguf.py
convert-hf-to-gguf.py
# convert newly generated model(in mlxa/models/effectz-attack) to gguf
# --outfile ../mlxa/models/effectz-attack.gguf - output gguf model file path
# --outtype q8_0 - 8 bit quantize which helps improve inference speed
# to optimize the model performance try different outtype parameters(e.g without outtype etc)
❯❯ python convert-hf-to-gguf.py ../mlxa/models/effectz-attack \
--outfile ../mlxa/models/effectz-predict.gguf \
--outtype q8_0
# output
INFO:hf-to-gguf:Loading model: effectz-attack
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Set model parameters
INFO:hf-to-gguf:gguf: context length = 32768
INFO:hf-to-gguf:gguf: embedding length = 4096
INFO:hf-to-gguf:gguf: feed forward length = 14336
INFO:hf-to-gguf:gguf: head count = 32
INFO:hf-to-gguf:gguf: key-value head count = 8
INFO:hf-to-gguf:gguf: rope theta = 1000000.0
INFO:hf-to-gguf:gguf: rms norm epsilon = 1e-05
INFO:hf-to-gguf:gguf: file type = 7
INFO:hf-to-gguf:Set model tokenizer
INFO:gguf.vocab:Setting special token type bos to 1
INFO:gguf.vocab:Setting special token type eos to 2
INFO:gguf.vocab:Setting special token type unk to 0
INFO:gguf.vocab:Setting add_bos_token to True
INFO:gguf.vocab:Setting add_eos_token to False
---
INFO:hf-to-gguf:blk.29.attn_v.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.30.attn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.30.ffn_down.weight, torch.bfloat16 --> Q8_0, shape = {14336, 4096}
INFO:hf-to-gguf:blk.30.ffn_gate.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.30.ffn_up.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.30.ffn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.30.attn_k.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.30.attn_output.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.30.attn_q.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.30.attn_v.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.31.attn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.31.ffn_down.weight, torch.bfloat16 --> Q8_0, shape = {14336, 4096}
INFO:hf-to-gguf:blk.31.ffn_gate.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.31.ffn_up.weight, torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.31.ffn_norm.weight, torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.31.attn_k.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.31.attn_output.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.31.attn_q.weight, torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.31.attn_v.weight, torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:output_norm.weight, torch.bfloat16 --> F32, shape = {4096}
Writing: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7.70G/7.70G [01:21<00:00, 94.9Mbyte/s]
INFO:hf-to-gguf:Model successfully exported to '../mlxa/models/effectz-attack.gguf'
# new gguf model generated in the mlxm/models
❯❯ cd ../mlxa
❯❯ ls models/effectz-attack.gguf
models/effectz-attack.gguf
构建并运行Ollama模型
现在,我可以使用名为effectz-predict.gguf的GGUF模型文件来创建一个Ollama Modelfile,并构建一个Ollama模型。Ollama Modelfile是一个配置文件,用于在Ollama平台上定义和管理模型。以下是创建Modelfile并生成新Ollama模型的方法。
# create file named `Modelfile` in models directory with following content
❯❯ cat models/Modelfile
FROM ./effectz-attack.gguf
# create ollama model
❯❯ ollama create effectz-attack -f models/Modelfile
transferring model data 100%
using existing layer sha256:189106bcbbfa1942e25555311d0097b1b2604d4a44416ae90436763ea8e17886
creating new layer sha256:633247d6e759ef95ed88aba596b6882716cdf3888236ab14762070b42126f039
writing manifest
success
# list ollama models
# effectz-attack:latest is the newly created model
❯❯ ollama ls
NAME ID SIZE MODIFIED
effectz-attack:latest 7aa76a0412bc 7.7 GB 44 seconds ago
effectz-predict:latest e87d93dff4c4 14 GB 4 weeks ago
effectz-sql:latest 736275f4faa4 7.7 GB 5 months ago
rahasak-sql:latest e41d278330ed 7.7 GB 5 months ago
mistral:latest 2ae6f6dd7a3d 4.1 GB 5 months ago
llama3:latest a6990ed6be41 4.7 GB 7 months ago
llama3:8b a6990ed6be41 4.7 GB 7 months ago
llama2:latest 78e26419b446 3.8 GB 8 months ago
llama2:13b d475bf4c50bc 7.4 GB 8 months ago
# run model with ollama and ask question about the diagnosis
# it will give the answer based on custom knowledge based which use to fine-tune the model
❯❯ ollama run effectz-attack
>>> You are an expert in network traffic classification. Based on the provided network traffic attributes, you must determine whether the traffic is normal or anomaly. Here are the attributes, 'source ip address: 192.168.100.103, source port: 47525, network protocol: tcp, duration of the connection: 2.998777, connection state: S0, missed_bytes: 0, number of packets sent from source to destination: 3, number of ip bytes sent from source to destination: 180, number of packets sent from destination to source: 0, number of ip bytes sent from destination to source: 0'.
anomaly network traffic