mirror of
https://github.com/anthropics/claude-cookbooks.git
synced 2025-10-06 01:00:28 +03:00
WIP Text to SQL - checkpoint
This commit is contained in:
Binary file not shown.
@@ -1,5 +1,3 @@
|
||||
# promptfooconfig.yaml
|
||||
|
||||
python_path: /opt/homebrew/bin/python3
|
||||
|
||||
providers:
|
||||
@@ -151,4 +149,175 @@ tests:
|
||||
"reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. "
|
||||
f"Data {'matches' if data_match else 'does not match'} expected result. "
|
||||
f"Actual: {actual_result if row else 'No data'}, Expected: {expected_result}"
|
||||
}
|
||||
}
|
||||
- description: "Find the average salary of employees in departments located in 'New York', but only for departments with more than 5 employees"
|
||||
vars:
|
||||
user_query: "What's the average salary for employees in New York-based departments that have more than 5 staff members?"
|
||||
assert:
|
||||
- type: python
|
||||
value: |
|
||||
import re
|
||||
import sqlite3
|
||||
|
||||
def extract_sql(text):
|
||||
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def execute_sql(sql):
|
||||
conn = sqlite3.connect('../data/data.db')
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
conn.close()
|
||||
return results
|
||||
|
||||
sql = extract_sql(output)
|
||||
|
||||
try:
|
||||
results = execute_sql(sql)
|
||||
execution_success = True
|
||||
# Check if results make sense (e.g., not empty, reasonable avg salary)
|
||||
result_valid = len(results) > 0 and 40000 < results[0][0] < 200000
|
||||
except Exception as e:
|
||||
execution_success = False
|
||||
result_valid = False
|
||||
print(f"SQL execution error: {e}")
|
||||
|
||||
return {
|
||||
"pass": execution_success and result_valid,
|
||||
"score": 1 if (execution_success and result_valid) else 0,
|
||||
"reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}."
|
||||
}
|
||||
- description: "Find employees who earn more than their department's average salary, along with the percentage difference"
|
||||
vars:
|
||||
user_query: "Which employees earn above their department's average salary, and by what percentage?"
|
||||
assert:
|
||||
- type: python
|
||||
value: |
|
||||
import re
|
||||
import sqlite3
|
||||
|
||||
def extract_sql(text):
|
||||
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def execute_sql(sql):
|
||||
conn = sqlite3.connect('../data/data.db')
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
conn.close()
|
||||
return results
|
||||
|
||||
sql = extract_sql(output)
|
||||
|
||||
try:
|
||||
results = execute_sql(sql)
|
||||
execution_success = True
|
||||
# Check if results make sense (e.g., not empty, percentage above 0)
|
||||
result_valid = len(results) > 0 and all(row[2] > 0 for row in results)
|
||||
except Exception as e:
|
||||
execution_success = False
|
||||
result_valid = False
|
||||
print(f"SQL execution error: {e}")
|
||||
|
||||
return {
|
||||
"pass": execution_success and result_valid,
|
||||
"score": 1 if (execution_success and result_valid) else 0,
|
||||
"reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}."
|
||||
}
|
||||
- description: "Complex hierarchical query with multiple aggregations: For each department, show the highest paid employee and how their salary compares to the average of the next 3 highest paid employees in that department"
|
||||
vars:
|
||||
user_query: "For each department, show the name of the highest paid employee, their salary, and the percentage difference between their salary and the average salary of the next 3 highest paid employees in that department"
|
||||
assert:
|
||||
- type: contains
|
||||
value: "PARTITION BY"
|
||||
- type: contains
|
||||
value: "AVG"
|
||||
- type: python
|
||||
value: |
|
||||
import re
|
||||
import sqlite3
|
||||
|
||||
def extract_sql(text):
|
||||
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def execute_sql(sql):
|
||||
conn = sqlite3.connect('../data/data.db')
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
conn.close()
|
||||
return results
|
||||
|
||||
sql = extract_sql(output)
|
||||
|
||||
try:
|
||||
results = execute_sql(sql)
|
||||
execution_success = True
|
||||
result_valid = len(results) > 0 and len(results[0]) == 4 # department, employee name, salary, percentage difference
|
||||
if result_valid:
|
||||
for row in results:
|
||||
if not (isinstance(row[2], (int, float)) and isinstance(row[3], (int, float))):
|
||||
result_valid = False
|
||||
break
|
||||
except Exception as e:
|
||||
execution_success = False
|
||||
result_valid = False
|
||||
print(f"SQL execution error: {e}")
|
||||
|
||||
return {
|
||||
"pass": execution_success and result_valid,
|
||||
"score": 1 if (execution_success and result_valid) else 0,
|
||||
"reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid results obtained' if result_valid else 'Invalid or no results'}"
|
||||
}
|
||||
- description: "Department budget allocation analysis"
|
||||
vars:
|
||||
user_query: "Analyze the budget allocation across departments. Calculate the percentage of total salary budget each department consumes. Then, for each department, show the top 3 highest-paid employees and what percentage of the department's budget their salaries represent. Finally, calculate a 'budget efficiency' score for each department, defined as the department's percentage of total employees divided by its percentage of total salary budget."
|
||||
assert:
|
||||
- type: contains
|
||||
value: "WITH"
|
||||
- type: contains
|
||||
value: "ROW_NUMBER()"
|
||||
- type: contains
|
||||
value: "PARTITION BY"
|
||||
- type: python
|
||||
value: |
|
||||
import re
|
||||
import sqlite3
|
||||
|
||||
def extract_sql(text):
|
||||
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def execute_sql(sql):
|
||||
conn = sqlite3.connect('../data/data.db')
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
conn.close()
|
||||
return results
|
||||
|
||||
sql = extract_sql(output)
|
||||
|
||||
try:
|
||||
results = execute_sql(sql)
|
||||
execution_success = True
|
||||
result_valid = len(results) > 0 and len(results[0]) >= 5 # department, budget %, top employees, their salary %, efficiency score
|
||||
if result_valid:
|
||||
for row in results:
|
||||
if not (isinstance(row[1], float) and 0 <= row[1] <= 100 and
|
||||
isinstance(row[-1], float)):
|
||||
result_valid = False
|
||||
break
|
||||
except Exception as e:
|
||||
execution_success = False
|
||||
result_valid = False
|
||||
print(f"SQL execution error: {e}")
|
||||
|
||||
return {
|
||||
"pass": execution_success and result_valid,
|
||||
"score": 1 if (execution_success and result_valid) else 0,
|
||||
"reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid budget analysis results obtained' if result_valid else 'Invalid or incomplete analysis results'}"
|
||||
}
|
||||
|
||||
@@ -153,18 +153,58 @@ def generate_prompt_with_rag(context):
|
||||
relevant_schema = vectordb.search(user_query, k=10, similarity_threshold=0.3)
|
||||
schema_info = "\n".join([f"Table: {item['metadata']['table']}, Column: {item['metadata']['column']}, Type: {item['metadata']['type']}"
|
||||
for item in relevant_schema])
|
||||
return f"""
|
||||
You are an AI assistant that converts natural language queries into SQL.
|
||||
|
||||
examples = """
|
||||
<example>
|
||||
<query>List all employees in the HR department.</query>
|
||||
<thought_process>
|
||||
1. We need to join the employees and departments tables.
|
||||
2. We'll match employees.department_id with departments.id.
|
||||
3. We'll filter for the HR department.
|
||||
4. We only need to return the employee names.
|
||||
</thought_process>
|
||||
<sql>SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'HR';</sql>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
<query>What is the average salary of employees hired in 2022?</query>
|
||||
<thought_process>
|
||||
1. We need to work with the employees table.
|
||||
2. We need to filter for employees hired in 2022.
|
||||
3. We'll use the YEAR function to extract the year from the hire_date.
|
||||
4. We'll calculate the average of the salary column for the filtered rows.
|
||||
</thought_process>
|
||||
<sql>SELECT AVG(salary) FROM employees WHERE YEAR(hire_date) = 2022;</sql>
|
||||
</example>
|
||||
"""
|
||||
|
||||
return f"""You are an AI assistant that converts natural language queries into SQL.
|
||||
Given the following relevant columns from the SQL database schema:
|
||||
|
||||
<schema>
|
||||
{schema_info}
|
||||
</schema>
|
||||
|
||||
Convert the following natural language query into SQL:
|
||||
Here are some examples of natural language queries, thought processes, and their corresponding SQL:
|
||||
|
||||
<examples>
|
||||
{examples}
|
||||
</examples>
|
||||
|
||||
Now, convert the following natural language query into SQL:
|
||||
<query>
|
||||
{user_query}
|
||||
</query>
|
||||
|
||||
First, provide your thought process within <thought_process> tags, explaining how you'll approach creating the SQL query.
|
||||
Then, provide the SQL query within <sql> tags.
|
||||
First, provide your thought process within <thought_process> tags, explaining how you'll approach creating the SQL query. Consider the following steps:
|
||||
1. Identify the relevant tables and columns from the provided schema.
|
||||
2. Determine any necessary joins between tables.
|
||||
3. Identify any filtering conditions.
|
||||
4. Decide on the appropriate aggregations or calculations.
|
||||
5. Structure the query logically.
|
||||
|
||||
Ensure your SQL query is compatible with SQLite syntax.
|
||||
"""
|
||||
Then, within <sql> tags, provide your output SQL query.
|
||||
|
||||
Ensure your SQL query is compatible with SQLite syntax and uses only the tables and columns provided in the schema.
|
||||
If you're unsure about a particular table or column, use the information available in the provided schema.
|
||||
"""
|
||||
@@ -1601,7 +1601,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -1610,7 +1610,7 @@
|
||||
"text": [
|
||||
"Generated prompt:\n",
|
||||
"You are an AI assistant that converts natural language queries into SQL.\n",
|
||||
" Given the following relevant columns from the SQL database schema:\n",
|
||||
" Given the following SQL database schema:\n",
|
||||
"\n",
|
||||
" <schema>\n",
|
||||
" Table: employees, Column: salary, Type: REAL\n",
|
||||
@@ -1624,7 +1624,35 @@
|
||||
"Table: employees, Column: hire_date, Type: DATE\n",
|
||||
" </schema>\n",
|
||||
"\n",
|
||||
" Convert the following natural language query into SQL:\n",
|
||||
" Here are some examples of natural language queries, thought processes, and their corresponding SQL:\n",
|
||||
"\n",
|
||||
" <examples>\n",
|
||||
" \n",
|
||||
" <example>\n",
|
||||
" <query>List all employees in the HR department.</query>\n",
|
||||
" <thought_process>\n",
|
||||
" 1. We need to join the employees and departments tables.\n",
|
||||
" 2. We'll match employees.department_id with departments.id.\n",
|
||||
" 3. We'll filter for the HR department.\n",
|
||||
" 4. We only need to return the employee names.\n",
|
||||
" </thought_process>\n",
|
||||
" <sql>SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'HR';</sql>\n",
|
||||
" </example>\n",
|
||||
"\n",
|
||||
" <example>\n",
|
||||
" <query>What is the average salary of employees hired in 2022?</query>\n",
|
||||
" <thought_process>\n",
|
||||
" 1. We need to work with the employees table.\n",
|
||||
" 2. We need to filter for employees hired in 2022.\n",
|
||||
" 3. We'll use the YEAR function to extract the year from the hire_date.\n",
|
||||
" 4. We'll calculate the average of the salary column for the filtered rows.\n",
|
||||
" </thought_process>\n",
|
||||
" <sql>SELECT AVG(salary) FROM employees WHERE YEAR(hire_date) = 2022;</sql>\n",
|
||||
" </example>\n",
|
||||
" \n",
|
||||
" </examples>\n",
|
||||
"\n",
|
||||
" Now, convert the following natural language query into SQL:\n",
|
||||
" <query>\n",
|
||||
" What is the average salary of employees in each department?\n",
|
||||
" </query>\n",
|
||||
@@ -1635,46 +1663,25 @@
|
||||
"\n",
|
||||
"Generated result:\n",
|
||||
"<thought_process>\n",
|
||||
"To answer this query, we need to:\n",
|
||||
"1. Join the employees and departments tables to get department information for each employee.\n",
|
||||
"2. Group the results by department.\n",
|
||||
"3. Calculate the average salary for each group.\n",
|
||||
"\n",
|
||||
"Here's the step-by-step thought process:\n",
|
||||
"1. We'll use the employees table as our main table since it contains the salary information.\n",
|
||||
"2. We need to join the departments table to get the department names.\n",
|
||||
"3. The join will be on employees.department_id = departments.id\n",
|
||||
"4. We'll group the results by department name (or id, but name is more informative).\n",
|
||||
"5. We'll use the AVG function to calculate the average salary for each group.\n",
|
||||
"6. We'll select the department name and the average salary in the SELECT clause.\n",
|
||||
"1. We need to work with both the employees and departments tables to get department names and employee salaries.\n",
|
||||
"2. We'll need to join these tables using the department_id from employees and id from departments.\n",
|
||||
"3. We want to group the results by department, so we'll use GROUP BY.\n",
|
||||
"4. We need to calculate the average salary for each department.\n",
|
||||
"5. We should select the department name and the average salary for each group.\n",
|
||||
"</thought_process>\n",
|
||||
"\n",
|
||||
"<sql>\n",
|
||||
"SELECT \n",
|
||||
" d.name AS department_name,\n",
|
||||
" AVG(e.salary) AS average_salary\n",
|
||||
"FROM \n",
|
||||
" employees e\n",
|
||||
"JOIN \n",
|
||||
" departments d ON e.department_id = d.id\n",
|
||||
"GROUP BY \n",
|
||||
" d.name\n",
|
||||
"ORDER BY \n",
|
||||
" d.name\n",
|
||||
"SELECT d.name AS department_name, AVG(e.salary) AS average_salary\n",
|
||||
"FROM employees e\n",
|
||||
"JOIN departments d ON e.department_id = d.id\n",
|
||||
"GROUP BY d.name;\n",
|
||||
"</sql>\n",
|
||||
"\n",
|
||||
"Extracted SQL:\n",
|
||||
"SELECT \n",
|
||||
" d.name AS department_name,\n",
|
||||
" AVG(e.salary) AS average_salary\n",
|
||||
"FROM \n",
|
||||
" employees e\n",
|
||||
"JOIN \n",
|
||||
" departments d ON e.department_id = d.id\n",
|
||||
"GROUP BY \n",
|
||||
" d.name\n",
|
||||
"ORDER BY \n",
|
||||
" d.name\n",
|
||||
"SELECT d.name AS department_name, AVG(e.salary) AS average_salary\n",
|
||||
"FROM employees e\n",
|
||||
"JOIN departments d ON e.department_id = d.id\n",
|
||||
"GROUP BY d.name;\n",
|
||||
"\n",
|
||||
"Query result:\n"
|
||||
]
|
||||
@@ -1783,21 +1790,7 @@
|
||||
" schema_info = \"\\n\".join([f\"Table: {item['metadata']['table']}, Column: {item['metadata']['column']}, Type: {item['metadata']['type']}\" \n",
|
||||
" for item in relevant_schema])\n",
|
||||
" \n",
|
||||
" return f\"\"\"You are an AI assistant that converts natural language queries into SQL.\n",
|
||||
" Given the following relevant columns from the SQL database schema:\n",
|
||||
"\n",
|
||||
" <schema>\n",
|
||||
" {schema_info}\n",
|
||||
" </schema>\n",
|
||||
"\n",
|
||||
" Convert the following natural language query into SQL:\n",
|
||||
" <query>\n",
|
||||
" {query}\n",
|
||||
" </query>\n",
|
||||
"\n",
|
||||
" Within <thought_process> tags, explain your thought process for creating the SQL query.\n",
|
||||
" Then, within <sql> tags, provide your output SQL query.\n",
|
||||
" \"\"\"\n",
|
||||
" return generate_prompt_with_cot(schema_info, query)\n",
|
||||
"\n",
|
||||
"# Test the RAG-based prompt\n",
|
||||
"user_query = \"What is the average salary of employees in each department?\"\n",
|
||||
|
||||
Reference in New Issue
Block a user