Files
LongCodeZip/longcodezip/__init__.py
YerbaPage a391badfe1 packaging
2025-10-11 21:33:12 +08:00

1899 lines
93 KiB
Python

import torch
import numpy as np
from typing import List, Union, Tuple, Dict, Optional
import re
import math
import zlib
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
from tqdm import tqdm
import copy
import bisect
import json
from loguru import logger
import sys
# set the level to info only, no need to show the debug messages
logger.remove()
logger.add(sys.stderr, level="INFO")
class EntropyChunking:
def __init__(self, model_name="Qwen/Qwen2.5-Coder-0.5B-Instruct"):
"""Entropy-based text chunking implementation"""
logger.debug(f"Loading Entropy chunking model: {model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.debug(f"Entropy chunking model loaded on device: {self.device}")
def split_into_sentences(self, text: str) -> List[str]:
"""Split text into sentences, inserting empty lines for double newlines"""
# First replace double newlines with a special marker
text_with_markers = text.replace('\n\n', '\n__EMPTY_LINE__\n')
# Split by single newlines
lines = text_with_markers.split('\n')
# Process lines: replace markers with empty strings, keep original lines
sentences = []
for line in lines:
if line == '__EMPTY_LINE__':
sentences.append(' ') # Empty line for double newline breaks
else:
sentences.append(line) # Keep original line with indentation
return sentences
def calculate_sentence_ppl(self, sentences: List[str]) -> List[float]:
"""Calculate perplexity for each sentence based on preceding context"""
ppls = []
for i, sentence in enumerate(sentences):
if i == 0:
context = ""
target = sentence
else:
context = "\n".join(sentences[:i])
target = sentence
ppl = self._compute_ppl(context, target)
ppls.append(ppl)
return ppls
def _compute_ppl(self, context: str, target: str) -> float:
"""Compute perplexity of target text given context"""
# Handle empty target lines
if not target:
return 0.0 # Assign zero perplexity to empty lines
if context:
full_text = context + "\n" + target
context_tokens = self.tokenizer(context + "\n", return_tensors="pt", add_special_tokens=True)
context_length = context_tokens.input_ids.shape[1]
else:
full_text = target
context_length = 0
inputs = self.tokenizer(full_text, return_tensors="pt", add_special_tokens=True).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
if context_length > 0:
target_logits = logits[0, context_length-1:-1]
target_labels = inputs.input_ids[0, context_length:]
else:
target_logits = logits[0, :-1]
target_labels = inputs.input_ids[0, 1:]
if len(target_labels) > 0:
log_probs = torch.log_softmax(target_logits, dim=-1)
token_log_probs = log_probs[torch.arange(len(target_labels)), target_labels]
avg_log_prob = token_log_probs.mean().item()
ppl = math.exp(-avg_log_prob)
else:
ppl = float('inf')
# take log2 of ppl
ppl = math.log2(ppl)
return ppl
def calculate_adaptive_thresholds(self, ppls: List[float], k: float = 0.2) -> dict:
"""Calculate adaptive thresholds using different statistical methods"""
# Filter out infinite and NaN values
valid_ppls = [p for p in ppls if not math.isinf(p) and not math.isnan(p) and p > 0]
if len(valid_ppls) < 3:
# Fallback to fixed threshold if not enough valid data
return {
'std': 0.5,
'robust_std': 0.5,
'iqr': 0.5,
'mad': 0.5
}
valid_ppls = np.array(valid_ppls)
# Method 1: Standard deviation based
mean_ppl = np.mean(valid_ppls)
std_ppl = np.std(valid_ppls)
threshold_std = mean_ppl + k * std_ppl
# Method 2: Robust standard deviation (using median and MAD)
median_ppl = np.median(valid_ppls)
mad = np.median(np.abs(valid_ppls - median_ppl))
robust_std = mad * 1.4826 # Convert MAD to robust std estimate
threshold_robust_std = median_ppl + k * robust_std
# Method 3: IQR based (Interquartile Range)
q25 = np.percentile(valid_ppls, 25)
q75 = np.percentile(valid_ppls, 75)
iqr = q75 - q25
threshold_iqr = q75 + k * iqr
# Method 4: MAD based (Median Absolute Deviation)
threshold_mad = median_ppl + k * mad
return {
'std': threshold_std,
'robust_std': threshold_robust_std,
'iqr': threshold_iqr,
'mad': threshold_mad
}
def find_ppl_spikes_adaptive(self, values: List[float], method: str = 'std', k: float = 0.2) -> tuple:
"""Find PPL spikes using adaptive threshold based on statistical method"""
thresholds = self.calculate_adaptive_thresholds(values, k)
threshold = thresholds[method]
spike_indices = []
for i in range(1, len(values) - 1):
current = values[i]
left = values[i - 1]
right = values[i + 1]
# Skip infinite or NaN values
if math.isinf(current) or math.isnan(current):
continue
if math.isinf(left) or math.isnan(left):
left = current
if math.isinf(right) or math.isnan(right):
right = current
# Check if current PPL is significantly higher than both neighbors
left_diff = current - left
right_diff = current - right
# Condition: Current PPL is higher than both neighbors with adaptive threshold
if (left_diff >= threshold or right_diff >= threshold) and (left_diff >= 0 and right_diff >= 0):
spike_indices.append(i)
return spike_indices, threshold
def chunk_text_adaptive(self, text: str, method: str = 'std', k: float = 0.2) -> tuple:
"""Perform PPL-based text chunking using adaptive spike detection"""
sentences = self.split_into_sentences(text)
ppls = self.calculate_sentence_ppl(sentences)
spike_indices, threshold = self.find_ppl_spikes_adaptive(ppls, method, k)
chunks = []
# Split at spike points (after the spike line)
split_points = [0] + [idx + 1 for idx in spike_indices] + [len(sentences)]
for i in range(len(split_points) - 1):
start = split_points[i]
end = split_points[i + 1]
chunk_sentences = sentences[start:end]
chunk_text = "\n".join(chunk_sentences)
chunks.append(chunk_text)
return chunks, sentences, ppls, spike_indices
class LongCodeZip:
def __init__(
self,
model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4",
device_map: str = "cuda",
model_config: dict = {},
):
"""
Initialize the LongCodeZip with a language model for compression.
Args:
model_name: The name of the model to load from HuggingFace
device_map: Device to load the model on
model_config: Additional configuration for the model
"""
self.model_name = model_name
self.device = device_map
self.model_config = model_config
self.load_model(model_name, device_map, model_config)
# Initialize Entropy chunking with smaller model
logger.debug("Initializing Entropy chunking...")
self.entropy_chunking = EntropyChunking()
# Add caching system for model outputs and token information
self.cache = {
"token_length": {}, # Cache for token length by text
"encodings": {}, # Cache for tokenizer encodings
"perplexity": {}, # Cache for perplexity calculations
"conditional_ppl": {}, # Cache for conditional perplexity
"context_rankings": {}, # Cache for context rankings
}
self.max_cache_size = 1000 # Limit cache size to prevent memory issues
# set up the max position embeddings and cache bos num
self.max_position_embeddings = getattr(self.model.config, "max_position_embeddings", 4096)
self.cache_bos_num = 10
self.prefix_bos_num = 100
self.context_idxs = []
def load_model(
self, model_name: str, device_map: str = "cuda", model_config: dict = {}
):
"""
Load the language model and tokenizer.
Args:
model_name: The name of the model to load
device_map: Device to load the model on
model_config: Additional configuration for the model
"""
logger.debug(f"Loading model {model_name} on {device_map}")
torch_dtype = torch.bfloat16 if "torch_dtype" not in model_config else model_config["torch_dtype"]
# model_kwargs = {"device_map": device_map, "torch_dtype": torch_dtype, "trust_remote_code": True}
model_kwargs = {"device_map": device_map, "torch_dtype": torch_dtype, "trust_remote_code": True}
for k, v in model_config.items():
model_kwargs[k] = v
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"
logger.debug("Model and tokenizer loaded successfully")
def _manage_cache_size(self, cache_type):
"""
Manage cache size by removing oldest entries when cache exceeds max size.
Args:
cache_type: The type of cache to manage
"""
if len(self.cache[cache_type]) > self.max_cache_size:
# Remove 20% of the oldest entries
remove_count = int(self.max_cache_size * 0.2)
keys_to_remove = list(self.cache[cache_type].keys())[:remove_count]
for key in keys_to_remove:
del self.cache[cache_type][key]
def get_token_length(
self,
text: str,
add_special_tokens: bool = True,
):
"""
Get the number of tokens in the given text.
Args:
text: The text to tokenize
add_special_tokens: Whether to count special tokens
Returns:
The number of tokens
"""
# Create a cache key based on text and parameters
cache_key = f"{text}_{add_special_tokens}"
# Check if result is in cache
if cache_key in self.cache["token_length"]:
return self.cache["token_length"][cache_key]
# Calculate token length if not in cache
token_length = len(self.tokenizer.encode(text, add_special_tokens=add_special_tokens))
# Store in cache
self.cache["token_length"][cache_key] = token_length
self._manage_cache_size("token_length")
return token_length
def get_ppl(
self,
text: str,
granularity: str = "line",
input_ids=None,
attention_mask=None,
past_key_values=None,
return_kv=False,
end=None,
condition_mode: str = "none",
condition_pos_id: int = 0,
):
"""
Calculate perplexity for the given text at line level.
Args:
text: The text to calculate perplexity for
granularity: The granularity of perplexity calculation (line, token, chunk)
input_ids, attention_mask, past_key_values: Optional pre-processed inputs
return_kv: Whether to return key-values
end: End position for calculation
condition_mode: Mode for conditional perplexity (none, prefix)
condition_pos_id: Position ID for condition
Returns:
A dictionary with perplexity scores and processing information
"""
# Create a cache key for this specific perplexity calculation
cache_key = f"{text}_{granularity}_{condition_mode}_{condition_pos_id}"
if past_key_values is None and not return_kv and cache_key in self.cache["perplexity"]:
return self.cache["perplexity"][cache_key]
# Initialize input processing
if input_ids is None:
encoding_key = text
if encoding_key in self.cache["encodings"]:
cached_encoding = self.cache["encodings"][encoding_key]
input_ids = cached_encoding["input_ids"]
attention_mask = cached_encoding["attention_mask"]
else:
encoding = self.tokenizer(
text,
return_tensors="pt",
padding=True
)
input_ids = encoding["input_ids"].to(self.model.device)
attention_mask = encoding["attention_mask"].to(self.model.device)
# Cache the encoding
self.cache["encodings"][encoding_key] = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
self._manage_cache_size("encodings")
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
else:
past_length = 0
if end is None:
end = input_ids.shape[1]
end = min(end, past_length + self.max_position_embeddings)
with torch.no_grad():
outputs = self.model(
input_ids=input_ids[:, past_length:end],
attention_mask=attention_mask[:, :end],
past_key_values=past_key_values,
return_dict=True,
output_hidden_states=True,
use_cache=True,
)
# Get logits and shift
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., past_length+1:end].contiguous()
# Flatten tokens for loss calculation
active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
active_labels = shift_labels.view(-1)[active]
# Calculate loss
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(active_logits, active_labels)
# Apply condition filtering if required
if condition_mode == "prefix":
loss = loss[condition_pos_id:]
segments = [text] if text else []
lines_info = []
# Calculate mean perplexity
mean_loss = loss.mean() if len(loss) > 0 else torch.tensor(0.0)
ppl = torch.exp(mean_loss).item() if mean_loss.item() != float('inf') else float('inf')
result = {
"loss": loss,
"input_ids": input_ids,
"attention_mask": attention_mask,
"lines_info": lines_info,
"segments": segments,
"ppl": ppl,
}
if return_kv:
result["past_key_values"] = outputs.past_key_values
else:
# Cache the result if we're not returning KV cache
self.cache["perplexity"][cache_key] = result
self._manage_cache_size("perplexity")
return result
def __get_lines_info(self, lines, input_ids, loss):
"""
Get information about each line including start/end positions and importance.
Args:
lines: List of lines in the text
input_ids: Token IDs for the entire text
loss: Per-token loss values
Returns:
List of dictionaries with line information
"""
line_info = []
cumulative_tokens = 0
input_ids_list = input_ids.cpu().tolist()
for i, line in enumerate(lines):
if not line.strip():
continue
# Encode each line to find its token length
line_tokens = self.tokenizer.encode(line, add_special_tokens=False)
line_length = len(line_tokens)
# Find position in the tokenized text
start_pos = cumulative_tokens
end_pos = start_pos + line_length
# Calculate mean loss (importance) for this line
# Loss might be shorter than the token IDs due to shifting
if isinstance(loss, torch.Tensor) and start_pos < len(loss) and end_pos <= len(loss):
line_loss = loss[start_pos:end_pos].mean().item()
else:
# Handle edge cases
line_loss = float("inf")
line_info.append({
"line": line,
"start": start_pos,
"end": end_pos,
"importance": line_loss,
"tokens": line_length
})
cumulative_tokens += line_length
return line_info
def get_prefix_length(self, prefix: str, text: str):
"""
Calculate the length of a prefix in tokens when concatenated with a text.
Args:
prefix: The prefix text
text: The main text
Returns:
Length of the prefix in tokens
"""
possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1)
full_input_ids = self.tokenizer(prefix + text[:100], add_special_tokens=False).input_ids
for i in range(possible_prefix_token, len(full_input_ids)):
cur_prefix = self.tokenizer.decode(full_input_ids[:i])
if cur_prefix == prefix:
break
return i
def get_condition_ppl(
self,
text: str,
question: str,
condition_in_question: str = "none",
granularity: str = "line",
):
"""
Calculate perplexity change of a question when given context text.
A positive change means the context helps reduce question perplexity.
Args:
text: The context text
question: The question to evaluate
condition_in_question: Conditioning mode (none, prefix)
granularity: Granularity for perplexity calculation
Returns:
Perplexity change for the question with/without context
"""
# Create a cache key for this conditional perplexity calculation
cache_key = f"{text}_{question}_{condition_in_question}_{granularity}"
if cache_key in self.cache["conditional_ppl"]:
return self.cache["conditional_ppl"][cache_key]
if condition_in_question == "none":
# Just return the perplexity of the text
result = self.get_ppl(
text=text, granularity=granularity, condition_mode="none"
)
ppl_value = result["ppl"]
else:
# First calculate question perplexity without context
question_ppl_without_context = self.get_ppl(
text=question,
granularity=granularity
)["ppl"]
# Then calculate question perplexity with context
question_ppl_with_context = self.get_ppl(
text=text + "\n\n" + question,
granularity=granularity,
condition_mode="prefix",
condition_pos_id=self.get_token_length(text + "\n\n", add_special_tokens=True)
)["ppl"]
# Calculate the change (positive means context helps)
ppl_value = question_ppl_without_context - question_ppl_with_context
# Cache the result
self.cache["conditional_ppl"][cache_key] = ppl_value
self._manage_cache_size("conditional_ppl")
return ppl_value
def control_context_budget(
self,
context_list: List[str],
target_token: float,
question: str = "",
reorder_context: str = "original",
condition_in_question: str = "none",
force_context_ids: List[int] = None,
force_context_number: int = None,
context_budget: str = "+100",
dynamic_context_compression_ratio: float = 0.0,
):
"""
Control token budget for contexts based on relevance ranking, following LongLLMLingua.
Args:
context_list: List of contexts
target_token: Target number of tokens
question: Question for relevance ranking
reorder_context: How to reorder contexts ("original", "importance", "two_stage")
condition_in_question: Mode for conditional ranking
force_context_ids: List of context IDs to always include
force_context_number: Number of contexts to forcibly include
context_budget: String expression to modify target token budget
dynamic_context_compression_ratio: Ratio for dynamic compression (0.0-1.0)
Returns:
Selected contexts, their indices, and dynamic ratios
"""
logger.debug(f"Controlling context budget with target_token={target_token}")
start_time = time.time()
if not context_list:
return [], [], []
# Get token counts for each context
logger.debug("Calculating token lengths for contexts")
context_tokens_length = [self.get_token_length(context) for context in context_list]
# If total tokens already fit within budget, return all contexts
total_tokens = sum(context_tokens_length)
if total_tokens <= target_token:
logger.debug(f"All contexts fit within budget ({total_tokens} <= {target_token})")
end_time = time.time()
logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds")
return context_list, list(range(len(context_list))), [0.0] * len(context_list)
# Rank contexts by relevance if question is provided
logger.debug("Ranking contexts by relevance")
if question:
# Get perplexity change for each context with the question
context_ppl_changes = []
for d, dl in zip(context_list, context_tokens_length):
# Calculate how much this context reduces question perplexity
ppl_change = self.get_condition_ppl(
d,
question,
condition_in_question,
)
# Apply length adjustment factor similar to before
context_ppl_changes.append(ppl_change - dl * 2 / 250 * 0)
# Sort by perplexity change - higher is better (more reduction in question perplexity)
demonstrations_sort = sorted(enumerate(context_ppl_changes), key=lambda x: -x[1])
else:
# Without question, use default ordering
demonstrations_sort = [(i, 0) for i in range(len(context_list))]
# Extract ranking for later use
self.context_idxs.append([x for idx, (x, _) in enumerate(demonstrations_sort)])
# Calculate the target token budget with context_budget expression
if target_token < 0:
target_token = 100
target_token = eval("target_token" + context_budget)
# Initialize selected context tracking
used = force_context_ids if force_context_ids is not None else []
# Select contexts until we reach the token budget
for idx, _ in demonstrations_sort:
if idx >= len(context_tokens_length):
continue
target_token -= context_tokens_length[idx]
if idx not in used:
used.append(idx)
if target_token < 0 or (
force_context_number is not None and len(used) >= force_context_number
):
break
# Store original selection order
original_used = used.copy()
# Reorder contexts if requested
if reorder_context == "original":
used = sorted(used)
elif reorder_context == "two_stage":
l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [
_ for idx, _ in enumerate(used) if idx % 2 == 1
]
used = l + r[::-1]
# Calculate dynamic compression ratios if requested
if dynamic_context_compression_ratio > 0:
N = len(used)
dynamic_ratio = [
i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0
for i in range(-(N - 1), N, 2)
][::-1]
dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)}
dynamic_ratio = [dynamic_ratio_map[i] for i in used]
else:
dynamic_ratio = [0.0] * len(used)
# Build list of selected contexts
selected_contexts = [context_list[idx] for idx in used if idx < len(context_list)]
end_time = time.time()
logger.debug(f"Selected {len(selected_contexts)} contexts out of {len(context_list)}")
logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds")
return selected_contexts, used, dynamic_ratio, demonstrations_sort
def compress_code_file(
self,
code: str,
query: str = "",
instruction: str = "",
rate: float = 0.5,
target_token: float = -1,
language: str = "python",
dynamic_compression_ratio: float = 0.2,
context_budget: str = "+100",
rank_only: bool = False,
fine_ratio: float = None,
fine_grained_importance_method: str = "conditional_ppl",
min_lines_for_fine_grained: int = 5,
importance_beta: float = 0.5,
use_knapsack: bool = True,
):
"""
Compress a code file by first splitting it into function-based chunks and then compressing.
Functions are prioritized based on query relevance, similar to LongLLMLingua.
Args:
code: The code to compress
query: Query to prioritize relevant functions
instruction: Additional instruction to guide compression
rate: Compression rate for coarse-grained (function level) compression (0.0-1.0)
target_token: Target number of tokens (alternative to rate)
language: Programming language of the code
dynamic_compression_ratio: Ratio for dynamic compression
context_budget: String expression to modify token budget
rank_only: If True, just rank and select contexts without fine-grained compression
fine_ratio: Ratio for fine-grained line selection (0.0-1.0). If None, uses `rate`.
fine_grained_importance_method: Method for scoring line importance ('contrastive_perplexity' or 'conditional_ppl'). Defaults to 'conditional_ppl'.
min_lines_for_fine_grained: Minimum number of lines a function must have to undergo fine-grained compression (otherwise kept fully).
importance_beta: Controls how much function importance affects its individual compression rate during fine-grained compression.
use_knapsack: Whether to use knapsack algorithm for block selection (True) or greedy line-by-line approach (False).
Returns:
Compressed code and statistics with the following structure:
{
"original_code": Original uncompressed code,
"compressed_code": Compressed code,
"compressed_prompt": Complete compressed prompt with instruction and query,
"original_tokens": Number of tokens in original code,
"compressed_tokens": Number of tokens in compressed code,
"final_compressed_tokens": Number of tokens in final compressed prompt,
"compression_ratio": Ratio of compressed to original tokens,
"function_compressions": Details about compression for each function,
"selected_functions": Indices of selected functions,
"demonstrations_sort": Ranking of functions by importance,
"compressed_chunks": List of compressed code chunks
"fine_grained_method_used": The method used for fine-grained importance scoring.
}
"""
logger.debug(f"Starting code file compression with rate={rate}, target_token={target_token}, language={language}")
start_time = time.time()
# Split code into function-based chunks
logger.debug("Splitting code into function-based chunks")
code_chunks = self.split_code_by_functions(code, language=language)
logger.debug(f"Split code into {len(code_chunks)} chunks")
# Calculate total tokens
logger.debug("Calculating total tokens")
total_tokens = sum(self.get_token_length(chunk) for chunk in code_chunks)
logger.debug(f"Total tokens: {total_tokens}")
# Determine target_token based on rate if not specified
original_target_token = target_token # Store original value if provided
if target_token <= 0:
if rate <= 0:
# Default target if both rate and target_token are invalid
target_token = int(total_tokens * 0.5)
logger.warning(f"Rate and target_token invalid, defaulting target_token to {target_token}")
else:
target_token = int(total_tokens * rate)
logger.debug(f"Coarse Target tokens: {target_token}")
# Use context budget control to select important functions
logger.debug("Selecting important functions using context budget control")
selected_contexts, selected_indices, dynamic_ratios, demonstrations_sort = self.control_context_budget(
code_chunks,
target_token=target_token,
question=query,
reorder_context="original", # Keep original order to maintain code structure
condition_in_question="prefix",
context_budget=context_budget,
dynamic_context_compression_ratio=dynamic_compression_ratio,
)
# If rank_only is True, just use the selected contexts without further compression
logger.debug("Using rank-only mode: selecting top functions without fine-grained compression")
compressed_chunks = []
compressed_tokens = 0
function_compressions = {}
# Just keep the selected contexts as is
for i, chunk in enumerate(code_chunks):
if i in selected_indices:
compressed_chunks.append(chunk)
chunk_tokens = self.get_token_length(chunk)
compressed_tokens += chunk_tokens
# Store compression info - no actual compression in this mode
function_compressions[i] = {
"original_tokens": chunk_tokens,
"compressed_tokens": chunk_tokens,
"compression_ratio": 1.0,
}
else:
# Skip this function completely
comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//"
omission_text = f"{comment_marker} ... "
compressed_chunks.append(omission_text)
compressed_tokens += self.get_token_length(omission_text)
# Combine compressed chunks
compressed_code = "\n\n".join(compressed_chunks)
# --- Post-join cleanup for consecutive omission markers ---
logger.debug("Cleaning up consecutive omission markers after joining...")
lines = compressed_code.split("\n")
cleaned_lines = []
last_non_empty_line_was_omission = False
comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//"
omission_marker_content = f"{comment_marker} ...".strip() # Content to check against
for line in lines:
stripped_line = line.strip()
if not stripped_line:
# Keep empty lines
cleaned_lines.append(line)
# Don't reset the flag here, wait for a non-empty line
elif stripped_line == omission_marker_content:
if last_non_empty_line_was_omission:
# Skip this consecutive omission marker line
logger.debug(f"Skipping line: '{line}' (consecutive omission)")
continue
else:
# Keep the first omission marker line
cleaned_lines.append(line)
last_non_empty_line_was_omission = True
else:
# Regular code line
cleaned_lines.append(line)
last_non_empty_line_was_omission = False
compressed_code = "\n".join(cleaned_lines)
logger.debug("Cleanup finished.")
# --- End post-join cleanup ---
output = f"{instruction}\n\n{compressed_code}\n\n{query}\n{instruction}"
# Calculate actual compressed tokens
final_compressed_tokens = self.get_token_length(output)
end_time = time.time()
logger.debug(f"Code file compression completed in {end_time - start_time:.2f} seconds")
logger.debug(f"Compression ratio: {compressed_tokens / total_tokens if total_tokens > 0 else 1.0:.2f}")
if rank_only:
return {
"original_code": code,
"compressed_code": compressed_code,
"compressed_prompt": output,
"original_tokens": total_tokens,
"compressed_tokens": compressed_tokens,
"final_compressed_tokens": final_compressed_tokens,
"compression_ratio": compressed_tokens / total_tokens if total_tokens > 0 else 1.0,
"function_compressions": function_compressions,
"selected_functions": selected_indices,
"demonstrations_sort": demonstrations_sort,
"compressed_chunks": compressed_chunks,
"fine_grained_method_used": None,
}
else:
# enter fine-grained compression
logger.debug(f"Starting fine-grained compression on selected functions using method: {fine_grained_importance_method}")
# --- Dynamic Fine-grained Rate Allocation ---
logger.debug("Calculating dynamic fine-grained compression rates...")
# 1. Collect data for selected functions
selected_functions_data = []
importance_map = {idx: score for idx, score in demonstrations_sort} # Map index to score
total_lines_selected = 0
for i in selected_indices:
if i < len(code_chunks):
chunk = code_chunks[i]
# Use simple line splitting for allocation efficiency
lines = chunk.split("\n")
line_count = len(lines)
score = importance_map.get(i, 0.0) # Default score 0 if not found
selected_functions_data.append({
"index": i,
"lines": lines,
"line_count": line_count,
"score": score
})
total_lines_selected += line_count
else:
logger.warning(f"Selected index {i} is out of bounds for code_chunks (length {len(code_chunks)})")
# 2. Calculate overall fine-grained target lines
current_fine_ratio = fine_ratio if fine_ratio is not None else rate # Use rate if fine_ratio not set
if original_target_token > 0: # If target_token was set explicitly, derive ratio from it for fine-grained stage
# Estimate target lines based on the ratio of selected tokens to total tokens, then apply fine_ratio
selected_tokens = sum(self.get_token_length(code_chunks[d['index']]) for d in selected_functions_data)
effective_coarse_rate = selected_tokens / total_tokens if total_tokens > 0 else 1.0
# Use the user-provided fine_ratio, or fall back to rate/coarse target estimate
fine_target_rate = current_fine_ratio
logger.debug(f"Using fine_ratio={fine_target_rate} for fine-grained target calculation.")
target_total_lines = int(total_lines_selected * fine_target_rate)
else: # Calculate target based on fine_ratio/rate directly applied to selected lines
target_total_lines = int(total_lines_selected * current_fine_ratio)
logger.debug(f"Using current_fine_ratio={current_fine_ratio} for fine-grained target calculation.")
logger.debug(f"Total lines in selected functions: {total_lines_selected}")
logger.debug(f"Target total lines after fine-grained compression: {target_total_lines}")
# 3. Separate small and large functions
small_functions = []
large_functions = []
lines_in_small_functions = 0
lines_in_large_functions = 0
for data in selected_functions_data:
if data["line_count"] < min_lines_for_fine_grained:
small_functions.append(data)
lines_in_small_functions += data["line_count"]
else:
large_functions.append(data)
lines_in_large_functions += data["line_count"]
logger.debug(f"Found {len(small_functions)} small functions (< {min_lines_for_fine_grained} lines) with {lines_in_small_functions} total lines (will be kept).")
logger.debug(f"Found {len(large_functions)} large functions (>= {min_lines_for_fine_grained} lines) with {lines_in_large_functions} total lines.")
# 4. Calculate target lines for large functions
target_lines_for_large = max(0, target_total_lines - lines_in_small_functions)
logger.debug(f"Target lines to keep from large functions: {target_lines_for_large}")
# 5. Allocate rates for large functions
function_fine_ratios = {} # Map: index -> individual_fine_ratio
if not large_functions or lines_in_large_functions == 0:
logger.debug("No large functions to compress further or zero lines in large functions.")
global_rate_for_large = 1.0 if lines_in_large_functions > 0 else 0.0 # Should be 0 if lines_in_large_functions is 0
elif target_lines_for_large <= 0:
logger.debug("Target lines for large functions is <= 0. Setting rates to 0.")
global_rate_for_large = 0.0
elif target_lines_for_large >= lines_in_large_functions:
logger.debug("Target lines for large functions >= total lines. Setting rates to 1.0.")
global_rate_for_large = 1.0
else:
global_rate_for_large = target_lines_for_large / lines_in_large_functions
logger.debug(f"Global target rate for large functions: {global_rate_for_large:.4f}")
# Normalize scores for weighting (MinMax scaling)
scores = [d["score"] for d in large_functions]
valid_scores = [s for s in scores if not math.isinf(s) and not math.isnan(s)]
if not valid_scores or max(valid_scores) == min(valid_scores):
logger.debug("Scores are uniform or invalid, using global rate for all large functions.")
for data in large_functions:
function_fine_ratios[data["index"]] = global_rate_for_large
else:
min_score = min(valid_scores)
max_score = max(valid_scores)
normalized_scores = [(s - min_score) / (max_score - min_score) if not math.isinf(s) and not math.isnan(s) else 0.0 for s in scores] # Normalize to [0, 1], default 0 for invalid
# Calculate initial biased rates
initial_rates = []
for norm_score in normalized_scores:
# Bias rate: higher score -> higher rate (closer to 1)
# Beta controls sensitivity. beta=0 -> uniform rate. beta=1 -> max sensitivity.
biased_rate = global_rate_for_large * (1 + importance_beta * (norm_score - 0.5) * 2) # Scale norm_score diff to [-beta, beta]
clamped_rate = max(0.0, min(1.0, biased_rate)) # Clamp to [0, 1]
initial_rates.append(clamped_rate)
# Calculate actual lines kept with initial rates
actual_lines_kept = sum(initial_rates[i] * large_functions[i]["line_count"] for i in range(len(large_functions)))
logger.debug(f"Initial biased rates calculated. Estimated lines kept: {actual_lines_kept:.1f}")
# Adjust rates proportionally to meet target
if actual_lines_kept > 0 and abs(actual_lines_kept - target_lines_for_large) > 1: # Adjust if difference is significant
adjustment_factor = target_lines_for_large / actual_lines_kept
logger.debug(f"Adjusting rates by factor: {adjustment_factor:.4f}")
final_rates = [max(0.0, min(1.0, r * adjustment_factor)) for r in initial_rates] # Adjust and clamp again
else:
logger.debug("Initial rates are close enough or actual_lines_kept is 0, no adjustment needed.")
final_rates = initial_rates
for i, data in enumerate(large_functions):
function_fine_ratios[data["index"]] = final_rates[i]
# Set rate 1.0 for small functions
for data in small_functions:
function_fine_ratios[data["index"]] = 1.0
# --- End Dynamic Allocation ---
# Apply fine-grained compression to each selected function
fine_compressed_chunks = []
compressed_tokens = 0
function_compressions = {}
# Define a smoothing window size for moving average
smoothing_window = 5
# fine_ratio = fine_ratio if fine_ratio is not None else rate # Use the same ratio by default if fine_ratio not specified # Removed, using individual ratios now
# Process each chunk in the original order
# Use tqdm.auto for compatibility
fine_grained_pbar = tqdm(enumerate(code_chunks), total=len(code_chunks), desc="Fine-Grained Compression", leave=False)
for i, chunk in fine_grained_pbar:
# for i, chunk in enumerate(code_chunks):
if i in selected_indices:
# This function was selected during coarse-grained compression
individual_fine_ratio = function_fine_ratios.get(i) # Get dynamically assigned ratio
if individual_fine_ratio is None:
logger.error(f"Missing fine-grained ratio for selected function index {i}. Skipping fine-grained compression for this chunk.")
individual_fine_ratio = 1.0 # Fallback to keep the chunk
# Use Entropy chunking for fine-grained compression instead of simple line splitting
chunks, sentences, ppls, spike_indices = self.entropy_chunking.chunk_text_adaptive(
code_chunks[i], method='std', k=0.2
)
# Use chunks as lines, but preserve all chunks including empty ones to maintain formatting
chunk_lines = chunks # Keep all chunks to preserve \n\n and formatting
chunk_line_count = len([chunk for chunk in chunk_lines if chunk.strip()]) # Count only non-empty for logic
chunk_score = importance_map.get(i, float('nan')) # Get score
logger.debug(f"Processing Func {i}: Entropy Chunks={len(chunk_lines)}, Non-empty={chunk_line_count}, Score={chunk_score:.4f}, Assigned FineRatio={individual_fine_ratio:.4f}")
# Skip fine-grained compression if ratio is 1.0 (or close) or function is small
if individual_fine_ratio >= 0.999 or chunk_line_count < min_lines_for_fine_grained:
note = "Kept (Ratio=1.0)" if individual_fine_ratio >= 0.999 else f"Kept (Small Func < {min_lines_for_fine_grained} lines)"
logger.debug(f" - {note}")
fine_compressed_chunks.append(chunk)
chunk_tokens = self.get_token_length(chunk)
compressed_tokens += chunk_tokens
function_compressions[i] = {
"original_tokens": chunk_tokens,
"compressed_tokens": chunk_tokens,
"compression_ratio": 1.0,
"individual_fine_ratio": individual_fine_ratio,
"note": note,
"importance_method": None # No line importance calculation needed
}
continue # Move to next chunk
# Apply fine-grained compression only if the function is large enough
# and we're not in rank-only mode (already checked) and ratio < 1.0
if chunk_line_count >= min_lines_for_fine_grained and individual_fine_ratio < 0.999:
logger.debug(f" - Applying fine-grained compression with ratio {individual_fine_ratio:.4f}")
fine_grained_pbar.set_description(f"Fine-Grained Compressing Func {i}")
# Calculate target tokens for this function
original_func_tokens = self.get_token_length(chunk)
target_func_tokens = int(original_func_tokens * individual_fine_ratio)
# Calculate importance for each block based on the chosen method
block_importances = []
importance_calculation_start = time.time()
if fine_grained_importance_method == "conditional_ppl":
# Calculate conditional PPL importance for each block
if not query or not query.strip():
logger.warning(f"Query is empty for func {i}, cannot calculate conditional PPL. Assigning 0 importance.")
block_importances = [0.0] * len(chunk_lines)
else:
query_ppl_result = self.get_ppl(query, granularity="line")
query_ppl_without_context = query_ppl_result["ppl"]
if math.isinf(query_ppl_without_context):
logger.warning(f"Base query PPL is infinite for func {i}. Assigning 0 importance to blocks.")
block_importances = [0.0] * len(chunk_lines)
else:
pbar_cond = tqdm(enumerate(chunk_lines), total=len(chunk_lines), desc=f"Func {i} Block CondPPL", leave=False)
for block_idx, block in pbar_cond:
if not block.strip():
block_importances.append(-float('inf')) # Low score for empty blocks
continue
conditional_text = block + "\n\n" + query
prefix_len_text = block + "\n\n"
prefix_len = self.get_token_length(prefix_len_text, add_special_tokens=True)
cond_ppl_result = self.get_ppl(
text=conditional_text,
granularity="line",
condition_mode="prefix",
condition_pos_id=prefix_len - 1
)
ppl_with_context = cond_ppl_result["ppl"]
if math.isinf(ppl_with_context):
ppl_change = -float('inf')
else:
ppl_change = query_ppl_without_context - ppl_with_context
block_importances.append(ppl_change)
pbar_cond.set_description(f"Func {i} Block CondPPL (B{block_idx}: {ppl_change:.2f})")
elif fine_grained_importance_method == "contrastive_perplexity":
# Calculate contrastive PPL importance for each block
fine_grained_pbar.set_description(f"Fine-Grained ContrastivePPL Func {i}")
with torch.no_grad():
pbar = tqdm(enumerate(chunk_lines), total=len(chunk_lines), desc="Block Contrastive PPL", leave=False)
for block_idx, block in pbar:
if not block.strip():
block_importances.append(-float('inf'))
continue
# Build context from previous blocks
prev_context = "\n\n".join(chunk_lines[:block_idx]) if block_idx > 0 else ""
# 1. PPL(Block | prev_blocks)
regular_ppl_condition = prev_context + "\n\n" if prev_context else None
regular_ppl = self._calculate_perplexity_for_contrastive(block, condition_text=regular_ppl_condition)
# 2. PPL(Block | query, prev_blocks)
question_context_parts = [query]
if prev_context:
question_context_parts.append(prev_context)
question_context = "\n\n".join(filter(None, question_context_parts))
cond_ppl_condition = question_context + "\n\n"
cond_ppl = self._calculate_perplexity_for_contrastive(block, condition_text=cond_ppl_condition)
# 3. Importance = PPL(Block|prev) - PPL(Block|Q,prev)
if math.isinf(regular_ppl) or math.isinf(cond_ppl):
importance = -float('inf')
else:
importance = regular_ppl - cond_ppl
block_importances.append(importance)
pbar.set_description(f"Block {block_idx}: {importance:.2f}")
else:
raise ValueError(f"Unsupported fine_grained_importance_method: {fine_grained_importance_method}")
importance_calculation_end = time.time()
logger.debug(f" - Block importance calculation took {importance_calculation_end - importance_calculation_start:.2f}s")
# Identify preserved blocks (function signature, comments, returns)
preserved_block_indices = set()
comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//"
# Find blocks containing function signature
for block_idx, block in enumerate(chunk_lines):
block_lines = block.split('\n')
for line in block_lines:
if line.strip():
# Check for function/class definitions
if any(keyword in line for keyword in ['def ', 'class ', 'function ', 'fn ', 'func ']):
preserved_block_indices.add(block_idx)
break
# Check for function-level comments
if line.strip().startswith(comment_marker):
preserved_block_indices.add(block_idx)
break
# Check for return statements
if 'return ' in line:
preserved_block_indices.add(block_idx)
break
break # Only check first non-empty line of each block
# Choose selection method based on use_knapsack parameter
processing_start = time.time()
if use_knapsack:
# Use knapsack algorithm to select blocks
logger.debug(f" - Using knapsack algorithm for block selection")
selected_block_indices, selection_info = self._knapsack_block_selection(
blocks=chunk_lines,
block_importances=block_importances,
target_tokens=target_func_tokens,
preserved_block_indices=preserved_block_indices,
language=language
)
# Build compressed chunk from selected blocks
compressed_blocks = []
# Determine base indentation for omission markers
base_indentation = ""
for block in chunk_lines:
for line in block.split('\n'):
if line.strip():
match = re.match(r"^(\s*)", line)
if match:
base_indentation = match.group(1)
break
if base_indentation:
break
omission_marker = f"{base_indentation}{comment_marker} ... "
# Build output with omission markers for gaps
last_selected_idx = -1
for block_idx in sorted(selected_block_indices):
# Add omission marker if there's a gap
if last_selected_idx != -1 and block_idx > last_selected_idx + 1:
if not compressed_blocks or compressed_blocks[-1] != omission_marker:
compressed_blocks.append(omission_marker)
compressed_blocks.append(chunk_lines[block_idx])
last_selected_idx = block_idx
# Handle trailing omission if needed
if last_selected_idx != -1 and last_selected_idx < len(chunk_lines) - 1:
if not compressed_blocks or compressed_blocks[-1] != omission_marker:
compressed_blocks.append(omission_marker)
# Join blocks with double newlines to preserve Entropy chunk structure
compressed_chunk = "\n\n".join(compressed_blocks)
else:
# Use original greedy line-by-line approach with smoothing
logger.debug(f" - Using original greedy line-by-line approach")
# Convert block importances to line importances for compatibility
lines = []
line_importances = []
line_indices = []
for block_idx, (block, block_importance) in enumerate(zip(chunk_lines, block_importances)):
block_lines = block.split('\n')
for line_idx_in_block, line in enumerate(block_lines):
global_line_idx = len(lines)
lines.append(line)
line_importances.append(block_importance) # Use block importance for all lines in block
line_indices.append(global_line_idx)
# Apply original processing logic with smoothing
full_line_scores = [float('nan')] * len(lines)
for score_idx, original_line_idx in enumerate(line_indices):
if score_idx < len(line_importances):
full_line_scores[original_line_idx] = line_importances[score_idx]
# Replace NaN/Inf with min valid score for consistent processing
valid_scores = [s for s in full_line_scores if not math.isnan(s) and not math.isinf(s)]
if valid_scores:
min_valid_score = min(valid_scores)
if min_valid_score == float('inf') or min_valid_score == -float('inf') or math.isnan(min_valid_score):
min_replacement_score = 0.0
else:
min_replacement_score = min_valid_score
processed_line_scores = []
for s in full_line_scores:
if math.isnan(s) or s == -float('inf'):
processed_line_scores.append(min_replacement_score)
elif s == float('inf'):
processed_line_scores.append(min_replacement_score)
else:
processed_line_scores.append(s)
else:
processed_line_scores = [0.0] * len(lines)
# Apply smoothing using moving average
smoothing_window = 5
smoothed_importances = processed_line_scores.copy()
num_processed_scores = len(processed_line_scores)
for j in range(num_processed_scores):
window_start = max(0, j - smoothing_window // 2)
window_end = min(num_processed_scores, j + smoothing_window // 2 + 1)
window = processed_line_scores[window_start:window_end]
valid_window_scores = [s for s in window if not math.isnan(s) and not math.isinf(s)]
if valid_window_scores:
smoothed_importances[j] = sum(valid_window_scores) / len(valid_window_scores)
# Find preserved lines (convert block indices to line indices)
preserved_line_indices = set()
line_offset = 0
for block_idx, block in enumerate(chunk_lines):
block_lines = block.split('\n')
if block_idx in preserved_block_indices:
for line_idx_in_block in range(len(block_lines)):
preserved_line_indices.add(line_offset + line_idx_in_block)
line_offset += len(block_lines)
# Sort remaining lines by importance
sortable_lines = []
for idx in range(len(lines)):
if idx not in preserved_line_indices:
if idx < len(line_indices) and idx < len(line_importances):
original_score = line_importances[idx]
if not math.isnan(original_score) and not math.isinf(original_score):
smoothed_score = smoothed_importances[idx]
sortable_lines.append((idx, smoothed_score))
# Sort descending by score
sorted_line_indices = sorted(sortable_lines, key=lambda x: -x[1])
# Calculate target number of lines to keep
total_lines = len(lines)
preserved_count = len(preserved_line_indices)
target_lines = max(preserved_count, int(total_lines * individual_fine_ratio))
# Select top lines by importance up to target
selected_lines = set(preserved_line_indices)
for idx, score in sorted_line_indices:
if len(selected_lines) >= target_lines:
break
selected_lines.add(idx)
# Build compressed chunk from selected lines
compressed_chunks = []
base_indentation = ""
if lines:
for line in lines:
if line.strip():
match = re.match(r"^(\s*)", line)
if match:
base_indentation = match.group(1)
break
omission_marker_line = f"{base_indentation}{comment_marker} ... "
last_added_line_idx = -1
for j in range(len(lines)):
if j in selected_lines:
if last_added_line_idx != -1 and j > last_added_line_idx + 1:
if not compressed_chunks or compressed_chunks[-1] != omission_marker_line:
compressed_chunks.append(omission_marker_line)
compressed_chunks.append(lines[j])
last_added_line_idx = j
if last_added_line_idx != -1 and last_added_line_idx < len(lines) - 1:
if not compressed_chunks or compressed_chunks[-1] != omission_marker_line:
compressed_chunks.append(omission_marker_line)
compressed_chunk = "\n".join(compressed_chunks)
# Create selection info for compatibility
selection_info = {
"method": "greedy_line_by_line",
"preserved_lines": len(preserved_line_indices),
"selected_lines": len(selected_lines),
"total_lines": len(lines),
"smoothing_applied": True
}
selected_block_indices = preserved_block_indices # For compatibility
processing_end = time.time()
method_name = "knapsack" if use_knapsack else "greedy"
logger.debug(f" - {method_name} selection took {processing_end - processing_start:.2f}s")
if use_knapsack:
logger.debug(f" - Selected {len(selected_block_indices)}/{len(chunk_lines)} blocks")
else:
logger.debug(f" - Selected {len(selected_lines)}/{len(lines)} lines")
# Update token count and store compression info
fine_compressed_chunks.append(compressed_chunk)
compressed_chunk_tokens = self.get_token_length(compressed_chunk)
compressed_tokens += compressed_chunk_tokens
# Store compression info
actual_compression_ratio = compressed_chunk_tokens / original_func_tokens if original_func_tokens > 0 else 1.0
function_compressions[i] = {
"original_tokens": original_func_tokens,
"compressed_tokens": compressed_chunk_tokens,
"compression_ratio": actual_compression_ratio,
"individual_fine_ratio": individual_fine_ratio,
"preserved_blocks": list(preserved_block_indices),
"selected_blocks": list(selected_block_indices),
"selection_info": selection_info,
"importance_method": fine_grained_importance_method,
"selection_method": "knapsack" if use_knapsack else "greedy_line_by_line"
}
logger.debug(f" - Compressed func {i}: {original_func_tokens} -> {compressed_chunk_tokens} tokens (Ratio: {actual_compression_ratio:.3f})")
else:
# This case should now be handled by the check at the beginning of the loop
logger.warning(f"Reached unexpected state for func {i}. Keeping chunk as is.")
fine_compressed_chunks.append(chunk)
chunk_tokens = self.get_token_length(chunk)
compressed_tokens += chunk_tokens
function_compressions[i] = {
"original_tokens": chunk_tokens,
"compressed_tokens": chunk_tokens,
"compression_ratio": 1.0,
"individual_fine_ratio": individual_fine_ratio,
"note": "Unexpected state, kept function.",
"importance_method": None
}
else:
# This function was not selected during coarse-grained compression
# Add a placeholder
comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//"
omission_text = f"{comment_marker} ... "
fine_compressed_chunks.append(omission_text)
compressed_tokens += self.get_token_length(omission_text)
# Log skipped chunk
# logger.debug(f"Skipped Func {i} (not selected in coarse stage)")
# Combine fine-grained compressed chunks
compressed_code = "\n\n".join(fine_compressed_chunks)
# --- Post-join cleanup for consecutive omission markers ---
logger.debug("Cleaning up consecutive omission markers after joining...")
lines = compressed_code.split("\n")
cleaned_lines = []
last_non_empty_line_was_omission = False
comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//"
omission_marker_content = f"{comment_marker} ...".strip() # Content to check against
for line in lines:
stripped_line = line.strip()
if not stripped_line:
# Keep empty lines
cleaned_lines.append(line)
# Don't reset the flag here, wait for a non-empty line
elif stripped_line == omission_marker_content:
if last_non_empty_line_was_omission:
# Skip this consecutive omission marker line
logger.debug(f"Skipping line: '{line}' (consecutive omission)")
continue
else:
# Keep the first omission marker line
cleaned_lines.append(line)
last_non_empty_line_was_omission = True
else:
# Regular code line
cleaned_lines.append(line)
last_non_empty_line_was_omission = False
compressed_code = "\n".join(cleaned_lines)
logger.debug("Cleanup finished.")
# --- End post-join cleanup ---
# Ensure instruction/query parts are handled correctly, maybe use a template
prompt_parts = []
if instruction and instruction.strip():
prompt_parts.append(instruction.strip())
if compressed_code.strip():
prompt_parts.append(compressed_code) # Already has newlines handled
if query and query.strip():
# Add query, potentially repeating instruction based on original logic
prompt_parts.append(query.strip())
# Decide if instruction should be repeated after query based on original implementation's needs
# if instruction and instruction.strip(): # Repeat instruction if needed
# prompt_parts.append(instruction.strip())
output = "\n\n".join(prompt_parts) # Use double newline separation
# Calculate final compressed tokens
final_compressed_tokens = self.get_token_length(output)
end_time = time.time()
logger.debug(f"Fine-grained compression processing completed in {end_time - start_time:.2f} seconds")
final_compression_ratio = compressed_tokens / total_tokens if total_tokens > 0 else 1.0
logger.debug(f"Final Compression ratio (fine-grained tokens / total original tokens): {final_compression_ratio:.4f}")
return {
"original_code": code,
"compressed_code": compressed_code,
"compressed_prompt": output,
"original_tokens": total_tokens,
"compressed_tokens": compressed_tokens,
"final_compressed_tokens": final_compressed_tokens,
"compression_ratio": final_compression_ratio,
"function_compressions": function_compressions,
"selected_functions": selected_indices,
"demonstrations_sort": demonstrations_sort,
"compressed_chunks": fine_compressed_chunks,
"fine_grained_method_used": fine_grained_importance_method,
}
def split_code_by_functions(self, code: str, language: str = "python", custom_separator: str = "# --CHUNK_SEPARATOR-- #") -> List[str]:
"""
Split code into chunks based on function and class definitions for various languages.
Also splits on custom separator if provided.
Args:
code: The code to split
language: Programming language of the code (python, cpp, java, typescript, rust, go)
custom_separator: Optional custom separator string to also split on
Returns:
List of code chunks, each containing a function, class, or class method
"""
logger.debug(f"Splitting code by functions and classes for language: {language}")
start_time = time.time()
# Define regex patterns for different languages
patterns = {
# Python: Simplified to match 'def' or 'class' followed by content until the next def/class or end
"python": r'(^|\n)(\s*)(def|class)\s+[^\n]+(\n(?!\s*(?:def|class)\s)[^\n]*)*',
# C++: Improved to better handle multi-line declarations
"cpp": r'(^|\n)(\s*)(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s*:\s*[^{]*)?|(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*\s+)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?',
# Java: Improved for multi-line method declarations
"java": r'(^|\n)(\s*)(?:(?:public|private|protected|static|final|abstract|synchronized)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:<.*>)?(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*)\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*throws\s+[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?',
# TypeScript: Enhanced to handle multi-line methods and arrow functions
"typescript": r'(^|\n)(\s*)(?:(?:public|private|protected|static|abstract)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:(?:public|private|protected|static|async)\s+)*(?:function\s+)?(?:[a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<.*>)?\s*\([^{;]*\)\s*(?::\s*[^{;]*\s*)?(?:=>)?)\s*(?:{[^}]*}|[^;]*;)?',
# Rust: Improved for multi-line function declarations
"rust": r'(^|\n)(\s*)(?:pub\s+)?(?:struct\s+[a-zA-Z_][a-zA-Z0-9_]*|impl(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+for\s+[a-zA-Z_][a-zA-Z0-9_]*)?|(?:async\s+)?fn\s+[a-zA-Z_][a-zA-Z0-9_]*\s*(?:<.*>)?\s*\([^{;]*\)(?:\s*->\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?',
# Go: Improved for multi-line function declarations
"go": r'(^|\n)(\s*)(?:type\s+[a-zA-Z_][a-zA-Z0-9_]*\s+struct|func\s+(?:\([^)]*\)\s*)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?',
}
# Use default Python pattern if language not supported
if language.lower() not in patterns:
language = "python"
# First check if we need to split by custom separator
separator_chunks = []
if custom_separator and custom_separator in code:
logger.debug(f"Custom separator '{custom_separator}' found, first splitting by separator")
separator_chunks = [chunk for chunk in code.split(custom_separator) if chunk.strip()]
else:
separator_chunks = [code] # Just one chunk - the entire code
# Function to split a single chunk by functions/classes
def split_chunk_by_pattern(chunk_code):
function_pattern = re.compile(patterns[language.lower()], re.MULTILINE)
matches = list(function_pattern.finditer(chunk_code))
if not matches:
return [chunk_code] # No matches, return whole chunk
result_chunks = []
# Add code before first match
if matches[0].start() > 0:
result_chunks.append(chunk_code[:matches[0].start()])
# Process each match
for i, match in enumerate(matches):
start = match.start()
# End is either start of next match or end of code
if i < len(matches) - 1:
end = matches[i + 1].start()
else:
end = len(chunk_code)
result_chunks.append(chunk_code[start:end])
return result_chunks
# Now apply function/class splitting to each separator chunk
final_chunks = []
for chunk in separator_chunks:
function_chunks = split_chunk_by_pattern(chunk)
final_chunks.extend(function_chunks)
end_time = time.time()
logger.debug(f"Code splitting completed in {end_time - start_time:.2f} seconds")
logger.debug(f"Split code into {len(final_chunks)} chunks (using both separator and patterns)")
return final_chunks
def _calculate_perplexity_for_contrastive(self, text, condition_text=None):
"""Helper to calculate perplexity of text, optionally conditioned on condition_text"""
if condition_text:
full_text = condition_text + text
inputs = self.tokenizer(full_text, return_tensors="pt", add_special_tokens=True).to(self.device) # Use add_special_tokens=True for consistency
condition_input_ids = self.tokenizer(condition_text, return_tensors="pt", add_special_tokens=True).input_ids
condition_length = condition_input_ids.size(1)
# Handle potential edge case where condition length might exceed max length or input length
if condition_length >= inputs.input_ids.size(1):
logger.warning(f"Condition length ({condition_length}) >= input length ({inputs.input_ids.size(1)}). Cannot calculate conditional PPL.")
return float('inf')
with torch.no_grad():
outputs = self.model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) # Pass attention_mask
# Logits for the 'text' part, labels are the 'text' part shifted
logits = outputs.logits[0, condition_length-1:-1]
labels = inputs.input_ids[0, condition_length:]
if logits.size(0) == 0 or labels.size(0) == 0 or logits.size(0) != labels.size(0):
logger.warning(f"Logits/Labels shape mismatch or empty in _calculate_perplexity_for_contrastive (cond). Logits: {logits.shape}, Labels: {labels.shape}. Returning inf.")
return float('inf') # Return inf if shapes mismatch or empty
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
mean_loss = loss.mean().item()
perplexity = math.exp(mean_loss) if not math.isnan(mean_loss) and not math.isinf(mean_loss) else float('inf')
else:
# Calculate unconditional perplexity
inputs = self.tokenizer(text, return_tensors="pt", add_special_tokens=True).to(self.device) # Use add_special_tokens=True
with torch.no_grad():
outputs = self.model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) # Pass attention_mask
# Logits for all tokens except last, labels are all tokens except first
logits = outputs.logits[0, :-1]
labels = inputs.input_ids[0, 1:]
if logits.size(0) == 0 or labels.size(0) == 0 or logits.size(0) != labels.size(0):
logger.warning(f"Logits/Labels shape mismatch or empty in _calculate_perplexity_for_contrastive (uncond). Logits: {logits.shape}, Labels: {labels.shape}. Returning inf.")
return float('inf') # Return inf if shapes mismatch or empty
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
mean_loss = loss.mean().item()
perplexity = math.exp(mean_loss) if not math.isnan(mean_loss) and not math.isinf(mean_loss) else float('inf')
return perplexity
def _calculate_contrastive_perplexity(self, code_lines: List[str], question: str):
"""
Calculate contrastive perplexity-based importance for each line of code.
s_i = perplexity(x_i | x_{<i}) - perplexity(x_i | x^{que}, x_{<i})
Higher score means the question helps predict the line more.
Args:
code_lines: List of code lines to analyze
question: The query/question text
Returns:
Tuple of (line_scores, scored_indices)
"""
logger.debug("Calculating contrastive perplexity-based line importance...")
line_scores = []
scored_indices = []
with torch.no_grad():
# Use tqdm.auto for better compatibility
pbar = tqdm(enumerate(code_lines), total=len(code_lines), desc="Contrastive PPL", leave=False)
for i, line in pbar:
if not line.strip():
continue # Skip empty lines
# Ensure line has content before proceeding
if not line:
logger.debug(f"Skipping empty line {i}")
continue
# 1. PPL(L_i | L_<i)
prev_context = "\n".join(code_lines[:i])
# Add newline only if previous context exists
regular_ppl_condition = prev_context + "\n" if prev_context else None
regular_ppl = self._calculate_perplexity_for_contrastive(line, condition_text=regular_ppl_condition)
# 2. PPL(L_i | Q, L_<i)
# Combine question and previous context carefully
question_context_parts = [question]
if prev_context:
question_context_parts.append(prev_context)
# Join with double newline between Q and prev_context if both exist
question_context = "\n\n".join(filter(None, question_context_parts))
# Add trailing newline before the target line
cond_ppl_condition = question_context + "\n"
cond_ppl = self._calculate_perplexity_for_contrastive(line, condition_text=cond_ppl_condition)
# 3. Importance = PPL(L|prev) - PPL(L|Q,prev)
if math.isinf(regular_ppl) or math.isinf(cond_ppl):
# If either is infinite, the difference isn't well-defined for ranking.
# Assign a very low score, potentially based on which one is inf.
# If regular_ppl is inf, question might still help (cond_ppl could be finite).
# If cond_ppl is inf, question made it worse or impossible to predict.
# Let's assign -inf for simplicity, meaning "least important".
importance = -float('inf')
logger.debug(f"Line {i}: Inf PPL detected. Regular: {regular_ppl}, Conditional: {cond_ppl}. Importance set to -inf")
else:
importance = regular_ppl - cond_ppl
logger.debug(f"Line {i}: PPL(L|prev)={regular_ppl:.4f}, PPL(L|Q,prev)={cond_ppl:.4f}, Importance={importance:.4f}")
line_scores.append(importance)
scored_indices.append(i)
# Update tqdm description if needed, e.g., with last score
# pbar.set_description(f"Contrastive PPL (L{i}: {importance:.2f})")
logger.debug(f"Finished calculating contrastive PPL for {len(line_scores)} lines.")
return line_scores, scored_indices
def _knapsack_block_selection(
self,
blocks: List[str],
block_importances: List[float],
target_tokens: int,
preserved_block_indices: set = None,
language: str = "python"
) -> Tuple[set, Dict]:
"""
Use knapsack algorithm to select blocks that maximize total importance within token budget.
Args:
blocks: List of code blocks (Entropy chunks)
block_importances: Importance scores for each block
target_tokens: Target number of tokens to keep
preserved_block_indices: Set of block indices that must be preserved
language: Programming language for omission markers
Returns:
Tuple of (selected_block_indices, selection_info)
"""
logger.debug(f"Running knapsack block selection with target_tokens={target_tokens}")
if not blocks:
return set(), {}
# Calculate token weights for each block
block_weights = [self.get_token_length(block) for block in blocks]
# Handle preserved blocks
if preserved_block_indices is None:
preserved_block_indices = set()
# Calculate tokens already used by preserved blocks
preserved_tokens = sum(block_weights[i] for i in preserved_block_indices)
remaining_budget = max(0, target_tokens - preserved_tokens)
logger.debug(f"Preserved blocks: {len(preserved_block_indices)}, tokens: {preserved_tokens}")
logger.debug(f"Remaining budget for knapsack: {remaining_budget}")
# If no remaining budget, just return preserved blocks
if remaining_budget <= 0:
return preserved_block_indices, {
"method": "knapsack",
"preserved_only": True,
"total_value": sum(block_importances[i] for i in preserved_block_indices),
"total_weight": preserved_tokens
}
# Prepare items for knapsack (excluding preserved blocks)
knapsack_items = []
for i, (weight, value) in enumerate(zip(block_weights, block_importances)):
if i not in preserved_block_indices:
# Handle invalid importance scores
if math.isnan(value) or math.isinf(value):
value = 0.0
knapsack_items.append((i, weight, value))
# Sort by value-to-weight ratio for efficiency (greedy approximation first)
knapsack_items.sort(key=lambda x: x[2] / max(x[1], 1), reverse=True)
# Use dynamic programming for exact knapsack solution
# For efficiency, limit to reasonable problem size
if len(knapsack_items) <= 100 and remaining_budget <= 2000:
selected_indices = self._solve_knapsack_dp(knapsack_items, remaining_budget)
else:
# Use greedy approximation for large problems
logger.debug("Using greedy approximation for large knapsack problem")
selected_indices = self._solve_knapsack_greedy(knapsack_items, remaining_budget)
# Combine with preserved blocks
final_selection = preserved_block_indices.union(selected_indices)
# Calculate selection statistics
total_value = sum(block_importances[i] for i in final_selection)
total_weight = sum(block_weights[i] for i in final_selection)
selection_info = {
"method": "knapsack",
"preserved_blocks": len(preserved_block_indices),
"selected_blocks": len(selected_indices),
"total_blocks": len(final_selection),
"total_value": total_value,
"total_weight": total_weight,
"target_weight": target_tokens,
"efficiency": total_value / max(total_weight, 1)
}
logger.debug(f"Knapsack selection: {len(final_selection)}/{len(blocks)} blocks, "
f"value={total_value:.2f}, weight={total_weight}/{target_tokens}")
return final_selection, selection_info
def _solve_knapsack_dp(self, items: List[Tuple[int, int, float]], capacity: int) -> set:
"""
Solve knapsack problem using dynamic programming.
Args:
items: List of (index, weight, value) tuples
capacity: Maximum weight capacity
Returns:
Set of selected item indices
"""
n = len(items)
if n == 0 or capacity <= 0:
return set()
# DP table: dp[i][w] = maximum value using first i items with weight limit w
dp = [[0.0 for _ in range(capacity + 1)] for _ in range(n + 1)]
# Fill DP table
for i in range(1, n + 1):
idx, weight, value = items[i - 1]
for w in range(capacity + 1):
# Don't take item i
dp[i][w] = dp[i - 1][w]
# Take item i if it fits
if weight <= w:
dp[i][w] = max(dp[i][w], dp[i - 1][w - weight] + value)
# Backtrack to find selected items
selected = set()
w = capacity
for i in range(n, 0, -1):
if dp[i][w] != dp[i - 1][w]:
idx, weight, value = items[i - 1]
selected.add(idx)
w -= weight
return selected
def _solve_knapsack_greedy(self, items: List[Tuple[int, int, float]], capacity: int) -> set:
"""
Solve knapsack problem using greedy approximation (by value/weight ratio).
Args:
items: List of (index, weight, value) tuples (should be pre-sorted by ratio)
capacity: Maximum weight capacity
Returns:
Set of selected item indices
"""
selected = set()
current_weight = 0
for idx, weight, value in items:
if current_weight + weight <= capacity:
selected.add(idx)
current_weight += weight
return selected
if __name__ == "__main__":
context = """
def add(a, b):
return a + b
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)
def search_with_binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
"""
question = "How to write a quick sort algorithm?"
# Initialize compressor
logger.info("Initializing compressor...")
model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"
compressor = LongCodeZip(model_name=model_name)
# Test function-based code file compression with query
logger.info("\nTesting function-based code file compression with query...")
original_tokens = len(compressor.tokenizer.encode(context))
target_token = 64
target_ratio = min(1.0, max(0.0, target_token / original_tokens))
logger.info(f"LongCodeZip: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}")
result = compressor.compress_code_file(
code=context,
query=question, # Using current function context as query focus
instruction="Complete the following code function given the context.",
rate=target_ratio,
rank_only=True, # Only use coarse-grained compression
fine_grained_importance_method="conditional_ppl", # Explicitly test default
min_lines_for_fine_grained=5, # Min number of lines for fine-grained compression
importance_beta=0.5, # Sensitivity to importance score
use_knapsack=True,
)
# show the compressed code
logger.info(f"Compressed code (using {result['fine_grained_method_used']}): \n{result['compressed_code']}")
logger.info(f"Current function context: \n{question}")
# final prompt
final_prompt = result['compressed_prompt']
# get the completion
tokenized_prompt = compressor.tokenizer(final_prompt, return_tensors="pt").to(compressor.device)
# Increase max_new_tokens for potentially longer completions
completion_ids = compressor.model.generate(**tokenized_prompt, max_new_tokens=128, pad_token_id=compressor.tokenizer.eos_token_id)
# Decode only the generated part, skipping special tokens
completion = compressor.tokenizer.decode(completion_ids[0][len(tokenized_prompt.input_ids[0]):], skip_special_tokens=True)
# Basic cleanup: remove leading/trailing whitespace and potentially stop words if needed
completion = completion.strip()
# More robust cleanup: Find the first meaningful line if generation includes noise
completion_lines = [line for line in completion.split("\n") if line.strip() and not line.strip().startswith(("#", "//"))] # Simple comment removal
cleaned_completion = completion_lines[0] if completion_lines else completion # Take first non-comment line or original if none found
logger.info(f"Cleaned Completion: {cleaned_completion}")
# Optional: Test with conditional_ppl method
logger.info("\nTesting fine-grained compression with conditional_ppl...")
result_cond = compressor.compress_code_file(
code=context,
query=question,
instruction="Complete the following code function given the context.",
rate=target_ratio,
rank_only=False,
fine_grained_importance_method="conditional_ppl",
min_lines_for_fine_grained=5,
importance_beta=0.5
)
logger.info(f"Compressed code (using {result_cond['fine_grained_method_used']}): \n{result_cond['compressed_code']}")