WIP Text to SQL - checkpoint

This commit is contained in:
Mahesh Murag
2024-09-25 01:15:33 +02:00
parent 7b954b11d9
commit 5031fb36d4
4 changed files with 264 additions and 62 deletions

View File

@@ -1,5 +1,3 @@
# promptfooconfig.yaml
python_path: /opt/homebrew/bin/python3
providers:
@@ -151,4 +149,175 @@ tests:
"reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. "
f"Data {'matches' if data_match else 'does not match'} expected result. "
f"Actual: {actual_result if row else 'No data'}, Expected: {expected_result}"
}
}
- description: "Find the average salary of employees in departments located in 'New York', but only for departments with more than 5 employees"
vars:
user_query: "What's the average salary for employees in New York-based departments that have more than 5 staff members?"
assert:
- type: python
value: |
import re
import sqlite3
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
# Check if results make sense (e.g., not empty, reasonable avg salary)
result_valid = len(results) > 0 and 40000 < results[0][0] < 200000
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}."
}
- description: "Find employees who earn more than their department's average salary, along with the percentage difference"
vars:
user_query: "Which employees earn above their department's average salary, and by what percentage?"
assert:
- type: python
value: |
import re
import sqlite3
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
# Check if results make sense (e.g., not empty, percentage above 0)
result_valid = len(results) > 0 and all(row[2] > 0 for row in results)
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}."
}
- description: "Complex hierarchical query with multiple aggregations: For each department, show the highest paid employee and how their salary compares to the average of the next 3 highest paid employees in that department"
vars:
user_query: "For each department, show the name of the highest paid employee, their salary, and the percentage difference between their salary and the average salary of the next 3 highest paid employees in that department"
assert:
- type: contains
value: "PARTITION BY"
- type: contains
value: "AVG"
- type: python
value: |
import re
import sqlite3
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
result_valid = len(results) > 0 and len(results[0]) == 4 # department, employee name, salary, percentage difference
if result_valid:
for row in results:
if not (isinstance(row[2], (int, float)) and isinstance(row[3], (int, float))):
result_valid = False
break
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid results obtained' if result_valid else 'Invalid or no results'}"
}
- description: "Department budget allocation analysis"
vars:
user_query: "Analyze the budget allocation across departments. Calculate the percentage of total salary budget each department consumes. Then, for each department, show the top 3 highest-paid employees and what percentage of the department's budget their salaries represent. Finally, calculate a 'budget efficiency' score for each department, defined as the department's percentage of total employees divided by its percentage of total salary budget."
assert:
- type: contains
value: "WITH"
- type: contains
value: "ROW_NUMBER()"
- type: contains
value: "PARTITION BY"
- type: python
value: |
import re
import sqlite3
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
result_valid = len(results) > 0 and len(results[0]) >= 5 # department, budget %, top employees, their salary %, efficiency score
if result_valid:
for row in results:
if not (isinstance(row[1], float) and 0 <= row[1] <= 100 and
isinstance(row[-1], float)):
result_valid = False
break
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid budget analysis results obtained' if result_valid else 'Invalid or incomplete analysis results'}"
}

View File

@@ -153,18 +153,58 @@ def generate_prompt_with_rag(context):
relevant_schema = vectordb.search(user_query, k=10, similarity_threshold=0.3)
schema_info = "\n".join([f"Table: {item['metadata']['table']}, Column: {item['metadata']['column']}, Type: {item['metadata']['type']}"
for item in relevant_schema])
return f"""
You are an AI assistant that converts natural language queries into SQL.
examples = """
<example>
<query>List all employees in the HR department.</query>
<thought_process>
1. We need to join the employees and departments tables.
2. We'll match employees.department_id with departments.id.
3. We'll filter for the HR department.
4. We only need to return the employee names.
</thought_process>
<sql>SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'HR';</sql>
</example>
<example>
<query>What is the average salary of employees hired in 2022?</query>
<thought_process>
1. We need to work with the employees table.
2. We need to filter for employees hired in 2022.
3. We'll use the YEAR function to extract the year from the hire_date.
4. We'll calculate the average of the salary column for the filtered rows.
</thought_process>
<sql>SELECT AVG(salary) FROM employees WHERE YEAR(hire_date) = 2022;</sql>
</example>
"""
return f"""You are an AI assistant that converts natural language queries into SQL.
Given the following relevant columns from the SQL database schema:
<schema>
{schema_info}
</schema>
Convert the following natural language query into SQL:
Here are some examples of natural language queries, thought processes, and their corresponding SQL:
<examples>
{examples}
</examples>
Now, convert the following natural language query into SQL:
<query>
{user_query}
</query>
First, provide your thought process within <thought_process> tags, explaining how you'll approach creating the SQL query.
Then, provide the SQL query within <sql> tags.
First, provide your thought process within <thought_process> tags, explaining how you'll approach creating the SQL query. Consider the following steps:
1. Identify the relevant tables and columns from the provided schema.
2. Determine any necessary joins between tables.
3. Identify any filtering conditions.
4. Decide on the appropriate aggregations or calculations.
5. Structure the query logically.
Ensure your SQL query is compatible with SQLite syntax.
"""
Then, within <sql> tags, provide your output SQL query.
Ensure your SQL query is compatible with SQLite syntax and uses only the tables and columns provided in the schema.
If you're unsure about a particular table or column, use the information available in the provided schema.
"""

View File

@@ -1601,7 +1601,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"metadata": {},
"outputs": [
{
@@ -1610,7 +1610,7 @@
"text": [
"Generated prompt:\n",
"You are an AI assistant that converts natural language queries into SQL.\n",
" Given the following relevant columns from the SQL database schema:\n",
" Given the following SQL database schema:\n",
"\n",
" <schema>\n",
" Table: employees, Column: salary, Type: REAL\n",
@@ -1624,7 +1624,35 @@
"Table: employees, Column: hire_date, Type: DATE\n",
" </schema>\n",
"\n",
" Convert the following natural language query into SQL:\n",
" Here are some examples of natural language queries, thought processes, and their corresponding SQL:\n",
"\n",
" <examples>\n",
" \n",
" <example>\n",
" <query>List all employees in the HR department.</query>\n",
" <thought_process>\n",
" 1. We need to join the employees and departments tables.\n",
" 2. We'll match employees.department_id with departments.id.\n",
" 3. We'll filter for the HR department.\n",
" 4. We only need to return the employee names.\n",
" </thought_process>\n",
" <sql>SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'HR';</sql>\n",
" </example>\n",
"\n",
" <example>\n",
" <query>What is the average salary of employees hired in 2022?</query>\n",
" <thought_process>\n",
" 1. We need to work with the employees table.\n",
" 2. We need to filter for employees hired in 2022.\n",
" 3. We'll use the YEAR function to extract the year from the hire_date.\n",
" 4. We'll calculate the average of the salary column for the filtered rows.\n",
" </thought_process>\n",
" <sql>SELECT AVG(salary) FROM employees WHERE YEAR(hire_date) = 2022;</sql>\n",
" </example>\n",
" \n",
" </examples>\n",
"\n",
" Now, convert the following natural language query into SQL:\n",
" <query>\n",
" What is the average salary of employees in each department?\n",
" </query>\n",
@@ -1635,46 +1663,25 @@
"\n",
"Generated result:\n",
"<thought_process>\n",
"To answer this query, we need to:\n",
"1. Join the employees and departments tables to get department information for each employee.\n",
"2. Group the results by department.\n",
"3. Calculate the average salary for each group.\n",
"\n",
"Here's the step-by-step thought process:\n",
"1. We'll use the employees table as our main table since it contains the salary information.\n",
"2. We need to join the departments table to get the department names.\n",
"3. The join will be on employees.department_id = departments.id\n",
"4. We'll group the results by department name (or id, but name is more informative).\n",
"5. We'll use the AVG function to calculate the average salary for each group.\n",
"6. We'll select the department name and the average salary in the SELECT clause.\n",
"1. We need to work with both the employees and departments tables to get department names and employee salaries.\n",
"2. We'll need to join these tables using the department_id from employees and id from departments.\n",
"3. We want to group the results by department, so we'll use GROUP BY.\n",
"4. We need to calculate the average salary for each department.\n",
"5. We should select the department name and the average salary for each group.\n",
"</thought_process>\n",
"\n",
"<sql>\n",
"SELECT \n",
" d.name AS department_name,\n",
" AVG(e.salary) AS average_salary\n",
"FROM \n",
" employees e\n",
"JOIN \n",
" departments d ON e.department_id = d.id\n",
"GROUP BY \n",
" d.name\n",
"ORDER BY \n",
" d.name\n",
"SELECT d.name AS department_name, AVG(e.salary) AS average_salary\n",
"FROM employees e\n",
"JOIN departments d ON e.department_id = d.id\n",
"GROUP BY d.name;\n",
"</sql>\n",
"\n",
"Extracted SQL:\n",
"SELECT \n",
" d.name AS department_name,\n",
" AVG(e.salary) AS average_salary\n",
"FROM \n",
" employees e\n",
"JOIN \n",
" departments d ON e.department_id = d.id\n",
"GROUP BY \n",
" d.name\n",
"ORDER BY \n",
" d.name\n",
"SELECT d.name AS department_name, AVG(e.salary) AS average_salary\n",
"FROM employees e\n",
"JOIN departments d ON e.department_id = d.id\n",
"GROUP BY d.name;\n",
"\n",
"Query result:\n"
]
@@ -1783,21 +1790,7 @@
" schema_info = \"\\n\".join([f\"Table: {item['metadata']['table']}, Column: {item['metadata']['column']}, Type: {item['metadata']['type']}\" \n",
" for item in relevant_schema])\n",
" \n",
" return f\"\"\"You are an AI assistant that converts natural language queries into SQL.\n",
" Given the following relevant columns from the SQL database schema:\n",
"\n",
" <schema>\n",
" {schema_info}\n",
" </schema>\n",
"\n",
" Convert the following natural language query into SQL:\n",
" <query>\n",
" {query}\n",
" </query>\n",
"\n",
" Within <thought_process> tags, explain your thought process for creating the SQL query.\n",
" Then, within <sql> tags, provide your output SQL query.\n",
" \"\"\"\n",
" return generate_prompt_with_cot(schema_info, query)\n",
"\n",
"# Test the RAG-based prompt\n",
"user_query = \"What is the average salary of employees in each department?\"\n",