mirror of
https://github.com/codelion/optillm.git
synced 2025-05-28 09:39:38 +03:00
310 lines
9.9 KiB
Python
310 lines
9.9 KiB
Python
from typing import Dict, Any
|
|
from z3 import *
|
|
import sympy
|
|
import io
|
|
import re
|
|
import contextlib
|
|
import logging
|
|
import ast
|
|
import math
|
|
import multiprocessing
|
|
import traceback
|
|
|
|
class TimeoutException(Exception):
|
|
pass
|
|
|
|
def prepare_safe_globals():
|
|
safe_globals = {
|
|
'print': print,
|
|
'__builtins__': {
|
|
'True': True,
|
|
'False': False,
|
|
'None': None,
|
|
'abs': abs,
|
|
'float': float,
|
|
'int': int,
|
|
'len': len,
|
|
'max': max,
|
|
'min': min,
|
|
'round': round,
|
|
'sum': sum,
|
|
'complex': complex,
|
|
}
|
|
}
|
|
|
|
# Add common math functions
|
|
safe_globals.update({
|
|
'log': math.log,
|
|
'log2': math.log2,
|
|
'sqrt': math.sqrt,
|
|
'exp': math.exp,
|
|
'sin': math.sin,
|
|
'cos': math.cos,
|
|
'tan': math.tan,
|
|
'pi': math.pi,
|
|
'e': math.e,
|
|
})
|
|
|
|
# Add complex number support
|
|
safe_globals['I'] = complex(0, 1)
|
|
safe_globals['Complex'] = complex
|
|
|
|
return safe_globals
|
|
|
|
def execute_code_in_process(code: str):
|
|
import z3
|
|
import sympy
|
|
import math
|
|
import itertools
|
|
from fractions import Fraction
|
|
|
|
safe_globals = prepare_safe_globals()
|
|
|
|
# Add Z3 specific functions
|
|
z3_whitelist = set(dir(z3))
|
|
safe_globals.update({name: getattr(z3, name) for name in z3_whitelist})
|
|
|
|
# Add SymPy specific functions
|
|
sympy_whitelist = set(dir(sympy))
|
|
safe_globals.update({name: getattr(sympy, name) for name in sympy_whitelist})
|
|
|
|
# Ensure key Z3 and SymPy components are available
|
|
safe_globals.update({
|
|
'z3': z3,
|
|
'sympy': sympy,
|
|
'Solver': z3.Solver,
|
|
'solver': z3.Solver,
|
|
'Optimize': z3.Optimize,
|
|
'sat': z3.sat,
|
|
'unsat': z3.unsat,
|
|
'unknown': z3.unknown,
|
|
'Real': z3.Real,
|
|
'Int': z3.Int,
|
|
'Bool': z3.Bool,
|
|
'And': z3.And,
|
|
'Or': z3.Or,
|
|
'Not': z3.Not,
|
|
'Implies': z3.Implies,
|
|
'If': z3.If,
|
|
'Sum': z3.Sum,
|
|
'ForAll': z3.ForAll,
|
|
'Exists': z3.Exists,
|
|
'model': z3.Model,
|
|
'Symbol': sympy.Symbol,
|
|
'solve': sympy.solve,
|
|
'simplify': sympy.simplify,
|
|
'expand': sympy.expand,
|
|
'factor': sympy.factor,
|
|
'diff': sympy.diff,
|
|
'integrate': sympy.integrate,
|
|
'limit': sympy.limit,
|
|
'series': sympy.series,
|
|
})
|
|
|
|
# Add custom functions
|
|
def as_numerical(x):
|
|
if z3.is_expr(x):
|
|
if z3.is_int_value(x) or z3.is_rational_value(x):
|
|
return float(x.as_decimal(20))
|
|
elif z3.is_algebraic_value(x):
|
|
return x.approx(20)
|
|
return float(x)
|
|
|
|
safe_globals['as_numerical'] = as_numerical
|
|
|
|
def Mod(x, y):
|
|
return x % y
|
|
|
|
safe_globals['Mod'] = Mod
|
|
|
|
def Rational(numerator, denominator=1):
|
|
return z3.Real(str(Fraction(numerator, denominator)))
|
|
|
|
safe_globals['Rational'] = Rational
|
|
|
|
output_buffer = io.StringIO()
|
|
with contextlib.redirect_stdout(output_buffer):
|
|
try:
|
|
exec(code, safe_globals, {})
|
|
except Exception:
|
|
return ("error", traceback.format_exc())
|
|
return ("success", output_buffer.getvalue())
|
|
|
|
class Z3SymPySolverSystem:
|
|
def __init__(self, system_prompt: str, client, model: str, timeout: int = 30):
|
|
self.system_prompt = system_prompt
|
|
self.model = model
|
|
self.client = client
|
|
self.timeout = timeout
|
|
self.solver_completion_tokens = 0
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
def process_query(self, query: str) -> str:
|
|
try:
|
|
analysis = self.analyze_query(query)
|
|
if "SOLVER_CAN_BE_APPLIED: True" not in analysis:
|
|
return self.standard_llm_inference(query), self.solver_completion_tokens
|
|
|
|
formulation = self.extract_and_validate_expressions(analysis)
|
|
solver_result = self.solve_with_z3_sympy(formulation)
|
|
|
|
return self.generate_response(query, analysis, solver_result), self.solver_completion_tokens
|
|
except Exception as e:
|
|
logging.error(f"An error occurred while processing the query with Z3 and SymPy, returning standard llm inference results: {str(e)}")
|
|
return self.standard_llm_inference(query), self.solver_completion_tokens
|
|
|
|
def analyze_query(self, query: str) -> str:
|
|
analysis_prompt = f"""Analyze the given query and determine if it can be solved using Z3 or SymPy:
|
|
|
|
1. Identify variables, constraints, and objectives.
|
|
2. Determine the problem type (e.g., SAT, optimization, symbolic manipulation).
|
|
3. Decide if Z3, SymPy, or a combination of both is suitable.
|
|
|
|
If Z3 or SymPy can be applied, provide Python code using the appropriate library (or both) to solve the problem. Make sure you define any additional methods you need for solving the problem.
|
|
The code will be executed in an environment with Z3 and SymPy available, so do not include any other libraries or modules.
|
|
|
|
Query: {query}
|
|
|
|
Respond with:
|
|
SOLVER_CAN_BE_APPLIED: [True/False]
|
|
|
|
SOLVER_FORMULATION:
|
|
```python
|
|
# Z3 and/or SymPy code here
|
|
```
|
|
|
|
Analysis:
|
|
[Your step-by-step analysis]
|
|
"""
|
|
|
|
analysis_response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": analysis_prompt}
|
|
],
|
|
max_tokens=1024,
|
|
n=1,
|
|
temperature=0.1
|
|
)
|
|
self.solver_completion_tokens = analysis_response.usage.completion_tokens
|
|
return analysis_response.choices[0].message.content
|
|
|
|
def generate_response(self, query: str, analysis: str, solver_result: Dict[str, Any]) -> str:
|
|
if solver_result.get("status") != "success":
|
|
return self.standard_llm_inference(query)
|
|
|
|
response_prompt = f"""Provide a clear answer to the query using the analysis and solver result:
|
|
|
|
Query: {query}
|
|
|
|
Analysis: {analysis}
|
|
|
|
Solver Result: {solver_result.get("output")}
|
|
|
|
Response:
|
|
"""
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": response_prompt}
|
|
],
|
|
max_tokens=4096,
|
|
n=1,
|
|
temperature=0.1
|
|
)
|
|
self.solver_completion_tokens = response.usage.completion_tokens
|
|
return response.choices[0].message.content
|
|
|
|
def standard_llm_inference(self, query: str) -> str:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": query}
|
|
],
|
|
max_tokens=4096,
|
|
n=1,
|
|
temperature=0.1
|
|
)
|
|
self.solver_completion_tokens = response.usage.completion_tokens
|
|
return response.choices[0].message.content
|
|
|
|
def extract_and_validate_expressions(self, analysis: str) -> str:
|
|
formulation = re.search(r"```python\n([\s\S]+?)```", analysis)
|
|
if formulation:
|
|
return formulation.group(1).strip()
|
|
raise ValueError("No valid Z3 or SymPy formulation found in the analysis.")
|
|
|
|
def solve_with_z3_sympy(self, formulation: str, max_attempts: int = 3) -> Dict[str, Any]:
|
|
for attempt in range(max_attempts):
|
|
output = self.execute_solver_code(formulation)
|
|
if "Error:" not in output:
|
|
return {"status": "success", "output": output}
|
|
|
|
error_prompt = f"""Fix the Z3 or SymPy code that resulted in an error. Follow these steps:
|
|
|
|
1. Review the original code and the error message carefully.
|
|
2. Analyze the error and identify its root cause.
|
|
3. Think through the necessary changes to fix the error.
|
|
4. Generate a corrected version of the code.
|
|
|
|
Original Code:
|
|
{formulation}
|
|
|
|
Error Message:
|
|
{output}
|
|
|
|
Step-by-Step Analysis:
|
|
[Provide your step-by-step analysis here]
|
|
|
|
Corrected Z3 or SymPy Code:
|
|
```python
|
|
# Corrected code here
|
|
```
|
|
"""
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": error_prompt}
|
|
],
|
|
max_tokens=1024,
|
|
n=1,
|
|
temperature=0.1
|
|
)
|
|
self.solver_completion_tokens = response.usage.completion_tokens
|
|
formulation = self.extract_and_validate_expressions(response.choices[0].message.content)
|
|
|
|
return {"status": "failed", "output": "Failed to solve after multiple attempts."}
|
|
|
|
def execute_solver_code(self, code: str) -> str:
|
|
logging.info("Executing Z3 and SymPy solver code")
|
|
logging.info(f"Code: {code}")
|
|
|
|
# Parse the code into an AST
|
|
try:
|
|
_ = ast.parse(code)
|
|
except SyntaxError as e:
|
|
logging.error(f"Syntax error in provided code: {e}")
|
|
return f"Error: Syntax error: {e}"
|
|
|
|
# Execute the code in a separate process
|
|
ctx = multiprocessing.get_context('spawn')
|
|
with ctx.Pool(1) as pool:
|
|
async_result = pool.apply_async(execute_code_in_process, (code,))
|
|
try:
|
|
status, result = async_result.get(timeout=self.timeout)
|
|
except multiprocessing.TimeoutError:
|
|
pool.terminate()
|
|
logging.error("Execution timed out")
|
|
return "Error: Execution timed out"
|
|
|
|
if status == "error":
|
|
logging.error(f"Execution error: {result}")
|
|
return f"Error: {result}"
|
|
|
|
logging.info("Z3 and SymPy solver code executed successfully")
|
|
return result |