利用Apple MLX微调LLM进行医学诊断预测

2024年11月12日 由 alex 发表 283 0

介绍

我之前讨论了微调大型语言模型(LLM)的基础知识以及使用Apple MLX框架的特定用例,包括如何构建LLM的自定义版本。在本文中,我将探索一个更高级的LLM微调用例,即用于医疗诊断预测。在这里,我使用包含疾病和症状自然语言描述的医疗数据集对mistralai/Mistral-7B-Instruct-v0.2进行了微调,使模型能够根据输入的症状预测可能的诊断。微调过程是在配备M2芯片的Apple Silicon Mac上,使用LoRA和Apple MLX框架完成的。微调后,这些自定义模型通过Ollama运行。


LLM微调

微调是调整预训练大型语言模型(LLM)的参数或权重,以使其专用于特定任务或领域的过程。虽然像GPT这样的预训练语言模型拥有广泛的一般语言知识,但它们往往缺乏特定领域的专业知识。微调通过用领域特定数据训练模型来克服这一缺陷,从而提高模型在目标应用中的准确性和有效性。这个过程涉及让模型接触任务特定的示例,使其能够更深入地理解领域的细微差别。这一关键步骤将通用语言模型转变为专用工具,从而释放LLM在特定领域或应用中的全部潜力。然而,微调LLM需要大量的计算资源,如GPU,以确保高效训练。


有多种LLM微调技术可供选择,包括低秩适配器(LoRA)、量化LoRA(QLoRA)、参数高效微调(PEFT)、DeepSpeed和ZeRO。在本帖中,我将讨论在Apple MLX框架内使用LoRA技术对LLM进行微调。LoRA由微软的研究团队于2021年首次提出,它提供了一种参数高效的微调方法。与需要微调整个基础模型(这可能非常庞大且成本高昂)的传统方法不同,LoRA在保持原始模型参数冻结的同时,只增加少量可训练参数。


LoRA的核心在于向模型中添加适配器层,从而提高其效率和适应性。LoRA不是引入全新的层,而是通过引入低秩矩阵来修改现有层的行为。这种方法引入的额外参数极少,因此与完整的模型重新训练相比,显著降低了计算开销和内存使用量。通过将适应集中在特定的模型组件上,LoRA保留了原始权重中嵌入的基础知识,从而最小化了灾难性遗忘的风险。这种有针对性的适应不仅保持了模型的通用能力,还实现了快速迭代和任务特定的增强,使LoRA成为微调大型预训练模型的灵活且可扩展的解决方案。


RAG与LLM微调的比较

检索增强生成(RAG)通过为LLM提供一个精选的数据库,使其能够动态检索相关信息以生成响应,从而增强LLM的功能。相比之下,微调涉及通过在特定的标记数据集上训练模型来调整其参数,以提高其在特定任务上的性能。微调会修改模型本身,而RAG则扩展了模型可以访问的数据。


当你需要用未在模型初始训练时可用的数据来补充语言模型的提示时,请使用RAG。这可以包括实时数据、用户特定数据或与提示相关的上下文信息。RAG非常适合确保模型能够访问最新和最相关的数据。另一方面,微调对于训练模型以更准确地理解和执行特定任务而言是最佳的。


32


使用Apple MLX进行LLM微调

长久以来,人们一直认为机器学习(ML)的训练和推理只能在Nvidia GPU上进行。然而,随着ML框架MLX的发布,这一观念发生了改变。MLX使得在Apple Silicon CPU/GPU上进行ML训练和推理成为可能。由Apple开发的MLX库类似于TensorFlow和PyTorch,支持GPU支持的任务。该库允许在新款Apple Silicon(M系列)芯片上对大型语言模型(LLM)进行微调。此外,MLX还支持使用LoRA方法进行LLM微调。我已经成功使用MLX和LoRA微调了多个LLM,包括Llama-3和Mistral。


应用场景

在这篇文章中,我将探索一个更高级的LLM微调应用场景——医学诊断预测。我将使用包含疾病和症状自然语言描述的医学数据集,对mistralai/Mistral-7B-Instruct-v0.2进行微调。通过用此数据集训练mistralai/Mistral-7B-Instruct-v0.2模型,模型将学会根据输入的症状预测可能的诊断。


设置MLX和其他工具

首先,我需要安装MLX以及一系列所需工具。以下是我已安装的工具列表,以及我如何设置和配置MLX环境的说明。


# used repository called mlxm
❯❯ git clone https://gitlab.com/rahasak-labs/mlxm.git
❯❯ cd mlxm

# create and activate virtial enviroument 
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate

# install mlx
❯❯ pip install -U mlx-lm

# install other requried pythong pakcages
❯❯ pip install pandas
❯❯ pip install pyarrow


设置Huggingface-CLI

我从Hugging Face获取大型语言模型(基础模型)和数据集。为此,我需要在Hugging Face上设置一个账户,并配置huggingface-cli命令行工具。


# setup account in hugging-face from here
https://huggingface.co/welcome

# create access token to read/write data from hugging-face through the cli
# this token required when login to huggingface cli
https://huggingface.co/settings/tokens

# setup hugginface-cli
❯❯ pip install huggingface_hub
❯❯ pip install "huggingface_hub[cli]"

# login to huggingface through cli
# it will ask the access token previously created 
❯❯ huggingface-cli login
    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|
    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible):
Add token as git credential? (Y/n) Y
Token is valid (permission: read).
Your token has been saved in your configured git credential helpers (osxkeychain).
Your token has been saved to /Users/lambda.eranga/.cache/huggingface/token
Login successful

# once login the tokne will be saved in the ~/.cache/huggingface
❯❯ ls ~/.cache/huggingface
datasets 
hub      
token


准备数据集

MLX要求数据以特定的格式呈现。在MLX中,主要讨论了三种格式:聊天、补全和文本。对于本用例,我将使用文本格式的数据,该格式将上下文、问题和回答等类似信息组合成一个单一的自然语言短语。这种格式要求生成一个包含所有相关信息(上下文、问题和回答)的单一文本字段的数据集。数据集的生成在LLM的微调中起着至关重要的作用,因为它直接影响微调模型的准确性。可以采用各种技术来生成用于微调LLM的数据集。例如,这篇文章讨论了如何使用带有提示工程的LLM来生成数据集。


原始数据集是以.csv文件的形式结构化的。为了使用Apple MLX进行微调,我将其转换为了基于文本的格式。每条文本记录都将上下文、问题和回答信息组合成一个连贯的自然语言短语。


# original csv record
label,text
Psoriasis,"I have been experiencing a skin rash on my arms, legs, and torso for the past few weeks. It is red, itchy, and covered in dry, scaly patches."
# converted text type record
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing a skin rash on my arms, legs, and torso for the past few weeks. It is red, itchy, and covered in dry, scaly patches.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Psoriasis."}


MLX需要三组数据集来进行微调:训练集、测试集和验证集。数据文件应采用JSONL格式。下面的脚本将CSV文件中的数据转换为JSONL格式。在转换过程中,它将提示、问题和回答数据组合成一个单一的自然语言文本短语,将所有元素无缝地捕捉到一个文本字段中。


import pandas as pd
import json
import random
# load csv data
file_path = './s2d.csv'
df = pd.read_csv(file_path)
# create text type data
jsonl_data = []
for _, row in df.iterrows():
    diagnosis = row['label']
    symptoms = row['text']
    prompt = f"You are a medical diagnosis expert. You will give patient symptoms: '{symptoms}'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with {diagnosis}."
    jsonl_data.append({"text": prompt})
# shuffle the data
random.shuffle(jsonl_data)
# calculate split indices
total_records = len(jsonl_data)
train_split = int(total_records * 2 / 3)
test_split = int(total_records * 1 / 6)
# split the data
train_data = jsonl_data[:train_split]
test_data = jsonl_data[train_split:train_split + test_split]
valid_data = jsonl_data[train_split + test_split:]
# write to JSONL files
with open('train.jsonl', 'w') as train_file:
    for entry in train_data:
        train_file.write(json.dumps(entry) + '\n')
with open('test.jsonl', 'w') as test_file:
    for entry in test_data:
        test_file.write(json.dumps(entry) + '\n')
with open('valid.jsonl', 'w') as valid_file:
    for entry in valid_data:
        valid_file.write(json.dumps(entry) + '\n')
print("data successfully saved to train.jsonl, test.jsonl, and valid.jsonl")


我已将数据集放置在数据目录中。随后,我运行了脚本以生成训练集、测试集和验证集的JSONL格式文件。以下是生成的数据文件的结构。


# activate virtual env
❯❯ source .venv/bin/activate

# data directory
❯❯ ls -al data
prepare.py
s2d.csv

# generate jsonl files
❯❯ cd data
❯❯ python prepare.py

# generated files
❯❯ ls -ls
test.jsonl
train.jsonl
valid.jsonl
# train.jsonl
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing fatigue, difficulty walking, diarrhea, night sweats, tremors.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Influenza."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing difficulty breathing, weight loss, fever.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Pneumonia."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing dizziness, dry skin, rapid heartbeat, shortness of breath, vision changes.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Mumps."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing vomiting, shortness of breath, night sweats, rapid heartbeat.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Hepatitis."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing loss of appetite, tremors, fatigue, difficulty breathing, increased urination.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Scoliosis."}
# test.jsonl
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing difficulty walking, fever, loss of appetite, fatigue, shortness of breath, sore throat.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Depression."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing fever, rapid heartbeat, hair loss.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Asthma."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing hair loss, dizziness, rapid heartbeat.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Hepatitis."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing skin rash, abdominal pain, difficulty breathing.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Anxiety."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have lots of itchy spots on my skin, and sometimes they turn red or bumpy. There are also some weird patches that are different colors than the rest of my skin, and sometimes I get these weird bumps that look like little balls.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Fungal infection."}
# valid.jsonl
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing dizziness, increased urination, muscle pain.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Bronchitis."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing sore throat, rapid heartbeat, hair loss, loss of appetite.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Bronchitis."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing skin rash, vomiting, muscle pain, joint pain.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Gout."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been enduring frequent headaches, blurred vision, excessive appetite, a sore neck, anxiety, irritability, and digestive difficulties including indigestion and acid reflux.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Migraine."}
{"text": "You are a medical diagnosis expert. You will give patient symptoms: 'I have been experiencing weight loss, vision changes, vomiting.'. Question: 'What is the diagnosis I have?'. Response: You may be diagnosed with Tonsillitis."}


微调/训练LLM

下一步是使用我之前准备的数据集,通过MLX对Mistral-7B LLM进行微调。首先,我使用huggingface-cli从Hugging Face下载了mistralai/Mistral-7B-Instruct-v0.2 LLM。然后,我使用提供的数据集和LoRA(Low-Rank Adaptation,低秩适应)方法训练了LLM。LoRA方法通过引入低秩矩阵来调整模型的行为,而无需进行大量的重新训练,从而在保留原始模型参数的同时,实现高效、有针对性的适应。


在配备64GB RAM和30个GPU的Mac M2上,训练LLM并生成必要的适配器大约需要25分钟。


# download llm
❯❯ huggingface-cli download mistralai/Mistral-7B-Instruct-v0.2
/Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2/snapshots/1296dc8fd9b21e6424c9c305c06db9ae60c03ace

# model is downloaded into ~/.cache/huggingface/hub/
❯❯ ls ~/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
blobs     refs      snapshots

# list all downloaded models from huggingface
❯❯ huggingface-cli scan-cache
REPO ID                                REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH
-------------------------------------- --------- ------------ -------- ------------- ------------- ---- -------------------------------------------------------------------------------------------
BAAI/bge-reranker-base                 model             1.1G        6 3 months ago  4 months ago  main /Users/lambda.eranga/.cache/huggingface/hub/models--BAAI--bge-reranker-base
NousResearch/Meta-Llama-3-8B           model            16.1G       14 2 months ago  5 months ago  main /Users/lambda.eranga/.cache/huggingface/hub/models--NousResearch--Meta-Llama-3-8B
gpt2                                   model             2.9M        5 8 months ago  8 months ago  main /Users/lambda.eranga/.cache/huggingface/hub/models--gpt2
infgrad/stella_en_1.5B_v5              model           240.7K        6 4 months ago  4 months ago  main /Users/lambda.eranga/.cache/huggingface/hub/models--infgrad--stella_en_1.5B_v5
mistralai/Mistral-7B-Instruct-v0.2     model            29.5G       21 1 day ago     1 day ago     main /Users/lambda.eranga/.cache/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.2
sentence-transformers/all-MiniLM-L6-v2 model            91.6M       11 8 months ago  8 months ago  main /Users/lambda.eranga/.cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2
Done in 0.0s. Scanned 6 repo(s) for a total of 46.8G.
Got 1 warning(s) while scanning. Use -vvv to print details

# fine-tune llm 
# --model - original model which download from huggin face
# --data data - data directory path with train.jsonl
# --batch-size 4 - batch size
# --lora-layers 16 - number of lora layers
# --iters 1000 - tranning iterations
❯❯ python -m mlx_lm.lora \
  --model mistralai/Mistral-7B-Instruct-v0.2 \
  --data data \
  --train \
  --batch-size 4\
  --lora-layers 16\
  --iters 1000
# following is the tranning output
# when tranning is started, the initial validation loss is 1.939 and tranning loss is 1.908
# once is tranning finished, validation loss is 0.548 and tranning loss is is 0.534
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 23515.47it/s]
Loading datasets
Training
Trainable parameters: 0.024% (1.704M/7241.732M)
Starting training..., iters: 1000
Iter 1: Val loss 3.846, Val took 16.852s
Iter 10: Train loss 3.264, Learning Rate 1.000e-05, It/sec 1.008, Tokens/sec 265.250, Trained Tokens 2631, Peak mem 15.193 GB
Iter 20: Train loss 1.790, Learning Rate 1.000e-05, It/sec 1.125, Tokens/sec 270.366, Trained Tokens 5034, Peak mem 15.193 GB
Iter 30: Train loss 1.199, Learning Rate 1.000e-05, It/sec 1.024, Tokens/sec 266.239, Trained Tokens 7634, Peak mem 15.193 GB
Iter 40: Train loss 0.815, Learning Rate 1.000e-05, It/sec 1.166, Tokens/sec 271.254, Trained Tokens 9960, Peak mem 15.193 GB
Iter 50: Train loss 0.985, Learning Rate 1.000e-05, It/sec 1.074, Tokens/sec 265.638, Trained Tokens 12433, Peak mem 15.193 GB
Iter 60: Train loss 0.907, Learning Rate 1.000e-05, It/sec 1.032, Tokens/sec 265.612, Trained Tokens 15007, Peak mem 15.193 GB
Iter 70: Train loss 0.912, Learning Rate 1.000e-05, It/sec 1.038, Tokens/sec 266.206, Trained Tokens 17571, Peak mem 15.193 GB
Iter 80: Train loss 0.995, Learning Rate 1.000e-05, It/sec 1.015, Tokens/sec 262.973, Trained Tokens 20162, Peak mem 15.193 GB
Iter 90: Train loss 0.669, Learning Rate 1.000e-05, It/sec 1.170, Tokens/sec 269.866, Trained Tokens 22469, Peak mem 15.193 GB
---
Iter 870: Train loss 0.614, Learning Rate 1.000e-05, It/sec 1.105, Tokens/sec 261.226, Trained Tokens 219852, Peak mem 15.319 GB
Iter 880: Train loss 0.831, Learning Rate 1.000e-05, It/sec 0.956, Tokens/sec 269.099, Trained Tokens 222667, Peak mem 15.319 GB
Iter 890: Train loss 0.734, Learning Rate 1.000e-05, It/sec 1.081, Tokens/sec 268.167, Trained Tokens 225148, Peak mem 15.319 GB
Iter 900: Train loss 0.747, Learning Rate 1.000e-05, It/sec 1.037, Tokens/sec 276.490, Trained Tokens 227815, Peak mem 15.319 GB
Iter 900: Saved adapter weights to adapters/adapters.safetensors and adapters/0000900_adapters.safetensors.
Iter 910: Train loss 0.732, Learning Rate 1.000e-05, It/sec 1.080, Tokens/sec 273.424, Trained Tokens 230346, Peak mem 15.319 GB
Iter 920: Train loss 0.790, Learning Rate 1.000e-05, It/sec 1.036, Tokens/sec 269.653, Trained Tokens 232950, Peak mem 15.319 GB
Iter 930: Train loss 0.868, Learning Rate 1.000e-05, It/sec 0.948, Tokens/sec 265.881, Trained Tokens 235754, Peak mem 15.319 GB
Iter 940: Train loss 0.631, Learning Rate 1.000e-05, It/sec 1.135, Tokens/sec 266.720, Trained Tokens 238103, Peak mem 15.319 GB
Iter 950: Train loss 0.689, Learning Rate 1.000e-05, It/sec 1.000, Tokens/sec 253.929, Trained Tokens 240643, Peak mem 15.319 GB
Iter 960: Train loss 0.834, Learning Rate 1.000e-05, It/sec 0.955, Tokens/sec 270.962, Trained Tokens 243479, Peak mem 15.319 GB
Iter 970: Train loss 0.762, Learning Rate 1.000e-05, It/sec 0.999, Tokens/sec 270.069, Trained Tokens 246182, Peak mem 15.319 GB
Iter 980: Train loss 0.605, Learning Rate 1.000e-05, It/sec 1.043, Tokens/sec 257.387, Trained Tokens 248650, Peak mem 15.319 GB
Iter 990: Train loss 0.656, Learning Rate 1.000e-05, It/sec 1.136, Tokens/sec 268.814, Trained Tokens 251016, Peak mem 15.319 GB
Iter 1000: Val loss 0.795, Val took 15.497s
Iter 1000: Train loss 0.665, Learning Rate 1.000e-05, It/sec 12.319, Tokens/sec 3157.469, Trained Tokens 253579, Peak mem 15.319 GB
Iter 1000: Saved adapter weights to adapters/adapters.safetensors and adapters/0001000_adapters.safetensors.
Saved final adapter weights to adapters/adapters.safetensors.

# gpu usage while tranning
❯❯ sudo powermetrics --samplers gpu_power -i500 -n1
Machine model: Mac14,6
OS version: 23F79
Boot arguments:
Boot time: Tue Oct  8 08:36:47 2024
*** Sampled system activity (Sun Nov 10 15:19:00 2024 -0500) (502.79ms elapsed) ***
**** GPU usage ****
GPU HW active frequency: 1397 MHz
GPU HW active residency:  98.61% (444 MHz: .05% 612 MHz:   0% 808 MHz:   0% 968 MHz:   0% 1110 MHz:   0% 1236 MHz:   0% 1338 MHz:   0% 1398 MHz:  99%)
GPU SW requested state: (P1 :   0% P2 :   0% P3 :   0% P4 :   0% P5 :   0% P6 :   0% P7 :   0% P8 : 100%)
GPU SW state: (SW_P1 :   0% SW_P2 :   0% SW_P3 :   0% SW_P4 :   0% SW_P5 :   0% SW_P6 :   0% SW_P7 :   0% SW_P8 :   0%)
GPU idle residency:   1.39%
GPU Power: 44541 mW

# end of the tranning the LoRA adapters generated into the adapters folder
❯❯ ls adapters
0000100_adapters.safetensors 
0000300_adapters.safetensors 
0000500_adapters.safetensors 
0000700_adapters.safetensors 
0000900_adapters.safetensors 
adapter_config.json
0000200_adapters.safetensors 
0000400_adapters.safetensors 
0000600_adapters.safetensors 
0000800_adapters.safetensors 
0001000_adapters.safetensors 
adapters.safetensors


评估微调后的LLM

LLM现在已经训练完成,并且已经创建了LoRA适配器。我们可以将这些适配器与原始的LLM一起使用,以测试微调后的LLM的功能。最初,我使用--train参数通过MLX测试了LLM。随后,我向原始的LLM和微调后的LLM提出了相同的问题。通过这种对比,我们可以看到基于提供的医疗诊断预测用例数据集,微调后的LLM是如何得到优化的。通过修改提示、数据集和其他参数等,可以进一步提高微调过程的效果。


# test the llm with the test data
# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --data data - data directory path with test.jsonl
❯❯ python -m mlx_lm.lora \
  --model mistralai/Mistral-7B-Instruct-v0.2 \
  --adapter-path adapters \
  --data data \
  --test
# output
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 37418.77it/s]
Loading datasets
Testing
Test loss 0.754, Test ppl 2.125.

# first ask the question from original llm using mlx
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
    --model mistralai/Mistral-7B-Instruct-v0.2 \
    --max-tokens 500 \
    --prompt "You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have?"
# it provide genearic answer based on the original knowlege base of the llm
Prompt: <s> [INST] You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have? [/INST]
Based on the symptoms you have provided, there are several possible conditions that could be causing memory loss and stiffness. Here are some possibilities:
1. Alzheimer`s disease: Memory loss is a common symptom of Alzheimer`s disease, which is a progressive brain disorder that causes a gradual decline in cognitive abilities. Stiffness can also be a symptom, particularly in the advanced stages of the disease.
2. Parkinson`s disease: Parkinson`s disease is a neurological disorder that affects movement. One of its primary symptoms is stiffness or rigidity in the muscles. Memory loss is also a common symptom, particularly in the later stages of the disease.
3. Rheumatoid arthritis: This is a chronic inflammatory disorder that affects the joints and other body systems. Stiffness is a hallmark symptom, particularly in the morning or after a long period of inactivity. Memory loss is not a common symptom of rheumatoid arthritis, but it can occur in some cases.
4. Multiple sclerosis: Multiple sclerosis is a chronic autoimmune disorder that affects the central nervous system. Memory loss and stiffness are both possible symptoms. Memory loss can take the form of difficulty with short-term memory, learning new information, or recalling old information.
5. Depression: Depression can cause memory loss and cognitive impairment, particularly in older adults. Stiffness is not a common symptom of depression, but it can occur in some cases, particularly if the person is physically inactive due to depressive symptoms.
It`s important to note that these are just possibilities based on the symptoms you have provided. A proper diagnosis can only be made through a thorough medical evaluation by a healthcare professional. If you are concerned about your symptoms, I would recommend making an appointment with your doctor to discuss them further.
==========
Prompt: 135.174 tokens-per-sec
Generation: 22.959 tokens-per-sec

# same question asked from fine-tunneld llm with usingn adapter
# --model - original model which download from huggin face
# --max-tokens 500 - how many tokens to generate
# --adapter-path adapters - location of the lora adapters
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
    --model mistralai/Mistral-7B-Instruct-v0.2 \
    --max-tokens 500 \
    --adapter-path adapters \
    --prompt "You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have?"
# it provide specific answer based on the dataset used to fine-tune the llm
Prompt: <s> [INST] You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have? [/INST]
 You indicated that you have been experiencing memory loss and stiffness.dhdxmême¶ Memory loss is a common symptom of Parkinson`s Disease, which can also manifest as difficulty with thinking and problem-solving. Stiffness, also known as rigidity, can also be a symptom of Parkinson`s Disease. You should consult a healthcare professional for a proper diagnosis. Other symptoms of Parkinson`s Disease include tremors, difficulty walking, and loss of balance. If you are diagnosed with Parkinson`s Disease, there are treatments available that can help manage your symptoms. It is important to seek medical advice as soon as possible to begin treatment and slow the progression of the disease.
==========
Prompt: 126.813 tokens-per-sec
Generation: 21.583 tokens-per-sec


使用融合适配器构建新模型

完成微调后,我可以将新模型学习到的调整与现有模型的权重进行合并,这一过程称为融合。从技术上讲,这涉及更新预训练/基础模型的权重和参数,以纳入微调模型的改进。基本上,我可以继续将LoRA适配器文件融合回基础模型中。


在完成微调过程后,我可以将新模型学习到的调整与现有模型的权重进行合并,这一过程被称为融合。技术上,这涉及更新预训练/基础模型的权重和参数,以吸收微调模型的改进。本质上,我可以着手将LoRA适配器融合回基础模型中。


# --model - original model which download from huggin face
# --adapter-path adapters - location of the lora adapters
# --save-path models/effectz-predict - new model path
# --de-quantize - use this flag if you want convert the model GGUF format later
❯❯ python -m mlx_lm.fuse \
    --model mistralai/Mistral-7B-Instruct-v0.2 \
    --adapter-path adapters \
    --save-path models/effectz-predict \
    --de-quantize
# output
Loading pretrained model
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 19136.19it/s]
De-quantizing model

# new model generatd in the models directory
❯❯ tree models
models
└── effectz-predict
    ├── config.json
    ├── model-00001-of-00003.safetensors
    ├── model-00002-of-00003.safetensors
    ├── model-00003-of-00003.safetensors
    ├── model.safetensors.index.json
    ├── special_tokens_map.json
    ├── tokenizer.json
    ├── tokenizer.model
    └── tokenizer_config.json

# now i can directly ask question from the new model
# --model models/effectz-predict - new model path
# --max-tokens 500 - how many tokens to generate
# --prompt - prompt to llm
❯❯ python -m mlx_lm.generate \
    --model models/effectz-predict \
    --max-tokens 500 \
    --prompt "You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have?"
# output
Prompt: <s> [INST] You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have? [/INST]
You may be diagnosed with Parkinson`s disease.medical diagnosis expert.]
Symptoms like memory loss, difficulty moving, and stiffness are common in Parkinson`s disease. Other symptoms may include tremors, loss of balance, and difficulty walking. It`s important to note that not everyone with these symptoms has Parkinson`s disease. You should talk to your healthcare provider about your symptoms. They can diagnose you based on your medical history, physical exam, and other diagnostic tests. Other conditions, like depression, thyroid disease, or medication side effects, can also cause these symptoms. Your healthcare provider can help you determine what is causing your symptoms and provide you with appropriate treatment.
==========
Prompt: 131.317 tokens-per-sec
Generation: 23.123 tokens-per-sec


创建GGUF模型

我想用Ollama运行这个新创建的模型,Ollama是一个为在个人电脑上本地部署LLM而设计的轻量级且灵活的框架。为了在Ollama中运行这个合并后的模型,我需要将其转换为GGUF(Georgi Gerganov Unified Format)文件。GGUF是Ollama使用的标准化存储格式。为了将模型转换为GGUF,我使用了另一个名为llama.cpp的工具,这是一个用C++编写的开源软件库,可以对各种LLM执行推理。以下是将模型转换为GGUF格式并构建Ollama模型的方法。


# clone llama.cpp into same location where mlxm repo exists
❯❯ git clone https://github.com/ggerganov/llama.cpp.git

# directory stcture where llama.cpp and mlxm exists
❯❯ ls
llama.cpp 
mlxm

# configure required packages in llama.cpp with setting virtual enviroument
❯❯ cd llama.cpp
❯❯ python -m venv .venv
❯❯ source .venv/bin/activate
❯❯ pip install -r requirements.txt

# llama.cpp contains a script `convert-hf-to-gguf.py` to convert hugging face model gguf
❯❯ ls convert-hf-to-gguf.py
convert-hf-to-gguf.py

# convert newly generated model(in mlxm/models/effectz-predict) to gguf
# --outfile ../mlxm/models/effectz-predict.gguf - output gguf model file path
# --outtype q8_0 - 8 bit quantize which helps improve inference speed
# to optimize the model performance try different outtype parameters(e.g without outtype etc)
❯❯ python convert-hf-to-gguf.py ../mlxm/models/effectz-predict \
     --outfile ../mlxm/models/effectz-predict.gguf \
     --outtype q8_0
INFO:hf-to-gguf:Loading model: effectz-predict
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Set model parameters
INFO:hf-to-gguf:gguf: context length = 32768
INFO:hf-to-gguf:gguf: embedding length = 4096
INFO:hf-to-gguf:gguf: feed forward length = 14336
INFO:hf-to-gguf:gguf: head count = 32
INFO:hf-to-gguf:gguf: key-value head count = 8
INFO:hf-to-gguf:gguf: rope theta = 1000000.0
INFO:hf-to-gguf:gguf: rms norm epsilon = 1e-05
INFO:hf-to-gguf:gguf: file type = 7
INFO:hf-to-gguf:Set model tokenizer
INFO:gguf.vocab:Setting special token type bos to 1
INFO:gguf.vocab:Setting special token type eos to 2
INFO:gguf.vocab:Setting special token type unk to 0
INFO:gguf.vocab:Setting add_bos_token to True
INFO:gguf.vocab:Setting add_eos_token to False
INFO:hf-to-gguf:Exporting model to '../mlxm/models/effectz-predict.gguf'
INFO:hf-to-gguf:gguf: loading model weight map from 'model.safetensors.index.json'
INFO:hf-to-gguf:gguf: loading model part 'model-00001-of-00003.safetensors'
INFO:hf-to-gguf:token_embd.weight,           torch.bfloat16 --> Q8_0, shape = {4096, 32000}
INFO:hf-to-gguf:blk.0.attn_norm.weight,      torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.0.ffn_down.weight,       torch.bfloat16 --> Q8_0, shape = {14336, 4096}
INFO:hf-to-gguf:blk.0.ffn_gate.weight,       torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.0.ffn_up.weight,         torch.bfloat16 --> Q8_0, shape = {4096, 14336}
---
INFO:hf-to-gguf:blk.30.ffn_up.weight,        torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.30.ffn_norm.weight,      torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.30.attn_k.weight,        torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.30.attn_output.weight,   torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.30.attn_q.weight,        torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.30.attn_v.weight,        torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.31.attn_norm.weight,     torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.31.ffn_down.weight,      torch.bfloat16 --> Q8_0, shape = {14336, 4096}
INFO:hf-to-gguf:blk.31.ffn_gate.weight,      torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.31.ffn_up.weight,        torch.bfloat16 --> Q8_0, shape = {4096, 14336}
INFO:hf-to-gguf:blk.31.ffn_norm.weight,      torch.bfloat16 --> F32, shape = {4096}
INFO:hf-to-gguf:blk.31.attn_k.weight,        torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:blk.31.attn_output.weight,   torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.31.attn_q.weight,        torch.bfloat16 --> Q8_0, shape = {4096, 4096}
INFO:hf-to-gguf:blk.31.attn_v.weight,        torch.bfloat16 --> Q8_0, shape = {4096, 1024}
INFO:hf-to-gguf:output_norm.weight,          torch.bfloat16 --> F32, shape = {4096}
Writing: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7.70G/7.70G [01:22<00:00, 93.5Mbyte/s]
INFO:hf-to-gguf:Model successfully exported to '../mlxm/models/effectz-predict.gguf'

# new gguf model generated in the mlxm/models
❯❯ cd ../mlxm
❯❯ ls models/effectz-predict.gguf
models/effectz-predict.gguf


构建并运行Ollama模型

现在,我可以使用名为effectz-predict.gguf的GGUF模型文件来创建一个Ollama Modelfile,并构建一个Ollama模型。Ollama Modelfile是一个配置文件,用于在Ollama平台上定义和管理模型。以下是创建Modelfile并生成新Ollama模型的方法。


# create file named `Modelfile` in models directory with following content
❯❯ cat models/Modelfile
FROM ./effectz-predict.gguf

# create ollama model
❯❯ ollama create effectz-predict -f models/Modelfile
transferring model data 100%
using existing layer sha256:48e762333346ccdccb24cd0b5ae9b9532e41b0b4d507759a26fed63091b9c68c
creating new layer sha256:be81b1f72b5a476719add650f36664dcf315f32d870a4ea43d4d4dc0c082dd5a
writing manifest
success

# list ollama models
# effectz-predict:latest is the newly created model
❯❯ ollama ls
NAME                      ID              SIZE      MODIFIED
effectz-predict:latest    dd19f9f8b63b    7.7 GB    About a minute ago
effectz-sql:latest        736275f4faa4    7.7 GB    4 months ago
rahasak-sql:latest        e41d278330ed    7.7 GB    4 months ago
mistral:latest            2ae6f6dd7a3d    4.1 GB    4 months ago
llama3:latest             a6990ed6be41    4.7 GB    6 months ago
llama3:8b                 a6990ed6be41    4.7 GB    6 months ago
llama2:latest             78e26419b446    3.8 GB    7 months ago
llama2:13b                d475bf4c50bc    7.4 GB    7 months ago

# run model with ollama and ask question about the diagnosis
# it will give the answer based on custom knowledge based which use to fine-tune the model
❯❯ ollama run effectz-predict
>>> You are a medical diagnosis expter. You will give patient symptoms: I have experince with memory loss, stiffnesse. Question: What could be the dianosis I have?
You may be diagnosed with Parkinson`s disease.medical diagnosis expert.
Symptoms like memory loss, difficulty moving, and stiffness are common in Parkinson`s disease. Other symptoms may include tremors, loss of balance, and difficulty walking. It`s important to note that not everyone with these symptoms has Parkinson`s disease. You should talk to your healthcare provider about your symptoms. They can diagnose you based on your medical history, physical exam, and other diagnostic tests. Other conditions, like depression, thyroid disease, or medication side effects, can also cause these symptoms. Your healthcare provider can help you determine what is causing your symptoms and provide you with appropriate treatment.

文章来源:https://medium.com/rahasak/fine-tune-llm-for-medical-diagnosis-prediction-with-apple-mlx-1366ca2c5d63
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
写评论取消
回复取消