WIP Text to SQL - 3/4 evals

This commit is contained in:
Mahesh Murag
2024-09-23 23:32:15 +02:00
parent c982d54567
commit 2de1571cfd
5 changed files with 1050 additions and 90 deletions

Binary file not shown.

View File

@@ -1,37 +1,51 @@
# Learn more about building a configuration: https://promptfoo.dev/docs/configuration/guide
description: "Text to SQL Evaluation"
# promptfooconfig.yaml
prompts:
- "Write a tweet about {{topic}}"
- "Write a concise, funny tweet about {{topic}}"
python_path: /opt/homebrew/bin/python3
providers:
- id: anthropic:messages:claude-3-5-sonnet-20240620
label: "3.5 Sonnet"
config:
max_tokens: 4096
temperature: 0
prompts:
- prompts.py:generate_prompt
- prompts.py:generate_prompt_with_examples
- prompts.py:generate_prompt_with_cot
# - prompts.py:generate_prompt_with_rag
tests:
- vars:
topic: avocado toast
- description: "Simple query for employee names in Engineering"
prompt: prompts.py:basic_text_to_sql
vars:
user_query: "What are the names of all employees in the Engineering department?"
assert:
# For more information on assertions, see https://promptfoo.dev/docs/configuration/expected-outputs
- type: contains
value: "<sql>"
- type: contains
value: "</sql>"
- type: python
value: |
import re
# Make sure output contains the word "avocado"
- type: icontains
value: avocado
def extract_sql(text):
match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL)
return match.group(1).strip() if match else ""
# Prefer shorter outputs
- type: javascript
value: 1 / (output.length + 1)
def check_sql(sql):
required_elements = ['select', 'from employees', 'join departments', "name = 'engineering'"]
return all(element in sql.lower() for element in required_elements)
- vars:
topic: new york city
assert:
# For more information on model-graded evals, see https://promptfoo.dev/docs/configuration/expected-outputs/model-graded
- type: llm-rubric
value: ensure that the output is funny
sql = extract_sql(output)
result = check_sql(sql)
return {
"pass": result,
"score": 1 if result else 0,
"reason": f"SQL query {'is correct' if result else 'is incorrect or not found'}"
}
output:
- type: csv
path: ./results.csv

View File

@@ -0,0 +1,146 @@
import sqlite3
def get_schema_info(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
schema_info = []
# Get all tables
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
for (table_name,) in tables:
# Get columns for this table
cursor.execute(f"PRAGMA table_info({table_name})")
columns = cursor.fetchall()
table_info = f"Table: {table_name}\n"
table_info += "\n".join(f" - {col[1]} ({col[2]})" for col in columns)
schema_info.append(table_info)
conn.close()
return "\n\n".join(schema_info)
def generate_prompt(user_query, db_path='../data/data.db'):
schema = get_schema_info(db_path)
return f"""
You are an AI assistant that converts natural language queries into SQL.
Given the following SQL database schema:
{schema}
Convert the following natural language query into SQL:
{user_query}
Provide only the SQL query in your response, enclosed within <sql> tags.
"""
def generate_prompt_with_examples(user_query, db_path='../data/data.db'):
examples = """
Example 1:
<query>List all employees in the HR department.</<query>
<output>SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'HR';</output>
Example 2:
User: What is the average salary of employees in the Engineering department?
SQL: SELECT AVG(e.salary) FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering';
Example 3:
User: Who is the oldest employee?
SQL: SELECT name, age FROM employees ORDER BY age DESC LIMIT 1;
"""
schema = get_schema_info(db_path)
return f"""
You are an AI assistant that converts natural language queries into SQL.
Given the following SQL database schema:
<schema>
{schema}
</schema>
Here are some examples of natural language queries and their corresponding SQL:
<examples>
{examples}
</examples>
Now, convert the following natural language query into SQL:
<query>
{user_query}
</query>
Provide only the SQL query in your response, enclosed within <sql> tags.
"""
def generate_prompt_with_cot(user_query, db_path='../data/data.db'):
schema = get_schema_info(db_path)
examples = """
<example>
<query>List all employees in the HR department.</query>
<thought_process>
1. We need to join the employees and departments tables.
2. We'll match employees.department_id with departments.id.
3. We'll filter for the HR department.
4. We only need to return the employee names.
</thought_process>
<sql>SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'HR';</sql>
</example>
"""
return f"""
You are an AI assistant that converts natural language queries into SQL.
Given the following SQL database schema:
{schema}
Here are some examples of natural language queries, thought processes, and corresponding SQL queries:
{examples}
Now, convert the following natural language query into SQL:
{user_query}
First, provide your thought process within <thought_process> tags, explaining how you'll approach creating the SQL query.
Then, provide the SQL query within <sql> tags.
"""
def generate_prompt_with_rag(user_query, db_path='../data/data.db'):
from vectordb import VectorDB
# Load the vector database
vectordb = VectorDB()
vectordb.load_db()
if not vectordb.embeddings:
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
schema_data = [
{"text": f"Table: {table[0]}, Column: {col[1]}, Type: {col[2]}",
"metadata": {"table": table[0], "column": col[1], "type": col[2]}}
for table in cursor.fetchall()
for col in cursor.execute(f"PRAGMA table_info({table[0]})").fetchall()
]
vectordb.load_data(schema_data)
relevant_schema = vectordb.search(user_query, k=10, similarity_threshold=0.3)
schema_info = "\n".join([f"Table: {item['metadata']['table']}, Column: {item['metadata']['column']}, Type: {item['metadata']['type']}"
for item in relevant_schema])
return f"""
You are an AI assistant that converts natural language queries into SQL.
Given the following relevant columns from the SQL database schema:
{schema_info}
Convert the following natural language query into SQL:
{user_query}
First, provide your thought process within <thought_process> tags, explaining how you'll approach creating the SQL query.
Then, provide the SQL query within <sql> tags.
Ensure your SQL query is compatible with SQLite syntax.
"""

View File

@@ -0,0 +1,43 @@
import os
import numpy as np
import voyageai
import pickle
import json
class VectorDB:
def __init__(self, db_path='../data/vector_db.pkl'):
self.client = voyageai.Client(api_key=os.getenv("VOYAGE_API_KEY"))
self.db_path = db_path
self.load_db()
def load_db(self):
if os.path.exists(self.db_path):
with open(self.db_path, "rb") as file:
data = pickle.load(file)
self.embeddings, self.metadata, self.query_cache = data['embeddings'], data['metadata'], json.loads(data['query_cache'])
else:
self.embeddings, self.metadata, self.query_cache = [], [], {}
def load_data(self, data):
if not self.embeddings:
texts = [item["text"] for item in data]
self.embeddings = [emb for batch in range(0, len(texts), 128)
for emb in self.client.embed(texts[batch:batch+128], model="voyage-2").embeddings]
self.metadata = [item["metadata"] for item in data] # Store only the inner metadata
self.save_db()
def search(self, query, k=5, similarity_threshold=0.3):
if query not in self.query_cache:
self.query_cache[query] = self.client.embed([query], model="voyage-2").embeddings[0]
self.save_db()
similarities = np.dot(self.embeddings, self.query_cache[query])
top_indices = np.argsort(similarities)[::-1]
return [{"metadata": self.metadata[i], "similarity": similarities[i]}
for i in top_indices if similarities[i] >= similarity_threshold][:k]
def save_db(self):
with open(self.db_path, "wb") as file:
pickle.dump({"embeddings": self.embeddings, "metadata": self.metadata,
"query_cache": json.dumps(self.query_cache)}, file)

File diff suppressed because it is too large Load Diff