fix stragery selection

This commit is contained in:
Asankhaya Sharma
2025-05-13 09:36:20 +08:00
parent bfb702a7aa
commit 02de0928fc
3 changed files with 272 additions and 77 deletions

View File

@@ -594,12 +594,18 @@ def proxy():
# Extract response_format if present
response_format = data.get("response_format", None)
# Create request config with all parameters
request_config = {
# Explicit keys that we are already handling
explicit_keys = {'stream', 'messages', 'model', 'n', 'response_format'}
# Copy the rest into request_config
request_config = {k: v for k, v in data.items() if k not in explicit_keys}
# Add the explicitly handled ones
request_config.update({
"stream": stream,
"n": n,
"response_format": response_format # Add response_format to config
}
"response_format": response_format, # Add response_format to config
})
optillm_approach = data.get('optillm_approach', server_config['approach'])
logger.debug(data)

View File

@@ -40,6 +40,16 @@ STRATEGY_METRICS_PATH = os.path.join(STRATEGY_DB_DIR, 'metrics.json')
# Default max tokens for reasoning LLMs
DEFAULT_MAX_TOKENS = 4096
# How often to perform maintenance operations (merge, prune)
MAINTENANCE_INTERVAL = 10 # Every 10 queries instead of 100
# Similarity thresholds
STRATEGY_CREATION_THRESHOLD = 0.7 # Higher threshold to avoid creating similar strategies
STRATEGY_MERGING_THRESHOLD = 0.6 # Lower threshold to merge more similar strategies
# Maximum strategies per problem type
MAX_STRATEGIES_PER_TYPE = 2 # Reduced from implicit 3
# Ensure data directory exists
os.makedirs(STRATEGY_DB_DIR, exist_ok=True)
@@ -159,6 +169,11 @@ class Strategy:
if len(self.reasoning_examples) >= 5:
self.reasoning_examples.pop(0) # Remove oldest example
self.reasoning_examples.append(reasoning.strip())
def add_example(self, example: str) -> None:
"""Add an example to the strategy."""
if example and example not in self.examples:
self.examples.append(example)
class StrategyDatabase:
"""Manages a collection of problem-solving strategies."""
@@ -176,7 +191,8 @@ class StrategyDatabase:
"strategies_refined": 0,
"successful_resolutions": 0,
"last_strategy_id": 0,
"reasoning_examples_collected": 0
"reasoning_examples_collected": 0,
"strategies_merged": 0
}
self._load()
@@ -310,6 +326,13 @@ class StrategyDatabase:
self.metrics["reasoning_examples_collected"] += 1
self._save()
def add_example_to_strategy(self, strategy_id: str, example: str) -> None:
"""Add an example to a strategy."""
strategy = self.get_strategy_by_id(strategy_id)
if strategy and example:
strategy.add_example(example)
self._save()
def get_similar_strategies(self, query: str, n: int = 5) -> List[Tuple[Strategy, float]]:
"""Find strategies similar to a query using TF-IDF similarity."""
if not self.strategies:
@@ -336,6 +359,105 @@ class StrategyDatabase:
logger.error(f"Error finding similar strategies: {str(e)}")
return []
def find_similar_strategy(self, problem_type: str, query: str, threshold: float = STRATEGY_CREATION_THRESHOLD) -> Optional[Tuple[Strategy, float]]:
"""
Find a strategy of the same problem type that is similar to the query.
Args:
problem_type: The problem type to match
query: The query to find similarity against
threshold: The similarity threshold to consider a match
Returns:
Optional[Tuple[Strategy, float]]: The most similar strategy and its similarity score,
or None if no similar strategy is found
"""
if not self.strategies:
return None
# Get strategies of the specified problem type
type_strategies = [s for s in self.strategies if s.problem_type == problem_type]
if not type_strategies:
return None
try:
# Vectorize strategy texts
strategy_texts = [s.strategy_text for s in type_strategies]
vectorizer = TfidfVectorizer(stop_words='english')
vectors = vectorizer.fit_transform(strategy_texts + [query])
# Calculate similarities
query_vector = vectors[-1]
strategy_vectors = vectors[:-1]
similarities = cosine_similarity(query_vector, strategy_vectors).flatten()
# Find the most similar strategy
if len(similarities) > 0:
max_idx = similarities.argmax()
max_similarity = similarities[max_idx]
if max_similarity >= threshold:
return (type_strategies[max_idx], float(max_similarity))
except Exception as e:
logger.error(f"Error finding similar strategy: {str(e)}")
return None
def find_similar_examples(self, problem_type: str, query: str, threshold: float = STRATEGY_CREATION_THRESHOLD) -> Optional[Tuple[Strategy, float]]:
"""
Find a strategy of the same problem type with examples similar to the query.
Args:
problem_type: The problem type to match
query: The query to find similarity against
threshold: The similarity threshold to consider a match
Returns:
Optional[Tuple[Strategy, float]]: The strategy with the most similar examples and the similarity score,
or None if no similar strategy is found
"""
if not self.strategies:
return None
# Get strategies of the specified problem type
type_strategies = [s for s in self.strategies if s.problem_type == problem_type]
if not type_strategies:
return None
max_similarity = 0.0
most_similar_strategy = None
try:
for strategy in type_strategies:
if not strategy.examples:
continue
# Vectorize examples and query
vectorizer = TfidfVectorizer(stop_words='english')
vectors = vectorizer.fit_transform(strategy.examples + [query])
# Calculate similarities
query_vector = vectors[-1]
example_vectors = vectors[:-1]
similarities = cosine_similarity(query_vector, example_vectors).flatten()
# Get the maximum similarity for this strategy
if len(similarities) > 0:
strategy_max_similarity = similarities.max()
if strategy_max_similarity > max_similarity:
max_similarity = strategy_max_similarity
most_similar_strategy = strategy
if most_similar_strategy and max_similarity >= threshold:
return (most_similar_strategy, float(max_similarity))
except Exception as e:
logger.error(f"Error finding similar examples: {str(e)}")
return None
def get_next_strategy_id(self) -> str:
"""Generate a unique ID for a new strategy."""
self.metrics["last_strategy_id"] += 1
@@ -365,7 +487,7 @@ class StrategyDatabase:
self._save()
return pruned_count
def merge_similar_strategies(self, similarity_threshold: float = 0.8) -> int:
def merge_similar_strategies(self, similarity_threshold: float = STRATEGY_MERGING_THRESHOLD) -> int:
"""Merge strategies that are very similar to each other."""
if len(self.strategies) <= 1:
return 0
@@ -395,6 +517,8 @@ class StrategyDatabase:
# Remove the second strategy
self.strategies.pop(j)
merged_count += 1
self.metrics["strategies_merged"] += 1
logger.info(f"Merged strategies {self.strategies[i].strategy_id} and {merged_strategy.strategy_id} with similarity {similarity:.2f}")
else:
j += 1
else:
@@ -436,6 +560,63 @@ class StrategyDatabase:
merged.reasoning_examples = merged.reasoning_examples[-5:]
return merged
def limit_strategies_per_type(self, max_per_type: int = MAX_STRATEGIES_PER_TYPE) -> int:
"""
Limit the number of strategies per problem type to the specified maximum.
Keeps the best performing strategies based on success rate and recency.
Args:
max_per_type: Maximum number of strategies to keep per problem type
Returns:
int: Number of strategies removed
"""
# Group strategies by problem type
strategies_by_type = {}
for strategy in self.strategies:
if strategy.problem_type not in strategies_by_type:
strategies_by_type[strategy.problem_type] = []
strategies_by_type[strategy.problem_type].append(strategy)
# Keep track of strategies to remove
to_remove = []
# For each problem type, keep only the best max_per_type strategies
for problem_type, strategies in strategies_by_type.items():
if len(strategies) <= max_per_type:
continue
# Score strategies based on success rate (70%) and recency (30%)
scored_strategies = []
for strategy in strategies:
recency_score = 0
if strategy.last_used:
last_used = datetime.fromisoformat(strategy.last_used)
days_since = (datetime.now() - last_used).days
recency_score = max(0, 1.0 - min(1.0, days_since / 30.0))
score = (0.7 * strategy.success_rate) + (0.3 * recency_score)
scored_strategies.append((strategy, score))
# Sort by score (descending)
scored_strategies.sort(key=lambda x: x[1], reverse=True)
# Mark excess strategies for removal
for strategy, _ in scored_strategies[max_per_type:]:
to_remove.append(strategy)
# Remove marked strategies
initial_count = len(self.strategies)
self.strategies = [s for s in self.strategies if s not in to_remove]
removed_count = initial_count - len(self.strategies)
if removed_count > 0:
self.vectors = None # Invalidate vector cache
self._save()
logger.info(f"Removed {removed_count} excess strategies to maintain max {max_per_type} per type")
return removed_count
def extract_thinking(response: str) -> Tuple[str, Optional[str]]:
"""
@@ -627,9 +808,9 @@ def generate_strategy(problem: str, problem_type: str, client, model: str, db: S
examples=[problem]
)
def should_create_new_strategy(problem_type: str, query: str, existing_strategies: List[Strategy], db: StrategyDatabase) -> bool:
def should_create_new_strategy(problem_type: str, query: str, existing_strategies: List[Strategy], db: StrategyDatabase) -> Tuple[bool, Optional[Strategy]]:
"""
Determine whether to create a new strategy for a problem type that already has strategies.
Determine whether to create a new strategy or update an existing one.
Args:
problem_type: The type of problem
@@ -638,49 +819,44 @@ def should_create_new_strategy(problem_type: str, query: str, existing_strategie
db: Strategy database
Returns:
bool: True if a new strategy should be created
Tuple[bool, Optional[Strategy]]:
- Boolean indicating if a new strategy should be created
- The similar strategy to update (if any)
"""
# If there are no existing strategies, definitely create one
if not existing_strategies:
return True
return True, None
# Calculate the similarity of the query to the examples in existing strategies
max_similarity = 0.0
# Get all examples from existing strategies
all_examples = []
for strategy in existing_strategies:
all_examples.extend(strategy.examples)
if all_examples:
try:
# Vectorize examples and query
vectorizer = TfidfVectorizer(stop_words='english')
vectors = vectorizer.fit_transform(all_examples + [query])
# If we already have enough strategies for this problem type, check if the query
# is similar to any existing strategy
if len(existing_strategies) >= MAX_STRATEGIES_PER_TYPE:
# First, check similarity based on strategy text
similar_strategy_result = db.find_similar_strategy(problem_type, query)
if similar_strategy_result:
similar_strategy, similarity = similar_strategy_result
logger.info(f"Found similar strategy {similar_strategy.strategy_id} with text similarity {similarity:.2f}")
return False, similar_strategy
# Get similarity between query and each example
query_vector = vectors[-1]
example_vectors = vectors[:-1]
# Calculate similarities
similarities = cosine_similarity(query_vector, example_vectors).flatten()
# Get max similarity
if len(similarities) > 0:
max_similarity = similarities.max()
except Exception as e:
logger.error(f"Error calculating similarity: {str(e)}")
# Next, check similarity based on examples
similar_examples_result = db.find_similar_examples(problem_type, query)
if similar_examples_result:
similar_strategy, similarity = similar_examples_result
logger.info(f"Found strategy {similar_strategy.strategy_id} with similar examples, similarity {similarity:.2f}")
return False, similar_strategy
# If the query is very different from existing examples (low similarity),
# or if we have few strategies for this problem type, create a new one
if max_similarity < 0.5 or len(existing_strategies) < 3:
logger.info(f"Creating new strategy for {problem_type} (max similarity: {max_similarity:.2f}, existing strategies: {len(existing_strategies)})")
return True
# If we have fewer than the maximum allowed strategies for this type,
# check strategy similarity before creating a new one
similar_strategy_result = db.find_similar_strategy(problem_type, query, threshold=STRATEGY_CREATION_THRESHOLD)
if similar_strategy_result:
similar_strategy, similarity = similar_strategy_result
logger.info(f"Found similar strategy {similar_strategy.strategy_id} with text similarity {similarity:.2f}")
return False, similar_strategy
logger.info(f"Not creating new strategy for {problem_type} (max similarity: {max_similarity:.2f}, existing strategies: {len(existing_strategies)})")
return False
# If we get here, we should create a new strategy
logger.info(f"No similar strategy found for {problem_type}, creating a new one")
return True, None
def select_relevant_strategies(query: str, problem_type: str, db: StrategyDatabase, max_strategies: int = 3) -> List[Strategy]:
def select_relevant_strategies(query: str, problem_type: str, db: StrategyDatabase, max_strategies: int = MAX_STRATEGIES_PER_TYPE) -> List[Strategy]:
"""
Select the most relevant strategies for a given problem.
@@ -978,34 +1154,54 @@ def run(system_prompt: str, initial_query: str, client, model: str, request_conf
existing_strategies = db.get_strategies_for_problem(problem_type)
logger.info(f"Found {len(existing_strategies)} existing strategies for {problem_type}")
# 3. Determine if we need to create a new strategy
need_new_strategy = False
# 3. Determine if we need to create a new strategy or update an existing one
similar_strategy = None
if not existing_strategies:
# No strategies exist for this problem type
need_new_strategy = True
elif not inference_only:
# In learning mode, check if we should create a new strategy
need_new_strategy = should_create_new_strategy(problem_type, initial_query, existing_strategies, db)
if not inference_only:
# In learning mode, check if we should create a new strategy or update an existing one
should_create, similar_strategy = should_create_new_strategy(
problem_type,
initial_query,
existing_strategies,
db
)
if should_create:
# Create a new strategy
logger.info(f"Creating new strategy for {problem_type}")
new_strategy = generate_strategy(initial_query, problem_type, client, model, db)
db.add_strategy(new_strategy)
logger.info(f"Added new strategy with ID: {new_strategy.strategy_id}")
elif similar_strategy:
# Update existing strategy with new example
logger.info(f"Updating existing strategy {similar_strategy.strategy_id} with new example")
db.add_example_to_strategy(similar_strategy.strategy_id, initial_query)
# 4. Create a new strategy if needed
if need_new_strategy and not inference_only:
logger.info(f"Generating new strategy for {problem_type}")
# Pass the db instance to generate_strategy to ensure consistent ID generation
new_strategy = generate_strategy(initial_query, problem_type, client, model, db)
db.add_strategy(new_strategy)
logger.info(f"Added new strategy with ID: {new_strategy.strategy_id}")
# Make sure the new strategy is included in our list
existing_strategies.append(new_strategy)
# 4. Perform database maintenance (more frequently than before)
if not inference_only and db.metrics["total_queries"] % MAINTENANCE_INTERVAL == 0:
# 4.1 Merge similar strategies
merged_count = db.merge_similar_strategies(similarity_threshold=STRATEGY_MERGING_THRESHOLD)
logger.info(f"Merged {merged_count} similar strategies")
# 4.2 Limit strategies per problem type
limited_count = db.limit_strategies_per_type(max_per_type=MAX_STRATEGIES_PER_TYPE)
# 4.3 Prune low-performing strategies
pruned_count = db.prune_strategies()
logger.info(f"Pruned {pruned_count} low-performing strategies")
# 5. Select relevant strategies for this problem
# 5. Re-select strategies (in case the database changed in step 4)
existing_strategies = db.get_strategies_for_problem(problem_type)
# 6. Select relevant strategies for this problem
selected_strategies = select_relevant_strategies(initial_query, problem_type, db)
# Log the selected strategies
for i, strategy in enumerate(selected_strategies, 1):
logger.info(f"Selected strategy {i}: {strategy.strategy_id} (success rate: {strategy.success_rate:.2f})")
# 6. If no strategies selected, use fallback
# 7. If no strategies selected, use fallback
if not selected_strategies:
logger.info(f"No strategies selected, using fallback strategy")
fallback_strategy = Strategy(
@@ -1022,18 +1218,18 @@ def run(system_prompt: str, initial_query: str, client, model: str, request_conf
)
selected_strategies = [fallback_strategy]
# 7. Augment the system prompt with the selected strategies
# 8. Augment the system prompt with the selected strategies
augmented_prompt = augment_system_prompt(system_prompt, selected_strategies)
logger.info(f"Augmented system prompt with {len(selected_strategies)} strategies")
# 8. Forward the request to the LLM with the augmented prompt
# 9. Forward the request to the LLM with the augmented prompt
try:
# Create a copy of request_config without spl_inference_only
request_params = {}
if request_config:
request_params = {k: v for k, v in request_config.items() if k != 'spl_inference_only'}
# Ensure max_tokens is set to at least 4096 for reasoning LLMs
# Ensure max_tokens is set to at least DEFAULT_MAX_TOKENS for reasoning LLMs
if 'max_tokens' not in request_params:
request_params['max_tokens'] = DEFAULT_MAX_TOKENS
elif request_params['max_tokens'] < DEFAULT_MAX_TOKENS:
@@ -1061,7 +1257,7 @@ def run(system_prompt: str, initial_query: str, client, model: str, request_conf
# Only perform learning operations if not in inference-only mode
if not inference_only:
# 9. Evaluate the effectiveness of the strategies
# 10. Evaluate the effectiveness of the strategies
strategy_effectiveness = evaluate_strategy_effectiveness(
final_response,
thinking,
@@ -1070,7 +1266,7 @@ def run(system_prompt: str, initial_query: str, client, model: str, request_conf
model
)
# 10. Update strategy metrics based on effectiveness
# 11. Update strategy metrics based on effectiveness
for strategy_id, effective in strategy_effectiveness.items():
# Skip temporary fallback strategies
if strategy_id != "fallback_temporary":
@@ -1082,7 +1278,7 @@ def run(system_prompt: str, initial_query: str, client, model: str, request_conf
db.add_reasoning_example(strategy_id, thinking)
logger.info(f"Added reasoning example to strategy {strategy_id}")
# 11. Periodically refine strategies (after every 10 uses)
# 12. Periodically refine strategies (after every 10 uses)
for strategy in selected_strategies:
# Skip temporary fallback strategies
if (strategy.strategy_id != "fallback_temporary" and
@@ -1091,14 +1287,6 @@ def run(system_prompt: str, initial_query: str, client, model: str, request_conf
logger.info(f"Refining strategy {strategy.strategy_id} after {strategy.total_attempts} attempts")
refined_strategy = refine_strategy(strategy, initial_query, final_response, thinking, client, model)
db.refine_strategy(strategy.strategy_id, refined_strategy.strategy_text)
# 12. Periodically prune low-performing strategies and merge similar ones (after every 100 queries)
if db.metrics["total_queries"] % 100 == 0:
pruned_count = db.prune_strategies()
logger.info(f"Pruned {pruned_count} low-performing strategies")
merged_count = db.merge_similar_strategies()
logger.info(f"Merged {merged_count} similar strategies")
else:
logger.info("Skipping strategy evaluation and refinement in inference-only mode")

View File

@@ -40,7 +40,7 @@ def load_optillm_bench() -> datasets.Dataset:
"""Load the OptiLLM Bench dataset."""
try:
dataset = load_dataset("codelion/optillmbench")
return dataset["train"] # We use the test split for evaluation
return dataset["test"] # We use the test split for evaluation
except Exception as e:
logger.error(f"Error loading dataset: {e}")
raise
@@ -182,6 +182,7 @@ def evaluate_model(
],
temperature=0.2,
max_tokens=4096,
extra_body= {"spl_inference_only": True},
)
# Calculate time taken