mirror of
https://github.com/codelion/optillm.git
synced 2025-05-28 09:39:38 +03:00
init implementation
This commit is contained in:
@@ -343,7 +343,7 @@ Check this log file for connection issues, tool execution errors, and other diag
|
||||
|
||||
| Approach | Slug | Description |
|
||||
| ------------------------------------ | ------------------ | ---------------------------------------------------------------------------------------------- |
|
||||
| Cerebras Planning and Optimization | `cepo` | Combines Best of N, Chain-of-Thought, Self-Reflection, Self-Improvement, and various prompting techniques |
|
||||
| Cerebras Planning and Optimization | `cepo` | Combines Best of N, Chain-of-Thought, Self-Reflection, Self-Improvement, and various prompting techniques |
|
||||
| CoT with Reflection | `cot_reflection` | Implements chain-of-thought reasoning with \<thinking\>, \<reflection> and \<output\> sections |
|
||||
| PlanSearch | `plansearch` | Implements a search algorithm over candidate plans for solving a problem in natural language |
|
||||
| ReRead | `re2` | Implements rereading to improve reasoning by processing queries twice |
|
||||
@@ -359,6 +359,7 @@ Check this log file for connection issues, tool execution errors, and other diag
|
||||
| CoT Decoding | N/A for proxy | Implements chain-of-thought decoding to elicit reasoning without explicit prompting |
|
||||
| Entropy Decoding | N/A for proxy | Implements adaptive sampling based on the uncertainty of tokens during generation |
|
||||
| Thinkdeeper | N/A for proxy | Implements the `reasoning_effort` param from OpenAI for reasoning models like DeepSeek R1 |
|
||||
| AutoThink | N/A for proxy | Combines query complexity classification with steering vectors to enhance reasoning |
|
||||
|
||||
## Implemented plugins
|
||||
|
||||
|
||||
95
optillm/autothink/README.md
Normal file
95
optillm/autothink/README.md
Normal file
@@ -0,0 +1,95 @@
|
||||
# AutoThink
|
||||
|
||||
AutoThink is an adaptive thinking approach for Large Language Models that combines query complexity classification with steering vector guidance to enhance model reasoning capabilities.
|
||||
|
||||
## Overview
|
||||
|
||||
AutoThink combines several advanced techniques to optimize the thinking process of LLMs:
|
||||
|
||||
1. **Query Complexity Classification**: Uses an adaptive classifier to determine if a query requires HIGH or LOW complexity reasoning
|
||||
2. **Token Budget Allocation**: Dynamically allocates thinking tokens based on query complexity
|
||||
3. **Steering Vector Guidance**: Applies activation-based steering vectors to guide the model's reasoning process
|
||||
4. **Controlled Thinking Process**: Manages explicit thinking phases with start and end tokens
|
||||
|
||||
## How It Works
|
||||
|
||||
### 1. Query Classification
|
||||
|
||||
AutoThink uses the `adaptive-classifier/llm-router` model to classify incoming queries:
|
||||
|
||||
- **HIGH**: Complex queries requiring deep reasoning, multi-step calculations, or thorough exploration
|
||||
- **LOW**: Simpler queries requiring less extensive reasoning
|
||||
|
||||
### 2. Token Budget
|
||||
|
||||
Based on the classification, AutoThink allocates different token budgets for the thinking phase:
|
||||
|
||||
- **HIGH**: 70-90% of max tokens allocated for thinking
|
||||
- **LOW**: 20-40% of max tokens allocated for thinking
|
||||
|
||||
### 3. Steering Vectors
|
||||
|
||||
AutoThink uses pre-extracted steering vectors from datasets like `codelion/Qwen3-0.6B-pts-steering-vectors`. These vectors represent different reasoning patterns:
|
||||
|
||||
- **Depth and thoroughness**: Encourages detailed, step-by-step reasoning
|
||||
- **Numerical accuracy**: Promotes precise calculations and verification
|
||||
- **Self-correction**: Facilitates error detection and correction
|
||||
- **Exploration**: Supports considering multiple approaches
|
||||
- **Organization**: Improves logical structure in responses
|
||||
|
||||
During inference, the model's internal activations are modified based on these vectors to enhance specific reasoning capabilities.
|
||||
|
||||
### 4. Controlled Thinking Process
|
||||
|
||||
The generation process includes:
|
||||
1. A thinking phase marked by `<think>` and `</think>` tokens
|
||||
2. Automatic adjustment of thinking time based on query complexity
|
||||
3. Dynamic application of steering vectors
|
||||
4. Graceful transition to the final response
|
||||
|
||||
## Configuration
|
||||
|
||||
AutoThink can be configured with:
|
||||
|
||||
```python
|
||||
{
|
||||
"model_name": "your-model-name",
|
||||
"classifier_model": "adaptive-classifier/llm-router",
|
||||
"steering_dataset": "codelion/Qwen3-0.6B-pts-steering-vectors",
|
||||
"target_layer": 19, # Layer to apply steering vectors
|
||||
"high_complexity_min_tokens": 1024,
|
||||
"high_complexity_max_tokens": 4096,
|
||||
"low_complexity_min_tokens": 256,
|
||||
"low_complexity_max_tokens": 1024,
|
||||
"pattern_strengths": {
|
||||
"depth_and_thoroughness": 2.5, # Steering strength for different patterns
|
||||
"numerical_accuracy": 2.0,
|
||||
"self_correction": 3.0,
|
||||
"exploration": 2.0,
|
||||
"organization": 1.5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from optillm.autothink import autothink_decode
|
||||
|
||||
response = autothink_decode(
|
||||
model,
|
||||
tokenizer,
|
||||
messages,
|
||||
{
|
||||
"steering_dataset": "codelion/Qwen3-0.6B-pts-steering-vectors",
|
||||
"target_layer": 19
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
- **Adaptive Resource Usage**: Models think more on complex problems and less on simple ones
|
||||
- **Enhanced Reasoning**: Steering vectors guide the model toward better reasoning patterns
|
||||
- **Efficiency**: Better performance without increasing model size
|
||||
- **Customizability**: Can be tailored for different domains using domain-specific steering vector datasets
|
||||
7
optillm/autothink/__init__.py
Normal file
7
optillm/autothink/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
AutoThink - Adaptive thinking approach for LLMs with query complexity classification and steering vectors.
|
||||
"""
|
||||
|
||||
from .autothink import autothink_decode, AutoThinkProcessor
|
||||
|
||||
__all__ = ["autothink_decode", "AutoThinkProcessor"]
|
||||
88
optillm/autothink/autothink.py
Normal file
88
optillm/autothink/autothink.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
AutoThink main implementation.
|
||||
|
||||
This module provides the main implementation of AutoThink, combining
|
||||
query complexity classification with steering vectors to enhance reasoning.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from .processor import AutoThinkProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AutoThinkProcessor:
|
||||
"""
|
||||
Main AutoThink processor class for external use.
|
||||
Wraps the internal processor implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, config: Dict[str, Any] = None):
|
||||
"""
|
||||
Initialize the AutoThink processor.
|
||||
|
||||
Args:
|
||||
model: Language model
|
||||
tokenizer: Model tokenizer
|
||||
config: Configuration dictionary
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.processor = None
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Process messages with AutoThink's controlled thinking.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Generated response
|
||||
"""
|
||||
# Create processor on first use to allow for model loading
|
||||
if self.processor is None:
|
||||
self.processor = self._create_processor()
|
||||
|
||||
return self.processor.process(messages)
|
||||
|
||||
def _create_processor(self):
|
||||
"""Create the internal processor instance."""
|
||||
return AutoThinkProcessor(self.config, self.tokenizer, self.model)
|
||||
|
||||
def autothink_decode(
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
messages: List[Dict[str, str]],
|
||||
request_config: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Main plugin execution function with AutoThink's controlled thinking process.
|
||||
|
||||
Args:
|
||||
model: Language model
|
||||
tokenizer: Model tokenizer
|
||||
messages: List of message dictionaries
|
||||
request_config: Optional configuration dictionary
|
||||
|
||||
Returns:
|
||||
Generated response with thinking process
|
||||
"""
|
||||
logger.info("Starting AutoThink processing")
|
||||
|
||||
# Create config dictionary
|
||||
config = {}
|
||||
if request_config:
|
||||
config.update(request_config)
|
||||
|
||||
try:
|
||||
processor = AutoThinkProcessor(model, tokenizer, config)
|
||||
response = processor(messages)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in AutoThink processing: {str(e)}")
|
||||
raise
|
||||
152
optillm/autothink/classifier.py
Normal file
152
optillm/autothink/classifier.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Query complexity classifier for AutoThink.
|
||||
|
||||
This module provides functionality to classify queries as HIGH or LOW complexity
|
||||
using the adaptive-classifier model.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Tuple, Optional, List, Union
|
||||
import os
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ComplexityClassifier:
|
||||
"""
|
||||
Classifies queries as HIGH or LOW complexity for token budget allocation.
|
||||
Uses the adaptive-classifier model for classification.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "adaptive-classifier/llm-router"):
|
||||
"""
|
||||
Initialize the complexity classifier.
|
||||
|
||||
Args:
|
||||
model_name: HuggingFace model name or path for the classifier
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.classifier = None
|
||||
|
||||
# Load model
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load the classification model using adaptive-classifier library."""
|
||||
try:
|
||||
# Check if adaptive-classifier is installed
|
||||
try:
|
||||
import adaptive_classifier
|
||||
except ImportError:
|
||||
logger.info("Installing adaptive-classifier library...")
|
||||
os.system(f"{sys.executable} -m pip install adaptive-classifier")
|
||||
import adaptive_classifier
|
||||
|
||||
# Import the AdaptiveClassifier class
|
||||
from adaptive_classifier import AdaptiveClassifier
|
||||
|
||||
logger.info(f"Loading complexity classifier model: {self.model_name}")
|
||||
self.classifier = AdaptiveClassifier.from_pretrained(self.model_name)
|
||||
logger.info("Classifier loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading complexity classifier: {e}")
|
||||
# Fallback to basic classification if model fails to load
|
||||
self.classifier = None
|
||||
|
||||
def predict(self, text: str) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Predict the complexity label for a given text.
|
||||
|
||||
Args:
|
||||
text: The query text to classify
|
||||
|
||||
Returns:
|
||||
List of (label, score) tuples sorted by confidence
|
||||
"""
|
||||
if self.classifier is None:
|
||||
logger.warning("Classifier not loaded. Using fallback classification.")
|
||||
return self._fallback_classification(text)
|
||||
|
||||
try:
|
||||
# Make prediction using the AdaptiveClassifier
|
||||
predictions = self.classifier.predict(text)
|
||||
logger.debug(f"Classifier predictions: {predictions}")
|
||||
|
||||
# Make sure predictions are in the expected format
|
||||
if isinstance(predictions, list) and all(isinstance(p, tuple) and len(p) == 2 for p in predictions):
|
||||
# Sort by confidence (assuming higher score = higher confidence)
|
||||
predictions.sort(key=lambda x: x[1], reverse=True)
|
||||
return predictions
|
||||
else:
|
||||
logger.warning(f"Unexpected prediction format: {predictions}")
|
||||
return self._fallback_classification(text)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during classification: {e}")
|
||||
return self._fallback_classification(text)
|
||||
|
||||
def _fallback_classification(self, text: str) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Simple heuristic classification when model isn't available.
|
||||
|
||||
Args:
|
||||
text: The query text
|
||||
|
||||
Returns:
|
||||
List of (label, score) tuples
|
||||
"""
|
||||
# Count key indicators of complexity
|
||||
complexity_indicators = [
|
||||
"explain", "analyze", "compare", "evaluate", "synthesize",
|
||||
"how", "why", "complex", "detail", "thorough", "comprehensive",
|
||||
"step by step", "calculate", "prove", "justify", "multiple",
|
||||
"consequences", "implications", "differentiate", "frameworks"
|
||||
]
|
||||
|
||||
# Count mentions of complexity indicators
|
||||
count = sum(1 for indicator in complexity_indicators if indicator.lower() in text.lower())
|
||||
|
||||
# Calculate complexity probability based on count and text length
|
||||
text_length_factor = min(len(text) / 100, 2.0) # Cap at 2.0
|
||||
indicator_factor = min(count / 3, 1.5) # Cap at 1.5
|
||||
|
||||
# Combined factor determines HIGH vs LOW
|
||||
complexity_score = text_length_factor * indicator_factor
|
||||
|
||||
if complexity_score > 1.0:
|
||||
return [("HIGH", 0.7), ("LOW", 0.3)]
|
||||
else:
|
||||
return [("LOW", 0.8), ("HIGH", 0.2)]
|
||||
|
||||
def is_high_complexity(self, text: str, threshold: float = 0.5) -> bool:
|
||||
"""
|
||||
Determine if a query is high complexity.
|
||||
|
||||
Args:
|
||||
text: The query text
|
||||
threshold: Confidence threshold for HIGH classification
|
||||
|
||||
Returns:
|
||||
Boolean indicating if the query is high complexity
|
||||
"""
|
||||
predictions = self.predict(text)
|
||||
|
||||
for label, score in predictions:
|
||||
if label == "HIGH" and score >= threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_complexity_with_confidence(self, text: str) -> Tuple[str, float]:
|
||||
"""
|
||||
Get the complexity label and confidence score.
|
||||
|
||||
Args:
|
||||
text: The query text
|
||||
|
||||
Returns:
|
||||
Tuple of (complexity_label, confidence_score)
|
||||
"""
|
||||
predictions = self.predict(text)
|
||||
return predictions[0] # Return highest confidence prediction
|
||||
98
optillm/autothink/example.py
Normal file
98
optillm/autothink/example.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example usage of AutoThink.
|
||||
|
||||
This script demonstrates how to use AutoThink with a language model.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
import logging
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from optillm.autothink import autothink_decode
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run AutoThink demo")
|
||||
parser.add_argument("--model", type=str, default="deepseek-ai/deepseek-r1-llama-8b",
|
||||
help="Model name or path")
|
||||
parser.add_argument("--steering-dataset", type=str,
|
||||
default="codelion/Qwen3-0.6B-pts-steering-vectors",
|
||||
help="Steering vectors dataset")
|
||||
parser.add_argument("--target-layer", type=int, default=19,
|
||||
help="Target layer for steering")
|
||||
parser.add_argument("--query", type=str,
|
||||
default="Explain quantum computing to me in detail",
|
||||
help="Query to process")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model and tokenizer
|
||||
try:
|
||||
logger.info(f"Loading model: {args.model}")
|
||||
|
||||
# Determine device
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
|
||||
# Load model with appropriate configuration based on device
|
||||
model_kwargs = {"trust_remote_code": True}
|
||||
|
||||
if device == "cuda":
|
||||
model_kwargs["torch_dtype"] = torch.float16
|
||||
model_kwargs["device_map"] = "auto"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs)
|
||||
|
||||
# Ensure proper PAD token
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
logger.info("Model and tokenizer loaded successfully")
|
||||
|
||||
# Create AutoThink configuration
|
||||
config = {
|
||||
"steering_dataset": args.steering_dataset,
|
||||
"target_layer": args.target_layer,
|
||||
"pattern_strengths": {
|
||||
"depth_and_thoroughness": 2.5,
|
||||
"numerical_accuracy": 2.0,
|
||||
"self_correction": 3.0,
|
||||
"exploration": 2.0,
|
||||
"organization": 1.5
|
||||
}
|
||||
}
|
||||
|
||||
# Create messages
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": args.query}
|
||||
]
|
||||
|
||||
# Process with AutoThink
|
||||
logger.info("Running AutoThink processing...")
|
||||
response = autothink_decode(model, tokenizer, messages, config)
|
||||
|
||||
# Print response
|
||||
print("\n" + "=" * 80)
|
||||
print("QUERY:", args.query)
|
||||
print("-" * 80)
|
||||
print(response)
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in AutoThink demo: {str(e)}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
376
optillm/autothink/processor.py
Normal file
376
optillm/autothink/processor.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
AutoThink processor implementation.
|
||||
|
||||
This module implements the AutoThink processor for controlled thinking
|
||||
with query complexity classification and steering vectors.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import random
|
||||
import logging
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, DynamicCache
|
||||
from typing import Dict, List, Any, Optional, Union, Tuple
|
||||
|
||||
from .classifier import ComplexityClassifier
|
||||
from .steering import SteeringVectorManager, install_steering_hooks, remove_steering_hooks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default configurations
|
||||
DEFAULT_CONFIG = {
|
||||
# General configuration
|
||||
"min_thinking_tokens": 256,
|
||||
"max_thinking_tokens": 2048,
|
||||
"max_thoughts": 64,
|
||||
"prefill": "",
|
||||
"start_think_token": "<think>",
|
||||
"end_think_token": "</think>",
|
||||
|
||||
# Complexity-specific configurations
|
||||
"high_complexity_min_tokens": 1024,
|
||||
"high_complexity_max_tokens": 4096,
|
||||
"low_complexity_min_tokens": 256,
|
||||
"low_complexity_max_tokens": 1024,
|
||||
|
||||
# Thought switch tokens
|
||||
"thought_switch_tokens": [
|
||||
"Wait,",
|
||||
"Alternatively,",
|
||||
"However,",
|
||||
"Additionally,",
|
||||
"Let's consider,",
|
||||
"On second thought,",
|
||||
"Actually,",
|
||||
"Furthermore,",
|
||||
"Looking at it differently,",
|
||||
"To be thorough,"
|
||||
],
|
||||
|
||||
# Classifier configuration
|
||||
"classifier_model": "adaptive-classifier/llm-router",
|
||||
"complexity_threshold": 0.6,
|
||||
|
||||
# Steering configuration
|
||||
"steering_dataset": "",
|
||||
"target_layer": 19,
|
||||
"pattern_strengths": {
|
||||
"depth_and_thoroughness": 2.5,
|
||||
"numerical_accuracy": 2.0,
|
||||
"self_correction": 3.0,
|
||||
"exploration": 2.0,
|
||||
"organization": 1.5
|
||||
}
|
||||
}
|
||||
|
||||
class AutoThinkProcessor:
|
||||
"""
|
||||
AutoThink processor for controlled thinking with
|
||||
complexity classification and steering vectors.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], tokenizer: PreTrainedTokenizer, model: PreTrainedModel):
|
||||
"""
|
||||
Initialize the AutoThink processor.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
tokenizer: Model tokenizer
|
||||
model: Language model
|
||||
"""
|
||||
# Merge default config with provided config
|
||||
self.config = {**DEFAULT_CONFIG, **config}
|
||||
self.tokenizer = tokenizer
|
||||
self.model = model
|
||||
|
||||
# Initialize classifier
|
||||
self.classifier = ComplexityClassifier(self.config["classifier_model"])
|
||||
|
||||
# Get token IDs for think markers
|
||||
start_tokens = self.tokenizer.encode(self.config['start_think_token'])
|
||||
end_tokens = self.tokenizer.encode(self.config['end_think_token'])
|
||||
self._start_think_token = start_tokens[0] if len(start_tokens) == 1 else start_tokens[1]
|
||||
self.end_think_token = end_tokens[0] if len(end_tokens) == 1 else end_tokens[1]
|
||||
|
||||
# Store thought switch markers as token sequences
|
||||
self.thought_switch_sequences = []
|
||||
for phrase in self.config["thought_switch_tokens"]:
|
||||
token_ids = self.tokenizer.encode(phrase, add_special_tokens=False)
|
||||
self.thought_switch_sequences.append(token_ids)
|
||||
logger.debug(f"Encoded '{phrase}' to token sequence: {token_ids}")
|
||||
logger.debug(f"Decoded back: {self.tokenizer.decode(token_ids)}")
|
||||
|
||||
# Track thought switches
|
||||
self.thought_count = 0
|
||||
self.current_sequence = [] # Track recent tokens for sequence matching
|
||||
self.max_sequence_length = max(len(seq) for seq in self.thought_switch_sequences)
|
||||
|
||||
# Initialize steering vector manager and hooks if dataset is provided
|
||||
self.steering_manager = None
|
||||
self.steering_hooks = []
|
||||
|
||||
if self.config["steering_dataset"]:
|
||||
self._setup_steering()
|
||||
|
||||
def _setup_steering(self):
|
||||
"""Set up steering vector management."""
|
||||
try:
|
||||
# Initialize steering vector manager
|
||||
self.steering_manager = SteeringVectorManager(
|
||||
dataset_name=self.config["steering_dataset"],
|
||||
target_layer=self.config["target_layer"]
|
||||
)
|
||||
|
||||
# Set pattern strengths
|
||||
if "pattern_strengths" in self.config:
|
||||
for pattern, strength in self.config["pattern_strengths"].items():
|
||||
self.steering_manager.set_steering_strength(pattern, strength)
|
||||
|
||||
# Create tokenized contexts for efficient matching
|
||||
self.steering_manager.create_tokenized_contexts(self.tokenizer)
|
||||
|
||||
# Install hooks on the model
|
||||
self.steering_hooks = install_steering_hooks(
|
||||
self.model,
|
||||
self.steering_manager,
|
||||
self.tokenizer
|
||||
)
|
||||
|
||||
logger.info(f"Set up steering with {len(self.steering_hooks)} hooks")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting up steering: {e}")
|
||||
self.steering_manager = None
|
||||
self.steering_hooks = []
|
||||
|
||||
def _cleanup_steering(self):
|
||||
"""Clean up steering hooks."""
|
||||
if self.steering_hooks:
|
||||
remove_steering_hooks(self.steering_hooks)
|
||||
self.steering_hooks = []
|
||||
|
||||
def classify_complexity(self, query: str) -> Tuple[str, float]:
|
||||
"""
|
||||
Classify query complexity.
|
||||
|
||||
Args:
|
||||
query: The query to classify
|
||||
|
||||
Returns:
|
||||
Tuple of (complexity_label, confidence_score)
|
||||
"""
|
||||
complexity, confidence = self.classifier.get_complexity_with_confidence(query)
|
||||
logger.info(f"Query classified as {complexity} with confidence {confidence:.2f}")
|
||||
return complexity, confidence
|
||||
|
||||
def get_token_budget(self, complexity: str) -> Tuple[int, int]:
|
||||
"""
|
||||
Get token budget based on complexity.
|
||||
|
||||
Args:
|
||||
complexity: Complexity label (HIGH or LOW)
|
||||
|
||||
Returns:
|
||||
Tuple of (min_tokens, max_tokens)
|
||||
"""
|
||||
if complexity == "HIGH":
|
||||
return (
|
||||
self.config["high_complexity_min_tokens"],
|
||||
self.config["high_complexity_max_tokens"]
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self.config["low_complexity_min_tokens"],
|
||||
self.config["low_complexity_max_tokens"]
|
||||
)
|
||||
|
||||
def is_thought_switch(self, token: int) -> bool:
|
||||
"""
|
||||
Check if adding this token creates a thought switch sequence.
|
||||
|
||||
Args:
|
||||
token: Token ID to check
|
||||
|
||||
Returns:
|
||||
Boolean indicating if this completes a thought switch
|
||||
"""
|
||||
# Add new token to current sequence
|
||||
self.current_sequence.append(token)
|
||||
|
||||
# Keep only the most recent tokens that could match our sequences
|
||||
if len(self.current_sequence) > self.max_sequence_length:
|
||||
self.current_sequence = self.current_sequence[-self.max_sequence_length:]
|
||||
|
||||
# Check if current sequence ends with any thought switch sequence
|
||||
for sequence in self.thought_switch_sequences:
|
||||
if len(sequence) <= len(self.current_sequence) and \
|
||||
self.current_sequence[-len(sequence):] == sequence:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@torch.inference_mode()
|
||||
def process(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Process messages with AutoThink's controlled thinking.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Generated response
|
||||
"""
|
||||
try:
|
||||
# Extract the query from the messages
|
||||
query = self._extract_query(messages)
|
||||
|
||||
# Classify query complexity
|
||||
complexity, confidence = self.classify_complexity(query)
|
||||
|
||||
# Get token budget based on complexity
|
||||
min_tokens, max_tokens = self.get_token_budget(complexity)
|
||||
logger.info(f"Using token budget: {min_tokens}-{max_tokens} for {complexity} complexity")
|
||||
|
||||
# Prepare messages with thinking start token
|
||||
thinking_messages = messages.copy()
|
||||
thinking_messages.append({
|
||||
"role": "assistant",
|
||||
"content": f"{self.config['start_think_token']}\n{self.config['prefill']}"
|
||||
})
|
||||
|
||||
# Tokenize the messages
|
||||
tokens = self.tokenizer.apply_chat_template(
|
||||
thinking_messages,
|
||||
continue_final_message=True,
|
||||
return_tensors="pt"
|
||||
).to(self.model.device)
|
||||
|
||||
# Update token history in steering hooks
|
||||
if self.steering_hooks:
|
||||
token_ids = tokens[0].tolist()
|
||||
for hook, _ in self.steering_hooks:
|
||||
hook.update_token_history(token_ids)
|
||||
# Try to match with a steering vector
|
||||
hook.try_match()
|
||||
|
||||
# Generate with controlled thinking
|
||||
kv = DynamicCache()
|
||||
n_thinking_tokens = 0
|
||||
seen_end_think = False
|
||||
response_chunks = []
|
||||
|
||||
while True:
|
||||
out = self.model(input_ids=tokens, past_key_values=kv, use_cache=True)
|
||||
logits = out.logits[0, -1, :]
|
||||
|
||||
# Check if we need to force end thinking
|
||||
force_end = (n_thinking_tokens >= max_tokens or
|
||||
self.thought_count >= self.config["max_thoughts"])
|
||||
|
||||
if force_end and not seen_end_think:
|
||||
logger.debug(f"Forcing end think token. Tokens: {n_thinking_tokens}, Thoughts: {self.thought_count}")
|
||||
next_token = self.end_think_token
|
||||
response_chunks.append(self.tokenizer.decode([next_token]))
|
||||
seen_end_think = True
|
||||
tokens = torch.tensor([[next_token]]).to(tokens.device)
|
||||
continue
|
||||
else:
|
||||
next_token = torch.multinomial(
|
||||
torch.softmax(logits, dim=-1), 1
|
||||
).item()
|
||||
|
||||
kv = out.past_key_values
|
||||
next_str = self.tokenizer.decode([next_token])
|
||||
|
||||
# Update steering hooks with new token
|
||||
if self.steering_hooks:
|
||||
for hook, _ in self.steering_hooks:
|
||||
hook.update_token_history([next_token])
|
||||
|
||||
# Check if this is a thought-switching token (only if not in conclusion phase)
|
||||
if not seen_end_think and self.is_thought_switch(next_token):
|
||||
self.thought_count += 1
|
||||
logger.debug(f"Detected thought switch marker. Total thoughts: {self.thought_count}")
|
||||
self.current_sequence = []
|
||||
|
||||
# Handle natural end think token
|
||||
if next_token == self.end_think_token:
|
||||
seen_end_think = True
|
||||
logger.debug("Found end think token")
|
||||
|
||||
# If we haven't reached minimum tokens, continue with thought transition
|
||||
if n_thinking_tokens < min_tokens:
|
||||
replacement = random.choice(self.config["thought_switch_tokens"])
|
||||
logger.debug(f"Inserting thought transition: '{replacement}' (tokens: {n_thinking_tokens})")
|
||||
response_chunks.append(replacement)
|
||||
replacement_tokens = self.tokenizer.encode(replacement)
|
||||
n_thinking_tokens += len(replacement_tokens)
|
||||
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
|
||||
self.thought_count += 1
|
||||
seen_end_think = False
|
||||
continue
|
||||
|
||||
# Handle EOS token
|
||||
if next_token == self.model.config.eos_token_id:
|
||||
logger.debug("Found EOS token")
|
||||
if seen_end_think:
|
||||
logger.debug("Reached EOS after end think token - stopping generation")
|
||||
response_chunks.append(next_str)
|
||||
break
|
||||
elif n_thinking_tokens < min_tokens:
|
||||
# Continue with thought transition if under minimum tokens
|
||||
replacement = random.choice(self.config["thought_switch_tokens"])
|
||||
logger.debug(f"Inserting thought transition: '{replacement}' (tokens: {n_thinking_tokens})")
|
||||
response_chunks.append(replacement)
|
||||
replacement_tokens = self.tokenizer.encode(replacement)
|
||||
n_thinking_tokens += len(replacement_tokens)
|
||||
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
|
||||
self.thought_count += 1
|
||||
continue
|
||||
else:
|
||||
# Force end think token and continue generating for natural conclusion
|
||||
logger.debug("Reached EOS without end think token - adding end token and continuing generation")
|
||||
response_chunks.append(self.tokenizer.decode([self.end_think_token]))
|
||||
tokens = torch.tensor([[self.end_think_token]]).to(tokens.device)
|
||||
seen_end_think = True
|
||||
continue
|
||||
|
||||
# Normal token processing
|
||||
response_chunks.append(next_str)
|
||||
if not seen_end_think:
|
||||
n_thinking_tokens += 1
|
||||
tokens = torch.tensor([[next_token]]).to(tokens.device)
|
||||
|
||||
# Clean up steering hooks
|
||||
self._cleanup_steering()
|
||||
|
||||
# Join all chunks and add framing tokens
|
||||
response = "".join(response_chunks)
|
||||
full_response = f"{self.config['start_think_token']}\n{self.config['prefill']}{response}"
|
||||
|
||||
logger.debug(f"Final response length: {len(full_response)} chars, Total thoughts: {self.thought_count}")
|
||||
return full_response
|
||||
|
||||
except Exception as e:
|
||||
# Clean up steering hooks in case of error
|
||||
self._cleanup_steering()
|
||||
logger.error(f"Error in AutoThink processing: {str(e)}")
|
||||
raise
|
||||
|
||||
def _extract_query(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Extract the query from messages for classification.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Extracted query string
|
||||
"""
|
||||
# Get the last user message
|
||||
user_messages = [m["content"] for m in messages if m["role"] == "user"]
|
||||
|
||||
if user_messages:
|
||||
return user_messages[-1]
|
||||
|
||||
# Fallback to concatenated messages
|
||||
return " ".join(m["content"] for m in messages)
|
||||
603
optillm/autothink/steering.py
Normal file
603
optillm/autothink/steering.py
Normal file
@@ -0,0 +1,603 @@
|
||||
"""
|
||||
Steering vector manager for AutoThink.
|
||||
|
||||
This module provides functionality to load and apply steering vectors
|
||||
from Hugging Face datasets during inference.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import random
|
||||
import json
|
||||
import datasets
|
||||
from typing import Dict, List, Any, Tuple, Optional, Union
|
||||
from collections import defaultdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SteeringVectorManager:
|
||||
"""
|
||||
Manager for loading and applying steering vectors from a dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name: str,
|
||||
target_layer: int = 19,
|
||||
cache_dir: Optional[str] = None,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the steering vector manager.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the HuggingFace dataset containing steering vectors
|
||||
target_layer: Target layer for applying steering vectors
|
||||
cache_dir: Directory for caching the dataset
|
||||
device: Device to use for tensors
|
||||
"""
|
||||
self.dataset_name = dataset_name
|
||||
self.target_layer = target_layer
|
||||
self.cache_dir = cache_dir
|
||||
self.device = device or (
|
||||
"cuda" if torch.cuda.is_available() else
|
||||
"mps" if torch.backends.mps.is_available() else
|
||||
"cpu"
|
||||
)
|
||||
|
||||
# Storage for steering vectors
|
||||
self.steering_vectors = []
|
||||
self.pattern_to_vectors = {}
|
||||
self.tokenized_contexts = {}
|
||||
|
||||
# Default steering strengths
|
||||
self.default_strength = 2.0
|
||||
self.pattern_strengths = {
|
||||
"depth_and_thoroughness": 2.5,
|
||||
"numerical_accuracy": 2.0,
|
||||
"self_correction": 3.0,
|
||||
"exploration": 2.0,
|
||||
"organization": 1.5,
|
||||
"unknown": 1.0
|
||||
}
|
||||
|
||||
# If dataset is provided, load it
|
||||
if dataset_name:
|
||||
self.load_dataset()
|
||||
|
||||
def load_dataset(self):
|
||||
"""Load steering vectors from the HuggingFace dataset."""
|
||||
try:
|
||||
logger.info(f"Loading steering vectors from dataset: {self.dataset_name}")
|
||||
|
||||
# Load the dataset
|
||||
dataset = datasets.load_dataset(self.dataset_name, cache_dir=self.cache_dir)
|
||||
|
||||
# Get the main split (usually 'train')
|
||||
main_split = list(dataset.keys())[0]
|
||||
vector_data = dataset[main_split]
|
||||
|
||||
# Load each item as a steering vector
|
||||
for item in vector_data:
|
||||
# Convert dataset item to proper format
|
||||
vector = self._process_dataset_item(item)
|
||||
if vector:
|
||||
self.steering_vectors.append(vector)
|
||||
|
||||
# Group by reasoning pattern
|
||||
pattern = vector.get("reasoning_pattern", "unknown")
|
||||
if pattern not in self.pattern_to_vectors:
|
||||
self.pattern_to_vectors[pattern] = []
|
||||
self.pattern_to_vectors[pattern].append(vector)
|
||||
|
||||
logger.info(f"Loaded {len(self.steering_vectors)} steering vectors")
|
||||
logger.info(f"Found {len(self.pattern_to_vectors)} reasoning patterns: {list(self.pattern_to_vectors.keys())}")
|
||||
|
||||
# Log the first vector for debugging
|
||||
if self.steering_vectors:
|
||||
first_vector = self.steering_vectors[0]
|
||||
logger.info(f"First vector sample - pattern: {first_vector.get('reasoning_pattern', 'missing')}")
|
||||
if 'pivot_context' in first_vector:
|
||||
context_len = len(first_vector['pivot_context'])
|
||||
logger.info(f"First vector pivot_context length: {context_len}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading steering vectors: {e}")
|
||||
self.steering_vectors = []
|
||||
self.pattern_to_vectors = {}
|
||||
|
||||
def _process_dataset_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a dataset item into a steering vector.
|
||||
|
||||
Args:
|
||||
item: Dataset item
|
||||
|
||||
Returns:
|
||||
Processed steering vector or None if invalid
|
||||
"""
|
||||
try:
|
||||
# Check if item has the required fields
|
||||
required_fields = ["pivot_context", "steering_vector", "reasoning_pattern"]
|
||||
if not all(field in item for field in required_fields):
|
||||
return None
|
||||
|
||||
# Convert steering_vector to a proper format if it's a string or list
|
||||
steering_vector = item["steering_vector"]
|
||||
if isinstance(steering_vector, str):
|
||||
# Try to parse JSON string
|
||||
try:
|
||||
steering_vector = json.loads(steering_vector)
|
||||
except json.JSONDecodeError:
|
||||
# Try comma-separated format
|
||||
steering_vector = [float(x) for x in steering_vector.strip("[]").split(",")]
|
||||
|
||||
# Ensure we have a proper list
|
||||
if not isinstance(steering_vector, list):
|
||||
logger.warning(f"Invalid steering vector format: {type(steering_vector)}")
|
||||
return None
|
||||
|
||||
# Create the steering vector dictionary
|
||||
vector = {
|
||||
"pivot_context": item["pivot_context"],
|
||||
"pivot_token": item.get("pivot_token", ""),
|
||||
"pivot_token_id": item.get("pivot_token_id", -1),
|
||||
"prob_before": item.get("prob_before", 0.0),
|
||||
"prob_after": item.get("prob_after", 0.0),
|
||||
"prob_delta": item.get("prob_delta", 0.0),
|
||||
"model_id": item.get("model_id", ""),
|
||||
"task_type": item.get("task_type", "unknown"),
|
||||
"steering_vector": steering_vector,
|
||||
"cluster_id": item.get("cluster_id", -1),
|
||||
"reasoning_pattern": item.get("reasoning_pattern", "unknown"),
|
||||
"cluster_vector": item.get("cluster_vector", steering_vector),
|
||||
"steering_layer": item.get("steering_layer", self.target_layer),
|
||||
}
|
||||
|
||||
return vector
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing dataset item: {e}")
|
||||
return None
|
||||
|
||||
def create_tokenized_contexts(self, tokenizer):
|
||||
"""
|
||||
Pre-tokenize context patterns for efficient matching during generation.
|
||||
|
||||
Args:
|
||||
tokenizer: Tokenizer for encoding contexts
|
||||
"""
|
||||
# Get configurations
|
||||
max_pts_tokens = 256 # Maximum tokens to store for matching
|
||||
|
||||
count = 0
|
||||
for vector in self.steering_vectors:
|
||||
# Get the context
|
||||
context = vector.get("pivot_context", "")
|
||||
if not context:
|
||||
continue
|
||||
|
||||
# Pre-tokenize the context for faster matching
|
||||
tokenized_context = tokenizer.encode(context, add_special_tokens=False)
|
||||
|
||||
# Keep only up to max_pts_tokens
|
||||
if len(tokenized_context) > max_pts_tokens:
|
||||
tokenized_context = tokenized_context[-max_pts_tokens:]
|
||||
|
||||
# Store the tokenized context with its vector
|
||||
tuple_key = tuple(tokenized_context)
|
||||
self.tokenized_contexts[tuple_key] = vector
|
||||
|
||||
# Store additional shorter versions for partial matching
|
||||
for suffix_len in [4, 8, 12]:
|
||||
if len(tokenized_context) > suffix_len:
|
||||
suffix = tokenized_context[-suffix_len:]
|
||||
suffix_tuple = tuple(suffix)
|
||||
if suffix_tuple not in self.tokenized_contexts:
|
||||
self.tokenized_contexts[suffix_tuple] = vector
|
||||
|
||||
count += 1
|
||||
|
||||
# Log statistics
|
||||
logger.info(f"Pre-tokenized {count} contexts into {len(self.tokenized_contexts)} token patterns")
|
||||
|
||||
# Count patterns by length for debugging
|
||||
length_counts = {}
|
||||
for key in self.tokenized_contexts.keys():
|
||||
length = len(key)
|
||||
if length not in length_counts:
|
||||
length_counts[length] = 0
|
||||
length_counts[length] += 1
|
||||
|
||||
logger.info(f"Token pattern length distribution: {sorted(length_counts.items())}")
|
||||
|
||||
def get_steering_strength(self, pattern: str) -> float:
|
||||
"""
|
||||
Get the steering strength for a specific pattern.
|
||||
|
||||
Args:
|
||||
pattern: The reasoning pattern
|
||||
|
||||
Returns:
|
||||
The steering strength
|
||||
"""
|
||||
return self.pattern_strengths.get(pattern, self.default_strength)
|
||||
|
||||
def set_steering_strength(self, pattern: str, strength: float):
|
||||
"""
|
||||
Set the steering strength for a specific pattern.
|
||||
|
||||
Args:
|
||||
pattern: The reasoning pattern
|
||||
strength: The steering strength
|
||||
"""
|
||||
self.pattern_strengths[pattern] = strength
|
||||
logger.info(f"Set strength for {pattern} to {strength}")
|
||||
|
||||
def get_pattern_vectors(self, pattern: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all steering vectors for a specific reasoning pattern.
|
||||
|
||||
Args:
|
||||
pattern: The reasoning pattern
|
||||
|
||||
Returns:
|
||||
List of steering vectors
|
||||
"""
|
||||
return self.pattern_to_vectors.get(pattern, [])
|
||||
|
||||
class SteeringHook:
|
||||
"""Hook for applying steering vectors during generation."""
|
||||
|
||||
def __init__(self, manager: SteeringVectorManager, layer_num: int, tokenizer=None):
|
||||
"""
|
||||
Initialize the steering hook.
|
||||
|
||||
Args:
|
||||
manager: The steering vector manager
|
||||
layer_num: The layer number to apply steering to
|
||||
tokenizer: Tokenizer for token-based matching
|
||||
"""
|
||||
self.manager = manager
|
||||
self.layer_num = layer_num
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# For token-based matching
|
||||
self.token_history = [] # Store token IDs for matching
|
||||
self.max_history = 256 # Maximum tokens to keep in history
|
||||
|
||||
# State tracking
|
||||
self.match_found = False
|
||||
self.current_vector = None
|
||||
self.last_pattern = None
|
||||
|
||||
# Single pattern for entire request
|
||||
self.active_pattern = None # Currently active pattern
|
||||
self.generation_started = False
|
||||
|
||||
logger.info(f"Initialized hook for layer {layer_num}")
|
||||
|
||||
def __call__(self, module, input_tensors, output):
|
||||
"""
|
||||
Apply steering to the output of a layer.
|
||||
|
||||
Args:
|
||||
module: The module being hooked
|
||||
input_tensors: The input tensors
|
||||
output: The output tensor
|
||||
|
||||
Returns:
|
||||
Modified output tensor
|
||||
"""
|
||||
try:
|
||||
# Skip if no active pattern is set
|
||||
if not self.active_pattern:
|
||||
return output
|
||||
|
||||
# Apply steering vector if available
|
||||
if self.current_vector is not None:
|
||||
# Get the appropriate steering strength
|
||||
pattern = self.current_vector.get("reasoning_pattern", "unknown")
|
||||
strength = self.manager.get_steering_strength(pattern)
|
||||
|
||||
# Keep strength within safe bounds
|
||||
safe_strength = min(max(strength, 0.1), 2.0)
|
||||
|
||||
# Log when pattern changes
|
||||
if pattern != self.last_pattern:
|
||||
logger.info(f"Switching to {pattern} reasoning pattern with strength {safe_strength}")
|
||||
self.last_pattern = pattern
|
||||
|
||||
# Apply the steering vector
|
||||
try:
|
||||
if isinstance(output, tuple):
|
||||
# Some models return a tuple
|
||||
hidden_states = output[0]
|
||||
modified_hidden_states = self._apply_steering_vector(hidden_states, self.current_vector, safe_strength)
|
||||
|
||||
# Validate the result
|
||||
if modified_hidden_states.shape == hidden_states.shape:
|
||||
return (modified_hidden_states,) + output[1:]
|
||||
else:
|
||||
logger.error(f"Modified hidden states have wrong shape. Expected {hidden_states.shape}, got {modified_hidden_states.shape}")
|
||||
return output
|
||||
else:
|
||||
# Direct tensor output
|
||||
return self._apply_steering_vector(output, self.current_vector, safe_strength)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying steering: {e}")
|
||||
return output
|
||||
|
||||
return output
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error in hook: {e}")
|
||||
return output
|
||||
|
||||
def _apply_steering_vector(self, hidden_states: torch.Tensor,
|
||||
steering_vector: Dict[str, Any],
|
||||
scaling_factor: float = 2.0) -> torch.Tensor:
|
||||
"""
|
||||
Apply a steering vector to hidden states.
|
||||
|
||||
Args:
|
||||
hidden_states: The hidden states tensor
|
||||
steering_vector: Dictionary with steering vector data
|
||||
scaling_factor: Factor to scale the steering vector by
|
||||
|
||||
Returns:
|
||||
Modified hidden states tensor
|
||||
"""
|
||||
try:
|
||||
# Make a deep clone
|
||||
hidden_states_clone = hidden_states.clone().detach()
|
||||
|
||||
# Check what kind of vector we're using
|
||||
vector_data = None
|
||||
if "steering_vector" in steering_vector:
|
||||
vector_data = steering_vector["steering_vector"]
|
||||
vector_type = "steering_vector"
|
||||
elif "cluster_vector" in steering_vector:
|
||||
vector_data = steering_vector["cluster_vector"]
|
||||
vector_type = "cluster_vector"
|
||||
else:
|
||||
logger.warning("No valid vector found in steering data")
|
||||
return hidden_states
|
||||
|
||||
# Convert vector to tensor
|
||||
vector = torch.tensor(vector_data,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
# Log vector info
|
||||
pattern = steering_vector.get("reasoning_pattern", "unknown")
|
||||
logger.debug(f"Applying {vector_type} for pattern '{pattern}' with scaling {scaling_factor}")
|
||||
|
||||
# Apply scaling based on prob_delta if available
|
||||
if "prob_delta" in steering_vector:
|
||||
prob_delta = abs(steering_vector["prob_delta"])
|
||||
prob_delta_capped = min(max(prob_delta, 0.1), 2.0)
|
||||
scaling_factor *= prob_delta_capped
|
||||
|
||||
# Check if the token is positive or negative
|
||||
is_positive = steering_vector.get("is_positive", True)
|
||||
|
||||
# Verify shapes are compatible
|
||||
hs_shape = hidden_states.shape
|
||||
vector_shape = vector.shape
|
||||
|
||||
if len(vector_shape) != 1 or vector_shape[0] != hs_shape[-1]:
|
||||
logger.error(f"Shape mismatch - hidden_states: {hs_shape}, vector: {vector_shape}")
|
||||
return hidden_states
|
||||
|
||||
# Bound scaling factor for safety
|
||||
safe_scaling = min(max(scaling_factor, 0.0), 3.0)
|
||||
|
||||
# Apply steering
|
||||
if len(hs_shape) >= 3 and hs_shape[0] > 0 and hs_shape[1] > 0:
|
||||
# Apply to the last token's representation
|
||||
if is_positive:
|
||||
# Normalize vector to prevent numerical instability
|
||||
vector_norm = torch.nn.functional.normalize(vector, dim=0)
|
||||
hidden_states_clone[-1, -1, :] = hidden_states_clone[-1, -1, :] + safe_scaling * vector_norm
|
||||
else:
|
||||
vector_norm = torch.nn.functional.normalize(vector, dim=0)
|
||||
hidden_states_clone[-1, -1, :] = hidden_states_clone[-1, -1, :] - safe_scaling * vector_norm
|
||||
|
||||
# Check for NaN or inf values
|
||||
if torch.isnan(hidden_states_clone).any() or torch.isinf(hidden_states_clone).any():
|
||||
logger.error("NaN or inf values detected after applying vector, reverting to original")
|
||||
return hidden_states
|
||||
else:
|
||||
logger.error(f"Hidden states shape not suitable for steering: {hs_shape}")
|
||||
return hidden_states
|
||||
|
||||
return hidden_states_clone
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error applying steering vector: {e}")
|
||||
return hidden_states
|
||||
|
||||
def update_token_history(self, new_tokens: List[int]):
|
||||
"""
|
||||
Update the token history with new tokens.
|
||||
|
||||
Args:
|
||||
new_tokens: New token IDs to add
|
||||
"""
|
||||
# Add to token history
|
||||
self.token_history.extend(new_tokens)
|
||||
|
||||
# Trim history if needed
|
||||
if len(self.token_history) > self.max_history:
|
||||
self.token_history = self.token_history[-self.max_history:]
|
||||
|
||||
# Log token updates periodically
|
||||
if random.random() < 0.01:
|
||||
logger.debug(f"Token history updated, now has {len(self.token_history)} tokens")
|
||||
|
||||
def try_match(self) -> bool:
|
||||
"""
|
||||
Try to match the current context with a steering vector.
|
||||
|
||||
Returns:
|
||||
Boolean indicating if a match was found
|
||||
"""
|
||||
# If we already have an active pattern, don't try to match again
|
||||
if self.generation_started and self.active_pattern:
|
||||
return False
|
||||
|
||||
# Only attempt pattern matching at the beginning of generation
|
||||
self.generation_started = True
|
||||
|
||||
# Try token-based matching
|
||||
match_result = self._try_token_match()
|
||||
|
||||
# If a match is found, set this as the permanent pattern for this generation
|
||||
if match_result and self.current_vector:
|
||||
new_pattern = self.current_vector.get("reasoning_pattern", "unknown")
|
||||
self.active_pattern = new_pattern
|
||||
logger.info(f"Selected '{new_pattern}' pattern for this request")
|
||||
|
||||
return match_result
|
||||
|
||||
def _try_token_match(self) -> bool:
|
||||
"""
|
||||
Try to match using token-based context.
|
||||
|
||||
Returns:
|
||||
Boolean indicating if a match was found
|
||||
"""
|
||||
# Ensure we have enough tokens
|
||||
if len(self.token_history) < 4:
|
||||
return False
|
||||
|
||||
# Track best match
|
||||
best_match = {
|
||||
'length': 0,
|
||||
'vector': None,
|
||||
'is_partial': True
|
||||
}
|
||||
|
||||
# Check for matches in tokenized contexts
|
||||
for tokenized_context, vector in self.manager.tokenized_contexts.items():
|
||||
token_list = list(tokenized_context)
|
||||
token_len = len(token_list)
|
||||
|
||||
# Try partial matching for shorter contexts
|
||||
if len(self.token_history) < token_len:
|
||||
# Only try partial matching if we have enough context tokens
|
||||
if len(self.token_history) >= 4:
|
||||
# Calculate how many tokens to match
|
||||
match_len = min(len(self.token_history), max(4, token_len // 2))
|
||||
# Try to match the end of the token sequence
|
||||
if self.token_history[-match_len:] == token_list[-match_len:]:
|
||||
# Track this match - prefer longer matches
|
||||
if match_len > best_match['length']:
|
||||
best_match = {
|
||||
'length': match_len,
|
||||
'vector': vector,
|
||||
'is_partial': True,
|
||||
'match_len': match_len,
|
||||
'token_len': token_len
|
||||
}
|
||||
else:
|
||||
# Full matching when we have enough tokens
|
||||
if self.token_history[-token_len:] == token_list:
|
||||
# Track this match - full matches are preferred
|
||||
if token_len >= best_match['length']:
|
||||
best_match = {
|
||||
'length': token_len,
|
||||
'vector': vector,
|
||||
'is_partial': False,
|
||||
'match_len': token_len,
|
||||
'token_len': token_len
|
||||
}
|
||||
|
||||
# Apply best match if found
|
||||
if best_match['vector'] is not None:
|
||||
match_type = "PARTIAL" if best_match['is_partial'] else "FULL"
|
||||
self.match_found = True
|
||||
self.current_vector = best_match['vector']
|
||||
pattern = best_match['vector'].get("reasoning_pattern", "unknown")
|
||||
logger.info(f"Found {match_type} token match ({best_match['match_len']}/{best_match['token_len']} tokens) for {pattern} pattern")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def reset(self):
|
||||
"""Reset the hook state."""
|
||||
self.match_found = False
|
||||
self.current_vector = None
|
||||
self.token_history = []
|
||||
self.last_pattern = None
|
||||
self.active_pattern = None
|
||||
self.generation_started = False
|
||||
|
||||
def install_steering_hooks(model, manager: SteeringVectorManager, tokenizer=None) -> List[Tuple]:
|
||||
"""
|
||||
Install steering hooks on a model.
|
||||
|
||||
Args:
|
||||
model: The model to install hooks on
|
||||
manager: The steering vector manager
|
||||
tokenizer: Tokenizer for token-based matching
|
||||
|
||||
Returns:
|
||||
List of installed hooks
|
||||
"""
|
||||
hooks = []
|
||||
|
||||
# Target layer is specified in the manager
|
||||
layer_num = manager.target_layer
|
||||
logger.info(f"Attempting to install hook on layer {layer_num}")
|
||||
|
||||
# First, log model structure to help with debugging
|
||||
model_type = type(model).__name__
|
||||
logger.info(f"Model type is {model_type}")
|
||||
|
||||
# Find the appropriate module - depends on model architecture
|
||||
module = None
|
||||
if hasattr(model, 'transformer'):
|
||||
logger.info("Model has 'transformer' attribute")
|
||||
if hasattr(model.transformer, 'h') and layer_num < len(model.transformer.h):
|
||||
module = model.transformer.h[layer_num]
|
||||
logger.info(f"Using transformer.h[{layer_num}]")
|
||||
elif hasattr(model, 'model'):
|
||||
logger.info("Model has 'model' attribute")
|
||||
if hasattr(model.model, 'layers') and layer_num < len(model.model.layers):
|
||||
module = model.model.layers[layer_num]
|
||||
logger.info(f"Using model.layers[{layer_num}]")
|
||||
elif hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers') and layer_num < len(model.model.decoder.layers):
|
||||
module = model.model.decoder.layers[layer_num]
|
||||
logger.info(f"Using model.decoder.layers[{layer_num}]")
|
||||
elif hasattr(model, 'layers') and layer_num < len(model.layers):
|
||||
module = model.layers[layer_num]
|
||||
logger.info(f"Using layers[{layer_num}]")
|
||||
|
||||
if module is None:
|
||||
logger.error(f"Could not find appropriate module for layer {layer_num}")
|
||||
logger.error("Model structure not compatible with current hook installation logic")
|
||||
return []
|
||||
|
||||
# Create and register hook
|
||||
hook = SteeringHook(manager, layer_num, tokenizer)
|
||||
handle = module.register_forward_hook(hook)
|
||||
|
||||
# Return both hook object and handle for later removal
|
||||
hooks.append((hook, handle))
|
||||
|
||||
logger.info(f"Installed hook on layer {layer_num} successfully")
|
||||
|
||||
return hooks
|
||||
|
||||
def remove_steering_hooks(hooks):
|
||||
"""
|
||||
Remove steering hooks from a model.
|
||||
|
||||
Args:
|
||||
hooks: List of (hook, handle) tuples
|
||||
"""
|
||||
for _, handle in hooks:
|
||||
handle.remove()
|
||||
|
||||
logger.info(f"Removed {len(hooks)} hooks")
|
||||
@@ -20,6 +20,7 @@ import traceback
|
||||
from optillm.cot_decoding import cot_decode
|
||||
from optillm.entropy_decoding import entropy_decode
|
||||
from optillm.thinkdeeper import thinkdeeper_decode
|
||||
from optillm.autothink import autothink_decode
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -1467,6 +1468,34 @@ class InferenceClient:
|
||||
responses = [result]
|
||||
logprobs_results = [None]
|
||||
completion_tokens = len(pipeline.tokenizer.encode(result))
|
||||
elif decoding == "autothink":
|
||||
# Get steering dataset configuration
|
||||
steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors")
|
||||
target_layer = kwargs.get("target_layer", 19)
|
||||
|
||||
# Prepare AutoThink configuration
|
||||
autothink_config = {
|
||||
"steering_dataset": steering_dataset,
|
||||
"target_layer": target_layer,
|
||||
"pattern_strengths": kwargs.get("pattern_strengths", {
|
||||
"depth_and_thoroughness": 2.5,
|
||||
"numerical_accuracy": 2.0,
|
||||
"self_correction": 3.0,
|
||||
"exploration": 2.0,
|
||||
"organization": 1.5
|
||||
})
|
||||
}
|
||||
|
||||
# Process with AutoThink
|
||||
result = autothink_decode(
|
||||
pipeline.current_model,
|
||||
pipeline.tokenizer,
|
||||
messages,
|
||||
autothink_config
|
||||
)
|
||||
responses = [result]
|
||||
logprobs_results = [None]
|
||||
completion_tokens = len(pipeline.tokenizer.encode(result))
|
||||
else:
|
||||
raise ValueError(f"Unknown specialized decoding approach: {decoding}")
|
||||
|
||||
|
||||
@@ -27,4 +27,5 @@ spacy<3.8.0
|
||||
cerebras_cloud_sdk
|
||||
outlines[transformers]
|
||||
sentencepiece
|
||||
adaptive-classifier
|
||||
mcp
|
||||
Reference in New Issue
Block a user