mirror of
				https://github.com/anthropics/claude-cookbooks.git
				synced 2025-10-06 01:00:28 +03:00 
			
		
		
		
	WIP Text to SQL - checkpoint
This commit is contained in:
		
										
											Binary file not shown.
										
									
								
							| @@ -1,5 +1,3 @@ | ||||
| # promptfooconfig.yaml | ||||
|  | ||||
| python_path: /opt/homebrew/bin/python3 | ||||
|  | ||||
| providers: | ||||
| @@ -151,4 +149,175 @@ tests: | ||||
|               "reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. " | ||||
|                         f"Data {'matches' if data_match else 'does not match'} expected result. " | ||||
|                         f"Actual: {actual_result if row else 'No data'}, Expected: {expected_result}" | ||||
|           } | ||||
|           } | ||||
|   - description: "Find the average salary of employees in departments located in 'New York', but only for departments with more than 5 employees" | ||||
|     vars: | ||||
|       user_query: "What's the average salary for employees in New York-based departments that have more than 5 staff members?" | ||||
|     assert: | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|           import sqlite3 | ||||
|  | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|           def execute_sql(sql): | ||||
|               conn = sqlite3.connect('../data/data.db') | ||||
|               cursor = conn.cursor() | ||||
|               cursor.execute(sql) | ||||
|               results = cursor.fetchall() | ||||
|               conn.close() | ||||
|               return results | ||||
|  | ||||
|           sql = extract_sql(output) | ||||
|            | ||||
|           try: | ||||
|               results = execute_sql(sql) | ||||
|               execution_success = True | ||||
|               # Check if results make sense (e.g., not empty, reasonable avg salary) | ||||
|               result_valid = len(results) > 0 and 40000 < results[0][0] < 200000 | ||||
|           except Exception as e: | ||||
|               execution_success = False | ||||
|               result_valid = False | ||||
|               print(f"SQL execution error: {e}") | ||||
|  | ||||
|           return { | ||||
|               "pass": execution_success and result_valid, | ||||
|               "score": 1 if (execution_success and result_valid) else 0, | ||||
|               "reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}." | ||||
|           } | ||||
|   - description: "Find employees who earn more than their department's average salary, along with the percentage difference" | ||||
|     vars: | ||||
|       user_query: "Which employees earn above their department's average salary, and by what percentage?" | ||||
|     assert: | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|           import sqlite3 | ||||
|  | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|           def execute_sql(sql): | ||||
|               conn = sqlite3.connect('../data/data.db') | ||||
|               cursor = conn.cursor() | ||||
|               cursor.execute(sql) | ||||
|               results = cursor.fetchall() | ||||
|               conn.close() | ||||
|               return results | ||||
|  | ||||
|           sql = extract_sql(output) | ||||
|            | ||||
|           try: | ||||
|               results = execute_sql(sql) | ||||
|               execution_success = True | ||||
|               # Check if results make sense (e.g., not empty, percentage above 0) | ||||
|               result_valid = len(results) > 0 and all(row[2] > 0 for row in results) | ||||
|           except Exception as e: | ||||
|               execution_success = False | ||||
|               result_valid = False | ||||
|               print(f"SQL execution error: {e}") | ||||
|  | ||||
|           return { | ||||
|               "pass": execution_success and result_valid, | ||||
|               "score": 1 if (execution_success and result_valid) else 0, | ||||
|               "reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}." | ||||
|           } | ||||
|   - description: "Complex hierarchical query with multiple aggregations: For each department, show the highest paid employee and how their salary compares to the average of the next 3 highest paid employees in that department" | ||||
|     vars: | ||||
|       user_query: "For each department, show the name of the highest paid employee, their salary, and the percentage difference between their salary and the average salary of the next 3 highest paid employees in that department" | ||||
|     assert: | ||||
|       - type: contains | ||||
|         value: "PARTITION BY" | ||||
|       - type: contains | ||||
|         value: "AVG" | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|           import sqlite3 | ||||
|  | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|           def execute_sql(sql): | ||||
|               conn = sqlite3.connect('../data/data.db') | ||||
|               cursor = conn.cursor() | ||||
|               cursor.execute(sql) | ||||
|               results = cursor.fetchall() | ||||
|               conn.close() | ||||
|               return results | ||||
|  | ||||
|           sql = extract_sql(output) | ||||
|            | ||||
|           try: | ||||
|               results = execute_sql(sql) | ||||
|               execution_success = True | ||||
|               result_valid = len(results) > 0 and len(results[0]) == 4  # department, employee name, salary, percentage difference | ||||
|               if result_valid: | ||||
|                   for row in results: | ||||
|                       if not (isinstance(row[2], (int, float)) and isinstance(row[3], (int, float))): | ||||
|                           result_valid = False | ||||
|                           break | ||||
|           except Exception as e: | ||||
|               execution_success = False | ||||
|               result_valid = False | ||||
|               print(f"SQL execution error: {e}") | ||||
|  | ||||
|           return { | ||||
|               "pass": execution_success and result_valid, | ||||
|               "score": 1 if (execution_success and result_valid) else 0, | ||||
|               "reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid results obtained' if result_valid else 'Invalid or no results'}" | ||||
|           } | ||||
|   - description: "Department budget allocation analysis" | ||||
|     vars: | ||||
|       user_query: "Analyze the budget allocation across departments. Calculate the percentage of total salary budget each department consumes. Then, for each department, show the top 3 highest-paid employees and what percentage of the department's budget their salaries represent. Finally, calculate a 'budget efficiency' score for each department, defined as the department's percentage of total employees divided by its percentage of total salary budget." | ||||
|     assert: | ||||
|       - type: contains | ||||
|         value: "WITH" | ||||
|       - type: contains | ||||
|         value: "ROW_NUMBER()" | ||||
|       - type: contains | ||||
|         value: "PARTITION BY" | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|           import sqlite3 | ||||
|  | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|           def execute_sql(sql): | ||||
|               conn = sqlite3.connect('../data/data.db') | ||||
|               cursor = conn.cursor() | ||||
|               cursor.execute(sql) | ||||
|               results = cursor.fetchall() | ||||
|               conn.close() | ||||
|               return results | ||||
|  | ||||
|           sql = extract_sql(output) | ||||
|            | ||||
|           try: | ||||
|               results = execute_sql(sql) | ||||
|               execution_success = True | ||||
|               result_valid = len(results) > 0 and len(results[0]) >= 5  # department, budget %, top employees, their salary %, efficiency score | ||||
|               if result_valid: | ||||
|                   for row in results: | ||||
|                       if not (isinstance(row[1], float) and 0 <= row[1] <= 100 and | ||||
|                               isinstance(row[-1], float)): | ||||
|                           result_valid = False | ||||
|                           break | ||||
|           except Exception as e: | ||||
|               execution_success = False | ||||
|               result_valid = False | ||||
|               print(f"SQL execution error: {e}") | ||||
|  | ||||
|           return { | ||||
|               "pass": execution_success and result_valid, | ||||
|               "score": 1 if (execution_success and result_valid) else 0, | ||||
|               "reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid budget analysis results obtained' if result_valid else 'Invalid or incomplete analysis results'}" | ||||
|           } | ||||
|   | ||||
| @@ -153,18 +153,58 @@ def generate_prompt_with_rag(context): | ||||
|     relevant_schema = vectordb.search(user_query, k=10, similarity_threshold=0.3) | ||||
|     schema_info = "\n".join([f"Table: {item['metadata']['table']}, Column: {item['metadata']['column']}, Type: {item['metadata']['type']}" | ||||
|                              for item in relevant_schema]) | ||||
|     return f""" | ||||
|     You are an AI assistant that converts natural language queries into SQL. | ||||
|  | ||||
|     examples = """ | ||||
|     <example> | ||||
|     <query>List all employees in the HR department.</query> | ||||
|     <thought_process> | ||||
|     1. We need to join the employees and departments tables. | ||||
|     2. We'll match employees.department_id with departments.id. | ||||
|     3. We'll filter for the HR department. | ||||
|     4. We only need to return the employee names. | ||||
|     </thought_process> | ||||
|     <sql>SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'HR';</sql> | ||||
|     </example> | ||||
|  | ||||
|     <example> | ||||
|     <query>What is the average salary of employees hired in 2022?</query> | ||||
|     <thought_process> | ||||
|     1. We need to work with the employees table. | ||||
|     2. We need to filter for employees hired in 2022. | ||||
|     3. We'll use the YEAR function to extract the year from the hire_date. | ||||
|     4. We'll calculate the average of the salary column for the filtered rows. | ||||
|     </thought_process> | ||||
|     <sql>SELECT AVG(salary) FROM employees WHERE YEAR(hire_date) = 2022;</sql> | ||||
|     </example> | ||||
|     """ | ||||
|  | ||||
|     return f"""You are an AI assistant that converts natural language queries into SQL. | ||||
|     Given the following relevant columns from the SQL database schema: | ||||
|  | ||||
|     <schema> | ||||
|     {schema_info} | ||||
|     </schema> | ||||
|  | ||||
|     Convert the following natural language query into SQL: | ||||
|     Here are some examples of natural language queries, thought processes, and their corresponding SQL: | ||||
|  | ||||
|     <examples> | ||||
|     {examples} | ||||
|     </examples> | ||||
|  | ||||
|     Now, convert the following natural language query into SQL: | ||||
|     <query> | ||||
|     {user_query} | ||||
|     </query> | ||||
|  | ||||
|     First, provide your thought process within <thought_process> tags, explaining how you'll approach creating the SQL query. | ||||
|     Then, provide the SQL query within <sql> tags. | ||||
|     First, provide your thought process within <thought_process> tags, explaining how you'll approach creating the SQL query. Consider the following steps: | ||||
|     1. Identify the relevant tables and columns from the provided schema. | ||||
|     2. Determine any necessary joins between tables. | ||||
|     3. Identify any filtering conditions. | ||||
|     4. Decide on the appropriate aggregations or calculations. | ||||
|     5. Structure the query logically. | ||||
|  | ||||
|     Ensure your SQL query is compatible with SQLite syntax. | ||||
|     """ | ||||
|     Then, within <sql> tags, provide your output SQL query. | ||||
|  | ||||
|     Ensure your SQL query is compatible with SQLite syntax and uses only the tables and columns provided in the schema. | ||||
|     If you're unsure about a particular table or column, use the information available in the provided schema. | ||||
|     """ | ||||
| @@ -1601,7 +1601,7 @@ | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 16, | ||||
|    "execution_count": 18, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
| @@ -1610,7 +1610,7 @@ | ||||
|      "text": [ | ||||
|       "Generated prompt:\n", | ||||
|       "You are an AI assistant that converts natural language queries into SQL.\n", | ||||
|       "    Given the following relevant columns from the SQL database schema:\n", | ||||
|       "    Given the following SQL database schema:\n", | ||||
|       "\n", | ||||
|       "    <schema>\n", | ||||
|       "    Table: employees, Column: salary, Type: REAL\n", | ||||
| @@ -1624,7 +1624,35 @@ | ||||
|       "Table: employees, Column: hire_date, Type: DATE\n", | ||||
|       "    </schema>\n", | ||||
|       "\n", | ||||
|       "    Convert the following natural language query into SQL:\n", | ||||
|       "    Here are some examples of natural language queries, thought processes, and their corresponding SQL:\n", | ||||
|       "\n", | ||||
|       "    <examples>\n", | ||||
|       "    \n", | ||||
|       "    <example>\n", | ||||
|       "    <query>List all employees in the HR department.</query>\n", | ||||
|       "    <thought_process>\n", | ||||
|       "    1. We need to join the employees and departments tables.\n", | ||||
|       "    2. We'll match employees.department_id with departments.id.\n", | ||||
|       "    3. We'll filter for the HR department.\n", | ||||
|       "    4. We only need to return the employee names.\n", | ||||
|       "    </thought_process>\n", | ||||
|       "    <sql>SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'HR';</sql>\n", | ||||
|       "    </example>\n", | ||||
|       "\n", | ||||
|       "    <example>\n", | ||||
|       "    <query>What is the average salary of employees hired in 2022?</query>\n", | ||||
|       "    <thought_process>\n", | ||||
|       "    1. We need to work with the employees table.\n", | ||||
|       "    2. We need to filter for employees hired in 2022.\n", | ||||
|       "    3. We'll use the YEAR function to extract the year from the hire_date.\n", | ||||
|       "    4. We'll calculate the average of the salary column for the filtered rows.\n", | ||||
|       "    </thought_process>\n", | ||||
|       "    <sql>SELECT AVG(salary) FROM employees WHERE YEAR(hire_date) = 2022;</sql>\n", | ||||
|       "    </example>\n", | ||||
|       "    \n", | ||||
|       "    </examples>\n", | ||||
|       "\n", | ||||
|       "    Now, convert the following natural language query into SQL:\n", | ||||
|       "    <query>\n", | ||||
|       "    What is the average salary of employees in each department?\n", | ||||
|       "    </query>\n", | ||||
| @@ -1635,46 +1663,25 @@ | ||||
|       "\n", | ||||
|       "Generated result:\n", | ||||
|       "<thought_process>\n", | ||||
|       "To answer this query, we need to:\n", | ||||
|       "1. Join the employees and departments tables to get department information for each employee.\n", | ||||
|       "2. Group the results by department.\n", | ||||
|       "3. Calculate the average salary for each group.\n", | ||||
|       "\n", | ||||
|       "Here's the step-by-step thought process:\n", | ||||
|       "1. We'll use the employees table as our main table since it contains the salary information.\n", | ||||
|       "2. We need to join the departments table to get the department names.\n", | ||||
|       "3. The join will be on employees.department_id = departments.id\n", | ||||
|       "4. We'll group the results by department name (or id, but name is more informative).\n", | ||||
|       "5. We'll use the AVG function to calculate the average salary for each group.\n", | ||||
|       "6. We'll select the department name and the average salary in the SELECT clause.\n", | ||||
|       "1. We need to work with both the employees and departments tables to get department names and employee salaries.\n", | ||||
|       "2. We'll need to join these tables using the department_id from employees and id from departments.\n", | ||||
|       "3. We want to group the results by department, so we'll use GROUP BY.\n", | ||||
|       "4. We need to calculate the average salary for each department.\n", | ||||
|       "5. We should select the department name and the average salary for each group.\n", | ||||
|       "</thought_process>\n", | ||||
|       "\n", | ||||
|       "<sql>\n", | ||||
|       "SELECT \n", | ||||
|       "    d.name AS department_name,\n", | ||||
|       "    AVG(e.salary) AS average_salary\n", | ||||
|       "FROM \n", | ||||
|       "    employees e\n", | ||||
|       "JOIN \n", | ||||
|       "    departments d ON e.department_id = d.id\n", | ||||
|       "GROUP BY \n", | ||||
|       "    d.name\n", | ||||
|       "ORDER BY \n", | ||||
|       "    d.name\n", | ||||
|       "SELECT d.name AS department_name, AVG(e.salary) AS average_salary\n", | ||||
|       "FROM employees e\n", | ||||
|       "JOIN departments d ON e.department_id = d.id\n", | ||||
|       "GROUP BY d.name;\n", | ||||
|       "</sql>\n", | ||||
|       "\n", | ||||
|       "Extracted SQL:\n", | ||||
|       "SELECT \n", | ||||
|       "    d.name AS department_name,\n", | ||||
|       "    AVG(e.salary) AS average_salary\n", | ||||
|       "FROM \n", | ||||
|       "    employees e\n", | ||||
|       "JOIN \n", | ||||
|       "    departments d ON e.department_id = d.id\n", | ||||
|       "GROUP BY \n", | ||||
|       "    d.name\n", | ||||
|       "ORDER BY \n", | ||||
|       "    d.name\n", | ||||
|       "SELECT d.name AS department_name, AVG(e.salary) AS average_salary\n", | ||||
|       "FROM employees e\n", | ||||
|       "JOIN departments d ON e.department_id = d.id\n", | ||||
|       "GROUP BY d.name;\n", | ||||
|       "\n", | ||||
|       "Query result:\n" | ||||
|      ] | ||||
| @@ -1783,21 +1790,7 @@ | ||||
|     "    schema_info = \"\\n\".join([f\"Table: {item['metadata']['table']}, Column: {item['metadata']['column']}, Type: {item['metadata']['type']}\" \n", | ||||
|     "                             for item in relevant_schema])\n", | ||||
|     "    \n", | ||||
|     "    return f\"\"\"You are an AI assistant that converts natural language queries into SQL.\n", | ||||
|     "    Given the following relevant columns from the SQL database schema:\n", | ||||
|     "\n", | ||||
|     "    <schema>\n", | ||||
|     "    {schema_info}\n", | ||||
|     "    </schema>\n", | ||||
|     "\n", | ||||
|     "    Convert the following natural language query into SQL:\n", | ||||
|     "    <query>\n", | ||||
|     "    {query}\n", | ||||
|     "    </query>\n", | ||||
|     "\n", | ||||
|     "    Within <thought_process> tags, explain your thought process for creating the SQL query.\n", | ||||
|     "    Then, within <sql> tags, provide your output SQL query.\n", | ||||
|     "    \"\"\"\n", | ||||
|     "    return generate_prompt_with_cot(schema_info, query)\n", | ||||
|     "\n", | ||||
|     "# Test the RAG-based prompt\n", | ||||
|     "user_query = \"What is the average salary of employees in each department?\"\n", | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Mahesh Murag
					Mahesh Murag