利用Gradient.AI对RAG进行Llama2微调,提升医学推理能力

2024年01月25日 由 alex 发表 359 0

什么是RAG?


4


RAG 有什么帮助?


没有 RAG 的 LLM 只能了解它所训练的数据,也称为参数记忆(存储在模型参数中的信息)。


但有了 RAG,RAG 系统的检索器就可以访问外部信息源。因此,LLM 并不局限于其内部知识。外部来源可以是专有文件和数据,甚至是互联网。


检索器可以搜索和获取 LLM 不一定受过训练的信息。这将增加 LLM 的记忆,并在提示中作为上下文传递。也称为非参数记忆(模型参数之外的可用信息)


这里的工作在于创建知识库。有了经过适当策划的知识库,我们就能在以下方面占据一席之地:


  1. 将其扩展到所有来源
  2. 易于维护(更新/删除/插入)
  3. 与微调或再培训相比成本更低


通过向 LLM 提供上下文(检索到的额外信息),提高 LLM 响应的可信度。


RAG 的优势:


情境意识 来源引用 减少幻觉 增加的信息有助于 LLM 生成准确且与情境相适应的应答。


来源引用: 获取信息来源提高了法律硕士答复的透明度。


减少幻觉: 据观察,启用 RAG 的 LLM 系统比不启用 RAG 的系统更不易产生幻觉。


在这里,我们将使用经过医学推理微调的 LLM 实现一个开源 RAG。主要步骤包括:


  1. 加载模型和标记符
  2. 配置量化库
  3. 定义处理冗余的停止标准
  4. 构建文本生成管道
  5. 提示工程
  6. 生成索引
  7. 实例化检索器
  8. 提问和回答


我们将使用 Gradient.ai 进行微调。


什么是 Gradient?


Gradient 通过一个简单的微调和推理 Web API,让你可以轻松地个性化和构建开源 LLM。


它提供了一个具有简单网络 API 的平台,用于调整模型和生成补全。我们可以创建一个基础模型的私有实例,并在你的数据上对其进行指导,以查看它是如何实时学习的。


我们可以通过本地 CLI 以及 Python 和 Javascript SDK 访问网络 API。Gradient 为 RAG 完成了以下繁重的工作:


  1. 托管聊天模型
  2. 指导调整聊天模型
  3. 主机嵌入模型


一般术语


  • 指令调整(Instruct-tuning)是指根据大语言模型提供的指令,对其进行有监督的微调,使其与人类响应相一致,从而使其更有帮助。
  • 聊天模型的概念是跟踪所说的话,然后做出回应


利用渐变技术进行微调--技术栈


  1. gradient.ai: Gradient 工具可以让我们轻松地对模型进行微调,然后利用这些微调后的模型按代币定价。我们可以按需运行经过微调的模型。
  2. Llama-index : 它通过方便的封装器帮助对 LLM 进行微调。
  3. Langchain:开发 RAG 应用程序的协调框架


在这里,我们将在拥抱脸部的 mamachang/medical-reasoning 数据集上微调 Llama-2-7b-chat 模型。


在谷歌实验室(高内存)上实现:


5


代码实现


1. 注册


2. 创建一个新工作区


6


3. 安装所需的依赖项


!pip install llama-index -qU
!pip install langchain -qU
!pip install datasets -qU
!pip install gradientai -qU
!pip install chromadb pymupdf pypdf -qU


4. 设置所需的访问令牌


from google.colab import userdata
import os
os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')
HF_TOKEN = userdata.get('HUGGINGFACE_API_KEY')
os.environ['HF_AUTH_TOKEN'] = HF_TOKEN
os.environ['GRADIENT_ACCESS_TOKEN'] = userdata.get('GRADIENT_ACCESS_TOKEN')
os.environ['GRADIENT_WORKSPACE_ID'] = userdata.get('GRADIENT_WORKSPACE_ID')


5. 加载数据集


from datasets import load_dataset
#
dataset = load_dataset("mamachang/medical-reasoning")
dataset
#### Response
Downloading data: 100%
7.26M/7.26M [00:00<00:00, 20.8MB/s]
Generating train split:
3702/0 [00:00<00:00, 19042.50 examples/s]
DatasetDict({
    train: Dataset({
        features: ['output', 'input', 'instruction'],
        num_rows: 3702
    })
})


6. 格式化数据集



def create_prompt(sample):
  system_prompt_template = """<s>
Below is an instruction that describes a task.Write a response that appropriately completes the request.
### Instruction :<<user_instruction>>
### Condition:<<user_condition>>
### Response:
<<user_response>>
</s>
"""
  user_message = sample['instruction']
  user_response = sample['output']
  user_condition = sample["input"]
  prompt_template = system_prompt_template.replace("<<user_instruction>>",f"{user_message}").replace("<<user_response>>",f"{user_response}").replace("<<user_condition>>",f"{user_condition}")
  return {"inputs":prompt_template}


print(create_prompt(dataset["train"][1])["inputs"])
##### Response
<s>
Below is an instruction that describes a task.Write a response that appropriately completes the request.
### Instruction :Please answer with one of the option in the bracket. Write reasoning in between <analysis></analysis>. Write answer in between <answer></answer>.
### Condition:Q:A 23-year-old man comes to the physician because of a 2-day history of profuse watery diarrhea and abdominal cramps. Four days ago, he returned from a backpacking trip across Southeast Asia. Physical examination shows dry mucous membranes and decreased skin turgor. Stool culture shows gram-negative, oxidase-positive, curved rods that have a single polar flagellum. The pathogen responsible for this patient's condition most likely has which of the following characteristics?? 
{'A': 'Acts by activation of guanylate cyclase', 'B': 'Causes necrosis of Peyer patches of distal ileum', 'C': 'Infection commonly precedes Guillain-Barré syndrome', 'D': 'Grows well in medium with pH of 9', 'E': 'Forms spores in unfavorable environment'},
### Response:
<analysis>
This is a clinical vignette describing a 23-year-old man with profuse watery diarrhea and abdominal cramps after returning from a backpacking trip in Southeast Asia. The description of gram-negative curved rods with a single polar flagellum on stool culture suggests the pathogen is Vibrio cholerae. 
The question asks about the characteristics of the pathogen responsible for this patient's condition. Based on the clinical presentation and culture results, the pathogen is most likely Vibrio cholerae.
</analysis>
<answer>
D: Grows well in medium with pH of 9
</answer>
</s>


7. 在数据集中映射格式化函数


instruct_tune_dataset = dataset.map(create_prompt)
print(instruct_tune_dataset['train']['inputs'][0])
####Response
<s>
Below is an instruction that describes a task.Write a response that appropriately completes the request.
### Instruction :Please answer with one of the option in the bracket. Write reasoning in between <analysis></analysis>. Write answer in between <answer></answer>.
### Condition:Q:An 8-year-old boy is brought to the pediatrician by his mother with nausea, vomiting, and decreased frequency of urination. He has acute lymphoblastic leukemia for which he received the 1st dose of chemotherapy 5 days ago. His leukocyte count was 60,000/mm3 before starting chemotherapy. The vital signs include: pulse 110/min, temperature 37.0°C (98.6°F), and blood pressure 100/70 mm Hg. The physical examination shows bilateral pedal edema. Which of the following serum studies and urinalysis findings will be helpful in confirming the diagnosis of this condition? ? 
{'A': 'Hyperkalemia, hyperphosphatemia, hypocalcemia, and extremely elevated creatine kinase (MM)', 'B': 'Hyperkalemia, hyperphosphatemia, hypocalcemia, hyperuricemia, urine supernatant pink, and positive for heme', 'C': 'Hyperuricemia, hyperkalemia, hyperphosphatemia, lactic acidosis, and urate crystals in the urine', 'D': 'Hyperuricemia, hyperkalemia, hyperphosphatemia, and urinary monoclonal spike', 'E': 'Hyperuricemia, hyperkalemia, hyperphosphatemia, lactic acidosis, and oxalate crystals'},
### Response:
<analysis>
This is a clinical vignette describing an 8-year-old boy with acute lymphoblastic leukemia who recently started chemotherapy and now presents with nausea, vomiting, decreased urination, bilateral pedal edema, and other vital sign changes. 
The question asks which serum studies and urinalysis findings would help confirm the diagnosis. Based on the clinical history, the main diagnostic consideration is tumor lysis syndrome, which can occur after starting chemotherapy in a patient with a high tumor burden. 
Tumor lysis syndrome leads to rapid cell breakdown and release of intracellular contents into the bloodstream. This results in hyperuricemia, hyperkalemia, hyperphosphatemia and acute kidney injury. The urine may show urate crystals. 
So the correct answer should include these key lab abnormalities of tumor lysis syndrome.
</analysis>
<answer>
C: Hyperuricemia, hyperkalemia, hyperphosphatemia, lactic acidosis, and urate crystals in the urine
</answer>
</s>


8. 过滤数据集


我们需要考虑那些在训练模型上下文窗口内长度为2048个令牌的样本


pruned_dataset = instruct_tune_dataset.filter(lambda x : len(x["inputs"]) <= 2000)
pruned_dataset
#####Response
Filter: 100%
3702/3702 [00:00<00:00, 36571.25 examples/s]
DatasetDict({
    train: Dataset({
        features: ['output', 'input', 'instruction', 'inputs'],
        num_rows: 1742
    })
})


9. 使用lama-index初始化基本模型


from llama_index.llms import GradientBaseModelLLM
base_model_slug ="llama2-7b-chat"
base_llm = GradientBaseModelLLM(
    base_model_slug=base_model_slug,
    max_tokens=300,
    )


10. 初始化微调引擎


Llama_index构建了一个方便的包装器,可用于在gradient.ai上设置调优作业


梯度所需的参数如下:


  • base_model_slug:在API和CLI表中参考的model_id


  • name:已调优模型的名称


  • data_path:这将指向格式化的jsonl文件,该文件将被GradientFinetuneEngine用于提取训练示例。


  • 详细信息:显示日志


  • Max_steps:没有。在步骤中,模型将被微调


  • batch_size:不设置。一次用来训练的例子


保存为JSONL


我们可以利用数据集库的to_json以所需的格式导出数据集。


for split,dataset in instruct_tune_dataset.items():
  dataset.to_json(f"instruct_tune_{split}.jsonl")


创建了两个 jsonl 数据集


  • instruct_tune_test.jsonl
  • instruct_tune_train.jsonl


从我们的 finetune_engine.model_adapter_id 获取我们的 model_adapter_id (finetune_engine.model_adapter_id)。


11. 指示微调 Llama2


调用 finetune_engine 上的 finetune() 方法,开始发送示例以微调梯度模型!


epochs = 1
for i in range(epochs):
    print(f"** EPOCH {i} **")
    finetune_engine.finetune()
##### Response
** EPOCH 0 **
fine-tuning step 4: loss=4959.452, trainable tokens=2589
fine-tuning step 8: loss=3593.1758, trainable tokens=2083
fine-tuning step 12: loss=2385.9756, trainable tokens=1830
fine-tuning step 16: loss=2695.1196, trainable tokens=2385
fine-tuning step 20: loss=2558.9072, trainable tokens=2440
fine-tuning step 24: loss=2290.3503, trainable tokens=2238
fine-tuning step 28: loss=2508.7234, trainable tokens=2378
fine-tuning step 32: loss=2189.0679, trainable tokens=2448
fine-tuning step 36: loss=2164.287, trainable tokens=2087
fine-tuning step 40: loss=2251.7786, trainable tokens=2495
fine-tuning step 44: loss=1972.8579, trainable tokens=2134
fine-tuning step 48: loss=1937.082, trainable tokens=2254
fine-tuning step 52: loss=1627.0519, trainable tokens=1949
fine-tuning step 56: loss=2383.2615, trainable tokens=2410
fine-tuning step 60: loss=2107.046, trainable tokens=2326
fine-tuning step 64: loss=1630.6199, trainable tokens=1854
fine-tuning step 68: loss=2251.8284, trainable tokens=2187
fine-tuning step 72: loss=2036.4097, trainable tokens=2104
fine-tuning step 76: loss=1668.2056, trainable tokens=1971
fine-tuning step 80: loss=2215.5273, trainable tokens=2309
fine-tuning step 84: loss=2422.0725, trainable tokens=2767
fine-tuning step 88: loss=1638.0425, trainable tokens=1942
fine-tuning step 92: loss=1922.6688, trainable tokens=2179
fine-tuning step 96: loss=2072.206, trainable tokens=2314
fine-tuning step 100: loss=2040.1426, trainable tokens=2251
fine-tuning step 104: loss=1866.227, trainable tokens=2227
fine-tuning step 108: loss=2585.6978, trainable tokens=2236
fine-tuning step 112: loss=2392.5225, trainable tokens=2587
fine-tuning step 116: loss=1603.4441, trainable tokens=2009
fine-tuning step 120: loss=2469.9912, trainable tokens=2824
fine-tuning step 124: loss=1888.2914, trainable tokens=2267
fine-tuning step 128: loss=2602.9531, trainable tokens=2869
fine-tuning step 132: loss=1808.0504, trainable tokens=2017
fine-tuning step 136: loss=1542.7957, trainable tokens=2150
fine-tuning step 140: loss=1573.2522, trainable tokens=1857
fine-tuning step 144: loss=1730.1132, trainable tokens=2112
fine-tuning step 148: loss=3049.8752, trainable tokens=3137
fine-tuning step 152: loss=2016.0774, trainable tokens=2407
fine-tuning step 156: loss=1836.7034, trainable tokens=2234
fine-tuning step 160: loss=2342.8235, trainable tokens=2579
fine-tuning step 164: loss=1974.2599, trainable tokens=2285
fine-tuning step 168: loss=1858.6886, trainable tokens=2264
fine-tuning step 172: loss=2210.064, trainable tokens=2478
fine-tuning step 176: loss=1957.0702, trainable tokens=2297
fine-tuning step 180: loss=1710.1389, trainable tokens=2235
fine-tuning step 184: loss=2187.9094, trainable tokens=2323
fine-tuning step 188: loss=2093.4937, trainable tokens=2406
fine-tuning step 192: loss=2085.185, trainable tokens=2524
fine-tuning step 196: loss=1915.3977, trainable tokens=2185
fine-tuning step 200: loss=1581.7871, trainable tokens=2164
fine-tuning step 204: loss=1809.4635, trainable tokens=2148
fine-tuning step 208: loss=2299.2432, trainable tokens=2480
fine-tuning step 212: loss=2014.1716, trainable tokens=2443
fine-tuning step 216: loss=2187.0588, trainable tokens=2499
fine-tuning step 220: loss=1644.8057, trainable tokens=1836
fine-tuning step 224: loss=2132.4263, trainable tokens=2583
fine-tuning step 228: loss=1930.3994, trainable tokens=2330
fine-tuning step 232: loss=2130.0046, trainable tokens=2368
fine-tuning step 236: loss=1942.6675, trainable tokens=2097
fine-tuning step 240: loss=1636.968, trainable tokens=2102
fine-tuning step 244: loss=2037.2612, trainable tokens=2302
fine-tuning step 248: loss=1995.2659, trainable tokens=2416
fine-tuning step 252: loss=2166.8884, trainable tokens=2550
fine-tuning step 256: loss=2104.5574, trainable tokens=2579
fine-tuning step 260: loss=1405.8171, trainable tokens=1899
fine-tuning step 264: loss=2151.6802, trainable tokens=2167
fine-tuning step 268: loss=1799.2125, trainable tokens=2169
fine-tuning step 272: loss=1645.2439, trainable tokens=2094
fine-tuning step 276: loss=2313.582, trainable tokens=2540
fine-tuning step 280: loss=1976.7007, trainable tokens=2526
fine-tuning step 284: loss=1849.638, trainable tokens=2358
fine-tuning step 288: loss=1732.6041, trainable tokens=2135
fine-tuning step 292: loss=1679.3423, trainable tokens=2123
fine-tuning step 296: loss=1607.5176, trainable tokens=1956
fine-tuning step 300: loss=1757.9156, trainable tokens=2023
fine-tuning step 304: loss=1925.5841, trainable tokens=2221
fine-tuning step 308: loss=1934.6556, trainable tokens=2191
fine-tuning step 312: loss=1903.065, trainable tokens=2194
fine-tuning step 316: loss=2150.2397, trainable tokens=2255
fine-tuning step 320: loss=2202.9558, trainable tokens=2166
fine-tuning step 324: loss=2365.9102, trainable tokens=2612
fine-tuning step 328: loss=2113.0376, trainable tokens=2576
fine-tuning step 332: loss=2092.775, trainable tokens=2342
fine-tuning step 336: loss=2440.4287, trainable tokens=2714
fine-tuning step 340: loss=1984.5017, trainable tokens=2406
fine-tuning step 344: loss=2349.6685, trainable tokens=2523
fine-tuning step 348: loss=1596.2339, trainable tokens=1949
fine-tuning step 352: loss=1445.9419, trainable tokens=1794
fine-tuning step 356: loss=1521.0352, trainable tokens=1946
fine-tuning step 360: loss=1807.5386, trainable tokens=2272
fine-tuning step 364: loss=1626.7179, trainable tokens=2039
fine-tuning step 368: loss=1605.5002, trainable tokens=2013
fine-tuning step 372: loss=2036.533, trainable tokens=2370
fine-tuning step 376: loss=2117.618, trainable tokens=2649
fine-tuning step 380: loss=1831.4807, trainable tokens=2294
fine-tuning step 384: loss=1587.9495, trainable tokens=1941
fine-tuning step 388: loss=2029.3451, trainable tokens=2508
fine-tuning step 392: loss=1826.5919, trainable tokens=2080
fine-tuning step 396: loss=1466.8333, trainable tokens=1980
fine-tuning step 400: loss=1770.755, trainable tokens=1840
fine-tuning step 404: loss=2135.3428, trainable tokens=2522
fine-tuning step 408: loss=2007.6233, trainable tokens=2648
fine-tuning step 412: loss=1782.7094, trainable tokens=1972
fine-tuning step 416: loss=1998.5812, trainable tokens=2372
fine-tuning step 420: loss=1929.1812, trainable tokens=2402
fine-tuning step 424: loss=1856.892, trainable tokens=1895
fine-tuning step 428: loss=2020.7694, trainable tokens=2096
fine-tuning step 432: loss=2098.26, trainable tokens=2498
fine-tuning step 436: loss=2208.5173, trainable tokens=2332
fine-tuning step 440: loss=1819.5647, trainable tokens=2244
fine-tuning step 444: loss=2065.7, trainable tokens=2311
fine-tuning step 448: loss=2151.265, trainable tokens=2322
fine-tuning step 452: loss=1497.5433, trainable tokens=2116
fine-tuning step 456: loss=2042.0454, trainable tokens=2544
fine-tuning step 460: loss=1520.771, trainable tokens=1951
fine-tuning step 464: loss=2356.9346, trainable tokens=2675
fine-tuning step 468: loss=2105.2083, trainable tokens=2378
fine-tuning step 472: loss=1509.7086, trainable tokens=2154
fine-tuning step 476: loss=2111.8745, trainable tokens=2294
fine-tuning step 480: loss=2435.3896, trainable tokens=2933
fine-tuning step 484: loss=1714.5544, trainable tokens=2100
fine-tuning step 488: loss=1355.1733, trainable tokens=1795
fine-tuning step 492: loss=2004.9135, trainable tokens=2480
fine-tuning step 496: loss=2011.1257, trainable tokens=2251
fine-tuning step 500: loss=1698.9839, trainable tokens=1920
fine-tuning step 504: loss=2990.1262, trainable tokens=3106
fine-tuning step 508: loss=1853.8329, trainable tokens=2325
fine-tuning step 512: loss=2169.289, trainable tokens=2559
fine-tuning step 516: loss=1327.3379, trainable tokens=1773
fine-tuning step 520: loss=2037.6965, trainable tokens=2197
fine-tuning step 524: loss=1500.7937, trainable tokens=2068
fine-tuning step 528: loss=1308.2484, trainable tokens=1754
fine-tuning step 532: loss=1524.1477, trainable tokens=2083
fine-tuning step 536: loss=1651.4478, trainable tokens=1886
fine-tuning step 540: loss=2417.032, trainable tokens=2786
fine-tuning step 544: loss=2433.3733, trainable tokens=2577
fine-tuning step 548: loss=1719.5312, trainable tokens=2181
fine-tuning step 552: loss=2426.927, trainable tokens=2767
fine-tuning step 556: loss=1860.657, trainable tokens=1993
fine-tuning step 560: loss=2392.6145, trainable tokens=2814
fine-tuning step 564: loss=2013.239, trainable tokens=2620
fine-tuning step 568: loss=2231.7397, trainable tokens=2708
fine-tuning step 572: loss=1892.1404, trainable tokens=2051
fine-tuning step 576: loss=2008.4565, trainable tokens=2296
fine-tuning step 580: loss=1960.9663, trainable tokens=2280
fine-tuning step 584: loss=2558.5637, trainable tokens=2956
fine-tuning step 588: loss=1944.1198, trainable tokens=2108
fine-tuning step 592: loss=2536.519, trainable tokens=2733
fine-tuning step 596: loss=1576.0658, trainable tokens=1965
fine-tuning step 600: loss=2182.4795, trainable tokens=2492
fine-tuning step 604: loss=1763.933, trainable tokens=2109
fine-tuning step 608: loss=1844.563, trainable tokens=2255
fine-tuning step 612: loss=1679.5659, trainable tokens=2128
fine-tuning step 616: loss=1950.8341, trainable tokens=2363
fine-tuning step 620: loss=1709.6396, trainable tokens=2097
fine-tuning step 624: loss=2243.0486, trainable tokens=2454
fine-tuning step 628: loss=1992.2363, trainable tokens=2468
fine-tuning step 632: loss=1782.529, trainable tokens=2167
fine-tuning step 636: loss=2002.4166, trainable tokens=2101
fine-tuning step 640: loss=2211.5, trainable tokens=2597
fine-tuning step 644: loss=1918.258, trainable tokens=2239
fine-tuning step 648: loss=1759.8679, trainable tokens=2156
fine-tuning step 652: loss=2180.1433, trainable tokens=2588
fine-tuning step 656: loss=2505.6838, trainable tokens=3010
fine-tuning step 660: loss=1774.7881, trainable tokens=2151
fine-tuning step 664: loss=2163.466, trainable tokens=2661
fine-tuning step 668: loss=2049.1968, trainable tokens=2714
fine-tuning step 672: loss=1852.1995, trainable tokens=2375
fine-tuning step 676: loss=2021.3531, trainable tokens=2655
fine-tuning step 680: loss=1821.9147, trainable tokens=2114
fine-tuning step 684: loss=1833.266, trainable tokens=2053
fine-tuning step 688: loss=1544.6099, trainable tokens=2055
fine-tuning step 692: loss=1829.1609, trainable tokens=2386
fine-tuning step 696: loss=1773.6494, trainable tokens=2201
fine-tuning step 700: loss=1875.5953, trainable tokens=2340
fine-tuning step 704: loss=1694.0093, trainable tokens=2175
fine-tuning step 708: loss=2086.6123, trainable tokens=2295
fine-tuning step 712: loss=1671.3391, trainable tokens=2092
fine-tuning step 716: loss=1695.1176, trainable tokens=2335
fine-tuning step 720: loss=1693.6235, trainable tokens=2083
fine-tuning step 724: loss=1976.8916, trainable tokens=2206
fine-tuning step 728: loss=1655.9978, trainable tokens=2128
fine-tuning step 732: loss=1809.5564, trainable tokens=2332
fine-tuning step 736: loss=1721.8125, trainable tokens=2061
fine-tuning step 740: loss=1859.0647, trainable tokens=2321
fine-tuning step 744: loss=1836.2949, trainable tokens=2294
fine-tuning step 748: loss=1573.923, trainable tokens=1851
fine-tuning step 752: loss=1901.0651, trainable tokens=2303
fine-tuning step 756: loss=2193.9092, trainable tokens=2606
fine-tuning step 760: loss=1981.604, trainable tokens=2205
fine-tuning step 764: loss=2381.479, trainable tokens=2450
fine-tuning step 768: loss=1790.0127, trainable tokens=2032
fine-tuning step 772: loss=2219.024, trainable tokens=2728
fine-tuning step 776: loss=2284.9045, trainable tokens=2584
fine-tuning step 780: loss=1561.394, trainable tokens=1860
fine-tuning step 784: loss=2072.1267, trainable tokens=2356
fine-tuning step 788: loss=2330.8984, trainable tokens=2475
fine-tuning step 792: loss=1941.3085, trainable tokens=2380
fine-tuning step 796: loss=1750.3411, trainable tokens=2263
fine-tuning step 800: loss=1928.2925, trainable tokens=2428


12. 托管梯度嵌入模型


from langchain.embeddings import GradientEmbeddings
embeddings = GradientEmbeddings(model="bge-large")


我们需要做的就是提供访问令牌和工作区 ID(之前的环境中应该已经有了),然后选择 BGE 嵌入模型(目前唯一支持的嵌入模型,不过还有更多即将推出),这样就大功告成了!


13. 创建由 Gradient 和 LangChain 支持的 RAG 管道


import gradientai
from langchain.llms import GradientLLM
client = gradientai.Gradient()
#
models = client.list_models(only_base=False)
for model in models:
  if "adapter" in model.id:
    print(model.id, model.name)

#Load the finetuned latest Gradient Model
llm = GradientLLM(
    model=models[-1].id,
    model_kwargs=dict(max_generated_token_count=500),
)


14. 复制培训模板


from langchain.prompts import PromptTemplate
template = """"
Below is an instruction that describes a task.Write a response that appropriately completes the request.
### Instruction {user_instruction}
### Condition:{user_condition}
### Response:
"""
prompt = PromptTemplate(template=template, input_variables=["user_instruction", "user_condition"])


15. 创建一个简单的 LLMChain,将我们的提示链入 LLM。


from langchain.chains import LLMChain
#
llm_chain = LLMChain(prompt=prompt, llm=llm)
#
user_instruction= 'Please answer with only one of the option in the bracket. Write reasoning in between <analysis></analysis>. Write answer in between <answer></answer>.'
user_condition = """A 9-month-old infant is brought the pediatrician for immunizations and assessment. His parents report that he is eating well and produces several wet diapers every day. He is a happy and curious child. The boy was born at 39 weeks gestation via spontaneous vaginal delivery. He is up to date on all vaccines and is meeting all developmental milestones. The infant’s vital signs are normal. Physical growth is appropriate for his age. The physician notes a loud holosystolic murmur at the left sternal border (grade IV) and orders an echocardiogram which confirms the diagnosis of congenital heart defect. Based on echocardiogram findings, the pediatrician reassures the parents that the infant will be monitored, but most likely will not require surgical intervention. user_instruction
response = llm_chain.run(user_instruction=user_instruction,user_condition=user_condition)
print(response)

##### REsponse
<analysis>
This is a clinical vignette describing a 9-month-old infant with a congenital heart defect. The key findings are a loud holosystolic murmur at the left sternal border (grade IV), which confirms the diagnosis of a congenital heart defect. The pediatrician reassures the parents that the infant will be monitored but most likely will not require surgical intervention. 
Based on the murmur and echocardiogram findings, the most likely diagnosis is a ventricular septal defect (VSD). VSDs are the most common congenital heart defects and can be diagnosed by echocardiogram. They occur when there is a defect in the ventricular septum, allowing blood to flow between the ventricles. VSDs can cause murmurs and can be managed conservatively with close monitoring.
Atrial septal defects (ASDs) are less common than VSDs and occur when there is a defect in the atrial septum. They can also cause murmurs but are less likely than VSDs given the murmur described.
Coarctation of aorta and tetralogy of Fallot are less likely given the murmur and echocardiogram findings. Patent ductus arteriosus (PDA) can cause murmurs but is less likely than VSD given the murmur described.
</analysis>
<answer>
B: Ventricular septal defect
</answer>


user_instruction='Please answer with one of the option in the bracket. Write reasoning in between <analysis></analysis>. Write answer in between <answer></answer>'
user_condition = """Q:A 55-year-old man presents with burning and shooting in his feet and lower legs, which becomes more severe at night. In the past 6 months, the pain has become much worse and disturbs his sleep. He has a history of type 2 diabetes mellitus and essential hypertension. Which of the following best represent the etiology of this patient’s condition?? \n{'A': 'Autonomic neuropathy', 'B': 'Isolated cranial nerve neuropathy', 'C': 'Isolated peripheral nerve neuropathy', 'D': 'Distal symmetric sensorimotor polyneuropathy', 'E': 'Radiculopathy'}"""
response = llm_chain.run(user_instruction=user_instruction,user_condition=user_condition)
print(response)
#### Response
<analysis>
This is a clinical vignette describing a 55-year-old man with burning and shooting in his feet and lower legs that worsens at night. He has a history of type 2 diabetes mellitus and essential hypertension. The key findings are burning and shooting in the feet and lower legs, which worsens at night, and a history of diabetes and hypertension. This presentation is most consistent with distal symmetric sensorimotor polyneuropathy, which is caused by diabetes and hypertension. Autonomic neuropathy, cranial nerve neuropathy, and peripheral nerve neuropathy can cause burning and shooting in the feet, but would not worsen at night. Radiculopathy is less likely given the lack of back pain.
</analysis>
<answer>
D: Distal symmetric sensorimotor polyneuropathy
</answer>


让我们制作一个简单的 RAG 提示,看看它的效果如何。


template = """"
Below is an instruction that describes a task.Write a response that appropriately completes the request.
### Instruction {user_instruction}
### Condition:{user_condition}
### Response:
"""
rag_prompt = PromptTemplate(template=template, input_variables=["user_instruction", "user_condition"])
#
llm_chain = rag_prompt | llm
#
user_instruction = "Please answer with one of the option in the bracket. Write reasoning in between <analysis></analysis>. Write answer in between <answer></answer>."
user_condition = '''Q:A 34-year-old man presents to a clinic with complaints of abdominal discomfort and blood in the urine for 2 days. He has had similar abdominal discomfort during the past 5 years, although he does not remember passing blood in the urine. He has had hypertension for the past 2 years, for which he has been prescribed medication. There is no history of weight loss, skin rashes, joint pain, vomiting, change in bowel habits, and smoking. On physical examination, there are ballotable flank masses bilaterally. The bowel sounds are normal. Renal function tests are as follows:\nUrea 50 mg/dL\nCreatinine 1.4 mg/dL\nProtein Negative\nRBC Numerous\nThe patient underwent ultrasonography of the abdomen, which revealed enlarged kidneys and multiple anechoic cysts with well-defined walls. A CT scan confirmed the presence of multiple cysts in the kidneys. What is the most likely diagnosis?? \n{'A': 'Autosomal dominant polycystic kidney disease (ADPKD)', 'B': 'Autosomal recessive polycystic kidney disease (ARPKD)', 'C': 'Medullary cystic disease', 'D': 'Simple renal cysts', 'E': 'Acquired cystic kidney disease'}'''
response = llm_chain.invoke({"user_instruction" :user_instruction, "user_condition" : user_condition})
print(response)
##### Response
<analysis>
Based on the information provided in the question stem, the key findings are:
- 34-year-old man with abdominal discomfort and blood in the urine for 2 days 
- History of similar abdominal discomfort over the past 5 years
- Hypertension for 2 years, on medication
- Physical exam shows flank masses and normal bowel sounds
- Lab tests show elevated BUN and creatinine, normal RBC
- Ultrasound shows enlarged kidneys and multiple cysts
- CT confirms cysts in kidneys
The enlarged kidneys and multiple cysts on ultrasound and CT are classic for autosomal dominant polycystic kidney disease (ADPKD). The flank masses and elevated BUN/creatinine point towards a diagnosis of ADPKD.
Autosomal recessive polycystic kidney disease (ARPKD) can cause flank masses and hypertension, but would not explain the ultrasound findings. Medullary cystic disease can cause flank masses but not the ultrasound findings. Simple renal cysts would not cause flank masses or elevated BUN/creatinine. Acquired cystic kidney disease is rare and would not explain the ultrasound findings.
</analysis>
<answer>
A: Autosomal dominant polycystic kidney disease (ADPKD)
</answer>


在 LangChain 中创建 RAG 链


加载数据


from langchain.document_loaders.pdf import PyPDFLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
#
docs = PyPDFLoader("/content/71763-gale-encyclopedia-of-medicine.-vol.-1.-2nd-ed.pdf").load()


将加载的文件分割成块


from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 1250,
    chunk_overlap = 100,
    length_function = len,
    is_separator_regex = False
)
#
split_docs = text_splitter.split_documents(docs)
print(len(split_docs))


设置矢量存储


vectorstore = Chroma.from_documents(split_docs,
                                     embedding=embeddings,
                                    persist_directory="./chromadb"
                                     )
vectorstore.persist()


创造检索器


retriever = vs.as_retriever()


使用 LCEL 创建 RAG 管道


from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
rag_chain = (
    {
        "user_condition" : retriever, "user_instruction" : RunnablePassthrough()
    }
    | rag_prompt
    | llm
    | StrOutputParser()
)


询问 1.


rag_chain.invoke("""A 60 year old man compliants of  blood in the urine, hearing loss and eye problems.
Which of the following best represent the etiology of this patient’s condition?? \n{'A': 'Alport syndrome', 'B': 'Isolated cranial nerve neuropathy', 'C': 'Isolated peripheral nerve neuropathy', 'D': 'Alzheimer’s disease', 'E': 'Altitude sickness'}""")
#### Response
A: Alport syndrome
This is the best answer because Alport syndrome is a hereditary disease of the kidneys that primarily affects men, causing blood in the urine, hearing loss, and eye problems. It can progress to chronic kidney failure and uremia. The other options can be ruled out: isolated cranial nerve neuropathy, isolated peripheral nerve neuropathy, Alzheimer's disease, and altitude sickness do not fit with the clinical presentation of Alport syndrome.


询问 2.


rag_chain.invoke("""An elderly female  compliants of memory loss and retain their sense of self.
Which of the following best represent the etiology of this patient’s condition?? \n{'A': 'Amnesia', 'B': 'Amniocentesis', 'C': 'Amylase tests', 'D': 'Alzheimer’s disease', 'E': 'Altitude sickness'}""")
##### Response
A) Amnesia
This is the best answer because amnesia is the most direct cause of memory loss in this scenario. The other options describe various conditions that can cause memory loss, but amnesia specifically refers to the inability to form new memories due to damage to the hippocampus. 


结论


在这里,我们通过对 Llama2-intstruct-7b 基本模型进行指令调整,实现了 RAG 管道。微调的全部目的是为了方便推理答案选项。


文章来源:https://medium.com/@nayakpplaban/fine-tune-rag-with-llama2-using-gradient-ai-on-medical-reasoning-task-395bca26a551
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消