WIP Text to SQL - checkpoint

This commit is contained in:
Mahesh Murag
2024-09-24 23:35:12 +02:00
parent 58e73112df
commit 73c424d6bf
4 changed files with 1108 additions and 531 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -152,7 +152,4 @@ tests:
"reason": f"SQL {'executed successfully' if execution_success else 'execution failed'}. "
f"Data {'matches' if data_match else 'does not match'} expected result. "
f"Actual: {actual_result if row else 'No data'}, Expected: {expected_result}"
}
outputPath: ../data/results.csv
}

View File

@@ -1474,24 +1474,12 @@
}
],
"source": [
"def generate_sql_with_cot(prompt):\n",
" response = client.messages.create(\n",
" model=MODEL_NAME,\n",
" max_tokens=1000,\n",
" temperature=0,\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": prompt}\n",
" ]\n",
" )\n",
" return response.content[0].text.strip()\n",
"\n",
"# Generate SQL using the chain-of-thought prompt\n",
"result = generate_sql_with_cot(prompt)\n",
"result = generate_sql(prompt)\n",
"print(\"Raw response from Claude:\")\n",
"print(result)\n",
"\n",
"# Extract thought process and SQL query using simple string manipulation\n",
"# Note: For more robust parsing, consider using an XML parsing library\n",
"# Extract thought process and SQL query\n",
"thought_process = result.split('<thought_process>')[1].split('</thought_process>')[0].strip()\n",
"sql = result.split('<sql>')[1].split('</sql>')[0].strip()\n",
"\n",
@@ -1818,7 +1806,7 @@
"print(prompt)\n",
"\n",
"# Generate and execute SQL\n",
"result = generate_sql_with_cot(prompt)\n",
"result = generate_sql(prompt)\n",
"print(\"\\nGenerated result:\")\n",
"print(result)\n",
"\n",
@@ -2018,7 +2006,7 @@
" Explain your changes in <thought_process> tags and provide the corrected SQL in <sql> tags.\n",
" \"\"\"\n",
" \n",
" response = generate_sql_with_cot(prompt)\n",
" response = generate_sql(prompt)\n",
" sql = response.split('<sql>')[1].split('</sql>')[0].strip()\n",
" \n",
" print(f\"\\nAttempt {attempt + 1}:\")\n",