英文

T5大模型用于文本到SQL

该模型的目的是根据自然语言提示生成结构化的SQL查询。

介绍

在Text2SQL任务中,模型学习如何根据用自然语言提出的问题生成一个SQL查询。然而,在某些情况下,SQL查询包含未知列等,并且完全不考虑特定数据库的模式。

这就是我们的方法所发挥作用的地方。我们在训练过程中将数据库模式合并到输入问题中,以指定哪些列和关系可用于生成可应用的SQL查询。

数据库模式的呈现以及提示一起,使得模型能够学习模式到预期输出的映射。这使得模型能够更好地推广到在训练数据中不存在的模式。

基础模型

我们从 t5-large-LM-adapt 检查点中微调了该模型。

Spider和Spider-Syn数据集

该模型在 Spider Spider-Syn 数据集的训练集上进行了微调。我们不仅使用问题,还将数据库模式添加到问题中,因为我们希望模型能够在给定的数据库上生成问题

输入提示 :

Question:  What is the average, minimum, and maximum age for all French musicians?
Schema: "stadium" "Stadium_ID" int , "Location" text , "Name" text , "Capacity" int , "Highest" int , "Lowest" int ,
        "Average" int , foreign_key:  primary key: "Stadium_ID" [SEP] "singer" "Singer_ID" int , "Name" text , "Country" text ,
        "Song_Name" text , "Song_release_year" text , "Age" int , "Is_male" bool ,
        foreign_key:  primary key: "Singer_ID" [SEP],
        "concert" "concert_ID" int , "concert_Name" text , "Theme" text , "Year" text , foreign_key: "Stadium_ID" text from "stadium",
        "Stadium_ID" , primary key: "concert_ID" [SEP] "singer_in_concert",
        foreign_key: "concert_ID" int from "concert",
        "concert_ID" , "Singer_ID" text from "singer" "Singer_ID" , primary key: "concert_ID" "Singer_ID"

预期输出 :

SELECT avg(age), min(age), max(age) FROM singer WHERE country = 'France'

在评估输出时,我们查询 SQLite 数据库并获得:

[[34.5, 25, 43]]

数据库模式的格式

模型训练时使用的标准化数据库模式:

table_name column1_name column1_type column2_name column2_type ... foreign_key: FK_name FK_type from table_name column_name primary key: column_name [SEP]
table_name2 ...

用法

使用? Transformers在PyTorch中回答给定上下文上的问题的方法如下:

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model_path = 'gaussalgo/T5-LM-Large-text2sql-spider'
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

question = "What is the average, minimum, and maximum age for all French musicians?"
schema = """
   "stadium" "Stadium_ID" int , "Location" text , "Name" text , "Capacity" int , "Highest" int , "Lowest" int , "Average" int , foreign_key:  primary key: "Stadium_ID" [SEP] "singer" "Singer_ID" int , "Name" text , "Country" text , "Song_Name" text , "Song_release_year" text , "Age" int , "Is_male" bool , foreign_key:  primary key: "Singer_ID" [SEP] "concert" "concert_ID" int , "concert_Name" text , "Theme" text , "Year" text , foreign_key: "Stadium_ID" text from "stadium" "Stadium_ID" , primary key: "concert_ID" [SEP] "singer_in_concert"  foreign_key: "concert_ID" int from "concert" "concert_ID" , "Singer_ID" text from "singer" "Singer_ID" , primary key: "concert_ID" "Singer_ID"
"""

input_text = " ".join(["Question: ",question, "Schema:", schema])

model_inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**model_inputs, max_length=512)

output_text = tokenizer.decode(outputs, skip_special_tokens=True)

print("SQL Query:")
print(output_text)

输出:

SQL Query:
SELECT avg(age), min(age), max(age) FROM singer WHERE country = 'France'

评估

对Spider和Spider-syn数据集的开发集进行了评估。开发集中的数据库与训练集的数据库没有交集。这样我们确保模型在训练期间未接触到评估数据库。通过比较使用生成的查询和参考查询查询数据库的结果进行评估。Spider和Spider-Syn开发集各有1032个样本。

  • Spider开发集准确率:49.2%
  • Spider Syn开发集准确率:39.5%

训练

该模型已使用 Adaptor library 0.2.1进行了训练,使用Spider和Spider-syn数据集的训练集,并使用以下参数:

training_arguments = AdaptationArguments(output_dir="train_dir",
                                         learning_rate=5e-5,
                                         stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
                                         stopping_patience=8,
                                         save_total_limit=8,
                                         do_train=True,
                                         do_eval=True,
                                         bf16=True,
                                         warmup_steps=1000,
                                         gradient_accumulation_steps=8,
                                         logging_steps=10,
                                         eval_steps=200,
                                         save_steps=1000,
                                         num_train_epochs=10,
                                         evaluation_strategy="steps")

训练过程很容易复现,但我们不希望公开修改过的Spider数据集的副本,因为它依赖于Spider数据集。如果您想进一步探索,请随时通过新的PR或通过电子邮件(nikola.groverova(at)gaussalgo.com)与我们联系。