mirror of
https://github.com/codelion/optillm.git
synced 2025-05-28 09:39:38 +03:00
Add completion tokens
This commit is contained in:
30
optillm.py
30
optillm.py
@@ -123,37 +123,38 @@ def proxy():
|
||||
|
||||
|
||||
logger.info(f'Using approach {approach}, with {model}')
|
||||
completion_tokens = 0
|
||||
|
||||
try:
|
||||
if approach == 'mcts':
|
||||
final_response = chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
|
||||
final_response, completion_tokens = chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
|
||||
server_config['mcts_exploration'], server_config['mcts_depth'])
|
||||
elif approach == 'bon':
|
||||
final_response = best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
|
||||
final_response, completion_tokens = best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
|
||||
elif approach == 'moa':
|
||||
final_response = mixture_of_agents(system_prompt, initial_query, client, model)
|
||||
final_response, completion_tokens = mixture_of_agents(system_prompt, initial_query, client, model)
|
||||
elif approach == 'rto':
|
||||
final_response = round_trip_optimization(system_prompt, initial_query, client, model)
|
||||
final_response, completion_tokens = round_trip_optimization(system_prompt, initial_query, client, model)
|
||||
elif approach == 'z3':
|
||||
z3_solver = Z3SolverSystem(system_prompt, client, model)
|
||||
final_response = z3_solver.process_query(initial_query)
|
||||
final_response, completion_tokens = z3_solver.process_query(initial_query)
|
||||
elif approach == "self_consistency":
|
||||
final_response = advanced_self_consistency_approach(system_prompt, initial_query, client, model)
|
||||
final_response, completion_tokens = advanced_self_consistency_approach(system_prompt, initial_query, client, model)
|
||||
elif approach == "pvg":
|
||||
final_response = inference_time_pv_game(system_prompt, initial_query, client, model)
|
||||
final_response, completion_tokens = inference_time_pv_game(system_prompt, initial_query, client, model)
|
||||
elif approach == "rstar":
|
||||
rstar = RStar(system_prompt, client, model,
|
||||
max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
|
||||
c=server_config['rstar_c'])
|
||||
final_response = rstar.solve(initial_query)
|
||||
final_response, completion_tokens = rstar.solve(initial_query)
|
||||
elif approach == "cot_reflection":
|
||||
final_response = cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
|
||||
final_response, completion_tokens = cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
|
||||
elif approach == 'plansearch':
|
||||
final_response = plansearch(system_prompt, initial_query, client, model, n=n)
|
||||
final_response, completion_tokens = plansearch(system_prompt, initial_query, client, model, n=n)
|
||||
elif approach == 'leap':
|
||||
final_response = leap(system_prompt, initial_query, client, model)
|
||||
final_response, completion_tokens = leap(system_prompt, initial_query, client, model)
|
||||
elif approach == 're2':
|
||||
final_response = re2_approach(system_prompt, initial_query, client, model, n=n)
|
||||
final_response, completion_tokens = re2_approach(system_prompt, initial_query, client, model, n=n)
|
||||
else:
|
||||
raise ValueError(f"Unknown approach: {approach}")
|
||||
except Exception as e:
|
||||
@@ -162,7 +163,10 @@ def proxy():
|
||||
|
||||
response_data = {
|
||||
'model': model,
|
||||
'choices': []
|
||||
'choices': [],
|
||||
'usage': {
|
||||
'completion_tokens': completion_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
if isinstance(final_response, list):
|
||||
|
||||
@@ -3,6 +3,8 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3) -> str:
|
||||
bon_completion_tokens = 0
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": initial_query}]
|
||||
|
||||
@@ -16,6 +18,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
|
||||
temperature=1
|
||||
)
|
||||
completions = [choice.message.content for choice in response.choices]
|
||||
bon_completion_tokens += response.usage.completion_tokens
|
||||
|
||||
# Rate the completions
|
||||
rating_messages = messages.copy()
|
||||
@@ -33,7 +36,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
|
||||
n=1,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
bon_completion_tokens += rating_response.usage.completion_tokens
|
||||
try:
|
||||
rating = float(rating_response.choices[0].message.content.strip())
|
||||
ratings.append(rating)
|
||||
@@ -43,4 +46,4 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
|
||||
rating_messages = rating_messages[:-2]
|
||||
|
||||
best_index = ratings.index(max(ratings))
|
||||
return completions[best_index]
|
||||
return completions[best_index], bon_completion_tokens
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def cot_reflection(system_prompt, initial_query, client, model: str, return_full_response: bool=False):
|
||||
cot_completion_tokens = 0
|
||||
cot_prompt = f"""
|
||||
{system_prompt}
|
||||
|
||||
@@ -44,6 +45,7 @@ def cot_reflection(system_prompt, initial_query, client, model: str, return_full
|
||||
|
||||
# Extract the full response
|
||||
full_response = response.choices[0].message.content
|
||||
cot_completion_tokens += response.usage.completion_tokens
|
||||
logger.info(f"CoT with Reflection :\n{full_response}")
|
||||
|
||||
# Use regex to extract the content within <thinking> and <output> tags
|
||||
@@ -56,7 +58,7 @@ def cot_reflection(system_prompt, initial_query, client, model: str, return_full
|
||||
logger.info(f"Final output :\n{output}")
|
||||
|
||||
if return_full_response:
|
||||
return full_response
|
||||
return full_response, cot_completion_tokens
|
||||
else:
|
||||
return output
|
||||
return output, cot_completion_tokens
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ class LEAP:
|
||||
self.model = model
|
||||
self.low_level_principles = []
|
||||
self.high_level_principles = []
|
||||
self.leap_completion_tokens = 0
|
||||
|
||||
def extract_output(self, text: str) -> str:
|
||||
match = re.search(r'<output>(.*?)(?:</output>|$)', text, re.DOTALL)
|
||||
@@ -46,6 +47,7 @@ class LEAP:
|
||||
"""}
|
||||
]
|
||||
)
|
||||
self.leap_completion_tokens += response.usage.completion_tokens
|
||||
examples_str = self.extract_output(response.choices[0].message.content)
|
||||
logger.debug(f"Extracted examples: {examples_str}")
|
||||
examples = []
|
||||
@@ -80,6 +82,7 @@ class LEAP:
|
||||
],
|
||||
temperature=0.7,
|
||||
)
|
||||
self.leap_completion_tokens += response.usage.completion_tokens
|
||||
generated_reasoning = response.choices[0].message.content
|
||||
generated_answer = self.extract_output(generated_reasoning)
|
||||
if generated_answer != correct_answer:
|
||||
@@ -110,6 +113,7 @@ class LEAP:
|
||||
"""}
|
||||
]
|
||||
)
|
||||
self.leap_completion_tokens += response.usage.completion_tokens
|
||||
self.low_level_principles.append(self.extract_output(response.choices[0].message.content))
|
||||
return self.low_level_principles
|
||||
|
||||
@@ -134,6 +138,7 @@ class LEAP:
|
||||
"""}
|
||||
]
|
||||
)
|
||||
self.leap_completion_tokens += response.usage.completion_tokens
|
||||
self.high_level_principles = self.extract_output(response.choices[0].message.content).split("\n")
|
||||
return self.high_level_principles
|
||||
|
||||
@@ -154,6 +159,7 @@ class LEAP:
|
||||
"""}
|
||||
]
|
||||
)
|
||||
self.leap_completion_tokens += response.usage.completion_tokens
|
||||
return response.choices[0].message.content
|
||||
|
||||
def solve(self, initial_query: str) -> str:
|
||||
@@ -171,4 +177,4 @@ class LEAP:
|
||||
|
||||
def leap(system_prompt: str, initial_query: str, client, model: str) -> str:
|
||||
leap_solver = LEAP(system_prompt, client, model)
|
||||
return leap_solver.solve(initial_query)
|
||||
return leap_solver.solve(initial_query), leap_solver.leap_completion_tokens
|
||||
@@ -32,6 +32,7 @@ class MCTS:
|
||||
self.node_labels = {}
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.completion_tokens = 0
|
||||
|
||||
def select(self, node: MCTSNode) -> MCTSNode:
|
||||
logger.debug(f"Selecting node. Current node visits: {node.visits}, value: {node.value}")
|
||||
@@ -118,6 +119,7 @@ class MCTS:
|
||||
temperature=1
|
||||
)
|
||||
completions = [choice.message.content.strip() for choice in response.choices]
|
||||
self.completion_tokens += response.usage.completion_tokens
|
||||
logger.info(f"Received {len(completions)} completions from the model")
|
||||
return completions
|
||||
|
||||
@@ -140,6 +142,7 @@ class MCTS:
|
||||
)
|
||||
|
||||
next_query = response.choices[0].message.content
|
||||
self.completion_tokens += response.usage.completion_tokens
|
||||
logger.info(f"Generated next user query: {next_query}")
|
||||
return DialogueState(state.system_prompt, new_history, next_query)
|
||||
|
||||
@@ -161,7 +164,7 @@ class MCTS:
|
||||
n=1,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
self.completion_tokens += response.usage.completion_tokens
|
||||
try:
|
||||
score = float(response.choices[0].message.content.strip())
|
||||
score = max(0, min(score, 1)) # Ensure the score is between 0 and 1
|
||||
@@ -181,4 +184,4 @@ def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, n
|
||||
final_state = mcts.search(initial_state, num_simulations)
|
||||
response = final_state.conversation_history[-1]['content'] if final_state.conversation_history else ""
|
||||
logger.info(f"MCTS chat complete. Final response: {response[:100]}...")
|
||||
return response
|
||||
return response, mcts.completion_tokens
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str) -> str:
|
||||
moa_completion_tokens = 0
|
||||
completions = []
|
||||
|
||||
response = client.chat.completions.create(
|
||||
@@ -16,6 +17,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
|
||||
temperature=1
|
||||
)
|
||||
completions = [choice.message.content for choice in response.choices]
|
||||
moa_completion_tokens += response.usage.completion_tokens
|
||||
|
||||
critique_prompt = f"""
|
||||
Original query: {initial_query}
|
||||
@@ -45,6 +47,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
|
||||
temperature=0.1
|
||||
)
|
||||
critiques = critique_response.choices[0].message.content
|
||||
moa_completion_tokens += critique_response.usage.completion_tokens
|
||||
|
||||
final_prompt = f"""
|
||||
Original query: {initial_query}
|
||||
@@ -76,5 +79,5 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
|
||||
n=1,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
return final_response.choices[0].message.content
|
||||
moa_completion_tokens += final_response.usage.completion_tokens
|
||||
return final_response.choices[0].message.content, moa_completion_tokens
|
||||
|
||||
@@ -8,6 +8,7 @@ class PlanSearch:
|
||||
self.system_prompt = system_prompt
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.plansearch_completion_tokens = 0
|
||||
|
||||
def generate_observations(self, problem: str, num_observations: int = 3) -> List[str]:
|
||||
prompt = f"""You are an expert Python programmer. You will be given a competitive programming question
|
||||
@@ -28,7 +29,7 @@ Please provide {num_observations} observations."""
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
|
||||
self.plansearch_completion_tokens += response.usage.completion_tokens
|
||||
observations = response.choices[0].message.content.strip().split('\n')
|
||||
return [obs.strip() for obs in observations if obs.strip()]
|
||||
|
||||
@@ -55,7 +56,7 @@ Please provide {num_new_observations} new observations derived from the existing
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
|
||||
self.plansearch_completion_tokens += response.usage.completion_tokens
|
||||
new_observations = response.choices[0].message.content.strip().split('\n')
|
||||
return [obs.strip() for obs in new_observations if obs.strip()]
|
||||
|
||||
@@ -80,7 +81,7 @@ IS CRUCIAL."""
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
|
||||
self.plansearch_completion_tokens += response.usage.completion_tokens
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def implement_solution(self, problem: str, solution: str) -> str:
|
||||
@@ -105,7 +106,7 @@ Please implement the solution in Python."""
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
|
||||
self.plansearch_completion_tokens += response.usage.completion_tokens
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def solve(self, problem: str, num_initial_observations: int = 3, num_derived_observations: int = 2) -> Tuple[str, str]:
|
||||
@@ -134,4 +135,4 @@ Please implement the solution in Python."""
|
||||
|
||||
def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1) -> List[str]:
|
||||
planner = PlanSearch(system_prompt, client, model)
|
||||
return planner.solve_multiple(initial_query, n)
|
||||
return planner.solve_multiple(initial_query, n), planner.plansearch_completion_tokens
|
||||
|
||||
@@ -4,7 +4,10 @@ from typing import List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
pvg_completion_tokens = 0
|
||||
|
||||
def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7) -> List[str]:
|
||||
global pvg_completion_tokens
|
||||
role = "sneaky" if is_sneaky else "helpful"
|
||||
logger.info(f"Generating {num_solutions} {role} solutions")
|
||||
|
||||
@@ -34,11 +37,13 @@ def generate_solutions(client, system_prompt: str, query: str, model: str, num_s
|
||||
max_tokens=4096,
|
||||
temperature=temperature,
|
||||
)
|
||||
pvg_completion_tokens += response.usage.completion_tokens
|
||||
solutions = [choice.message.content for choice in response.choices]
|
||||
logger.debug(f"Generated {role} solutions: {solutions}")
|
||||
return solutions
|
||||
|
||||
def verify_solutions(client, system_prompt: str, initial_query: str, solutions: List[str], model: str) -> List[float]:
|
||||
global pvg_completion_tokens
|
||||
logger.info(f"Verifying {len(solutions)} solutions")
|
||||
verify_prompt = f"""{system_prompt}
|
||||
You are a verifier tasked with evaluating the correctness and clarity of solutions to the given problem.
|
||||
@@ -75,6 +80,7 @@ Ensure that the Score is a single number between 0 and 10, and the Explanation i
|
||||
max_tokens=1024,
|
||||
temperature=0.2,
|
||||
)
|
||||
pvg_completion_tokens += response.usage.completion_tokens
|
||||
rating = response.choices[0].message.content
|
||||
logger.debug(f"Raw rating for solution {i+1}: {rating}")
|
||||
|
||||
@@ -130,6 +136,7 @@ def extract_answer(final_state: str) -> Tuple[str, float]:
|
||||
return "", 0.0
|
||||
|
||||
def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3) -> str:
|
||||
global pvg_completion_tokens
|
||||
logger.info(f"Starting inference-time PV game with {num_rounds} rounds and {num_solutions} solutions per round")
|
||||
|
||||
best_solution = ""
|
||||
@@ -178,9 +185,10 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model
|
||||
max_tokens=1024,
|
||||
temperature=0.5,
|
||||
)
|
||||
pvg_completion_tokens += response.usage.completion_tokens
|
||||
initial_query = response.choices[0].message.content
|
||||
logger.debug(f"Refined query: {initial_query}")
|
||||
|
||||
logger.info(f"Inference-time PV game completed. Best solution score: {best_score}")
|
||||
|
||||
return best_solution
|
||||
return best_solution, pvg_completion_tokens
|
||||
@@ -17,6 +17,7 @@ def re2_approach(system_prompt, initial_query, client, model, n=1):
|
||||
str or list: The generated response(s) from the model.
|
||||
"""
|
||||
logger.info("Using RE2 approach for query processing")
|
||||
re2_completion_tokens = 0
|
||||
|
||||
# Construct the RE2 prompt
|
||||
re2_prompt = f"{initial_query}\nRead the question again: {initial_query}"
|
||||
@@ -32,11 +33,11 @@ def re2_approach(system_prompt, initial_query, client, model, n=1):
|
||||
messages=messages,
|
||||
n=n
|
||||
)
|
||||
|
||||
re2_completion_tokens += response.usage.completion_tokens
|
||||
if n == 1:
|
||||
return response.choices[0].message.content.strip()
|
||||
return response.choices[0].message.content.strip(), re2_completion_tokens
|
||||
else:
|
||||
return [choice.message.content.strip() for choice in response.choices]
|
||||
return [choice.message.content.strip() for choice in response.choices], re2_completion_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RE2 approach: {str(e)}")
|
||||
|
||||
@@ -30,6 +30,7 @@ class RStar:
|
||||
self.actions = ["A1", "A2", "A3", "A4", "A5"]
|
||||
self.original_question = None
|
||||
self.system = system
|
||||
self.rstar_completion_tokens = 0
|
||||
logger.debug(f"Initialized RStar with model: {model}, max_depth: {max_depth}, num_rollouts: {num_rollouts}")
|
||||
|
||||
async def generate_response_async(self, prompt: str) -> str:
|
||||
@@ -46,6 +47,7 @@ class RStar:
|
||||
}
|
||||
) as response:
|
||||
result = await response.json()
|
||||
self.rstar_completion_tokens += result['usage']['completion_tokens']
|
||||
return result['choices'][0]['message']['content'].strip()
|
||||
|
||||
async def expand_async(self, node: Node, action: str) -> Node:
|
||||
@@ -101,7 +103,7 @@ class RStar:
|
||||
answers = [self.extract_answer(node.state) for node in final_trajectory]
|
||||
final_answer = self.select_best_answer(answers)
|
||||
logger.info(f"Selected final answer: {final_answer}")
|
||||
return final_answer
|
||||
return final_answer, self.rstar_completion_tokens
|
||||
|
||||
def generate_response(self, prompt: str) -> str:
|
||||
logger.debug(f"Generating response for prompt: {prompt[:100]}...")
|
||||
@@ -114,6 +116,7 @@ class RStar:
|
||||
max_tokens=4096,
|
||||
temperature=0.2
|
||||
)
|
||||
self.rstar_completion_tokens += response.usage.completion_tokens
|
||||
generated_response = response.choices[0].message.content.strip()
|
||||
logger.debug(f"Generated response: {generated_response}")
|
||||
return generated_response
|
||||
@@ -342,5 +345,4 @@ This rephrasing should help clarify the problem and guide the solution process."
|
||||
"""
|
||||
Synchronous wrapper for solve_async method.
|
||||
"""
|
||||
return asyncio.run(self.solve_async(question))
|
||||
|
||||
return asyncio.run(self.solve_async(question))
|
||||
@@ -14,6 +14,7 @@ def extract_code_from_prompt(text):
|
||||
return text
|
||||
|
||||
def round_trip_optimization(system_prompt: str, initial_query: str, client, model: str) -> str:
|
||||
rto_completion_tokens = 0
|
||||
messages = [{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": initial_query}]
|
||||
|
||||
@@ -26,6 +27,7 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode
|
||||
temperature=0.1
|
||||
)
|
||||
c1 = response_c1.choices[0].message.content
|
||||
rto_completion_tokens += response_c1.usage.completion_tokens
|
||||
|
||||
# Generate description of the code (Q2)
|
||||
messages.append({"role": "assistant", "content": c1})
|
||||
@@ -38,6 +40,7 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode
|
||||
temperature=0.1
|
||||
)
|
||||
q2 = response_q2.choices[0].message.content
|
||||
rto_completion_tokens += response_q2.usage.completion_tokens
|
||||
|
||||
# Generate second code based on the description (C2)
|
||||
messages = [{"role": "system", "content": system_prompt},
|
||||
@@ -50,6 +53,7 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode
|
||||
temperature=0.1
|
||||
)
|
||||
c2 = response_c2.choices[0].message.content
|
||||
rto_completion_tokens += response_c2.usage.completion_tokens
|
||||
|
||||
c1 = extract_code_from_prompt(c1)
|
||||
c2 = extract_code_from_prompt(c2)
|
||||
@@ -67,5 +71,6 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode
|
||||
temperature=0.1
|
||||
)
|
||||
c3 = response_c3.choices[0].message.content
|
||||
rto_completion_tokens += response_c3.usage.completion_tokens
|
||||
|
||||
return c3
|
||||
return c3, rto_completion_tokens
|
||||
|
||||
@@ -10,6 +10,7 @@ class AdvancedSelfConsistency:
|
||||
self.model = model
|
||||
self.num_samples = num_samples
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.self_consistency_completion_tokens = 0
|
||||
|
||||
def generate_responses(self, system_prompt: str, user_prompt: str) -> List[str]:
|
||||
responses = []
|
||||
@@ -23,6 +24,7 @@ class AdvancedSelfConsistency:
|
||||
temperature=1,
|
||||
max_tokens=4096
|
||||
)
|
||||
self.self_consistency_completion_tokens += response.usage.completion_tokens
|
||||
responses.append(response.choices[0].message.content)
|
||||
return responses
|
||||
|
||||
@@ -85,6 +87,6 @@ def advanced_self_consistency_approach(system_prompt: str, initial_query: str, c
|
||||
logger.debug(f" Variants: {cluster['variants']}")
|
||||
|
||||
if result['aggregated_result']['clusters']:
|
||||
return result['aggregated_result']['clusters'][0]['answer']
|
||||
return result['aggregated_result']['clusters'][0]['answer'], self_consistency.self_consistency_completion_tokens
|
||||
else:
|
||||
return "No consistent answer found."
|
||||
return "No consistent answer found.", self_consistency.self_consistency_completion_tokens
|
||||
|
||||
@@ -6,12 +6,119 @@ import re
|
||||
import contextlib
|
||||
import logging
|
||||
import ast
|
||||
import math
|
||||
import itertools
|
||||
from fractions import Fraction
|
||||
import threading
|
||||
import ctypes
|
||||
import multiprocessing
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
class TimeoutException(Exception):
|
||||
pass
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutException("Execution timed out")
|
||||
def prepare_safe_globals():
|
||||
safe_globals = {
|
||||
'print': print,
|
||||
'__builtins__': {
|
||||
'True': True,
|
||||
'False': False,
|
||||
'None': None,
|
||||
'abs': abs,
|
||||
'float': float,
|
||||
'int': int,
|
||||
'len': len,
|
||||
'max': max,
|
||||
'min': min,
|
||||
'round': round,
|
||||
'sum': sum,
|
||||
'complex': complex,
|
||||
}
|
||||
}
|
||||
|
||||
# Add common math functions
|
||||
safe_globals.update({
|
||||
'log': math.log,
|
||||
'log2': math.log2,
|
||||
'sqrt': math.sqrt,
|
||||
'exp': math.exp,
|
||||
'sin': math.sin,
|
||||
'cos': math.cos,
|
||||
'tan': math.tan,
|
||||
'pi': math.pi,
|
||||
'e': math.e,
|
||||
})
|
||||
|
||||
# Add complex number support
|
||||
safe_globals['I'] = complex(0, 1)
|
||||
safe_globals['Complex'] = complex
|
||||
|
||||
return safe_globals
|
||||
|
||||
def execute_code_in_process(code: str):
|
||||
import z3
|
||||
import math
|
||||
import itertools
|
||||
from fractions import Fraction
|
||||
|
||||
safe_globals = prepare_safe_globals()
|
||||
|
||||
# Add Z3 specific functions
|
||||
z3_whitelist = set(dir(z3))
|
||||
safe_globals.update({name: getattr(z3, name) for name in z3_whitelist})
|
||||
|
||||
# Ensure key Z3 components are available
|
||||
safe_globals.update({
|
||||
'z3': z3,
|
||||
'Solver': z3.Solver,
|
||||
'solver': z3.Solver,
|
||||
'Optimize': z3.Optimize,
|
||||
'sat': z3.sat,
|
||||
'unsat': z3.unsat,
|
||||
'unknown': z3.unknown,
|
||||
'Real': z3.Real,
|
||||
'Int': z3.Int,
|
||||
'Bool': z3.Bool,
|
||||
'And': z3.And,
|
||||
'Or': z3.Or,
|
||||
'Not': z3.Not,
|
||||
'Implies': z3.Implies,
|
||||
'If': z3.If,
|
||||
'Sum': z3.Sum,
|
||||
'ForAll': z3.ForAll,
|
||||
'Exists': z3.Exists,
|
||||
'model': z3.Model,
|
||||
})
|
||||
|
||||
# Add custom functions
|
||||
def as_numerical(x):
|
||||
if z3.is_expr(x):
|
||||
if z3.is_int_value(x) or z3.is_rational_value(x):
|
||||
return float(x.as_decimal(20))
|
||||
elif z3.is_algebraic_value(x):
|
||||
return x.approx(20)
|
||||
return float(x)
|
||||
|
||||
safe_globals['as_numerical'] = as_numerical
|
||||
|
||||
def Mod(x, y):
|
||||
return x % y
|
||||
|
||||
safe_globals['Mod'] = Mod
|
||||
|
||||
def Rational(numerator, denominator=1):
|
||||
return z3.Real(str(Fraction(numerator, denominator)))
|
||||
|
||||
safe_globals['Rational'] = Rational
|
||||
|
||||
output_buffer = io.StringIO()
|
||||
with contextlib.redirect_stdout(output_buffer):
|
||||
try:
|
||||
exec(code, safe_globals, {})
|
||||
except Exception:
|
||||
return ("error", traceback.format_exc())
|
||||
return ("success", output_buffer.getvalue())
|
||||
|
||||
class Z3SolverSystem:
|
||||
def __init__(self, system_prompt: str, client, model: str, timeout: int = 30):
|
||||
@@ -19,6 +126,7 @@ class Z3SolverSystem:
|
||||
self.model = model
|
||||
self.client = client
|
||||
self.timeout = timeout
|
||||
self.z3_completion_tokens = 0
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
def process_query(self, query: str) -> str:
|
||||
@@ -26,17 +134,17 @@ class Z3SolverSystem:
|
||||
analysis = self.analyze_query(query)
|
||||
# print("Analysis: "+ analysis)
|
||||
if "SOLVER_CAN_BE_APPLIED: True" not in analysis:
|
||||
return self.standard_llm_inference(query)
|
||||
return self.standard_llm_inference(query) , self.z3_completion_tokens
|
||||
|
||||
formulation = self.extract_and_validate_expressions(analysis)
|
||||
# print("Formulation: "+ formulation)
|
||||
solver_result = self.solve_with_z3(formulation)
|
||||
# print(solver_result)
|
||||
|
||||
return self.generate_response(query, analysis, solver_result)
|
||||
|
||||
return self.generate_response(query, analysis, solver_result), self.z3_completion_tokens
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred while processing the query with Z3, returning standard llm inference results: {str(e)}")
|
||||
return self.standard_llm_inference(query)
|
||||
return self.standard_llm_inference(query), self.z3_completion_tokens
|
||||
|
||||
def analyze_query(self, query: str) -> str:
|
||||
analysis_prompt = f"""Analyze the given query and determine if it can be solved using Z3:
|
||||
@@ -45,7 +153,8 @@ class Z3SolverSystem:
|
||||
2. Determine the problem type (e.g., SAT, optimization).
|
||||
3. Decide if Z3 is suitable.
|
||||
|
||||
If Z3 can be applied, provide Python code using Z3 to solve the problem.
|
||||
If Z3 can be applied, provide Python code using Z3 to solve the problem. Make sure you define any additional methods you need for solving the problem.
|
||||
The code will be executed in an environment with only Z3 available, so do not include any other libraries or modules.
|
||||
|
||||
Query: {query}
|
||||
|
||||
@@ -71,6 +180,7 @@ Analysis:
|
||||
n=1,
|
||||
temperature=0.1
|
||||
)
|
||||
self.z3_completion_tokens = analysis_response.usage.completion_tokens
|
||||
return analysis_response.choices[0].message.content
|
||||
|
||||
def generate_response(self, query: str, analysis: str, solver_result: Dict[str, Any]) -> str:
|
||||
@@ -98,6 +208,7 @@ Response:
|
||||
n=1,
|
||||
temperature=0.1
|
||||
)
|
||||
self.z3_completion_tokens = response.usage.completion_tokens
|
||||
return response.choices[0].message.content
|
||||
|
||||
def standard_llm_inference(self, query: str) -> str:
|
||||
@@ -111,6 +222,7 @@ Response:
|
||||
n=1,
|
||||
temperature=0.1
|
||||
)
|
||||
self.z3_completion_tokens = response.usage.completion_tokens
|
||||
return response.choices[0].message.content
|
||||
|
||||
def extract_and_validate_expressions(self, analysis: str) -> str:
|
||||
@@ -148,12 +260,14 @@ Provide corrected Z3 code:
|
||||
n=1,
|
||||
temperature=0.1
|
||||
)
|
||||
self.z3_completion_tokens = response.usage.completion_tokens
|
||||
formulation = self.extract_and_validate_expressions(response.choices[0].message.content)
|
||||
|
||||
return {"status": "failed", "output": "Failed to solve after multiple attempts."}
|
||||
|
||||
def execute_solver_code(self, code: str) -> str:
|
||||
logging.info("Executing Z3 solver code")
|
||||
logging.info(f"Code: {code}")
|
||||
|
||||
# Define a whitelist of allowed Z3 names
|
||||
z3_whitelist = set(dir(z3))
|
||||
@@ -169,16 +283,19 @@ Provide corrected Z3 code:
|
||||
for node in ast.walk(parsed_ast):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
if alias.name != 'z3':
|
||||
if alias.name not in ['z3', 'math', 'fractions', 'itertools']:
|
||||
logging.warning(f"Unauthorized import: {alias.name}")
|
||||
return f"Error: Unauthorized import: {alias.name}"
|
||||
elif isinstance(node, ast.ImportFrom) and node.module != 'z3':
|
||||
logging.warning(f"Unauthorized import from: {node.module}")
|
||||
return f"Error: Unauthorized import from: {node.module}"
|
||||
elif isinstance(node, ast.ImportFrom) and node.module not in ['z3', 'math', 'fractions', 'itertools']:
|
||||
logging.warning(f"Unauthorized import from: {node.module}")
|
||||
return f"Error: Unauthorized import from: {node.module}"
|
||||
|
||||
# Prepare a restricted global namespace
|
||||
safe_globals = {
|
||||
'z3': z3,
|
||||
'math': math,
|
||||
'itertools': itertools,
|
||||
'Fraction': Fraction,
|
||||
'print': print, # Allow print for output
|
||||
'__builtins__': {
|
||||
'True': True,
|
||||
@@ -192,19 +309,67 @@ Provide corrected Z3 code:
|
||||
'min': min,
|
||||
'round': round,
|
||||
'sum': sum,
|
||||
'complex': complex,
|
||||
}
|
||||
}
|
||||
safe_globals.update({name: getattr(z3, name) for name in z3_whitelist})
|
||||
|
||||
# Add common math functions
|
||||
safe_globals.update({
|
||||
'log': math.log,
|
||||
'sqrt': math.sqrt,
|
||||
'exp': math.exp,
|
||||
'sin': math.sin,
|
||||
'cos': math.cos,
|
||||
'tan': math.tan,
|
||||
'pi': math.pi,
|
||||
'e': math.e,
|
||||
})
|
||||
|
||||
# Execute the code
|
||||
output_buffer = io.StringIO()
|
||||
with contextlib.redirect_stdout(output_buffer):
|
||||
# Add complex number support
|
||||
safe_globals['I'] = complex(0, 1)
|
||||
safe_globals['Complex'] = complex
|
||||
|
||||
# Add Z3 specific types and functions
|
||||
safe_globals['Optimize'] = z3.Optimize
|
||||
|
||||
# Add custom functions for Z3 specific operations
|
||||
def as_numerical(x):
|
||||
if z3.is_expr(x):
|
||||
if z3.is_int_value(x) or z3.is_rational_value(x):
|
||||
return float(x.as_decimal(20))
|
||||
elif z3.is_algebraic_value(x):
|
||||
return x.approx(20)
|
||||
return float(x)
|
||||
|
||||
safe_globals['as_numerical'] = as_numerical
|
||||
|
||||
# Add a custom Mod function that uses Z3's modulo operator
|
||||
def Mod(x, y):
|
||||
return x % y
|
||||
|
||||
safe_globals['Mod'] = Mod
|
||||
|
||||
# Add a custom Rational function to create rational numbers in Z3
|
||||
def Rational(numerator, denominator=1):
|
||||
return z3.Real(str(Fraction(numerator, denominator)))
|
||||
|
||||
safe_globals['Rational'] = Rational
|
||||
|
||||
# Execute the code in a separate process
|
||||
ctx = multiprocessing.get_context('spawn')
|
||||
with ctx.Pool(1) as pool:
|
||||
async_result = pool.apply_async(execute_code_in_process, (code,))
|
||||
try:
|
||||
exec(code, safe_globals, {})
|
||||
except Exception as e:
|
||||
logging.error(f"Execution error: {str(e)}")
|
||||
return f"Error: Execution error: {str(e)}"
|
||||
status, result = async_result.get(timeout=self.timeout)
|
||||
except multiprocessing.TimeoutError:
|
||||
pool.terminate()
|
||||
logging.error("Execution timed out")
|
||||
return "Error: Execution timed out"
|
||||
|
||||
if status == "error":
|
||||
logging.error(f"Execution error: {result}")
|
||||
return f"Error: {result}"
|
||||
|
||||
executed_output = output_buffer.getvalue()
|
||||
logging.info("Z3 solver code executed successfully")
|
||||
return executed_output
|
||||
return result
|
||||
Reference in New Issue
Block a user