该模型的目的是根据自然语言提示生成结构化的SQL查询。
在Text2SQL任务中,模型学习如何根据用自然语言提出的问题生成一个SQL查询。然而,在某些情况下,SQL查询包含未知列等,并且完全不考虑特定数据库的模式。
这就是我们的方法所发挥作用的地方。我们在训练过程中将数据库模式合并到输入问题中,以指定哪些列和关系可用于生成可应用的SQL查询。
数据库模式的呈现以及提示一起,使得模型能够学习模式到预期输出的映射。这使得模型能够更好地推广到在训练数据中不存在的模式。
我们从 t5-large-LM-adapt 检查点中微调了该模型。
该模型在 Spider 和 Spider-Syn 数据集的训练集上进行了微调。我们不仅使用问题,还将数据库模式添加到问题中,因为我们希望模型能够在给定的数据库上生成问题
输入提示 :
Question: What is the average, minimum, and maximum age for all French musicians? Schema: "stadium" "Stadium_ID" int , "Location" text , "Name" text , "Capacity" int , "Highest" int , "Lowest" int , "Average" int , foreign_key: primary key: "Stadium_ID" [SEP] "singer" "Singer_ID" int , "Name" text , "Country" text , "Song_Name" text , "Song_release_year" text , "Age" int , "Is_male" bool , foreign_key: primary key: "Singer_ID" [SEP], "concert" "concert_ID" int , "concert_Name" text , "Theme" text , "Year" text , foreign_key: "Stadium_ID" text from "stadium", "Stadium_ID" , primary key: "concert_ID" [SEP] "singer_in_concert", foreign_key: "concert_ID" int from "concert", "concert_ID" , "Singer_ID" text from "singer" "Singer_ID" , primary key: "concert_ID" "Singer_ID"
预期输出 :
SELECT avg(age), min(age), max(age) FROM singer WHERE country = 'France'
在评估输出时,我们查询 SQLite 数据库并获得:
[[34.5, 25, 43]]
模型训练时使用的标准化数据库模式:
table_name column1_name column1_type column2_name column2_type ... foreign_key: FK_name FK_type from table_name column_name primary key: column_name [SEP] table_name2 ...
使用? Transformers在PyTorch中回答给定上下文上的问题的方法如下:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer model_path = 'gaussalgo/T5-LM-Large-text2sql-spider' model = AutoModelForSeq2SeqLM.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) question = "What is the average, minimum, and maximum age for all French musicians?" schema = """ "stadium" "Stadium_ID" int , "Location" text , "Name" text , "Capacity" int , "Highest" int , "Lowest" int , "Average" int , foreign_key: primary key: "Stadium_ID" [SEP] "singer" "Singer_ID" int , "Name" text , "Country" text , "Song_Name" text , "Song_release_year" text , "Age" int , "Is_male" bool , foreign_key: primary key: "Singer_ID" [SEP] "concert" "concert_ID" int , "concert_Name" text , "Theme" text , "Year" text , foreign_key: "Stadium_ID" text from "stadium" "Stadium_ID" , primary key: "concert_ID" [SEP] "singer_in_concert" foreign_key: "concert_ID" int from "concert" "concert_ID" , "Singer_ID" text from "singer" "Singer_ID" , primary key: "concert_ID" "Singer_ID" """ input_text = " ".join(["Question: ",question, "Schema:", schema]) model_inputs = tokenizer(input_text, return_tensors="pt") outputs = model.generate(**model_inputs, max_length=512) output_text = tokenizer.decode(outputs, skip_special_tokens=True) print("SQL Query:") print(output_text)
输出:
SQL Query: SELECT avg(age), min(age), max(age) FROM singer WHERE country = 'France'
对Spider和Spider-syn数据集的开发集进行了评估。开发集中的数据库与训练集的数据库没有交集。这样我们确保模型在训练期间未接触到评估数据库。通过比较使用生成的查询和参考查询查询数据库的结果进行评估。Spider和Spider-Syn开发集各有1032个样本。
该模型已使用 Adaptor library 0.2.1进行了训练,使用Spider和Spider-syn数据集的训练集,并使用以下参数:
training_arguments = AdaptationArguments(output_dir="train_dir", learning_rate=5e-5, stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, stopping_patience=8, save_total_limit=8, do_train=True, do_eval=True, bf16=True, warmup_steps=1000, gradient_accumulation_steps=8, logging_steps=10, eval_steps=200, save_steps=1000, num_train_epochs=10, evaluation_strategy="steps")
训练过程很容易复现,但我们不希望公开修改过的Spider数据集的副本,因为它依赖于Spider数据集。如果您想进一步探索,请随时通过新的PR或通过电子邮件(nikola.groverova(at)gaussalgo.com)与我们联系。