使用Pydantic.ai和Groq构建自修复SQL查询生成器代理

2024年12月23日 由 alex 发表 20 0

介绍

在当今以数据驱动的世界中,高效查询数据库的能力至关重要。然而,编写SQL查询可能会具有挑战性,尤其对于那些对数据库操作不熟悉的人来说。本文介绍了一种自愈SQL查询生成器,可以将自然语言请求转换成SQL查询,从而使数据库交互更易于使用和更可靠。


什么是Pydantic AI?

Pydantic AI是一个助理框架,旨在简化创建和管理AI助理的过程。它建立在Pydantic库的坚实基础上,该库广泛用于Python项目中的数据验证和类型检查。


主要特点和优势


使用简便

Pydantic AI的一个显著特点是其简单性。创建助理非常简单,只需要几行代码。你可以通过指定模型名称和系统提示来定义助理,使得即使对于新手来说也能轻松上手。


构建于原始Python之上

与其他一些抽象了大部分底层代码的框架不同,Pydantic AI是基于原始Python构建的。这种方法使开发人员可以完全控制和了解其助理的内部工作原理,使其更容易在需要时进行定制和调试。


依赖注入

Pydantic AI集成了软件工程的最佳实践,包括依赖注入。这个功能可以使代码更加灵活和易于维护,特别是在大型项目中。


与可观察性工具的集成

对于那些关心可观察性的人来说,Pydantic AI与Logfire实现了无缝集成。这个工具可以全面跟踪提示、成本和其他与你的助理相关的重要指标。


函数调用和结构化输出

Pydantic AI在函数调用和生成结构化输出方面表现出色。你可以轻松定义对象模型来处理助理的响应,确保输出符合你的特定要求。在与期望特定格式数据的API或其他系统一起工作时,这特别有用。


成本跟踪

了解你的AI助理的资源使用情况非常重要,特别是在生产环境中。Pydantic AI通过提供内置的成本跟踪功能使此过程变得简单,可以监视每次交互的令牌使用情况和相关成本。


技术栈

这个实现利用了几种现代技术:

  • Pydantic.ai:用于构建具有强类型安全性的AI助理
  • Groq:快速高效的LLM提供者
  • SQLite:轻量级无服务器数据库
  • aiosqlite:Python的异步SQLite驱动程序
  • Python 3.11+:用于现代async/await 支持


工作流程描述

SQL助理遵循复杂的工作流程:

  • 用户输入:接收自然语言查询请求
  • 查询生成:使用LLM将请求转换为SQL
  • 验证:确保查询的语法和结构正确
  • 执行:对数据库测试查询
  • 自愈:如果执行失败,带有错误上下文进行重试
  • 结果传递:返回结果或详细的错误消息


代码演示


核心组件

该实现由几个关键组件组成:


# Models for type safety
class Success(BaseModel):
    type: str = Field("Success", pattern="^Success$")
    sql_query: Annotated[str, MinLen(1)]
    explanation: str
class InvalidRequest(BaseModel):
    type: str = Field("InvalidRequest", pattern="^InvalidRequest$")
    error_message: str


定义依赖关系


@dataclass
class Deps:
    conn: aiosqlite.Connection
    db_schema: str = DB_SCHEMA


代理配置


sqlagent = Agent(
    openai_model,
    deps_type=Deps,
    retries=3,3,
    result_type=Response,
    system_prompt=("""You are a proficient Database Administrator  having expertise in generating SQL queries. Your task is to convert natural language requests into SQL queries for a SQLite database.
You must respond with a Success object containing a sql_query and an explanation.
Database schema:
{DB_SCHEMA}
Format your response exactly like this, with no additional text or formatting:
{{
    "type": "Success",
    "sql_query": "<your SQL query here>",
    "explanation": "<your explanation here>"
}}
Examples:
    User: show me all users who have published posts
    {{
        "type": "Success",
        "sql_query": "SELECT DISTINCT users.* FROM users JOIN posts ON users.id = posts.user_id WHERE posts.published = TRUE",
        "explanation": "This query finds all users who have at least one published post by joining the users and posts tables."
    }}
    User: count posts by user
    {{
        "type": "Success",
        "sql_query": "SELECT users.name, COUNT(posts.id) as post_count FROM users LEFT JOIN posts ON users.id = posts.user_id GROUP BY users.id, users.name",
        "explanation": "This query counts the number of posts for each user, including users with no posts using LEFT JOIN."
    }}
    If you receive an error message about a previous query, analyze the error and fix the issues in your new query.
    Common fixes include:
    - Correcting column names
    - Fixing JOIN conditions
    - Adjusting GROUP BY clauses
    - Handling NULL values properly
If you cannot generate a valid query, respond with:
{{
    "type": "InvalidRequest",
    "error_message": "<explanation of why the request cannot be processed>"
}}
Important:
1. Respond with ONLY the JSON object, no additional text
2. Always include the "type" field as either "Success" or "InvalidRequest"
3. All queries must be SELECT statements
4. Provide clear explanations
5. Use proper JOIN conditions and WHERE clauses as needed
""")
)


查询执行和重试逻辑


async def query_database(prompt: str, conn: aiosqlite.Connection) -> Response:
    max_retries = 3
    last_error: Optional[str] = None
    
    for attempt in range(max_retries):
        try:
            result = await agent.run(prompt, deps=deps)
            success, error = await execute_query(conn, result.sql_query)
            if success:
                return result
            
            last_error = error
            prompt = f"""
Previous query failed with error: {error}
Please generate a corrected SQL query for the original request: {prompt}
"""
        except Exception as e:
            last_error = str(e)
            continue


代码实施

安装所需依赖项


%pip install 'pydantic-ai-slim[openai,groq,logfire]''pydantic-ai-slim[openai,groq,logfire]'
%pip install aiosqlite


设置API密钥


from google.colab import userdata
import os
os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')
os.environ["GROQ_API_KEY"] = userdata.get('GROQ_API_KEY')


导入所需的依赖项


from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.models.groq import GroqModel
import aiosqlite
import asyncio
from typing import Union, TypeAlias, Annotated,Optional,Tuple
from dataclasses import dataclass
from pydantic import BaseModel, Field
from annotated_types import MinLen
from pydantic_ai import Agent, ModelRetry, RunContext


Google Colab 运行的是 Jupyter Notebook 后端,它本身有一个活动的事件循环来管理交互和输出。因此,如果不进行修补,任何尝试启动另一个事件循环的行为都会导致冲突和错误。


nest_asyncio 的作用

  • 修补事件循环:nest_asyncio 库修改了 asyncio 的行为,允许对其函数进行嵌套调用。通过调用 nest_asyncio.apply(),我们实际上是指示代码使用已经在运行的事件循环,而不是尝试创建一个新的。这使我们能够在 Colab 环境中无缝执行异步任务。
  • 用法:要在我们的 Colab 笔记本中实现这一点,我们只需要导入 nest_asyncio 并在导入后立即应用它:


import nest_asyncio
nest_asyncio.apply()


实例化LLM


openai_model = OpenAIModel('gpt-4o-mini')'gpt-4o-mini')
groq_model = GroqModel("llama3-groq-70b-8192-tool-use-preview")


定义数据库模式(DB_SCHEMA)


# Define the schema for our example database
DB_SCHEMA = """
CREATE TABLE IF NOT EXISTS users (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    name TEXT NOT NULL,
    email TEXT UNIQUE NOT NULL,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS posts (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    user_id INTEGER,
    title TEXT NOT NULL,
    content TEXT,
    published BOOLEAN DEFAULT FALSE,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    FOREIGN KEY (user_id) REFERENCES users(id)
);
"""


定义代理

代理是PydanticAI与大型语言模型(LLM)交互的主要接口。


在某些用例中,单个代理将控制整个应用程序或组件,但多个代理也可以相互交互,以体现更复杂的工作流程。


Agent 类有完整的API文档,但从概念上讲,你可以将代理视为一个容器,包含以下内容:

  • 系统提示(System Prompt):由开发者为LLM编写的一组指令。
  • 一个或多个函数工具(Function Tools):LLM在生成响应时可能调用的函数,以获取信息。
  • 可选的结构化结果类型(Structured Result Type):LLM在运行结束时必须返回的结构化数据类型。
  • 依赖类型约束(Dependency Type Constraint):系统提示函数、工具和结果验证器在运行时都可能使用依赖项。
  • 代理还可以选择性地关联一个默认的大型语言模型(LLM);在运行代理时,也可以指定要使用的模型。


sqlagent = Agent(
    openai_model,
    deps_type=Deps,
    retries=3,3,
    result_type=Response,
    system_prompt=("""You are a proficient Database Administrator  having expertise in generating SQL queries. Your task is to convert natural language requests into SQL queries for a SQLite database.
You must respond with a Success object containing a sql_query and an explanation.
Database schema:
{DB_SCHEMA}
Format your response exactly like this, with no additional text or formatting:
{{
    "type": "Success",
    "sql_query": "<your SQL query here>",
    "explanation": "<your explanation here>"
}}
Examples:
    User: show me all users who have published posts
    {{
        "type": "Success",
        "sql_query": "SELECT DISTINCT users.* FROM users JOIN posts ON users.id = posts.user_id WHERE posts.published = TRUE",
        "explanation": "This query finds all users who have at least one published post by joining the users and posts tables."
    }}
    User: count posts by user
    {{
        "type": "Success",
        "sql_query": "SELECT users.name, COUNT(posts.id) as post_count FROM users LEFT JOIN posts ON users.id = posts.user_id GROUP BY users.id, users.name",
        "explanation": "This query counts the number of posts for each user, including users with no posts using LEFT JOIN."
    }}
    If you receive an error message about a previous query, analyze the error and fix the issues in your new query.
    Common fixes include:
    - Correcting column names
    - Fixing JOIN conditions
    - Adjusting GROUP BY clauses
    - Handling NULL values properly
If you cannot generate a valid query, respond with:
{{
    "type": "InvalidRequest",
    "error_message": "<explanation of why the request cannot be processed>"
}}
Important:
1. Respond with ONLY the JSON object, no additional text
2. Always include the "type" field as either "Success" or "InvalidRequest"
3. All queries must be SELECT statements
4. Provide clear explanations
5. Use proper JOIN conditions and WHERE clauses as needed
""")
)


发起数据库连接


async def init_database(db_path: str = "test.db") -> aiosqlite.Connection:
    """Initialize the database with schema"""
    conn = await aiosqlite.connect(db_path)
    
    # Enable foreign keys
    await conn.execute("PRAGMA foreign_keys = ON")
    
    # Create schema
    await conn.executescript(DB_SCHEMA)
    
    # Add some sample data if the tables are empty
    async with conn.execute("SELECT COUNT(*) FROM users") as cursor:
        count = await cursor.fetchone()
        if count[0] == 0:
            sample_data = """
            INSERT INTO users (name, email) VALUES 
                ('John Doe', 'john@example.com'),
                ('Jane Smith', 'jane@example.com');
                
            INSERT INTO posts (user_id, title, content, published) VALUES 
                (1, 'First Post', 'Hello World', TRUE),
                (1, 'Draft Post', 'Work in Progress', FALSE),
                (2, 'Jane''s Post', 'Hello from Jane', TRUE);
            """
            await conn.executescript(sample_data)
            await conn.commit()
    
    return conn


辅助函数来执行查询


from typing import Tuple, Optional
async def execute_query(conn: aiosqlite.Connection, query: str) -> Tuple[bool, Optional[str]]:
    """
    Execute a SQL query and return success status and error message if any.
    Returns: (success: bool, error_message: Optional[str])
    """
    try:
        async with conn.execute(query) as cursor:
            await cursor.fetchone()
        return True, None
    except Exception as e:
        return False, str(e)


调用SQL代理并处理查询的辅助函数


async def main():
    # Ensure GROQ API key is set
    if not os.getenv("GROQ_API_KEY"):
        raise ValueError("Please set GROQ_API_KEY environment variable")
    # Initialize database
    conn = await init_database("test.db")
    
    try:
        # Example queries to test
        test_queries = [
            "show me all users and the number of posts posted",
            "find users who have published posts",
            "show me all draft posts with their authors",
            "what is the count of users table",
            "show me the title of the posts published",
            "show me the structure of the posts",
            "show me the names of all the users"
        ]
        for query in test_queries:
            print(f"\nProcessing query: {query}")
            result = await query_database(query, conn)
            print(f"\nProcessing query result: {result}")
            if isinstance(result, InvalidRequest):
                print(f"Error: {result.error_message}")
            else:
                print("\n✅ Generated SQL:")
                print(result.data.sql_query)
                print("\n✅ Explanation:")
                print(result.data.explanation)
                print("\n✅ Cost:")
                print(result._usage)
                
                # Execute the query to show results
                try:
                    async with conn.execute(result.data.sql_query) as cursor:
                        rows = await cursor.fetchall()
                        print("\n? Results:")
                        for row in rows:
                            print(row)
                except Exception as e:
                    print(f"Error executing query: {query}")
                    continue
                          
            
            print("\n" + "="*50)
    finally:
        await conn.close()


调用SQL代理


asyncio.run(main()) 

##################################RESPONSE####################################
Processing query: show me all users and the number of posts posted
Processing query result: RunResult(_all_messages=[ModelRequest(parts=[SystemPromptPart(content='You are a proficient Database Administrator  having expertise in generating SQL queries. Your task is to convert natural language requests into SQL queries for a SQLite database.\nYou must respond with a Success object containing a sql_query and an explanation.\n\nDatabase schema:\n{DB_SCHEMA}\n\nFormat your response exactly like this, with no additional text or formatting:\n{{\n    "type": "Success",\n    "sql_query": "<your SQL query here>",\n    "explanation": "<your explanation here>"\n}}\n\nExamples:\n    User: show me all users who have published posts\n    {{\n        "type": "Success",\n        "sql_query": "SELECT DISTINCT users.* FROM users JOIN posts ON users.id = posts.user_id WHERE posts.published = TRUE",\n        "explanation": "This query finds all users who have at least one published post by joining the users and posts tables."\n    }}\n\n    User: count posts by user\n    {{\n        "type": "Success",\n        "sql_query": "SELECT users.name, COUNT(posts.id) as post_count FROM users LEFT JOIN posts ON users.id = posts.user_id GROUP BY users.id, users.name",\n        "explanation": "This query counts the number of posts for each user, including users with no posts using LEFT JOIN."\n    }}\n\n    If you receive an error message about a previous query, analyze the error and fix the issues in your new query.\n    Common fixes include:\n    - Correcting column names\n    - Fixing JOIN conditions\n    - Adjusting GROUP BY clauses\n    - Handling NULL values properly\n\nIf you cannot generate a valid query, respond with:\n{{\n    "type": "InvalidRequest",\n    "error_message": "<explanation of why the request cannot be processed>"\n}}\n\nImportant:\n1. Respond with ONLY the JSON object, no additional text\n2. Always include the "type" field as either "Success" or "InvalidRequest"\n3. All queries must be SELECT statements\n4. Provide clear explanations\n5. Use proper JOIN conditions and WHERE clauses as needed\n', part_kind='system-prompt'), UserPromptPart(content='show me all users and the number of posts posted', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 29, 420161, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request'), ModelResponse(parts=[ToolCallPart(tool_name='final_result_Success', args=ArgsJson(args_json='{"type":"Success","sql_query":"SELECT users.id, users.name, COUNT(posts.id) AS post_count FROM users LEFT JOIN posts ON users.id = posts.user_id GROUP BY users.id, users.name","explanation":"This query retrieves all users along with a count of the posts they have made. It uses a LEFT JOIN to include users even if they have not posted anything, ensuring that users with a post count of zero are still displayed."}'), tool_call_id='call_hFetDeIfPDo92NQNtyRJBuAS', part_kind='tool-call')], timestamp=datetime.datetime(2024, 12, 21, 16, 25, 29, tzinfo=datetime.timezone.utc), kind='response'), ModelRequest(parts=[ToolReturnPart(tool_name='final_result_Success', content='Final result processed.', tool_call_id='call_hFetDeIfPDo92NQNtyRJBuAS', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 31, 17975, tzinfo=datetime.timezone.utc), part_kind='tool-return')], kind='request')], _new_message_index=0, data=Success(type='Success', sql_query='SELECT users.id, users.name, COUNT(posts.id) AS post_count FROM users LEFT JOIN posts ON users.id = posts.user_id GROUP BY users.id, users.name', explanation='This query retrieves all users along with a count of the posts they have made. It uses a LEFT JOIN to include users even if they have not posted anything, ensuring that users with a post count of zero are still displayed.'), _usage=Usage(requests=1, request_tokens=630, response_tokens=103, total_tokens=733, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0}))
✅ Generated SQL:
SELECT users.id, users.name, COUNT(posts.id) AS post_count FROM users LEFT JOIN posts ON users.id = posts.user_id GROUP BY users.id, users.name
✅ Explanation:
This query retrieves all users along with a count of the posts they have made. It uses a LEFT JOIN to include users even if they have not posted anything, ensuring that users with a post count of zero are still displayed.
✅ Cost:
Usage(requests=1, request_tokens=630, response_tokens=103, total_tokens=733, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0})
? Results:
(1, 'John Doe', 2)
(2, 'Jane Smith', 1)
==================================================
Processing query: find users who have published posts
Processing query result: RunResult(_all_messages=[ModelRequest(parts=[SystemPromptPart(content='You are a proficient Database Administrator  having expertise in generating SQL queries. Your task is to convert natural language requests into SQL queries for a SQLite database.\nYou must respond with a Success object containing a sql_query and an explanation.\n\nDatabase schema:\n{DB_SCHEMA}\n\nFormat your response exactly like this, with no additional text or formatting:\n{{\n    "type": "Success",\n    "sql_query": "<your SQL query here>",\n    "explanation": "<your explanation here>"\n}}\n\nExamples:\n    User: show me all users who have published posts\n    {{\n        "type": "Success",\n        "sql_query": "SELECT DISTINCT users.* FROM users JOIN posts ON users.id = posts.user_id WHERE posts.published = TRUE",\n        "explanation": "This query finds all users who have at least one published post by joining the users and posts tables."\n    }}\n\n    User: count posts by user\n    {{\n        "type": "Success",\n        "sql_query": "SELECT users.name, COUNT(posts.id) as post_count FROM users LEFT JOIN posts ON users.id = posts.user_id GROUP BY users.id, users.name",\n        "explanation": "This query counts the number of posts for each user, including users with no posts using LEFT JOIN."\n    }}\n\n    If you receive an error message about a previous query, analyze the error and fix the issues in your new query.\n    Common fixes include:\n    - Correcting column names\n    - Fixing JOIN conditions\n    - Adjusting GROUP BY clauses\n    - Handling NULL values properly\n\nIf you cannot generate a valid query, respond with:\n{{\n    "type": "InvalidRequest",\n    "error_message": "<explanation of why the request cannot be processed>"\n}}\n\nImportant:\n1. Respond with ONLY the JSON object, no additional text\n2. Always include the "type" field as either "Success" or "InvalidRequest"\n3. All queries must be SELECT statements\n4. Provide clear explanations\n5. Use proper JOIN conditions and WHERE clauses as needed\n', part_kind='system-prompt'), UserPromptPart(content='find users who have published posts', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 31, 26011, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request'), ModelResponse(parts=[ToolCallPart(tool_name='final_result_Success', args=ArgsJson(args_json='{"sql_query":"SELECT DISTINCT users.* FROM users JOIN posts ON users.id = posts.user_id WHERE posts.published = TRUE","explanation":"This query identifies all users who have at least one post marked as published by joining the users and posts tables on their respective IDs."}'), tool_call_id='call_1ecYqTquJhhYEiBSRELqQAuT', part_kind='tool-call')], timestamp=datetime.datetime(2024, 12, 21, 16, 25, 31, tzinfo=datetime.timezone.utc), kind='response'), ModelRequest(parts=[ToolReturnPart(tool_name='final_result_Success', content='Final result processed.', tool_call_id='call_1ecYqTquJhhYEiBSRELqQAuT', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 32, 229726, tzinfo=datetime.timezone.utc), part_kind='tool-return')], kind='request')], _new_message_index=0, data=Success(type='Success', sql_query='SELECT DISTINCT users.* FROM users JOIN posts ON users.id = posts.user_id WHERE posts.published = TRUE', explanation='This query identifies all users who have at least one post marked as published by joining the users and posts tables on their respective IDs.'), _usage=Usage(requests=1, request_tokens=626, response_tokens=68, total_tokens=694, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0}))
✅ Generated SQL:
SELECT DISTINCT users.* FROM users JOIN posts ON users.id = posts.user_id WHERE posts.published = TRUE
✅ Explanation:
This query identifies all users who have at least one post marked as published by joining the users and posts tables on their respective IDs.
✅ Cost:
Usage(requests=1, request_tokens=626, response_tokens=68, total_tokens=694, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0})
? Results:
(1, 'John Doe', 'john@example.com', '2024-12-21 14:05:32')
(2, 'Jane Smith', 'jane@example.com', '2024-12-21 14:05:32')
==================================================
Processing query: show me all draft posts with their authors
Processing query result: RunResult(_all_messages=[ModelRequest(parts=[SystemPromptPart(content='You are a proficient Database Administrator  having expertise in generating SQL queries. Your task is to convert natural language requests into SQL queries for a SQLite database.\nYou must respond with a Success object containing a sql_query and an explanation.\n\nDatabase schema:\n{DB_SCHEMA}\n\nFormat your response exactly like this, with no additional text or formatting:\n{{\n    "type": "Success",\n    "sql_query": "<your SQL query here>",\n    "explanation": "<your explanation here>"\n}}\n\nExamples:\n    User: show me all users who have published posts\n    {{\n        "type": "Success",\n        "sql_query": "SELECT DISTINCT users.* FROM users JOIN posts ON users.id = posts.user_id WHERE posts.published = TRUE",\n        "explanation": "This query finds all users who have at least one published post by joining the users and posts tables."\n    }}\n\n    User: count posts by user\n    {{\n        "type": "Success",\n        "sql_query": "SELECT users.name, COUNT(posts.id) as post_count FROM users LEFT JOIN posts ON users.id = posts.user_id GROUP BY users.id, users.name",\n        "explanation": "This query counts the number of posts for each user, including users with no posts using LEFT JOIN."\n    }}\n\n    If you receive an error message about a previous query, analyze the error and fix the issues in your new query.\n    Common fixes include:\n    - Correcting column names\n    - Fixing JOIN conditions\n    - Adjusting GROUP BY clauses\n    - Handling NULL values properly\n\nIf you cannot generate a valid query, respond with:\n{{\n    "type": "InvalidRequest",\n    "error_message": "<explanation of why the request cannot be processed>"\n}}\n\nImportant:\n1. Respond with ONLY the JSON object, no additional text\n2. Always include the "type" field as either "Success" or "InvalidRequest"\n3. All queries must be SELECT statements\n4. Provide clear explanations\n5. Use proper JOIN conditions and WHERE clauses as needed\n', part_kind='system-prompt'), UserPromptPart(content='\nPrevious query failed with error: no such column: posts.status\nPlease generate a corrected SQL query for the original request: show me all draft posts with their authors\n', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 33, 648254, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request'), ModelResponse(parts=[ToolCallPart(tool_name='final_result_Success', args=ArgsJson(args_json='{"sql_query": "SELECT posts.*, users.name as author_name FROM posts JOIN users ON posts.user_id = users.id WHERE posts.published = FALSE", "explanation": "This query selects all draft posts (those not published) along with their authors by joining the posts and users tables based on the user ID."}'), tool_call_id='call_PFcdShmLLrI4ATOqh4MplMAt', part_kind='tool-call')], timestamp=datetime.datetime(2024, 12, 21, 16, 25, 33, tzinfo=datetime.timezone.utc), kind='response'), ModelRequest(parts=[ToolReturnPart(tool_name='final_result_Success', content='Final result processed.', tool_call_id='call_PFcdShmLLrI4ATOqh4MplMAt', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 35, 126566, tzinfo=datetime.timezone.utc), part_kind='tool-return')], kind='request')], _new_message_index=0, data=Success(type='Success', sql_query='SELECT posts.*, users.name as author_name FROM posts JOIN users ON posts.user_id = users.id WHERE posts.published = FALSE', explanation='This query selects all draft posts (those not published) along with their authors by joining the posts and users tables based on the user ID.'), _usage=Usage(requests=1, request_tokens=654, response_tokens=90, total_tokens=744, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0}))
✅ Generated SQL:
SELECT posts.*, users.name as author_name FROM posts JOIN users ON posts.user_id = users.id WHERE posts.published = FALSE
✅ Explanation:
This query selects all draft posts (those not published) along with their authors by joining the posts and users tables based on the user ID.
✅ Cost:
Usage(requests=1, request_tokens=654, response_tokens=90, total_tokens=744, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0})
? Results:
(2, 1, 'Draft Post', 'Work in Progress', 0, '2024-12-21 14:05:32', 'John Doe')
==================================================
Processing query: what is the count of users table
Processing query result: RunResult(_all_messages=[ModelRequest(parts=[SystemPromptPart(content='You are a proficient Database Administrator  having expertise in generating SQL queries. Your task is to convert natural language requests into SQL queries for a SQLite database.\nYou must respond with a Success object containing a sql_query and an explanation.\n\nDatabase schema:\n{DB_SCHEMA}\n\nFormat your response exactly like this, with no additional text or formatting:\n{{\n    "type": "Success",\n    "sql_query": "<your SQL query here>",\n    "explanation": "<your explanation here>"\n}}\n\nExamples:\n    User: show me all users who have published posts\n    {{\n        "type": "Success",\n        "sql_query": "SELECT DISTINCT users.* FROM users JOIN posts ON users.id = posts.user_id WHERE posts.published = TRUE",\n        "explanation": "This query finds all users who have at least one published post by joining the users and posts tables."\n    }}\n\n    User: count posts by user\n    {{\n        "type": "Success",\n        "sql_query": "SELECT users.name, COUNT(posts.id) as post_count FROM users LEFT JOIN posts ON users.id = posts.user_id GROUP BY users.id, users.name",\n        "explanation": "This query counts the number of posts for each user, including users with no posts using LEFT JOIN."\n    }}\n\n    If you receive an error message about a previous query, analyze the error and fix the issues in your new query.\n    Common fixes include:\n    - Correcting column names\n    - Fixing JOIN conditions\n    - Adjusting GROUP BY clauses\n    - Handling NULL values properly\n\nIf you cannot generate a valid query, respond with:\n{{\n    "type": "InvalidRequest",\n    "error_message": "<explanation of why the request cannot be processed>"\n}}\n\nImportant:\n1. Respond with ONLY the JSON object, no additional text\n2. Always include the "type" field as either "Success" or "InvalidRequest"\n3. All queries must be SELECT statements\n4. Provide clear explanations\n5. Use proper JOIN conditions and WHERE clauses as needed\n', part_kind='system-prompt'), UserPromptPart(content='what is the count of users table', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 35, 135611, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request'), ModelResponse(parts=[ToolCallPart(tool_name='final_result_Success', args=ArgsJson(args_json='{"sql_query": "SELECT COUNT(*) as user_count FROM users", "explanation": "This query counts the total number of rows in the users table, providing the total number of users."}'), tool_call_id='call_vvBQ3QO4UwLHvpEpn0Gam5zt', part_kind='tool-call')], timestamp=datetime.datetime(2024, 12, 21, 16, 25, 35, tzinfo=datetime.timezone.utc), kind='response'), ModelRequest(parts=[ToolReturnPart(tool_name='final_result_Success', content='Final result processed.', tool_call_id='call_vvBQ3QO4UwLHvpEpn0Gam5zt', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 36, 421186, tzinfo=datetime.timezone.utc), part_kind='tool-return')], kind='request')], _new_message_index=0, data=Success(type='Success', sql_query='SELECT COUNT(*) as user_count FROM users', explanation='This query counts the total number of rows in the users table, providing the total number of users.'), _usage=Usage(requests=1, request_tokens=627, response_tokens=65, total_tokens=692, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0}))
✅ Generated SQL:
SELECT COUNT(*) as user_count FROM users
✅ Explanation:
This query counts the total number of rows in the users table, providing the total number of users.
✅ Cost:
Usage(requests=1, request_tokens=627, response_tokens=65, total_tokens=692, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0})
? Results:
(2,)
==================================================
Processing query: show me the title of the posts published
Processing query result: RunResult(_all_messages=[ModelRequest(parts=[SystemPromptPart(content='You are a proficient Database Administrator  having expertise in generating SQL queries. Your task is to convert natural language requests into SQL queries for a SQLite database.\nYou must respond with a Success object containing a sql_query and an explanation.\n\nDatabase schema:\n{DB_SCHEMA}\n\nFormat your response exactly like this, with no additional text or formatting:\n{{\n    "type": "Success",\n    "sql_query": "<your SQL query here>",\n    "explanation": "<your explanation here>"\n}}\n\nExamples:\n    User: show me all users who have published posts\n    {{\n        "type": "Success",\n        "sql_query": "SELECT DISTINCT users.* FROM users JOIN posts ON users.id = posts.user_id WHERE posts.published = TRUE",\n        "explanation": "This query finds all users who have at least one published post by joining the users and posts tables."\n    }}\n\n    User: count posts by user\n    {{\n        "type": "Success",\n        "sql_query": "SELECT users.name, COUNT(posts.id) as post_count FROM users LEFT JOIN posts ON users.id = posts.user_id GROUP BY users.id, users.name",\n        "explanation": "This query counts the number of posts for each user, including users with no posts using LEFT JOIN."\n    }}\n\n    If you receive an error message about a previous query, analyze the error and fix the issues in your new query.\n    Common fixes include:\n    - Correcting column names\n    - Fixing JOIN conditions\n    - Adjusting GROUP BY clauses\n    - Handling NULL values properly\n\nIf you cannot generate a valid query, respond with:\n{{\n    "type": "InvalidRequest",\n    "error_message": "<explanation of why the request cannot be processed>"\n}}\n\nImportant:\n1. Respond with ONLY the JSON object, no additional text\n2. Always include the "type" field as either "Success" or "InvalidRequest"\n3. All queries must be SELECT statements\n4. Provide clear explanations\n5. Use proper JOIN conditions and WHERE clauses as needed\n', part_kind='system-prompt'), UserPromptPart(content='show me the title of the posts published', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 36, 428434, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request'), ModelResponse(parts=[ToolCallPart(tool_name='final_result_Success', args=ArgsJson(args_json='{"sql_query":"SELECT title FROM posts WHERE published = TRUE","explanation":"This query retrieves the titles of all posts that have been marked as published."}'), tool_call_id='call_wkb9qISadI545XilG1RXjr9f', part_kind='tool-call')], timestamp=datetime.datetime(2024, 12, 21, 16, 25, 36, tzinfo=datetime.timezone.utc), kind='response'), ModelRequest(parts=[ToolReturnPart(tool_name='final_result_Success', content='Final result processed.', tool_call_id='call_wkb9qISadI545XilG1RXjr9f', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 37, 566200, tzinfo=datetime.timezone.utc), part_kind='tool-return')], kind='request')], _new_message_index=0, data=Success(type='Success', sql_query='SELECT title FROM posts WHERE published = TRUE', explanation='This query retrieves the titles of all posts that have been marked as published.'), _usage=Usage(requests=1, request_tokens=628, response_tokens=44, total_tokens=672, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0}))
✅ Generated SQL:
SELECT title FROM posts WHERE published = TRUE
✅ Explanation:
This query retrieves the titles of all posts that have been marked as published.
✅ Cost:
Usage(requests=1, request_tokens=628, response_tokens=44, total_tokens=672, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0})
? Results:
('First Post',)
("Jane's Post",)
==================================================
Processing query: show me the structure of the posts
Processing query result: type='InvalidRequest' error_message="Failed to generate valid query after 3 attempts. Last error: 'InvalidRequest' object has no attribute 'sql_query'"
Error: Failed to generate valid query after 3 attempts. Last error: 'InvalidRequest' object has no attribute 'sql_query'
==================================================
Processing query: show me the names of all the users
Processing query result: RunResult(_all_messages=[ModelRequest(parts=[SystemPromptPart(content='You are a proficient Database Administrator  having expertise in generating SQL queries. Your task is to convert natural language requests into SQL queries for a SQLite database.\nYou must respond with a Success object containing a sql_query and an explanation.\n\nDatabase schema:\n{DB_SCHEMA}\n\nFormat your response exactly like this, with no additional text or formatting:\n{{\n    "type": "Success",\n    "sql_query": "<your SQL query here>",\n    "explanation": "<your explanation here>"\n}}\n\nExamples:\n    User: show me all users who have published posts\n    {{\n        "type": "Success",\n        "sql_query": "SELECT DISTINCT users.* FROM users JOIN posts ON users.id = posts.user_id WHERE posts.published = TRUE",\n        "explanation": "This query finds all users who have at least one published post by joining the users and posts tables."\n    }}\n\n    User: count posts by user\n    {{\n        "type": "Success",\n        "sql_query": "SELECT users.name, COUNT(posts.id) as post_count FROM users LEFT JOIN posts ON users.id = posts.user_id GROUP BY users.id, users.name",\n        "explanation": "This query counts the number of posts for each user, including users with no posts using LEFT JOIN."\n    }}\n\n    If you receive an error message about a previous query, analyze the error and fix the issues in your new query.\n    Common fixes include:\n    - Correcting column names\n    - Fixing JOIN conditions\n    - Adjusting GROUP BY clauses\n    - Handling NULL values properly\n\nIf you cannot generate a valid query, respond with:\n{{\n    "type": "InvalidRequest",\n    "error_message": "<explanation of why the request cannot be processed>"\n}}\n\nImportant:\n1. Respond with ONLY the JSON object, no additional text\n2. Always include the "type" field as either "Success" or "InvalidRequest"\n3. All queries must be SELECT statements\n4. Provide clear explanations\n5. Use proper JOIN conditions and WHERE clauses as needed\n', part_kind='system-prompt'), UserPromptPart(content='show me the names of all the users', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 40, 312604, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request'), ModelResponse(parts=[ToolCallPart(tool_name='final_result_Success', args=ArgsJson(args_json='{"sql_query":"SELECT name FROM users","explanation":"This query retrieves the names of all users from the users table."}'), tool_call_id='call_c06zk9SA5TyoqBG0sjVEhpdJ', part_kind='tool-call')], timestamp=datetime.datetime(2024, 12, 21, 16, 25, 40, tzinfo=datetime.timezone.utc), kind='response'), ModelRequest(parts=[ToolReturnPart(tool_name='final_result_Success', content='Final result processed.', tool_call_id='call_c06zk9SA5TyoqBG0sjVEhpdJ', timestamp=datetime.datetime(2024, 12, 21, 16, 25, 41, 194161, tzinfo=datetime.timezone.utc), part_kind='tool-return')], kind='request')], _new_message_index=0, data=Success(type='Success', sql_query='SELECT name FROM users', explanation='This query retrieves the names of all users from the users table.'), _usage=Usage(requests=1, request_tokens=628, response_tokens=38, total_tokens=666, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0}))
✅ Generated SQL:
SELECT name FROM users
✅ Explanation:
This query retrieves the names of all users from the users table.
✅ Cost:
Usage(requests=1, request_tokens=628, response_tokens=38, total_tokens=666, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0})
? Results:
('John Doe',)
('Jane Smith',)
==================================================


结论

SQL代理展示了现代AI技术如何使数据库交互更加便捷,同时通过自我修复机制保持稳健性。尽管存在局限性,但该系统为自然语言数据库交互提供了坚实的基础。


本文概述了SQL代理的实现、架构以及潜在的改进方向。你可以根据自己的具体需求和受众,自由地调整和扩展这一结构。

文章来源:https://medium.com/the-ai-forum/building-a-self-healing-sql-query-generator-agent-with-pydantic-ai-and-groq-7045910265c0
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消