fix edge cases

This commit is contained in:
Asankhaya Sharma
2025-05-16 10:05:18 +08:00
parent 7e2ebf9649
commit 6d6f8503cc
2 changed files with 22 additions and 6 deletions

View File

@@ -21,17 +21,24 @@ from optillm.plugins.spl.config import (
# Setup logging
logger = logging.getLogger(__name__)
def select_relevant_strategies(query: str, problem_type: str, db: Any, max_strategies: int = MAX_STRATEGIES_FOR_INFERENCE) -> List[Strategy]:
def select_relevant_strategies(query: str, problem_type: str, db: Any, learning_mode: bool = False, max_strategies: int = MAX_STRATEGIES_FOR_INFERENCE) -> List[Strategy]:
"""
Select the most relevant strategies for a given problem to be used during inference.
This controls how many strategies are included in the system prompt augmentation.
Only selects strategies of the matching problem type with success rate >= MIN_SUCCESS_RATE_FOR_INFERENCE.
When in inference mode (not learning_mode), only strategies with:
- A matching problem type
- Success rate >= MIN_SUCCESS_RATE_FOR_INFERENCE
- At least 5 attempts
are selected.
In learning mode, strategies with fewer attempts are also considered.
Args:
query: The problem/query text
problem_type: The type of problem
db: Strategy database
learning_mode: Whether we're in learning mode (affects filtering criteria)
max_strategies: Maximum number of strategies to return
Returns:
@@ -41,13 +48,22 @@ def select_relevant_strategies(query: str, problem_type: str, db: Any, max_strat
type_specific = db.get_strategies_for_problem(problem_type)
logger.info(f"Found {len(type_specific)} strategies for problem type '{problem_type}'")
# Filter strategies by minimum success rate
# Filter strategies by minimum success rate and attempts
qualified_strategies = []
for strategy in type_specific:
if strategy.success_rate >= MIN_SUCCESS_RATE_FOR_INFERENCE or strategy.total_attempts < 5:
# In learning mode, we're more lenient with new strategies
if learning_mode and strategy.total_attempts < 5:
logger.info(f"Strategy {strategy.strategy_id} included (learning mode - only {strategy.total_attempts} attempts so far)")
qualified_strategies.append(strategy)
# For inference or well-tested strategies, we require minimum success rate
elif strategy.success_rate >= MIN_SUCCESS_RATE_FOR_INFERENCE and strategy.total_attempts >= 5:
logger.info(f"Strategy {strategy.strategy_id} qualified - success rate {strategy.success_rate:.2f} >= minimum {MIN_SUCCESS_RATE_FOR_INFERENCE:.2f} with {strategy.total_attempts} attempts")
qualified_strategies.append(strategy)
else:
logger.info(f"Strategy {strategy.strategy_id} skipped - success rate {strategy.success_rate:.2f} < minimum {MIN_SUCCESS_RATE_FOR_INFERENCE:.2f}")
if strategy.total_attempts < 5:
logger.info(f"Strategy {strategy.strategy_id} skipped - insufficient attempts ({strategy.total_attempts} < 5) in inference mode")
else:
logger.info(f"Strategy {strategy.strategy_id} skipped - success rate {strategy.success_rate:.2f} < minimum {MIN_SUCCESS_RATE_FOR_INFERENCE:.2f}")
if not qualified_strategies:
logger.info(f"No strategies meet the minimum success rate threshold ({MIN_SUCCESS_RATE_FOR_INFERENCE:.2f}) for problem type '{problem_type}'")

View File

@@ -119,7 +119,7 @@ def run_spl(system_prompt: str, initial_query: str, client, model: str, request_
existing_strategies = db.get_strategies_for_problem(problem_type)
# 6. Select relevant strategies for this problem (using inference limit)
selected_strategies = select_relevant_strategies(initial_query, problem_type, db, MAX_STRATEGIES_FOR_INFERENCE)
selected_strategies = select_relevant_strategies(initial_query, problem_type, db, learning_mode, MAX_STRATEGIES_FOR_INFERENCE)
# Log the selected strategies
for i, strategy in enumerate(selected_strategies, 1):