mirror of
https://github.com/codelion/optillm.git
synced 2025-05-28 09:39:38 +03:00
Plugins are loaded in the following order - from package plugins folder - current directory/optillm/plugins folder for same SLUG we will override with local plugins
503 lines
20 KiB
Python
503 lines
20 KiB
Python
import argparse
|
|
import logging
|
|
import os
|
|
import secrets
|
|
from flask import Flask, request, jsonify
|
|
from openai import AzureOpenAI, OpenAI
|
|
from flask import Response
|
|
import json
|
|
import importlib
|
|
import glob
|
|
import asyncio
|
|
import re
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
# Import the LiteLLM wrapper
|
|
from optillm.litellm_wrapper import LiteLLMWrapper
|
|
|
|
# Import approach modules
|
|
from optillm.mcts import chat_with_mcts
|
|
from optillm.bon import best_of_n_sampling
|
|
from optillm.moa import mixture_of_agents
|
|
from optillm.rto import round_trip_optimization
|
|
from optillm.self_consistency import advanced_self_consistency_approach
|
|
from optillm.pvg import inference_time_pv_game
|
|
from optillm.z3_solver import Z3SymPySolverSystem
|
|
from optillm.rstar import RStar
|
|
from optillm.cot_reflection import cot_reflection
|
|
from optillm.plansearch import plansearch
|
|
from optillm.leap import leap
|
|
from optillm.reread import re2_approach
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
logging_levels = {
|
|
"notset": logging.NOTSET,
|
|
"debug": logging.DEBUG,
|
|
"info": logging.INFO,
|
|
"warning": logging.WARNING,
|
|
"error": logging.ERROR,
|
|
"critical": logging.CRITICAL,
|
|
}
|
|
|
|
# Initialize Flask app
|
|
app = Flask(__name__)
|
|
|
|
def get_config():
|
|
API_KEY = None
|
|
# OpenAI, Azure, or LiteLLM API configuration
|
|
if os.environ.get("OPENAI_API_KEY"):
|
|
API_KEY = os.environ.get("OPENAI_API_KEY")
|
|
base_url = server_config['base_url']
|
|
if base_url != "":
|
|
default_client = OpenAI(api_key=API_KEY, base_url=base_url)
|
|
else:
|
|
default_client = OpenAI(api_key=API_KEY)
|
|
elif os.environ.get("AZURE_OPENAI_API_KEY"):
|
|
API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
|
|
API_VERSION = os.environ.get("AZURE_API_VERSION")
|
|
AZURE_ENDPOINT = os.environ.get("AZURE_API_BASE")
|
|
if API_KEY is not None:
|
|
default_client = AzureOpenAI(
|
|
api_key=API_KEY,
|
|
api_version=API_VERSION,
|
|
azure_endpoint=AZURE_ENDPOINT,
|
|
)
|
|
else:
|
|
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
azure_credential = DefaultAzureCredential()
|
|
token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
|
|
default_client = AzureOpenAI(
|
|
api_version=API_VERSION,
|
|
azure_endpoint=AZURE_ENDPOINT,
|
|
azure_ad_token_provider=token_provider
|
|
)
|
|
else:
|
|
default_client = LiteLLMWrapper()
|
|
return default_client, API_KEY
|
|
|
|
# Server configuration
|
|
server_config = {
|
|
'approach': 'bon',
|
|
'mcts_simulations': 2,
|
|
'mcts_exploration': 0.2,
|
|
'mcts_depth': 1,
|
|
'best_of_n': 3,
|
|
'model': 'gpt-4o-mini',
|
|
'rstar_max_depth': 3,
|
|
'rstar_num_rollouts': 5,
|
|
'rstar_c': 1.4,
|
|
'n': 1,
|
|
'base_url': '',
|
|
'optillm_api_key': '',
|
|
'return_full_response': False,
|
|
'port': 8000,
|
|
'log': 'info',
|
|
}
|
|
|
|
# List of known approaches
|
|
known_approaches = ["mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar",
|
|
"cot_reflection", "plansearch", "leap", "re2"]
|
|
|
|
plugin_approaches = {}
|
|
|
|
def load_plugins():
|
|
# Clear existing plugins first but modify the global dict in place
|
|
plugin_approaches.clear()
|
|
|
|
# Get installed package plugins directory
|
|
import optillm
|
|
package_plugin_dir = os.path.join(os.path.dirname(optillm.__file__), 'plugins')
|
|
|
|
# Get local project plugins directory
|
|
current_dir = os.getcwd()
|
|
local_plugin_dir = os.path.join(current_dir, 'optillm', 'plugins')
|
|
|
|
plugin_dirs = []
|
|
|
|
# Add package plugin dir
|
|
plugin_dirs.append((package_plugin_dir, "package"))
|
|
|
|
# Add local plugin dir only if it's different from package dir
|
|
if local_plugin_dir != package_plugin_dir:
|
|
plugin_dirs.append((local_plugin_dir, "local"))
|
|
|
|
for plugin_dir, source in plugin_dirs:
|
|
logger.info(f"Looking for {source} plugins in: {plugin_dir}")
|
|
|
|
if not os.path.exists(plugin_dir):
|
|
logger.debug(f"{source.capitalize()} plugin directory not found: {plugin_dir}")
|
|
continue
|
|
|
|
plugin_files = glob.glob(os.path.join(plugin_dir, '*.py'))
|
|
if not plugin_files:
|
|
logger.debug(f"No plugin files found in {source} directory: {plugin_dir}")
|
|
continue
|
|
|
|
logger.info(f"Found {source} plugin files: {plugin_files}")
|
|
|
|
for plugin_file in plugin_files:
|
|
try:
|
|
module_name = os.path.basename(plugin_file)[:-3] # Remove .py extension
|
|
spec = importlib.util.spec_from_file_location(module_name, plugin_file)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
|
|
if hasattr(module, 'SLUG') and hasattr(module, 'run'):
|
|
if module.SLUG in plugin_approaches:
|
|
logger.info(f"Overriding {source} plugin: {module.SLUG}")
|
|
plugin_approaches[module.SLUG] = module.run
|
|
logger.info(f"Loaded {source} plugin: {module.SLUG}")
|
|
else:
|
|
logger.warning(f"Plugin {module_name} from {source} missing required attributes (SLUG and run)")
|
|
except Exception as e:
|
|
logger.error(f"Error loading {source} plugin {plugin_file}: {str(e)}")
|
|
|
|
if not plugin_approaches:
|
|
logger.warning("No plugins loaded from any location")
|
|
|
|
def parse_combined_approach(model: str, known_approaches: list, plugin_approaches: dict):
|
|
if model == 'auto':
|
|
return 'SINGLE', ['bon'], model
|
|
|
|
parts = model.split('-')
|
|
approaches = []
|
|
operation = 'SINGLE'
|
|
model_parts = []
|
|
parsing_approaches = True
|
|
|
|
for part in parts:
|
|
if parsing_approaches:
|
|
if part in known_approaches or part in plugin_approaches:
|
|
approaches.append(part)
|
|
elif '&' in part:
|
|
operation = 'AND'
|
|
approaches.extend(part.split('&'))
|
|
elif '|' in part:
|
|
operation = 'OR'
|
|
approaches.extend(part.split('|'))
|
|
else:
|
|
parsing_approaches = False
|
|
model_parts.append(part)
|
|
else:
|
|
model_parts.append(part)
|
|
|
|
if not approaches:
|
|
approaches = ['bon']
|
|
operation = 'SINGLE'
|
|
|
|
actual_model = '-'.join(model_parts)
|
|
|
|
return operation, approaches, actual_model
|
|
|
|
def execute_single_approach(approach, system_prompt, initial_query, client, model):
|
|
if approach in known_approaches:
|
|
# Execute known approaches
|
|
if approach == 'mcts':
|
|
return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
|
|
server_config['mcts_exploration'], server_config['mcts_depth'])
|
|
elif approach == 'bon':
|
|
return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
|
|
elif approach == 'moa':
|
|
return mixture_of_agents(system_prompt, initial_query, client, model)
|
|
elif approach == 'rto':
|
|
return round_trip_optimization(system_prompt, initial_query, client, model)
|
|
elif approach == 'z3':
|
|
z3_solver = Z3SymPySolverSystem(system_prompt, client, model)
|
|
return z3_solver.process_query(initial_query)
|
|
elif approach == "self_consistency":
|
|
return advanced_self_consistency_approach(system_prompt, initial_query, client, model)
|
|
elif approach == "pvg":
|
|
return inference_time_pv_game(system_prompt, initial_query, client, model)
|
|
elif approach == "rstar":
|
|
rstar = RStar(system_prompt, client, model,
|
|
max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
|
|
c=server_config['rstar_c'])
|
|
return rstar.solve(initial_query)
|
|
elif approach == "cot_reflection":
|
|
return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
|
|
elif approach == 'plansearch':
|
|
return plansearch(system_prompt, initial_query, client, model, n=server_config['n'])
|
|
elif approach == 'leap':
|
|
return leap(system_prompt, initial_query, client, model)
|
|
elif approach == 're2':
|
|
return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'])
|
|
elif approach in plugin_approaches:
|
|
return plugin_approaches[approach](system_prompt, initial_query, client, model)
|
|
else:
|
|
raise ValueError(f"Unknown approach: {approach}")
|
|
|
|
def execute_combined_approaches(approaches, system_prompt, initial_query, client, model):
|
|
final_response = initial_query
|
|
total_tokens = 0
|
|
for approach in approaches:
|
|
response, tokens = execute_single_approach(approach, system_prompt, final_response, client, model)
|
|
final_response = response
|
|
total_tokens += tokens
|
|
return final_response, total_tokens
|
|
|
|
async def execute_parallel_approaches(approaches, system_prompt, initial_query, client, model):
|
|
async def run_approach(approach):
|
|
return await asyncio.to_thread(execute_single_approach, approach, system_prompt, initial_query, client, model)
|
|
|
|
tasks = [run_approach(approach) for approach in approaches]
|
|
results = await asyncio.gather(*tasks)
|
|
responses, tokens = zip(*results)
|
|
return list(responses), sum(tokens)
|
|
|
|
def generate_streaming_response(final_response, model):
|
|
# Yield the final response
|
|
if isinstance(final_response, list):
|
|
for index, response in enumerate(final_response):
|
|
yield "data: " + json.dumps({
|
|
"choices": [{"delta": {"content": response}, "index": index, "finish_reason": "stop"}],
|
|
"model": model,
|
|
}) + "\n\n"
|
|
else:
|
|
yield "data: " + json.dumps({
|
|
"choices": [{"delta": {"content": final_response}, "index": 0, "finish_reason": "stop"}],
|
|
"model": model,
|
|
}) + "\n\n"
|
|
|
|
# Yield the final message to indicate the stream has ended
|
|
yield "data: [DONE]\n\n"
|
|
|
|
def parse_conversation(messages):
|
|
system_prompt = ""
|
|
conversation = []
|
|
optillm_approach = None
|
|
|
|
for message in messages:
|
|
role = message['role']
|
|
content = message['content']
|
|
|
|
if role == 'system':
|
|
system_prompt, optillm_approach = extract_optillm_approach(content)
|
|
elif role == 'user':
|
|
if not optillm_approach:
|
|
content, optillm_approach = extract_optillm_approach(content)
|
|
conversation.append(f"User: {content}")
|
|
elif role == 'assistant':
|
|
conversation.append(f"Assistant: {content}")
|
|
|
|
initial_query = "\n".join(conversation)
|
|
return system_prompt, initial_query, optillm_approach
|
|
|
|
def extract_optillm_approach(content):
|
|
match = re.search(r'<optillm_approach>(.*?)</optillm_approach>', content)
|
|
if match:
|
|
approach = match.group(1)
|
|
content = re.sub(r'<optillm_approach>.*?</optillm_approach>', '', content).strip()
|
|
return content, approach
|
|
return content, None
|
|
|
|
# Optional API key configuration to secure the proxy
|
|
@app.before_request
|
|
def check_api_key():
|
|
if server_config['optillm_api_key']:
|
|
if request.path == "/health":
|
|
return
|
|
|
|
auth_header = request.headers.get('Authorization')
|
|
if not auth_header or not auth_header.startswith('Bearer '):
|
|
return jsonify({"error": "Invalid Authorization header. Expected format: 'Authorization: Bearer YOUR_API_KEY'"}), 401
|
|
|
|
client_key = auth_header.split('Bearer ', 1)[1].strip()
|
|
if not secrets.compare_digest(client_key, server_config['optillm_api_key']):
|
|
return jsonify({"error": "Invalid API key"}), 401
|
|
|
|
@app.route('/v1/chat/completions', methods=['POST'])
|
|
def proxy():
|
|
logger.info('Received request to /v1/chat/completions')
|
|
data = request.get_json()
|
|
auth_header = request.headers.get("Authorization")
|
|
bearer_token = ""
|
|
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
# Extract the bearer token
|
|
bearer_token = auth_header.split("Bearer ")[1].strip()
|
|
logger.debug(f"Intercepted Bearer Token: {bearer_token}")
|
|
|
|
logger.debug(f'Request data: {data}')
|
|
|
|
stream = data.get('stream', False)
|
|
messages = data.get('messages', [])
|
|
model = data.get('model', server_config['model'])
|
|
|
|
optillm_approach = data.get('optillm_approach', server_config['approach'])
|
|
|
|
system_prompt, initial_query, message_optillm_approach = parse_conversation(messages)
|
|
|
|
if message_optillm_approach:
|
|
optillm_approach = message_optillm_approach
|
|
|
|
if optillm_approach != "auto":
|
|
model = f"{optillm_approach}-{model}"
|
|
|
|
base_url = server_config['base_url']
|
|
default_client, api_key = get_config()
|
|
|
|
operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches)
|
|
logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}')
|
|
|
|
if bearer_token != "" and bearer_token.startswith("sk-"):
|
|
api_key = bearer_token
|
|
if base_url != "":
|
|
client = OpenAI(api_key=api_key, base_url=base_url)
|
|
else:
|
|
client = OpenAI(api_key=api_key)
|
|
else:
|
|
client = default_client
|
|
|
|
try:
|
|
if operation == 'SINGLE':
|
|
final_response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
|
|
elif operation == 'AND':
|
|
final_response, completion_tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model)
|
|
elif operation == 'OR':
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
final_response, completion_tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model))
|
|
else:
|
|
raise ValueError(f"Unknown operation: {operation}")
|
|
except Exception as e:
|
|
logger.error(f"Error processing request: {str(e)}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
if stream:
|
|
return Response(generate_streaming_response(final_response, model), content_type='text/event-stream')
|
|
else:
|
|
response_data = {
|
|
'model': model,
|
|
'choices': [],
|
|
'usage': {
|
|
'completion_tokens': completion_tokens,
|
|
}
|
|
}
|
|
|
|
if isinstance(final_response, list):
|
|
for index, response in enumerate(final_response):
|
|
response_data['choices'].append({
|
|
'index': index,
|
|
'message': {
|
|
'role': 'assistant',
|
|
'content': response,
|
|
},
|
|
'finish_reason': 'stop'
|
|
})
|
|
else:
|
|
response_data['choices'].append({
|
|
'index': 0,
|
|
'message': {
|
|
'role': 'assistant',
|
|
'content': final_response,
|
|
},
|
|
'finish_reason': 'stop'
|
|
})
|
|
|
|
logger.debug(f'API response: {response_data}')
|
|
return jsonify(response_data), 200
|
|
|
|
@app.route('/v1/models', methods=['GET'])
|
|
def proxy_models():
|
|
logger.info('Received request to /v1/models')
|
|
default_client, API_KEY = get_config()
|
|
try:
|
|
if server_config['base_url']:
|
|
client = OpenAI(api_key=API_KEY, base_url=server_config['base_url'])
|
|
else:
|
|
client = default_client
|
|
|
|
# Fetch models using the OpenAI client and return the raw response
|
|
models_response = client.models.list()
|
|
|
|
logger.debug('Models retrieved successfully')
|
|
return models_response.model_dump(), 200
|
|
except Exception as e:
|
|
logger.error(f"Error fetching models: {str(e)}")
|
|
return jsonify({"error": f"Error fetching models: {str(e)}"}), 500
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
def health():
|
|
return jsonify({"status": "ok"}), 200
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Run LLM inference with various approaches.")
|
|
|
|
# Define arguments and their corresponding environment variables
|
|
args_env = [
|
|
("--optillm-api-key", "OPTILLM_API_KEY", str, "", "Optional API key for client authentication to optillm"),
|
|
("--approach", "OPTILLM_APPROACH", str, "auto", "Inference approach to use", known_approaches),
|
|
("--mcts-simulations", "OPTILLM_SIMULATIONS", int, 2, "Number of MCTS simulations"),
|
|
("--mcts-exploration", "OPTILLM_EXPLORATION", float, 0.2, "Exploration weight for MCTS"),
|
|
("--mcts-depth", "OPTILLM_DEPTH", int, 1, "Simulation depth for MCTS"),
|
|
("--model", "OPTILLM_MODEL", str, "gpt-4o-mini", "OpenAI model to use"),
|
|
("--rstar-max-depth", "OPTILLM_RSTAR_MAX_DEPTH", int, 3, "Maximum depth for rStar algorithm"),
|
|
("--rstar-num-rollouts", "OPTILLM_RSTAR_NUM_ROLLOUTS", int, 5, "Number of rollouts for rStar algorithm"),
|
|
("--rstar-c", "OPTILLM_RSTAR_C", float, 1.4, "Exploration constant for rStar algorithm"),
|
|
("--n", "OPTILLM_N", int, 1, "Number of final responses to be returned"),
|
|
("--return-full-response", "OPTILLM_RETURN_FULL_RESPONSE", bool, False, "Return the full response including the CoT with <thinking> tags"),
|
|
("--port", "OPTILLM_PORT", int, 8000, "Specify the port to run the proxy"),
|
|
("--log", "OPTILLM_LOG", str, "info", "Specify the logging level", list(logging_levels.keys()))
|
|
]
|
|
|
|
for arg, env, type_, default, help_text, *extra in args_env:
|
|
env_value = os.environ.get(env)
|
|
if env_value is not None:
|
|
if type_ == bool:
|
|
default = env_value.lower() in ('true', '1', 'yes')
|
|
else:
|
|
default = type_(env_value)
|
|
if extra and extra[0]: # Check if there are choices for this argument
|
|
parser.add_argument(arg, type=type_, default=default, help=help_text, choices=extra[0])
|
|
else:
|
|
parser.add_argument(arg, type=type_, default=default, help=help_text)
|
|
|
|
# Special handling for best_of_n to support both formats
|
|
best_of_n_default = int(os.environ.get("OPTILLM_BEST_OF_N", 3))
|
|
parser.add_argument("--best-of-n", "--best_of_n", dest="best_of_n", type=int, default=best_of_n_default,
|
|
help="Number of samples for best_of_n approach")
|
|
|
|
# Special handling for base_url to support both formats
|
|
base_url_default = os.environ.get("OPTILLM_BASE_URL", "")
|
|
parser.add_argument("--base-url", "--base_url", dest="base_url", type=str, default=base_url_default,
|
|
help="Base url for OpenAI compatible endpoint")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Convert argument names to match server_config keys
|
|
args_dict = vars(args)
|
|
for key in list(args_dict.keys()):
|
|
new_key = key.replace("-", "_")
|
|
if new_key != key:
|
|
args_dict[new_key] = args_dict.pop(key)
|
|
|
|
return args
|
|
|
|
def main():
|
|
global server_config
|
|
args = parse_args()
|
|
|
|
# Call this function at the start of main()
|
|
load_plugins()
|
|
# Update server_config with all argument values
|
|
server_config.update(vars(args))
|
|
|
|
port = server_config['port']
|
|
|
|
# Set logging level from user request
|
|
logging_level = server_config['log']
|
|
if logging_level in logging_levels.keys():
|
|
logger.setLevel(logging_levels[logging_level])
|
|
|
|
logger.info(f"Starting server with approach: {server_config['approach']}")
|
|
server_config_clean = server_config.copy()
|
|
if server_config_clean['optillm_api_key']:
|
|
server_config_clean['optillm_api_key'] = '[REDACTED]'
|
|
logger.info(f"Server configuration: {server_config_clean}")
|
|
app.run(host='0.0.0.0', port=port)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|