mirror of
https://github.com/codelion/optillm.git
synced 2025-05-28 09:39:38 +03:00
1689 lines
71 KiB
Python
1689 lines
71 KiB
Python
import os
|
|
import torch
|
|
import logging
|
|
import numpy as np
|
|
from typing import Dict, List, Optional, Tuple, Any, Union
|
|
from dataclasses import dataclass
|
|
from collections import OrderedDict, defaultdict
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
import math
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
|
from peft import PeftModel, PeftConfig
|
|
import bitsandbytes as bnb
|
|
from scipy.stats import entropy
|
|
from functools import lru_cache
|
|
import time
|
|
import threading
|
|
import traceback
|
|
|
|
from optillm.cot_decoding import cot_decode
|
|
from optillm.entropy_decoding import entropy_decode
|
|
from optillm.thinkdeeper import thinkdeeper_decode
|
|
from optillm.autothink import autothink_decode
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
base_model_id: str
|
|
adapter_ids: Optional[List[str]] = None
|
|
batch_size: int = 32
|
|
max_cache_size: int = 5
|
|
quantization_bits: int = 4
|
|
device_preference: Optional[str] = None
|
|
# Default generation parameters
|
|
max_new_tokens: int = 4096
|
|
do_sample: bool = True
|
|
top_p: float = 0.9
|
|
top_k: int = 50
|
|
temperature: float = 0.7
|
|
num_return_sequences: int = 1
|
|
repetition_penalty: float = 1.0
|
|
pad_token_id: Optional[int] = None
|
|
logprobs: bool = False
|
|
# Advanced parameters
|
|
use_memory_efficient_attention: bool = True
|
|
enable_prompt_caching: bool = True
|
|
dynamic_temperature: bool = False
|
|
|
|
|
|
@dataclass
|
|
class LogProbsResult:
|
|
"""Container for logprobs calculation results"""
|
|
tokens: List[str]
|
|
token_logprobs: List[float]
|
|
top_logprobs: List[Dict[str, float]]
|
|
bytes_per_token: List[List[int]]
|
|
|
|
class LogProbsCalculator:
|
|
"""Handles calculation of log probabilities for generated tokens"""
|
|
|
|
def __init__(self, tokenizer, model):
|
|
self.tokenizer = tokenizer
|
|
self.model = model
|
|
|
|
def _get_bytes_for_token(self, token: str) -> List[int]:
|
|
"""Get UTF-8 bytes for a token"""
|
|
try:
|
|
return list(token.encode('utf-8'))
|
|
except UnicodeEncodeError:
|
|
return []
|
|
|
|
def _get_top_alternatives(
|
|
self,
|
|
logits: torch.Tensor,
|
|
actual_token_id: int,
|
|
num_alternatives: int
|
|
) -> Dict[str, float]:
|
|
"""Calculate top alternative tokens and their logprobs"""
|
|
probs = F.softmax(logits, dim=-1)
|
|
logprobs = torch.log(probs)
|
|
|
|
# Get top tokens excluding the actual token
|
|
top_values, top_indices = torch.topk(logprobs, k=num_alternatives + 1)
|
|
|
|
alternatives = {}
|
|
for value, idx in zip(top_values, top_indices):
|
|
token = self.tokenizer.decode([idx])
|
|
if idx != actual_token_id: # Skip the actual token
|
|
alternatives[token] = value.item()
|
|
if len(alternatives) >= num_alternatives:
|
|
break
|
|
|
|
return alternatives
|
|
|
|
def calculate_logprobs(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
generated_ids: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
num_alternatives: int = 5
|
|
) -> LogProbsResult:
|
|
"""Calculate log probabilities for a sequence of tokens"""
|
|
self.model.eval()
|
|
|
|
with torch.no_grad():
|
|
# Get model outputs for the entire sequence
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
return_dict=True
|
|
)
|
|
logits = outputs.logits
|
|
|
|
# Calculate softmax and log probabilities
|
|
probs = F.softmax(logits, dim=-1)
|
|
logprobs = torch.log(probs)
|
|
|
|
# Process each position
|
|
all_tokens = []
|
|
all_token_logprobs = []
|
|
all_top_logprobs = []
|
|
all_bytes = []
|
|
|
|
sequence_length = generated_ids.shape[-1]
|
|
|
|
for pos in range(sequence_length - 1): # -1 because we look at next token
|
|
next_token_id = generated_ids[0, pos + 1]
|
|
current_logits = logits[0, pos]
|
|
|
|
# Get token and its logprob
|
|
token = self.tokenizer.decode([next_token_id])
|
|
token_logprob = logprobs[0, pos, next_token_id].item()
|
|
|
|
# Get top alternative tokens
|
|
top_logprobs = self._get_top_alternatives(
|
|
current_logits,
|
|
next_token_id,
|
|
num_alternatives
|
|
)
|
|
|
|
# Get bytes for token
|
|
token_bytes = self._get_bytes_for_token(token)
|
|
|
|
all_tokens.append(token)
|
|
all_token_logprobs.append(token_logprob)
|
|
all_top_logprobs.append(top_logprobs)
|
|
all_bytes.append(token_bytes)
|
|
|
|
# Add None for the last token
|
|
all_tokens.append(self.tokenizer.decode([generated_ids[0, -1]]))
|
|
all_token_logprobs.append(None)
|
|
all_top_logprobs.append(None)
|
|
all_bytes.append(self._get_bytes_for_token(all_tokens[-1]))
|
|
|
|
return LogProbsResult(
|
|
tokens=all_tokens,
|
|
token_logprobs=all_token_logprobs,
|
|
top_logprobs=all_top_logprobs,
|
|
bytes_per_token=all_bytes
|
|
)
|
|
|
|
class MemoryEfficientAttention(nn.Module):
|
|
"""
|
|
Memory-efficient attention using linear attention mechanism.
|
|
Supports automatic fallback to optimized implementations when available.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_attention_heads: int,
|
|
dropout: float = 0.1,
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.num_attention_heads = num_attention_heads
|
|
self.head_dim = hidden_size // num_attention_heads
|
|
self.scale = self.head_dim ** -0.5
|
|
|
|
# Standard projections with bias
|
|
self.q_proj = nn.Linear(hidden_size, hidden_size)
|
|
self.k_proj = nn.Linear(hidden_size, hidden_size)
|
|
self.v_proj = nn.Linear(hidden_size, hidden_size)
|
|
self.o_proj = nn.Linear(hidden_size, hidden_size)
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
# Try to import optimized attention implementations
|
|
self.optimized_attention = None
|
|
|
|
# Try Flash Attention
|
|
try:
|
|
from flash_attn import flash_attn_func
|
|
self.optimized_attention = flash_attn_func
|
|
print("Using Flash Attention")
|
|
except ImportError:
|
|
pass
|
|
|
|
# Try xFormers
|
|
if self.optimized_attention is None:
|
|
try:
|
|
import xformers.ops as xops
|
|
self.optimized_attention = xops.memory_efficient_attention
|
|
print("Using xFormers attention")
|
|
except ImportError:
|
|
pass
|
|
|
|
def _linear_attention(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
causal: bool = False,
|
|
) -> torch.Tensor:
|
|
"""Implements linear attention for more memory efficiency"""
|
|
# Scale query
|
|
q = q * self.scale
|
|
|
|
# Apply mask if provided
|
|
if attention_mask is not None:
|
|
# Convert boolean mask to float mask if needed
|
|
if attention_mask.dtype == torch.bool:
|
|
attention_mask = attention_mask.float()
|
|
k = k * attention_mask.unsqueeze(-1)
|
|
|
|
if causal:
|
|
# Handle causal attention
|
|
batch_size, num_heads, seq_length, head_dim = q.shape
|
|
positions = torch.arange(seq_length, device=q.device)
|
|
causal_mask = positions.view(1, 1, -1, 1) <= positions.view(1, 1, 1, -1)
|
|
k = k * causal_mask.float()
|
|
|
|
# Linear attention computation
|
|
context = torch.matmul(k.transpose(-2, -1), v)
|
|
out = torch.matmul(q, context)
|
|
|
|
return out
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
causal: bool = False,
|
|
) -> torch.Tensor:
|
|
batch_size, seq_length, _ = hidden_states.size()
|
|
|
|
# Project to q, k, v
|
|
q = self.q_proj(hidden_states)
|
|
k = self.k_proj(hidden_states)
|
|
v = self.v_proj(hidden_states)
|
|
|
|
# Reshape for attention
|
|
q = q.view(batch_size, seq_length, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
|
k = k.view(batch_size, seq_length, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
|
v = v.view(batch_size, seq_length, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
|
|
|
# Try using optimized attention if available
|
|
if self.optimized_attention is not None and hidden_states.device.type == 'cuda':
|
|
# Handle attention mask for optimized implementations
|
|
if attention_mask is not None:
|
|
if attention_mask.dtype != torch.bool:
|
|
attention_mask = attention_mask > 0
|
|
attention_mask = attention_mask.view(batch_size, 1, 1, seq_length)
|
|
|
|
try:
|
|
attn_output = self.optimized_attention(
|
|
q, k, v,
|
|
attn_mask=attention_mask,
|
|
causal=causal,
|
|
scale=self.scale
|
|
)
|
|
except Exception as e:
|
|
print(f"Optimized attention failed, falling back to linear attention: {e}")
|
|
attn_output = self._linear_attention(q, k, v, attention_mask, causal)
|
|
else:
|
|
# Use linear attention for CPU/MPS or when optimized attention is not available
|
|
attn_output = self._linear_attention(q, k, v, attention_mask, causal)
|
|
|
|
# Reshape and project back
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.view(batch_size, seq_length, self.hidden_size)
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output
|
|
|
|
class PromptCache:
|
|
"""Advanced caching system for frequent prompts and responses"""
|
|
|
|
def __init__(self, max_size: int = 1000):
|
|
self.max_size = max_size
|
|
self.cache = OrderedDict()
|
|
self.prompt_stats = defaultdict(lambda: {"count": 0, "success_rate": 0.0})
|
|
|
|
@lru_cache(maxsize=128)
|
|
def _compute_prompt_signature(self, prompt: str) -> str:
|
|
"""Compute a signature for semantic similarity matching"""
|
|
# Simple but effective signature based on key content words
|
|
words = set(prompt.lower().split())
|
|
return " ".join(sorted(list(words)))
|
|
|
|
def get_cached_response(self, prompt: str, temperature: float, top_p: float) -> Optional[str]:
|
|
"""Get cached response with fuzzy matching"""
|
|
signature = self._compute_prompt_signature(prompt)
|
|
|
|
if signature in self.cache:
|
|
cached_item = self.cache[signature]
|
|
if abs(cached_item["temperature"] - temperature) < 0.1 and abs(cached_item["top_p"] - top_p) < 0.1:
|
|
self.prompt_stats[signature]["count"] += 1
|
|
return cached_item["response"]
|
|
|
|
return None
|
|
|
|
def add_to_cache(self, prompt: str, response: str, temperature: float, top_p: float):
|
|
"""Add response to cache with metadata"""
|
|
signature = self._compute_prompt_signature(prompt)
|
|
|
|
self.cache[signature] = {
|
|
"response": response,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"timestamp": torch.cuda.current_timestamp() if torch.cuda.is_available() else 0
|
|
}
|
|
|
|
if len(self.cache) > self.max_size:
|
|
self.cache.popitem(last=False)
|
|
|
|
def update_stats(self, prompt: str, success: bool):
|
|
"""Update prompt success statistics"""
|
|
signature = self._compute_prompt_signature(prompt)
|
|
stats = self.prompt_stats[signature]
|
|
stats["count"] += 1
|
|
stats["success_rate"] = (stats["success_rate"] * (stats["count"] - 1) + float(success)) / stats["count"]
|
|
|
|
class DynamicTemperature:
|
|
"""Implements dynamic temperature scaling based on input characteristics"""
|
|
|
|
def __init__(self):
|
|
self.token_entropy_cache = {}
|
|
|
|
def _compute_token_entropy(self, tokens: List[int]) -> float:
|
|
"""Compute token distribution entropy"""
|
|
token_counts = np.bincount(tokens)
|
|
probabilities = token_counts / len(tokens)
|
|
return entropy(probabilities)
|
|
|
|
def get_optimal_temperature(self, prompt: str, tokenizer: AutoTokenizer, base_temperature: float) -> float:
|
|
"""Calculate optimal temperature based on prompt characteristics"""
|
|
tokens = tokenizer.encode(prompt)
|
|
|
|
# Calculate entropy-based scaling
|
|
token_entropy = self._compute_token_entropy(tokens)
|
|
|
|
# Scale temperature based on prompt entropy and length
|
|
length_factor = np.clip(len(tokens) / 100, 0.5, 2.0)
|
|
entropy_factor = np.clip(token_entropy / 4.0, 0.5, 1.5)
|
|
|
|
optimal_temperature = base_temperature * length_factor * entropy_factor
|
|
return np.clip(optimal_temperature, 0.1, 2.0)
|
|
|
|
class CacheManager:
|
|
"""
|
|
Singleton cache manager for models and tokenizers.
|
|
Thread-safe but minimizes lock contention.
|
|
"""
|
|
_instance = None
|
|
_lock = threading.Lock()
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
if cls._instance is None:
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
instance = super().__new__(cls)
|
|
instance._initialized = False
|
|
cls._instance = instance
|
|
return cls._instance
|
|
|
|
def __init__(self, max_size: int = 5):
|
|
if self._initialized:
|
|
return
|
|
|
|
with self._lock:
|
|
if not self._initialized:
|
|
logger.info("Initializing CacheManager singleton")
|
|
self.max_size = max_size
|
|
self.model_cache = OrderedDict()
|
|
self.tokenizer_cache = OrderedDict()
|
|
self.adapter_cache = OrderedDict()
|
|
self.model_adapter_map = {} # Maps model ID to list of loaded adapter IDs
|
|
self.cache_stats = defaultdict(lambda: {"hits": 0, "misses": 0})
|
|
self._initialized = True
|
|
logger.info("CacheManager singleton initialized")
|
|
|
|
def get_or_load_model(self, model_key: str, loader_fn) -> Tuple[Any, Any]:
|
|
"""Get or load model and tokenizer with minimal locking."""
|
|
cached_model = cached_tokenizer = None
|
|
cache_hit = False
|
|
|
|
with self._lock:
|
|
if model_key in self.model_cache and model_key in self.tokenizer_cache:
|
|
cached_model = self.model_cache[model_key]
|
|
cached_tokenizer = self.tokenizer_cache[model_key]
|
|
self.model_cache.move_to_end(model_key)
|
|
self.tokenizer_cache.move_to_end(model_key)
|
|
self.cache_stats[model_key]["hits"] += 1
|
|
cache_hit = True
|
|
logger.debug(f"Cache hit for model: {model_key}")
|
|
|
|
if cache_hit:
|
|
return cached_model, cached_tokenizer
|
|
|
|
logger.info(f"Loading model and tokenizer: {model_key}")
|
|
model, tokenizer = loader_fn()
|
|
|
|
with self._lock:
|
|
if model_key in self.model_cache and model_key in self.tokenizer_cache:
|
|
cached_model = self.model_cache[model_key]
|
|
cached_tokenizer = self.tokenizer_cache[model_key]
|
|
self.cache_stats[model_key]["hits"] += 1
|
|
logger.debug(f"Using already cached model: {model_key}")
|
|
return cached_model, cached_tokenizer
|
|
|
|
self.model_cache[model_key] = model
|
|
self.tokenizer_cache[model_key] = tokenizer
|
|
self.cache_stats[model_key]["misses"] += 1
|
|
self.model_adapter_map[model_key] = [] # Initialize empty adapter list for new model
|
|
|
|
self._cleanup_caches()
|
|
|
|
logger.info(f"Successfully cached model and tokenizer: {model_key}")
|
|
return model, tokenizer
|
|
|
|
def get_or_load_adapter(self, model_key: str, adapter_key: str, loader_fn):
|
|
"""Get or load adapter with enhanced caching."""
|
|
cache_key = f"{model_key}_{adapter_key}"
|
|
|
|
with self._lock:
|
|
if cache_key in self.adapter_cache:
|
|
adapter = self.adapter_cache[cache_key]
|
|
self.adapter_cache.move_to_end(cache_key)
|
|
logger.debug(f"Cache hit for adapter: {cache_key}")
|
|
return adapter
|
|
|
|
adapter = loader_fn()
|
|
|
|
with self._lock:
|
|
self.adapter_cache[cache_key] = adapter
|
|
if model_key not in self.model_adapter_map:
|
|
self.model_adapter_map[model_key] = []
|
|
if adapter_key not in self.model_adapter_map[model_key]:
|
|
self.model_adapter_map[model_key].append(adapter_key)
|
|
self._cleanup_caches()
|
|
logger.info(f"Successfully cached adapter: {cache_key}")
|
|
return adapter
|
|
|
|
def get_model_adapters(self, model_key: str) -> List[str]:
|
|
"""Get list of adapter IDs loaded for a specific model."""
|
|
with self._lock:
|
|
return self.model_adapter_map.get(model_key, [])
|
|
|
|
def _cleanup_caches(self):
|
|
"""Clean up caches if they exceed max size."""
|
|
while len(self.model_cache) > self.max_size:
|
|
model_key, model = self.model_cache.popitem(last=False)
|
|
if hasattr(model, 'cpu'):
|
|
model.cpu()
|
|
# Clean up associated adapters
|
|
if model_key in self.model_adapter_map:
|
|
for adapter_id in self.model_adapter_map[model_key]:
|
|
cache_key = f"{model_key}_{adapter_id}"
|
|
if cache_key in self.adapter_cache:
|
|
self.adapter_cache.pop(cache_key)
|
|
self.model_adapter_map.pop(model_key)
|
|
|
|
while len(self.tokenizer_cache) > self.max_size:
|
|
self.tokenizer_cache.popitem(last=False)
|
|
|
|
# Cleanup orphaned adapters
|
|
valid_cache_keys = {
|
|
f"{model_key}_{adapter_id}"
|
|
for model_key, adapter_ids in self.model_adapter_map.items()
|
|
for adapter_id in adapter_ids
|
|
}
|
|
|
|
orphaned_adapters = [
|
|
key for key in self.adapter_cache.keys()
|
|
if key not in valid_cache_keys
|
|
]
|
|
|
|
for key in orphaned_adapters:
|
|
adapter = self.adapter_cache.pop(key)
|
|
if hasattr(adapter, 'cpu'):
|
|
adapter.cpu()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
@classmethod
|
|
def get_instance(cls, max_size: int = 5) -> 'CacheManager':
|
|
"""Alternative way to get the singleton instance."""
|
|
if cls._instance is None:
|
|
return cls(max_size)
|
|
return cls._instance
|
|
|
|
class DeviceManager:
|
|
def __init__(self):
|
|
self.available_devices = self._detect_devices()
|
|
self.device_stats = {device: {'memory_used': 0, 'active_models': 0} for device in self.available_devices}
|
|
|
|
def _detect_devices(self) -> List[str]:
|
|
devices = ['cpu']
|
|
if torch.cuda.is_available():
|
|
devices.extend([f'cuda:{i}' for i in range(torch.cuda.device_count())])
|
|
if torch.backends.mps.is_available():
|
|
devices.append('mps')
|
|
return devices
|
|
|
|
def get_optimal_device(self, model_size: int = 0) -> str:
|
|
if not self.available_devices:
|
|
return 'cpu'
|
|
|
|
# Prefer CUDA devices if available
|
|
cuda_devices = [d for d in self.available_devices if 'cuda' in d]
|
|
if cuda_devices:
|
|
# Find CUDA device with most free memory
|
|
max_free_memory = 0
|
|
optimal_device = cuda_devices[0]
|
|
|
|
for device in cuda_devices:
|
|
idx = int(device.split(':')[1])
|
|
free_memory = torch.cuda.get_device_properties(idx).total_memory - torch.cuda.memory_allocated(idx)
|
|
if free_memory > max_free_memory:
|
|
max_free_memory = free_memory
|
|
optimal_device = device
|
|
|
|
return optimal_device
|
|
|
|
# Fall back to MPS if available
|
|
if 'mps' in self.available_devices:
|
|
return 'mps'
|
|
|
|
return 'cpu'
|
|
|
|
def track_device_usage(self, device: str, memory_delta: int):
|
|
if device in self.device_stats:
|
|
self.device_stats[device]['memory_used'] += memory_delta
|
|
|
|
class ModelManager:
|
|
def __init__(self, cache_manager: CacheManager, device_manager: DeviceManager):
|
|
self.cache_manager = cache_manager
|
|
self.device_manager = device_manager
|
|
|
|
def quantize_model(self, model):
|
|
"""Quantize model to 4-bit precision using bitsandbytes"""
|
|
def _replace_linear_layers(module):
|
|
for name, child in module.named_children():
|
|
if isinstance(child, torch.nn.Linear):
|
|
setattr(module, name, bnb.nn.Linear4bit(
|
|
child.in_features,
|
|
child.out_features,
|
|
bias=child.bias is not None,
|
|
compute_dtype=torch.float16
|
|
))
|
|
else:
|
|
_replace_linear_layers(child)
|
|
|
|
_replace_linear_layers(model)
|
|
return model
|
|
|
|
def load_base_model(self, model_id: str, quantize: bool = True) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
|
def _load_model():
|
|
logger.info(f"Loading base model: {model_id}")
|
|
|
|
device = self.device_manager.get_optimal_device()
|
|
logger.info(f"Using device: {device}")
|
|
|
|
# Load tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
|
|
# Base kwargs for model loading
|
|
model_kwargs = {
|
|
"trust_remote_code": True,
|
|
"device_map": "auto" if 'cuda' in device else device
|
|
}
|
|
|
|
# Configure device-specific optimizations
|
|
if 'cuda' in device:
|
|
compute_capability = torch.cuda.get_device_capability(0)
|
|
if compute_capability[0] >= 8:
|
|
model_kwargs["torch_dtype"] = torch.bfloat16
|
|
elif compute_capability[0] >= 7:
|
|
model_kwargs["torch_dtype"] = torch.float16
|
|
|
|
# Check for flash attention availability
|
|
try:
|
|
import flash_attn
|
|
has_flash_attn = True
|
|
logger.info("Flash Attention 2 is available")
|
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
except ImportError:
|
|
has_flash_attn = False
|
|
logger.info("Flash Attention 2 is not installed - falling back to default attention")
|
|
|
|
elif 'mps' in device:
|
|
# MPS supports FP16
|
|
model_kwargs["torch_dtype"] = torch.float16
|
|
# model_kwargs["torch_dtype"] = torch.float32
|
|
logger.info("Using MPS device with float16 precision")
|
|
else:
|
|
# CPU can use FP16 if available
|
|
if hasattr(torch.cpu, 'has_fp16') and torch.cpu.has_fp16:
|
|
model_kwargs["torch_dtype"] = torch.float16
|
|
logger.info("Using CPU device with float16 precision")
|
|
else:
|
|
model_kwargs["torch_dtype"] = torch.float32
|
|
logger.info("Using CPU device with float32 precision - FP16 not supported")
|
|
|
|
# Load model with configured optimizations
|
|
try:
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
**model_kwargs
|
|
)
|
|
except Exception as e:
|
|
if "attn_implementation" in model_kwargs:
|
|
logger.warning(f"Failed to load model with Flash Attention: {e}")
|
|
logger.info("Retrying without Flash Attention...")
|
|
model_kwargs.pop("attn_implementation")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
**model_kwargs
|
|
)
|
|
elif model_kwargs["torch_dtype"] == torch.float16:
|
|
# If FP16 fails, fallback to FP32
|
|
logger.warning(f"Failed to load model with FP16: {e}")
|
|
logger.info("Falling back to FP32...")
|
|
model_kwargs["torch_dtype"] = torch.float32
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
**model_kwargs
|
|
)
|
|
|
|
logger.info(f"Model loaded successfully with dtype: {model_kwargs['torch_dtype']}")
|
|
|
|
# Only apply quantization for CUDA devices when not using mixed precision
|
|
if quantize and 'cuda' in device and model_kwargs["torch_dtype"] == torch.float32:
|
|
model = self.quantize_model(model)
|
|
|
|
return model, tokenizer
|
|
|
|
return self.cache_manager.get_or_load_model(model_id, _load_model)
|
|
|
|
class LoRAManager:
|
|
"""LoRA manager with enhanced error handling and caching"""
|
|
|
|
def __init__(self, cache_manager: CacheManager):
|
|
self.cache_manager = cache_manager
|
|
self.loaded_adapters = {}
|
|
self.adapter_names = {} # Maps adapter_id to valid adapter name
|
|
|
|
def _get_adapter_name(self, adapter_id: str) -> str:
|
|
"""Create a valid adapter name from adapter_id."""
|
|
if adapter_id in self.adapter_names:
|
|
return self.adapter_names[adapter_id]
|
|
|
|
name = adapter_id.replace('.', '_').replace('-', '_')
|
|
name = ''.join(c if c.isalnum() or c == '_' else '' for c in name)
|
|
if name[0].isdigit():
|
|
name = f"adapter_{name}"
|
|
|
|
self.adapter_names[adapter_id] = name
|
|
return name
|
|
|
|
def validate_adapter(self, adapter_id: str) -> bool:
|
|
"""Validate if adapter exists and is compatible"""
|
|
try:
|
|
config = PeftConfig.from_pretrained(
|
|
adapter_id,
|
|
trust_remote_code=True,
|
|
use_auth_token=os.getenv("HF_TOKEN")
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error validating adapter {adapter_id}: {str(e)}")
|
|
return False
|
|
|
|
def load_adapter(self, base_model: PreTrainedModel, adapter_id: str) -> PreTrainedModel:
|
|
"""Load a LoRA adapter with enhanced caching"""
|
|
model_key = base_model.config._name_or_path
|
|
|
|
def _load_adapter():
|
|
logger.info(f"Loading LoRA adapter: {adapter_id}")
|
|
|
|
if not self.validate_adapter(adapter_id):
|
|
error_msg = f"Adapter {adapter_id} not found or is not compatible"
|
|
logger.error(error_msg)
|
|
raise ValueError(error_msg)
|
|
|
|
try:
|
|
adapter_name = self._get_adapter_name(adapter_id)
|
|
|
|
config = PeftConfig.from_pretrained(
|
|
adapter_id,
|
|
trust_remote_code=True,
|
|
use_auth_token=os.getenv("HF_TOKEN")
|
|
)
|
|
|
|
model = base_model
|
|
model.add_adapter(
|
|
config,
|
|
adapter_name = adapter_name,
|
|
)
|
|
|
|
if model not in self.loaded_adapters:
|
|
self.loaded_adapters[model] = []
|
|
if adapter_id not in self.loaded_adapters[model]:
|
|
self.loaded_adapters[model].append(adapter_id)
|
|
|
|
return model
|
|
|
|
except Exception as e:
|
|
error_msg = f"Failed to load adapter {adapter_id}: {str(e)}"
|
|
logger.error(error_msg)
|
|
raise RuntimeError(error_msg) from e
|
|
|
|
return self.cache_manager.get_or_load_adapter(model_key, adapter_id, _load_adapter)
|
|
|
|
def set_active_adapter(self, model: PeftModel, adapter_id: str = None) -> bool:
|
|
"""Set a specific adapter as active with error handling"""
|
|
if not isinstance(model, PeftModel):
|
|
logger.warning("Model is not a PeftModel, cannot set active adapter")
|
|
return False
|
|
|
|
available_adapters = self.loaded_adapters.get(model, [])
|
|
|
|
if not available_adapters:
|
|
logger.warning("No adapters loaded in model")
|
|
return False
|
|
|
|
if adapter_id is None:
|
|
adapter_id = available_adapters[-1]
|
|
|
|
if adapter_id in available_adapters:
|
|
try:
|
|
model.set_adapter(self._get_adapter_name(adapter_id))
|
|
logger.info(f"Successfully set active adapter to: {adapter_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error setting adapter {adapter_id}: {str(e)}")
|
|
return False
|
|
else:
|
|
logger.warning(f"Requested adapter {adapter_id} not loaded. Available adapters: {available_adapters}")
|
|
return False
|
|
|
|
class InferencePipeline:
|
|
def __init__(self, model_config: ModelConfig, cache_manager, device_manager, model_manager, lora_manager):
|
|
self.model_config = model_config
|
|
self.cache_manager = cache_manager
|
|
self.device_manager = device_manager
|
|
self.model_manager = model_manager
|
|
self.lora_manager = lora_manager
|
|
self.last_used = time.time()
|
|
|
|
try:
|
|
self.base_model, self.tokenizer = self.model_manager.load_base_model(
|
|
model_config.base_model_id,
|
|
quantize=model_config.quantization_bits == 4
|
|
)
|
|
|
|
self.tokenizer = self.setup_tokenizer(self.tokenizer)
|
|
|
|
if self.base_model.get_input_embeddings().num_embeddings != len(self.tokenizer):
|
|
self.base_model.resize_token_embeddings(len(self.tokenizer))
|
|
|
|
self.current_model = self.base_model
|
|
|
|
if model_config.adapter_ids:
|
|
for adapter_id in model_config.adapter_ids:
|
|
try:
|
|
self.current_model = self.lora_manager.load_adapter(
|
|
self.current_model, adapter_id
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error loading adapter {adapter_id}: {e}")
|
|
|
|
# Set active adapter and verify it's set correctly
|
|
if isinstance(self.current_model, PeftModel):
|
|
success = self.lora_manager.set_active_adapter(self.current_model)
|
|
if not success:
|
|
logger.error("Failed to set active adapter")
|
|
|
|
self.dtype = self.current_model.dtype
|
|
self.optimal_batch_size = self._find_optimal_batch_size()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Pipeline initialization error: {str(e)}")
|
|
logger.error(f"Error traceback: {traceback.format_exc()}")
|
|
raise
|
|
|
|
def setup_tokenizer(self, tokenizer: AutoTokenizer) -> AutoTokenizer:
|
|
"""Use tokenizer with its default configuration for inference"""
|
|
logger.debug(" a. Starting tokenizer setup")
|
|
|
|
# Just use existing special tokens without modification
|
|
logger.debug(f" b. Using tokenizer with vocab size: {len(tokenizer)}")
|
|
logger.debug(f" c. Special tokens: PAD={tokenizer.pad_token_id}, "
|
|
f"EOS={tokenizer.eos_token_id}, BOS={tokenizer.bos_token_id}")
|
|
|
|
return tokenizer
|
|
|
|
def get_optimized_generation_config(self, generation_params: Optional[Dict[str, Any]] = None) -> Dict:
|
|
"""Get optimized generation config"""
|
|
config = {
|
|
"max_new_tokens": generation_params.get("max_new_tokens", 4096),
|
|
"do_sample": generation_params.get("temperature", 1.0) > 0,
|
|
"temperature": generation_params.get("temperature", 1.0),
|
|
"top_p": generation_params.get("top_p", 0.95),
|
|
"num_return_sequences": generation_params.get("num_return_sequences", 1),
|
|
"pad_token_id": self.tokenizer.pad_token_id,
|
|
"eos_token_id": self.tokenizer.eos_token_id,
|
|
"return_dict_in_generate": True,
|
|
"output_scores": generation_params.get("logprobs", False),
|
|
"use_cache": True,
|
|
"return_legacy_cache": True, # To avoid warning
|
|
}
|
|
return config
|
|
|
|
def generate(
|
|
self,
|
|
prompt: str,
|
|
generation_params: Optional[Dict[str, Any]] = None
|
|
) -> Tuple[List[str], List[int]]:
|
|
"""Generate completions with optional logprobs"""
|
|
start_time = time.time()
|
|
|
|
# First: Set pad token if needed
|
|
if self.tokenizer.pad_token is None:
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
# Tokenize with batching disabled for single prompts
|
|
tokenize_start = time.time()
|
|
inputs = self.tokenizer(
|
|
prompt,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
).to(self.current_model.device)
|
|
logger.info(f"Tokenization time: {time.time() - tokenize_start:.2f}s")
|
|
|
|
# Extract logprobs parameters
|
|
calculate_logprobs = generation_params.get("logprobs", False)
|
|
top_logprobs = generation_params.get("top_logprobs", 0)
|
|
|
|
if top_logprobs and not calculate_logprobs:
|
|
raise ValueError("logprobs must be true when top_logprobs is specified")
|
|
|
|
if top_logprobs and not (0 <= top_logprobs <= 20):
|
|
raise ValueError("top_logprobs must be between 0 and 20")
|
|
|
|
# Get optimized generation config
|
|
gen_config = self.get_optimized_generation_config(generation_params)
|
|
|
|
# Add optional parameters
|
|
if generation_params:
|
|
if generation_params.get("presence_penalty", 0) != 0:
|
|
gen_config["presence_penalty"] = generation_params["presence_penalty"]
|
|
if generation_params.get("frequency_penalty", 0) != 0:
|
|
gen_config["repetition_penalty"] = 1.0 + generation_params["frequency_penalty"]
|
|
if generation_params.get("stop_sequences"):
|
|
gen_config["stopping_criteria"] = self._create_stopping_criteria(
|
|
generation_params["stop_sequences"],
|
|
inputs['input_ids'].shape[1]
|
|
)
|
|
if generation_params.get("seed") is not None:
|
|
torch.manual_seed(generation_params["seed"])
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(generation_params["seed"])
|
|
|
|
# Generate responses
|
|
generate_start = time.time()
|
|
with torch.inference_mode(): # Faster than no_grad
|
|
outputs = self.current_model.generate(
|
|
**inputs,
|
|
**gen_config
|
|
)
|
|
logger.info(f"Generation time: {time.time() - generate_start:.2f}s")
|
|
|
|
generated_sequences = outputs.sequences
|
|
input_length = inputs['input_ids'].shape[1]
|
|
|
|
# Process outputs
|
|
process_start = time.time()
|
|
responses = []
|
|
token_counts = []
|
|
logprobs_results = []
|
|
|
|
# Process each generated sequence
|
|
for sequence in generated_sequences:
|
|
response_tokens = sequence[input_length:]
|
|
response_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True)
|
|
responses.append(response_text)
|
|
token_counts.append(len(response_tokens))
|
|
|
|
# Calculate logprobs if requested
|
|
if calculate_logprobs:
|
|
calculator = LogProbsCalculator(self.tokenizer, self.current_model)
|
|
logprobs_result = calculator.calculate_logprobs(
|
|
input_ids=sequence.unsqueeze(0),
|
|
generated_ids=sequence.unsqueeze(0),
|
|
attention_mask=torch.ones_like(sequence).unsqueeze(0),
|
|
num_alternatives=top_logprobs or 5
|
|
)
|
|
logprobs_results.append({
|
|
"content": [{
|
|
"token": token,
|
|
"logprob": logprob,
|
|
"bytes": bytes_,
|
|
"top_logprobs": top_logprobs
|
|
} for token, logprob, bytes_, top_logprobs in zip(
|
|
logprobs_result.tokens[input_length:],
|
|
logprobs_result.token_logprobs[input_length:],
|
|
logprobs_result.bytes_per_token[input_length:],
|
|
logprobs_result.top_logprobs[input_length:]
|
|
)]
|
|
})
|
|
else:
|
|
logprobs_results.append(None)
|
|
|
|
logger.info(f"Post-processing time: {time.time() - process_start:.2f}s")
|
|
logger.info(f"Total generation time: {time.time() - start_time:.2f}s")
|
|
|
|
return responses, token_counts, logprobs_results
|
|
|
|
def setup_efficient_attention(self):
|
|
"""Replace standard attention with memory-efficient version"""
|
|
if hasattr(self.current_model, 'config') and hasattr(self.current_model.config, 'hidden_size'):
|
|
hidden_size = self.current_model.config.hidden_size
|
|
num_attention_heads = self.current_model.config.num_attention_heads
|
|
self.efficient_attention = MemoryEfficientAttention(hidden_size, num_attention_heads)
|
|
|
|
# Monkey patch attention computation if possible
|
|
if hasattr(self.current_model, 'encoder') and hasattr(self.current_model.encoder, 'layer'):
|
|
for layer in self.current_model.encoder.layer:
|
|
if hasattr(layer, 'attention'):
|
|
layer.attention.self = self.efficient_attention
|
|
logger.info("Memory-efficient attention mechanism enabled")
|
|
|
|
def _find_optimal_batch_size(self, initial_batch_size: int = 1, max_batch_size: int = 128) -> int:
|
|
"""Find optimal batch size through binary search with memory monitoring"""
|
|
if not torch.cuda.is_available():
|
|
return initial_batch_size
|
|
|
|
device = self.current_model.device
|
|
if 'cuda' not in str(device):
|
|
return initial_batch_size
|
|
|
|
left, right = initial_batch_size, max_batch_size
|
|
optimal_size = initial_batch_size
|
|
|
|
sample_text = "Sample input text for batch size optimization."
|
|
|
|
while left <= right:
|
|
mid = (left + right) // 2
|
|
try:
|
|
torch.cuda.empty_cache()
|
|
|
|
inputs = self.tokenizer([sample_text] * mid,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors="pt").to(device)
|
|
|
|
with torch.amp.autocast('cuda',dtype=self.dtype):
|
|
with torch.no_grad():
|
|
_ = self.current_model.generate(
|
|
**inputs,
|
|
max_new_tokens=1,
|
|
num_return_sequences=1,
|
|
pad_token_id=self.tokenizer.pad_token_id
|
|
)
|
|
|
|
optimal_size = mid
|
|
left = mid + 1
|
|
|
|
memory_used = torch.cuda.memory_allocated(device)
|
|
total_memory = torch.cuda.get_device_properties(device).total_memory
|
|
|
|
if memory_used > 0.9 * total_memory:
|
|
break
|
|
|
|
except torch.cuda.OutOfMemoryError:
|
|
right = mid - 1
|
|
torch.cuda.empty_cache()
|
|
|
|
return max(1, int(optimal_size * 0.9))
|
|
|
|
def optimize_generation_params(self, prompt: str) -> Dict[str, Any]:
|
|
"""Optimize generation parameters based on prompt characteristics"""
|
|
base_params = {
|
|
"max_new_tokens": self.model_config.max_new_tokens,
|
|
"do_sample": self.model_config.do_sample,
|
|
"top_p": self.model_config.top_p,
|
|
"top_k": self.model_config.top_k,
|
|
"temperature": self.model_config.temperature,
|
|
"num_return_sequences": self.model_config.num_return_sequences,
|
|
"repetition_penalty": self.model_config.repetition_penalty,
|
|
"pad_token_id": self.model_config.pad_token_id or self.tokenizer.pad_token_id
|
|
}
|
|
|
|
if self.model_config.dynamic_temperature:
|
|
base_params["temperature"] = self.dynamic_temperature.get_optimal_temperature(
|
|
prompt, self.tokenizer, base_params["temperature"]
|
|
)
|
|
|
|
return base_params
|
|
|
|
def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str:
|
|
"""Format the prompt according to model's chat template"""
|
|
if hasattr(self.tokenizer, 'apply_chat_template'):
|
|
# Use the model's built-in chat template if available
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt}
|
|
]
|
|
return self.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
else:
|
|
# Fallback to a generic template
|
|
return f"<|system|>{system_prompt}</s><|user|>{user_prompt}</s><|assistant|>"
|
|
|
|
def _create_stopping_criteria(self, stop_sequences: List[str], input_length: int):
|
|
"""Create stopping criteria for generation"""
|
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
|
|
|
class StopSequenceCriteria(StoppingCriteria):
|
|
def __init__(self, tokenizer, stop_sequences, input_length):
|
|
self.tokenizer = tokenizer
|
|
self.stop_ids = [
|
|
self.tokenizer.encode(seq, add_special_tokens=False)
|
|
for seq in stop_sequences
|
|
]
|
|
self.input_length = input_length
|
|
|
|
def __call__(self, input_ids, scores, **kwargs):
|
|
for stop_ids in self.stop_ids:
|
|
if input_ids[0, -len(stop_ids):].tolist() == stop_ids:
|
|
return True
|
|
return False
|
|
|
|
return StoppingCriteriaList([
|
|
StopSequenceCriteria(
|
|
self.tokenizer,
|
|
stop_sequences,
|
|
input_length=input_length
|
|
)
|
|
])
|
|
|
|
def process_batch(
|
|
self,
|
|
system_prompts: List[str],
|
|
user_prompts: List[str],
|
|
generation_params: Optional[Dict[str, Any]] = None,
|
|
active_adapter: str = None,
|
|
return_token_count: bool = True
|
|
) -> Tuple[List[str], List[int]]:
|
|
"""Process a batch of prompts with all optimizations"""
|
|
|
|
# Set the requested adapter if specified
|
|
if isinstance(self.current_model, PeftModel) and active_adapter is not None:
|
|
self.lora_manager.set_active_adapter(self.current_model, active_adapter)
|
|
|
|
all_responses = []
|
|
token_counts = []
|
|
|
|
# Format all prompts using chat template
|
|
formatted_prompts = [
|
|
self.format_chat_prompt(system_prompt, user_prompt)
|
|
for system_prompt, user_prompt in zip(system_prompts, user_prompts)
|
|
]
|
|
|
|
# Get number of completions requested
|
|
n = generation_params.get("num_return_sequences", 1) if generation_params else 1
|
|
|
|
for i in range(0, len(formatted_prompts), self.optimal_batch_size):
|
|
batch_prompts = formatted_prompts[i:i + self.optimal_batch_size]
|
|
batch_system = system_prompts[i:i + self.optimal_batch_size]
|
|
batch_user = user_prompts[i:i + self.optimal_batch_size]
|
|
|
|
# Check cache first if enabled
|
|
if self.model_config.enable_prompt_caching:
|
|
cached_responses = []
|
|
uncached_indices = []
|
|
|
|
for idx, prompt in enumerate(batch_prompts):
|
|
temp = generation_params.get("temperature", self.model_config.temperature) if generation_params else self.model_config.temperature
|
|
top_p = generation_params.get("top_p", self.model_config.top_p) if generation_params else self.model_config.top_p
|
|
|
|
cached_response = self.cache_manager.prompt_cache.get_cached_response(
|
|
prompt,
|
|
temp,
|
|
top_p
|
|
)
|
|
if cached_response is not None:
|
|
# For cached responses, replicate n times if multiple completions requested
|
|
cached_responses.extend([cached_response] * n)
|
|
else:
|
|
uncached_indices.append(idx)
|
|
|
|
if uncached_indices:
|
|
batch_prompts = [batch_prompts[i] for i in uncached_indices]
|
|
else:
|
|
batch_prompts = []
|
|
|
|
if batch_prompts: # If there are any uncached prompts
|
|
# Configure generation parameters
|
|
base_params = {
|
|
"max_new_tokens": generation_params.get("max_new_tokens", 4096) if generation_params else self.model_config.max_new_tokens,
|
|
"do_sample": generation_params.get("temperature", 1.0) > 0 if generation_params else self.model_config.do_sample,
|
|
"temperature": generation_params.get("temperature", 1.0) if generation_params else self.model_config.temperature,
|
|
"top_p": generation_params.get("top_p", 1.0) if generation_params else self.model_config.top_p,
|
|
"num_return_sequences": n,
|
|
"pad_token_id": self.tokenizer.pad_token_id,
|
|
"eos_token_id": self.tokenizer.eos_token_id,
|
|
}
|
|
|
|
# Add optional parameters if specified
|
|
if generation_params:
|
|
if generation_params.get("presence_penalty", 0) != 0:
|
|
base_params["presence_penalty"] = generation_params["presence_penalty"]
|
|
if generation_params.get("frequency_penalty", 0) != 0:
|
|
base_params["repetition_penalty"] = 1.0 + generation_params["frequency_penalty"]
|
|
if generation_params.get("logit_bias"):
|
|
base_params["logit_bias"] = generation_params["logit_bias"]
|
|
if generation_params.get("seed") is not None:
|
|
torch.manual_seed(generation_params["seed"])
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(generation_params["seed"])
|
|
|
|
# Tokenize inputs
|
|
inputs = self.tokenizer(
|
|
batch_prompts,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors="pt"
|
|
).to(self.current_model.device)
|
|
|
|
# Get the length of each input
|
|
input_lengths = inputs['input_ids'].shape[1]
|
|
|
|
# Add stopping criteria if specified
|
|
if generation_params and generation_params.get("stop_sequences"):
|
|
base_params["stopping_criteria"] = self._create_stopping_criteria(
|
|
generation_params["stop_sequences"],
|
|
input_lengths
|
|
)
|
|
|
|
# Generate responses
|
|
with torch.amp.autocast('cuda', dtype=self.dtype):
|
|
with torch.no_grad():
|
|
outputs = self.current_model.generate(
|
|
**inputs,
|
|
**base_params
|
|
)
|
|
|
|
# Decode outputs and remove input portion
|
|
batch_responses = []
|
|
batch_token_counts = []
|
|
|
|
# Handle multiple sequences per input
|
|
num_return_sequences = base_params["num_return_sequences"]
|
|
for i in range(0, len(outputs), num_return_sequences):
|
|
sequences = outputs[i:i + num_return_sequences]
|
|
for seq in sequences:
|
|
response_tokens = seq[input_lengths:]
|
|
response_text = self.tokenizer.decode(response_tokens, skip_special_tokens=True)
|
|
batch_responses.append(response_text)
|
|
batch_token_counts.append(len(response_tokens))
|
|
|
|
# Cache new responses if enabled
|
|
if self.model_config.enable_prompt_caching:
|
|
for prompt, response in zip(batch_prompts, batch_responses[::n]): # Only cache first response of each input
|
|
self.cache_manager.prompt_cache.add_to_cache(
|
|
prompt,
|
|
response,
|
|
base_params["temperature"],
|
|
base_params["top_p"]
|
|
)
|
|
|
|
# Merge cached and new responses in correct order
|
|
all_responses.extend(cached_responses)
|
|
if uncached_indices:
|
|
response_idx = 0
|
|
for original_idx in range(len(formatted_prompts[i:i + self.optimal_batch_size])):
|
|
if original_idx in uncached_indices:
|
|
# Add n responses for this uncached prompt
|
|
for _ in range(n):
|
|
while len(all_responses) < original_idx * n + _:
|
|
all_responses.append("")
|
|
if response_idx < len(batch_responses):
|
|
all_responses.append(batch_responses[response_idx])
|
|
response_idx += 1
|
|
|
|
if return_token_count:
|
|
# Count tokens for responses
|
|
token_counts.extend([0] * len(cached_responses)) # 0 for cached responses
|
|
token_counts.extend(batch_token_counts)
|
|
|
|
if return_token_count:
|
|
return all_responses, token_counts
|
|
return all_responses, [0] * len(all_responses)
|
|
|
|
class ChatCompletionMessage:
|
|
def __init__(self, content: str, role: str = "assistant", logprobs: Optional[Dict] = None):
|
|
self.content = content
|
|
self.role = role
|
|
self.logprobs = logprobs
|
|
|
|
class ChatCompletionChoice:
|
|
def __init__(
|
|
self,
|
|
index: int,
|
|
message: Dict[str, Any],
|
|
finish_reason: str = "stop",
|
|
logprobs: Optional[Dict] = None
|
|
):
|
|
self.index = index
|
|
self.message = ChatCompletionMessage(**message)
|
|
self.finish_reason = finish_reason
|
|
if logprobs:
|
|
self.message.logprobs = logprobs
|
|
|
|
class ChatCompletionUsage:
|
|
def __init__(self, prompt_tokens: int, completion_tokens: int, total_tokens: int):
|
|
self.prompt_tokens = prompt_tokens
|
|
self.completion_tokens = completion_tokens
|
|
self.total_tokens = total_tokens
|
|
|
|
class ChatCompletion:
|
|
def __init__(self, response_dict: Dict):
|
|
self.id = response_dict["id"]
|
|
self.object = response_dict["object"]
|
|
self.created = response_dict["created"]
|
|
self.model = response_dict["model"]
|
|
self.choices = [
|
|
ChatCompletionChoice(
|
|
index=choice["index"],
|
|
message=choice["message"],
|
|
finish_reason=choice["finish_reason"]
|
|
)
|
|
for choice in response_dict["choices"]
|
|
]
|
|
self.usage = ChatCompletionUsage(**response_dict["usage"])
|
|
|
|
def model_dump(self) -> Dict:
|
|
return {
|
|
"id": self.id,
|
|
"object": self.object,
|
|
"created": self.created,
|
|
"model": self.model,
|
|
"choices": [
|
|
{
|
|
"index": choice.index,
|
|
"message": {
|
|
"role": choice.message.role,
|
|
"content": choice.message.content,
|
|
"logprobs": choice.message.logprobs
|
|
} if choice.message.logprobs else {
|
|
"role": choice.message.role,
|
|
"content": choice.message.content
|
|
},
|
|
"finish_reason": choice.finish_reason
|
|
}
|
|
for choice in self.choices
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": self.usage.prompt_tokens,
|
|
"completion_tokens": self.usage.completion_tokens,
|
|
"total_tokens": self.usage.total_tokens
|
|
}
|
|
}
|
|
|
|
class InferenceClient:
|
|
"""OpenAI SDK Compatible client for local inference with dynamic model support"""
|
|
|
|
def __init__(self):
|
|
self.cache_manager = CacheManager.get_instance(max_size=4)
|
|
self.device_manager = DeviceManager()
|
|
self.model_manager = ModelManager(self.cache_manager, self.device_manager)
|
|
self.lora_manager = LoRAManager(self.cache_manager)
|
|
self.chat = self.Chat(self)
|
|
self.models = self.Models()
|
|
|
|
def get_pipeline(self, model: str) -> 'InferencePipeline':
|
|
model_config = parse_model_string(model)
|
|
return InferencePipeline(
|
|
model_config,
|
|
self.cache_manager,
|
|
self.device_manager,
|
|
self.model_manager,
|
|
self.lora_manager
|
|
)
|
|
|
|
class Chat:
|
|
"""OpenAI-compatible chat interface"""
|
|
def __init__(self, client: 'InferenceClient'):
|
|
self.client = client
|
|
self.completions = self.Completions(client)
|
|
|
|
class Completions:
|
|
def __init__(self, client: 'InferenceClient'):
|
|
self.client = client
|
|
|
|
def create(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
model: str,
|
|
temperature: float = 1.0,
|
|
top_p: float = 1.0,
|
|
n: int = 1,
|
|
stream: bool = False,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
max_tokens: Optional[int] = None,
|
|
presence_penalty: float = 0,
|
|
frequency_penalty: float = 0,
|
|
logit_bias: Optional[Dict[str, float]] = None,
|
|
seed: Optional[int] = None,
|
|
logprobs: Optional[bool] = None,
|
|
top_logprobs: Optional[int] = None,
|
|
active_adapter: Optional[Dict[str, Any]] = None,
|
|
decoding: Optional[str] = None,
|
|
# CoT specific params
|
|
k: int = 10,
|
|
num_beams: int = 1,
|
|
length_penalty: float = 1.0,
|
|
no_repeat_ngram_size: int = 0,
|
|
early_stopping: bool = False,
|
|
aggregate_paths: bool = True,
|
|
# Entropy specific params
|
|
top_k: int = 27,
|
|
min_p: float = 0.03,
|
|
# Thinking specific params
|
|
reasoning_effort: str = "low",
|
|
thought_switch_tokens: List[str] = [],
|
|
min_thinking_tokens: Optional[int] = None,
|
|
max_thinking_tokens: Optional[int] = None,
|
|
max_thoughts: Optional[int] = None,
|
|
prefill: str = "",
|
|
start_think_token: str ="<think>",
|
|
end_think_token: str = "</think>",
|
|
**kwargs
|
|
) -> ChatCompletion:
|
|
"""Create a chat completion with OpenAI-compatible parameters"""
|
|
logger.info("Starting chat completion creation")
|
|
if stream:
|
|
raise NotImplementedError("Streaming is not yet supported")
|
|
|
|
logger.info(f"Getting pipeline for model: {model}")
|
|
pipeline = self.client.get_pipeline(model)
|
|
logger.info("Pipeline acquired")
|
|
|
|
# Set active adapter if specified
|
|
if active_adapter is not None:
|
|
logger.info(f"Setting active adapter to: {active_adapter}")
|
|
pipeline.lora_manager.set_active_adapter(pipeline.current_model, active_adapter)
|
|
|
|
responses = []
|
|
logprobs_results = []
|
|
prompt_tokens = 0
|
|
completion_tokens = 0
|
|
|
|
try:
|
|
# Handle specialized decoding approaches
|
|
if decoding:
|
|
logger.info(f"Using specialized decoding approach: {decoding}")
|
|
|
|
# Ensure model is in eval mode and on correct device
|
|
pipeline.current_model.eval()
|
|
device = pipeline.current_model.device
|
|
|
|
if decoding == "cot_decoding":
|
|
# Use directly available parameters for CoT
|
|
cot_params = {
|
|
"k": k,
|
|
"num_beams": num_beams,
|
|
"max_new_tokens": max_tokens if max_tokens is not None else 512,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"repetition_penalty": 1.0,
|
|
"length_penalty": length_penalty,
|
|
"no_repeat_ngram_size": no_repeat_ngram_size,
|
|
"early_stopping": early_stopping,
|
|
"aggregate_paths": aggregate_paths,
|
|
}
|
|
|
|
result, confidence = cot_decode(
|
|
pipeline.current_model,
|
|
pipeline.tokenizer,
|
|
messages,
|
|
**cot_params
|
|
)
|
|
responses = [result]
|
|
logprobs_results = [{"confidence_score": confidence} if confidence is not None else None]
|
|
completion_tokens = len(pipeline.tokenizer.encode(result))
|
|
|
|
elif decoding == "entropy_decoding":
|
|
|
|
# Ensure model is using full precision
|
|
original_dtype = pipeline.current_model.dtype
|
|
pipeline.current_model = pipeline.current_model.to(torch.float32)
|
|
|
|
try:
|
|
# Configure generator for entropy decoding
|
|
generator = None
|
|
if seed is not None:
|
|
generator = torch.Generator(device=device)
|
|
generator.manual_seed(seed)
|
|
else:
|
|
generator = torch.Generator(device=device)
|
|
generator.manual_seed(1337) # Default seed as in original implementation
|
|
|
|
# Use directly available parameters for entropy decoding
|
|
entropy_params = {
|
|
"max_new_tokens": max_tokens if max_tokens is not None else 4096,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"top_k": top_k,
|
|
"min_p": min_p,
|
|
"generator": generator
|
|
}
|
|
|
|
# Disable autocast and run in full precision
|
|
with torch.amp.autocast('cuda', enabled=False), torch.inference_mode():
|
|
result = entropy_decode(
|
|
pipeline.current_model,
|
|
pipeline.tokenizer,
|
|
messages,
|
|
**entropy_params
|
|
)
|
|
responses = [result]
|
|
logprobs_results = [None]
|
|
completion_tokens = len(pipeline.tokenizer.encode(result))
|
|
|
|
finally:
|
|
# Restore original dtype
|
|
pipeline.current_model = pipeline.current_model.to(original_dtype)
|
|
|
|
elif decoding == "thinkdeeper":
|
|
# Get base config for reasoning effort
|
|
thinkdeeper_config = get_effort_profile(reasoning_effort, max_tokens)
|
|
|
|
# Override with any custom parameters
|
|
custom_config = {
|
|
"min_thinking_tokens": min_thinking_tokens if min_thinking_tokens is not None else thinkdeeper_config["min_thinking_tokens"],
|
|
"max_thinking_tokens": max_thinking_tokens if max_thinking_tokens is not None else thinkdeeper_config["max_thinking_tokens"],
|
|
"max_thoughts": max_thoughts if max_thoughts is not None else thinkdeeper_config["max_thoughts"],
|
|
"thought_switch_tokens": thought_switch_tokens if thought_switch_tokens else thinkdeeper_config["thought_switch_tokens"],
|
|
"prefill": prefill if prefill else thinkdeeper_config["prefill"],
|
|
"start_think_token": start_think_token,
|
|
"end_think_token": end_think_token,
|
|
}
|
|
thinkdeeper_config.update(custom_config)
|
|
|
|
result = thinkdeeper_decode(
|
|
pipeline.current_model,
|
|
pipeline.tokenizer,
|
|
messages,
|
|
thinkdeeper_config
|
|
)
|
|
responses = [result]
|
|
logprobs_results = [None]
|
|
completion_tokens = len(pipeline.tokenizer.encode(result))
|
|
elif decoding == "autothink":
|
|
# Get steering dataset configuration
|
|
steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors")
|
|
target_layer = kwargs.get("target_layer", 19)
|
|
|
|
# Prepare AutoThink configuration
|
|
autothink_config = {
|
|
"steering_dataset": steering_dataset,
|
|
"target_layer": target_layer,
|
|
"pattern_strengths": kwargs.get("pattern_strengths", {
|
|
"depth_and_thoroughness": 2.5,
|
|
"numerical_accuracy": 2.0,
|
|
"self_correction": 3.0,
|
|
"exploration": 2.0,
|
|
"organization": 1.5
|
|
})
|
|
}
|
|
|
|
# Process with AutoThink
|
|
result = autothink_decode(
|
|
pipeline.current_model,
|
|
pipeline.tokenizer,
|
|
messages,
|
|
autothink_config
|
|
)
|
|
responses = [result]
|
|
logprobs_results = [None]
|
|
completion_tokens = len(pipeline.tokenizer.encode(result))
|
|
else:
|
|
raise ValueError(f"Unknown specialized decoding approach: {decoding}")
|
|
|
|
# Calculate prompt tokens for specialized approaches
|
|
prompt_text = pipeline.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
prompt_tokens = len(pipeline.tokenizer.encode(prompt_text))
|
|
|
|
else:
|
|
# Standard generation
|
|
prompt = pipeline.tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True
|
|
)
|
|
|
|
# Set generation parameters
|
|
generation_params = {
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"num_return_sequences": n,
|
|
"max_new_tokens": max_tokens if max_tokens is not None else 4096,
|
|
"presence_penalty": presence_penalty,
|
|
"frequency_penalty": frequency_penalty,
|
|
"stop_sequences": [stop] if isinstance(stop, str) else stop,
|
|
"seed": seed,
|
|
"logit_bias": logit_bias,
|
|
"logprobs": logprobs,
|
|
"top_logprobs": top_logprobs
|
|
}
|
|
|
|
# Generate responses
|
|
responses, token_counts, logprobs_results = pipeline.generate(
|
|
prompt,
|
|
generation_params=generation_params
|
|
)
|
|
|
|
prompt_tokens = len(pipeline.tokenizer.encode(prompt))
|
|
completion_tokens = sum(token_counts)
|
|
|
|
# Create OpenAI-compatible response format
|
|
response_dict = {
|
|
"id": f"chatcmpl-{int(time.time()*1000)}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": model,
|
|
"choices": [
|
|
{
|
|
"index": idx,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": response,
|
|
**({"logprobs": logprob_result} if logprob_result else {})
|
|
},
|
|
"finish_reason": "stop"
|
|
}
|
|
for idx, (response, logprob_result) in enumerate(zip(responses, logprobs_results))
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": completion_tokens + prompt_tokens
|
|
}
|
|
}
|
|
|
|
logger.debug(f"Response : {response_dict}")
|
|
return ChatCompletion(response_dict)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in chat completion: {str(e)}")
|
|
raise
|
|
|
|
class Models:
|
|
"""OpenAI-compatible models interface"""
|
|
def list(self):
|
|
"""Return list of supported models"""
|
|
try:
|
|
import requests
|
|
response = requests.get(
|
|
"https://huggingface.co/api/models?sort=downloads&direction=-1&filter=text-generation&limit=20"
|
|
)
|
|
models = response.json()
|
|
model_list = []
|
|
|
|
for model in models:
|
|
if 'pipeline_tag' in model and model['pipeline_tag'] == 'text-generation':
|
|
model_list.append({
|
|
"id": model['id'],
|
|
"object": "model",
|
|
"created": int(time.time()),
|
|
"owned_by": "huggingface",
|
|
})
|
|
|
|
return {"data": model_list, "object": "list"}
|
|
except Exception as e:
|
|
logger.warning(f"Failed to fetch models: {e}")
|
|
return {
|
|
"data": [{
|
|
"id": "HuggingFaceTB/SmolLM-135M-Instruct",
|
|
"object": "model",
|
|
"created": int(time.time()),
|
|
"owned_by": "huggingface",
|
|
}],
|
|
"object": "list"
|
|
}
|
|
|
|
def create_inference_client() -> InferenceClient:
|
|
"""Factory function to create an inference client"""
|
|
return InferenceClient()
|
|
|
|
def parse_model_string(model: str) -> ModelConfig:
|
|
"""Parse the model string to extract base model and adapter IDs"""
|
|
parts = model.split('+')
|
|
base_model_id = parts[0]
|
|
adapter_ids = parts[1:] if len(parts) > 1 else None
|
|
|
|
return ModelConfig(
|
|
base_model_id=base_model_id,
|
|
adapter_ids=adapter_ids,
|
|
use_memory_efficient_attention=False,
|
|
quantization_bits=0,
|
|
enable_prompt_caching=False,
|
|
dynamic_temperature=False,
|
|
)
|
|
|
|
def get_effort_profile(reasoning_effort: str, max_tokens: int = 4096) -> dict:
|
|
"""Get reasoning effort profile based on specified level and max tokens.
|
|
|
|
Args:
|
|
reasoning_effort: 'low', 'medium', or 'high'
|
|
max_tokens: Maximum tokens allowed for generation, defaults to 4096
|
|
|
|
Returns:
|
|
dict: Configuration for the specified reasoning effort level
|
|
"""
|
|
# Base profiles with percentages and thought counts
|
|
profiles = {
|
|
"low": {
|
|
"min_tokens_pct": 0.10,
|
|
"max_tokens_pct": 0.33, # 33% of max_tokens
|
|
"max_thoughts": 64,
|
|
"thought_switch_tokens": [
|
|
"Wait,",
|
|
"Alternatively,",
|
|
"However,",
|
|
"Additionally,",
|
|
],
|
|
"prefill": ""
|
|
},
|
|
"medium": {
|
|
"min_tokens_pct": 0.10,
|
|
"max_tokens_pct": 0.66, # 66% of max_tokens
|
|
"max_thoughts": 256,
|
|
"thought_switch_tokens": [
|
|
"Wait,",
|
|
"Alternatively,",
|
|
"However,",
|
|
"Additionally,",
|
|
],
|
|
"prefill": ""
|
|
},
|
|
"high": {
|
|
"min_tokens_pct": 0.10,
|
|
"max_tokens_pct": 0.90, # 90% of max_tokens
|
|
"max_thoughts": 512,
|
|
"thought_switch_tokens": [
|
|
"Wait,",
|
|
"Alternatively,",
|
|
"However,",
|
|
"Additionally,",
|
|
],
|
|
"prefill": ""
|
|
}
|
|
}
|
|
|
|
# Get base profile or default to medium
|
|
profile = profiles.get(reasoning_effort.lower(), profiles["low"])
|
|
|
|
# Calculate actual token limits based on max_tokens
|
|
min_thinking_tokens = int(max_tokens * profile["min_tokens_pct"])
|
|
max_thinking_tokens = int(max_tokens * profile["max_tokens_pct"])
|
|
|
|
# Create final config
|
|
config = {
|
|
"min_thinking_tokens": min_thinking_tokens,
|
|
"max_thinking_tokens": max_thinking_tokens,
|
|
"max_thoughts": profile["max_thoughts"],
|
|
"thought_switch_tokens": profile["thought_switch_tokens"],
|
|
"prefill": profile["prefill"]
|
|
}
|
|
|
|
return config |