mirror of
				https://github.com/anthropics/claude-cookbooks.git
				synced 2025-10-06 01:00:28 +03:00 
			
		
		
		
	WIP Text to SQL - checkpoint
This commit is contained in:
		| @@ -1,5 +1,3 @@ | ||||
| python_path: /opt/homebrew/bin/python3 | ||||
|  | ||||
| providers: | ||||
|   - id: anthropic:messages:claude-3-haiku-20240307 | ||||
|     label: "3 Haiku" | ||||
| @@ -32,25 +30,8 @@ tests: | ||||
|       - type: contains | ||||
|         value: "</sql>" | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|         value: file://tests/test_simple_query.py | ||||
|  | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|           def check_sql(sql): | ||||
|               required_elements = ['select', 'from employees', 'join departments', "name = 'engineering'"] | ||||
|               return all(element in sql.lower() for element in required_elements) | ||||
|  | ||||
|           sql = extract_sql(output) | ||||
|           result = check_sql(sql) | ||||
|            | ||||
|           return { | ||||
|               "pass": result, | ||||
|               "score": 1 if result else 0, | ||||
|               "reason": f"SQL query {'is correct' if result else 'is incorrect or not found'}" | ||||
|           } | ||||
|   - description: "Validate count of employees in Engineering department" | ||||
|     vars: | ||||
|       user_query: "How many employees are in the Engineering department?" | ||||
| @@ -60,41 +41,8 @@ tests: | ||||
|       - type: contains | ||||
|         value: "</sql>" | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|           import sqlite3 | ||||
|         value: file://tests/test_employee_count.py | ||||
|  | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|           def execute_sql(sql): | ||||
|               conn = sqlite3.connect('../data/data.db') | ||||
|               cursor = conn.cursor() | ||||
|               cursor.execute(sql) | ||||
|               results = cursor.fetchall() | ||||
|               conn.close() | ||||
|               return results | ||||
|  | ||||
|           sql = extract_sql(output) | ||||
|            | ||||
|           try: | ||||
|               results = execute_sql(sql) | ||||
|               count = results[0][0] if results else 0 | ||||
|               execution_success = True | ||||
|           except Exception as e: | ||||
|               execution_success = False | ||||
|               count = 0 | ||||
|               print(f"SQL execution error: {e}") | ||||
|  | ||||
|           expected_count = 20 | ||||
|  | ||||
|           return { | ||||
|               "pass": execution_success and count == expected_count, | ||||
|               "score": 1 if (execution_success and count == expected_count) else 0, | ||||
|               "reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. " | ||||
|                         f"Returned count: {count}, Expected count: {expected_count}." | ||||
|           } | ||||
|   - description: "Check specific employee details in Engineering department" | ||||
|     vars: | ||||
|       user_query: "Give me the name, age, and salary of the oldest employee in the Engineering department." | ||||
| @@ -104,133 +52,23 @@ tests: | ||||
|       - type: contains | ||||
|         value: "</sql>" | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|           import sqlite3 | ||||
|         value: file://tests/test_employee_details.py | ||||
|  | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|           def execute_sql(sql): | ||||
|               conn = sqlite3.connect('../data/data.db') | ||||
|               cursor = conn.cursor() | ||||
|               cursor.execute(sql) | ||||
|               results = cursor.fetchall() | ||||
|               conn.close() | ||||
|               return results | ||||
|  | ||||
|           sql = extract_sql(output) | ||||
|            | ||||
|           try: | ||||
|               results = execute_sql(sql) | ||||
|               row = results[0] if results else None | ||||
|               execution_success = True | ||||
|           except Exception as e: | ||||
|               execution_success = False | ||||
|               row = None | ||||
|               print(f"SQL execution error: {e}") | ||||
|  | ||||
|           expected_result = { | ||||
|               "name": "Julia Clark", | ||||
|               "age": 64, | ||||
|               "salary": 103699.17 | ||||
|           } | ||||
|  | ||||
|           if row: | ||||
|               actual_result = { | ||||
|                   "name": row[0], | ||||
|                   "age": row[1], | ||||
|                   "salary": row[2] | ||||
|               } | ||||
|               data_match = actual_result == expected_result | ||||
|           else: | ||||
|               data_match = False | ||||
|  | ||||
|           return { | ||||
|               "pass": execution_success and data_match, | ||||
|               "score": 1 if (execution_success and data_match) else 0, | ||||
|               "reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. " | ||||
|                         f"Data {'matches' if data_match else 'does not match'} expected result. " | ||||
|                         f"Actual: {actual_result if row else 'No data'}, Expected: {expected_result}" | ||||
|           } | ||||
|   - description: "Find the average salary of employees in departments located in 'New York', but only for departments with more than 5 employees" | ||||
|     vars: | ||||
|       user_query: "What's the average salary for employees in New York-based departments that have more than 5 staff members?" | ||||
|     assert: | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|           import sqlite3 | ||||
|         value: file://tests/test_average_salary.py | ||||
|  | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|           def execute_sql(sql): | ||||
|               conn = sqlite3.connect('../data/data.db') | ||||
|               cursor = conn.cursor() | ||||
|               cursor.execute(sql) | ||||
|               results = cursor.fetchall() | ||||
|               conn.close() | ||||
|               return results | ||||
|  | ||||
|           sql = extract_sql(output) | ||||
|            | ||||
|           try: | ||||
|               results = execute_sql(sql) | ||||
|               execution_success = True | ||||
|               # Check if results make sense (e.g., not empty, reasonable avg salary) | ||||
|               result_valid = len(results) > 0 and 40000 < results[0][0] < 200000 | ||||
|           except Exception as e: | ||||
|               execution_success = False | ||||
|               result_valid = False | ||||
|               print(f"SQL execution error: {e}") | ||||
|  | ||||
|           return { | ||||
|               "pass": execution_success and result_valid, | ||||
|               "score": 1 if (execution_success and result_valid) else 0, | ||||
|               "reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}." | ||||
|           } | ||||
|   - description: "Find employees who earn more than their department's average salary, along with the percentage difference" | ||||
|     vars: | ||||
|       user_query: "Which employees earn above their department's average salary, and by what percentage?" | ||||
|     assert: | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|           import sqlite3 | ||||
|         value: file://tests/test_above_average_salary.py | ||||
|  | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|           def execute_sql(sql): | ||||
|               conn = sqlite3.connect('../data/data.db') | ||||
|               cursor = conn.cursor() | ||||
|               cursor.execute(sql) | ||||
|               results = cursor.fetchall() | ||||
|               conn.close() | ||||
|               return results | ||||
|  | ||||
|           sql = extract_sql(output) | ||||
|            | ||||
|           try: | ||||
|               results = execute_sql(sql) | ||||
|               execution_success = True | ||||
|               # Check if results make sense (e.g., not empty, percentage above 0) | ||||
|               result_valid = len(results) > 0 and all(row[2] > 0 for row in results) | ||||
|           except Exception as e: | ||||
|               execution_success = False | ||||
|               result_valid = False | ||||
|               print(f"SQL execution error: {e}") | ||||
|  | ||||
|           return { | ||||
|               "pass": execution_success and result_valid, | ||||
|               "score": 1 if (execution_success and result_valid) else 0, | ||||
|               "reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}." | ||||
|           } | ||||
|   - description: "Complex hierarchical query with multiple aggregations: For each department, show the highest paid employee and how their salary compares to the average of the next 3 highest paid employees in that department" | ||||
|   - description: "Complex hierarchical query with multiple aggregations" | ||||
|     vars: | ||||
|       user_query: "For each department, show the name of the highest paid employee, their salary, and the percentage difference between their salary and the average salary of the next 3 highest paid employees in that department" | ||||
|     assert: | ||||
| @@ -239,43 +77,8 @@ tests: | ||||
|       - type: contains | ||||
|         value: "AVG" | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|           import sqlite3 | ||||
|         value: file://tests/test_hierarchical_query.py | ||||
|  | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|           def execute_sql(sql): | ||||
|               conn = sqlite3.connect('../data/data.db') | ||||
|               cursor = conn.cursor() | ||||
|               cursor.execute(sql) | ||||
|               results = cursor.fetchall() | ||||
|               conn.close() | ||||
|               return results | ||||
|  | ||||
|           sql = extract_sql(output) | ||||
|            | ||||
|           try: | ||||
|               results = execute_sql(sql) | ||||
|               execution_success = True | ||||
|               result_valid = len(results) > 0 and len(results[0]) == 4  # department, employee name, salary, percentage difference | ||||
|               if result_valid: | ||||
|                   for row in results: | ||||
|                       if not (isinstance(row[2], (int, float)) and isinstance(row[3], (int, float))): | ||||
|                           result_valid = False | ||||
|                           break | ||||
|           except Exception as e: | ||||
|               execution_success = False | ||||
|               result_valid = False | ||||
|               print(f"SQL execution error: {e}") | ||||
|  | ||||
|           return { | ||||
|               "pass": execution_success and result_valid, | ||||
|               "score": 1 if (execution_success and result_valid) else 0, | ||||
|               "reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid results obtained' if result_valid else 'Invalid or no results'}" | ||||
|           } | ||||
|   - description: "Department budget allocation analysis" | ||||
|     vars: | ||||
|       user_query: "Analyze the budget allocation across departments. Calculate the percentage of total salary budget each department consumes. Then, for each department, show the top 3 highest-paid employees and what percentage of the department's budget their salaries represent. Finally, calculate a 'budget efficiency' score for each department, defined as the department's percentage of total employees divided by its percentage of total salary budget." | ||||
| @@ -287,41 +90,4 @@ tests: | ||||
|       - type: contains | ||||
|         value: "PARTITION BY" | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|           import sqlite3 | ||||
|  | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|           def execute_sql(sql): | ||||
|               conn = sqlite3.connect('../data/data.db') | ||||
|               cursor = conn.cursor() | ||||
|               cursor.execute(sql) | ||||
|               results = cursor.fetchall() | ||||
|               conn.close() | ||||
|               return results | ||||
|  | ||||
|           sql = extract_sql(output) | ||||
|            | ||||
|           try: | ||||
|               results = execute_sql(sql) | ||||
|               execution_success = True | ||||
|               result_valid = len(results) > 0 and len(results[0]) >= 5  # department, budget %, top employees, their salary %, efficiency score | ||||
|               if result_valid: | ||||
|                   for row in results: | ||||
|                       if not (isinstance(row[1], float) and 0 <= row[1] <= 100 and | ||||
|                               isinstance(row[-1], float)): | ||||
|                           result_valid = False | ||||
|                           break | ||||
|           except Exception as e: | ||||
|               execution_success = False | ||||
|               result_valid = False | ||||
|               print(f"SQL execution error: {e}") | ||||
|  | ||||
|           return { | ||||
|               "pass": execution_success and result_valid, | ||||
|               "score": 1 if (execution_success and result_valid) else 0, | ||||
|               "reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid budget analysis results obtained' if result_valid else 'Invalid or incomplete analysis results'}" | ||||
|           } | ||||
|         value: file://tests/test_budget_allocation.py | ||||
| @@ -0,0 +1,19 @@ | ||||
| from utils import extract_sql, execute_sql | ||||
|  | ||||
| def get_assert(output, context): | ||||
|     sql = extract_sql(output) | ||||
|      | ||||
|     try: | ||||
|         results = execute_sql(sql) | ||||
|         execution_success = True | ||||
|         result_valid = len(results) > 0 and all(row[2] > 0 for row in results) | ||||
|     except Exception as e: | ||||
|         execution_success = False | ||||
|         result_valid = False | ||||
|         print(f"SQL execution error: {e}") | ||||
|  | ||||
|     return { | ||||
|         "pass": execution_success and result_valid, | ||||
|         "score": 1 if (execution_success and result_valid) else 0, | ||||
|         "reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}." | ||||
|     } | ||||
							
								
								
									
										19
									
								
								skills/text_to_sql/evaluation/tests/test_average_salary.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								skills/text_to_sql/evaluation/tests/test_average_salary.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| from utils import extract_sql, execute_sql | ||||
|  | ||||
| def get_assert(output, context): | ||||
|     sql = extract_sql(output) | ||||
|      | ||||
|     try: | ||||
|         results = execute_sql(sql) | ||||
|         execution_success = True | ||||
|         result_valid = len(results) > 0 and 40000 < results[0][0] < 200000 | ||||
|     except Exception as e: | ||||
|         execution_success = False | ||||
|         result_valid = False | ||||
|         print(f"SQL execution error: {e}") | ||||
|  | ||||
|     return { | ||||
|         "pass": execution_success and result_valid, | ||||
|         "score": 1 if (execution_success and result_valid) else 0, | ||||
|         "reason": f"SQL {'executed successfully with valid results' if (execution_success and result_valid) else 'failed or produced invalid results'}." | ||||
|     } | ||||
| @@ -0,0 +1,25 @@ | ||||
| from utils import extract_sql, execute_sql | ||||
|  | ||||
| def get_assert(output, context): | ||||
|     sql = extract_sql(output) | ||||
|      | ||||
|     try: | ||||
|         results = execute_sql(sql) | ||||
|         execution_success = True | ||||
|         result_valid = len(results) > 0 and len(results[0]) >= 5  # department, budget %, top employees, their salary %, efficiency score | ||||
|         if result_valid: | ||||
|             for row in results: | ||||
|                 if not (isinstance(row[1], float) and 0 <= row[1] <= 100 and | ||||
|                         isinstance(row[-1], float)): | ||||
|                     result_valid = False | ||||
|                     break | ||||
|     except Exception as e: | ||||
|         execution_success = False | ||||
|         result_valid = False | ||||
|         print(f"SQL execution error: {e}") | ||||
|  | ||||
|     return { | ||||
|         "pass": execution_success and result_valid, | ||||
|         "score": 1 if (execution_success and result_valid) else 0, | ||||
|         "reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid budget analysis results obtained' if result_valid else 'Invalid or incomplete analysis results'}" | ||||
|     } | ||||
							
								
								
									
										22
									
								
								skills/text_to_sql/evaluation/tests/test_employee_count.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								skills/text_to_sql/evaluation/tests/test_employee_count.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| from utils import extract_sql, execute_sql | ||||
|  | ||||
| def get_assert(output, context): | ||||
|     sql = extract_sql(output) | ||||
|      | ||||
|     try: | ||||
|         results = execute_sql(sql) | ||||
|         count = results[0][0] if results else 0 | ||||
|         execution_success = True | ||||
|     except Exception as e: | ||||
|         execution_success = False | ||||
|         count = 0 | ||||
|         print(f"SQL execution error: {e}") | ||||
|  | ||||
|     expected_count = 20 | ||||
|  | ||||
|     return { | ||||
|         "pass": execution_success and count == expected_count, | ||||
|         "score": 1 if (execution_success and count == expected_count) else 0, | ||||
|         "reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. " | ||||
|                   f"Returned count: {count}, Expected count: {expected_count}." | ||||
|     } | ||||
							
								
								
									
										37
									
								
								skills/text_to_sql/evaluation/tests/test_employee_details.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								skills/text_to_sql/evaluation/tests/test_employee_details.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| from utils import extract_sql, execute_sql | ||||
|  | ||||
| def get_assert(output, context): | ||||
|     sql = extract_sql(output) | ||||
|      | ||||
|     try: | ||||
|         results = execute_sql(sql) | ||||
|         row = results[0] if results else None | ||||
|         execution_success = True | ||||
|     except Exception as e: | ||||
|         execution_success = False | ||||
|         row = None | ||||
|         print(f"SQL execution error: {e}") | ||||
|  | ||||
|     expected_result = { | ||||
|         "name": "Julia Clark", | ||||
|         "age": 64, | ||||
|         "salary": 103699.17 | ||||
|     } | ||||
|  | ||||
|     if row: | ||||
|         actual_result = { | ||||
|             "name": row[0], | ||||
|             "age": row[1], | ||||
|             "salary": row[2] | ||||
|         } | ||||
|         data_match = actual_result == expected_result | ||||
|     else: | ||||
|         data_match = False | ||||
|  | ||||
|     return { | ||||
|         "pass": execution_success and data_match, | ||||
|         "score": 1 if (execution_success and data_match) else 0, | ||||
|         "reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. " | ||||
|                   f"Data {'matches' if data_match else 'does not match'} expected result. " | ||||
|                   f"Actual: {actual_result if row else 'No data'}, Expected: {expected_result}" | ||||
|     } | ||||
| @@ -0,0 +1,24 @@ | ||||
| from utils import extract_sql, execute_sql | ||||
|  | ||||
| def get_assert(output, context): | ||||
|     sql = extract_sql(output) | ||||
|      | ||||
|     try: | ||||
|         results = execute_sql(sql) | ||||
|         execution_success = True | ||||
|         result_valid = len(results) > 0 and len(results[0]) == 4  # department, employee name, salary, percentage difference | ||||
|         if result_valid: | ||||
|             for row in results: | ||||
|                 if not (isinstance(row[2], (int, float)) and isinstance(row[3], (int, float))): | ||||
|                     result_valid = False | ||||
|                     break | ||||
|     except Exception as e: | ||||
|         execution_success = False | ||||
|         result_valid = False | ||||
|         print(f"SQL execution error: {e}") | ||||
|  | ||||
|     return { | ||||
|         "pass": execution_success and result_valid, | ||||
|         "score": 1 if (execution_success and result_valid) else 0, | ||||
|         "reason": f"SQL {'executed successfully' if execution_success else 'failed to execute'}. {'Valid results obtained' if result_valid else 'Invalid or no results'}" | ||||
|     } | ||||
							
								
								
									
										12
									
								
								skills/text_to_sql/evaluation/tests/test_simple_query.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								skills/text_to_sql/evaluation/tests/test_simple_query.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| from utils import extract_sql | ||||
|  | ||||
| def get_assert(output, context): | ||||
|     sql = extract_sql(output) | ||||
|     required_elements = ['select', 'from employees', 'join departments', "name = 'engineering'"] | ||||
|     result = all(element in sql.lower() for element in required_elements) | ||||
|      | ||||
|     return { | ||||
|         "pass": result, | ||||
|         "score": 1 if result else 0, | ||||
|         "reason": f"SQL query {'is correct' if result else 'is incorrect or not found'}" | ||||
|     } | ||||
							
								
								
									
										15
									
								
								skills/text_to_sql/evaluation/tests/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								skills/text_to_sql/evaluation/tests/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| # sql_utils.py | ||||
| import re | ||||
| import sqlite3 | ||||
|  | ||||
| def extract_sql(text): | ||||
|     match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|     return match.group(1).strip() if match else "" | ||||
|  | ||||
| def execute_sql(sql): | ||||
|     conn = sqlite3.connect('../data/data.db') | ||||
|     cursor = conn.cursor() | ||||
|     cursor.execute(sql) | ||||
|     results = cursor.fetchall() | ||||
|     conn.close() | ||||
|     return results | ||||
		Reference in New Issue
	
	Block a user
	 Mahesh Murag
					Mahesh Murag