init implementation

This commit is contained in:
Asankhaya Sharma
2025-05-07 10:38:40 +08:00
parent 2ab4e6e896
commit 23de9188de
11 changed files with 1452 additions and 1 deletions

View File

@@ -343,7 +343,7 @@ Check this log file for connection issues, tool execution errors, and other diag
| Approach | Slug | Description |
| ------------------------------------ | ------------------ | ---------------------------------------------------------------------------------------------- |
| Cerebras Planning and Optimization | `cepo` | Combines Best of N, Chain-of-Thought, Self-Reflection, Self-Improvement, and various prompting techniques |
| Cerebras Planning and Optimization | `cepo` | Combines Best of N, Chain-of-Thought, Self-Reflection, Self-Improvement, and various prompting techniques |
| CoT with Reflection | `cot_reflection` | Implements chain-of-thought reasoning with \<thinking\>, \<reflection> and \<output\> sections |
| PlanSearch | `plansearch` | Implements a search algorithm over candidate plans for solving a problem in natural language |
| ReRead | `re2` | Implements rereading to improve reasoning by processing queries twice |
@@ -359,6 +359,7 @@ Check this log file for connection issues, tool execution errors, and other diag
| CoT Decoding | N/A for proxy | Implements chain-of-thought decoding to elicit reasoning without explicit prompting |
| Entropy Decoding | N/A for proxy | Implements adaptive sampling based on the uncertainty of tokens during generation |
| Thinkdeeper | N/A for proxy | Implements the `reasoning_effort` param from OpenAI for reasoning models like DeepSeek R1 |
| AutoThink | N/A for proxy | Combines query complexity classification with steering vectors to enhance reasoning |
## Implemented plugins

View File

@@ -0,0 +1,95 @@
# AutoThink
AutoThink is an adaptive thinking approach for Large Language Models that combines query complexity classification with steering vector guidance to enhance model reasoning capabilities.
## Overview
AutoThink combines several advanced techniques to optimize the thinking process of LLMs:
1. **Query Complexity Classification**: Uses an adaptive classifier to determine if a query requires HIGH or LOW complexity reasoning
2. **Token Budget Allocation**: Dynamically allocates thinking tokens based on query complexity
3. **Steering Vector Guidance**: Applies activation-based steering vectors to guide the model's reasoning process
4. **Controlled Thinking Process**: Manages explicit thinking phases with start and end tokens
## How It Works
### 1. Query Classification
AutoThink uses the `adaptive-classifier/llm-router` model to classify incoming queries:
- **HIGH**: Complex queries requiring deep reasoning, multi-step calculations, or thorough exploration
- **LOW**: Simpler queries requiring less extensive reasoning
### 2. Token Budget
Based on the classification, AutoThink allocates different token budgets for the thinking phase:
- **HIGH**: 70-90% of max tokens allocated for thinking
- **LOW**: 20-40% of max tokens allocated for thinking
### 3. Steering Vectors
AutoThink uses pre-extracted steering vectors from datasets like `codelion/Qwen3-0.6B-pts-steering-vectors`. These vectors represent different reasoning patterns:
- **Depth and thoroughness**: Encourages detailed, step-by-step reasoning
- **Numerical accuracy**: Promotes precise calculations and verification
- **Self-correction**: Facilitates error detection and correction
- **Exploration**: Supports considering multiple approaches
- **Organization**: Improves logical structure in responses
During inference, the model's internal activations are modified based on these vectors to enhance specific reasoning capabilities.
### 4. Controlled Thinking Process
The generation process includes:
1. A thinking phase marked by `<think>` and `</think>` tokens
2. Automatic adjustment of thinking time based on query complexity
3. Dynamic application of steering vectors
4. Graceful transition to the final response
## Configuration
AutoThink can be configured with:
```python
{
"model_name": "your-model-name",
"classifier_model": "adaptive-classifier/llm-router",
"steering_dataset": "codelion/Qwen3-0.6B-pts-steering-vectors",
"target_layer": 19, # Layer to apply steering vectors
"high_complexity_min_tokens": 1024,
"high_complexity_max_tokens": 4096,
"low_complexity_min_tokens": 256,
"low_complexity_max_tokens": 1024,
"pattern_strengths": {
"depth_and_thoroughness": 2.5, # Steering strength for different patterns
"numerical_accuracy": 2.0,
"self_correction": 3.0,
"exploration": 2.0,
"organization": 1.5
}
}
```
## Usage
```python
from optillm.autothink import autothink_decode
response = autothink_decode(
model,
tokenizer,
messages,
{
"steering_dataset": "codelion/Qwen3-0.6B-pts-steering-vectors",
"target_layer": 19
}
)
```
## Benefits
- **Adaptive Resource Usage**: Models think more on complex problems and less on simple ones
- **Enhanced Reasoning**: Steering vectors guide the model toward better reasoning patterns
- **Efficiency**: Better performance without increasing model size
- **Customizability**: Can be tailored for different domains using domain-specific steering vector datasets

View File

@@ -0,0 +1,7 @@
"""
AutoThink - Adaptive thinking approach for LLMs with query complexity classification and steering vectors.
"""
from .autothink import autothink_decode, AutoThinkProcessor
__all__ = ["autothink_decode", "AutoThinkProcessor"]

View File

@@ -0,0 +1,88 @@
"""
AutoThink main implementation.
This module provides the main implementation of AutoThink, combining
query complexity classification with steering vectors to enhance reasoning.
"""
import logging
from typing import Dict, List, Any, Optional
from transformers import PreTrainedModel, PreTrainedTokenizer
from .processor import AutoThinkProcessor
logger = logging.getLogger(__name__)
class AutoThinkProcessor:
"""
Main AutoThink processor class for external use.
Wraps the internal processor implementation.
"""
def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, config: Dict[str, Any] = None):
"""
Initialize the AutoThink processor.
Args:
model: Language model
tokenizer: Model tokenizer
config: Configuration dictionary
"""
self.config = config or {}
self.processor = None
self.model = model
self.tokenizer = tokenizer
def __call__(self, messages: List[Dict[str, str]]) -> str:
"""
Process messages with AutoThink's controlled thinking.
Args:
messages: List of message dictionaries
Returns:
Generated response
"""
# Create processor on first use to allow for model loading
if self.processor is None:
self.processor = self._create_processor()
return self.processor.process(messages)
def _create_processor(self):
"""Create the internal processor instance."""
return AutoThinkProcessor(self.config, self.tokenizer, self.model)
def autothink_decode(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
messages: List[Dict[str, str]],
request_config: Optional[Dict[str, Any]] = None
) -> str:
"""
Main plugin execution function with AutoThink's controlled thinking process.
Args:
model: Language model
tokenizer: Model tokenizer
messages: List of message dictionaries
request_config: Optional configuration dictionary
Returns:
Generated response with thinking process
"""
logger.info("Starting AutoThink processing")
# Create config dictionary
config = {}
if request_config:
config.update(request_config)
try:
processor = AutoThinkProcessor(model, tokenizer, config)
response = processor(messages)
return response
except Exception as e:
logger.error(f"Error in AutoThink processing: {str(e)}")
raise

View File

@@ -0,0 +1,152 @@
"""
Query complexity classifier for AutoThink.
This module provides functionality to classify queries as HIGH or LOW complexity
using the adaptive-classifier model.
"""
import logging
from typing import Dict, Any, Tuple, Optional, List, Union
import os
import sys
logger = logging.getLogger(__name__)
class ComplexityClassifier:
"""
Classifies queries as HIGH or LOW complexity for token budget allocation.
Uses the adaptive-classifier model for classification.
"""
def __init__(self, model_name: str = "adaptive-classifier/llm-router"):
"""
Initialize the complexity classifier.
Args:
model_name: HuggingFace model name or path for the classifier
"""
self.model_name = model_name
self.classifier = None
# Load model
self._load_model()
def _load_model(self):
"""Load the classification model using adaptive-classifier library."""
try:
# Check if adaptive-classifier is installed
try:
import adaptive_classifier
except ImportError:
logger.info("Installing adaptive-classifier library...")
os.system(f"{sys.executable} -m pip install adaptive-classifier")
import adaptive_classifier
# Import the AdaptiveClassifier class
from adaptive_classifier import AdaptiveClassifier
logger.info(f"Loading complexity classifier model: {self.model_name}")
self.classifier = AdaptiveClassifier.from_pretrained(self.model_name)
logger.info("Classifier loaded successfully")
except Exception as e:
logger.error(f"Error loading complexity classifier: {e}")
# Fallback to basic classification if model fails to load
self.classifier = None
def predict(self, text: str) -> List[Tuple[str, float]]:
"""
Predict the complexity label for a given text.
Args:
text: The query text to classify
Returns:
List of (label, score) tuples sorted by confidence
"""
if self.classifier is None:
logger.warning("Classifier not loaded. Using fallback classification.")
return self._fallback_classification(text)
try:
# Make prediction using the AdaptiveClassifier
predictions = self.classifier.predict(text)
logger.debug(f"Classifier predictions: {predictions}")
# Make sure predictions are in the expected format
if isinstance(predictions, list) and all(isinstance(p, tuple) and len(p) == 2 for p in predictions):
# Sort by confidence (assuming higher score = higher confidence)
predictions.sort(key=lambda x: x[1], reverse=True)
return predictions
else:
logger.warning(f"Unexpected prediction format: {predictions}")
return self._fallback_classification(text)
except Exception as e:
logger.error(f"Error during classification: {e}")
return self._fallback_classification(text)
def _fallback_classification(self, text: str) -> List[Tuple[str, float]]:
"""
Simple heuristic classification when model isn't available.
Args:
text: The query text
Returns:
List of (label, score) tuples
"""
# Count key indicators of complexity
complexity_indicators = [
"explain", "analyze", "compare", "evaluate", "synthesize",
"how", "why", "complex", "detail", "thorough", "comprehensive",
"step by step", "calculate", "prove", "justify", "multiple",
"consequences", "implications", "differentiate", "frameworks"
]
# Count mentions of complexity indicators
count = sum(1 for indicator in complexity_indicators if indicator.lower() in text.lower())
# Calculate complexity probability based on count and text length
text_length_factor = min(len(text) / 100, 2.0) # Cap at 2.0
indicator_factor = min(count / 3, 1.5) # Cap at 1.5
# Combined factor determines HIGH vs LOW
complexity_score = text_length_factor * indicator_factor
if complexity_score > 1.0:
return [("HIGH", 0.7), ("LOW", 0.3)]
else:
return [("LOW", 0.8), ("HIGH", 0.2)]
def is_high_complexity(self, text: str, threshold: float = 0.5) -> bool:
"""
Determine if a query is high complexity.
Args:
text: The query text
threshold: Confidence threshold for HIGH classification
Returns:
Boolean indicating if the query is high complexity
"""
predictions = self.predict(text)
for label, score in predictions:
if label == "HIGH" and score >= threshold:
return True
return False
def get_complexity_with_confidence(self, text: str) -> Tuple[str, float]:
"""
Get the complexity label and confidence score.
Args:
text: The query text
Returns:
Tuple of (complexity_label, confidence_score)
"""
predictions = self.predict(text)
return predictions[0] # Return highest confidence prediction

View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python3
"""
Example usage of AutoThink.
This script demonstrates how to use AutoThink with a language model.
"""
import torch
import argparse
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
from optillm.autothink import autothink_decode
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="Run AutoThink demo")
parser.add_argument("--model", type=str, default="deepseek-ai/deepseek-r1-llama-8b",
help="Model name or path")
parser.add_argument("--steering-dataset", type=str,
default="codelion/Qwen3-0.6B-pts-steering-vectors",
help="Steering vectors dataset")
parser.add_argument("--target-layer", type=int, default=19,
help="Target layer for steering")
parser.add_argument("--query", type=str,
default="Explain quantum computing to me in detail",
help="Query to process")
args = parser.parse_args()
# Load model and tokenizer
try:
logger.info(f"Loading model: {args.model}")
# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model)
# Load model with appropriate configuration based on device
model_kwargs = {"trust_remote_code": True}
if device == "cuda":
model_kwargs["torch_dtype"] = torch.float16
model_kwargs["device_map"] = "auto"
model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs)
# Ensure proper PAD token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Model and tokenizer loaded successfully")
# Create AutoThink configuration
config = {
"steering_dataset": args.steering_dataset,
"target_layer": args.target_layer,
"pattern_strengths": {
"depth_and_thoroughness": 2.5,
"numerical_accuracy": 2.0,
"self_correction": 3.0,
"exploration": 2.0,
"organization": 1.5
}
}
# Create messages
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": args.query}
]
# Process with AutoThink
logger.info("Running AutoThink processing...")
response = autothink_decode(model, tokenizer, messages, config)
# Print response
print("\n" + "=" * 80)
print("QUERY:", args.query)
print("-" * 80)
print(response)
print("=" * 80 + "\n")
except Exception as e:
logger.error(f"Error in AutoThink demo: {str(e)}")
raise
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,376 @@
"""
AutoThink processor implementation.
This module implements the AutoThink processor for controlled thinking
with query complexity classification and steering vectors.
"""
import torch
import random
import logging
from transformers import PreTrainedModel, PreTrainedTokenizer, DynamicCache
from typing import Dict, List, Any, Optional, Union, Tuple
from .classifier import ComplexityClassifier
from .steering import SteeringVectorManager, install_steering_hooks, remove_steering_hooks
logger = logging.getLogger(__name__)
# Default configurations
DEFAULT_CONFIG = {
# General configuration
"min_thinking_tokens": 256,
"max_thinking_tokens": 2048,
"max_thoughts": 64,
"prefill": "",
"start_think_token": "<think>",
"end_think_token": "</think>",
# Complexity-specific configurations
"high_complexity_min_tokens": 1024,
"high_complexity_max_tokens": 4096,
"low_complexity_min_tokens": 256,
"low_complexity_max_tokens": 1024,
# Thought switch tokens
"thought_switch_tokens": [
"Wait,",
"Alternatively,",
"However,",
"Additionally,",
"Let's consider,",
"On second thought,",
"Actually,",
"Furthermore,",
"Looking at it differently,",
"To be thorough,"
],
# Classifier configuration
"classifier_model": "adaptive-classifier/llm-router",
"complexity_threshold": 0.6,
# Steering configuration
"steering_dataset": "",
"target_layer": 19,
"pattern_strengths": {
"depth_and_thoroughness": 2.5,
"numerical_accuracy": 2.0,
"self_correction": 3.0,
"exploration": 2.0,
"organization": 1.5
}
}
class AutoThinkProcessor:
"""
AutoThink processor for controlled thinking with
complexity classification and steering vectors.
"""
def __init__(self, config: Dict[str, Any], tokenizer: PreTrainedTokenizer, model: PreTrainedModel):
"""
Initialize the AutoThink processor.
Args:
config: Configuration dictionary
tokenizer: Model tokenizer
model: Language model
"""
# Merge default config with provided config
self.config = {**DEFAULT_CONFIG, **config}
self.tokenizer = tokenizer
self.model = model
# Initialize classifier
self.classifier = ComplexityClassifier(self.config["classifier_model"])
# Get token IDs for think markers
start_tokens = self.tokenizer.encode(self.config['start_think_token'])
end_tokens = self.tokenizer.encode(self.config['end_think_token'])
self._start_think_token = start_tokens[0] if len(start_tokens) == 1 else start_tokens[1]
self.end_think_token = end_tokens[0] if len(end_tokens) == 1 else end_tokens[1]
# Store thought switch markers as token sequences
self.thought_switch_sequences = []
for phrase in self.config["thought_switch_tokens"]:
token_ids = self.tokenizer.encode(phrase, add_special_tokens=False)
self.thought_switch_sequences.append(token_ids)
logger.debug(f"Encoded '{phrase}' to token sequence: {token_ids}")
logger.debug(f"Decoded back: {self.tokenizer.decode(token_ids)}")
# Track thought switches
self.thought_count = 0
self.current_sequence = [] # Track recent tokens for sequence matching
self.max_sequence_length = max(len(seq) for seq in self.thought_switch_sequences)
# Initialize steering vector manager and hooks if dataset is provided
self.steering_manager = None
self.steering_hooks = []
if self.config["steering_dataset"]:
self._setup_steering()
def _setup_steering(self):
"""Set up steering vector management."""
try:
# Initialize steering vector manager
self.steering_manager = SteeringVectorManager(
dataset_name=self.config["steering_dataset"],
target_layer=self.config["target_layer"]
)
# Set pattern strengths
if "pattern_strengths" in self.config:
for pattern, strength in self.config["pattern_strengths"].items():
self.steering_manager.set_steering_strength(pattern, strength)
# Create tokenized contexts for efficient matching
self.steering_manager.create_tokenized_contexts(self.tokenizer)
# Install hooks on the model
self.steering_hooks = install_steering_hooks(
self.model,
self.steering_manager,
self.tokenizer
)
logger.info(f"Set up steering with {len(self.steering_hooks)} hooks")
except Exception as e:
logger.error(f"Error setting up steering: {e}")
self.steering_manager = None
self.steering_hooks = []
def _cleanup_steering(self):
"""Clean up steering hooks."""
if self.steering_hooks:
remove_steering_hooks(self.steering_hooks)
self.steering_hooks = []
def classify_complexity(self, query: str) -> Tuple[str, float]:
"""
Classify query complexity.
Args:
query: The query to classify
Returns:
Tuple of (complexity_label, confidence_score)
"""
complexity, confidence = self.classifier.get_complexity_with_confidence(query)
logger.info(f"Query classified as {complexity} with confidence {confidence:.2f}")
return complexity, confidence
def get_token_budget(self, complexity: str) -> Tuple[int, int]:
"""
Get token budget based on complexity.
Args:
complexity: Complexity label (HIGH or LOW)
Returns:
Tuple of (min_tokens, max_tokens)
"""
if complexity == "HIGH":
return (
self.config["high_complexity_min_tokens"],
self.config["high_complexity_max_tokens"]
)
else:
return (
self.config["low_complexity_min_tokens"],
self.config["low_complexity_max_tokens"]
)
def is_thought_switch(self, token: int) -> bool:
"""
Check if adding this token creates a thought switch sequence.
Args:
token: Token ID to check
Returns:
Boolean indicating if this completes a thought switch
"""
# Add new token to current sequence
self.current_sequence.append(token)
# Keep only the most recent tokens that could match our sequences
if len(self.current_sequence) > self.max_sequence_length:
self.current_sequence = self.current_sequence[-self.max_sequence_length:]
# Check if current sequence ends with any thought switch sequence
for sequence in self.thought_switch_sequences:
if len(sequence) <= len(self.current_sequence) and \
self.current_sequence[-len(sequence):] == sequence:
return True
return False
@torch.inference_mode()
def process(self, messages: List[Dict[str, str]]) -> str:
"""
Process messages with AutoThink's controlled thinking.
Args:
messages: List of message dictionaries
Returns:
Generated response
"""
try:
# Extract the query from the messages
query = self._extract_query(messages)
# Classify query complexity
complexity, confidence = self.classify_complexity(query)
# Get token budget based on complexity
min_tokens, max_tokens = self.get_token_budget(complexity)
logger.info(f"Using token budget: {min_tokens}-{max_tokens} for {complexity} complexity")
# Prepare messages with thinking start token
thinking_messages = messages.copy()
thinking_messages.append({
"role": "assistant",
"content": f"{self.config['start_think_token']}\n{self.config['prefill']}"
})
# Tokenize the messages
tokens = self.tokenizer.apply_chat_template(
thinking_messages,
continue_final_message=True,
return_tensors="pt"
).to(self.model.device)
# Update token history in steering hooks
if self.steering_hooks:
token_ids = tokens[0].tolist()
for hook, _ in self.steering_hooks:
hook.update_token_history(token_ids)
# Try to match with a steering vector
hook.try_match()
# Generate with controlled thinking
kv = DynamicCache()
n_thinking_tokens = 0
seen_end_think = False
response_chunks = []
while True:
out = self.model(input_ids=tokens, past_key_values=kv, use_cache=True)
logits = out.logits[0, -1, :]
# Check if we need to force end thinking
force_end = (n_thinking_tokens >= max_tokens or
self.thought_count >= self.config["max_thoughts"])
if force_end and not seen_end_think:
logger.debug(f"Forcing end think token. Tokens: {n_thinking_tokens}, Thoughts: {self.thought_count}")
next_token = self.end_think_token
response_chunks.append(self.tokenizer.decode([next_token]))
seen_end_think = True
tokens = torch.tensor([[next_token]]).to(tokens.device)
continue
else:
next_token = torch.multinomial(
torch.softmax(logits, dim=-1), 1
).item()
kv = out.past_key_values
next_str = self.tokenizer.decode([next_token])
# Update steering hooks with new token
if self.steering_hooks:
for hook, _ in self.steering_hooks:
hook.update_token_history([next_token])
# Check if this is a thought-switching token (only if not in conclusion phase)
if not seen_end_think and self.is_thought_switch(next_token):
self.thought_count += 1
logger.debug(f"Detected thought switch marker. Total thoughts: {self.thought_count}")
self.current_sequence = []
# Handle natural end think token
if next_token == self.end_think_token:
seen_end_think = True
logger.debug("Found end think token")
# If we haven't reached minimum tokens, continue with thought transition
if n_thinking_tokens < min_tokens:
replacement = random.choice(self.config["thought_switch_tokens"])
logger.debug(f"Inserting thought transition: '{replacement}' (tokens: {n_thinking_tokens})")
response_chunks.append(replacement)
replacement_tokens = self.tokenizer.encode(replacement)
n_thinking_tokens += len(replacement_tokens)
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
self.thought_count += 1
seen_end_think = False
continue
# Handle EOS token
if next_token == self.model.config.eos_token_id:
logger.debug("Found EOS token")
if seen_end_think:
logger.debug("Reached EOS after end think token - stopping generation")
response_chunks.append(next_str)
break
elif n_thinking_tokens < min_tokens:
# Continue with thought transition if under minimum tokens
replacement = random.choice(self.config["thought_switch_tokens"])
logger.debug(f"Inserting thought transition: '{replacement}' (tokens: {n_thinking_tokens})")
response_chunks.append(replacement)
replacement_tokens = self.tokenizer.encode(replacement)
n_thinking_tokens += len(replacement_tokens)
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
self.thought_count += 1
continue
else:
# Force end think token and continue generating for natural conclusion
logger.debug("Reached EOS without end think token - adding end token and continuing generation")
response_chunks.append(self.tokenizer.decode([self.end_think_token]))
tokens = torch.tensor([[self.end_think_token]]).to(tokens.device)
seen_end_think = True
continue
# Normal token processing
response_chunks.append(next_str)
if not seen_end_think:
n_thinking_tokens += 1
tokens = torch.tensor([[next_token]]).to(tokens.device)
# Clean up steering hooks
self._cleanup_steering()
# Join all chunks and add framing tokens
response = "".join(response_chunks)
full_response = f"{self.config['start_think_token']}\n{self.config['prefill']}{response}"
logger.debug(f"Final response length: {len(full_response)} chars, Total thoughts: {self.thought_count}")
return full_response
except Exception as e:
# Clean up steering hooks in case of error
self._cleanup_steering()
logger.error(f"Error in AutoThink processing: {str(e)}")
raise
def _extract_query(self, messages: List[Dict[str, str]]) -> str:
"""
Extract the query from messages for classification.
Args:
messages: List of message dictionaries
Returns:
Extracted query string
"""
# Get the last user message
user_messages = [m["content"] for m in messages if m["role"] == "user"]
if user_messages:
return user_messages[-1]
# Fallback to concatenated messages
return " ".join(m["content"] for m in messages)

View File

@@ -0,0 +1,603 @@
"""
Steering vector manager for AutoThink.
This module provides functionality to load and apply steering vectors
from Hugging Face datasets during inference.
"""
import torch
import logging
import random
import json
import datasets
from typing import Dict, List, Any, Tuple, Optional, Union
from collections import defaultdict
logger = logging.getLogger(__name__)
class SteeringVectorManager:
"""
Manager for loading and applying steering vectors from a dataset.
"""
def __init__(
self,
dataset_name: str,
target_layer: int = 19,
cache_dir: Optional[str] = None,
device: Optional[str] = None
):
"""
Initialize the steering vector manager.
Args:
dataset_name: Name of the HuggingFace dataset containing steering vectors
target_layer: Target layer for applying steering vectors
cache_dir: Directory for caching the dataset
device: Device to use for tensors
"""
self.dataset_name = dataset_name
self.target_layer = target_layer
self.cache_dir = cache_dir
self.device = device or (
"cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu"
)
# Storage for steering vectors
self.steering_vectors = []
self.pattern_to_vectors = {}
self.tokenized_contexts = {}
# Default steering strengths
self.default_strength = 2.0
self.pattern_strengths = {
"depth_and_thoroughness": 2.5,
"numerical_accuracy": 2.0,
"self_correction": 3.0,
"exploration": 2.0,
"organization": 1.5,
"unknown": 1.0
}
# If dataset is provided, load it
if dataset_name:
self.load_dataset()
def load_dataset(self):
"""Load steering vectors from the HuggingFace dataset."""
try:
logger.info(f"Loading steering vectors from dataset: {self.dataset_name}")
# Load the dataset
dataset = datasets.load_dataset(self.dataset_name, cache_dir=self.cache_dir)
# Get the main split (usually 'train')
main_split = list(dataset.keys())[0]
vector_data = dataset[main_split]
# Load each item as a steering vector
for item in vector_data:
# Convert dataset item to proper format
vector = self._process_dataset_item(item)
if vector:
self.steering_vectors.append(vector)
# Group by reasoning pattern
pattern = vector.get("reasoning_pattern", "unknown")
if pattern not in self.pattern_to_vectors:
self.pattern_to_vectors[pattern] = []
self.pattern_to_vectors[pattern].append(vector)
logger.info(f"Loaded {len(self.steering_vectors)} steering vectors")
logger.info(f"Found {len(self.pattern_to_vectors)} reasoning patterns: {list(self.pattern_to_vectors.keys())}")
# Log the first vector for debugging
if self.steering_vectors:
first_vector = self.steering_vectors[0]
logger.info(f"First vector sample - pattern: {first_vector.get('reasoning_pattern', 'missing')}")
if 'pivot_context' in first_vector:
context_len = len(first_vector['pivot_context'])
logger.info(f"First vector pivot_context length: {context_len}")
except Exception as e:
logger.error(f"Error loading steering vectors: {e}")
self.steering_vectors = []
self.pattern_to_vectors = {}
def _process_dataset_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
"""
Process a dataset item into a steering vector.
Args:
item: Dataset item
Returns:
Processed steering vector or None if invalid
"""
try:
# Check if item has the required fields
required_fields = ["pivot_context", "steering_vector", "reasoning_pattern"]
if not all(field in item for field in required_fields):
return None
# Convert steering_vector to a proper format if it's a string or list
steering_vector = item["steering_vector"]
if isinstance(steering_vector, str):
# Try to parse JSON string
try:
steering_vector = json.loads(steering_vector)
except json.JSONDecodeError:
# Try comma-separated format
steering_vector = [float(x) for x in steering_vector.strip("[]").split(",")]
# Ensure we have a proper list
if not isinstance(steering_vector, list):
logger.warning(f"Invalid steering vector format: {type(steering_vector)}")
return None
# Create the steering vector dictionary
vector = {
"pivot_context": item["pivot_context"],
"pivot_token": item.get("pivot_token", ""),
"pivot_token_id": item.get("pivot_token_id", -1),
"prob_before": item.get("prob_before", 0.0),
"prob_after": item.get("prob_after", 0.0),
"prob_delta": item.get("prob_delta", 0.0),
"model_id": item.get("model_id", ""),
"task_type": item.get("task_type", "unknown"),
"steering_vector": steering_vector,
"cluster_id": item.get("cluster_id", -1),
"reasoning_pattern": item.get("reasoning_pattern", "unknown"),
"cluster_vector": item.get("cluster_vector", steering_vector),
"steering_layer": item.get("steering_layer", self.target_layer),
}
return vector
except Exception as e:
logger.error(f"Error processing dataset item: {e}")
return None
def create_tokenized_contexts(self, tokenizer):
"""
Pre-tokenize context patterns for efficient matching during generation.
Args:
tokenizer: Tokenizer for encoding contexts
"""
# Get configurations
max_pts_tokens = 256 # Maximum tokens to store for matching
count = 0
for vector in self.steering_vectors:
# Get the context
context = vector.get("pivot_context", "")
if not context:
continue
# Pre-tokenize the context for faster matching
tokenized_context = tokenizer.encode(context, add_special_tokens=False)
# Keep only up to max_pts_tokens
if len(tokenized_context) > max_pts_tokens:
tokenized_context = tokenized_context[-max_pts_tokens:]
# Store the tokenized context with its vector
tuple_key = tuple(tokenized_context)
self.tokenized_contexts[tuple_key] = vector
# Store additional shorter versions for partial matching
for suffix_len in [4, 8, 12]:
if len(tokenized_context) > suffix_len:
suffix = tokenized_context[-suffix_len:]
suffix_tuple = tuple(suffix)
if suffix_tuple not in self.tokenized_contexts:
self.tokenized_contexts[suffix_tuple] = vector
count += 1
# Log statistics
logger.info(f"Pre-tokenized {count} contexts into {len(self.tokenized_contexts)} token patterns")
# Count patterns by length for debugging
length_counts = {}
for key in self.tokenized_contexts.keys():
length = len(key)
if length not in length_counts:
length_counts[length] = 0
length_counts[length] += 1
logger.info(f"Token pattern length distribution: {sorted(length_counts.items())}")
def get_steering_strength(self, pattern: str) -> float:
"""
Get the steering strength for a specific pattern.
Args:
pattern: The reasoning pattern
Returns:
The steering strength
"""
return self.pattern_strengths.get(pattern, self.default_strength)
def set_steering_strength(self, pattern: str, strength: float):
"""
Set the steering strength for a specific pattern.
Args:
pattern: The reasoning pattern
strength: The steering strength
"""
self.pattern_strengths[pattern] = strength
logger.info(f"Set strength for {pattern} to {strength}")
def get_pattern_vectors(self, pattern: str) -> List[Dict[str, Any]]:
"""
Get all steering vectors for a specific reasoning pattern.
Args:
pattern: The reasoning pattern
Returns:
List of steering vectors
"""
return self.pattern_to_vectors.get(pattern, [])
class SteeringHook:
"""Hook for applying steering vectors during generation."""
def __init__(self, manager: SteeringVectorManager, layer_num: int, tokenizer=None):
"""
Initialize the steering hook.
Args:
manager: The steering vector manager
layer_num: The layer number to apply steering to
tokenizer: Tokenizer for token-based matching
"""
self.manager = manager
self.layer_num = layer_num
self.tokenizer = tokenizer
# For token-based matching
self.token_history = [] # Store token IDs for matching
self.max_history = 256 # Maximum tokens to keep in history
# State tracking
self.match_found = False
self.current_vector = None
self.last_pattern = None
# Single pattern for entire request
self.active_pattern = None # Currently active pattern
self.generation_started = False
logger.info(f"Initialized hook for layer {layer_num}")
def __call__(self, module, input_tensors, output):
"""
Apply steering to the output of a layer.
Args:
module: The module being hooked
input_tensors: The input tensors
output: The output tensor
Returns:
Modified output tensor
"""
try:
# Skip if no active pattern is set
if not self.active_pattern:
return output
# Apply steering vector if available
if self.current_vector is not None:
# Get the appropriate steering strength
pattern = self.current_vector.get("reasoning_pattern", "unknown")
strength = self.manager.get_steering_strength(pattern)
# Keep strength within safe bounds
safe_strength = min(max(strength, 0.1), 2.0)
# Log when pattern changes
if pattern != self.last_pattern:
logger.info(f"Switching to {pattern} reasoning pattern with strength {safe_strength}")
self.last_pattern = pattern
# Apply the steering vector
try:
if isinstance(output, tuple):
# Some models return a tuple
hidden_states = output[0]
modified_hidden_states = self._apply_steering_vector(hidden_states, self.current_vector, safe_strength)
# Validate the result
if modified_hidden_states.shape == hidden_states.shape:
return (modified_hidden_states,) + output[1:]
else:
logger.error(f"Modified hidden states have wrong shape. Expected {hidden_states.shape}, got {modified_hidden_states.shape}")
return output
else:
# Direct tensor output
return self._apply_steering_vector(output, self.current_vector, safe_strength)
except Exception as e:
logger.error(f"Error applying steering: {e}")
return output
return output
except Exception as e:
logger.error(f"Critical error in hook: {e}")
return output
def _apply_steering_vector(self, hidden_states: torch.Tensor,
steering_vector: Dict[str, Any],
scaling_factor: float = 2.0) -> torch.Tensor:
"""
Apply a steering vector to hidden states.
Args:
hidden_states: The hidden states tensor
steering_vector: Dictionary with steering vector data
scaling_factor: Factor to scale the steering vector by
Returns:
Modified hidden states tensor
"""
try:
# Make a deep clone
hidden_states_clone = hidden_states.clone().detach()
# Check what kind of vector we're using
vector_data = None
if "steering_vector" in steering_vector:
vector_data = steering_vector["steering_vector"]
vector_type = "steering_vector"
elif "cluster_vector" in steering_vector:
vector_data = steering_vector["cluster_vector"]
vector_type = "cluster_vector"
else:
logger.warning("No valid vector found in steering data")
return hidden_states
# Convert vector to tensor
vector = torch.tensor(vector_data,
dtype=hidden_states.dtype,
device=hidden_states.device)
# Log vector info
pattern = steering_vector.get("reasoning_pattern", "unknown")
logger.debug(f"Applying {vector_type} for pattern '{pattern}' with scaling {scaling_factor}")
# Apply scaling based on prob_delta if available
if "prob_delta" in steering_vector:
prob_delta = abs(steering_vector["prob_delta"])
prob_delta_capped = min(max(prob_delta, 0.1), 2.0)
scaling_factor *= prob_delta_capped
# Check if the token is positive or negative
is_positive = steering_vector.get("is_positive", True)
# Verify shapes are compatible
hs_shape = hidden_states.shape
vector_shape = vector.shape
if len(vector_shape) != 1 or vector_shape[0] != hs_shape[-1]:
logger.error(f"Shape mismatch - hidden_states: {hs_shape}, vector: {vector_shape}")
return hidden_states
# Bound scaling factor for safety
safe_scaling = min(max(scaling_factor, 0.0), 3.0)
# Apply steering
if len(hs_shape) >= 3 and hs_shape[0] > 0 and hs_shape[1] > 0:
# Apply to the last token's representation
if is_positive:
# Normalize vector to prevent numerical instability
vector_norm = torch.nn.functional.normalize(vector, dim=0)
hidden_states_clone[-1, -1, :] = hidden_states_clone[-1, -1, :] + safe_scaling * vector_norm
else:
vector_norm = torch.nn.functional.normalize(vector, dim=0)
hidden_states_clone[-1, -1, :] = hidden_states_clone[-1, -1, :] - safe_scaling * vector_norm
# Check for NaN or inf values
if torch.isnan(hidden_states_clone).any() or torch.isinf(hidden_states_clone).any():
logger.error("NaN or inf values detected after applying vector, reverting to original")
return hidden_states
else:
logger.error(f"Hidden states shape not suitable for steering: {hs_shape}")
return hidden_states
return hidden_states_clone
except Exception as e:
logger.error(f"Unexpected error applying steering vector: {e}")
return hidden_states
def update_token_history(self, new_tokens: List[int]):
"""
Update the token history with new tokens.
Args:
new_tokens: New token IDs to add
"""
# Add to token history
self.token_history.extend(new_tokens)
# Trim history if needed
if len(self.token_history) > self.max_history:
self.token_history = self.token_history[-self.max_history:]
# Log token updates periodically
if random.random() < 0.01:
logger.debug(f"Token history updated, now has {len(self.token_history)} tokens")
def try_match(self) -> bool:
"""
Try to match the current context with a steering vector.
Returns:
Boolean indicating if a match was found
"""
# If we already have an active pattern, don't try to match again
if self.generation_started and self.active_pattern:
return False
# Only attempt pattern matching at the beginning of generation
self.generation_started = True
# Try token-based matching
match_result = self._try_token_match()
# If a match is found, set this as the permanent pattern for this generation
if match_result and self.current_vector:
new_pattern = self.current_vector.get("reasoning_pattern", "unknown")
self.active_pattern = new_pattern
logger.info(f"Selected '{new_pattern}' pattern for this request")
return match_result
def _try_token_match(self) -> bool:
"""
Try to match using token-based context.
Returns:
Boolean indicating if a match was found
"""
# Ensure we have enough tokens
if len(self.token_history) < 4:
return False
# Track best match
best_match = {
'length': 0,
'vector': None,
'is_partial': True
}
# Check for matches in tokenized contexts
for tokenized_context, vector in self.manager.tokenized_contexts.items():
token_list = list(tokenized_context)
token_len = len(token_list)
# Try partial matching for shorter contexts
if len(self.token_history) < token_len:
# Only try partial matching if we have enough context tokens
if len(self.token_history) >= 4:
# Calculate how many tokens to match
match_len = min(len(self.token_history), max(4, token_len // 2))
# Try to match the end of the token sequence
if self.token_history[-match_len:] == token_list[-match_len:]:
# Track this match - prefer longer matches
if match_len > best_match['length']:
best_match = {
'length': match_len,
'vector': vector,
'is_partial': True,
'match_len': match_len,
'token_len': token_len
}
else:
# Full matching when we have enough tokens
if self.token_history[-token_len:] == token_list:
# Track this match - full matches are preferred
if token_len >= best_match['length']:
best_match = {
'length': token_len,
'vector': vector,
'is_partial': False,
'match_len': token_len,
'token_len': token_len
}
# Apply best match if found
if best_match['vector'] is not None:
match_type = "PARTIAL" if best_match['is_partial'] else "FULL"
self.match_found = True
self.current_vector = best_match['vector']
pattern = best_match['vector'].get("reasoning_pattern", "unknown")
logger.info(f"Found {match_type} token match ({best_match['match_len']}/{best_match['token_len']} tokens) for {pattern} pattern")
return True
return False
def reset(self):
"""Reset the hook state."""
self.match_found = False
self.current_vector = None
self.token_history = []
self.last_pattern = None
self.active_pattern = None
self.generation_started = False
def install_steering_hooks(model, manager: SteeringVectorManager, tokenizer=None) -> List[Tuple]:
"""
Install steering hooks on a model.
Args:
model: The model to install hooks on
manager: The steering vector manager
tokenizer: Tokenizer for token-based matching
Returns:
List of installed hooks
"""
hooks = []
# Target layer is specified in the manager
layer_num = manager.target_layer
logger.info(f"Attempting to install hook on layer {layer_num}")
# First, log model structure to help with debugging
model_type = type(model).__name__
logger.info(f"Model type is {model_type}")
# Find the appropriate module - depends on model architecture
module = None
if hasattr(model, 'transformer'):
logger.info("Model has 'transformer' attribute")
if hasattr(model.transformer, 'h') and layer_num < len(model.transformer.h):
module = model.transformer.h[layer_num]
logger.info(f"Using transformer.h[{layer_num}]")
elif hasattr(model, 'model'):
logger.info("Model has 'model' attribute")
if hasattr(model.model, 'layers') and layer_num < len(model.model.layers):
module = model.model.layers[layer_num]
logger.info(f"Using model.layers[{layer_num}]")
elif hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers') and layer_num < len(model.model.decoder.layers):
module = model.model.decoder.layers[layer_num]
logger.info(f"Using model.decoder.layers[{layer_num}]")
elif hasattr(model, 'layers') and layer_num < len(model.layers):
module = model.layers[layer_num]
logger.info(f"Using layers[{layer_num}]")
if module is None:
logger.error(f"Could not find appropriate module for layer {layer_num}")
logger.error("Model structure not compatible with current hook installation logic")
return []
# Create and register hook
hook = SteeringHook(manager, layer_num, tokenizer)
handle = module.register_forward_hook(hook)
# Return both hook object and handle for later removal
hooks.append((hook, handle))
logger.info(f"Installed hook on layer {layer_num} successfully")
return hooks
def remove_steering_hooks(hooks):
"""
Remove steering hooks from a model.
Args:
hooks: List of (hook, handle) tuples
"""
for _, handle in hooks:
handle.remove()
logger.info(f"Removed {len(hooks)} hooks")

View File

@@ -20,6 +20,7 @@ import traceback
from optillm.cot_decoding import cot_decode
from optillm.entropy_decoding import entropy_decode
from optillm.thinkdeeper import thinkdeeper_decode
from optillm.autothink import autothink_decode
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -1467,6 +1468,34 @@ class InferenceClient:
responses = [result]
logprobs_results = [None]
completion_tokens = len(pipeline.tokenizer.encode(result))
elif decoding == "autothink":
# Get steering dataset configuration
steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors")
target_layer = kwargs.get("target_layer", 19)
# Prepare AutoThink configuration
autothink_config = {
"steering_dataset": steering_dataset,
"target_layer": target_layer,
"pattern_strengths": kwargs.get("pattern_strengths", {
"depth_and_thoroughness": 2.5,
"numerical_accuracy": 2.0,
"self_correction": 3.0,
"exploration": 2.0,
"organization": 1.5
})
}
# Process with AutoThink
result = autothink_decode(
pipeline.current_model,
pipeline.tokenizer,
messages,
autothink_config
)
responses = [result]
logprobs_results = [None]
completion_tokens = len(pipeline.tokenizer.encode(result))
else:
raise ValueError(f"Unknown specialized decoding approach: {decoding}")

View File

@@ -27,4 +27,5 @@ spacy<3.8.0
cerebras_cloud_sdk
outlines[transformers]
sentencepiece
adaptive-classifier
mcp

View File

@@ -45,6 +45,7 @@ setup(
"outlines[transformers]",
"sentencepiece",
"mcp",
"adaptive-classifier",
],
entry_points={
'console_scripts': [