WIP Text to SQL - count rows

This commit is contained in:
Mahesh Murag
2024-09-23 23:38:19 +02:00
parent 2de1571cfd
commit 29e1f97dff

View File

@@ -15,7 +15,7 @@ prompts:
# - prompts.py:generate_prompt_with_rag
tests:
- description: "Simple query for employee names in Engineering"
- description: "Check syntax of simple query"
prompt: prompts.py:basic_text_to_sql
vars:
user_query: "What are the names of all employees in the Engineering department?"
@@ -44,7 +44,51 @@ tests:
"score": 1 if result else 0,
"reason": f"SQL query {'is correct' if result else 'is incorrect or not found'}"
}
- description: "Validate count of employees in Engineering department"
prompt: prompts.py:generate_prompt
vars:
user_query: "How many employees are in the Engineering department?"
assert:
- type: contains
value: "<sql>"
- type: contains
value: "</sql>"
- type: python
value: |
import re
import sqlite3
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
def execute_sql(sql):
conn = sqlite3.connect('../data/data.db')
cursor = conn.cursor()
cursor.execute(sql)
results = cursor.fetchall()
conn.close()
return results
sql = extract_sql(output)
try:
results = execute_sql(sql)
count = results[0][0] if results else 0
execution_success = True
except Exception as e:
execution_success = False
count = 0
print(f"SQL execution error: {e}")
expected_count = 3
return {
"pass": execution_success and count == expected_count,
"score": 1 if (execution_success and count == expected_count) else 0,
"reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. "
f"Returned count: {count}, Expected count: {expected_count}."
}
output:
- type: csv