WIP Text to SQL - checkpoint

This commit is contained in:
Mahesh Murag
2024-09-24 00:04:36 +02:00
parent f9a07c5329
commit ce51e8c078
3 changed files with 36 additions and 18 deletions

View File

@@ -16,7 +16,6 @@ prompts:
tests:
- description: "Check syntax of simple query"
prompt: prompts.py:basic_text_to_sql
vars:
user_query: "What are the names of all employees in the Engineering department?"
assert:
@@ -45,7 +44,6 @@ tests:
"reason": f"SQL query {'is correct' if result else 'is incorrect or not found'}"
}
- description: "Validate count of employees in Engineering department"
prompt: prompts.py:generate_prompt
vars:
user_query: "How many employees are in the Engineering department?"
assert:
@@ -90,7 +88,6 @@ tests:
f"Returned count: {count}, Expected count: {expected_count}."
}
- description: "Check specific employee details in Engineering department"
prompt: prompts.py:generate_prompt
vars:
user_query: "Give me the name, age, and salary of the oldest employee in the Engineering department."
assert:

View File

@@ -1,7 +1,9 @@
import sqlite3
def get_schema_info(db_path):
conn = sqlite3.connect(db_path)
DATABASE_PATH = '../data/data.db'
def get_schema_info():
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
schema_info = []
@@ -22,8 +24,8 @@ def get_schema_info(db_path):
conn.close()
return "\n\n".join(schema_info)
def generate_prompt(user_query, db_path='../data/data.db'):
schema = get_schema_info(db_path)
def generate_prompt(user_query):
schema = get_schema_info()
return f"""
You are an AI assistant that converts natural language queries into SQL.
Given the following SQL database schema:
@@ -37,7 +39,7 @@ def generate_prompt(user_query, db_path='../data/data.db'):
Provide only the SQL query in your response, enclosed within <sql> tags.
"""
def generate_prompt_with_examples(user_query, db_path='../data/data.db'):
def generate_prompt_with_examples(user_query):
examples = """
Example 1:
<query>List all employees in the HR department.</<query>
@@ -52,7 +54,7 @@ def generate_prompt_with_examples(user_query, db_path='../data/data.db'):
SQL: SELECT name, age FROM employees ORDER BY age DESC LIMIT 1;
"""
schema = get_schema_info(db_path)
schema = get_schema_info()
return f"""
You are an AI assistant that converts natural language queries into SQL.
@@ -76,8 +78,8 @@ def generate_prompt_with_examples(user_query, db_path='../data/data.db'):
Provide only the SQL query in your response, enclosed within <sql> tags.
"""
def generate_prompt_with_cot(user_query, db_path='../data/data.db'):
schema = get_schema_info(db_path)
def generate_prompt_with_cot(user_query):
schema = get_schema_info()
examples = """
<example>
<query>List all employees in the HR department.</query>
@@ -89,25 +91,42 @@ def generate_prompt_with_cot(user_query, db_path='../data/data.db'):
</thought_process>
<sql>SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'HR';</sql>
</example>
<example>
<query>What is the average salary of employees hired in 2022?</query>
<thought_process>
1. We need to work with the employees table.
2. We need to filter for employees hired in 2022.
3. We'll use the YEAR function to extract the year from the hire_date.
4. We'll calculate the average of the salary column for the filtered rows.
</thought_process>
<sql>SELECT AVG(salary) FROM employees WHERE YEAR(hire_date) = 2022;</sql>
</example>
"""
return f"""
You are an AI assistant that converts natural language queries into SQL.
return f"""You are an AI assistant that converts natural language queries into SQL.
Given the following SQL database schema:
<schema>
{schema}
</schema>
Here are some examples of natural language queries, thought processes, and corresponding SQL queries:
Here are some examples of natural language queries, thought processes, and their corresponding SQL:
<examples>
{examples}
</examples>
Now, convert the following natural language query into SQL:
<query>
{user_query}
</query>
First, provide your thought process within <thought_process> tags, explaining how you'll approach creating the SQL query.
Then, provide the SQL query within <sql> tags.
Within <thought_process> tags, explain your thought process for creating the SQL query.
Then, within <sql> tags, provide your output SQL query.
"""
def generate_prompt_with_rag(user_query, db_path='../data/data.db'):
def generate_prompt_with_rag(user_query):
from vectordb import VectorDB
# Load the vector database
@@ -115,7 +134,7 @@ def generate_prompt_with_rag(user_query, db_path='../data/data.db'):
vectordb.load_db()
if not vectordb.embeddings:
with sqlite3.connect(db_path) as conn:
with sqlite3.connect() as conn:
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
schema_data = [

View File

@@ -900,7 +900,9 @@
" </examples>\n",
"\n",
" Now, convert the following natural language query into SQL:\n",
" <query>\n",
" {query}\n",
" </query>\n",
"\n",
" Within <thought_process> tags, explain your thought process for creating the SQL query.\n",
" Then, within <sql> tags, provide your output SQL query.\n",