WIP Text to SQL - checkpoint

This commit is contained in:
Mahesh Murag
2024-09-25 02:03:10 +02:00
parent ec4a65a92a
commit 609f9364fb
9 changed files with 181 additions and 242 deletions

View File

@@ -1,5 +1,3 @@
python_path: /opt/homebrew/bin/python3
providers:
- id: anthropic:messages:claude-3-haiku-20240307
label: "3 Haiku"
@@ -32,25 +30,8 @@ tests:
- type: contains
value: "</sql>"
- type: python
value: |
import re
value: file://tests/test_simple_query.py
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def check_sql(sql):
required_elements = ['select', 'from employees', 'join departments', "name = 'engineering'"]
return all(element in sql.lower() for element in required_elements)
sql = extract_sql(output)
result = check_sql(sql)
return {
"pass": result,
"score": 1 if result else 0,
"reason": f"SQL query {'is correct' if result else 'is incorrect or not found'}"
}
- description: "Validate count of employees in Engineering department"
vars:
user_query: "How many employees are in the Engineering department?"
@@ -60,41 +41,8 @@ tests:
- type: contains
value: "</sql>"
- type: python
value: |
import re
import sqlite3
value: file://tests/test_employee_count.py
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results
sql = extract_sql(output)
try:
results = execute_sql(sql)
count = results[0][0] if results else 0
execution_success = True
except Exception as e:
execution_success = False
count = 0
print(f"SQL execution error: {e}")
expected_count = 20
return {
"pass": execution_success and count == expected_count,
"score": 1 if (execution_success and count == expected_count) else 0,
"reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. "
f"Returned count: {count}, Expected count: {expected_count}."
}
- description: "Check specific employee details in Engineering department"
vars:
user_query: "Give me the name, age, and salary of the oldest employee in the Engineering department."
@@ -104,133 +52,23 @@ tests:
- type: contains
value: "</sql>"
- type: python
value: |
import re
import sqlite3
value: file://tests/test_employee_details.py
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results
sql = extract_sql(output)
try:
results = execute_sql(sql)
row = results[0] if results else None
execution_success = True
except Exception as e:
execution_success = False
row = None
print(f"SQL execution error: {e}")
expected_result = {
"name": "Julia Clark",
"age": 64,
"salary": 103699.17
}
if row:
actual_result = {
"name": row[0],
"age": row[1],
"salary": row[2]
}
data_match = actual_result == expected_result
else:
data_match = False
return {
"pass": execution_success and data_match,
"score": 1 if (execution_success and data_match) else 0,
"reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. "
f"Data {'matches' if data_match else 'does not match'} expected result. "
f"Actual: {actual_result if row else 'No data'}, Expected: {expected_result}"
}
- description: "Find the average salary of employees in departments located in 'New York', but only for departments with more than 5 employees"
vars:
user_query: "What's the average salary for employees in New York-based departments that have more than 5 staff members?"
assert:
- type: python
value: |
import re
import sqlite3
value: file://tests/test_average_salary.py
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
# Check if results make sense (e.g., not empty, reasonable avg salary)
result_valid = len(results) > 0 and 40000 < results[0][0] < 200000
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}."
}
- description: "Find employees who earn more than their department's average salary, along with the percentage difference"
vars:
user_query: "Which employees earn above their department's average salary, and by what percentage?"
assert:
- type: python
value: |
import re
import sqlite3
value: file://tests/test_above_average_salary.py
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
# Check if results make sense (e.g., not empty, percentage above 0)
result_valid = len(results) > 0 and all(row[2] > 0 for row in results)
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}."
}
- description: "Complex hierarchical query with multiple aggregations: For each department, show the highest paid employee and how their salary compares to the average of the next 3 highest paid employees in that department"
- description: "Complex hierarchical query with multiple aggregations"
vars:
user_query: "For each department, show the name of the highest paid employee, their salary, and the percentage difference between their salary and the average salary of the next 3 highest paid employees in that department"
assert:
@@ -239,43 +77,8 @@ tests:
- type: contains
value: "AVG"
- type: python
value: |
import re
import sqlite3
value: file://tests/test_hierarchical_query.py
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
result_valid = len(results) > 0 and len(results[0]) == 4 # department, employee name, salary, percentage difference
if result_valid:
for row in results:
if not (isinstance(row[2], (int, float)) and isinstance(row[3], (int, float))):
result_valid = False
break
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid results obtained' if result_valid else 'Invalid or no results'}"
}
- description: "Department budget allocation analysis"
vars:
user_query: "Analyze the budget allocation across departments. Calculate the percentage of total salary budget each department consumes. Then, for each department, show the top 3 highest-paid employees and what percentage of the department's budget their salaries represent. Finally, calculate a 'budget efficiency' score for each department, defined as the department's percentage of total employees divided by its percentage of total salary budget."
@@ -287,41 +90,4 @@ tests:
- type: contains
value: "PARTITION BY"
- type: python
value: |
import re
import sqlite3
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
result_valid = len(results) > 0 and len(results[0]) >= 5 # department, budget %, top employees, their salary %, efficiency score
if result_valid:
for row in results:
if not (isinstance(row[1], float) and 0 <= row[1] <= 100 and
isinstance(row[-1], float)):
result_valid = False
break
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid budget analysis results obtained' if result_valid else 'Invalid or incomplete analysis results'}"
}
value: file://tests/test_budget_allocation.py

View File

@@ -0,0 +1,19 @@
from utils import extract_sql, execute_sql
def get_assert(output, context):
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
result_valid = len(results) > 0 and all(row[2] > 0 for row in results)
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}."
}

View File

@@ -0,0 +1,19 @@
from utils import extract_sql, execute_sql
def get_assert(output, context):
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
result_valid = len(results) > 0 and 40000 < results[0][0] < 200000
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}."
}

View File

@@ -0,0 +1,25 @@
from utils import extract_sql, execute_sql
def get_assert(output, context):
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
result_valid = len(results) > 0 and len(results[0]) >= 5 # department, budget %, top employees, their salary %, efficiency score
if result_valid:
for row in results:
if not (isinstance(row[1], float) and 0 <= row[1] <= 100 and
isinstance(row[-1], float)):
result_valid = False
break
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid budget analysis results obtained' if result_valid else 'Invalid or incomplete analysis results'}"
}

View File

@@ -0,0 +1,22 @@
from utils import extract_sql, execute_sql
def get_assert(output, context):
sql = extract_sql(output)
try:
results = execute_sql(sql)
count = results[0][0] if results else 0
execution_success = True
except Exception as e:
execution_success = False
count = 0
print(f"SQL execution error: {e}")
expected_count = 20
return {
"pass": execution_success and count == expected_count,
"score": 1 if (execution_success and count == expected_count) else 0,
"reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. "
f"Returned count: {count}, Expected count: {expected_count}."
}

View File

@@ -0,0 +1,37 @@
from utils import extract_sql, execute_sql
def get_assert(output, context):
sql = extract_sql(output)
try:
results = execute_sql(sql)
row = results[0] if results else None
execution_success = True
except Exception as e:
execution_success = False
row = None
print(f"SQL execution error: {e}")
expected_result = {
"name": "Julia Clark",
"age": 64,
"salary": 103699.17
}
if row:
actual_result = {
"name": row[0],
"age": row[1],
"salary": row[2]
}
data_match = actual_result == expected_result
else:
data_match = False
return {
"pass": execution_success and data_match,
"score": 1 if (execution_success and data_match) else 0,
"reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. "
f"Data {'matches' if data_match else 'does not match'} expected result. "
f"Actual: {actual_result if row else 'No data'}, Expected: {expected_result}"
}

View File

@@ -0,0 +1,24 @@
from utils import extract_sql, execute_sql
def get_assert(output, context):
sql = extract_sql(output)
try:
results = execute_sql(sql)
execution_success = True
result_valid = len(results) > 0 and len(results[0]) == 4 # department, employee name, salary, percentage difference
if result_valid:
for row in results:
if not (isinstance(row[2], (int, float)) and isinstance(row[3], (int, float))):
result_valid = False
break
except Exception as e:
execution_success = False
result_valid = False
print(f"SQL execution error: {e}")
return {
"pass": execution_success and result_valid,
"score": 1 if (execution_success and result_valid) else 0,
"reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid results obtained' if result_valid else 'Invalid or no results'}"
}

View File

@@ -0,0 +1,12 @@
from utils import extract_sql
def get_assert(output, context):
sql = extract_sql(output)
required_elements = ['select', 'from employees', 'join departments', "name = 'engineering'"]
result = all(element in sql.lower() for element in required_elements)
return {
"pass": result,
"score": 1 if result else 0,
"reason": f"SQL query {'is correct' if result else 'is incorrect or not found'}"
}

View File

@@ -0,0 +1,15 @@
# sql_utils.py
import re
import sqlite3
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results