如何利用逐步推理优化文本到SQL代理

2025年02月11日 由 alex 发表 2483 0

DeepSeek R1发布的一个酷炫成果是,大型语言模型(LLM)现在开始在回复中显示“思考<think>”标记,类似于ChatGPT的o1和o3-mimi版本。鼓励大型语言模型进行更深入的思考有很多好处:

  • 不再只是黑箱答案!你可以实时看到大型语言模型回复背后的推理过程。
  • 用户能够了解模型是如何得出结论的。
  • 能够清晰地发现并修正提示中的错误。
  • 透明度让AI的决策显得更加可靠。
  • 当人类和AI共享推理过程时,合作变得轻松无阻。


所以,现在我在这里,我已经构建了一个检索增强生成(RAG)模型,它将类似的推理过程(即思维链(CoT)回复)带到了结合工具调用的LangGraph SQL代理中。这是一个“思考+行动”(ReAct)代理,它将LangGraph的SQL工具包与基于图的执行相结合。其工作原理如下:


21


现在,让我们来了解一下这个思考过程。


代理从一个系统提示开始,这个提示为其思考过程提供了结构:


我已经规划出了我们的SQL代理从接收到问题到返回最终查询所经历的确切步骤:


四阶段思考过程


推理阶段(<reasoning>标签)

  • 解释信息需求
  • 描述预期结果
  • 识别挑战
  • 说明方法选择的理由


分析阶段(<analysis>标签)

  • 所需的表和连接
  • 必需的列
  • 过滤条件和条件
  • 排序/分组逻辑


查询阶段(<query>标签)

  • 根据规则构造SQL:
  • 仅使用SELECT语句
  • 正确的语法
  • 默认LIMIT 10
  • 验证过的模式


验证阶段(<error_check>和<final_check>标签)

  • 验证推理
  • 确认方法
  • 检查完整性
  • 验证输出


以下是该过程的可视化表示:


22


以下是一个完整的提示模板:


query_gen_system = """"""
I am an SQL expert who helps analyze database queries. I have access to tools for interacting with the database. When given a question, I'll think through it carefully and explain my reasoning in natural language.
 
Then I'll walk through my analysis process:
1. First, I'll understand what tables and data I need
2. Then, I'll verify the schema and relationships
3. Finally, I'll construct an appropriate SQL query
For each query, I'll think about:
- What tables are involved and how they connect
- Any special conditions or filters needed
- How to handle potential edge cases
- The most efficient way to get the results
<reasoning>
I will **always** include this section before writing a query. Here, I will:
- Explain what information I need and why  
- Describe my expected outcome  
- Identify potential challenges  
- Justify my query structure  
If this section is missing, I will rewrite my response to include it.
</reasoning>
<analysis>
Here I break down the key components needed for the query:
- Required tables and joins
- Important columns and calculations
- Any specific filters or conditions
- Proper ordering and grouping
</analysis>
<query>
The final SQL query
</query>
<error_check>
If there's an error, I'll explain:
- What went wrong
- Why it happened
- How to fix it
</error_check>
<final_check>
Before finalizing, I will verify:
- Did I include a clear reasoning section?
- Did I explain my approach before querying?
- Did I provide an analysis of the query structure?
- If any of these are missing, I will revise my response.
</final_check>
Important rules:
1. Only use SELECT statements, no modifications
2. Verify all schema assumptions
3. Use proper SQLite syntax
4. Limit results to 10 unless specified
5. Double-check all joins and conditions
6. Always include tool_analysis and tool_reasoning for each tool call
"""


我们代理的思考过程的主要部分已经完成——我们已经涵盖了流程以及指导其推理的详细提示。现在,让我们进入下一部分:构建LangGraph SQL代理。


首先,让我们来看看图的实现:


query_gen_prompt = ChatPromptTemplate.from_messages([
    ("system", query_gen_system),"system", query_gen_system),
    MessagesPlaceholder(variable_name="messages"),
])
query_gen_model = query_gen_prompt | ChatOpenAI(
    model="gpt-4o-mini", temperature=0).bind_tools(tools=sql_db_toolkit_tools)

class State(TypedDict):
    messages: Annotated[list, add_messages]

graph_builder = StateGraph(State)

def query_gen_node(state: State):
    return {"messages": [query_gen_model.invoke(state["messages"])]}

checkpointer = MemorySaver()
graph_builder.add_node("query_gen", query_gen_node)
query_gen_tools_node = ToolNode(tools=sql_db_toolkit_tools)
graph_builder.add_node("query_gen_tools", query_gen_tools_node)
graph_builder.add_conditional_edges(
    "query_gen",
    tools_condition,
    {"tools": "query_gen_tools", END: END},
)
graph_builder.add_edge("query_gen_tools", "query_gen")
graph_builder.set_entry_point("query_gen")
graph = graph_builder.compile(checkpointer=checkpointer)


现在,关键的部分来了——我们如何从代理的回复中提取和处理思考过程:

  • 从我们定义的推理标签中提取每个思考阶段
  • 以可读的方式格式化输出
  • 在生成时捕获最终的SQL查询
  • 实时展示代理的思考过程


def extract_section(text: str, section: str) -> str:extract_section(text: str, section: str) -> str:
    pattern = f"<{section}>(.*?)</{section}>"
    match = re.search(pattern, text, re.DOTALL)
    return match.group(1).strip() if match else ""
def process_event(event: Dict[str, Any]) -> Optional[str]:
    if 'query_gen' in event:
        messages = event['query_gen']['messages']
        for message in messages:
            content = message.content if hasattr(message, 'content') else ""
            reasoning = extract_section(content, "reasoning")
            if reasoning:
                print(format_section("", reasoning))
            analysis = extract_section(content, "analysis")
            if analysis:
                print(format_section("", analysis))
            error_check = extract_section(content, "error_check")
            if error_check:
                print(format_section("", error_check))
            final_check = extract_section(content, "final_check")
            if final_check:
                print(format_section("", final_check))
            if hasattr(message, 'tool_calls'):
                for tool_call in message.tool_calls:
                    tool_name = tool_call['name']
                    if tool_name == 'sql_db_query':
                        return tool_call['args']['query']
            query = extract_section(content, "query")
            if query:
                # Try to extract SQL between triple backticks
                sql_match = re.search(
                    r'```sql\n(.*?)\n```', query, re.DOTALL)
                if sql_match:
                    return format_section("", query)
    return None


要使用它,我们只需从graph.stream:中流式传输结果。


def run_query(query_text: str):
    print(f"\nAnalyzing: {query_text}")
    for event in graph.stream({"messages": [("user", query_text)]},
                              config={"configurable": {"thread_id": 12}}):
        if sql := process_event(event):
            print(f"\nGenerated SQL: {sql}")
            return sql


以下是使这一切正常工作的完整代码:


import os
from typing import Dict, Any
import re
from typing_extensions import TypedDict
from typing import Annotated, Optional
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver

def _set_env(key: str):
    if key not in os.environ:
        os.environ['OPENAI_API_KEY'] = key

_set_env("API_KEY")
db_file = "chinook.db"
engine = create_engine(f"sqlite:///{db_file}")
db = SQLDatabase(engine=engine)
toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model="gpt-4o-mini"))
sql_db_toolkit_tools = toolkit.get_tools()
query_gen_system = """
I am an SQL expert who helps analyze database queries. I have access to tools for interacting with the database. When given a question, I'll think through it carefully and explain my reasoning in natural language.
 
Then I'll walk through my analysis process:
1. First, I'll understand what tables and data I need
2. Then, I'll verify the schema and relationships
3. Finally, I'll construct an appropriate SQL query
For each query, I'll think about:
- What tables are involved and how they connect
- Any special conditions or filters needed
- How to handle potential edge cases
- The most efficient way to get the results
<reasoning>
I will **always** include this section before writing a query. Here, I will:
- Explain what information I need and why  
- Describe my expected outcome  
- Identify potential challenges  
- Justify my query structure  
If this section is missing, I will rewrite my response to include it.
</reasoning>
<analysis>
Here I break down the key components needed for the query:
- Required tables and joins
- Important columns and calculations
- Any specific filters or conditions
- Proper ordering and grouping
</analysis>
<query>
The final SQL query
</query>
<error_check>
If there's an error, I'll explain:
- What went wrong
- Why it happened
- How to fix it
</error_check>
<final_check>
Before finalizing, I will verify:
- Did I include a clear reasoning section?
- Did I explain my approach before querying?
- Did I provide an analysis of the query structure?
- If any of these are missing, I will revise my response.
</final_check>
Important rules:
1. Only use SELECT statements, no modifications
2. Verify all schema assumptions
3. Use proper SQLite syntax
4. Limit results to 10 unless specified
5. Double-check all joins and conditions
6. Always include tool_analysis and tool_reasoning for each tool call
"""
query_gen_prompt = ChatPromptTemplate.from_messages([
    ("system", query_gen_system),
    MessagesPlaceholder(variable_name="messages"),
])
query_gen_model = query_gen_prompt | ChatOpenAI(
    model="gpt-4o-mini", temperature=0).bind_tools(tools=sql_db_toolkit_tools)

class State(TypedDict):
    messages: Annotated[list, add_messages]

graph_builder = StateGraph(State)

def query_gen_node(state: State):
    return {"messages": [query_gen_model.invoke(state["messages"])]}

checkpointer = MemorySaver()
graph_builder.add_node("query_gen", query_gen_node)
query_gen_tools_node = ToolNode(tools=sql_db_toolkit_tools)
graph_builder.add_node("query_gen_tools", query_gen_tools_node)
graph_builder.add_conditional_edges(
    "query_gen",
    tools_condition,
    {"tools": "query_gen_tools", END: END},
)
graph_builder.add_edge("query_gen_tools", "query_gen")
graph_builder.set_entry_point("query_gen")
graph = graph_builder.compile(checkpointer=checkpointer)

def format_section(title: str, content: str) -> str:
    if not content:
        return ""
    return f"\n{content}\n"

def extract_section(text: str, section: str) -> str:
    pattern = f"<{section}>(.*?)</{section}>"
    match = re.search(pattern, text, re.DOTALL)
    return match.group(1).strip() if match else ""

def process_event(event: Dict[str, Any]) -> Optional[str]:
    if 'query_gen' in event:
        messages = event['query_gen']['messages']
        for message in messages:
            content = message.content if hasattr(message, 'content') else ""
            reasoning = extract_section(content, "reasoning")
            if reasoning:
                print(format_section("", reasoning))
            analysis = extract_section(content, "analysis")
            if analysis:
                print(format_section("", analysis))
            error_check = extract_section(content, "error_check")
            if error_check:
                print(format_section("", error_check))
            final_check = extract_section(content, "final_check")
            if final_check:
                print(format_section("", final_check))
            if hasattr(message, 'tool_calls'):
                for tool_call in message.tool_calls:
                    tool_name = tool_call['name']
                    if tool_name == 'sql_db_query':
                        return tool_call['args']['query']
            query = extract_section(content, "query")
            if query:
                sql_match = re.search(
                    r'```sql\n(.*?)\n```', query, re.DOTALL)
                if sql_match:
                    return format_section("", query)
    return None

def run_query(query_text: str):
    print(f"\nAnalyzing your question: {query_text}")
    final_sql = None
    for event in graph.stream({"messages": [("user", query_text)]},
                              config={"configurable": {"thread_id": 12}}):
        sql = process_event(event)
        if sql:
            final_sql = sql
    if final_sql:
        print(
            "\nBased on my analysis, here's the SQL query that will answer your question:")
        print(f"\n{final_sql}")
        return final_sql

def interactive_sql():
    print("\nWelcome to the SQL Assistant! Type 'exit' to quit.")
    while True:
        try:
            query = input("\nWhat would you like to know? ")
            if query.lower() in ['exit', 'quit']:
                print("\nThank you for using SQL Assistant!")
                break
            run_query(query)
        except KeyboardInterrupt:
            print("\nThank you for using SQL Assistant!")
            break
        except Exception as e:
            print(f"\nAn error occurred: {str(e)}")
            print("Please try again with a different query.")

if __name__ == "__main__":
    interactive_sql()


我已经用几个模型(gpt4o、gpt4o-mini、Claude 3.5 Haiku)测试了这个实现,结果很有希望。以下是一个思考输出的示例:


What are the top 5 best-selling tracks by revenue?
Analyzing your question: What are the top 5 best-selling tracks by revenue?
To determine the top 5 best-selling tracks by revenue, I need to analyze the relevant tables that contain information about tracks and their sales. Typically, this would involve a "tracks" table that includes track details and a "sales" or "orders" table that records sales transactions.
My expected outcome is a list of the top 5 tracks sorted by total revenue generated from sales. The challenge here is to ensure that I correctly join the tables and aggregate the sales data to calculate the total revenue for each track.
I will structure the query to:
1. Join the "tracks" table with the "sales" table on the track ID.
2. Sum the revenue for each track.
3. Order the results by total revenue in descending order.
4. Limit the results to the top 5 tracks.
I will first check the database schema to confirm the table names and their relationships.

- Required tables: "tracks" and "sales" (or equivalent names).
- Important columns: Track ID, track name, and revenue from sales.
- Specific filters: None needed, but I will aggregate sales data.
- Proper ordering: By total revenue in descending order, limited to 5 results.
Now, I will check the database for the existing tables to confirm their names and structure.

Now that I have confirmed the relevant tables and their structures, I can proceed to construct the SQL query. The "Track" table contains information about each track, including its ID and price. The "InvoiceLine" table records each sale, linking to the "Track" table via the TrackId, and includes the quantity sold and unit price.
To calculate the total revenue for each track, I will:
1. Join the "Track" table with the "InvoiceLine" table on the TrackId.
2. Multiply the UnitPrice by the Quantity for each sale to get the revenue for that sale.
3. Sum the revenue for each track.
4. Order the results by total revenue in descending order.
5. Limit the results to the top 5 tracks.
This approach will ensure that I accurately capture the best-selling tracks by revenue.

- Required tables: "Track" and "InvoiceLine".
- Important columns: TrackId, Name (from Track), UnitPrice, Quantity (from InvoiceLine).
- Specific filters: None needed, as I want all tracks.
- Proper ordering: By total revenue in descending order, limited to 5 results.
Now, I will construct the SQL query based on this analysis.

- I included a clear reasoning section explaining the need for the query.
- I provided an analysis of the query structure, detailing the tables and columns involved.
- I executed the query and received results without errors.
The query successfully returned the top 5 best-selling tracks by revenue. Here are the results:
1. **The Woman King** - $3.98
2. **The Fix** - $3.98
3. **Walkabout** - $3.98
4. **Hot Girl** - $3.98
5. **Gay Witch Hunt** - $3.98
All tracks generated the same revenue, which indicates that they may have been sold in equal quantities or at the same price point. 
Everything is in order, and I have verified all steps.

Based on my analysis, here's the SQL query that will answer your question:
SELECT 
    t.TrackId, 
    t.Name, 
    SUM(il.UnitPrice * il.Quantity) AS TotalRevenue
FROM 
    Track t
JOIN 
    InvoiceLine il ON t.TrackId = il.TrackId
GROUP BY 
    t.TrackId, t.Name
ORDER BY 
    TotalRevenue DESC
LIMIT 5;


如你所见,推理过程清晰呈现,展示了所有的思考步骤。输出展示了我们的代理是如何思考的,每一步都展示了其工作过程,而不是直接跳到答案。你可以随意将这种方法适应到你自己的用例中!

文章来源:https://medium.com/@yia333/implementing-reasoning-in-text-to-sql-agents-f979331176b4
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消