Files
autothink/optillm.py
2025-05-13 09:36:20 +08:00

882 lines
35 KiB
Python

import argparse
import logging
import os
import secrets
from flask import Flask, request, jsonify
from cerebras.cloud.sdk import Cerebras
from openai import AzureOpenAI, OpenAI
from flask import Response
import json
import importlib
import glob
import asyncio
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Optional, Union, Dict, Any, List
from importlib.metadata import version
from dataclasses import fields
# Import approach modules
from optillm.mcts import chat_with_mcts
from optillm.bon import best_of_n_sampling
from optillm.moa import mixture_of_agents
from optillm.rto import round_trip_optimization
from optillm.self_consistency import advanced_self_consistency_approach
from optillm.pvg import inference_time_pv_game
from optillm.z3_solver import Z3SymPySolverSystem
from optillm.rstar import RStar
from optillm.cot_reflection import cot_reflection
from optillm.plansearch import plansearch
from optillm.leap import leap
from optillm.reread import re2_approach
from optillm.cepo.cepo import cepo, CepoConfig, init_cepo_config
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logging_levels = {
"notset": logging.NOTSET,
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
# Initialize Flask app
app = Flask(__name__)
def get_config():
API_KEY = None
if os.environ.get("OPTILLM_API_KEY"):
# Use local inference engine
from optillm.inference import create_inference_client
API_KEY = os.environ.get("OPTILLM_API_KEY")
default_client = create_inference_client()
# Cerebras, OpenAI, Azure, or LiteLLM API configuration
elif os.environ.get("CEREBRAS_API_KEY"):
API_KEY = os.environ.get("CEREBRAS_API_KEY")
base_url = server_config['base_url']
if base_url != "":
default_client = Cerebras(api_key=API_KEY, base_url=base_url)
else:
default_client = Cerebras(api_key=API_KEY)
elif os.environ.get("OPENAI_API_KEY"):
API_KEY = os.environ.get("OPENAI_API_KEY")
base_url = server_config['base_url']
if base_url != "":
default_client = OpenAI(api_key=API_KEY, base_url=base_url)
else:
default_client = OpenAI(api_key=API_KEY)
elif os.environ.get("AZURE_OPENAI_API_KEY"):
API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
API_VERSION = os.environ.get("AZURE_API_VERSION")
AZURE_ENDPOINT = os.environ.get("AZURE_API_BASE")
if API_KEY is not None:
default_client = AzureOpenAI(
api_key=API_KEY,
api_version=API_VERSION,
azure_endpoint=AZURE_ENDPOINT,
)
else:
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
azure_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
default_client = AzureOpenAI(
api_version=API_VERSION,
azure_endpoint=AZURE_ENDPOINT,
azure_ad_token_provider=token_provider
)
else:
# Import the LiteLLM wrapper
from optillm.litellm_wrapper import LiteLLMWrapper
default_client = LiteLLMWrapper()
return default_client, API_KEY
# Server configuration
server_config = {
'approach': 'none',
'mcts_simulations': 2,
'mcts_exploration': 0.2,
'mcts_depth': 1,
'best_of_n': 3,
'model': 'gpt-4o-mini',
'rstar_max_depth': 3,
'rstar_num_rollouts': 5,
'rstar_c': 1.4,
'n': 1,
'base_url': '',
'optillm_api_key': '',
'return_full_response': False,
'port': 8000,
'log': 'info',
}
# List of known approaches
known_approaches = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency",
"pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2", "cepo"]
plugin_approaches = {}
def normalize_message_content(messages):
"""
Ensure all message content fields are strings, not lists.
Some models don't handle list-format content correctly.
"""
normalized_messages = []
for message in messages:
normalized_message = message.copy()
content = message.get('content', '')
# Convert list content to string if needed
if isinstance(content, list):
# Extract text content from the list
text_content = ' '.join(
item.get('text', '') for item in content
if isinstance(item, dict) and item.get('type') == 'text'
)
normalized_message['content'] = text_content
normalized_messages.append(normalized_message)
return normalized_messages
def none_approach(
client: Any,
model: str,
original_messages: List[Dict[str, str]],
**kwargs
) -> Dict[str, Any]:
"""
Direct proxy approach that passes through all parameters to the underlying endpoint.
Args:
system_prompt: System prompt text (unused)
initial_query: Initial query/conversation (unused)
client: OpenAI client instance
model: Model identifier
original_messages: Original messages from the request
**kwargs: Additional parameters to pass through
Returns:
Dict[str, Any]: Full OpenAI API response
"""
# Strip 'none-' prefix from model if present
if model.startswith('none-'):
model = model[5:]
try:
# Normalize message content to ensure it's always string
normalized_messages = normalize_message_content(original_messages)
# Make the direct completion call with normalized messages and parameters
response = client.chat.completions.create(
model=model,
messages=normalized_messages,
**kwargs
)
# Convert to dict if it's not already
if hasattr(response, 'model_dump'):
return response.model_dump()
return response
except Exception as e:
logger.error(f"Error in none approach: {str(e)}")
raise
def load_plugins():
# Clear existing plugins first but modify the global dict in place
plugin_approaches.clear()
# Get installed package plugins directory
import optillm
package_plugin_dir = os.path.join(os.path.dirname(optillm.__file__), 'plugins')
# Get local project plugins directory
current_dir = os.getcwd() if server_config.get("plugins_dir", "") == "" else server_config["plugins_dir"]
local_plugin_dir = os.path.join(current_dir, 'optillm', 'plugins')
plugin_dirs = []
# Add package plugin dir
plugin_dirs.append((package_plugin_dir, "package"))
# Add local plugin dir only if it's different from package dir
if local_plugin_dir != package_plugin_dir:
plugin_dirs.append((local_plugin_dir, "local"))
for plugin_dir, source in plugin_dirs:
logger.info(f"Looking for {source} plugins in: {plugin_dir}")
if not os.path.exists(plugin_dir):
logger.debug(f"{source.capitalize()} plugin directory not found: {plugin_dir}")
continue
plugin_files = glob.glob(os.path.join(plugin_dir, '*.py'))
if not plugin_files:
logger.debug(f"No plugin files found in {source} directory: {plugin_dir}")
continue
logger.info(f"Found {source} plugin files: {plugin_files}")
for plugin_file in plugin_files:
try:
module_name = os.path.basename(plugin_file)[:-3] # Remove .py extension
spec = importlib.util.spec_from_file_location(module_name, plugin_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if hasattr(module, 'SLUG') and hasattr(module, 'run'):
if module.SLUG in plugin_approaches:
logger.info(f"Overriding {source} plugin: {module.SLUG}")
plugin_approaches[module.SLUG] = module.run
logger.info(f"Loaded {source} plugin: {module.SLUG}")
else:
logger.warning(f"Plugin {module_name} from {source} missing required attributes (SLUG and run)")
except Exception as e:
logger.error(f"Error loading {source} plugin {plugin_file}: {str(e)}")
if not plugin_approaches:
logger.warning("No plugins loaded from any location")
def get_config_path():
# Get installed package config directory
import optillm
package_config_dir = os.path.join(os.path.dirname(optillm.__file__), 'cepo', 'configs')
package_config_path = os.path.join(package_config_dir, 'cepo_config.yaml')
# Get local project config directory
current_dir = os.getcwd() if server_config.get("config_dir", "") == "" else server_config["config_dir"]
local_config_dir = os.path.join(current_dir, 'optillm', 'cepo', 'configs')
local_config_path = os.path.join(local_config_dir, 'cepo_config.yaml')
# If local config exists and is different from package config, use local
if os.path.exists(local_config_path) and local_config_path != package_config_path:
logger.debug(f"Using local config from: {local_config_path}")
return local_config_path
# Otherwise use package config
logger.debug(f"Using package config from: {package_config_path}")
return package_config_path
def parse_combined_approach(model: str, known_approaches: list, plugin_approaches: dict):
if model == 'auto':
return 'SINGLE', ['none'], model
parts = model.split('-')
approaches = []
operation = 'SINGLE'
model_parts = []
parsing_approaches = True
for part in parts:
if parsing_approaches:
if part in known_approaches or part in plugin_approaches:
approaches.append(part)
elif '&' in part:
operation = 'AND'
approaches.extend(part.split('&'))
elif '|' in part:
operation = 'OR'
approaches.extend(part.split('|'))
else:
parsing_approaches = False
model_parts.append(part)
else:
model_parts.append(part)
if not approaches:
approaches = ['none']
operation = 'SINGLE'
actual_model = '-'.join(model_parts)
return operation, approaches, actual_model
def execute_single_approach(approach, system_prompt, initial_query, client, model, request_config: dict = None):
if approach in known_approaches:
if approach == 'none':
# Extract kwargs from the request data
kwargs = {}
if hasattr(request, 'json'):
data = request.get_json()
messages = data.get('messages', [])
# Copy all parameters except 'stream', 'model' , 'n' and 'messages'
kwargs = {k: v for k, v in data.items()
if k not in ['model', 'messages', 'stream', 'n', 'optillm_approach']}
response = none_approach(original_messages=messages, client=client, model=model, **kwargs)
# For none approach, we return the response and a token count of 0
# since the full token count is already in the response
return response, 0
elif approach == 'mcts':
return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
server_config['mcts_exploration'], server_config['mcts_depth'])
elif approach == 'bon':
return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
elif approach == 'moa':
return mixture_of_agents(system_prompt, initial_query, client, model)
elif approach == 'rto':
return round_trip_optimization(system_prompt, initial_query, client, model)
elif approach == 'z3':
z3_solver = Z3SymPySolverSystem(system_prompt, client, model)
return z3_solver.process_query(initial_query)
elif approach == "self_consistency":
return advanced_self_consistency_approach(system_prompt, initial_query, client, model)
elif approach == "pvg":
return inference_time_pv_game(system_prompt, initial_query, client, model)
elif approach == "rstar":
rstar = RStar(system_prompt, client, model,
max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
c=server_config['rstar_c'])
return rstar.solve(initial_query)
elif approach == "cot_reflection":
return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
elif approach == 'plansearch':
return plansearch(system_prompt, initial_query, client, model, n=server_config['n'])
elif approach == 'leap':
return leap(system_prompt, initial_query, client, model)
elif approach == 're2':
return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'])
elif approach == 'cepo':
return cepo(system_prompt, initial_query, client, model, cepo_config)
elif approach in plugin_approaches:
# Check if the plugin accepts request_config
plugin_func = plugin_approaches[approach]
import inspect
sig = inspect.signature(plugin_func)
# Check if the plugin function is async
is_async = inspect.iscoroutinefunction(plugin_func)
if is_async:
# For async functions, we need to run them in an event loop
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
if 'request_config' in sig.parameters:
# Plugin supports request_config
result = loop.run_until_complete(plugin_func(system_prompt, initial_query, client, model, request_config=request_config))
else:
# Legacy plugin without request_config support
result = loop.run_until_complete(plugin_func(system_prompt, initial_query, client, model))
return result
finally:
loop.close()
else:
# For synchronous functions, call directly
if 'request_config' in sig.parameters:
# Plugin supports request_config
return plugin_func(system_prompt, initial_query, client, model, request_config=request_config)
else:
# Legacy plugin without request_config support
return plugin_func(system_prompt, initial_query, client, model)
else:
raise ValueError(f"Unknown approach: {approach}")
def execute_combined_approaches(approaches, system_prompt, initial_query, client, model, request_config: dict = None):
final_response = initial_query
total_tokens = 0
for approach in approaches:
response, tokens = execute_single_approach(approach, system_prompt, final_response, client, model, request_config)
final_response = response
total_tokens += tokens
return final_response, total_tokens
async def execute_parallel_approaches(approaches, system_prompt, initial_query, client, model, request_config: dict = None):
async def run_approach(approach):
return await asyncio.to_thread(execute_single_approach, approach, system_prompt, initial_query, client, model, request_config)
tasks = [run_approach(approach) for approach in approaches]
results = await asyncio.gather(*tasks)
responses, tokens = zip(*results)
return list(responses), sum(tokens)
def execute_n_times(n: int, approaches, operation: str, system_prompt: str, initial_query: str, client: Any, model: str,
request_config: dict = None) -> Tuple[Union[str, List[str]], int]:
"""
Execute the pipeline n times and return n responses.
Args:
n (int): Number of times to run the pipeline
approaches (list): List of approaches to execute
operation (str): Operation type ('SINGLE', 'AND', or 'OR')
system_prompt (str): System prompt
initial_query (str): Initial query
client: OpenAI client instance
model (str): Model identifier
Returns:
Tuple[Union[str, List[str]], int]: List of responses and total token count
"""
responses = []
total_tokens = 0
for _ in range(n):
if operation == 'SINGLE':
response, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config)
elif operation == 'AND':
response, tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model, request_config)
elif operation == 'OR':
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
response, tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model, request_config))
loop.close()
else:
raise ValueError(f"Unknown operation: {operation}")
# If response is already a list (from OR operation), extend responses
# Otherwise append the single response
if isinstance(response, list):
responses.extend(response)
else:
responses.append(response)
total_tokens += tokens
# If n=1 and we got a single response, return it as is
# Otherwise return the list of responses
if n == 1 and len(responses) == 1:
return responses[0], total_tokens
return responses, total_tokens
def generate_streaming_response(final_response, model):
# Yield the final response
if isinstance(final_response, list):
for index, response in enumerate(final_response):
yield "data: " + json.dumps({
"choices": [{"delta": {"content": response}, "index": index, "finish_reason": "stop"}],
"model": model,
}) + "\n\n"
else:
yield "data: " + json.dumps({
"choices": [{"delta": {"content": final_response}, "index": 0, "finish_reason": "stop"}],
"model": model,
}) + "\n\n"
# Yield the final message to indicate the stream has ended
yield "data: [DONE]\n\n"
def extract_contents(response_obj):
contents = []
# Handle both single response and list of responses
responses = response_obj if isinstance(response_obj, list) else [response_obj]
for response in responses:
# Extract content from first choice if it exists
if (response.get('choices') and
len(response['choices']) > 0 and
response['choices'][0].get('message') and
response['choices'][0]['message'].get('content')):
contents.append(response['choices'][0]['message']['content'])
return contents
def parse_conversation(messages):
system_prompt = ""
conversation = []
optillm_approach = None
for message in messages:
role = message['role']
content = message['content']
# Handle content that could be a list or string
if isinstance(content, list):
# Extract text content from the list
text_content = ' '.join(
item['text'] for item in content
if isinstance(item, dict) and item.get('type') == 'text'
)
else:
text_content = content
if role == 'system':
system_prompt, optillm_approach = extract_optillm_approach(text_content)
elif role == 'user':
if not optillm_approach:
text_content, optillm_approach = extract_optillm_approach(text_content)
conversation.append(f"User: {text_content}")
elif role == 'assistant':
conversation.append(f"Assistant: {text_content}")
initial_query = "\n".join(conversation)
return system_prompt, initial_query, optillm_approach
def tagged_conversation_to_messages(response_text):
"""Convert a tagged conversation string or list of strings into a list of messages.
If the input doesn't contain User:/Assistant: tags, return it as is.
Args:
response_text: Either a string containing "User:" and "Assistant:" tags,
or a list of such strings.
Returns:
If input has tags: A list of message dictionaries.
If input has no tags: The original input.
"""
def has_conversation_tags(text):
return "User:" in text or "Assistant:" in text
def process_single_response(text):
if not has_conversation_tags(text):
return text
messages = []
# Split on "User:" or "Assistant:" while keeping the delimiter
parts = re.split(r'(?=(User:|Assistant:))', text.strip())
# Remove empty strings
parts = [p for p in parts if p.strip()]
for part in parts:
part = part.strip()
if part.startswith('User:'):
messages.append({
'role': 'user',
'content': part[5:].strip()
})
elif part.startswith('Assistant:'):
messages.append({
'role': 'assistant',
'content': part[10:].strip()
})
return messages
if isinstance(response_text, list):
processed = [process_single_response(text) for text in response_text]
# If none of the responses had tags, return original list
if all(isinstance(p, str) for p in processed):
return response_text
return processed
else:
return process_single_response(response_text)
def extract_optillm_approach(content):
match = re.search(r'<optillm_approach>(.*?)</optillm_approach>', content)
if match:
approach = match.group(1)
content = re.sub(r'<optillm_approach>.*?</optillm_approach>', '', content).strip()
return content, approach
return content, None
# Optional API key configuration to secure the proxy
@app.before_request
def check_api_key():
if server_config['optillm_api_key']:
if request.path == "/health":
return
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Invalid Authorization header. Expected format: 'Authorization: Bearer YOUR_API_KEY'"}), 401
client_key = auth_header.split('Bearer ', 1)[1].strip()
if not secrets.compare_digest(client_key, server_config['optillm_api_key']):
return jsonify({"error": "Invalid API key"}), 401
@app.route('/v1/chat/completions', methods=['POST'])
def proxy():
logger.info('Received request to /v1/chat/completions')
data = request.get_json()
auth_header = request.headers.get("Authorization")
bearer_token = ""
if auth_header and auth_header.startswith("Bearer "):
bearer_token = auth_header.split("Bearer ")[1].strip()
logger.debug(f"Intercepted Bearer Token: {bearer_token}")
logger.debug(f'Request data: {data}')
stream = data.get('stream', False)
messages = data.get('messages', [])
model = data.get('model', server_config['model'])
n = data.get('n', server_config['n']) # Get n value from request or config
# Extract response_format if present
response_format = data.get("response_format", None)
# Explicit keys that we are already handling
explicit_keys = {'stream', 'messages', 'model', 'n', 'response_format'}
# Copy the rest into request_config
request_config = {k: v for k, v in data.items() if k not in explicit_keys}
# Add the explicitly handled ones
request_config.update({
"stream": stream,
"n": n,
"response_format": response_format, # Add response_format to config
})
optillm_approach = data.get('optillm_approach', server_config['approach'])
logger.debug(data)
server_config['mcts_depth'] = data.get('mcts_depth', server_config['mcts_depth'])
server_config['mcts_exploration'] = data.get('mcts_exploration', server_config['mcts_exploration'])
server_config['mcts_simulations'] = data.get('mcts_simulations', server_config['mcts_simulations'])
system_prompt, initial_query, message_optillm_approach = parse_conversation(messages)
if message_optillm_approach:
optillm_approach = message_optillm_approach
if optillm_approach != "auto":
model = f"{optillm_approach}-{model}"
base_url = server_config['base_url']
default_client, api_key = get_config()
operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches)
logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}')
if bearer_token != "" and bearer_token.startswith("sk-"):
api_key = bearer_token
if base_url != "":
client = OpenAI(api_key=api_key, base_url=base_url)
else:
client = OpenAI(api_key=api_key)
else:
client = default_client
try:
# Check if any of the approaches is 'none'
contains_none = any(approach == 'none' for approach in approaches)
if operation == 'SINGLE' and approaches[0] == 'none':
# For none approach with n>1, make n separate calls
if n > 1:
responses = []
completion_tokens = 0
for _ in range(n):
result, tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config)
responses.append(result)
completion_tokens += tokens
result = responses
else:
result, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model, request_config)
logger.debug(f'Direct proxy response: {result}')
if stream:
return Response(generate_streaming_response(extract_contents(result), model), content_type='text/event-stream')
else :
return jsonify(result), 200
elif operation == 'AND' or operation == 'OR':
if contains_none:
raise ValueError("'none' approach cannot be combined with other approaches")
# Handle non-none approaches with n attempts
response, completion_tokens = execute_n_times(n, approaches, operation, system_prompt, initial_query, client, model, request_config)
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
return jsonify({"error": str(e)}), 500
# Convert tagged conversation to messages format if needed
if isinstance(response, list):
processed_response = tagged_conversation_to_messages(response)
# If processed_response is a list of message lists, extract last message content
if processed_response != response: # Only process if format changed
response = [msg[-1]['content'] if isinstance(msg, list) and msg else msg
for msg in processed_response]
# Otherwise keep original response
else:
messages = tagged_conversation_to_messages(response)
if isinstance(messages, list) and messages: # Only process if format changed
response = messages[-1]['content']
if stream:
return Response(generate_streaming_response(response, model), content_type='text/event-stream')
else:
response_data = {
'model': model,
'choices': [],
'usage': {
'completion_tokens': completion_tokens,
}
}
if isinstance(response, list):
for index, resp in enumerate(response):
response_data['choices'].append({
'index': index,
'message': {
'role': 'assistant',
'content': resp,
},
'finish_reason': 'stop'
})
else:
response_data['choices'].append({
'index': 0,
'message': {
'role': 'assistant',
'content': response,
},
'finish_reason': 'stop'
})
logger.debug(f'API response: {response_data}')
return jsonify(response_data), 200
@app.route('/v1/models', methods=['GET'])
def proxy_models():
logger.info('Received request to /v1/models')
default_client, API_KEY = get_config()
try:
if server_config['base_url']:
client = OpenAI(api_key=API_KEY, base_url=server_config['base_url'])
else:
client = default_client
# Fetch models using the OpenAI client and return the raw response
models_response = client.models.list()
logger.debug('Models retrieved successfully')
return models_response, 200
except Exception as e:
logger.error(f"Error fetching models: {str(e)}")
return jsonify({"error": f"Error fetching models: {str(e)}"}), 500
@app.route('/health', methods=['GET'])
def health():
return jsonify({"status": "ok"}), 200
def parse_args():
parser = argparse.ArgumentParser(description="Run LLM inference with various approaches.")
try:
from optillm import __version__ as package_version
except ImportError:
package_version = "unknown"
parser.add_argument('--version', action='version',
version=f'%(prog)s {package_version}',
help="Show program's version number and exit")
# Define arguments and their corresponding environment variables
args_env = [
("--optillm-api-key", "OPTILLM_API_KEY", str, "", "Optional API key for client authentication to optillm"),
("--approach", "OPTILLM_APPROACH", str, "auto", "Inference approach to use", known_approaches + list(plugin_approaches.keys())),
("--mcts-simulations", "OPTILLM_SIMULATIONS", int, 2, "Number of MCTS simulations"),
("--mcts-exploration", "OPTILLM_EXPLORATION", float, 0.2, "Exploration weight for MCTS"),
("--mcts-depth", "OPTILLM_DEPTH", int, 1, "Simulation depth for MCTS"),
("--model", "OPTILLM_MODEL", str, "gpt-4o-mini", "OpenAI model to use"),
("--rstar-max-depth", "OPTILLM_RSTAR_MAX_DEPTH", int, 3, "Maximum depth for rStar algorithm"),
("--rstar-num-rollouts", "OPTILLM_RSTAR_NUM_ROLLOUTS", int, 5, "Number of rollouts for rStar algorithm"),
("--rstar-c", "OPTILLM_RSTAR_C", float, 1.4, "Exploration constant for rStar algorithm"),
("--n", "OPTILLM_N", int, 1, "Number of final responses to be returned"),
("--return-full-response", "OPTILLM_RETURN_FULL_RESPONSE", bool, False, "Return the full response including the CoT with <thinking> tags"),
("--port", "OPTILLM_PORT", int, 8000, "Specify the port to run the proxy"),
("--log", "OPTILLM_LOG", str, "info", "Specify the logging level", list(logging_levels.keys())),
("--launch-gui", "OPTILLM_LAUNCH_GUI", bool, False, "Launch a Gradio chat interface"),
("--plugins-dir", "OPTILLM_PLUGINS_DIR", str, "", "Path to the plugins directory"),
]
for arg, env, type_, default, help_text, *extra in args_env:
env_value = os.environ.get(env)
if env_value is not None:
if type_ == bool:
default = env_value.lower() in ('true', '1', 'yes')
else:
default = type_(env_value)
if extra and extra[0]: # Check if there are choices for this argument
parser.add_argument(arg, type=type_, default=default, help=help_text, choices=extra[0])
else:
parser.add_argument(arg, type=type_, default=default, help=help_text)
# Special handling for best_of_n to support both formats
best_of_n_default = int(os.environ.get("OPTILLM_BEST_OF_N", 3))
parser.add_argument("--best-of-n", "--best_of_n", dest="best_of_n", type=int, default=best_of_n_default,
help="Number of samples for best_of_n approach")
# Special handling for base_url to support both formats
base_url_default = os.environ.get("OPTILLM_BASE_URL", "")
parser.add_argument("--base-url", "--base_url", dest="base_url", type=str, default=base_url_default,
help="Base url for OpenAI compatible endpoint")
# Use the function to get the default path
default_config_path = get_config_path()
# Special handling of all the CePO Configurations
for field in fields(CepoConfig):
parser.add_argument(f"--cepo_{field.name}",
dest=f"cepo_{field.name}",
type=field.type,
default=None,
help=f"CePO configuration for {field.name}")
parser.add_argument("--cepo_config_file",
dest="cepo_config_file",
type=str,
default=default_config_path,
help="Path to CePO configuration file")
args = parser.parse_args()
# Convert argument names to match server_config keys
args_dict = vars(args)
for key in list(args_dict.keys()):
new_key = key.replace("-", "_")
if new_key != key:
args_dict[new_key] = args_dict.pop(key)
return args
def main():
global server_config
global cepo_config
# Call this function at the start of main()
args = parse_args()
# Update server_config with all argument values
server_config.update(vars(args))
load_plugins()
port = server_config['port']
# Set logging level from user request
logging_level = server_config['log']
if logging_level in logging_levels.keys():
logger.setLevel(logging_levels[logging_level])
# set and log the cepo configs
cepo_config = init_cepo_config(server_config)
if args.approach == 'cepo':
logger.info(f"CePO Config: {cepo_config}")
logger.info(f"Starting server with approach: {server_config['approach']}")
server_config_clean = server_config.copy()
if server_config_clean['optillm_api_key']:
server_config_clean['optillm_api_key'] = '[REDACTED]'
logger.info(f"Server configuration: {server_config_clean}")
# Launch GUI if requested
if server_config.get('launch_gui'):
try:
import gradio as gr
# Start server in a separate thread
import threading
server_thread = threading.Thread(target=app.run, kwargs={'host': '0.0.0.0', 'port': port})
server_thread.daemon = True
server_thread.start()
# Configure the base URL for the Gradio interface
base_url = f"http://localhost:{port}/v1"
logger.info(f"Launching Gradio interface connected to {base_url}")
# Launch Gradio interface
demo = gr.load_chat(
base_url,
model=server_config['model'],
token=None
)
demo.launch(server_name="0.0.0.0", share=False)
except ImportError:
logger.error("Gradio is required for GUI. Install it with: pip install gradio")
return
app.run(host='0.0.0.0', port=port)
if __name__ == "__main__":
main()