Add completion tokens

This commit is contained in:
Asankhaya Sharma
2024-09-24 03:37:22 -07:00
parent b943abf7e4
commit 4eae055a89
13 changed files with 262 additions and 57 deletions

View File

@@ -123,37 +123,38 @@ def proxy():
logger.info(f'Using approach {approach}, with {model}')
completion_tokens = 0
try:
if approach == 'mcts':
final_response = chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
final_response, completion_tokens = chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
server_config['mcts_exploration'], server_config['mcts_depth'])
elif approach == 'bon':
final_response = best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
final_response, completion_tokens = best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
elif approach == 'moa':
final_response = mixture_of_agents(system_prompt, initial_query, client, model)
final_response, completion_tokens = mixture_of_agents(system_prompt, initial_query, client, model)
elif approach == 'rto':
final_response = round_trip_optimization(system_prompt, initial_query, client, model)
final_response, completion_tokens = round_trip_optimization(system_prompt, initial_query, client, model)
elif approach == 'z3':
z3_solver = Z3SolverSystem(system_prompt, client, model)
final_response = z3_solver.process_query(initial_query)
final_response, completion_tokens = z3_solver.process_query(initial_query)
elif approach == "self_consistency":
final_response = advanced_self_consistency_approach(system_prompt, initial_query, client, model)
final_response, completion_tokens = advanced_self_consistency_approach(system_prompt, initial_query, client, model)
elif approach == "pvg":
final_response = inference_time_pv_game(system_prompt, initial_query, client, model)
final_response, completion_tokens = inference_time_pv_game(system_prompt, initial_query, client, model)
elif approach == "rstar":
rstar = RStar(system_prompt, client, model,
max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
c=server_config['rstar_c'])
final_response = rstar.solve(initial_query)
final_response, completion_tokens = rstar.solve(initial_query)
elif approach == "cot_reflection":
final_response = cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
final_response, completion_tokens = cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
elif approach == 'plansearch':
final_response = plansearch(system_prompt, initial_query, client, model, n=n)
final_response, completion_tokens = plansearch(system_prompt, initial_query, client, model, n=n)
elif approach == 'leap':
final_response = leap(system_prompt, initial_query, client, model)
final_response, completion_tokens = leap(system_prompt, initial_query, client, model)
elif approach == 're2':
final_response = re2_approach(system_prompt, initial_query, client, model, n=n)
final_response, completion_tokens = re2_approach(system_prompt, initial_query, client, model, n=n)
else:
raise ValueError(f"Unknown approach: {approach}")
except Exception as e:
@@ -162,7 +163,10 @@ def proxy():
response_data = {
'model': model,
'choices': []
'choices': [],
'usage': {
'completion_tokens': completion_tokens,
}
}
if isinstance(final_response, list):

View File

@@ -3,6 +3,8 @@ import logging
logger = logging.getLogger(__name__)
def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3) -> str:
bon_completion_tokens = 0
messages = [{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_query}]
@@ -16,6 +18,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
temperature=1
)
completions = [choice.message.content for choice in response.choices]
bon_completion_tokens += response.usage.completion_tokens
# Rate the completions
rating_messages = messages.copy()
@@ -33,7 +36,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
n=1,
temperature=0.1
)
bon_completion_tokens += rating_response.usage.completion_tokens
try:
rating = float(rating_response.choices[0].message.content.strip())
ratings.append(rating)
@@ -43,4 +46,4 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
rating_messages = rating_messages[:-2]
best_index = ratings.index(max(ratings))
return completions[best_index]
return completions[best_index], bon_completion_tokens

View File

@@ -4,6 +4,7 @@ import logging
logger = logging.getLogger(__name__)
def cot_reflection(system_prompt, initial_query, client, model: str, return_full_response: bool=False):
cot_completion_tokens = 0
cot_prompt = f"""
{system_prompt}
@@ -44,6 +45,7 @@ def cot_reflection(system_prompt, initial_query, client, model: str, return_full
# Extract the full response
full_response = response.choices[0].message.content
cot_completion_tokens += response.usage.completion_tokens
logger.info(f"CoT with Reflection :\n{full_response}")
# Use regex to extract the content within <thinking> and <output> tags
@@ -56,7 +58,7 @@ def cot_reflection(system_prompt, initial_query, client, model: str, return_full
logger.info(f"Final output :\n{output}")
if return_full_response:
return full_response
return full_response, cot_completion_tokens
else:
return output
return output, cot_completion_tokens

View File

@@ -14,6 +14,7 @@ class LEAP:
self.model = model
self.low_level_principles = []
self.high_level_principles = []
self.leap_completion_tokens = 0
def extract_output(self, text: str) -> str:
match = re.search(r'<output>(.*?)(?:</output>|$)', text, re.DOTALL)
@@ -46,6 +47,7 @@ class LEAP:
"""}
]
)
self.leap_completion_tokens += response.usage.completion_tokens
examples_str = self.extract_output(response.choices[0].message.content)
logger.debug(f"Extracted examples: {examples_str}")
examples = []
@@ -80,6 +82,7 @@ class LEAP:
],
temperature=0.7,
)
self.leap_completion_tokens += response.usage.completion_tokens
generated_reasoning = response.choices[0].message.content
generated_answer = self.extract_output(generated_reasoning)
if generated_answer != correct_answer:
@@ -110,6 +113,7 @@ class LEAP:
"""}
]
)
self.leap_completion_tokens += response.usage.completion_tokens
self.low_level_principles.append(self.extract_output(response.choices[0].message.content))
return self.low_level_principles
@@ -134,6 +138,7 @@ class LEAP:
"""}
]
)
self.leap_completion_tokens += response.usage.completion_tokens
self.high_level_principles = self.extract_output(response.choices[0].message.content).split("\n")
return self.high_level_principles
@@ -154,6 +159,7 @@ class LEAP:
"""}
]
)
self.leap_completion_tokens += response.usage.completion_tokens
return response.choices[0].message.content
def solve(self, initial_query: str) -> str:
@@ -171,4 +177,4 @@ class LEAP:
def leap(system_prompt: str, initial_query: str, client, model: str) -> str:
leap_solver = LEAP(system_prompt, client, model)
return leap_solver.solve(initial_query)
return leap_solver.solve(initial_query), leap_solver.leap_completion_tokens

View File

@@ -32,6 +32,7 @@ class MCTS:
self.node_labels = {}
self.client = client
self.model = model
self.completion_tokens = 0
def select(self, node: MCTSNode) -> MCTSNode:
logger.debug(f"Selecting node. Current node visits: {node.visits}, value: {node.value}")
@@ -118,6 +119,7 @@ class MCTS:
temperature=1
)
completions = [choice.message.content.strip() for choice in response.choices]
self.completion_tokens += response.usage.completion_tokens
logger.info(f"Received {len(completions)} completions from the model")
return completions
@@ -140,6 +142,7 @@ class MCTS:
)
next_query = response.choices[0].message.content
self.completion_tokens += response.usage.completion_tokens
logger.info(f"Generated next user query: {next_query}")
return DialogueState(state.system_prompt, new_history, next_query)
@@ -161,7 +164,7 @@ class MCTS:
n=1,
temperature=0.1
)
self.completion_tokens += response.usage.completion_tokens
try:
score = float(response.choices[0].message.content.strip())
score = max(0, min(score, 1)) # Ensure the score is between 0 and 1
@@ -181,4 +184,4 @@ def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, n
final_state = mcts.search(initial_state, num_simulations)
response = final_state.conversation_history[-1]['content'] if final_state.conversation_history else ""
logger.info(f"MCTS chat complete. Final response: {response[:100]}...")
return response
return response, mcts.completion_tokens

View File

@@ -3,6 +3,7 @@ import logging
logger = logging.getLogger(__name__)
def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str) -> str:
moa_completion_tokens = 0
completions = []
response = client.chat.completions.create(
@@ -16,6 +17,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
temperature=1
)
completions = [choice.message.content for choice in response.choices]
moa_completion_tokens += response.usage.completion_tokens
critique_prompt = f"""
Original query: {initial_query}
@@ -45,6 +47,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
temperature=0.1
)
critiques = critique_response.choices[0].message.content
moa_completion_tokens += critique_response.usage.completion_tokens
final_prompt = f"""
Original query: {initial_query}
@@ -76,5 +79,5 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
n=1,
temperature=0.1
)
return final_response.choices[0].message.content
moa_completion_tokens += final_response.usage.completion_tokens
return final_response.choices[0].message.content, moa_completion_tokens

View File

@@ -8,6 +8,7 @@ class PlanSearch:
self.system_prompt = system_prompt
self.client = client
self.model = model
self.plansearch_completion_tokens = 0
def generate_observations(self, problem: str, num_observations: int = 3) -> List[str]:
prompt = f"""You are an expert Python programmer. You will be given a competitive programming question
@@ -28,7 +29,7 @@ Please provide {num_observations} observations."""
{"role": "user", "content": prompt}
]
)
self.plansearch_completion_tokens += response.usage.completion_tokens
observations = response.choices[0].message.content.strip().split('\n')
return [obs.strip() for obs in observations if obs.strip()]
@@ -55,7 +56,7 @@ Please provide {num_new_observations} new observations derived from the existing
{"role": "user", "content": prompt}
]
)
self.plansearch_completion_tokens += response.usage.completion_tokens
new_observations = response.choices[0].message.content.strip().split('\n')
return [obs.strip() for obs in new_observations if obs.strip()]
@@ -80,7 +81,7 @@ IS CRUCIAL."""
{"role": "user", "content": prompt}
]
)
self.plansearch_completion_tokens += response.usage.completion_tokens
return response.choices[0].message.content.strip()
def implement_solution(self, problem: str, solution: str) -> str:
@@ -105,7 +106,7 @@ Please implement the solution in Python."""
{"role": "user", "content": prompt}
]
)
self.plansearch_completion_tokens += response.usage.completion_tokens
return response.choices[0].message.content.strip()
def solve(self, problem: str, num_initial_observations: int = 3, num_derived_observations: int = 2) -> Tuple[str, str]:
@@ -134,4 +135,4 @@ Please implement the solution in Python."""
def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1) -> List[str]:
planner = PlanSearch(system_prompt, client, model)
return planner.solve_multiple(initial_query, n)
return planner.solve_multiple(initial_query, n), planner.plansearch_completion_tokens

View File

@@ -4,7 +4,10 @@ from typing import List, Tuple
logger = logging.getLogger(__name__)
pvg_completion_tokens = 0
def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7) -> List[str]:
global pvg_completion_tokens
role = "sneaky" if is_sneaky else "helpful"
logger.info(f"Generating {num_solutions} {role} solutions")
@@ -34,11 +37,13 @@ def generate_solutions(client, system_prompt: str, query: str, model: str, num_s
max_tokens=4096,
temperature=temperature,
)
pvg_completion_tokens += response.usage.completion_tokens
solutions = [choice.message.content for choice in response.choices]
logger.debug(f"Generated {role} solutions: {solutions}")
return solutions
def verify_solutions(client, system_prompt: str, initial_query: str, solutions: List[str], model: str) -> List[float]:
global pvg_completion_tokens
logger.info(f"Verifying {len(solutions)} solutions")
verify_prompt = f"""{system_prompt}
You are a verifier tasked with evaluating the correctness and clarity of solutions to the given problem.
@@ -75,6 +80,7 @@ Ensure that the Score is a single number between 0 and 10, and the Explanation i
max_tokens=1024,
temperature=0.2,
)
pvg_completion_tokens += response.usage.completion_tokens
rating = response.choices[0].message.content
logger.debug(f"Raw rating for solution {i+1}: {rating}")
@@ -130,6 +136,7 @@ def extract_answer(final_state: str) -> Tuple[str, float]:
return "", 0.0
def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3) -> str:
global pvg_completion_tokens
logger.info(f"Starting inference-time PV game with {num_rounds} rounds and {num_solutions} solutions per round")
best_solution = ""
@@ -178,9 +185,10 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model
max_tokens=1024,
temperature=0.5,
)
pvg_completion_tokens += response.usage.completion_tokens
initial_query = response.choices[0].message.content
logger.debug(f"Refined query: {initial_query}")
logger.info(f"Inference-time PV game completed. Best solution score: {best_score}")
return best_solution
return best_solution, pvg_completion_tokens

View File

@@ -17,6 +17,7 @@ def re2_approach(system_prompt, initial_query, client, model, n=1):
str or list: The generated response(s) from the model.
"""
logger.info("Using RE2 approach for query processing")
re2_completion_tokens = 0
# Construct the RE2 prompt
re2_prompt = f"{initial_query}\nRead the question again: {initial_query}"
@@ -32,11 +33,11 @@ def re2_approach(system_prompt, initial_query, client, model, n=1):
messages=messages,
n=n
)
re2_completion_tokens += response.usage.completion_tokens
if n == 1:
return response.choices[0].message.content.strip()
return response.choices[0].message.content.strip(), re2_completion_tokens
else:
return [choice.message.content.strip() for choice in response.choices]
return [choice.message.content.strip() for choice in response.choices], re2_completion_tokens
except Exception as e:
logger.error(f"Error in RE2 approach: {str(e)}")

View File

@@ -30,6 +30,7 @@ class RStar:
self.actions = ["A1", "A2", "A3", "A4", "A5"]
self.original_question = None
self.system = system
self.rstar_completion_tokens = 0
logger.debug(f"Initialized RStar with model: {model}, max_depth: {max_depth}, num_rollouts: {num_rollouts}")
async def generate_response_async(self, prompt: str) -> str:
@@ -46,6 +47,7 @@ class RStar:
}
) as response:
result = await response.json()
self.rstar_completion_tokens += result['usage']['completion_tokens']
return result['choices'][0]['message']['content'].strip()
async def expand_async(self, node: Node, action: str) -> Node:
@@ -101,7 +103,7 @@ class RStar:
answers = [self.extract_answer(node.state) for node in final_trajectory]
final_answer = self.select_best_answer(answers)
logger.info(f"Selected final answer: {final_answer}")
return final_answer
return final_answer, self.rstar_completion_tokens
def generate_response(self, prompt: str) -> str:
logger.debug(f"Generating response for prompt: {prompt[:100]}...")
@@ -114,6 +116,7 @@ class RStar:
max_tokens=4096,
temperature=0.2
)
self.rstar_completion_tokens += response.usage.completion_tokens
generated_response = response.choices[0].message.content.strip()
logger.debug(f"Generated response: {generated_response}")
return generated_response
@@ -342,5 +345,4 @@ This rephrasing should help clarify the problem and guide the solution process."
"""
Synchronous wrapper for solve_async method.
"""
return asyncio.run(self.solve_async(question))
return asyncio.run(self.solve_async(question))

View File

@@ -14,6 +14,7 @@ def extract_code_from_prompt(text):
return text
def round_trip_optimization(system_prompt: str, initial_query: str, client, model: str) -> str:
rto_completion_tokens = 0
messages = [{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_query}]
@@ -26,6 +27,7 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode
temperature=0.1
)
c1 = response_c1.choices[0].message.content
rto_completion_tokens += response_c1.usage.completion_tokens
# Generate description of the code (Q2)
messages.append({"role": "assistant", "content": c1})
@@ -38,6 +40,7 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode
temperature=0.1
)
q2 = response_q2.choices[0].message.content
rto_completion_tokens += response_q2.usage.completion_tokens
# Generate second code based on the description (C2)
messages = [{"role": "system", "content": system_prompt},
@@ -50,6 +53,7 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode
temperature=0.1
)
c2 = response_c2.choices[0].message.content
rto_completion_tokens += response_c2.usage.completion_tokens
c1 = extract_code_from_prompt(c1)
c2 = extract_code_from_prompt(c2)
@@ -67,5 +71,6 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode
temperature=0.1
)
c3 = response_c3.choices[0].message.content
rto_completion_tokens += response_c3.usage.completion_tokens
return c3
return c3, rto_completion_tokens

View File

@@ -10,6 +10,7 @@ class AdvancedSelfConsistency:
self.model = model
self.num_samples = num_samples
self.similarity_threshold = similarity_threshold
self.self_consistency_completion_tokens = 0
def generate_responses(self, system_prompt: str, user_prompt: str) -> List[str]:
responses = []
@@ -23,6 +24,7 @@ class AdvancedSelfConsistency:
temperature=1,
max_tokens=4096
)
self.self_consistency_completion_tokens += response.usage.completion_tokens
responses.append(response.choices[0].message.content)
return responses
@@ -85,6 +87,6 @@ def advanced_self_consistency_approach(system_prompt: str, initial_query: str, c
logger.debug(f" Variants: {cluster['variants']}")
if result['aggregated_result']['clusters']:
return result['aggregated_result']['clusters'][0]['answer']
return result['aggregated_result']['clusters'][0]['answer'], self_consistency.self_consistency_completion_tokens
else:
return "No consistent answer found."
return "No consistent answer found.", self_consistency.self_consistency_completion_tokens

View File

@@ -6,12 +6,119 @@ import re
import contextlib
import logging
import ast
import math
import itertools
from fractions import Fraction
import threading
import ctypes
import multiprocessing
import traceback
import sys
class TimeoutException(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutException("Execution timed out")
def prepare_safe_globals():
safe_globals = {
'print': print,
'__builtins__': {
'True': True,
'False': False,
'None': None,
'abs': abs,
'float': float,
'int': int,
'len': len,
'max': max,
'min': min,
'round': round,
'sum': sum,
'complex': complex,
}
}
# Add common math functions
safe_globals.update({
'log': math.log,
'log2': math.log2,
'sqrt': math.sqrt,
'exp': math.exp,
'sin': math.sin,
'cos': math.cos,
'tan': math.tan,
'pi': math.pi,
'e': math.e,
})
# Add complex number support
safe_globals['I'] = complex(0, 1)
safe_globals['Complex'] = complex
return safe_globals
def execute_code_in_process(code: str):
import z3
import math
import itertools
from fractions import Fraction
safe_globals = prepare_safe_globals()
# Add Z3 specific functions
z3_whitelist = set(dir(z3))
safe_globals.update({name: getattr(z3, name) for name in z3_whitelist})
# Ensure key Z3 components are available
safe_globals.update({
'z3': z3,
'Solver': z3.Solver,
'solver': z3.Solver,
'Optimize': z3.Optimize,
'sat': z3.sat,
'unsat': z3.unsat,
'unknown': z3.unknown,
'Real': z3.Real,
'Int': z3.Int,
'Bool': z3.Bool,
'And': z3.And,
'Or': z3.Or,
'Not': z3.Not,
'Implies': z3.Implies,
'If': z3.If,
'Sum': z3.Sum,
'ForAll': z3.ForAll,
'Exists': z3.Exists,
'model': z3.Model,
})
# Add custom functions
def as_numerical(x):
if z3.is_expr(x):
if z3.is_int_value(x) or z3.is_rational_value(x):
return float(x.as_decimal(20))
elif z3.is_algebraic_value(x):
return x.approx(20)
return float(x)
safe_globals['as_numerical'] = as_numerical
def Mod(x, y):
return x % y
safe_globals['Mod'] = Mod
def Rational(numerator, denominator=1):
return z3.Real(str(Fraction(numerator, denominator)))
safe_globals['Rational'] = Rational
output_buffer = io.StringIO()
with contextlib.redirect_stdout(output_buffer):
try:
exec(code, safe_globals, {})
except Exception:
return ("error", traceback.format_exc())
return ("success", output_buffer.getvalue())
class Z3SolverSystem:
def __init__(self, system_prompt: str, client, model: str, timeout: int = 30):
@@ -19,6 +126,7 @@ class Z3SolverSystem:
self.model = model
self.client = client
self.timeout = timeout
self.z3_completion_tokens = 0
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def process_query(self, query: str) -> str:
@@ -26,17 +134,17 @@ class Z3SolverSystem:
analysis = self.analyze_query(query)
# print("Analysis: "+ analysis)
if "SOLVER_CAN_BE_APPLIED: True" not in analysis:
return self.standard_llm_inference(query)
return self.standard_llm_inference(query) , self.z3_completion_tokens
formulation = self.extract_and_validate_expressions(analysis)
# print("Formulation: "+ formulation)
solver_result = self.solve_with_z3(formulation)
# print(solver_result)
return self.generate_response(query, analysis, solver_result)
return self.generate_response(query, analysis, solver_result), self.z3_completion_tokens
except Exception as e:
logging.error(f"An error occurred while processing the query with Z3, returning standard llm inference results: {str(e)}")
return self.standard_llm_inference(query)
return self.standard_llm_inference(query), self.z3_completion_tokens
def analyze_query(self, query: str) -> str:
analysis_prompt = f"""Analyze the given query and determine if it can be solved using Z3:
@@ -45,7 +153,8 @@ class Z3SolverSystem:
2. Determine the problem type (e.g., SAT, optimization).
3. Decide if Z3 is suitable.
If Z3 can be applied, provide Python code using Z3 to solve the problem.
If Z3 can be applied, provide Python code using Z3 to solve the problem. Make sure you define any additional methods you need for solving the problem.
The code will be executed in an environment with only Z3 available, so do not include any other libraries or modules.
Query: {query}
@@ -71,6 +180,7 @@ Analysis:
n=1,
temperature=0.1
)
self.z3_completion_tokens = analysis_response.usage.completion_tokens
return analysis_response.choices[0].message.content
def generate_response(self, query: str, analysis: str, solver_result: Dict[str, Any]) -> str:
@@ -98,6 +208,7 @@ Response:
n=1,
temperature=0.1
)
self.z3_completion_tokens = response.usage.completion_tokens
return response.choices[0].message.content
def standard_llm_inference(self, query: str) -> str:
@@ -111,6 +222,7 @@ Response:
n=1,
temperature=0.1
)
self.z3_completion_tokens = response.usage.completion_tokens
return response.choices[0].message.content
def extract_and_validate_expressions(self, analysis: str) -> str:
@@ -148,12 +260,14 @@ Provide corrected Z3 code:
n=1,
temperature=0.1
)
self.z3_completion_tokens = response.usage.completion_tokens
formulation = self.extract_and_validate_expressions(response.choices[0].message.content)
return {"status": "failed", "output": "Failed to solve after multiple attempts."}
def execute_solver_code(self, code: str) -> str:
logging.info("Executing Z3 solver code")
logging.info(f"Code: {code}")
# Define a whitelist of allowed Z3 names
z3_whitelist = set(dir(z3))
@@ -169,16 +283,19 @@ Provide corrected Z3 code:
for node in ast.walk(parsed_ast):
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name != 'z3':
if alias.name not in ['z3', 'math', 'fractions', 'itertools']:
logging.warning(f"Unauthorized import: {alias.name}")
return f"Error: Unauthorized import: {alias.name}"
elif isinstance(node, ast.ImportFrom) and node.module != 'z3':
logging.warning(f"Unauthorized import from: {node.module}")
return f"Error: Unauthorized import from: {node.module}"
elif isinstance(node, ast.ImportFrom) and node.module not in ['z3', 'math', 'fractions', 'itertools']:
logging.warning(f"Unauthorized import from: {node.module}")
return f"Error: Unauthorized import from: {node.module}"
# Prepare a restricted global namespace
safe_globals = {
'z3': z3,
'math': math,
'itertools': itertools,
'Fraction': Fraction,
'print': print, # Allow print for output
'__builtins__': {
'True': True,
@@ -192,19 +309,67 @@ Provide corrected Z3 code:
'min': min,
'round': round,
'sum': sum,
'complex': complex,
}
}
safe_globals.update({name: getattr(z3, name) for name in z3_whitelist})
# Add common math functions
safe_globals.update({
'log': math.log,
'sqrt': math.sqrt,
'exp': math.exp,
'sin': math.sin,
'cos': math.cos,
'tan': math.tan,
'pi': math.pi,
'e': math.e,
})
# Execute the code
output_buffer = io.StringIO()
with contextlib.redirect_stdout(output_buffer):
# Add complex number support
safe_globals['I'] = complex(0, 1)
safe_globals['Complex'] = complex
# Add Z3 specific types and functions
safe_globals['Optimize'] = z3.Optimize
# Add custom functions for Z3 specific operations
def as_numerical(x):
if z3.is_expr(x):
if z3.is_int_value(x) or z3.is_rational_value(x):
return float(x.as_decimal(20))
elif z3.is_algebraic_value(x):
return x.approx(20)
return float(x)
safe_globals['as_numerical'] = as_numerical
# Add a custom Mod function that uses Z3's modulo operator
def Mod(x, y):
return x % y
safe_globals['Mod'] = Mod
# Add a custom Rational function to create rational numbers in Z3
def Rational(numerator, denominator=1):
return z3.Real(str(Fraction(numerator, denominator)))
safe_globals['Rational'] = Rational
# Execute the code in a separate process
ctx = multiprocessing.get_context('spawn')
with ctx.Pool(1) as pool:
async_result = pool.apply_async(execute_code_in_process, (code,))
try:
exec(code, safe_globals, {})
except Exception as e:
logging.error(f"Execution error: {str(e)}")
return f"Error: Execution error: {str(e)}"
status, result = async_result.get(timeout=self.timeout)
except multiprocessing.TimeoutError:
pool.terminate()
logging.error("Execution timed out")
return "Error: Execution timed out"
if status == "error":
logging.error(f"Execution error: {result}")
return f"Error: {result}"
executed_output = output_buffer.getvalue()
logging.info("Z3 solver code executed successfully")
return executed_output
return result