mirror of
				https://github.com/anthropics/claude-cookbooks.git
				synced 2025-10-06 01:00:28 +03:00 
			
		
		
		
	WIP Text to SQL - 3/4 evals
This commit is contained in:
		
										
											Binary file not shown.
										
									
								
							| @@ -1,37 +1,51 @@ | ||||
| # Learn more about building a configuration: https://promptfoo.dev/docs/configuration/guide | ||||
| description: "Text to SQL Evaluation" | ||||
| # promptfooconfig.yaml | ||||
|  | ||||
| prompts: | ||||
|   - "Write a tweet about {{topic}}" | ||||
|   - "Write a concise, funny tweet about {{topic}}" | ||||
|    | ||||
| python_path: /opt/homebrew/bin/python3 | ||||
|  | ||||
| providers: | ||||
|   - id: anthropic:messages:claude-3-5-sonnet-20240620 | ||||
|     label: "3.5 Sonnet" | ||||
|     config: | ||||
|       max_tokens: 4096 | ||||
|       temperature: 0 | ||||
|    | ||||
|  | ||||
| prompts:  | ||||
|   - prompts.py:generate_prompt | ||||
|   - prompts.py:generate_prompt_with_examples | ||||
|   - prompts.py:generate_prompt_with_cot | ||||
|   # - prompts.py:generate_prompt_with_rag | ||||
|  | ||||
| tests: | ||||
|   - vars: | ||||
|       topic: avocado toast | ||||
|   - description: "Simple query for employee names in Engineering" | ||||
|     prompt: prompts.py:basic_text_to_sql | ||||
|     vars: | ||||
|       user_query: "What are the names of all employees in the Engineering department?" | ||||
|     assert: | ||||
|       # For more information on assertions, see https://promptfoo.dev/docs/configuration/expected-outputs | ||||
|       - type: contains | ||||
|         value: "<sql>" | ||||
|       - type: contains | ||||
|         value: "</sql>" | ||||
|       - type: python | ||||
|         value: | | ||||
|           import re | ||||
|  | ||||
|       # Make sure output contains the word "avocado" | ||||
|       - type: icontains | ||||
|         value: avocado | ||||
|           def extract_sql(text): | ||||
|               match = re.search(r'<sql>(.*?)</sql>', text, re.DOTALL) | ||||
|               return match.group(1).strip() if match else "" | ||||
|  | ||||
|       # Prefer shorter outputs | ||||
|       - type: javascript | ||||
|         value: 1 / (output.length + 1) | ||||
|           def check_sql(sql): | ||||
|               required_elements = ['select', 'from employees', 'join departments', "name = 'engineering'"] | ||||
|               return all(element in sql.lower() for element in required_elements) | ||||
|  | ||||
|   - vars: | ||||
|       topic: new york city | ||||
|     assert: | ||||
|       # For more information on model-graded evals, see https://promptfoo.dev/docs/configuration/expected-outputs/model-graded | ||||
|       - type: llm-rubric | ||||
|         value: ensure that the output is funny | ||||
|           sql = extract_sql(output) | ||||
|           result = check_sql(sql) | ||||
|            | ||||
|           return { | ||||
|               "pass": result, | ||||
|               "score": 1 if result else 0, | ||||
|               "reason": f"SQL query {'is correct' if result else 'is incorrect or not found'}" | ||||
|           } | ||||
|  | ||||
|  | ||||
| output: | ||||
|   - type: csv | ||||
|     path: ./results.csv | ||||
							
								
								
									
										146
									
								
								skills/text_to_sql/evaluation/prompts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								skills/text_to_sql/evaluation/prompts.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,146 @@ | ||||
| import sqlite3 | ||||
|  | ||||
| def get_schema_info(db_path): | ||||
|     conn = sqlite3.connect(db_path) | ||||
|     cursor = conn.cursor() | ||||
|      | ||||
|     schema_info = [] | ||||
|      | ||||
|     # Get all tables | ||||
|     cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | ||||
|     tables = cursor.fetchall() | ||||
|      | ||||
|     for (table_name,) in tables: | ||||
|         # Get columns for this table | ||||
|         cursor.execute(f"PRAGMA table_info({table_name})") | ||||
|         columns = cursor.fetchall() | ||||
|          | ||||
|         table_info = f"Table: {table_name}\n" | ||||
|         table_info += "\n".join(f"  - {col[1]} ({col[2]})" for col in columns) | ||||
|         schema_info.append(table_info) | ||||
|      | ||||
|     conn.close() | ||||
|     return "\n\n".join(schema_info) | ||||
|  | ||||
| def generate_prompt(user_query, db_path='../data/data.db'): | ||||
|     schema = get_schema_info(db_path) | ||||
|     return f""" | ||||
|     You are an AI assistant that converts natural language queries into SQL.  | ||||
|     Given the following SQL database schema: | ||||
|  | ||||
|     {schema} | ||||
|  | ||||
|     Convert the following natural language query into SQL: | ||||
|  | ||||
|     {user_query} | ||||
|  | ||||
|     Provide only the SQL query in your response, enclosed within <sql> tags. | ||||
|     """ | ||||
|  | ||||
| def generate_prompt_with_examples(user_query, db_path='../data/data.db'): | ||||
|     examples = """ | ||||
|         Example 1: | ||||
|         <query>List all employees in the HR department.</<query> | ||||
|         <output>SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'HR';</output> | ||||
|  | ||||
|         Example 2: | ||||
|         User: What is the average salary of employees in the Engineering department? | ||||
|         SQL: SELECT AVG(e.salary) FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'Engineering'; | ||||
|  | ||||
|         Example 3: | ||||
|         User: Who is the oldest employee? | ||||
|         SQL: SELECT name, age FROM employees ORDER BY age DESC LIMIT 1; | ||||
|     """ | ||||
|  | ||||
|     schema = get_schema_info(db_path) | ||||
|  | ||||
|     return f""" | ||||
|         You are an AI assistant that converts natural language queries into SQL. | ||||
|         Given the following SQL database schema: | ||||
|  | ||||
|         <schema> | ||||
|         {schema} | ||||
|         </schema> | ||||
|  | ||||
|         Here are some examples of natural language queries and their corresponding SQL: | ||||
|  | ||||
|         <examples> | ||||
|         {examples} | ||||
|         </examples> | ||||
|  | ||||
|         Now, convert the following natural language query into SQL: | ||||
|         <query> | ||||
|         {user_query} | ||||
|         </query> | ||||
|  | ||||
|         Provide only the SQL query in your response, enclosed within <sql> tags. | ||||
|     """ | ||||
|  | ||||
| def generate_prompt_with_cot(user_query, db_path='../data/data.db'): | ||||
|     schema = get_schema_info(db_path) | ||||
|     examples = """ | ||||
|     <example> | ||||
|     <query>List all employees in the HR department.</query> | ||||
|     <thought_process> | ||||
|     1. We need to join the employees and departments tables. | ||||
|     2. We'll match employees.department_id with departments.id. | ||||
|     3. We'll filter for the HR department. | ||||
|     4. We only need to return the employee names. | ||||
|     </thought_process> | ||||
|     <sql>SELECT e.name FROM employees e JOIN departments d ON e.department_id = d.id WHERE d.name = 'HR';</sql> | ||||
|     </example> | ||||
|     """ | ||||
|     return f""" | ||||
|     You are an AI assistant that converts natural language queries into SQL. | ||||
|     Given the following SQL database schema: | ||||
|  | ||||
|     {schema} | ||||
|  | ||||
|     Here are some examples of natural language queries, thought processes, and corresponding SQL queries: | ||||
|  | ||||
|     {examples} | ||||
|  | ||||
|     Now, convert the following natural language query into SQL: | ||||
|     {user_query} | ||||
|  | ||||
|     First, provide your thought process within <thought_process> tags, explaining how you'll approach creating the SQL query. | ||||
|     Then, provide the SQL query within <sql> tags. | ||||
|     """ | ||||
|  | ||||
| def generate_prompt_with_rag(user_query, db_path='../data/data.db'): | ||||
|     from vectordb import VectorDB | ||||
|      | ||||
|     # Load the vector database | ||||
|     vectordb = VectorDB() | ||||
|     vectordb.load_db() | ||||
|  | ||||
|     if not vectordb.embeddings: | ||||
|         with sqlite3.connect(db_path) as conn: | ||||
|             cursor = conn.cursor() | ||||
|             cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | ||||
|             schema_data = [ | ||||
|                 {"text": f"Table: {table[0]}, Column: {col[1]}, Type: {col[2]}",  | ||||
|                 "metadata": {"table": table[0], "column": col[1], "type": col[2]}} | ||||
|                 for table in cursor.fetchall() | ||||
|                 for col in cursor.execute(f"PRAGMA table_info({table[0]})").fetchall() | ||||
|             ] | ||||
|         vectordb.load_data(schema_data) | ||||
|  | ||||
|     relevant_schema = vectordb.search(user_query, k=10, similarity_threshold=0.3) | ||||
|     schema_info = "\n".join([f"Table: {item['metadata']['table']}, Column: {item['metadata']['column']}, Type: {item['metadata']['type']}" | ||||
|                              for item in relevant_schema]) | ||||
|     return f""" | ||||
|     You are an AI assistant that converts natural language queries into SQL. | ||||
|     Given the following relevant columns from the SQL database schema: | ||||
|  | ||||
|     {schema_info} | ||||
|  | ||||
|     Convert the following natural language query into SQL: | ||||
|  | ||||
|     {user_query} | ||||
|  | ||||
|     First, provide your thought process within <thought_process> tags, explaining how you'll approach creating the SQL query. | ||||
|     Then, provide the SQL query within <sql> tags. | ||||
|  | ||||
|     Ensure your SQL query is compatible with SQLite syntax. | ||||
|     """ | ||||
							
								
								
									
										43
									
								
								skills/text_to_sql/evaluation/vectordb.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								skills/text_to_sql/evaluation/vectordb.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | ||||
| import os | ||||
| import numpy as np | ||||
| import voyageai | ||||
| import pickle | ||||
| import json | ||||
|  | ||||
| class VectorDB: | ||||
|     def __init__(self, db_path='../data/vector_db.pkl'): | ||||
|         self.client = voyageai.Client(api_key=os.getenv("VOYAGE_API_KEY")) | ||||
|         self.db_path = db_path | ||||
|         self.load_db() | ||||
|  | ||||
|     def load_db(self): | ||||
|         if os.path.exists(self.db_path): | ||||
|             with open(self.db_path, "rb") as file: | ||||
|                 data = pickle.load(file) | ||||
|             self.embeddings, self.metadata, self.query_cache = data['embeddings'], data['metadata'], json.loads(data['query_cache']) | ||||
|         else: | ||||
|             self.embeddings, self.metadata, self.query_cache = [], [], {} | ||||
|  | ||||
|     def load_data(self, data): | ||||
|         if not self.embeddings: | ||||
|                 texts = [item["text"] for item in data] | ||||
|                 self.embeddings = [emb for batch in range(0, len(texts), 128)  | ||||
|                                     for emb in self.client.embed(texts[batch:batch+128], model="voyage-2").embeddings] | ||||
|                 self.metadata = [item["metadata"] for item in data]  # Store only the inner metadata | ||||
|                 self.save_db() | ||||
|  | ||||
|     def search(self, query, k=5, similarity_threshold=0.3): | ||||
|         if query not in self.query_cache: | ||||
|             self.query_cache[query] = self.client.embed([query], model="voyage-2").embeddings[0] | ||||
|             self.save_db() | ||||
|          | ||||
|         similarities = np.dot(self.embeddings, self.query_cache[query]) | ||||
|         top_indices = np.argsort(similarities)[::-1] | ||||
|          | ||||
|         return [{"metadata": self.metadata[i], "similarity": similarities[i]}  | ||||
|                 for i in top_indices if similarities[i] >= similarity_threshold][:k] | ||||
|  | ||||
|     def save_db(self): | ||||
|         with open(self.db_path, "wb") as file: | ||||
|             pickle.dump({"embeddings": self.embeddings, "metadata": self.metadata,  | ||||
|                          "query_cache": json.dumps(self.query_cache)}, file) | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Reference in New Issue
	
	Block a user
	 Mahesh Murag
					Mahesh Murag