DeepSeek R1发布的一个酷炫成果是,大型语言模型(LLM)现在开始在回复中显示“思考<think>”标记,类似于ChatGPT的o1和o3-mimi版本。鼓励大型语言模型进行更深入的思考有很多好处:
所以,现在我在这里,我已经构建了一个检索增强生成(RAG)模型,它将类似的推理过程(即思维链(CoT)回复)带到了结合工具调用的LangGraph SQL代理中。这是一个“思考+行动”(ReAct)代理,它将LangGraph的SQL工具包与基于图的执行相结合。其工作原理如下:
现在,让我们来了解一下这个思考过程。
代理从一个系统提示开始,这个提示为其思考过程提供了结构:
我已经规划出了我们的SQL代理从接收到问题到返回最终查询所经历的确切步骤:
四阶段思考过程
推理阶段(<reasoning>标签)
分析阶段(<analysis>标签)
查询阶段(<query>标签)
验证阶段(<error_check>和<final_check>标签)
以下是该过程的可视化表示:
以下是一个完整的提示模板:
query_gen_system = """"""
I am an SQL expert who helps analyze database queries. I have access to tools for interacting with the database. When given a question, I'll think through it carefully and explain my reasoning in natural language.
Then I'll walk through my analysis process:
1. First, I'll understand what tables and data I need
2. Then, I'll verify the schema and relationships
3. Finally, I'll construct an appropriate SQL query
For each query, I'll think about:
- What tables are involved and how they connect
- Any special conditions or filters needed
- How to handle potential edge cases
- The most efficient way to get the results
<reasoning>
I will **always** include this section before writing a query. Here, I will:
- Explain what information I need and why
- Describe my expected outcome
- Identify potential challenges
- Justify my query structure
If this section is missing, I will rewrite my response to include it.
</reasoning>
<analysis>
Here I break down the key components needed for the query:
- Required tables and joins
- Important columns and calculations
- Any specific filters or conditions
- Proper ordering and grouping
</analysis>
<query>
The final SQL query
</query>
<error_check>
If there's an error, I'll explain:
- What went wrong
- Why it happened
- How to fix it
</error_check>
<final_check>
Before finalizing, I will verify:
- Did I include a clear reasoning section?
- Did I explain my approach before querying?
- Did I provide an analysis of the query structure?
- If any of these are missing, I will revise my response.
</final_check>
Important rules:
1. Only use SELECT statements, no modifications
2. Verify all schema assumptions
3. Use proper SQLite syntax
4. Limit results to 10 unless specified
5. Double-check all joins and conditions
6. Always include tool_analysis and tool_reasoning for each tool call
"""
我们代理的思考过程的主要部分已经完成——我们已经涵盖了流程以及指导其推理的详细提示。现在,让我们进入下一部分:构建LangGraph SQL代理。
首先,让我们来看看图的实现:
query_gen_prompt = ChatPromptTemplate.from_messages([
("system", query_gen_system),"system", query_gen_system),
MessagesPlaceholder(variable_name="messages"),
])
query_gen_model = query_gen_prompt | ChatOpenAI(
model="gpt-4o-mini", temperature=0).bind_tools(tools=sql_db_toolkit_tools)
class State(TypedDict):
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
def query_gen_node(state: State):
return {"messages": [query_gen_model.invoke(state["messages"])]}
checkpointer = MemorySaver()
graph_builder.add_node("query_gen", query_gen_node)
query_gen_tools_node = ToolNode(tools=sql_db_toolkit_tools)
graph_builder.add_node("query_gen_tools", query_gen_tools_node)
graph_builder.add_conditional_edges(
"query_gen",
tools_condition,
{"tools": "query_gen_tools", END: END},
)
graph_builder.add_edge("query_gen_tools", "query_gen")
graph_builder.set_entry_point("query_gen")
graph = graph_builder.compile(checkpointer=checkpointer)
现在,关键的部分来了——我们如何从代理的回复中提取和处理思考过程:
def extract_section(text: str, section: str) -> str:extract_section(text: str, section: str) -> str:
pattern = f"<{section}>(.*?)</{section}>"
match = re.search(pattern, text, re.DOTALL)
return match.group(1).strip() if match else ""
def process_event(event: Dict[str, Any]) -> Optional[str]:
if 'query_gen' in event:
messages = event['query_gen']['messages']
for message in messages:
content = message.content if hasattr(message, 'content') else ""
reasoning = extract_section(content, "reasoning")
if reasoning:
print(format_section("", reasoning))
analysis = extract_section(content, "analysis")
if analysis:
print(format_section("", analysis))
error_check = extract_section(content, "error_check")
if error_check:
print(format_section("", error_check))
final_check = extract_section(content, "final_check")
if final_check:
print(format_section("", final_check))
if hasattr(message, 'tool_calls'):
for tool_call in message.tool_calls:
tool_name = tool_call['name']
if tool_name == 'sql_db_query':
return tool_call['args']['query']
query = extract_section(content, "query")
if query:
# Try to extract SQL between triple backticks
sql_match = re.search(
r'```sql\n(.*?)\n```', query, re.DOTALL)
if sql_match:
return format_section("", query)
return None
要使用它,我们只需从graph.stream:中流式传输结果。
def run_query(query_text: str):
print(f"\nAnalyzing: {query_text}")
for event in graph.stream({"messages": [("user", query_text)]},
config={"configurable": {"thread_id": 12}}):
if sql := process_event(event):
print(f"\nGenerated SQL: {sql}")
return sql
以下是使这一切正常工作的完整代码:
import os
from typing import Dict, Any
import re
from typing_extensions import TypedDict
from typing import Annotated, Optional
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
def _set_env(key: str):
if key not in os.environ:
os.environ['OPENAI_API_KEY'] = key
_set_env("API_KEY")
db_file = "chinook.db"
engine = create_engine(f"sqlite:///{db_file}")
db = SQLDatabase(engine=engine)
toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model="gpt-4o-mini"))
sql_db_toolkit_tools = toolkit.get_tools()
query_gen_system = """
I am an SQL expert who helps analyze database queries. I have access to tools for interacting with the database. When given a question, I'll think through it carefully and explain my reasoning in natural language.
Then I'll walk through my analysis process:
1. First, I'll understand what tables and data I need
2. Then, I'll verify the schema and relationships
3. Finally, I'll construct an appropriate SQL query
For each query, I'll think about:
- What tables are involved and how they connect
- Any special conditions or filters needed
- How to handle potential edge cases
- The most efficient way to get the results
<reasoning>
I will **always** include this section before writing a query. Here, I will:
- Explain what information I need and why
- Describe my expected outcome
- Identify potential challenges
- Justify my query structure
If this section is missing, I will rewrite my response to include it.
</reasoning>
<analysis>
Here I break down the key components needed for the query:
- Required tables and joins
- Important columns and calculations
- Any specific filters or conditions
- Proper ordering and grouping
</analysis>
<query>
The final SQL query
</query>
<error_check>
If there's an error, I'll explain:
- What went wrong
- Why it happened
- How to fix it
</error_check>
<final_check>
Before finalizing, I will verify:
- Did I include a clear reasoning section?
- Did I explain my approach before querying?
- Did I provide an analysis of the query structure?
- If any of these are missing, I will revise my response.
</final_check>
Important rules:
1. Only use SELECT statements, no modifications
2. Verify all schema assumptions
3. Use proper SQLite syntax
4. Limit results to 10 unless specified
5. Double-check all joins and conditions
6. Always include tool_analysis and tool_reasoning for each tool call
"""
query_gen_prompt = ChatPromptTemplate.from_messages([
("system", query_gen_system),
MessagesPlaceholder(variable_name="messages"),
])
query_gen_model = query_gen_prompt | ChatOpenAI(
model="gpt-4o-mini", temperature=0).bind_tools(tools=sql_db_toolkit_tools)
class State(TypedDict):
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
def query_gen_node(state: State):
return {"messages": [query_gen_model.invoke(state["messages"])]}
checkpointer = MemorySaver()
graph_builder.add_node("query_gen", query_gen_node)
query_gen_tools_node = ToolNode(tools=sql_db_toolkit_tools)
graph_builder.add_node("query_gen_tools", query_gen_tools_node)
graph_builder.add_conditional_edges(
"query_gen",
tools_condition,
{"tools": "query_gen_tools", END: END},
)
graph_builder.add_edge("query_gen_tools", "query_gen")
graph_builder.set_entry_point("query_gen")
graph = graph_builder.compile(checkpointer=checkpointer)
def format_section(title: str, content: str) -> str:
if not content:
return ""
return f"\n{content}\n"
def extract_section(text: str, section: str) -> str:
pattern = f"<{section}>(.*?)</{section}>"
match = re.search(pattern, text, re.DOTALL)
return match.group(1).strip() if match else ""
def process_event(event: Dict[str, Any]) -> Optional[str]:
if 'query_gen' in event:
messages = event['query_gen']['messages']
for message in messages:
content = message.content if hasattr(message, 'content') else ""
reasoning = extract_section(content, "reasoning")
if reasoning:
print(format_section("", reasoning))
analysis = extract_section(content, "analysis")
if analysis:
print(format_section("", analysis))
error_check = extract_section(content, "error_check")
if error_check:
print(format_section("", error_check))
final_check = extract_section(content, "final_check")
if final_check:
print(format_section("", final_check))
if hasattr(message, 'tool_calls'):
for tool_call in message.tool_calls:
tool_name = tool_call['name']
if tool_name == 'sql_db_query':
return tool_call['args']['query']
query = extract_section(content, "query")
if query:
sql_match = re.search(
r'```sql\n(.*?)\n```', query, re.DOTALL)
if sql_match:
return format_section("", query)
return None
def run_query(query_text: str):
print(f"\nAnalyzing your question: {query_text}")
final_sql = None
for event in graph.stream({"messages": [("user", query_text)]},
config={"configurable": {"thread_id": 12}}):
sql = process_event(event)
if sql:
final_sql = sql
if final_sql:
print(
"\nBased on my analysis, here's the SQL query that will answer your question:")
print(f"\n{final_sql}")
return final_sql
def interactive_sql():
print("\nWelcome to the SQL Assistant! Type 'exit' to quit.")
while True:
try:
query = input("\nWhat would you like to know? ")
if query.lower() in ['exit', 'quit']:
print("\nThank you for using SQL Assistant!")
break
run_query(query)
except KeyboardInterrupt:
print("\nThank you for using SQL Assistant!")
break
except Exception as e:
print(f"\nAn error occurred: {str(e)}")
print("Please try again with a different query.")
if __name__ == "__main__":
interactive_sql()
我已经用几个模型(gpt4o、gpt4o-mini、Claude 3.5 Haiku)测试了这个实现,结果很有希望。以下是一个思考输出的示例:
What are the top 5 best-selling tracks by revenue?
Analyzing your question: What are the top 5 best-selling tracks by revenue?
To determine the top 5 best-selling tracks by revenue, I need to analyze the relevant tables that contain information about tracks and their sales. Typically, this would involve a "tracks" table that includes track details and a "sales" or "orders" table that records sales transactions.
My expected outcome is a list of the top 5 tracks sorted by total revenue generated from sales. The challenge here is to ensure that I correctly join the tables and aggregate the sales data to calculate the total revenue for each track.
I will structure the query to:
1. Join the "tracks" table with the "sales" table on the track ID.
2. Sum the revenue for each track.
3. Order the results by total revenue in descending order.
4. Limit the results to the top 5 tracks.
I will first check the database schema to confirm the table names and their relationships.
- Required tables: "tracks" and "sales" (or equivalent names).
- Important columns: Track ID, track name, and revenue from sales.
- Specific filters: None needed, but I will aggregate sales data.
- Proper ordering: By total revenue in descending order, limited to 5 results.
Now, I will check the database for the existing tables to confirm their names and structure.
Now that I have confirmed the relevant tables and their structures, I can proceed to construct the SQL query. The "Track" table contains information about each track, including its ID and price. The "InvoiceLine" table records each sale, linking to the "Track" table via the TrackId, and includes the quantity sold and unit price.
To calculate the total revenue for each track, I will:
1. Join the "Track" table with the "InvoiceLine" table on the TrackId.
2. Multiply the UnitPrice by the Quantity for each sale to get the revenue for that sale.
3. Sum the revenue for each track.
4. Order the results by total revenue in descending order.
5. Limit the results to the top 5 tracks.
This approach will ensure that I accurately capture the best-selling tracks by revenue.
- Required tables: "Track" and "InvoiceLine".
- Important columns: TrackId, Name (from Track), UnitPrice, Quantity (from InvoiceLine).
- Specific filters: None needed, as I want all tracks.
- Proper ordering: By total revenue in descending order, limited to 5 results.
Now, I will construct the SQL query based on this analysis.
- I included a clear reasoning section explaining the need for the query.
- I provided an analysis of the query structure, detailing the tables and columns involved.
- I executed the query and received results without errors.
The query successfully returned the top 5 best-selling tracks by revenue. Here are the results:
1. **The Woman King** - $3.98
2. **The Fix** - $3.98
3. **Walkabout** - $3.98
4. **Hot Girl** - $3.98
5. **Gay Witch Hunt** - $3.98
All tracks generated the same revenue, which indicates that they may have been sold in equal quantities or at the same price point.
Everything is in order, and I have verified all steps.
Based on my analysis, here's the SQL query that will answer your question:
SELECT
t.TrackId,
t.Name,
SUM(il.UnitPrice * il.Quantity) AS TotalRevenue
FROM
Track t
JOIN
InvoiceLine il ON t.TrackId = il.TrackId
GROUP BY
t.TrackId, t.Name
ORDER BY
TotalRevenue DESC
LIMIT 5;
如你所见,推理过程清晰呈现,展示了所有的思考步骤。输出展示了我们的代理是如何思考的,每一步都展示了其工作过程,而不是直接跳到答案。你可以随意将这种方法适应到你自己的用例中!