WIP Text to SQL - checkpoint

This commit is contained in:
Mahesh Murag
2024-09-24 00:39:48 +02:00
parent ce51e8c078
commit 3be7f8fc9b
6 changed files with 858 additions and 123 deletions

Binary file not shown.

View File

@@ -0,0 +1,185 @@
user_query,[3.5 Sonnet] prompts.py:generate_prompt,[3.5 Sonnet] prompts.py:generate_prompt_with_examples,[3.5 Sonnet] prompts.py:generate_prompt_with_cot,[3.5 Sonnet] prompts.py:generate_prompt_with_rag
What are the names of all employees in the Engineering department?,"[PASS] (1.00)
<sql>
SELECT employees.name
FROM employees
JOIN departments ON employees.department_id = departments.id
WHERE departments.name = 'Engineering';
</sql>
Pass Reason: All assertions passed","[PASS] (1.00)
<sql>
SELECT e.name
FROM employees e
JOIN departments d ON e.department_id = d.id
WHERE d.name = 'Engineering';
</sql>
Pass Reason: All assertions passed","[PASS] (1.00)
<thought_process>
1. We need to join the employees and departments tables to get the department information for each employee.
2. We'll match employees.department_id with departments.id for the join.
3. We need to filter for employees in the Engineering department.
4. We only need to return the employee names.
5. We'll use a JOIN operation to combine the tables and a WHERE clause to filter for the Engineering department.
</thought_process>
<sql>
SELECT e.name
FROM employees e
JOIN departments d ON e.department_id = d.id
WHERE d.name = 'Engineering';
</sql>
Pass Reason: All assertions passed","[PASS] (1.00)
<thought_process>
To answer this query, we need to:
1. Join the employees and departments tables since the department information is in a separate table.
2. Filter for employees in the Engineering department.
3. Select only the names of these employees.
The steps to create the SQL query:
1. Start with SELECT to choose the employee names.
2. FROM the employees table.
3. JOIN with the departments table to get department information.
4. Use WHERE to filter for the Engineering department.
5. The join condition will be on the department_id in employees and id in departments.
</thought_process>
<sql>
SELECT employees.name
FROM employees
JOIN departments ON employees.department_id = departments.id
WHERE departments.name = 'Engineering';
</sql>
Pass Reason: All assertions passed"
How many employees are in the Engineering department?,"[PASS] (1.00)
<sql>
SELECT COUNT(*)
FROM employees e
JOIN departments d ON e.department_id = d.id
WHERE d.name = 'Engineering';
</sql>
Pass Reason: All assertions passed","[PASS] (1.00)
<sql>
SELECT COUNT(*) FROM employees e
JOIN departments d ON e.department_id = d.id
WHERE d.name = 'Engineering';
</sql>
Pass Reason: All assertions passed","[PASS] (1.00)
<thought_process>
1. We need to count employees in the Engineering department.
2. This requires joining the employees and departments tables.
3. We'll match employees.department_id with departments.id.
4. We need to filter for the Engineering department.
5. We'll use COUNT(*) to get the number of employees.
6. Since we're counting all employees in the department, we don't need to select any specific columns from the employees table.
</thought_process>
<sql>
SELECT COUNT(*)
FROM employees e
JOIN departments d ON e.department_id = d.id
WHERE d.name = 'Engineering';
</sql>
Pass Reason: All assertions passed","[PASS] (1.00)
<thought_process>
To answer this query, we need to:
1. Join the employees and departments tables since the department name is in the departments table, but we need to count employees.
2. Filter for the Engineering department.
3. Count the number of employees in that department.
We'll use:
- A JOIN between employees and departments tables on department_id
- A WHERE clause to filter for the Engineering department
- COUNT(*) to count the number of employees
</thought_process>
<sql>
SELECT COUNT(*) AS employee_count
FROM employees e
JOIN departments d ON e.department_id = d.id
WHERE d.name = 'Engineering';
</sql>
Pass Reason: All assertions passed"
"Give me the name, age, and salary of the oldest employee in the Engineering department.","[PASS] (1.00)
<sql>
SELECT e.name, e.age, e.salary
FROM employees e
JOIN departments d ON e.department_id = d.id
WHERE d.name = 'Engineering'
ORDER BY e.age DESC
LIMIT 1
</sql>
Pass Reason: All assertions passed","[PASS] (1.00)
<sql>
SELECT e.name, e.age, e.salary
FROM employees e
JOIN departments d ON e.department_id = d.id
WHERE d.name = 'Engineering'
ORDER BY e.age DESC
LIMIT 1;
</sql>
Pass Reason: All assertions passed","[PASS] (1.00)
<thought_process>
1. We need to join the employees and departments tables to filter for the Engineering department.
2. We'll match employees.department_id with departments.id.
3. We need to filter for the Engineering department.
4. To find the oldest employee, we'll order by age in descending order and limit to 1 result.
5. We need to select the name, age, and salary columns from the employees table.
</thought_process>
<sql>
SELECT e.name, e.age, e.salary
FROM employees e
JOIN departments d ON e.department_id = d.id
WHERE d.name = 'Engineering'
ORDER BY e.age DESC
LIMIT 1;
</sql>
Pass Reason: All assertions passed","[PASS] (1.00)
<thought_process>
To solve this query, I'll follow these steps:
1. We need to join the employees and departments tables to get the department information.
2. We need to filter for employees in the Engineering department.
3. We need to find the oldest employee, which means the maximum age.
4. We need to select the name, age, and salary of this employee.
5. We'll use a subquery to find the maximum age in the Engineering department.
6. Then we'll use this subquery in the main query to get the desired information.
</thought_process>
<sql>
SELECT e.name, e.age, e.salary
FROM employees e
JOIN departments d ON e.department_id = d.id
WHERE d.name = 'Engineering'
AND e.age = (
SELECT MAX(e2.age)
FROM employees e2
JOIN departments d2 ON e2.department_id = d2.id
WHERE d2.name = 'Engineering'
)
LIMIT 1;
</sql>
Pass Reason: All assertions passed"
1 user_query [3.5 Sonnet] prompts.py:generate_prompt [3.5 Sonnet] prompts.py:generate_prompt_with_examples [3.5 Sonnet] prompts.py:generate_prompt_with_cot [3.5 Sonnet] prompts.py:generate_prompt_with_rag
2 What are the names of all employees in the Engineering department? [PASS] (1.00) <sql> SELECT employees.name FROM employees JOIN departments ON employees.department_id = departments.id WHERE departments.name = 'Engineering'; </sql> Pass Reason: All assertions passed [PASS] (1.00) <sql> SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering'; </sql> Pass Reason: All assertions passed [PASS] (1.00) <thought_process> 1. We need to join the employees and departments tables to get the department information for each employee. 2. We'll match employees.department_id with departments.id for the join. 3. We need to filter for employees in the Engineering department. 4. We only need to return the employee names. 5. We'll use a JOIN operation to combine the tables and a WHERE clause to filter for the Engineering department. </thought_process> <sql> SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering'; </sql> Pass Reason: All assertions passed [PASS] (1.00) <thought_process> To answer this query, we need to: 1. Join the employees and departments tables since the department information is in a separate table. 2. Filter for employees in the Engineering department. 3. Select only the names of these employees. The steps to create the SQL query: 1. Start with SELECT to choose the employee names. 2. FROM the employees table. 3. JOIN with the departments table to get department information. 4. Use WHERE to filter for the Engineering department. 5. The join condition will be on the department_id in employees and id in departments. </thought_process> <sql> SELECT employees.name FROM employees JOIN departments ON employees.department_id = departments.id WHERE departments.name = 'Engineering'; </sql> Pass Reason: All assertions passed
3 How many employees are in the Engineering department? [PASS] (1.00) <sql> SELECT COUNT(*) FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering'; </sql> Pass Reason: All assertions passed [PASS] (1.00) <sql> SELECT COUNT(*) FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering'; </sql> Pass Reason: All assertions passed [PASS] (1.00) <thought_process> 1. We need to count employees in the Engineering department. 2. This requires joining the employees and departments tables. 3. We'll match employees.department_id with departments.id. 4. We need to filter for the Engineering department. 5. We'll use COUNT(*) to get the number of employees. 6. Since we're counting all employees in the department, we don't need to select any specific columns from the employees table. </thought_process> <sql> SELECT COUNT(*) FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering'; </sql> Pass Reason: All assertions passed [PASS] (1.00) <thought_process> To answer this query, we need to: 1. Join the employees and departments tables since the department name is in the departments table, but we need to count employees. 2. Filter for the Engineering department. 3. Count the number of employees in that department. We'll use: - A JOIN between employees and departments tables on department_id - A WHERE clause to filter for the Engineering department - COUNT(*) to count the number of employees </thought_process> <sql> SELECT COUNT(*) AS employee_count FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering'; </sql> Pass Reason: All assertions passed
4 Give me the name, age, and salary of the oldest employee in the Engineering department. [PASS] (1.00) <sql> SELECT e.name, e.age, e.salary FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering' ORDER BY e.age DESC LIMIT 1 </sql> Pass Reason: All assertions passed [PASS] (1.00) <sql> SELECT e.name, e.age, e.salary FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering' ORDER BY e.age DESC LIMIT 1; </sql> Pass Reason: All assertions passed [PASS] (1.00) <thought_process> 1. We need to join the employees and departments tables to filter for the Engineering department. 2. We'll match employees.department_id with departments.id. 3. We need to filter for the Engineering department. 4. To find the oldest employee, we'll order by age in descending order and limit to 1 result. 5. We need to select the name, age, and salary columns from the employees table. </thought_process> <sql> SELECT e.name, e.age, e.salary FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering' ORDER BY e.age DESC LIMIT 1; </sql> Pass Reason: All assertions passed [PASS] (1.00) <thought_process> To solve this query, I'll follow these steps: 1. We need to join the employees and departments tables to get the department information. 2. We need to filter for employees in the Engineering department. 3. We need to find the oldest employee, which means the maximum age. 4. We need to select the name, age, and salary of this employee. 5. We'll use a subquery to find the maximum age in the Engineering department. 6. Then we'll use this subquery in the main query to get the desired information. </thought_process> <sql> SELECT e.name, e.age, e.salary FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering' AND e.age = ( SELECT MAX(e2.age) FROM employees e2 JOIN departments d2 ON e2.department_id = d2.id WHERE d2.name = 'Engineering' ) LIMIT 1; </sql> Pass Reason: All assertions passed

View File

@@ -4,6 +4,7 @@ python_path: /opt/homebrew/bin/python3
providers:
- id: anthropic:messages:claude-3-5-sonnet-20240620
label: "3.5 Sonnet"
config:
max_tokens: 4096
temperature: 0
@@ -12,7 +13,7 @@ prompts:
- prompts.py:generate_prompt
- prompts.py:generate_prompt_with_examples
- prompts.py:generate_prompt_with_cot
# - prompts.py:generate_prompt_with_rag
- prompts.py:generate_prompt_with_rag
tests:
- description: "Check syntax of simple query"
@@ -79,7 +80,7 @@ tests:
count = 0
print(f"SQL execution error: {e}")
expected_count = 3
expected_count = 20
return {
"pass": execution_success and count == expected_count,
@@ -124,9 +125,9 @@ tests:
print(f"SQL execution error: {e}")
expected_result = {
"name": "Charlie Davis",
"age": 31,
"salary": 85000.00
"name": "Julia Clark",
"age": 64,
"salary": 103699.17
}
if row:
@@ -148,6 +149,4 @@ tests:
}
output:
- type: csv
path: ./results.csv
outputPath: ../data/results.csv

View File

@@ -24,7 +24,8 @@ def get_schema_info():
conn.close()
return "\n\n".join(schema_info)
def generate_prompt(user_query):
def generate_prompt(context):
user_query = context['vars']['user_query']
schema = get_schema_info()
return f"""
You are an AI assistant that converts natural language queries into SQL.
@@ -39,7 +40,8 @@ def generate_prompt(user_query):
Provide only the SQL query in your response, enclosed within <sql> tags.
"""
def generate_prompt_with_examples(user_query):
def generate_prompt_with_examples(context):
user_query = context['vars']['user_query']
examples = """
Example 1:
<query>List all employees in the HR department.</<query>
@@ -78,7 +80,8 @@ def generate_prompt_with_examples(user_query):
Provide only the SQL query in your response, enclosed within <sql> tags.
"""
def generate_prompt_with_cot(user_query):
def generate_prompt_with_cot(context):
user_query = context['vars']['user_query']
schema = get_schema_info()
examples = """
<example>
@@ -126,15 +129,17 @@ def generate_prompt_with_cot(user_query):
Then, within <sql> tags, provide your output SQL query.
"""
def generate_prompt_with_rag(user_query):
def generate_prompt_with_rag(context):
from vectordb import VectorDB
# Load the vector database
vectordb = VectorDB()
vectordb.load_db()
user_query = context['vars']['user_query']
if not vectordb.embeddings:
with sqlite3.connect() as conn:
with sqlite3.connect(DATABASE_PATH) as conn:
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
schema_data = [
@@ -144,7 +149,7 @@ def generate_prompt_with_rag(user_query):
for col in cursor.execute(f"PRAGMA table_info({table[0]})").fetchall()
]
vectordb.load_data(schema_data)
relevant_schema = vectordb.search(user_query, k=10, similarity_threshold=0.3)
schema_info = "\n".join([f"Table: {item['metadata']['table']}, Column: {item['metadata']['column']}, Type: {item['metadata']['type']}"
for item in relevant_schema])

File diff suppressed because one or more lines are too long