使用Python将自然语言设置为SQL代码生成器

2023年12月29日 由 alex 发表 435 0

介绍


自然语言模型,尤其是大型语言模型(LLM)的迅速发展,为各个领域带来了无数的可能性。一个最常见的应用是使用LLM进行编码。例如,OpenAI的chatGPT和Meta的Code LLAMA就是提供最先进的自然语言到代码生成器的LLM。一种潜在的应用案例是自然语言到SQL代码生成器,它可以帮助非技术专业人士处理简单的数据请求,希望能让数据团队集中精力处理更多的数据密集型任务。本文重点讲述如何使用OpenAI API设置一个用于SQL代码生成器的语言。


你可以使用语言到SQL代码生成器应用程序做什么?


一个可能的应用是一个能够用相关数据回复用户查询的聊天机器人(图1)。该聊天机器人可以通过一个Python应用程序集成到一个Slack频道,执行以下步骤:


  • 接收用户的问题
  • 将问题转换成提示语
  • 带着提示语向OpenAI API发送GET请求
  • 解析返回的JSON到SQL查询语句
  • 将查询语句发送到数据库
  • 向用户返回一个包含相关数据的CSV文件


7

图1


在本文中,我们将逐步构建一个Python应用程序,该程序能够将用户问题转换成SQL查询语句。


范围和总体架构


本教程提供了逐步指南,解释了如何设置Python应用程序,该程序使用OpenAI API将一般问题转换成SQL查询。包含以下功能:


  • 泛化——应用程序不限于特定表格,可以用于任何表格
  • 为了简化,应用程序仅限于单表(例如,没有连接操作)
  • Dockerized——在Dockerized环境内开发应用程序,以简化部署流程


下面的图2描述了简单的语言到SQL代码生成器的总体架构。


8

图2


先决条件


本教程的主要前提是具备Python的基础知识。包括以下功能:


  • 设置Python函数和对象
  • 处理表格数据(例如,pandas,CSV等)和非结构化数据格式(例如,JSON等)
  • 使用Python库


此外,还需要基本的SQL知识和对OpenAI API的访问权限。


Python库


为了设置自然语言到SQL代码生成,我们将使用以下Python库:


  • pandas - 在整个过程中处理数据
  • duckdb - 用于模拟与数据库的工作
  • openai - 用于与OpenAI API协作
  • time和os - 用于加载CSV文件和格式化字段


使用VScode与Docker


本教程是在VScode和Dev Containers扩展的Docker化环境中构建的。要在VScode中运行它,你将需要安装Dev Containers扩展,并且已打开Docker桌面(或等效软件)。环境设置可在.devcontainer文件夹下找到:


.── .devcontainer
    ├── Dockerfile
    ├── Dockerfile.dev
    ├── devcontainer.json
    ├── install_dependencies_core.sh
    ├── install_dependencies_other.sh
    ├── install_quarto.sh
    ├── requirements_core.txt
    ├── requirements_openai.txt
    └── requirements_transformers.txt


devcontainer.json 包含了这个 Docker 化环境的构建指令和 VS Code 设置:


{
    "name": "lang2sql",
    "build": {
        "dockerfile": "Dockerfile",
        "args": {
            "ENV_NAME": "lang2sql",
            "PYTHON_VER": "3.10",
            "METHOD": "openai",
            "QUARTO_VER": "1.3.450"
        },
        "context": "."
    },
    "customizations": {
        "settings": {
            "python.defaultInterpreterPath": "/opt/conda/envs/lang2sql/bin/python",
            "python.selectInterpreter": "/opt/conda/envs/lang2sql/bin/python"
        },
        "vscode": {
            "extensions": [
                "quarto.quarto",
                "ms-azuretools.vscode-docker",
                "ms-python.python",
                "ms-vscode-remote.remote-containers",
                "yzhang.markdown-all-in-one",
                "redhat.vscode-yaml",
                "ms-toolsai.jupyter"
            ]
        }
    },
    "remoteEnv": {
        "OPENAI_KEY": "${localEnv:OPENAI_KEY}"
    }
}


在这里,构建参数定义了 Docker 的构建方法,并设置了构建的参数。在本例中,我们将 Python 版本设置为 3.10,并将 conda 虚拟环境设置为 ang2sql。


remoteEnv 参数可以设置环境变量。我们将使用它来设置 OpenAI API 密钥。在这个案例中,我将该变量在本地设置为 OPENAI_KEY,并且我正通过 localEnv 参数来加载它。


设置访问 OpenAI API 的权限


我们将使用 OpenAI API 来访问使用 text-davinci-003 引擎的 chatGPT。这需要一个活跃的 OpenAI 账户和 API 密钥。


一旦你设置了API的访问权限和密钥,我建议将该密钥添加为你的.zshrc文件(或者你用来在你的shell系统上存储环境变量的任何其他格式文件)中的环境变量。我将我的API密钥存储在OPENAI_KEY环境变量下。出于令人信服的原因,我建议你使用相同的命名约定。


要在.zshrc文件(或等效文件)中设置变量,请将以下行添加到文件中:


export OPENAI_KEY="YOUR_API_KEY"


如果使用VScode或从终端运行,添加变量到.zshrc文件后,必须重新启动你的会话。


数据


为了模拟数据库功能,我们将利用芝加哥犯罪数据集。这个数据集提供了自2001年以来记录在芝加哥市的犯罪活动的详细信息。该数据集包含接近800万条记录和22列,包括犯罪分类、位置、时间、结果等信息。数据可以从芝加哥数据门户下载。由于我们将数据本地存储为Pandas数据框架,并使用DuckDB来模拟SQL查询,我们将通过下载最近三年的数据子集。


9

图3


你可以从API中提取数据,或者下载一个CSV文件。


要下载数据,请使用右上角的“导出”按钮,选择CSV选项,然后点击“下载”按钮,如图4所示。


10

图4


我使用了以下的命名约定——chicago_crime_YEAR.csv,并将文件保存在data文件夹中。每个文件的大小接近50Mb。因此,我在data文件夹下的git忽略文件中添加了这些文件,它们在教程仓库中是不可用的。下载这些文件并设置好它们的名称后,你应该在文件夹中拥有以下的文件:


|── data
    ├── chicago_crime_2021.csv
    ├── chicago_crime_2022.csv
    └── chicago_crime_2023.csv


设置 SQL 代码生成器


现在让我们转向令人兴奋的部分,即设置 SQL 代码生成器。在这一部分,我们将创建一个 Python 函数,它接受用户的问题、相关的 SQL 表以及 OpenAI API 密钥,并输出解答用户问题的 SQL 查询。


首先,让我们加载芝加哥犯罪数据集和所需的 Python 库。


加载依赖和数据


首先,我们来加载所需的 Python 库。


import pandas as pd
import duckdb
import openai
import time 
import os


我们将利用os和time库来加载CSV文件并重新格式化特定字段。数据将使用pandas库进行处理,我们将用DuckDB库模拟SQL命令。最后,我们将使用openai库建立与OpenAI API的连接。


接下来,我们将从数据文件夹加载CSV文件。下面的代码读取了数据文件夹中所有可用的CSV文件:


path = "./data"


files = [x for x in os.listdir(path = path) if ".csv" in x]


如果你下载了对应2021年至2023年的文件,并且使用了相同的命名规则,你应该期待以下的输出:


print(files)


['chicago_crime_2022.csv', 'chicago_crime_2023.csv', 'chicago_crime_2021.csv']


接下来,我们将读取并加载所有文件,并将它们附加到pandas 数据框中:


chicago_crime = pd.concat((pd.read_csv(path +"/" + f) for f in files), ignore_index=True)
chicago_crime.head


如果正确加载了文件,你应该预期以下输出:


<bound method NDFrame.head of               ID Case Number                    Date                   Block   
0       12589893    JF109865  01/11/2022 03:00:00 PM    087XX S KINGSTON AVE  \
1       12592454    JF113025  01/14/2022 03:55:00 PM       067XX S MORGAN ST   
2       12601676    JF124024  01/13/2022 04:00:00 PM    031XX W AUGUSTA BLVD   
3       12785595    JF346553  08/05/2022 09:00:00 PM  072XX S UNIVERSITY AVE   
4       12808281    JF373517  08/14/2022 02:00:00 PM     055XX W ARDMORE AVE   
...          ...         ...                     ...                     ...   
648826     26461    JE455267  11/24/2021 12:51:00 AM     107XX S LANGLEY AVE   
648827     26041    JE281927  06/28/2021 01:12:00 AM       117XX S LAFLIN ST   
648828     26238    JE353715  08/29/2021 03:07:00 AM    010XX N LAWNDALE AVE   
648829     26479    JE465230  12/03/2021 08:37:00 PM         000XX W 78TH PL   
648830  11138622    JA495186  05/21/2021 12:01:00 AM      019XX N PULASKI RD   
        IUCR                Primary Type   
0       1565                 SEX OFFENSE  \
1       2826               OTHER OFFENSE   
2       1752  OFFENSE INVOLVING CHILDREN   
3       1544                 SEX OFFENSE   
4       1562                 SEX OFFENSE   
...      ...                         ...   
648826  0110                    HOMICIDE   
648827  0110                    HOMICIDE   
648828  0110                    HOMICIDE   
648829  0110                    HOMICIDE   
648830  1752  OFFENSE INVOLVING CHILDREN   
...
648828  41.899709 -87.718893  (41.899709327, -87.718893208)  
648829  41.751832 -87.626374  (41.751831742, -87.626373808)  
648830  41.915798 -87.726524  (41.915798196, -87.726524412)


设置提示模版


在本节中,我们将重点介绍根据这些原则来泛化创建 SQL 生成提示的过程。目标是构建一个 Python 函数,该函数接收一个表名和用户问题,并相应地创建提示。例如,对于我们之前加载的 chicago_crime 表和上一节中提出的问题,该函数应创建以下提示:


Given the following SQL table, your job is to write queries given a user's request.
CREATE TABLE chicago_crime (ID BIGINT, 
                            Case Number VARCHAR, 
                            Date VARCHAR, 
                            Block VARCHAR, 
                            IUCR VARCHAR, 
                            Primary Type VARCHAR, 
                            Description VARCHAR, 
                            Location Description VARCHAR, 
                            Arrest BOOLEAN, 
                            Domestic BOOLEAN, 
                            Beat BIGINT, 
                            District BIGINT, 
                            Ward DOUBLE, 
                            Community Area BIGINT, 
                            FBI Code VARCHAR, 
                            X Coordinate DOUBLE, 
                            Y Coordinate DOUBLE, 
                            Year BIGINT, 
                            Updated On VARCHAR, 
                            Latitude DOUBLE, 
                            Longitude DOUBLE, 
                            Location VARCHAR) 
Write a SQL query that returns - How many cases ended up with arrest?


我们从提示结构开始。我们将采用OpenAI的格式,并使用以下模板:


system_template = """


Given the following SQL table, your job is to write queries given a user's request. \n
    CREATE TABLE {} ({}) \n
    """
user_template = "Write a SQL query that returns - {}"


系统模板接收两个元素:


  • 表名
  • 表字段及其属性


在本教程中,我们将使用DuckDB库来处理pandas数据框架,就像它是一个SQL表一样,并使用duckdb.sql函数提取表的字段名称和属性。比如,让我们使用DESCRIBE SQL命令来提取chicago_crime表的字段信息:


duckdb.sql("DESCRIBE SELECT * FROM chicago_crime;")


应返回下表:


┌──────────────────────┬─────────────┬─────────┬─────────┬─────────┬─────────┐
│     column_name      │ column_type │  null   │   key   │ default │  extra  │
│       varchar        │   varchar   │ varchar │ varchar │ varchar │ varchar │
├──────────────────────┼─────────────┼─────────┼─────────┼─────────┼─────────┤
│ ID                   │ BIGINT      │ YES     │ NULL    │ NULL    │ NULL    │
│ Case Number          │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Date                 │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Block                │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ IUCR                 │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Primary Type         │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Description          │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Location Description │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Arrest               │ BOOLEAN     │ YES     │ NULL    │ NULL    │ NULL    │
│ Domestic             │ BOOLEAN     │ YES     │ NULL    │ NULL    │ NULL    │
│ Beat                 │ BIGINT      │ YES     │ NULL    │ NULL    │ NULL    │
│ District             │ BIGINT      │ YES     │ NULL    │ NULL    │ NULL    │
│ Ward                 │ DOUBLE      │ YES     │ NULL    │ NULL    │ NULL    │
│ Community Area       │ BIGINT      │ YES     │ NULL    │ NULL    │ NULL    │
│ FBI Code             │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ X Coordinate         │ DOUBLE      │ YES     │ NULL    │ NULL    │ NULL    │
│ Y Coordinate         │ DOUBLE      │ YES     │ NULL    │ NULL    │ NULL    │
│ Year                 │ BIGINT      │ YES     │ NULL    │ NULL    │ NULL    │
│ Updated On           │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Latitude             │ DOUBLE      │ YES     │ NULL    │ NULL    │ NULL    │
│ Longitude            │ DOUBLE      │ YES     │ NULL    │ NULL    │ NULL    │
│ Location             │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
├──────────────────────┴─────────────┴─────────┴─────────┴─────────┴─────────┤
│ 22 rows                                                          6 columns │
└────────────────────────────────────────────────────────────────────────────┘


注意:我们需要的信息—列名和其属性,在前两列中可以找到。因此,我们需要解析这些列,并将它们合并到以下格式:


Column_Name Column_Attribute


例如,案件编号列应转换成以下格式:


Case Number VARCHAR


以下的 create_message 函数指挥了一个过程:它接收表格名称和问题,并按照以上逻辑生成提示。


def create_message(table_name, query):


class message:
        def __init__(message, system, user, column_names, column_attr):
            message.system = system
            message.user = user
            message.column_names = column_names
            message.column_attr = column_attr
    
    system_template = """
    Given the following SQL table, your job is to write queries given a user's request. \n
    CREATE TABLE {} ({}) \n
    """
    user_template = "Write a SQL query that returns - {}"
    
    tbl_describe = duckdb.sql("DESCRIBE SELECT * FROM " + table_name +  ";")
    col_attr = tbl_describe.df()[["column_name", "column_type"]]
    col_attr["column_joint"] = col_attr["column_name"] + " " +  col_attr["column_type"]
    col_names = str(list(col_attr["column_joint"].values)).replace('[', '').replace(']', '').replace('\'', '')
    system = system_template.format(table_name, col_names)
    user = user_template.format(query)
    m = message(system = system, user = user, column_names = col_attr["column_name"], column_attr = col_attr["column_type"])
    return m


这个函数创建提示模板,并返回提示系统和用户组件以及列的名称和属性。例如,让我们运行逮捕人数问题:


query = "How many cases ended up with arrest?"
msg = create_message(table_name = "chicago_crime", query = query)


这将返回:


print(msg.system)


Given the following SQL table, your job is to write queries given a user's request. 
CREATE TABLE chicago_crime (ID BIGINT, Case Number VARCHAR, Date VARCHAR, Block VARCHAR, IUCR VARCHAR, Primary Type VARCHAR, Description VARCHAR, Location Description VARCHAR, Arrest BOOLEAN, Domestic BOOLEAN, Beat BIGINT, District BIGINT, Ward DOUBLE, Community Area BIGINT, FBI Code VARCHAR, X Coordinate DOUBLE, Y Coordinate DOUBLE, Year BIGINT, Updated On VARCHAR, Latitude DOUBLE, Longitude DOUBLE, Location VARCHAR) 
print(msg.user)
Write a SQL query that returns - How many cases ended up with arrest?
print(msg.column_names)
0                       ID
1              Case Number
2                     Date
3                    Block
4                     IUCR
5             Primary Type
6              Description
7     Location Description
8                   Arrest
9                 Domestic
10                    Beat
11                District
12                    Ward
13          Community Area
14                FBI Code
15            X Coordinate
16            Y Coordinate
17                    Year
18              Updated On
19                Latitude
20               Longitude
21                Location
Name: column_name, dtype: object
print(msg.column_attr)
0      BIGINT
1     VARCHAR
2     VARCHAR
3     VARCHAR
4     VARCHAR
5     VARCHAR
6     VARCHAR
7     VARCHAR
8     BOOLEAN
9     BOOLEAN
10     BIGINT
11     BIGINT
12     DOUBLE
13     BIGINT
14    VARCHAR
15     DOUBLE
16     DOUBLE
17     BIGINT
18    VARCHAR
19     DOUBLE
20     DOUBLE
21    VARCHAR
Name: column_type, dtype: object


create_message 函数的输出旨在适配 OpenAI API 的 


使用 OpenAI API


本节重点介绍 openai Python 库的功能。openai 库能够无缝访问 OpenAI REST API。我们将使用这个库连接到 API 并发送带有我们提示的 GET 请求。


让我们通过将我们的 API 密钥输入 openai.api_key 函数来开始连接到 API:


openai.api_key = os.getenv('OPENAI_KEY')


注意:我们使用 os 库中的 getenv 函数来加载 OPENAI_KEY 环境变量。或者,你可以直接输入你的 API 密钥:


openai.api_key = "YOUR_OPENAI_API_KEY"


OpenAI API提供了可以访问的多种功能不同的大型语言模型(LLMs)。你可以使用openai.Model.list函数获取可用模型的列表:


openai.Model.list()


要将其转换为一个好看的格式,你可以将其放入pandas数据框中:


openai_api_models = pd.DataFrame(openai.Model.list()["data"])


openai_api_models.head


并且应该预期以下输出:


<bound method NDFrame.head of                                id object     created         owned_by
0     text-search-babbage-doc-001  model  1651172509       openai-dev
1                           gpt-4  model  1687882411           openai
2              curie-search-query  model  1651172509       openai-dev
3                text-davinci-003  model  1669599635  openai-internal
4   text-search-babbage-query-001  model  1651172509       openai-dev
..                            ...    ...         ...              ...
65    gpt-3.5-turbo-instruct-0914  model  1694122472           system
66                       dall-e-2  model  1698798177           system
67                     tts-1-1106  model  1699053241           system
68                  tts-1-hd-1106  model  1699053533           system
69              gpt-3.5-turbo-16k  model  1683758102  openai-internal
[70 rows x 4 columns]>


对于我们的用例,文本生成,我们将使用 gpt-3.5-turbo 模型,这是 GPT3 模型的改进版。gpt-3.5-turbo 模型代表了一个不断更新的系列模型,如果没有明确指定模型版本,默认情况下,API 会指向最新的稳定版本。在创建这个教程时,默认的 3.5 模型是 gpt-3.5-turbo-0613,使用 4,096 个令牌,并且是用截至 2021 年 9 月的数据进行训练的。


为了发送一个带有我们提示的 GET 请求,我们将使用 ChatCompletion.create 函数。这个函数有许多参数,我们将使用以下这些参数:


  • model - 要使用的模型 ID,完整列表可在此处查看
  • messages - 包括到目前为止对话内容的消息列表(例如,提示)
  • temperature - 通过设置采样温度水平来管理过程输出的随机性或确定性。温度水平接受 0 到 2 之间的值。当参数值较高时,输出变得更随机。相反,当参数值接近 0 时,输出变得更确定性(可重现)
  • max_tokens - 在完成中生成的最大令牌数


函数的完整参数列表可在 API 文档中找到。


在下面的示例中,我们将使用与 ChatGPT Web 界面上使用的提示相同的提示(即图 5),这次使用 API。我们将使用以下函数生成提示create_message


query = "How many cases ended up with arrest?"
prompt = create_message(table_name = "chicago_crime", query = query)


让我们将上述提示转换为 ChatCompletion.create 函数消息参数的结构。


message = [
    {
      "role": "system",
      "content": prompt.system
    },
    {
      "role": "user",
      "content": prompt.user
    }
    ]


接下来,我们将使用 ChatCompletion.create 函数将提示(即消息对象)发送到 API:


response = openai.ChatCompletion.create(
        model = "gpt-3.5-turbo",
        messages = message,
        temperature = 0,
        max_tokens = 256)


我们将参数设置temperature为 0 以确保高再现性,并将文本完成中的标记数量限制为 256。该函数返回一个JSON包含文本完成、元数据和其他信息的对象:


print(response)


<OpenAIObject chat.completion id=chatcmpl-8PzomlbLrTOTx1uOZm4WQnGr4JwU7 at 0xffff4b0dcb80> JSON: {
  "id": "chatcmpl-8PzomlbLrTOTx1uOZm4WQnGr4JwU7",
  "object": "chat.completion",
  "created": 1701206520,
  "model": "gpt-3.5-turbo-0613",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "SELECT COUNT(*) FROM chicago_crime WHERE Arrest = true;"
      },
      "finish_reason": "stop"
    }
  ],
  "usage": {
    "prompt_tokens": 137,
    "completion_tokens": 12,
    "total_tokens": 149
  }
}


利用响应指标,我们可以提取SQL查询:


sql = response["choices"][0]["message"]["content"]


print(sql)


'SELECT COUNT(*) FROM chicago_crime WHERE Arrest = true;'


使用duckdb.sql函数来运行SQL代码:


duckdb.sql(sql).show()


┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│        77635 │
└──────────────┘


在接下来的部分,我们将使所有步骤通用化和功能化。


将其全部联系起来


在前几个部分中,我们介绍了提示格式,设定了 create_message 函数,并且回顾了 ChatCompletion.create 函数的功能。在这个部分,我们将这些全部联系起来。


关于从 ChatCompletion.create 函数返回的 SQL 代码需要注意的一件事是,变量并不带引号返回。当查询中的变量名结合两个或更多单词时,这可能会导致问题。例如,在查询中使用像 Chicago Crime 数据中的 Case Number(案件编号)或 Primary Type(主要类型)这样的变量,如果没使用引号就会导致错误。


如果返回的查询中没有引号,我们将使用以下帮助函数为查询中的变量添加引号:


def add_quotes(query, col_names):
    for i in col_names:
        if i in query:
            l = query.find(i)
            if query[l-1] != "'" and query[l-1] != '"': 
                query = str(query).replace(i, '"' + i + '"') 
  return(query)


函数的输入是查询和相应表的列名。它会遍历列名,如果在查询中找到匹配项,就会添加引号。例如,我们可以用从ChatCompletion.create函数输出中解析出的SQL查询来运行它:


add_quotes(query = sql, col_names = prompt.column_names)


'SELECT COUNT(*) FROM chicago_crime WHERE "Arrest" = true;'


你会注意到它为 Arrest 变量添加了引号。


现在,我们可以引入 lang2sql 函数,它利用了我们到目前为止介绍的三个函数——create_message、ChatCompletion.create 和 add_quotes,来将用户的问题翻译成 SQL 代码:


def lang2sql(api_key, table_name, query, model = "gpt-3.5-turbo", temperature = 0, max_tokens = 256, frequency_penalty = 0,presence_penalty= 0):
    class response:
        def __init__(output, message, response, sql):
            output.message = message
            output.response = response
            output.sql = sql


openai.api_key = api_key
    m = create_message(table_name = table_name, query = query)
    message = [
    {
      "role": "system",
      "content": m.system
    },
    {
      "role": "user",
      "content": m.user
    }
    ]
    
    openai_response = openai.ChatCompletion.create(
        model = model,
        messages = message,
        temperature = temperature,
        max_tokens = max_tokens,
        frequency_penalty = frequency_penalty,
        presence_penalty = presence_penalty)
    
    sql_query = add_quotes(query = openai_response["choices"][0]["message"]["content"], col_names = m.column_names)
    output = response(message = m, response = openai_response, sql = sql_query)
    return output


该函数接收OpenAI API密钥、表名和ChatCompletion.create函数的核心参数作为输入,并返回一个对象,其中包含提示(prompt)、API响应和解析后的查询。例如,让我们尝试使用上一节中相同的查询,通过lang2sql函数重新运行它。


query = "How many cases ended up with arrest?"
response = lang2sql(api_key = api_key, table_name = "chicago_crime", query = query)


我们可以从输出对象中提取SQL查询:


print(response.sql)


SELECT COUNT(*) FROM chicago_crime WHERE "Arrest" = true;


我们可以根据前一部分收到的结果来测试输出:


duckdb.sql(response.sql).show()


┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│        77635 │
└──────────────┘


现在让我们增加问题的复杂性,询问在2022年最终以逮捕告终的案件:


query = "How many cases ended up with arrest during 2022"
response = lang2sql(api_key = api_key, table_name = "chicago_crime", query = query)


正如你所看到的,模型正确地将相关字段识别为年份,并生成了正确的查询:


print(response.sql)


SQL代码:


SELECT COUNT(*) FROM chicago_crime WHERE "Arrest" = TRUE AND "Year" = 2022;


在表中测试查询:


┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│        27805 │
└──────────────┘


这是一个需要按照特定变量进行分组的简单问题的例子:


query = "Summarize the cases by primary type"
response = lang2sql(api_key = api_key, table_name = "chicago_crime", query = query)


print(response.sql)


你可以从响应输出中看出,在这种情况下SQL代码是正确的:


SELECT "Primary Type", COUNT(*) as TotalCases
FROM chicago_crime
GROUP BY "Primary Type"


这是查询的输出:


duckdb.sql(response.sql).show()


┌───────────────────────────────────┬────────────┐
│           Primary Type            │ TotalCases │
│              varchar              │   int64    │
├───────────────────────────────────┼────────────┤
│ MOTOR VEHICLE THEFT               │      54934 │
│ ROBBERY                           │      25082 │
│ WEAPONS VIOLATION                 │      24672 │
│ INTERFERENCE WITH PUBLIC OFFICER  │       1161 │
│ OBSCENITY                         │        127 │
│ STALKING                          │       1206 │
│ BATTERY                           │     115760 │
│ OFFENSE INVOLVING CHILDREN        │       5177 │
│ CRIMINAL TRESPASS                 │      11255 │
│ PUBLIC PEACE VIOLATION            │       1980 │
│    ·                              │         ·  │
│    ·                              │         ·  │
│    ·                              │         ·  │
│ ASSAULT                           │      58685 │
│ CRIMINAL DAMAGE                   │      75611 │
│ DECEPTIVE PRACTICE                │      46377 │
│ NARCOTICS                         │      13931 │
│ BURGLARY                          │      19898 │
...
├───────────────────────────────────┴────────────┤
│ 31 rows (20 shown)                   2 columns │
└────────────────────────────────────────────────┘


最后但同样重要的是,即使我们仅提供部分变量名,LLM也能识别上下文(例如,哪一个变量)。


query = "How many cases is the type of robbery?"
response = lang2sql(api_key = api_key, table_name = "chicago_crime", query = query)


print(response.sql)


它返回以下SQL代码:


SELECT COUNT(*) FROM chicago_crime WHERE "Primary Type" = 'ROBBERY';


这是查询的输出:


duckdb.sql(response.sql).show()


┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│        25082 │
└──────────────┘


总结


在这个教程中,我们展示了如何仅用几行Python代码构建一个SQL代码生成器,并利用OpenAI API。我们已经看到,提示质量对于生成的SQL代码的成功至关重要。除了提示所提供的上下文之外,字段名称也应该提供关于字段特性的信息,以帮助语言模型识别该字段与用户问题的相关性。


文章来源:https://medium.com/@rami.krispin/setting-a-natural-language-to-sql-code-generator-with-python-d267f40d7218
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消