Files
LongCodeZip/experiments/repo-qa/code_compressor.py
YerbaPage a391badfe1 packaging
2025-10-11 21:33:12 +08:00

1543 lines
65 KiB
Python

import torch
import numpy as np
from typing import List, Union, Tuple, Dict, Optional
import re
import math
import zlib
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
from tqdm import tqdm
import logging
import copy
import bisect
import json
# set up the logger
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("CodeCompressor")
class CodeCompressor:
def __init__(
self,
model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4",
device_map: str = "cuda",
model_config: dict = {},
):
"""
Initialize the CodeCompressor with a language model for compression.
Args:
model_name: The name of the model to load from HuggingFace
device_map: Device to load the model on
model_config: Additional configuration for the model
"""
self.model_name = model_name
self.device = device_map
self.model_config = model_config
self.load_model(model_name, device_map, model_config)
# Add caching system for model outputs and token information
self.cache = {
"token_length": {}, # Cache for token length by text
"encodings": {}, # Cache for tokenizer encodings
"perplexity": {}, # Cache for perplexity calculations
"conditional_ppl": {}, # Cache for conditional perplexity
"context_rankings": {}, # Cache for context rankings
}
self.max_cache_size = 1000 # Limit cache size to prevent memory issues
# set up the max position embeddings and cache bos num
self.max_position_embeddings = getattr(self.model.config, "max_position_embeddings", 4096)
self.cache_bos_num = 10
self.prefix_bos_num = 100
self.context_idxs = []
def load_model(
self, model_name: str, device_map: str = "cuda", model_config: dict = {}
):
"""
Load the language model and tokenizer.
Args:
model_name: The name of the model to load
device_map: Device to load the model on
model_config: Additional configuration for the model
"""
logger.debug(f"Loading model {model_name} on {device_map}")
torch_dtype = torch.float16 if "torch_dtype" not in model_config else model_config["torch_dtype"]
model_kwargs = {"device_map": device_map, "torch_dtype": torch_dtype}
for k, v in model_config.items():
if k != "torch_dtype":
model_kwargs[k] = v
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"
self.tokenizer_is_gpt = "gpt" in model_name.lower()
logger.debug("Model and tokenizer loaded successfully")
def _manage_cache_size(self, cache_type):
"""
Manage cache size by removing oldest entries when cache exceeds max size.
Args:
cache_type: The type of cache to manage
"""
if len(self.cache[cache_type]) > self.max_cache_size:
# Remove 20% of the oldest entries
remove_count = int(self.max_cache_size * 0.2)
keys_to_remove = list(self.cache[cache_type].keys())[:remove_count]
for key in keys_to_remove:
del self.cache[cache_type][key]
def get_token_length(
self,
text: str,
add_special_tokens: bool = True,
):
"""
Get the number of tokens in the given text.
Args:
text: The text to tokenize
add_special_tokens: Whether to count special tokens
Returns:
The number of tokens
"""
# Create a cache key based on text and parameters
cache_key = f"{text}_{add_special_tokens}"
# Check if result is in cache
if cache_key in self.cache["token_length"]:
return self.cache["token_length"][cache_key]
# Calculate token length if not in cache
token_length = len(self.tokenizer.encode(text, add_special_tokens=add_special_tokens))
# Store in cache
self.cache["token_length"][cache_key] = token_length
self._manage_cache_size("token_length")
return token_length
def get_ppl(
self,
text: str,
granularity: str = "line",
input_ids=None,
attention_mask=None,
past_key_values=None,
return_kv=False,
end=None,
condition_mode: str = "none",
condition_pos_id: int = 0,
):
"""
Calculate perplexity for the given text at line level.
Args:
text: The text to calculate perplexity for
granularity: The granularity of perplexity calculation (line, token, chunk)
input_ids, attention_mask, past_key_values: Optional pre-processed inputs
return_kv: Whether to return key-values
end: End position for calculation
condition_mode: Mode for conditional perplexity (none, prefix)
condition_pos_id: Position ID for condition
Returns:
A dictionary with perplexity scores and processing information
"""
# Create a cache key for this specific perplexity calculation
cache_key = f"{text}_{granularity}_{condition_mode}_{condition_pos_id}"
if past_key_values is None and not return_kv and cache_key in self.cache["perplexity"]:
return self.cache["perplexity"][cache_key]
# Initialize input processing
if input_ids is None:
encoding_key = text
if encoding_key in self.cache["encodings"]:
cached_encoding = self.cache["encodings"][encoding_key]
input_ids = cached_encoding["input_ids"]
attention_mask = cached_encoding["attention_mask"]
else:
encoding = self.tokenizer(
text,
return_tensors="pt",
padding=True
)
input_ids = encoding["input_ids"].to(self.model.device)
attention_mask = encoding["attention_mask"].to(self.model.device)
# Cache the encoding
self.cache["encodings"][encoding_key] = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
self._manage_cache_size("encodings")
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
else:
past_length = 0
if end is None:
end = input_ids.shape[1]
end = min(end, past_length + self.max_position_embeddings)
with torch.no_grad():
outputs = self.model(
input_ids=input_ids[:, past_length:end],
attention_mask=attention_mask[:, :end],
past_key_values=past_key_values,
return_dict=True,
output_hidden_states=True,
use_cache=True,
)
# Get logits and shift
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., past_length+1:end].contiguous()
# Flatten tokens for loss calculation
active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
active_labels = shift_labels.view(-1)[active]
# Calculate loss
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(active_logits, active_labels)
# Apply condition filtering if required
if condition_mode == "prefix":
loss = loss[condition_pos_id:]
# Process based on granularity
if granularity == "token":
result_loss = loss
else:
result_loss = loss.mean()
# Split text into lines for line-level granularity
if granularity == "line" and text:
segments = text.split("\n")
segments = [seg for seg in segments if seg.strip()]
lines_info = self.__get_lines_info(segments, input_ids[0], loss)
else:
segments = [text] if text else []
lines_info = []
# Calculate mean perplexity
mean_loss = loss.mean() if len(loss) > 0 else torch.tensor(0.0)
ppl = torch.exp(mean_loss).item() if mean_loss.item() != float('inf') else float('inf')
result = {
"loss": loss,
"input_ids": input_ids,
"attention_mask": attention_mask,
"lines_info": lines_info,
"segments": segments,
"ppl": ppl,
}
if return_kv:
result["past_key_values"] = outputs.past_key_values
else:
# Cache the result if we're not returning KV cache
self.cache["perplexity"][cache_key] = result
self._manage_cache_size("perplexity")
return result
def __get_lines_info(self, lines, input_ids, loss):
"""
Get information about each line including start/end positions and importance.
Args:
lines: List of lines in the text
input_ids: Token IDs for the entire text
loss: Per-token loss values
Returns:
List of dictionaries with line information
"""
line_info = []
cumulative_tokens = 0
input_ids_list = input_ids.cpu().tolist()
for i, line in enumerate(lines):
if not line.strip():
continue
# Encode each line to find its token length
line_tokens = self.tokenizer.encode(line, add_special_tokens=False)
line_length = len(line_tokens)
# Find position in the tokenized text
start_pos = cumulative_tokens
end_pos = start_pos + line_length
# Calculate mean loss (importance) for this line
# Loss might be shorter than the token IDs due to shifting
if isinstance(loss, torch.Tensor) and start_pos < len(loss) and end_pos <= len(loss):
line_loss = loss[start_pos:end_pos].mean().item()
else:
# Handle edge cases
line_loss = float("inf")
line_info.append({
"line": line,
"start": start_pos,
"end": end_pos,
"importance": line_loss,
"tokens": line_length
})
cumulative_tokens += line_length
return line_info
def get_prefix_length(self, prefix: str, text: str):
"""
Calculate the length of a prefix in tokens when concatenated with a text.
Args:
prefix: The prefix text
text: The main text
Returns:
Length of the prefix in tokens
"""
possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1)
full_input_ids = self.tokenizer(prefix + text[:100], add_special_tokens=False).input_ids
for i in range(possible_prefix_token, len(full_input_ids)):
cur_prefix = self.tokenizer.decode(full_input_ids[:i])
if cur_prefix == prefix:
break
return i
def get_condition_ppl(
self,
text: str,
question: str,
condition_in_question: str = "none",
granularity: str = "line",
):
"""
Calculate perplexity change of a question when given context text.
A positive change means the context helps reduce question perplexity.
Args:
text: The context text
question: The question to evaluate
condition_in_question: Conditioning mode (none, prefix)
granularity: Granularity for perplexity calculation
Returns:
Perplexity change for the question with/without context
"""
# Create a cache key for this conditional perplexity calculation
cache_key = f"{text}_{question}_{condition_in_question}_{granularity}"
if cache_key in self.cache["conditional_ppl"]:
return self.cache["conditional_ppl"][cache_key]
if condition_in_question == "none":
# Just return the perplexity of the text
result = self.get_ppl(
text=text, granularity=granularity, condition_mode="none"
)
ppl_value = result["ppl"]
else:
# First calculate question perplexity without context
question_ppl_without_context = self.get_ppl(
text=question,
granularity=granularity
)["ppl"]
# Then calculate question perplexity with context
question_ppl_with_context = self.get_ppl(
text=text + "\n\n" + question,
granularity=granularity,
condition_mode="prefix",
condition_pos_id=self.get_token_length(text + "\n\n", add_special_tokens=True)
)["ppl"]
# Calculate the change (positive means context helps)
ppl_value = question_ppl_without_context - question_ppl_with_context
# Cache the result
self.cache["conditional_ppl"][cache_key] = ppl_value
self._manage_cache_size("conditional_ppl")
return ppl_value
def get_estimate_threshold_base_distribution(
self, ppl_values, ratio: float, condition_flag: bool = False
):
"""
Estimate threshold value for compression based on perplexity distribution.
Args:
ppl_values: Perplexity values for tokens or lines
ratio: Compression ratio (0.0-1.0)
condition_flag: Whether values are conditional (affecting sorting direction)
Returns:
Threshold value for filtering
"""
if ratio >= 1.0:
return float("-inf")
if isinstance(ppl_values, torch.Tensor):
# Filter out extreme values that might skew the threshold
valid_values = ppl_values[ppl_values != float('inf')]
valid_values = valid_values[valid_values != -float('inf')]
valid_values = valid_values[~torch.isnan(valid_values)]
if len(valid_values) == 0:
return 0.0
# Calculate the target position for the percentile
target_token = max(0, min(len(valid_values) - 1, int(len(valid_values) * ratio) - 1))
# Sort values based on condition_flag and get threshold
sort_values = valid_values.sort(descending=not condition_flag).values
if target_token < len(sort_values):
return sort_values[target_token].item()
return 0.0
else:
# Handle non-tensor inputs (lists, numpy arrays)
valid_values = [v for v in ppl_values if v != float('inf') and v != -float('inf') and not math.isnan(v)]
if not valid_values:
return 0.0
# Calculate the target position for the percentile
target_idx = max(0, min(len(valid_values) - 1, int(len(valid_values) * ratio) - 1))
# Sort values and get threshold
sorted_values = sorted(valid_values, reverse=not condition_flag)
if target_idx < len(sorted_values):
return sorted_values[target_idx]
return 0.0
def get_dynamic_compression_ratio(
self,
context: list,
target_token: float,
iterative_size: int,
dynamic_ratio: list,
start: int,
):
"""
Calculate dynamic compression ratios for iterative compression.
Args:
context: List of context strings
target_token: Target number of tokens
iterative_size: Size of each iteration
dynamic_ratio: List of dynamic ratio adjustments
start: Start position for processing
Returns:
List of ratios for each iteration chunk
"""
def get_ratio(base: float, delta: float):
return max(min(1, base + delta), 0)
context_length = [self.get_token_length(ii, False) + 2 for ii in context]
if start:
context_length = context_length[1:]
tau = target_token / (sum(context_length) + 1)
res, idx, last, last_target = [], 0, 1, []
while idx < len(context_length):
if last + context_length[idx] >= iterative_size:
last_target.append(
(iterative_size - last, get_ratio(tau, dynamic_ratio[idx]))
)
res.append(last_target)
last = last + context_length[idx] - iterative_size
if last > iterative_size:
k = last // iterative_size
res.extend(
[[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k
)
last -= k * iterative_size
last_target = (
[(last, get_ratio(tau, dynamic_ratio[idx]))] if last else []
)
else:
last += context_length[idx]
last_target.append(
(context_length[idx], get_ratio(tau, dynamic_ratio[idx]))
)
idx += 1
if last_target:
res.append(last_target)
return res
def iterative_compress_prompt(
self,
context: List[str],
target_token: float,
iterative_size: int = 200,
keep_lines: bool = True,
start: int = 0,
dynamic_ratio: list = None,
condition_compare: bool = False,
):
"""
Iteratively compress text using a sliding window approach with KV caching.
Args:
context: List of text contexts to compress
target_token: Target number of tokens after compression
iterative_size: Size of each iteration window
keep_lines: Whether to keep line structure
start: Start position for processing
dynamic_ratio: List of dynamic compression ratios
condition_compare: Whether to use conditional comparison
Returns:
Compressed input IDs and attention mask
"""
# Calculate dynamic compression ratios for each iteration
iterative_ratios = self.get_dynamic_compression_ratio(
context, target_token, iterative_size, dynamic_ratio, start
)
# Join contexts and tokenize
context_joined = "\n\n".join(context)
tokenized_text = self.tokenizer(
context_joined, return_tensors="pt", add_special_tokens=False
)
input_ids = tokenized_text["input_ids"].to(self.model.device)
attention_mask = tokenized_text["attention_mask"].to(self.model.device)
# Initialize working variables
compressed_input_ids, compressed_attention_mask = input_ids, attention_mask
end = min(iterative_size + start, compressed_input_ids.shape[1])
threshold, keep_flag = None, None
if keep_lines:
# Build a keep flag for important line tokens (e.g., indentation patterns)
input_ids_numpy = input_ids.cpu().detach().numpy()[0]
N = len(input_ids_numpy)
# Identify line break patterns to preserve
newline_ids = set(self.tokenizer.encode("\n", add_special_tokens=False))
keep_flag = torch.zeros(N, dtype=torch.bool).to(self.model.device)
# Mark tokens that represent indentation to be preserved
for i in range(1, N):
if input_ids_numpy[i-1] in newline_ids:
# Check if this token is whitespace (indentation)
token = self.tokenizer.decode([input_ids_numpy[i]])
if token.isspace():
keep_flag[i] = True
# Initialize processing state
past_key_values, past_loss, ready_end = None, None, 0
pop_compressed_input_ids = None
idx = 0
# Process text in chunks
while end <= compressed_input_ids.shape[1]:
# Handle KV-cache window sliding for long texts
if end > self.max_position_embeddings and past_key_values is not None:
# KV-Cache Compression
e, s = end - self.max_position_embeddings, min(
self.cache_bos_num + start, self.max_position_embeddings
)
if pop_compressed_input_ids is None:
pop_compressed_input_ids = compressed_input_ids[:, :e]
else:
pop_compressed_input_ids = torch.cat(
[pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1
)
compressed_input_ids = compressed_input_ids[:, e:]
compressed_attention_mask = compressed_attention_mask[:, e:]
# Update KV cache - keep beginning tokens and skip processed tokens
past_key_values = [
[
torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
]
for k, v in past_key_values
]
if keep_flag is not None:
keep_flag = keep_flag[e:]
end, ready_end = end - e, ready_end - e
# Calculate perplexity for current window
result = self.get_ppl(
"",
"token",
compressed_input_ids,
compressed_attention_mask,
past_key_values=past_key_values,
return_kv=True,
end=end if idx else None,
)
loss, past_key_values = result["loss"], result["past_key_values"]
if loss.shape[0] == 0:
break
# Merge with previous loss calculations
if past_loss is not None:
if end - 1 > len(past_loss):
past_loss = torch.cat(
[past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]]
)
past_loss[ready_end : end - 1] = loss
loss = past_loss
else:
past_loss = loss
# Slide the KV cache window
if idx:
past_key_values = [
[k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]]
for k, v in past_key_values
]
else:
past_key_values = None
# Apply compression for each chunk in the current window
for delta_end, ratio in iterative_ratios[idx]:
loss = past_loss
# Calculate threshold for token filtering
threshold = self.get_estimate_threshold_base_distribution(
loss, ratio, False
)
# Filter tokens using the calculated threshold
compressed_input_ids, compressed_attention_mask, keep_flag, end, past_loss = self.get_compressed_input(
loss,
compressed_input_ids,
compressed_attention_mask,
end - iterative_size + delta_end,
iterative_size=delta_end,
threshold=threshold,
keep_flag=keep_flag,
start=start,
)
end += iterative_size
ready_end = end - iterative_size if not (start and idx == 0) else 0
idx += 1
# Concatenate saved tokens with final compressed tokens
if pop_compressed_input_ids is not None:
compressed_input_ids = torch.cat(
[pop_compressed_input_ids, compressed_input_ids], dim=-1
)
return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]
def iterative_compress_prompt_line(
self,
context: List[str],
target_token: float,
dynamic_ratio: list = None,
):
"""
Compress text by evaluating and filtering entire lines based on importance.
This is a line-level alternative to the token-level iterative_compress_prompt.
Args:
context: List of text contexts to compress
target_token: Target number of tokens after compression
dynamic_ratio: List of dynamic compression ratios for each context
Returns:
Compressed input IDs and attention mask
"""
# Join contexts
context_joined = "\n\n".join(context)
# Split text into lines
lines = context_joined.split("\n")
# Get perplexity for the entire text at line level
ppl_result = self.get_ppl(context_joined, granularity="line")
lines_info = ppl_result["lines_info"]
# Calculate token count for each line
line_tokens = [(i, info["tokens"], info["importance"])
for i, info in enumerate(lines_info)]
# Apply dynamic ratio adjustments if provided
if dynamic_ratio and len(dynamic_ratio) > 0:
# Create dynamic ratios for each line based on context dynamic ratios
# We'll infer which context each line belongs to
line_contexts = []
context_idx = 0
line_count = 0
# Map each line to its corresponding context
for i, info in enumerate(lines_info):
line_contexts.append(min(context_idx, len(dynamic_ratio) - 1))
line_count += 1
# Check if we've reached the end of a context
if line_count >= lines.count("\n") + 1 and context_idx < len(context) - 1:
context_idx += 1
line_count = 0
# Apply dynamic ratio adjustments to line importance scores
for i in range(len(line_tokens)):
if i < len(line_contexts):
context_idx = line_contexts[i]
if context_idx < len(dynamic_ratio):
# Adjust importance using dynamic ratio
# Lower importance score means higher priority (will be kept)
adjustment = dynamic_ratio[context_idx]
line_tokens[i] = (
line_tokens[i][0],
line_tokens[i][1],
line_tokens[i][2] - adjustment # Lower importance means keep
)
# Sort lines by importance (lower score is more important)
sorted_lines = sorted(line_tokens, key=lambda x: x[2])
# Select lines to keep within token budget
tokens_so_far = 0
lines_to_keep = set()
for line_idx, line_tokens, _ in sorted_lines:
if tokens_so_far + line_tokens <= target_token:
lines_to_keep.add(line_idx)
tokens_so_far += line_tokens
else:
# Stop if we've reached our target
break
# Create compressed text with only the selected lines
compressed_lines = [lines_info[i]["line"] for i in sorted(lines_to_keep)]
compressed_text = "\n".join(compressed_lines)
# Tokenize the compressed text
tokenized_text = self.tokenizer(
compressed_text, return_tensors="pt", add_special_tokens=False
)
compressed_input_ids = tokenized_text["input_ids"].to(self.model.device)
compressed_attention_mask = tokenized_text["attention_mask"].to(self.model.device)
return compressed_input_ids, compressed_attention_mask
def get_compressed_input(
self,
loss,
input_ids,
attention_mask,
end=200,
iterative_size=200,
threshold=0.5,
keep_flag=None,
start: int = 0,
):
"""
Filter input tokens based on loss values and thresholds.
Args:
loss: Loss values for each token
input_ids: Input token IDs
attention_mask: Attention mask
end: End position for processing
iterative_size: Size of each iteration
threshold: Threshold value for filtering
keep_flag: Flags for tokens to always keep
start: Start position for processing
Returns:
Compressed inputs and updated state
"""
# Determine which tokens to keep based on loss values and threshold
need_idx = torch.concat([loss > threshold, loss[:1] > 0])
# Ensure we keep tokens at positions outside our current window
need_idx[end:] = 1
need_idx[: end - iterative_size] = 1
# Get filtered loss
loss = loss[need_idx[:-1]]
# Ensure need_idx matches input_ids length
if need_idx.shape[0] < input_ids.shape[1]:
need_idx = torch.cat(
[
need_idx,
torch.ones(
input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool
).to(need_idx.device),
]
)
elif need_idx.shape[0] > input_ids.shape[1]:
need_idx = need_idx[: input_ids.shape[1]]
# Enforce keeping tokens marked in keep_flag
if keep_flag is not None:
need_idx[keep_flag] = 1
# Optionally apply line break preservation logic
# Find tokens representing newlines and always keep one of consecutive newlines
tokens = input_ids[0]
newline_ids = set(self.tokenizer.encode("\n", add_special_tokens=False))
last_kept_newline = False
for ii in range(max(0, end - iterative_size), end):
if need_idx[ii] == 0:
continue
token_id = tokens[ii].item()
# Handle newline logic - avoid consecutive newlines unless marked important
if token_id in newline_ids:
if last_kept_newline and keep_flag[ii].item() == 0:
need_idx[ii] = 0
else:
last_kept_newline = True
else:
last_kept_newline = False
# Apply the filtering to get compressed tokens
compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0)
compressed_attention_mask = attention_mask[attention_mask == 1][need_idx].unsqueeze(0)
# Update the end position based on how many tokens we removed
end -= (need_idx[:end] == 0).sum()
return compressed_input_ids, compressed_attention_mask, keep_flag, end, loss
def compress_code(
self,
code: str,
query: str = "",
instruction: str = "",
rate: float = 0.5,
target_token: int = -1,
use_line_level_filter: bool = True,
use_iterative_compression: bool = True,
iterative_size: int = 200,
dynamic_compression_ratio: float = 0.2,
):
"""
Compress code by removing less important lines based on query relevance.
Args:
code: The code to compress
query: Query to prioritize relevant lines
instruction: Additional instruction to guide compression
rate: Compression rate (0.0-1.0), where 1.0 means no compression
target_token: Target number of tokens (alternative to rate)
use_line_level_filter: Whether to use line-level filtering
use_iterative_compression: Whether to use token-level iterative compression
iterative_size: Size of each iteration for token-level compression
dynamic_compression_ratio: Ratio for dynamic compression (0.0-1.0)
Returns:
Compressed code and statistics
"""
logger.debug(f"Starting code compression with rate={rate}, target_token={target_token}")
start_time = time.time()
# Calculate total tokens in the code
total_tokens = self.get_token_length(code)
logger.debug(f"Total tokens in code: {total_tokens}")
# Determine target tokens
if target_token <= 0:
target_token = int(total_tokens * rate)
logger.debug(f"Target tokens: {target_token}")
if rate >= 1.0 or target_token >= total_tokens:
# No compression needed
return {
"original_code": code,
"compressed_code": code,
"output": code,
"original_tokens": total_tokens,
"compressed_tokens": total_tokens,
"final_compressed_tokens": total_tokens,
"compression_ratio": 1.0,
"kept_lines": list(range(len(code.split("\n")))),
}
# For very small code snippets, skip iterative compression
if total_tokens < 100:
use_iterative_compression = False
if use_line_level_filter:
# Split code into lines for line-level filtering
lines = code.split("\n")
non_empty_lines = [line for line in lines if line.strip()]
logger.debug(f"Split code into {len(non_empty_lines)} non-empty lines")
# Get perplexity for entire code
ppl_result = self.get_ppl(code, granularity="line")
lines_info = ppl_result["lines_info"]
# For query is provided, rank lines by relevance
if query:
logger.debug("Ranking lines by relevance to query")
# Get conditional perplexity for each line
line_importances = []
for i, line_info in tqdm(enumerate(lines_info), total=len(lines_info), desc="Calculating line importance"):
# First calculate the perplexity of the query without the line
query_ppl_without_context = self.get_ppl(query, granularity="line")["ppl"]
# Then calculate the perplexity of the query with the line as context
query_ppl_with_context = self.get_ppl(
line_info["line"] + "\n\n" + query,
granularity="line",
condition_mode="prefix",
condition_pos_id=self.get_token_length(line_info["line"] + "\n\n", add_special_tokens=True)
)["ppl"]
# Calculate the perplexity change (lower value means context is more helpful)
ppl_change = query_ppl_without_context - query_ppl_with_context
# Add length adjustment similar to before
line_importances.append((i, -ppl_change - line_info["tokens"] * 2 / 250 * 0))
# Sort by importance (higher perplexity reduction = more relevant to query)
sorted_lines = sorted(line_importances, key=lambda x: x[1])
else:
# Sort lines by importance (lower loss = more important)
line_importances = [(i, info["importance"]) for i, info in enumerate(lines_info)]
sorted_lines = sorted(line_importances, key=lambda x: x[1])
# Apply dynamic compression ratio if specified
if dynamic_compression_ratio > 0:
N = len(sorted_lines)
# This creates a gradient of compression rates from higher to lower importance
dynamic_ratios = [
i * (dynamic_compression_ratio / (N - 1)) if N > 1 else 0
for i in range(-(N - 1), N, 2)
]
# Assign dynamic ratios to lines based on their importance rank
sorted_indices = [idx for idx, _ in sorted_lines]
dynamic_ratio_map = {idx: ratio for idx, ratio in zip(sorted_indices, dynamic_ratios)}
else:
dynamic_ratio_map = {i: 0 for i in range(len(lines_info))}
# Determine which lines to keep based on token budget
tokens_so_far = 0
lines_to_keep = set()
# First pass - keep most important lines within budget
for line_idx, _ in sorted_lines:
if line_idx >= len(lines_info):
continue
line_info = lines_info[line_idx]
line_tokens = line_info["tokens"]
if tokens_so_far + line_tokens <= target_token:
lines_to_keep.add(line_idx)
tokens_so_far += line_tokens
else:
# Stop if we've reached our target
break
logger.debug(f"Selected {len(lines_to_keep)} lines to keep out of {len(lines_info)}")
# Construct code with only the selected lines
preserved_code = "\n".join([lines_info[i]["line"] for i in sorted(lines_to_keep)])
# If we need iterative token-level compression
if use_iterative_compression:
logger.debug("Applying iterative line-level compression")
# Create dynamic ratios for iterative compression
dynamic_ratios = [dynamic_ratio_map.get(i, 0.0) for i in sorted(lines_to_keep)]
# Convert to list for iterative compression
context = [preserved_code]
# Apply line-level compression instead of token-level compression
compressed_ids, compressed_mask = self.iterative_compress_prompt_line(
context,
target_token=target_token,
dynamic_ratio=dynamic_ratios,
)
# Convert back to text
compressed_code = self.tokenizer.decode(compressed_ids[0])
else:
compressed_code = preserved_code
else:
# Without line-level filter, apply iterative compression directly
if use_iterative_compression:
logger.debug("Applying iterative line-level compression without line filtering")
# Apply line-level compression to the entire code
compressed_ids, _ = self.iterative_compress_prompt_line(
[code],
target_token=target_token,
dynamic_ratio=[0.0], # No dynamic ratio adjustment for single context
)
# Convert back to text
compressed_code = self.tokenizer.decode(compressed_ids[0])
else:
# Simple truncation
logger.debug("No compression methods selected, using simple truncation")
encoded = self.tokenizer.encode(code, add_special_tokens=False)
truncated = encoded[:target_token]
compressed_code = self.tokenizer.decode(truncated)
# Construct final output with instruction and query
output = ""
if instruction:
output += f"{instruction}\n\n"
output += compressed_code
if query:
output += f"\n\n{query}"
# Calculate compression statistics
compressed_tokens = self.get_token_length(compressed_code)
final_compressed_tokens = self.get_token_length(output)
compression_ratio = compressed_tokens / total_tokens if total_tokens > 0 else 1.0
end_time = time.time()
logger.debug(f"Code compression completed in {end_time - start_time:.2f} seconds")
logger.debug(f"Compression ratio: {compression_ratio:.2f}")
# For line-level filtering, include which lines were kept
if use_line_level_filter:
kept_lines = sorted(lines_to_keep)
else:
# Approximate which lines were kept based on content
original_lines = code.split("\n")
compressed_lines = compressed_code.split("\n")
kept_lines = []
for i, line in enumerate(original_lines):
if line in compressed_lines:
kept_lines.append(i)
return {
"original_code": code,
"compressed_code": compressed_code,
"output": output,
"original_tokens": total_tokens,
"compressed_tokens": compressed_tokens,
"final_compressed_tokens": final_compressed_tokens,
"compression_ratio": compression_ratio,
"kept_lines": kept_lines,
}
def control_context_budget(
self,
context_list: List[str],
target_token: float,
question: str = "",
reorder_context: str = "original",
condition_in_question: str = "none",
force_context_ids: List[int] = None,
force_context_number: int = None,
context_budget: str = "+100",
dynamic_context_compression_ratio: float = 0.0,
):
"""
Control token budget for contexts based on relevance ranking, following LongLLMLingua.
Args:
context_list: List of contexts
target_token: Target number of tokens
question: Question for relevance ranking
reorder_context: How to reorder contexts ("original", "importance", "two_stage")
condition_in_question: Mode for conditional ranking
force_context_ids: List of context IDs to always include
force_context_number: Number of contexts to forcibly include
context_budget: String expression to modify target token budget
dynamic_context_compression_ratio: Ratio for dynamic compression (0.0-1.0)
Returns:
Selected contexts, their indices, and dynamic ratios
"""
logger.debug(f"Controlling context budget with target_token={target_token}")
start_time = time.time()
if not context_list:
return [], [], []
# Get token counts for each context
logger.debug("Calculating token lengths for contexts")
context_tokens_length = [self.get_token_length(context) for context in context_list]
# If total tokens already fit within budget, return all contexts
total_tokens = sum(context_tokens_length)
if total_tokens <= target_token:
logger.debug(f"All contexts fit within budget ({total_tokens} <= {target_token})")
end_time = time.time()
logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds")
return context_list, list(range(len(context_list))), [0.0] * len(context_list)
# Rank contexts by relevance if question is provided
logger.debug("Ranking contexts by relevance")
if question:
# Get perplexity change for each context with the question
context_ppl_changes = []
for d, dl in zip(context_list, context_tokens_length):
# Calculate how much this context reduces question perplexity
ppl_change = self.get_condition_ppl(
d,
question,
condition_in_question,
)
# Apply length adjustment factor similar to before
context_ppl_changes.append(ppl_change - dl * 2 / 250 * 0)
# Sort by perplexity change - higher is better (more reduction in question perplexity)
demonstrations_sort = sorted(enumerate(context_ppl_changes), key=lambda x: -x[1])
else:
# Without question, use default ordering
demonstrations_sort = [(i, 0) for i in range(len(context_list))]
# Extract ranking for later use
self.context_idxs.append([x for idx, (x, _) in enumerate(demonstrations_sort)])
# Calculate the target token budget with context_budget expression
if target_token < 0:
target_token = 100
target_token = eval("target_token" + context_budget)
# Initialize selected context tracking
used = force_context_ids if force_context_ids is not None else []
# Select contexts until we reach the token budget
for idx, _ in demonstrations_sort:
if idx >= len(context_tokens_length):
continue
target_token -= context_tokens_length[idx]
if idx not in used:
used.append(idx)
if target_token < 0 or (
force_context_number is not None and len(used) >= force_context_number
):
break
# Store original selection order
original_used = used.copy()
# Reorder contexts if requested
if reorder_context == "original":
used = sorted(used)
elif reorder_context == "two_stage":
l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [
_ for idx, _ in enumerate(used) if idx % 2 == 1
]
used = l + r[::-1]
# Calculate dynamic compression ratios if requested
if dynamic_context_compression_ratio > 0:
N = len(used)
dynamic_ratio = [
i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0
for i in range(-(N - 1), N, 2)
][::-1]
dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)}
dynamic_ratio = [dynamic_ratio_map[i] for i in used]
else:
dynamic_ratio = [0.0] * len(used)
# Build list of selected contexts
selected_contexts = [context_list[idx] for idx in used if idx < len(context_list)]
end_time = time.time()
logger.debug(f"Selected {len(selected_contexts)} contexts out of {len(context_list)}")
logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds")
return selected_contexts, used, dynamic_ratio, demonstrations_sort
def compress_code_file(
self,
code: str,
query: str = "",
instruction: str = "",
rate: float = 0.5,
target_token: float = -1,
language: str = "python",
use_iterative_compression: bool = True,
iterative_size: int = 200,
dynamic_compression_ratio: float = 0.2,
context_budget: str = "+100",
rank_only: bool = False,
):
"""
Compress a code file by first splitting it into function-based chunks and then compressing.
Functions are prioritized based on query relevance, similar to LongLLMLingua.
Args:
code: The code to compress
query: Query to prioritize relevant functions
instruction: Additional instruction to guide compression
rate: Compression rate (0.0-1.0)
target_token: Target number of tokens (alternative to rate)
language: Programming language of the code
use_iterative_compression: Whether to use iterative compression
iterative_size: Size of each iteration for iterative compression
dynamic_compression_ratio: Ratio for dynamic compression
context_budget: String expression to modify token budget
rank_only: If True, just rank and select contexts without fine-grained compression
Returns:
Compressed code and statistics
"""
logger.debug(f"Starting code file compression with rate={rate}, target_token={target_token}, language={language}")
start_time = time.time()
# Split code into function-based chunks
logger.debug("Splitting code into function-based chunks")
code_chunks = self.split_code_by_functions(code, language=language)
logger.debug(f"Split code into {len(code_chunks)} chunks")
# Calculate total tokens
logger.debug("Calculating total tokens")
total_tokens = sum(self.get_token_length(chunk) for chunk in code_chunks)
logger.debug(f"Total tokens: {total_tokens}")
# If target token is not provided, use rate
if target_token <= 0:
target_token = int(total_tokens * rate)
logger.debug(f"Target tokens: {target_token}")
# Use context budget control to select important functions
logger.debug("Selecting important functions using context budget control")
selected_contexts, selected_indices, dynamic_ratios, demonstrations_sort = self.control_context_budget(
code_chunks,
target_token=target_token,
question=query,
reorder_context="original", # Keep original order to maintain code structure
condition_in_question="prefix",
context_budget=context_budget,
dynamic_context_compression_ratio=dynamic_compression_ratio,
)
# If rank_only is True, just use the selected contexts without further compression
if rank_only:
logger.debug("Using rank-only mode: selecting top functions without fine-grained compression")
compressed_chunks = []
compressed_tokens = 0
function_compressions = {}
# Just keep the selected contexts as is
for i, chunk in enumerate(code_chunks):
if i in selected_indices:
compressed_chunks.append(chunk)
chunk_tokens = self.get_token_length(chunk)
compressed_tokens += chunk_tokens
# Store compression info - no actual compression in this mode
function_compressions[i] = {
"original_tokens": chunk_tokens,
"compressed_tokens": chunk_tokens,
"compression_ratio": 1.0,
}
else:
# Skip this function completely
comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//"
omission_text = f"{comment_marker} ... "
compressed_chunks.append(omission_text)
compressed_tokens += self.get_token_length(omission_text)
# Combine compressed chunks
compressed_code = "\n\n".join(compressed_chunks)
output = f"{instruction}\n\n{compressed_code}\n\n{query}\n{instruction}"
# Calculate actual compressed tokens
final_compressed_tokens = self.get_token_length(output)
end_time = time.time()
logger.debug(f"Code file compression completed in {end_time - start_time:.2f} seconds")
logger.debug(f"Compression ratio: {compressed_tokens / total_tokens if total_tokens > 0 else 1.0:.2f}")
return {
"original_code": code,
"compressed_code": compressed_code,
"compressed_prompt": output,
"original_tokens": total_tokens,
"compressed_tokens": compressed_tokens,
"final_compressed_tokens": final_compressed_tokens,
"compression_ratio": compressed_tokens / total_tokens if total_tokens > 0 else 1.0,
"function_compressions": function_compressions,
"selected_functions": selected_indices,
"demonstrations_sort": demonstrations_sort,
}
# Compress each function according to its importance
logger.debug("Compressing selected functions")
compressed_chunks = []
compressed_tokens = 0
function_compressions = {}
# Allocate tokens proportionally based on importance
importance_scores = {}
for i, idx in enumerate(selected_indices):
# Higher importance for functions mentioned early in ranking
importance_scores[idx] = len(selected_indices) - i
# Calculate total importance
total_importance = sum(importance_scores.values()) if importance_scores else 1
# Allocate tokens based on importance
token_allocation = {}
for idx, importance in importance_scores.items():
allocation = max(10, int(target_token * importance / total_importance))
token_allocation[idx] = min(allocation, self.get_token_length(code_chunks[idx]))
# Adjust allocations to fit target
logger.debug("Adjusting token allocations to fit target")
while sum(token_allocation.values()) > target_token:
max_idx = max(token_allocation, key=token_allocation.get)
token_allocation[max_idx] = max(0, token_allocation[max_idx] - 10)
# Show the allocation
logger.debug(f"Token allocation: {token_allocation}")
# Process each chunk
for i, chunk in tqdm(enumerate(code_chunks), total=len(code_chunks), desc="Compressing functions"):
if i in token_allocation and token_allocation[i] > 0:
# Calculate compression rate for this chunk
chunk_tokens = self.get_token_length(chunk)
chunk_rate = token_allocation[i] / chunk_tokens
# Apply dynamic compression ratio based on importance
dynamic_ratio = dynamic_ratios[selected_indices.index(i)] if i in selected_indices else 0.0
# Compress the chunk using line-level compression if requested
if use_iterative_compression and chunk_tokens > 50:
compressed_input_ids, _ = self.iterative_compress_prompt_line(
[chunk],
target_token=token_allocation[i],
dynamic_ratio=[dynamic_ratio],
)
compressed_chunk = self.tokenizer.decode(compressed_input_ids[0])
else:
# Use simple line-level compression for smaller chunks
compress_result = self.compress_code(
code=chunk,
query=query,
rate=chunk_rate,
use_iterative_compression=False
)
compressed_chunk = compress_result["compressed_code"]
compressed_chunks.append(compressed_chunk)
chunk_compressed_tokens = self.get_token_length(compressed_chunk)
compressed_tokens += chunk_compressed_tokens
# Store compression info for this function
function_compressions[i] = {
"original_tokens": chunk_tokens,
"compressed_tokens": chunk_compressed_tokens,
"compression_ratio": chunk_compressed_tokens / chunk_tokens if chunk_tokens > 0 else 1.0,
}
else:
# Skip this function completely
comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//"
# omission_text = f"{comment_marker} ... function omitted ..."
omission_text = f"{comment_marker} ... "
compressed_chunks.append(omission_text)
compressed_tokens += self.get_token_length(omission_text)
# Combine compressed chunks
logger.debug("Combining compressed chunks")
compressed_code = "\n\n".join(compressed_chunks)
# # If instruction is provided, add it to the final output
# output = ""
# if instruction:
# output += f"{instruction}\n\n"
# output += compressed_code
# if query:
# output += f"\n\n{query}"
output = f"{instruction}\n\n{compressed_code}\n\n{query}\n{instruction}"
# Calculate actual compressed tokens including instruction and query
final_compressed_tokens = self.get_token_length(output)
end_time = time.time()
logger.debug(f"Code file compression completed in {end_time - start_time:.2f} seconds")
logger.debug(f"Compression ratio: {compressed_tokens / total_tokens if total_tokens > 0 else 1.0:.2f}")
return {
"original_code": code,
"compressed_code": compressed_code,
"compressed_prompt": output,
"original_tokens": total_tokens,
"compressed_tokens": compressed_tokens,
"final_compressed_tokens": final_compressed_tokens,
"compression_ratio": compressed_tokens / total_tokens if total_tokens > 0 else 1.0,
"function_compressions": function_compressions,
"selected_functions": selected_indices,
"demonstrations_sort": demonstrations_sort,
}
def split_code_by_functions(self, code: str, language: str = "python") -> List[str]:
"""
Split code into chunks based on function and class definitions for various languages.
Args:
code: The code to split
language: Programming language of the code (python, cpp, java, typescript, rust, go)
Returns:
List of code chunks, each containing a function, class, or class method
"""
logger.debug(f"Splitting code by functions and classes for language: {language}")
start_time = time.time()
# Define regex patterns for different languages
patterns = {
# Python: Simplified to match 'def' or 'class' followed by content until the next def/class or end
"python": r'(^|\n)(\s*)(def|class)\s+[^\n]+(\n(?!\s*(?:def|class)\s)[^\n]*)*',
# C++: Improved to better handle multi-line declarations
"cpp": r'(^|\n)(\s*)(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s*:\s*[^{]*)?|(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*\s+)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?',
# Java: Improved for multi-line method declarations
"java": r'(^|\n)(\s*)(?:(?:public|private|protected|static|final|abstract|synchronized)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:<.*>)?(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*)\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*throws\s+[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?',
# TypeScript: Enhanced to handle multi-line methods and arrow functions
"typescript": r'(^|\n)(\s*)(?:(?:public|private|protected|static|abstract)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:(?:public|private|protected|static|async)\s+)*(?:function\s+)?(?:[a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<.*>)?\s*\([^{;]*\)\s*(?::\s*[^{;]*\s*)?(?:=>)?)\s*(?:{[^}]*}|[^;]*;)?',
# Rust: Improved for multi-line function declarations
"rust": r'(^|\n)(\s*)(?:pub\s+)?(?:struct\s+[a-zA-Z_][a-zA-Z0-9_]*|impl(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+for\s+[a-zA-Z_][a-zA-Z0-9_]*)?|(?:async\s+)?fn\s+[a-zA-Z_][a-zA-Z0-9_]*\s*(?:<.*>)?\s*\([^{;]*\)(?:\s*->\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?',
# Go: Improved for multi-line function declarations
"go": r'(^|\n)(\s*)(?:type\s+[a-zA-Z_][a-zA-Z0-9_]*\s+struct|func\s+(?:\([^)]*\)\s*)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?',
}
# Use default Python pattern if language not supported
if language.lower() not in patterns:
language = "python"
function_pattern = re.compile(patterns[language.lower()], re.MULTILINE)
# Find all function and class definitions
matches = list(function_pattern.finditer(code))
logger.debug(f"Found {len(matches)} function and class definitions")
# If no functions or classes found, return the whole code as one chunk
if not matches:
logger.debug("No functions or classes found, returning entire code as one chunk")
end_time = time.time()
logger.debug(f"Code splitting completed in {end_time - start_time:.2f} seconds")
return [code]
# Extract chunks that include function and class definitions
chunks = []
# Add imports and other code before the first function or class
if matches[0].start() > 0:
chunks.append(code[:matches[0].start()])
# Process each function or class match
for i, match in enumerate(matches):
# Get the current function or class
start = match.start()
# Determine end position (either the start of the next function/class or the end of the code)
if i < len(matches) - 1:
end = matches[i + 1].start()
else:
end = len(code)
# Extract the function/class and its body
chunks.append(code[start:end])
end_time = time.time()
logger.debug(f"Code splitting completed in {end_time - start_time:.2f} seconds")
logger.debug(f"Split code into {len(chunks)} chunks")
return chunks
def load_examples(language: Optional[str] = None) -> List[Dict]:
"""Load examples from the results file, optionally filtered by language"""
with open("../results/ntoken_16384/Qwen_slash_Qwen2.5-7B-Instruct.jsonl", "r") as f:
# with open("../results/ntoken_16384/Qwen_slash_Qwen2.5-7B-Instruct-GPTQ-Int4.jsonl", "r") as f:
data = [json.loads(line) for line in f]
if language:
data = [example for example in data if example["language"] == language]
if not data:
available_languages = set(ex["language"] for ex in data)
raise ValueError(f"No examples found for language '{language}'. Available languages: {available_languages}")
return data
# Simple test code
if __name__ == "__main__":
# Load real examples from the dataset
examples = load_examples(language="python")
example = examples[0] # Use the first example
sample_code = example["code_context"]
query = example["description"]
language = example["language"]
print(f"Using example with language: {language}")
print(f"Query: {query}")
# Initialize compressor
print("Initializing compressor...")
compressor = CodeCompressor()
# Test function-based code file compression with query
print("\nTesting function-based code file compression with query...")
start_time = time.time()
file_result = compressor.compress_code_file(
code=sample_code,
query=query,
rate=0.1,
language=language
)
end_time = time.time()
print(f"File compression with query completed in {end_time - start_time:.2f} seconds")
print(f"Original tokens: {file_result['original_tokens']}")
print(f"Compressed tokens: {file_result['compressed_tokens']}")
print(f"Final compressed tokens (with query): {file_result['final_compressed_tokens']}")
print(f"Compression ratio: {file_result['compression_ratio']:.2f}")
print(f"Kept function IDs: {file_result['selected_functions']}")
print(f"Demonstrations sort: {file_result['demonstrations_sort']}")
chunk_ppl_scores = {idx: score for idx, score in file_result['demonstrations_sort']}
top_5_score = sorted(chunk_ppl_scores.values(), reverse=True)[5]
# Split into chunks and show the chunks
chunks = compressor.split_code_by_functions(sample_code, language=language)
print(f"Split code into {len(chunks)} chunks")
# show the chunk with corresponding ppl score
for i, chunk in enumerate(chunks):
print(f"==========Chunk {i+1} with demonstration sort: {chunk_ppl_scores[i]}==========")
if chunk_ppl_scores[i] >= top_5_score:
print(chunk)
print("\n")
else:
# only show some lines and then use ... to indicate the rest
print(chunk[:100])
print("...")
print(chunk[-100:])
print("\n")
print("\nCompressed Code File with Query:")
print("-------------------")
print(file_result['compressed_code'])