diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..532e16d --- /dev/null +++ b/.gitignore @@ -0,0 +1,188 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +crosscodeeval/ +temp/ +*.jsonl +*.out +*.txt +datasets/ +repositories/ +*.ipynb +cache/ +output*/ +*.pdf +*.json +*.jsonl +old_scripts/ +*cache*/ +*.zip +*.tar +*.tar.gz +*.tar.xz +*.tar.bz2 +*.tar.lzma +*.tar.lz4 +*.tar.zstd +*.tar.lz +*.html \ No newline at end of file diff --git a/README.md b/README.md index 9635efa..5477bce 100644 --- a/README.md +++ b/README.md @@ -1 +1,125 @@ # LongCodeZip + +This repository is the official implementation of LongCodeZip, a novel two-stage long code compression method. + + +## Method Overview + +![Overview](assets/overview.png) + +LongCodeZip introduces a two-stage code compression framework specifically designed for code LLMs: + +1. **Coarse-grained Compression**: Function-based chunking and ranking using conditional perplexity with respect to the query to select the most relevant functions. + +2. **Fine-grained Compression**: Entropy-based block detection combined with 0/1 knapsack optimization to maximize relevance within adaptive token budgets. + +The method is plug-and-play and can be integrated with existing code LLMs to achieve significant compression ratios while maintaining or improving task performance. + +## Repository Structure + +This repository contains implementations and experiments for three code-related tasks: + +``` +LongCodeZip/ +├── repoqa/ # Code Retrieval Task +│ ├── main.py # Main evaluation script +│ ├── run.sh # Experiment runner +│ ├── code_compressor.py # Core compression implementation +│ ├── compute_score.py # Evaluation metrics +│ └── ... +├── long-code-completion/ # Code Completion Task +│ ├── main.py # Main evaluation script +│ ├── run.sh # Experiment runner +│ ├── code_compressor.py # Core compression implementation +│ ├── utils.py # Utility functions +│ └── ... +├── module_summarization/ # Code Summarization Task +│ ├── main.py # Main evaluation script +│ ├── run.sh # Experiment runner +│ ├── code_compressor.py # Core compression implementation +│ ├── utils.py # Utility functions +│ └── ... +└── README.md +``` + +## Installation + +```bash +pip install -r requirements.txt +``` + +## Usage + +### Quick Start + +Each task directory contains a `run.sh` script for easy experimentation. Simply navigate to the desired task directory and run: + +```bash +cd +bash run.sh +``` + +### Code Retrieval (RepoQA) + +Navigate to the `repoqa` directory and run experiments with different compression ratios: + +```bash +cd repoqa +bash run.sh +``` + +The script will evaluate LongCodeZip on the RepoQA dataset with compression ratios of 0.1, 0.2, 0.3, and 0.4, running experiments in parallel on multiple GPUs. + +**Key Parameters:** +- `--compression-ratio`: Controls the compression level (0.1-0.4) +- `--model`: Specifies the base LLM model +- `--backend`: Backend for model inference (vllm) + +### Code Completion + +Navigate to the `long-code-completion` directory: + +```bash +cd long-code-completion +bash run.sh +``` + +This evaluates LongCodeZip on long-context code completion tasks with various configurations including different target token limits, fine-grained compression ratios, and importance beta values. + +**Key Parameters:** +- `--code_compressor_target_token`: Target token budget (2048, 4096) +- `--code_compressor_fine_ratio`: Fine-grained compression ratio (0.5, 0.8) +- `--importance_beta`: Importance weighting parameter (0.0, 0.5) + +### Code Summarization + +Navigate to the `module_summarization` directory: + +```bash +cd module_summarization +bash run.sh +``` + +This runs code summarization experiments with fine-grained compression and various beta values for importance weighting. + +**Key Parameters:** +- `--code_compressor_target_token`: Target token budget +- `--code_compressor_fine_ratio`: Fine-grained compression ratio +- `--importance_beta`: Importance weighting parameter + +## Configuration + +Each task can be customized by modifying the respective `run.sh` file or by directly calling the main scripts with custom parameters. Key configuration options include: + +- **Model Selection**: Compatible with various code LLMs (default: Qwen2.5-Coder-7B-Instruct) +- **Compression Ratios**: Adjustable compression levels for different use cases +- **Token Budgets**: Configurable target token limits +- **GPU Configuration**: Multi-GPU support for parallel experiments + +## Performance + +LongCodeZip achieves up to **5.6× compression ratio** without sacrificing task performance across code completion, summarization, and retrieval tasks. And even when using a 0.5B Qwen model as the compressor, it can also achieve competitive performance. + +## Contact + +Please feel free to contact us if you have any questions. \ No newline at end of file diff --git a/assets/overview.png b/assets/overview.png new file mode 100644 index 0000000..023ca5b Binary files /dev/null and b/assets/overview.png differ diff --git a/long-code-completion/code_compressor.py b/long-code-completion/code_compressor.py new file mode 100644 index 0000000..b39b106 --- /dev/null +++ b/long-code-completion/code_compressor.py @@ -0,0 +1,1889 @@ +import torch +import numpy as np +from typing import List, Union, Tuple, Dict, Optional +import re +import math +import zlib +from transformers import AutoModelForCausalLM, AutoTokenizer +import time +from tqdm import tqdm +import copy +import bisect +import json +from llmlingua import PromptCompressor +from loguru import logger + +class EntropyChunking: + def __init__(self, model_name="Qwen/Qwen2.5-Coder-0.5B-Instruct"): + """Entropy-based text chunking implementation""" + logger.debug(f"Loading Entropy chunking model: {model_name}") + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + logger.debug(f"Entropy chunking model loaded on device: {self.device}") + + def split_into_sentences(self, text: str) -> List[str]: + """Split text into sentences, inserting empty lines for double newlines""" + # First replace double newlines with a special marker + text_with_markers = text.replace('\n\n', '\n__EMPTY_LINE__\n') + + # Split by single newlines + lines = text_with_markers.split('\n') + + # Process lines: replace markers with empty strings, keep original lines + sentences = [] + for line in lines: + if line == '__EMPTY_LINE__': + sentences.append(' ') # Empty line for double newline breaks + else: + sentences.append(line) # Keep original line with indentation + + return sentences + + def calculate_sentence_ppl(self, sentences: List[str]) -> List[float]: + """Calculate perplexity for each sentence based on preceding context""" + ppls = [] + + for i, sentence in enumerate(sentences): + if i == 0: + context = "" + target = sentence + else: + context = "\n".join(sentences[:i]) + target = sentence + + ppl = self._compute_ppl(context, target) + ppls.append(ppl) + + return ppls + + def _compute_ppl(self, context: str, target: str) -> float: + """Compute perplexity of target text given context""" + # Handle empty target lines + if not target: + return 0.0 # Assign zero perplexity to empty lines + + if context: + full_text = context + "\n" + target + context_tokens = self.tokenizer(context + "\n", return_tensors="pt", add_special_tokens=True) + context_length = context_tokens.input_ids.shape[1] + else: + full_text = target + context_length = 0 + + inputs = self.tokenizer(full_text, return_tensors="pt", add_special_tokens=True).to(self.device) + + with torch.no_grad(): + outputs = self.model(**inputs) + logits = outputs.logits + + if context_length > 0: + target_logits = logits[0, context_length-1:-1] + target_labels = inputs.input_ids[0, context_length:] + else: + target_logits = logits[0, :-1] + target_labels = inputs.input_ids[0, 1:] + + if len(target_labels) > 0: + log_probs = torch.log_softmax(target_logits, dim=-1) + token_log_probs = log_probs[torch.arange(len(target_labels)), target_labels] + avg_log_prob = token_log_probs.mean().item() + ppl = math.exp(-avg_log_prob) + else: + ppl = float('inf') + + # take log2 of ppl + ppl = math.log2(ppl) + + return ppl + + def calculate_adaptive_thresholds(self, ppls: List[float], k: float = 0.2) -> dict: + """Calculate adaptive thresholds using different statistical methods""" + # Filter out infinite and NaN values + valid_ppls = [p for p in ppls if not math.isinf(p) and not math.isnan(p) and p > 0] + + if len(valid_ppls) < 3: + # Fallback to fixed threshold if not enough valid data + return { + 'std': 0.5, + 'robust_std': 0.5, + 'iqr': 0.5, + 'mad': 0.5 + } + + valid_ppls = np.array(valid_ppls) + + # Method 1: Standard deviation based + mean_ppl = np.mean(valid_ppls) + std_ppl = np.std(valid_ppls) + threshold_std = mean_ppl + k * std_ppl + + # Method 2: Robust standard deviation (using median and MAD) + median_ppl = np.median(valid_ppls) + mad = np.median(np.abs(valid_ppls - median_ppl)) + robust_std = mad * 1.4826 # Convert MAD to robust std estimate + threshold_robust_std = median_ppl + k * robust_std + + # Method 3: IQR based (Interquartile Range) + q25 = np.percentile(valid_ppls, 25) + q75 = np.percentile(valid_ppls, 75) + iqr = q75 - q25 + threshold_iqr = q75 + k * iqr + + # Method 4: MAD based (Median Absolute Deviation) + threshold_mad = median_ppl + k * mad + + return { + 'std': threshold_std, + 'robust_std': threshold_robust_std, + 'iqr': threshold_iqr, + 'mad': threshold_mad + } + + def find_ppl_spikes_adaptive(self, values: List[float], method: str = 'std', k: float = 0.2) -> tuple: + """Find PPL spikes using adaptive threshold based on statistical method""" + thresholds = self.calculate_adaptive_thresholds(values, k) + threshold = thresholds[method] + + spike_indices = [] + + for i in range(1, len(values) - 1): + current = values[i] + left = values[i - 1] + right = values[i + 1] + + # Skip infinite or NaN values + if math.isinf(current) or math.isnan(current): + continue + if math.isinf(left) or math.isnan(left): + left = current + if math.isinf(right) or math.isnan(right): + right = current + + # Check if current PPL is significantly higher than both neighbors + left_diff = current - left + right_diff = current - right + + # Condition: Current PPL is higher than both neighbors with adaptive threshold + if (left_diff >= threshold or right_diff >= threshold) and (left_diff >= 0 and right_diff >= 0): + spike_indices.append(i) + + return spike_indices, threshold + + def chunk_text_adaptive(self, text: str, method: str = 'std', k: float = 0.2) -> tuple: + """Perform PPL-based text chunking using adaptive spike detection""" + sentences = self.split_into_sentences(text) + ppls = self.calculate_sentence_ppl(sentences) + spike_indices, threshold = self.find_ppl_spikes_adaptive(ppls, method, k) + + chunks = [] + # Split at spike points (after the spike line) + split_points = [0] + [idx + 1 for idx in spike_indices] + [len(sentences)] + + for i in range(len(split_points) - 1): + start = split_points[i] + end = split_points[i + 1] + chunk_sentences = sentences[start:end] + chunk_text = "\n".join(chunk_sentences) + chunks.append(chunk_text) + + return chunks, sentences, ppls, spike_indices + +class CodeCompressor: + def __init__( + self, + model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4", + device_map: str = "cuda", + model_config: dict = {}, + ): + """ + Initialize the CodeCompressor with a language model for compression. + + Args: + model_name: The name of the model to load from HuggingFace + device_map: Device to load the model on + model_config: Additional configuration for the model + """ + self.model_name = model_name + self.device = device_map + self.model_config = model_config + self.load_model(model_name, device_map, model_config) + + # Initialize Entropy chunking with smaller model + logger.debug("Initializing Entropy chunking...") + self.entropy_chunking = EntropyChunking() + + # Add caching system for model outputs and token information + self.cache = { + "token_length": {}, # Cache for token length by text + "encodings": {}, # Cache for tokenizer encodings + "perplexity": {}, # Cache for perplexity calculations + "conditional_ppl": {}, # Cache for conditional perplexity + "context_rankings": {}, # Cache for context rankings + } + self.max_cache_size = 1000 # Limit cache size to prevent memory issues + + # set up the max position embeddings and cache bos num + self.max_position_embeddings = getattr(self.model.config, "max_position_embeddings", 4096) + self.cache_bos_num = 10 + self.prefix_bos_num = 100 + self.context_idxs = [] + + def load_model( + self, model_name: str, device_map: str = "cuda", model_config: dict = {} + ): + """ + Load the language model and tokenizer. + + Args: + model_name: The name of the model to load + device_map: Device to load the model on + model_config: Additional configuration for the model + """ + logger.debug(f"Loading model {model_name} on {device_map}") + torch_dtype = torch.bfloat16 if "torch_dtype" not in model_config else model_config["torch_dtype"] + # model_kwargs = {"device_map": device_map, "torch_dtype": torch_dtype, "trust_remote_code": True} + model_kwargs = {"device_map": device_map, "torch_dtype": torch_dtype, "trust_remote_code": True} + + for k, v in model_config.items(): + model_kwargs[k] = v + + self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.padding_side = "left" + logger.debug("Model and tokenizer loaded successfully") + + def _manage_cache_size(self, cache_type): + """ + Manage cache size by removing oldest entries when cache exceeds max size. + + Args: + cache_type: The type of cache to manage + """ + if len(self.cache[cache_type]) > self.max_cache_size: + # Remove 20% of the oldest entries + remove_count = int(self.max_cache_size * 0.2) + keys_to_remove = list(self.cache[cache_type].keys())[:remove_count] + for key in keys_to_remove: + del self.cache[cache_type][key] + + def get_token_length( + self, + text: str, + add_special_tokens: bool = True, + ): + """ + Get the number of tokens in the given text. + + Args: + text: The text to tokenize + add_special_tokens: Whether to count special tokens + + Returns: + The number of tokens + """ + # Create a cache key based on text and parameters + cache_key = f"{text}_{add_special_tokens}" + + # Check if result is in cache + if cache_key in self.cache["token_length"]: + return self.cache["token_length"][cache_key] + + # Calculate token length if not in cache + token_length = len(self.tokenizer.encode(text, add_special_tokens=add_special_tokens)) + + # Store in cache + self.cache["token_length"][cache_key] = token_length + self._manage_cache_size("token_length") + + return token_length + + def get_ppl( + self, + text: str, + granularity: str = "line", + input_ids=None, + attention_mask=None, + past_key_values=None, + return_kv=False, + end=None, + condition_mode: str = "none", + condition_pos_id: int = 0, + ): + """ + Calculate perplexity for the given text at line level. + + Args: + text: The text to calculate perplexity for + granularity: The granularity of perplexity calculation (line, token, chunk) + input_ids, attention_mask, past_key_values: Optional pre-processed inputs + return_kv: Whether to return key-values + end: End position for calculation + condition_mode: Mode for conditional perplexity (none, prefix) + condition_pos_id: Position ID for condition + + Returns: + A dictionary with perplexity scores and processing information + """ + # Create a cache key for this specific perplexity calculation + cache_key = f"{text}_{granularity}_{condition_mode}_{condition_pos_id}" + if past_key_values is None and not return_kv and cache_key in self.cache["perplexity"]: + return self.cache["perplexity"][cache_key] + + # Initialize input processing + if input_ids is None: + encoding_key = text + if encoding_key in self.cache["encodings"]: + cached_encoding = self.cache["encodings"][encoding_key] + input_ids = cached_encoding["input_ids"] + attention_mask = cached_encoding["attention_mask"] + else: + encoding = self.tokenizer( + text, + return_tensors="pt", + padding=True + ) + input_ids = encoding["input_ids"].to(self.model.device) + attention_mask = encoding["attention_mask"].to(self.model.device) + + # Cache the encoding + self.cache["encodings"][encoding_key] = { + "input_ids": input_ids, + "attention_mask": attention_mask + } + self._manage_cache_size("encodings") + + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + else: + past_length = 0 + + if end is None: + end = input_ids.shape[1] + end = min(end, past_length + self.max_position_embeddings) + + with torch.no_grad(): + outputs = self.model( + input_ids=input_ids[:, past_length:end], + attention_mask=attention_mask[:, :end], + past_key_values=past_key_values, + return_dict=True, + output_hidden_states=True, + use_cache=True, + ) + + # Get logits and shift + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = input_ids[..., past_length+1:end].contiguous() + + # Flatten tokens for loss calculation + active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1) + active_logits = shift_logits.view(-1, shift_logits.size(-1))[active] + active_labels = shift_labels.view(-1)[active] + + # Calculate loss + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct(active_logits, active_labels) + + # Apply condition filtering if required + if condition_mode == "prefix": + loss = loss[condition_pos_id:] + + segments = [text] if text else [] + lines_info = [] + + # Calculate mean perplexity + mean_loss = loss.mean() if len(loss) > 0 else torch.tensor(0.0) + ppl = torch.exp(mean_loss).item() if mean_loss.item() != float('inf') else float('inf') + + result = { + "loss": loss, + "input_ids": input_ids, + "attention_mask": attention_mask, + "lines_info": lines_info, + "segments": segments, + "ppl": ppl, + } + + if return_kv: + result["past_key_values"] = outputs.past_key_values + else: + # Cache the result if we're not returning KV cache + self.cache["perplexity"][cache_key] = result + self._manage_cache_size("perplexity") + + return result + + def __get_lines_info(self, lines, input_ids, loss): + """ + Get information about each line including start/end positions and importance. + + Args: + lines: List of lines in the text + input_ids: Token IDs for the entire text + loss: Per-token loss values + + Returns: + List of dictionaries with line information + """ + line_info = [] + cumulative_tokens = 0 + + input_ids_list = input_ids.cpu().tolist() + + for i, line in enumerate(lines): + if not line.strip(): + continue + + # Encode each line to find its token length + line_tokens = self.tokenizer.encode(line, add_special_tokens=False) + line_length = len(line_tokens) + + # Find position in the tokenized text + start_pos = cumulative_tokens + end_pos = start_pos + line_length + + # Calculate mean loss (importance) for this line + # Loss might be shorter than the token IDs due to shifting + if isinstance(loss, torch.Tensor) and start_pos < len(loss) and end_pos <= len(loss): + line_loss = loss[start_pos:end_pos].mean().item() + else: + # Handle edge cases + line_loss = float("inf") + + line_info.append({ + "line": line, + "start": start_pos, + "end": end_pos, + "importance": line_loss, + "tokens": line_length + }) + + cumulative_tokens += line_length + + return line_info + + def get_prefix_length(self, prefix: str, text: str): + """ + Calculate the length of a prefix in tokens when concatenated with a text. + + Args: + prefix: The prefix text + text: The main text + + Returns: + Length of the prefix in tokens + """ + possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1) + full_input_ids = self.tokenizer(prefix + text[:100], add_special_tokens=False).input_ids + + for i in range(possible_prefix_token, len(full_input_ids)): + cur_prefix = self.tokenizer.decode(full_input_ids[:i]) + if cur_prefix == prefix: + break + + return i + + def get_condition_ppl( + self, + text: str, + question: str, + condition_in_question: str = "none", + granularity: str = "line", + ): + """ + Calculate perplexity change of a question when given context text. + A positive change means the context helps reduce question perplexity. + + Args: + text: The context text + question: The question to evaluate + condition_in_question: Conditioning mode (none, prefix) + granularity: Granularity for perplexity calculation + + Returns: + Perplexity change for the question with/without context + """ + # Create a cache key for this conditional perplexity calculation + cache_key = f"{text}_{question}_{condition_in_question}_{granularity}" + + if cache_key in self.cache["conditional_ppl"]: + return self.cache["conditional_ppl"][cache_key] + + if condition_in_question == "none": + # Just return the perplexity of the text + result = self.get_ppl( + text=text, granularity=granularity, condition_mode="none" + ) + ppl_value = result["ppl"] + else: + # First calculate question perplexity without context + question_ppl_without_context = self.get_ppl( + text=question, + granularity=granularity + )["ppl"] + + # Then calculate question perplexity with context + question_ppl_with_context = self.get_ppl( + text=text + "\n\n" + question, + granularity=granularity, + condition_mode="prefix", + condition_pos_id=self.get_token_length(text + "\n\n", add_special_tokens=True) + )["ppl"] + + # Calculate the change (positive means context helps) + ppl_value = question_ppl_without_context - question_ppl_with_context + + # Cache the result + self.cache["conditional_ppl"][cache_key] = ppl_value + self._manage_cache_size("conditional_ppl") + + return ppl_value + + def control_context_budget( + self, + context_list: List[str], + target_token: float, + question: str = "", + reorder_context: str = "original", + condition_in_question: str = "none", + force_context_ids: List[int] = None, + force_context_number: int = None, + context_budget: str = "+100", + dynamic_context_compression_ratio: float = 0.0, + ): + """ + Control token budget for contexts based on relevance ranking, following LongLLMLingua. + + Args: + context_list: List of contexts + target_token: Target number of tokens + question: Question for relevance ranking + reorder_context: How to reorder contexts ("original", "importance", "two_stage") + condition_in_question: Mode for conditional ranking + force_context_ids: List of context IDs to always include + force_context_number: Number of contexts to forcibly include + context_budget: String expression to modify target token budget + dynamic_context_compression_ratio: Ratio for dynamic compression (0.0-1.0) + + Returns: + Selected contexts, their indices, and dynamic ratios + """ + logger.debug(f"Controlling context budget with target_token={target_token}") + start_time = time.time() + + if not context_list: + return [], [], [] + + # Get token counts for each context + logger.debug("Calculating token lengths for contexts") + context_tokens_length = [self.get_token_length(context) for context in context_list] + + # If total tokens already fit within budget, return all contexts + total_tokens = sum(context_tokens_length) + if total_tokens <= target_token: + logger.debug(f"All contexts fit within budget ({total_tokens} <= {target_token})") + end_time = time.time() + logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds") + return context_list, list(range(len(context_list))), [0.0] * len(context_list) + + # Rank contexts by relevance if question is provided + logger.debug("Ranking contexts by relevance") + if question: + # Get perplexity change for each context with the question + context_ppl_changes = [] + for d, dl in zip(context_list, context_tokens_length): + # Calculate how much this context reduces question perplexity + ppl_change = self.get_condition_ppl( + d, + question, + condition_in_question, + ) + # Apply length adjustment factor similar to before + context_ppl_changes.append(ppl_change - dl * 2 / 250 * 0) + + # Sort by perplexity change - higher is better (more reduction in question perplexity) + demonstrations_sort = sorted(enumerate(context_ppl_changes), key=lambda x: -x[1]) + else: + # Without question, use default ordering + demonstrations_sort = [(i, 0) for i in range(len(context_list))] + + # Extract ranking for later use + self.context_idxs.append([x for idx, (x, _) in enumerate(demonstrations_sort)]) + + # Calculate the target token budget with context_budget expression + if target_token < 0: + target_token = 100 + target_token = eval("target_token" + context_budget) + + # Initialize selected context tracking + used = force_context_ids if force_context_ids is not None else [] + + # Select contexts until we reach the token budget + for idx, _ in demonstrations_sort: + if idx >= len(context_tokens_length): + continue + target_token -= context_tokens_length[idx] + if idx not in used: + used.append(idx) + if target_token < 0 or ( + force_context_number is not None and len(used) >= force_context_number + ): + break + + # Store original selection order + original_used = used.copy() + + # Reorder contexts if requested + if reorder_context == "original": + used = sorted(used) + elif reorder_context == "two_stage": + l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [ + _ for idx, _ in enumerate(used) if idx % 2 == 1 + ] + used = l + r[::-1] + + # Calculate dynamic compression ratios if requested + if dynamic_context_compression_ratio > 0: + N = len(used) + dynamic_ratio = [ + i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0 + for i in range(-(N - 1), N, 2) + ][::-1] + dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)} + dynamic_ratio = [dynamic_ratio_map[i] for i in used] + else: + dynamic_ratio = [0.0] * len(used) + + # Build list of selected contexts + selected_contexts = [context_list[idx] for idx in used if idx < len(context_list)] + + end_time = time.time() + logger.debug(f"Selected {len(selected_contexts)} contexts out of {len(context_list)}") + logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds") + + return selected_contexts, used, dynamic_ratio, demonstrations_sort + + def compress_code_file( + self, + code: str, + query: str = "", + instruction: str = "", + rate: float = 0.5, + target_token: float = -1, + language: str = "python", + use_iterative_compression: bool = True, + iterative_size: int = 200, + dynamic_compression_ratio: float = 0.2, + context_budget: str = "+100", + rank_only: bool = False, + fine_ratio: float = None, + fine_grained_importance_method: str = "conditional_ppl", + min_lines_for_fine_grained: int = 5, + importance_beta: float = 0.5, + use_knapsack: bool = True, + ): + """ + Compress a code file by first splitting it into function-based chunks and then compressing. + Functions are prioritized based on query relevance, similar to LongLLMLingua. + + Args: + code: The code to compress + query: Query to prioritize relevant functions + instruction: Additional instruction to guide compression + rate: Compression rate for coarse-grained (function level) compression (0.0-1.0) + target_token: Target number of tokens (alternative to rate) + language: Programming language of the code + use_iterative_compression: Whether to use iterative compression + iterative_size: Size of each iteration for iterative compression + dynamic_compression_ratio: Ratio for dynamic compression + context_budget: String expression to modify token budget + rank_only: If True, just rank and select contexts without fine-grained compression + fine_ratio: Ratio for fine-grained line selection (0.0-1.0). If None, uses `rate`. + fine_grained_importance_method: Method for scoring line importance ('contrastive_perplexity' or 'conditional_ppl'). Defaults to 'conditional_ppl'. + min_lines_for_fine_grained: Minimum number of lines a function must have to undergo fine-grained compression (otherwise kept fully). + importance_beta: Controls how much function importance affects its individual compression rate during fine-grained compression. + use_knapsack: Whether to use knapsack algorithm for block selection (True) or greedy line-by-line approach (False). + + Returns: + Compressed code and statistics with the following structure: + { + "original_code": Original uncompressed code, + "compressed_code": Compressed code, + "compressed_prompt": Complete compressed prompt with instruction and query, + "original_tokens": Number of tokens in original code, + "compressed_tokens": Number of tokens in compressed code, + "final_compressed_tokens": Number of tokens in final compressed prompt, + "compression_ratio": Ratio of compressed to original tokens, + "function_compressions": Details about compression for each function, + "selected_functions": Indices of selected functions, + "demonstrations_sort": Ranking of functions by importance, + "compressed_chunks": List of compressed code chunks + "fine_grained_method_used": The method used for fine-grained importance scoring. + } + """ + logger.debug(f"Starting code file compression with rate={rate}, target_token={target_token}, language={language}") + start_time = time.time() + + # Split code into function-based chunks + logger.debug("Splitting code into function-based chunks") + code_chunks = self.split_code_by_functions(code, language=language) + logger.debug(f"Split code into {len(code_chunks)} chunks") + + # Calculate total tokens + logger.debug("Calculating total tokens") + total_tokens = sum(self.get_token_length(chunk) for chunk in code_chunks) + logger.debug(f"Total tokens: {total_tokens}") + + # Determine target_token based on rate if not specified + original_target_token = target_token # Store original value if provided + if target_token <= 0: + if rate <= 0: + # Default target if both rate and target_token are invalid + target_token = int(total_tokens * 0.5) + logger.warning(f"Rate and target_token invalid, defaulting target_token to {target_token}") + else: + target_token = int(total_tokens * rate) + logger.debug(f"Coarse Target tokens: {target_token}") + + # Use context budget control to select important functions + logger.debug("Selecting important functions using context budget control") + selected_contexts, selected_indices, dynamic_ratios, demonstrations_sort = self.control_context_budget( + code_chunks, + target_token=target_token, + question=query, + reorder_context="original", # Keep original order to maintain code structure + condition_in_question="prefix", + context_budget=context_budget, + dynamic_context_compression_ratio=dynamic_compression_ratio, + ) + + # If rank_only is True, just use the selected contexts without further compression + logger.debug("Using rank-only mode: selecting top functions without fine-grained compression") + compressed_chunks = [] + compressed_tokens = 0 + function_compressions = {} + + # Just keep the selected contexts as is + for i, chunk in enumerate(code_chunks): + if i in selected_indices: + compressed_chunks.append(chunk) + chunk_tokens = self.get_token_length(chunk) + compressed_tokens += chunk_tokens + + # Store compression info - no actual compression in this mode + function_compressions[i] = { + "original_tokens": chunk_tokens, + "compressed_tokens": chunk_tokens, + "compression_ratio": 1.0, + } + else: + # Skip this function completely + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + omission_text = f"{comment_marker} ... " + compressed_chunks.append(omission_text) + compressed_tokens += self.get_token_length(omission_text) + + # Combine compressed chunks + compressed_code = "\n\n".join(compressed_chunks) + + # --- Post-join cleanup for consecutive omission markers --- + logger.debug("Cleaning up consecutive omission markers after joining...") + lines = compressed_code.split("\n") + cleaned_lines = [] + last_non_empty_line_was_omission = False + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + omission_marker_content = f"{comment_marker} ...".strip() # Content to check against + + for line in lines: + stripped_line = line.strip() + if not stripped_line: + # Keep empty lines + cleaned_lines.append(line) + # Don't reset the flag here, wait for a non-empty line + elif stripped_line == omission_marker_content: + if last_non_empty_line_was_omission: + # Skip this consecutive omission marker line + logger.debug(f"Skipping line: '{line}' (consecutive omission)") + continue + else: + # Keep the first omission marker line + cleaned_lines.append(line) + last_non_empty_line_was_omission = True + else: + # Regular code line + cleaned_lines.append(line) + last_non_empty_line_was_omission = False + + compressed_code = "\n".join(cleaned_lines) + logger.debug("Cleanup finished.") + # --- End post-join cleanup --- + + + output = f"{instruction}\n\n{compressed_code}\n\n{query}\n{instruction}" + + # Calculate actual compressed tokens + final_compressed_tokens = self.get_token_length(output) + + end_time = time.time() + logger.debug(f"Code file compression completed in {end_time - start_time:.2f} seconds") + logger.debug(f"Compression ratio: {compressed_tokens / total_tokens if total_tokens > 0 else 1.0:.2f}") + + if rank_only: + return { + "original_code": code, + "compressed_code": compressed_code, + "compressed_prompt": output, + "original_tokens": total_tokens, + "compressed_tokens": compressed_tokens, + "final_compressed_tokens": final_compressed_tokens, + "compression_ratio": compressed_tokens / total_tokens if total_tokens > 0 else 1.0, + "function_compressions": function_compressions, + "selected_functions": selected_indices, + "demonstrations_sort": demonstrations_sort, + "compressed_chunks": compressed_chunks, + "fine_grained_method_used": None, + } + else: + # enter fine-grained compression + logger.debug(f"Starting fine-grained compression on selected functions using method: {fine_grained_importance_method}") + + # --- Dynamic Fine-grained Rate Allocation --- + logger.debug("Calculating dynamic fine-grained compression rates...") + + # 1. Collect data for selected functions + selected_functions_data = [] + importance_map = {idx: score for idx, score in demonstrations_sort} # Map index to score + total_lines_selected = 0 + for i in selected_indices: + if i < len(code_chunks): + chunk = code_chunks[i] + # Use simple line splitting for allocation efficiency + lines = chunk.split("\n") + line_count = len(lines) + score = importance_map.get(i, 0.0) # Default score 0 if not found + selected_functions_data.append({ + "index": i, + "lines": lines, + "line_count": line_count, + "score": score + }) + total_lines_selected += line_count + else: + logger.warning(f"Selected index {i} is out of bounds for code_chunks (length {len(code_chunks)})") + + + # 2. Calculate overall fine-grained target lines + current_fine_ratio = fine_ratio if fine_ratio is not None else rate # Use rate if fine_ratio not set + if original_target_token > 0: # If target_token was set explicitly, derive ratio from it for fine-grained stage + # Estimate target lines based on the ratio of selected tokens to total tokens, then apply fine_ratio + selected_tokens = sum(self.get_token_length(code_chunks[d['index']]) for d in selected_functions_data) + effective_coarse_rate = selected_tokens / total_tokens if total_tokens > 0 else 1.0 + # Use the user-provided fine_ratio, or fall back to rate/coarse target estimate + fine_target_rate = current_fine_ratio + logger.debug(f"Using fine_ratio={fine_target_rate} for fine-grained target calculation.") + target_total_lines = int(total_lines_selected * fine_target_rate) + + else: # Calculate target based on fine_ratio/rate directly applied to selected lines + target_total_lines = int(total_lines_selected * current_fine_ratio) + logger.debug(f"Using current_fine_ratio={current_fine_ratio} for fine-grained target calculation.") + + logger.debug(f"Total lines in selected functions: {total_lines_selected}") + logger.debug(f"Target total lines after fine-grained compression: {target_total_lines}") + + # 3. Separate small and large functions + small_functions = [] + large_functions = [] + lines_in_small_functions = 0 + lines_in_large_functions = 0 + + for data in selected_functions_data: + if data["line_count"] < min_lines_for_fine_grained: + small_functions.append(data) + lines_in_small_functions += data["line_count"] + else: + large_functions.append(data) + lines_in_large_functions += data["line_count"] + + logger.debug(f"Found {len(small_functions)} small functions (< {min_lines_for_fine_grained} lines) with {lines_in_small_functions} total lines (will be kept).") + logger.debug(f"Found {len(large_functions)} large functions (>= {min_lines_for_fine_grained} lines) with {lines_in_large_functions} total lines.") + + # 4. Calculate target lines for large functions + target_lines_for_large = max(0, target_total_lines - lines_in_small_functions) + logger.debug(f"Target lines to keep from large functions: {target_lines_for_large}") + + # 5. Allocate rates for large functions + function_fine_ratios = {} # Map: index -> individual_fine_ratio + + if not large_functions or lines_in_large_functions == 0: + logger.debug("No large functions to compress further or zero lines in large functions.") + global_rate_for_large = 1.0 if lines_in_large_functions > 0 else 0.0 # Should be 0 if lines_in_large_functions is 0 + elif target_lines_for_large <= 0: + logger.debug("Target lines for large functions is <= 0. Setting rates to 0.") + global_rate_for_large = 0.0 + elif target_lines_for_large >= lines_in_large_functions: + logger.debug("Target lines for large functions >= total lines. Setting rates to 1.0.") + global_rate_for_large = 1.0 + else: + global_rate_for_large = target_lines_for_large / lines_in_large_functions + logger.debug(f"Global target rate for large functions: {global_rate_for_large:.4f}") + + # Normalize scores for weighting (MinMax scaling) + scores = [d["score"] for d in large_functions] + valid_scores = [s for s in scores if not math.isinf(s) and not math.isnan(s)] + + if not valid_scores or max(valid_scores) == min(valid_scores): + logger.debug("Scores are uniform or invalid, using global rate for all large functions.") + for data in large_functions: + function_fine_ratios[data["index"]] = global_rate_for_large + else: + min_score = min(valid_scores) + max_score = max(valid_scores) + normalized_scores = [(s - min_score) / (max_score - min_score) if not math.isinf(s) and not math.isnan(s) else 0.0 for s in scores] # Normalize to [0, 1], default 0 for invalid + + # Calculate initial biased rates + initial_rates = [] + for norm_score in normalized_scores: + # Bias rate: higher score -> higher rate (closer to 1) + # Beta controls sensitivity. beta=0 -> uniform rate. beta=1 -> max sensitivity. + biased_rate = global_rate_for_large * (1 + importance_beta * (norm_score - 0.5) * 2) # Scale norm_score diff to [-beta, beta] + clamped_rate = max(0.0, min(1.0, biased_rate)) # Clamp to [0, 1] + initial_rates.append(clamped_rate) + + # Calculate actual lines kept with initial rates + actual_lines_kept = sum(initial_rates[i] * large_functions[i]["line_count"] for i in range(len(large_functions))) + logger.debug(f"Initial biased rates calculated. Estimated lines kept: {actual_lines_kept:.1f}") + + # Adjust rates proportionally to meet target + if actual_lines_kept > 0 and abs(actual_lines_kept - target_lines_for_large) > 1: # Adjust if difference is significant + adjustment_factor = target_lines_for_large / actual_lines_kept + logger.debug(f"Adjusting rates by factor: {adjustment_factor:.4f}") + final_rates = [max(0.0, min(1.0, r * adjustment_factor)) for r in initial_rates] # Adjust and clamp again + else: + logger.debug("Initial rates are close enough or actual_lines_kept is 0, no adjustment needed.") + final_rates = initial_rates + + for i, data in enumerate(large_functions): + function_fine_ratios[data["index"]] = final_rates[i] + + # Set rate 1.0 for small functions + for data in small_functions: + function_fine_ratios[data["index"]] = 1.0 + + # --- End Dynamic Allocation --- + + + # Apply fine-grained compression to each selected function + fine_compressed_chunks = [] + compressed_tokens = 0 + function_compressions = {} + + # Define a smoothing window size for moving average + smoothing_window = 5 + # fine_ratio = fine_ratio if fine_ratio is not None else rate # Use the same ratio by default if fine_ratio not specified # Removed, using individual ratios now + + # Process each chunk in the original order + # Use tqdm.auto for compatibility + fine_grained_pbar = tqdm(enumerate(code_chunks), total=len(code_chunks), desc="Fine-Grained Compression", leave=False) + for i, chunk in fine_grained_pbar: + # for i, chunk in enumerate(code_chunks): + if i in selected_indices: + # This function was selected during coarse-grained compression + individual_fine_ratio = function_fine_ratios.get(i) # Get dynamically assigned ratio + if individual_fine_ratio is None: + logger.error(f"Missing fine-grained ratio for selected function index {i}. Skipping fine-grained compression for this chunk.") + individual_fine_ratio = 1.0 # Fallback to keep the chunk + + # Use Entropy chunking for fine-grained compression instead of simple line splitting + chunks, sentences, ppls, spike_indices = self.entropy_chunking.chunk_text_adaptive( + code_chunks[i], method='std', k=0.2 + ) + # Use chunks as lines, but preserve all chunks including empty ones to maintain formatting + chunk_lines = chunks # Keep all chunks to preserve \n\n and formatting + chunk_line_count = len([chunk for chunk in chunk_lines if chunk.strip()]) # Count only non-empty for logic + chunk_score = importance_map.get(i, float('nan')) # Get score + + logger.debug(f"Processing Func {i}: Entropy Chunks={len(chunk_lines)}, Non-empty={chunk_line_count}, Score={chunk_score:.4f}, Assigned FineRatio={individual_fine_ratio:.4f}") + + + # Skip fine-grained compression if ratio is 1.0 (or close) or function is small + if individual_fine_ratio >= 0.999 or chunk_line_count < min_lines_for_fine_grained: + note = "Kept (Ratio=1.0)" if individual_fine_ratio >= 0.999 else f"Kept (Small Func < {min_lines_for_fine_grained} lines)" + logger.debug(f" - {note}") + fine_compressed_chunks.append(chunk) + chunk_tokens = self.get_token_length(chunk) + compressed_tokens += chunk_tokens + function_compressions[i] = { + "original_tokens": chunk_tokens, + "compressed_tokens": chunk_tokens, + "compression_ratio": 1.0, + "individual_fine_ratio": individual_fine_ratio, + "note": note, + "importance_method": None # No line importance calculation needed + } + continue # Move to next chunk + + + # Apply fine-grained compression only if the function is large enough + # and we're not in rank-only mode (already checked) and ratio < 1.0 + if chunk_line_count >= min_lines_for_fine_grained and individual_fine_ratio < 0.999: + logger.debug(f" - Applying fine-grained compression with ratio {individual_fine_ratio:.4f}") + fine_grained_pbar.set_description(f"Fine-Grained Compressing Func {i}") + + # Calculate target tokens for this function + original_func_tokens = self.get_token_length(chunk) + target_func_tokens = int(original_func_tokens * individual_fine_ratio) + + # Calculate importance for each block based on the chosen method + block_importances = [] + importance_calculation_start = time.time() + + if fine_grained_importance_method == "conditional_ppl": + # Calculate conditional PPL importance for each block + if not query or not query.strip(): + logger.warning(f"Query is empty for func {i}, cannot calculate conditional PPL. Assigning 0 importance.") + block_importances = [0.0] * len(chunk_lines) + else: + query_ppl_result = self.get_ppl(query, granularity="line") + query_ppl_without_context = query_ppl_result["ppl"] + + if math.isinf(query_ppl_without_context): + logger.warning(f"Base query PPL is infinite for func {i}. Assigning 0 importance to blocks.") + block_importances = [0.0] * len(chunk_lines) + else: + pbar_cond = tqdm(enumerate(chunk_lines), total=len(chunk_lines), desc=f"Func {i} Block CondPPL", leave=False) + for block_idx, block in pbar_cond: + if not block.strip(): + block_importances.append(-float('inf')) # Low score for empty blocks + continue + + conditional_text = block + "\n\n" + query + prefix_len_text = block + "\n\n" + prefix_len = self.get_token_length(prefix_len_text, add_special_tokens=True) + + cond_ppl_result = self.get_ppl( + text=conditional_text, + granularity="line", + condition_mode="prefix", + condition_pos_id=prefix_len - 1 + ) + ppl_with_context = cond_ppl_result["ppl"] + + if math.isinf(ppl_with_context): + ppl_change = -float('inf') + else: + ppl_change = query_ppl_without_context - ppl_with_context + + block_importances.append(ppl_change) + pbar_cond.set_description(f"Func {i} Block CondPPL (B{block_idx}: {ppl_change:.2f})") + + elif fine_grained_importance_method == "contrastive_perplexity": + # Calculate contrastive PPL importance for each block + fine_grained_pbar.set_description(f"Fine-Grained ContrastivePPL Func {i}") + + with torch.no_grad(): + pbar = tqdm(enumerate(chunk_lines), total=len(chunk_lines), desc="Block Contrastive PPL", leave=False) + for block_idx, block in pbar: + if not block.strip(): + block_importances.append(-float('inf')) + continue + + # Build context from previous blocks + prev_context = "\n\n".join(chunk_lines[:block_idx]) if block_idx > 0 else "" + + # 1. PPL(Block | prev_blocks) + regular_ppl_condition = prev_context + "\n\n" if prev_context else None + regular_ppl = self._calculate_perplexity_for_contrastive(block, condition_text=regular_ppl_condition) + + # 2. PPL(Block | query, prev_blocks) + question_context_parts = [query] + if prev_context: + question_context_parts.append(prev_context) + question_context = "\n\n".join(filter(None, question_context_parts)) + cond_ppl_condition = question_context + "\n\n" + cond_ppl = self._calculate_perplexity_for_contrastive(block, condition_text=cond_ppl_condition) + + # 3. Importance = PPL(Block|prev) - PPL(Block|Q,prev) + if math.isinf(regular_ppl) or math.isinf(cond_ppl): + importance = -float('inf') + else: + importance = regular_ppl - cond_ppl + + block_importances.append(importance) + pbar.set_description(f"Block {block_idx}: {importance:.2f}") + + else: + raise ValueError(f"Unsupported fine_grained_importance_method: {fine_grained_importance_method}") + + importance_calculation_end = time.time() + logger.debug(f" - Block importance calculation took {importance_calculation_end - importance_calculation_start:.2f}s") + + # Identify preserved blocks (function signature, comments, returns) + preserved_block_indices = set() + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + + # Find blocks containing function signature + for block_idx, block in enumerate(chunk_lines): + block_lines = block.split('\n') + for line in block_lines: + if line.strip(): + # Check for function/class definitions + if any(keyword in line for keyword in ['def ', 'class ', 'function ', 'fn ', 'func ']): + preserved_block_indices.add(block_idx) + break + # Check for function-level comments + if line.strip().startswith(comment_marker): + preserved_block_indices.add(block_idx) + break + # Check for return statements + if 'return ' in line: + preserved_block_indices.add(block_idx) + break + break # Only check first non-empty line of each block + + # Choose selection method based on use_knapsack parameter + processing_start = time.time() + + if use_knapsack: + # Use knapsack algorithm to select blocks + logger.debug(f" - Using knapsack algorithm for block selection") + selected_block_indices, selection_info = self._knapsack_block_selection( + blocks=chunk_lines, + block_importances=block_importances, + target_tokens=target_func_tokens, + preserved_block_indices=preserved_block_indices, + language=language + ) + + # Build compressed chunk from selected blocks + compressed_blocks = [] + + # Determine base indentation for omission markers + base_indentation = "" + for block in chunk_lines: + for line in block.split('\n'): + if line.strip(): + match = re.match(r"^(\s*)", line) + if match: + base_indentation = match.group(1) + break + if base_indentation: + break + + omission_marker = f"{base_indentation}{comment_marker} ... " + + # Build output with omission markers for gaps + last_selected_idx = -1 + for block_idx in sorted(selected_block_indices): + # Add omission marker if there's a gap + if last_selected_idx != -1 and block_idx > last_selected_idx + 1: + if not compressed_blocks or compressed_blocks[-1] != omission_marker: + compressed_blocks.append(omission_marker) + + compressed_blocks.append(chunk_lines[block_idx]) + last_selected_idx = block_idx + + # Handle trailing omission if needed + if last_selected_idx != -1 and last_selected_idx < len(chunk_lines) - 1: + if not compressed_blocks or compressed_blocks[-1] != omission_marker: + compressed_blocks.append(omission_marker) + + # Join blocks with double newlines to preserve Entropy chunk structure + compressed_chunk = "\n\n".join(compressed_blocks) + + else: + # Use original greedy line-by-line approach with smoothing + logger.debug(f" - Using original greedy line-by-line approach") + + # Convert block importances to line importances for compatibility + lines = [] + line_importances = [] + line_indices = [] + + for block_idx, (block, block_importance) in enumerate(zip(chunk_lines, block_importances)): + block_lines = block.split('\n') + for line_idx_in_block, line in enumerate(block_lines): + global_line_idx = len(lines) + lines.append(line) + line_importances.append(block_importance) # Use block importance for all lines in block + line_indices.append(global_line_idx) + + # Apply original processing logic with smoothing + full_line_scores = [float('nan')] * len(lines) + for score_idx, original_line_idx in enumerate(line_indices): + if score_idx < len(line_importances): + full_line_scores[original_line_idx] = line_importances[score_idx] + + # Replace NaN/Inf with min valid score for consistent processing + valid_scores = [s for s in full_line_scores if not math.isnan(s) and not math.isinf(s)] + if valid_scores: + min_valid_score = min(valid_scores) + if min_valid_score == float('inf') or min_valid_score == -float('inf') or math.isnan(min_valid_score): + min_replacement_score = 0.0 + else: + min_replacement_score = min_valid_score + + processed_line_scores = [] + for s in full_line_scores: + if math.isnan(s) or s == -float('inf'): + processed_line_scores.append(min_replacement_score) + elif s == float('inf'): + processed_line_scores.append(min_replacement_score) + else: + processed_line_scores.append(s) + else: + processed_line_scores = [0.0] * len(lines) + + # Apply smoothing using moving average + smoothing_window = 5 + smoothed_importances = processed_line_scores.copy() + num_processed_scores = len(processed_line_scores) + for j in range(num_processed_scores): + window_start = max(0, j - smoothing_window // 2) + window_end = min(num_processed_scores, j + smoothing_window // 2 + 1) + window = processed_line_scores[window_start:window_end] + valid_window_scores = [s for s in window if not math.isnan(s) and not math.isinf(s)] + if valid_window_scores: + smoothed_importances[j] = sum(valid_window_scores) / len(valid_window_scores) + + # Find preserved lines (convert block indices to line indices) + preserved_line_indices = set() + line_offset = 0 + for block_idx, block in enumerate(chunk_lines): + block_lines = block.split('\n') + if block_idx in preserved_block_indices: + for line_idx_in_block in range(len(block_lines)): + preserved_line_indices.add(line_offset + line_idx_in_block) + line_offset += len(block_lines) + + # Sort remaining lines by importance + sortable_lines = [] + for idx in range(len(lines)): + if idx not in preserved_line_indices: + if idx < len(line_indices) and idx < len(line_importances): + original_score = line_importances[idx] + if not math.isnan(original_score) and not math.isinf(original_score): + smoothed_score = smoothed_importances[idx] + sortable_lines.append((idx, smoothed_score)) + + # Sort descending by score + sorted_line_indices = sorted(sortable_lines, key=lambda x: -x[1]) + + # Calculate target number of lines to keep + total_lines = len(lines) + preserved_count = len(preserved_line_indices) + target_lines = max(preserved_count, int(total_lines * individual_fine_ratio)) + + # Select top lines by importance up to target + selected_lines = set(preserved_line_indices) + for idx, score in sorted_line_indices: + if len(selected_lines) >= target_lines: + break + selected_lines.add(idx) + + # Build compressed chunk from selected lines + compressed_chunks = [] + base_indentation = "" + if lines: + for line in lines: + if line.strip(): + match = re.match(r"^(\s*)", line) + if match: + base_indentation = match.group(1) + break + + omission_marker_line = f"{base_indentation}{comment_marker} ... " + + last_added_line_idx = -1 + for j in range(len(lines)): + if j in selected_lines: + if last_added_line_idx != -1 and j > last_added_line_idx + 1: + if not compressed_chunks or compressed_chunks[-1] != omission_marker_line: + compressed_chunks.append(omission_marker_line) + compressed_chunks.append(lines[j]) + last_added_line_idx = j + + if last_added_line_idx != -1 and last_added_line_idx < len(lines) - 1: + if not compressed_chunks or compressed_chunks[-1] != omission_marker_line: + compressed_chunks.append(omission_marker_line) + + compressed_chunk = "\n".join(compressed_chunks) + + # Create selection info for compatibility + selection_info = { + "method": "greedy_line_by_line", + "preserved_lines": len(preserved_line_indices), + "selected_lines": len(selected_lines), + "total_lines": len(lines), + "smoothing_applied": True + } + selected_block_indices = preserved_block_indices # For compatibility + + processing_end = time.time() + method_name = "knapsack" if use_knapsack else "greedy" + logger.debug(f" - {method_name} selection took {processing_end - processing_start:.2f}s") + + if use_knapsack: + logger.debug(f" - Selected {len(selected_block_indices)}/{len(chunk_lines)} blocks") + else: + logger.debug(f" - Selected {len(selected_lines)}/{len(lines)} lines") + + # Update token count and store compression info + fine_compressed_chunks.append(compressed_chunk) + compressed_chunk_tokens = self.get_token_length(compressed_chunk) + compressed_tokens += compressed_chunk_tokens + + # Store compression info + actual_compression_ratio = compressed_chunk_tokens / original_func_tokens if original_func_tokens > 0 else 1.0 + function_compressions[i] = { + "original_tokens": original_func_tokens, + "compressed_tokens": compressed_chunk_tokens, + "compression_ratio": actual_compression_ratio, + "individual_fine_ratio": individual_fine_ratio, + "preserved_blocks": list(preserved_block_indices), + "selected_blocks": list(selected_block_indices), + "selection_info": selection_info, + "importance_method": fine_grained_importance_method, + "selection_method": "knapsack" if use_knapsack else "greedy_line_by_line" + } + logger.debug(f" - Compressed func {i}: {original_func_tokens} -> {compressed_chunk_tokens} tokens (Ratio: {actual_compression_ratio:.3f})") + else: + # This case should now be handled by the check at the beginning of the loop + logger.warning(f"Reached unexpected state for func {i}. Keeping chunk as is.") + fine_compressed_chunks.append(chunk) + chunk_tokens = self.get_token_length(chunk) + compressed_tokens += chunk_tokens + function_compressions[i] = { + "original_tokens": chunk_tokens, + "compressed_tokens": chunk_tokens, + "compression_ratio": 1.0, + "individual_fine_ratio": individual_fine_ratio, + "note": "Unexpected state, kept function.", + "importance_method": None + } + + else: + # This function was not selected during coarse-grained compression + # Add a placeholder + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + omission_text = f"{comment_marker} ... " + fine_compressed_chunks.append(omission_text) + compressed_tokens += self.get_token_length(omission_text) + # Log skipped chunk + # logger.debug(f"Skipped Func {i} (not selected in coarse stage)") + + + # Combine fine-grained compressed chunks + compressed_code = "\n\n".join(fine_compressed_chunks) + + # --- Post-join cleanup for consecutive omission markers --- + logger.debug("Cleaning up consecutive omission markers after joining...") + lines = compressed_code.split("\n") + cleaned_lines = [] + last_non_empty_line_was_omission = False + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + omission_marker_content = f"{comment_marker} ...".strip() # Content to check against + + for line in lines: + stripped_line = line.strip() + if not stripped_line: + # Keep empty lines + cleaned_lines.append(line) + # Don't reset the flag here, wait for a non-empty line + elif stripped_line == omission_marker_content: + if last_non_empty_line_was_omission: + # Skip this consecutive omission marker line + logger.debug(f"Skipping line: '{line}' (consecutive omission)") + continue + else: + # Keep the first omission marker line + cleaned_lines.append(line) + last_non_empty_line_was_omission = True + else: + # Regular code line + cleaned_lines.append(line) + last_non_empty_line_was_omission = False + + compressed_code = "\n".join(cleaned_lines) + logger.debug("Cleanup finished.") + # --- End post-join cleanup --- + + + # Ensure instruction/query parts are handled correctly, maybe use a template + prompt_parts = [] + if instruction and instruction.strip(): + prompt_parts.append(instruction.strip()) + if compressed_code.strip(): + prompt_parts.append(compressed_code) # Already has newlines handled + if query and query.strip(): + # Add query, potentially repeating instruction based on original logic + prompt_parts.append(query.strip()) + # Decide if instruction should be repeated after query based on original implementation's needs + # if instruction and instruction.strip(): # Repeat instruction if needed + # prompt_parts.append(instruction.strip()) + + output = "\n\n".join(prompt_parts) # Use double newline separation + + # Calculate final compressed tokens + final_compressed_tokens = self.get_token_length(output) + + end_time = time.time() + logger.debug(f"Fine-grained compression processing completed in {end_time - start_time:.2f} seconds") + final_compression_ratio = compressed_tokens / total_tokens if total_tokens > 0 else 1.0 + logger.debug(f"Final Compression ratio (fine-grained tokens / total original tokens): {final_compression_ratio:.4f}") + + + return { + "original_code": code, + "compressed_code": compressed_code, + "compressed_prompt": output, + "original_tokens": total_tokens, + "compressed_tokens": compressed_tokens, + "final_compressed_tokens": final_compressed_tokens, + "compression_ratio": final_compression_ratio, + "function_compressions": function_compressions, + "selected_functions": selected_indices, + "demonstrations_sort": demonstrations_sort, + "compressed_chunks": fine_compressed_chunks, + "fine_grained_method_used": fine_grained_importance_method, + } + + def split_code_by_functions(self, code: str, language: str = "python", custom_separator: str = "# --CHUNK_SEPARATOR-- #") -> List[str]: + """ + Split code into chunks based on function and class definitions for various languages. + Also splits on custom separator if provided. + + Args: + code: The code to split + language: Programming language of the code (python, cpp, java, typescript, rust, go) + custom_separator: Optional custom separator string to also split on + + Returns: + List of code chunks, each containing a function, class, or class method + """ + logger.debug(f"Splitting code by functions and classes for language: {language}") + start_time = time.time() + + # Define regex patterns for different languages + patterns = { + # Python: Simplified to match 'def' or 'class' followed by content until the next def/class or end + "python": r'(^|\n)(\s*)(def|class)\s+[^\n]+(\n(?!\s*(?:def|class)\s)[^\n]*)*', + # C++: Improved to better handle multi-line declarations + "cpp": r'(^|\n)(\s*)(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s*:\s*[^{]*)?|(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*\s+)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?', + # Java: Improved for multi-line method declarations + "java": r'(^|\n)(\s*)(?:(?:public|private|protected|static|final|abstract|synchronized)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:<.*>)?(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*)\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*throws\s+[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?', + # TypeScript: Enhanced to handle multi-line methods and arrow functions + "typescript": r'(^|\n)(\s*)(?:(?:public|private|protected|static|abstract)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:(?:public|private|protected|static|async)\s+)*(?:function\s+)?(?:[a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<.*>)?\s*\([^{;]*\)\s*(?::\s*[^{;]*\s*)?(?:=>)?)\s*(?:{[^}]*}|[^;]*;)?', + # Rust: Improved for multi-line function declarations + "rust": r'(^|\n)(\s*)(?:pub\s+)?(?:struct\s+[a-zA-Z_][a-zA-Z0-9_]*|impl(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+for\s+[a-zA-Z_][a-zA-Z0-9_]*)?|(?:async\s+)?fn\s+[a-zA-Z_][a-zA-Z0-9_]*\s*(?:<.*>)?\s*\([^{;]*\)(?:\s*->\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?', + # Go: Improved for multi-line function declarations + "go": r'(^|\n)(\s*)(?:type\s+[a-zA-Z_][a-zA-Z0-9_]*\s+struct|func\s+(?:\([^)]*\)\s*)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?', + } + + # Use default Python pattern if language not supported + if language.lower() not in patterns: + language = "python" + + # First check if we need to split by custom separator + separator_chunks = [] + if custom_separator and custom_separator in code: + logger.debug(f"Custom separator '{custom_separator}' found, first splitting by separator") + separator_chunks = [chunk for chunk in code.split(custom_separator) if chunk.strip()] + else: + separator_chunks = [code] # Just one chunk - the entire code + + # Function to split a single chunk by functions/classes + def split_chunk_by_pattern(chunk_code): + function_pattern = re.compile(patterns[language.lower()], re.MULTILINE) + matches = list(function_pattern.finditer(chunk_code)) + + if not matches: + return [chunk_code] # No matches, return whole chunk + + result_chunks = [] + + # Add code before first match + if matches[0].start() > 0: + result_chunks.append(chunk_code[:matches[0].start()]) + + # Process each match + for i, match in enumerate(matches): + start = match.start() + + # End is either start of next match or end of code + if i < len(matches) - 1: + end = matches[i + 1].start() + else: + end = len(chunk_code) + + result_chunks.append(chunk_code[start:end]) + + return result_chunks + + # Now apply function/class splitting to each separator chunk + final_chunks = [] + for chunk in separator_chunks: + function_chunks = split_chunk_by_pattern(chunk) + final_chunks.extend(function_chunks) + + end_time = time.time() + logger.debug(f"Code splitting completed in {end_time - start_time:.2f} seconds") + logger.debug(f"Split code into {len(final_chunks)} chunks (using both separator and patterns)") + + return final_chunks + + def _calculate_perplexity_for_contrastive(self, text, condition_text=None): + """Helper to calculate perplexity of text, optionally conditioned on condition_text""" + if condition_text: + full_text = condition_text + text + inputs = self.tokenizer(full_text, return_tensors="pt", add_special_tokens=True).to(self.device) # Use add_special_tokens=True for consistency + + condition_input_ids = self.tokenizer(condition_text, return_tensors="pt", add_special_tokens=True).input_ids + condition_length = condition_input_ids.size(1) + + # Handle potential edge case where condition length might exceed max length or input length + if condition_length >= inputs.input_ids.size(1): + logger.warning(f"Condition length ({condition_length}) >= input length ({inputs.input_ids.size(1)}). Cannot calculate conditional PPL.") + return float('inf') + + with torch.no_grad(): + outputs = self.model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) # Pass attention_mask + + # Logits for the 'text' part, labels are the 'text' part shifted + logits = outputs.logits[0, condition_length-1:-1] + labels = inputs.input_ids[0, condition_length:] + + if logits.size(0) == 0 or labels.size(0) == 0 or logits.size(0) != labels.size(0): + logger.warning(f"Logits/Labels shape mismatch or empty in _calculate_perplexity_for_contrastive (cond). Logits: {logits.shape}, Labels: {labels.shape}. Returning inf.") + return float('inf') # Return inf if shapes mismatch or empty + + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + mean_loss = loss.mean().item() + perplexity = math.exp(mean_loss) if not math.isnan(mean_loss) and not math.isinf(mean_loss) else float('inf') + + else: + # Calculate unconditional perplexity + inputs = self.tokenizer(text, return_tensors="pt", add_special_tokens=True).to(self.device) # Use add_special_tokens=True + with torch.no_grad(): + outputs = self.model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) # Pass attention_mask + + # Logits for all tokens except last, labels are all tokens except first + logits = outputs.logits[0, :-1] + labels = inputs.input_ids[0, 1:] + + if logits.size(0) == 0 or labels.size(0) == 0 or logits.size(0) != labels.size(0): + logger.warning(f"Logits/Labels shape mismatch or empty in _calculate_perplexity_for_contrastive (uncond). Logits: {logits.shape}, Labels: {labels.shape}. Returning inf.") + return float('inf') # Return inf if shapes mismatch or empty + + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + mean_loss = loss.mean().item() + perplexity = math.exp(mean_loss) if not math.isnan(mean_loss) and not math.isinf(mean_loss) else float('inf') + + return perplexity + + def _calculate_contrastive_perplexity(self, code_lines: List[str], question: str): + """ + Calculate contrastive perplexity-based importance for each line of code. + s_i = perplexity(x_i | x_{ Tuple[set, Dict]: + """ + Use knapsack algorithm to select blocks that maximize total importance within token budget. + + Args: + blocks: List of code blocks (Entropy chunks) + block_importances: Importance scores for each block + target_tokens: Target number of tokens to keep + preserved_block_indices: Set of block indices that must be preserved + language: Programming language for omission markers + + Returns: + Tuple of (selected_block_indices, selection_info) + """ + logger.debug(f"Running knapsack block selection with target_tokens={target_tokens}") + + if not blocks: + return set(), {} + + # Calculate token weights for each block + block_weights = [self.get_token_length(block) for block in blocks] + + # Handle preserved blocks + if preserved_block_indices is None: + preserved_block_indices = set() + + # Calculate tokens already used by preserved blocks + preserved_tokens = sum(block_weights[i] for i in preserved_block_indices) + remaining_budget = max(0, target_tokens - preserved_tokens) + + logger.debug(f"Preserved blocks: {len(preserved_block_indices)}, tokens: {preserved_tokens}") + logger.debug(f"Remaining budget for knapsack: {remaining_budget}") + + # If no remaining budget, just return preserved blocks + if remaining_budget <= 0: + return preserved_block_indices, { + "method": "knapsack", + "preserved_only": True, + "total_value": sum(block_importances[i] for i in preserved_block_indices), + "total_weight": preserved_tokens + } + + # Prepare items for knapsack (excluding preserved blocks) + knapsack_items = [] + for i, (weight, value) in enumerate(zip(block_weights, block_importances)): + if i not in preserved_block_indices: + # Handle invalid importance scores + if math.isnan(value) or math.isinf(value): + value = 0.0 + knapsack_items.append((i, weight, value)) + + # Sort by value-to-weight ratio for efficiency (greedy approximation first) + knapsack_items.sort(key=lambda x: x[2] / max(x[1], 1), reverse=True) + + # Use dynamic programming for exact knapsack solution + # For efficiency, limit to reasonable problem size + if len(knapsack_items) <= 100 and remaining_budget <= 2000: + selected_indices = self._solve_knapsack_dp(knapsack_items, remaining_budget) + else: + # Use greedy approximation for large problems + logger.debug("Using greedy approximation for large knapsack problem") + selected_indices = self._solve_knapsack_greedy(knapsack_items, remaining_budget) + + # Combine with preserved blocks + final_selection = preserved_block_indices.union(selected_indices) + + # Calculate selection statistics + total_value = sum(block_importances[i] for i in final_selection) + total_weight = sum(block_weights[i] for i in final_selection) + + selection_info = { + "method": "knapsack", + "preserved_blocks": len(preserved_block_indices), + "selected_blocks": len(selected_indices), + "total_blocks": len(final_selection), + "total_value": total_value, + "total_weight": total_weight, + "target_weight": target_tokens, + "efficiency": total_value / max(total_weight, 1) + } + + logger.debug(f"Knapsack selection: {len(final_selection)}/{len(blocks)} blocks, " + f"value={total_value:.2f}, weight={total_weight}/{target_tokens}") + + return final_selection, selection_info + + def _solve_knapsack_dp(self, items: List[Tuple[int, int, float]], capacity: int) -> set: + """ + Solve knapsack problem using dynamic programming. + + Args: + items: List of (index, weight, value) tuples + capacity: Maximum weight capacity + + Returns: + Set of selected item indices + """ + n = len(items) + if n == 0 or capacity <= 0: + return set() + + # DP table: dp[i][w] = maximum value using first i items with weight limit w + dp = [[0.0 for _ in range(capacity + 1)] for _ in range(n + 1)] + + # Fill DP table + for i in range(1, n + 1): + idx, weight, value = items[i - 1] + for w in range(capacity + 1): + # Don't take item i + dp[i][w] = dp[i - 1][w] + + # Take item i if it fits + if weight <= w: + dp[i][w] = max(dp[i][w], dp[i - 1][w - weight] + value) + + # Backtrack to find selected items + selected = set() + w = capacity + for i in range(n, 0, -1): + if dp[i][w] != dp[i - 1][w]: + idx, weight, value = items[i - 1] + selected.add(idx) + w -= weight + + return selected + + def _solve_knapsack_greedy(self, items: List[Tuple[int, int, float]], capacity: int) -> set: + """ + Solve knapsack problem using greedy approximation (by value/weight ratio). + + Args: + items: List of (index, weight, value) tuples (should be pre-sorted by ratio) + capacity: Maximum weight capacity + + Returns: + Set of selected item indices + """ + selected = set() + current_weight = 0 + + for idx, weight, value in items: + if current_weight + weight <= capacity: + selected.add(idx) + current_weight += weight + + return selected + +if __name__ == "__main__": + # Load real examples from the dataset + # with open("exp-cur50lines-bg5000tokens/results/deepseek-coder-6.7b-instruct/method_code_compressor_t2048_rankonly/deepseek-ai_slash_deepseek-coder-6.7b-instruct.jsonl", "r") as f: + with open("exp-cur50lines-bg5000tokens-500examples/results/mistral-7b-instruct/method_code_compressor_t512_rankonly/mistralai_slash_Mistral-7B-Instruct-v0.3.jsonl", "r") as f: + data = [json.loads(line) for line in f] + + example = data[190] + # print(example.keys()) # dict_keys(['id', 'gt', 'original_background_context', 'original_current_function_context', 'language', 'prompt', 'output', 'es', 'em']) + + context = example["original_background_context"] + question = example["original_current_function_context"] + ground_truth = example["gt"] + + # Initialize compressor + logger.info("Initializing compressor...") + model_name = "Qwen/Qwen2.5-Coder-7B-Instruct" + compressor = CodeCompressor(model_name=model_name) + + # Test function-based code file compression with query + logger.info("\nTesting function-based code file compression with query...") + + original_tokens = len(compressor.tokenizer.encode(context)) + target_token = 512 + target_ratio = min(1.0, max(0.0, target_token / original_tokens)) + logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}") + + result = compressor.compress_code_file( + code=context, + query=question, # Using current function context as query focus + instruction="Complete the following code function given the context.", + rate=target_ratio, + rank_only=False, # Test fine-grained compression + fine_grained_importance_method="contrastive_perplexity", # Explicitly test default + min_lines_for_fine_grained=5, # New parameter + importance_beta=0.5, # Sensitivity to importance score + use_knapsack=True, + ) + + # show the compressed code + logger.info(f"Compressed code (using {result['fine_grained_method_used']}): \n{result['compressed_code']}") + logger.info(f"Current function context: \n{question}") + # final prompt + final_prompt = result['compressed_prompt'] + # get the completion + try: + tokenized_prompt = compressor.tokenizer(final_prompt, return_tensors="pt").to(compressor.device) + # Increase max_new_tokens for potentially longer completions + completion_ids = compressor.model.generate(**tokenized_prompt, max_new_tokens=128, pad_token_id=compressor.tokenizer.eos_token_id) + # Decode only the generated part, skipping special tokens + completion = compressor.tokenizer.decode(completion_ids[0][len(tokenized_prompt.input_ids[0]):], skip_special_tokens=True) + + # Basic cleanup: remove leading/trailing whitespace and potentially stop words if needed + completion = completion.strip() + # More robust cleanup: Find the first meaningful line if generation includes noise + completion_lines = [line for line in completion.split("\n") if line.strip() and not line.strip().startswith(("#", "//"))] # Simple comment removal + cleaned_completion = completion_lines[0] if completion_lines else completion # Take first non-comment line or original if none found + + except Exception as e: + logger.error(f"Error during generation or decoding: {e}") + cleaned_completion = "[ERROR DURING GENERATION]" + + logger.info(f"Cleaned Completion: {cleaned_completion}") + logger.info(f"Ground truth: {ground_truth}") + + # Optional: Test with conditional_ppl method + logger.info("\nTesting fine-grained compression with conditional_ppl...") + result_cond = compressor.compress_code_file( + code=context, + query=question, + instruction="Complete the following code function given the context.", + rate=target_ratio, + rank_only=False, + fine_grained_importance_method="conditional_ppl", + min_lines_for_fine_grained=5, + importance_beta=0.5 + ) + logger.info(f"Compressed code (using {result_cond['fine_grained_method_used']}): \n{result_cond['compressed_code']}") \ No newline at end of file diff --git a/long-code-completion/compare_empty_line_handling.py b/long-code-completion/compare_empty_line_handling.py new file mode 100644 index 0000000..d2290af --- /dev/null +++ b/long-code-completion/compare_empty_line_handling.py @@ -0,0 +1,190 @@ +import torch +import math +from typing import List +from transformers import AutoModelForCausalLM, AutoTokenizer + +def compare_empty_line_handling(): + """Compare original vs corrected empty line handling in PPL chunking""" + + code_to_be_analyzed = """def evaluate_blind(self, code, **kwargs): + + suffix = kwargs.get('suffix', self.get('suffix', '')) + blind = kwargs.get('blind', False) + + action = self.actions.get('evaluate_blind', {}) + payload_action = action.get('evaluate_blind') + call_name = action.get('call', 'inject') + + # Skip if something is missing or call function is not set + if not action or not payload_action or not call_name or not hasattr(self, call_name): + return + + expected_delay = self._get_expected_delay() + + if '%(code_b64)s' in payload_action: + log.debug('[b64 encoding] %s' % code) + execution_code = payload_action % ({ + 'code_b64' : base64.urlsafe_b64encode(code), + 'delay' : expected_delay + }) + else: + execution_code = payload_action % ({ + 'code' : code, + 'delay' : expected_delay + }) + + return getattr(self, call_name)( + code = execution_code, + prefix = prefix, + suffix = suffix, + blind=True + )""" + + print("="*80) + print("COMPARISON: Empty Line Handling in PPL Chunking") + print("="*80) + + lines = code_to_be_analyzed.split('\n') + + # Simulate original approach (includes empty lines in smoothing) + def original_smoothing(values, window_size=3): + """Original smoothing that includes empty lines""" + smoothed = [] + for i in range(len(values)): + start_idx = max(0, i - window_size // 2) + end_idx = min(len(values), i + window_size // 2 + 1) + + window_values = [] + for j in range(start_idx, end_idx): + if not math.isinf(values[j]) and not math.isnan(values[j]): + window_values.append(values[j]) + + if window_values: + smoothed.append(sum(window_values) / len(window_values)) + else: + smoothed.append(values[i]) + + return smoothed + + # Simulate corrected approach (excludes empty lines from smoothing) + def corrected_smoothing(values, lines, window_size=3): + """Corrected smoothing that excludes empty lines""" + smoothed = [] + + # Identify non-empty line indices + non_empty_indices = [i for i, line in enumerate(lines) if line.strip() != ''] + + for i in range(len(values)): + if lines[i].strip() == '': # Empty line + smoothed.append(values[i]) # Keep original value + else: + # Find position in non-empty indices + try: + pos_in_non_empty = non_empty_indices.index(i) + except ValueError: + smoothed.append(values[i]) + continue + + # Get window around this position in non-empty lines + start_pos = max(0, pos_in_non_empty - window_size // 2) + end_pos = min(len(non_empty_indices), pos_in_non_empty + window_size // 2 + 1) + + # Get values from non-empty lines in the window + window_values = [] + for j in range(start_pos, end_pos): + idx = non_empty_indices[j] + val = values[idx] + if not math.isinf(val) and not math.isnan(val) and val > 0: + window_values.append(val) + + if window_values: + smoothed.append(sum(window_values) / len(window_values)) + else: + smoothed.append(values[i]) + + return smoothed + + # Create sample PPL values (simulated) + sample_ppls = [] + for i, line in enumerate(lines): + if line.strip() == '': + sample_ppls.append(1.0) # Empty line PPL + else: + # Simulate varying PPL values + if 'def ' in line: + sample_ppls.append(101.65) + elif 'return' in line and len(line.strip()) < 20: + sample_ppls.append(1.50) + elif line.strip().startswith('#'): + sample_ppls.append(17.72) + elif 'kwargs.get' in line: + sample_ppls.append(8.39) + elif 'action' in line: + sample_ppls.append(8.17) + elif 'if ' in line: + sample_ppls.append(12.41) + elif 'else:' in line: + sample_ppls.append(1.36) + elif line.strip().startswith("'"): + sample_ppls.append(2.52) + else: + sample_ppls.append(5.0 + (i % 10)) # Varying values + + # Apply both smoothing approaches + original_smoothed = original_smoothing(sample_ppls, window_size=3) + corrected_smoothed = corrected_smoothing(sample_ppls, lines, window_size=3) + + print(f"Total lines: {len(lines)}") + print(f"Empty lines: {len([l for l in lines if l.strip() == ''])}") + print(f"Non-empty lines: {len([l for l in lines if l.strip() != ''])}") + print() + + print("Line-by-line comparison:") + print(f"{'Line':>4} {'Empty':>5} {'Original':>10} {'Corrected':>10} {'Difference':>10} {'Content'}") + print("-" * 80) + + for i, (line, orig_ppl, orig_smooth, corr_smooth) in enumerate(zip(lines, sample_ppls, original_smoothed, corrected_smoothed)): + is_empty = line.strip() == '' + diff = abs(orig_smooth - corr_smooth) + content = repr(line[:40] + "..." if len(line) > 40 else line) + + print(f"{i:4d} {'Yes' if is_empty else 'No':>5} {orig_smooth:10.4f} {corr_smooth:10.4f} {diff:10.4f} {content}") + + print("\n" + "="*80) + print("KEY DIFFERENCES:") + print("="*80) + + # Find lines where smoothing differs significantly + significant_diffs = [] + for i, (orig_smooth, corr_smooth) in enumerate(zip(original_smoothed, corrected_smoothed)): + diff = abs(orig_smooth - corr_smooth) + if diff > 0.1 and lines[i].strip() != '': # Non-empty lines with significant difference + significant_diffs.append((i, diff, orig_smooth, corr_smooth)) + + print(f"\nLines with significant smoothing differences (> 0.1):") + for line_idx, diff, orig, corr in significant_diffs: + print(f"Line {line_idx}: Original={orig:.4f}, Corrected={corr:.4f}, Diff={diff:.4f}") + print(f" Content: {repr(lines[line_idx])}") + + # Show impact on empty lines + empty_line_indices = [i for i, line in enumerate(lines) if line.strip() == ''] + print(f"\nEmpty line smoothing values:") + for idx in empty_line_indices: + print(f"Line {idx}: Original={original_smoothed[idx]:.4f}, Corrected={corrected_smoothed[idx]:.4f}") + + print("\n" + "="*80) + print("SUMMARY:") + print("="*80) + print("Original approach:") + print("- Includes empty lines (PPL=1.0) in smoothing windows") + print("- Can artificially lower smoothed PPL values near empty lines") + print("- May create false local minimums") + + print("\nCorrected approach:") + print("- Excludes empty lines from smoothing calculations") + print("- Only considers non-empty lines for smoothing windows") + print("- Preserves original line indices for visualization") + print("- More accurate representation of code complexity patterns") + +if __name__ == "__main__": + compare_empty_line_handling() \ No newline at end of file diff --git a/long-code-completion/main.py b/long-code-completion/main.py new file mode 100644 index 0000000..df911f4 --- /dev/null +++ b/long-code-completion/main.py @@ -0,0 +1,750 @@ +import os +import json +from tqdm import tqdm +import torch +from transformers import AutoTokenizer, AutoModel +from llmlingua import PromptCompressor +import fire +from utils import load_data, compute_EM, compute_ES +from vllm import LLM, SamplingParams +from loguru import logger +from code_compressor import CodeCompressor +import gc +from typing import List +import re + + +# Helper function for splitting code by functions (standalone version) +def split_code_by_functions_standalone(code: str, language: str = "python") -> List[str]: + """ + Split code into chunks based on function and class definitions for various languages. + Standalone version that doesn't require CodeCompressor instance. + + Args: + code: The code to split + language: Programming language of the code (python, cpp, java, typescript, rust, go) + + Returns: + List of code chunks, each containing a function, class, or class method + """ + # Define regex patterns for different languages + patterns = { + # Python: Simplified to match 'def' or 'class' followed by content until the next def/class or end + "python": r'(^|\n)(\s*)(def|class)\s+[^\n]+(\n(?!\s*(?:def|class)\s)[^\n]*)*', + # C++: Improved to better handle multi-line declarations + "cpp": r'(^|\n)(\s*)(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s*:\s*[^{]*)?|(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*\s+)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?', + # Java: Improved for multi-line method declarations + "java": r'(^|\n)(\s*)(?:(?:public|private|protected|static|final|abstract|synchronized)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:<.*>)?(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*)\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*throws\s+[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?', + # TypeScript: Enhanced to handle multi-line methods and arrow functions + "typescript": r'(^|\n)(\s*)(?:(?:public|private|protected|static|abstract)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:(?:public|private|protected|static|async)\s+)*(?:function\s+)?(?:[a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<.*>)?\s*\([^{;]*\)\s*(?::\s*[^{;]*\s*)?(?:=>)?)\s*(?:{[^}]*}|[^;]*;)?', + # Rust: Improved for multi-line function declarations + "rust": r'(^|\n)(\s*)(?:pub\s+)?(?:struct\s+[a-zA-Z_][a-zA-Z0-9_]*|impl(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+for\s+[a-zA-Z_][a-zA-Z0-9_]*)?|(?:async\s+)?fn\s+[a-zA-Z_][a-zA-Z0-9_]*\s*(?:<.*>)?\s*\([^{;]*\)(?:\s*->\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?', + # Go: Improved for multi-line function declarations + "go": r'(^|\n)(\s*)(?:type\s+[a-zA-Z_][a-zA-Z0-9_]*\s+struct|func\s+(?:\([^)]*\)\s*)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?', + } + + # Use default Python pattern if language not supported + if language.lower() not in patterns: + language = "python" + + function_pattern = re.compile(patterns[language.lower()], re.MULTILINE) + matches = list(function_pattern.finditer(code)) + + if not matches: + return [code] if code.strip() else [] # No matches, return whole code if not empty + + result_chunks = [] + + # Add code before first match if exists + if matches[0].start() > 0: + pre_code = code[:matches[0].start()].strip() + if pre_code: + result_chunks.append(pre_code) + + # Process each match + for i, match in enumerate(matches): + start = match.start() + + # End is either start of next match or end of code + if i < len(matches) - 1: + end = matches[i + 1].start() + else: + end = len(code) + + chunk = code[start:end].strip() + if chunk: + result_chunks.append(chunk) + + return result_chunks + + +# Helper function for function-level RAG retrieval +def function_rag_retrieve(background_code: str, query_code: str, model, tokenizer, device, language: str, top_k: int) -> str: + """Uses function-level chunking and retrieves top_k similar functions.""" + if not background_code.strip(): + return "" # Return empty if no background context + + # Split code into function-based chunks + chunks = split_code_by_functions_standalone(background_code, language) + if not chunks: + return "" # Return empty if chunking results in nothing + + query_embedding = compute_embedding(query_code, model, tokenizer, device) + + chunk_embeddings = [] + valid_chunks = [] + for chunk in chunks: + if chunk.strip(): + chunk_embeddings.append(compute_embedding(chunk, model, tokenizer, device)) + valid_chunks.append(chunk) + + if not valid_chunks: + return "" + + # Stack embeddings for efficient similarity calculation + chunk_embeddings_tensor = torch.stack(chunk_embeddings) + + # Compute cosine similarity + similarities = torch.cosine_similarity(query_embedding.unsqueeze(0), chunk_embeddings_tensor, dim=1) + + # Get top_k indices + top_k_indices = torch.topk(similarities, k=min(top_k, len(valid_chunks)), dim=0).indices + + # Retrieve relevant chunks + retrieved_chunks = [valid_chunks[i] for i in top_k_indices.tolist()] + + # Combine relevant chunks (maintain order by similarity score) + combined_code = "\n\n".join(retrieved_chunks) + + return combined_code + + +# Helper function for sliding window chunking +def chunk_sliding_window(code: str, window_size: int, overlap: int) -> list[str]: + """Splits code into overlapping chunks using a sliding window.""" + lines = code.splitlines() + if not lines: + return [] + + chunks = [] + start = 0 + stride = window_size - overlap + if stride <= 0: + raise ValueError("Overlap size must be smaller than window size.") + + while True: + end = min(start + window_size, len(lines)) + chunk_lines = lines[start:end] + if not chunk_lines: # Should not happen if lines is not empty, but safety check + break + chunks.append("\n".join(chunk_lines)) + if end == len(lines): + break # Exit loop if we reached the end + next_start = start + stride + # If the next window would go past the end, break + if next_start >= len(lines): + # Add the final overlapping chunk if needed + final_start = max(0, len(lines) - window_size) + if final_start > start: # Ensure it's a new chunk not already added + final_chunk_lines = lines[final_start:] + chunks.append("\n".join(final_chunk_lines)) + break + start = next_start + + # Handle case where code is shorter than window size + if not chunks and lines: + return ["\n".join(lines)] + + # Remove duplicates while preserving order (important for RAG) + seen = set() + unique_chunks = [] + for chunk in chunks: + if chunk not in seen: + seen.add(chunk) + unique_chunks.append(chunk) + + return unique_chunks + + +# Helper function to compute embeddings (using mean pooling) +def compute_embedding(text: str, model, tokenizer, device) -> torch.Tensor: + """Computes sentence embedding for a text using the provided model.""" + if not text.strip(): # Handle empty strings + return torch.zeros(model.config.hidden_size).to(device) + inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(device) + with torch.no_grad(): + outputs = model(**inputs) + # Mean pool the last hidden state + embedding = outputs.last_hidden_state.mean(dim=1).squeeze() + return embedding + +# Helper function for RAG retrieval + + +def rag_retrieve(background_code: str, query_code: str, model, tokenizer, device, window_size: int, overlap: int, top_k: int) -> str: + """Chunks background, embeds chunks and query, retrieves top_k similar chunks.""" + if not background_code.strip(): + return "" # Return empty if no background context + + chunks = chunk_sliding_window(background_code, window_size, overlap) + if not chunks: + return "" # Return empty if chunking results in nothing + + query_embedding = compute_embedding(query_code, model, tokenizer, device) + + chunk_embeddings = [] + valid_chunks = [] + for chunk in chunks: + if chunk.strip(): + chunk_embeddings.append(compute_embedding(chunk, model, tokenizer, device)) + valid_chunks.append(chunk) + + if not valid_chunks: + return "" + + # Stack embeddings for efficient similarity calculation + chunk_embeddings_tensor = torch.stack(chunk_embeddings) + + # Compute cosine similarity + similarities = torch.cosine_similarity(query_embedding.unsqueeze(0), chunk_embeddings_tensor, dim=1) + + # Get top_k indices + top_k_indices = torch.topk(similarities, k=min(top_k, len(valid_chunks)), dim=0).indices + + # Retrieve and sort chunks by their original position + relevant_chunks_with_indices = [] + original_indices_map = {chunk_content: idx for idx, chunk_content in enumerate(chunks)} # Map content back to original index + + retrieved_chunk_contents = [valid_chunks[i] for i in top_k_indices.tolist()] + + # Find original start lines to sort chronologically (approximate) + chunk_start_lines = {} + current_line = 0 + lines = background_code.splitlines() + chunk_map_from_sliding = chunk_sliding_window(background_code, window_size, overlap) # Re-chunk to get consistent indexing if needed + start_line_num = 0 + stride = window_size - overlap + for i, chunk_content in enumerate(chunk_map_from_sliding): + # This assumes the chunking function returns chunks in order + chunk_start_lines[chunk_content] = start_line_num + start_line_num += stride + # Rough approximation, doesn't perfectly handle edge cases/final chunks + + sorted_relevant_chunks = sorted( + retrieved_chunk_contents, + key=lambda content: chunk_start_lines.get(content, float('inf')) # Sort by approximate start line + ) + + # Combine relevant chunks + # Original implementation joined with \n, let's keep it simple + combined_code = "\n\n".join(sorted_relevant_chunks) # Separate chunks by double newline for clarity + + return combined_code + + +# Helper function for LLMLingua compression +def compress_llmlingua(context: str, query: str, compressor: PromptCompressor, target_token: int, instruction: str) -> str: + """Compresses context using LLMLingua.""" + if not context.strip(): + return "" + try: + # Ensure no "<|endoftext|>" + context_clean = context.replace("<|endoftext|>", "") + compressed = compressor.compress_prompt( + context_clean, + instruction=instruction, + question=query + "\n" + instruction, # Combine query and instruction for question + target_token=target_token + ) + # Ensure result exists and is string + result = compressed.get('compressed_prompt', '') + return result if isinstance(result, str) else "" + except Exception as e: + logger.error(f"LLMLingua compression failed: {e}") + # Fallback: Truncate based on target tokens (approximate) + tokens = compressor.tokenizer.encode(context_clean) + if len(tokens) > target_token: + return compressor.tokenizer.decode(tokens[:target_token]) + return context_clean + + +# Helper function for LongLLMLingua compression +def compress_longllmlingua(context: str, query: str, compressor: PromptCompressor, target_token: int, instruction: str, chunk_size: int, overlap: int) -> str: + """Compresses context using LongLLMLingua with sliding window chunks.""" + if not context.strip(): + return "" + try: + # Ensure no "<|endoftext|>" + context_clean = context.replace("<|endoftext|>", "") + # Use our sliding window chunker + chunks = chunk_sliding_window(context_clean, chunk_size, overlap) + if not chunks: + return "" # Handle case where context is too short or chunking fails + + compressed = compressor.compress_prompt( + chunks, + instruction=instruction, + question=query + "\n" + instruction, # Combine query and instruction for question + target_token=target_token, + rank_method="longllmlingua" # Use the specified rank method + ) + # Ensure result exists and is string + result = compressed.get('compressed_prompt', '') + return result if isinstance(result, str) else "" + except Exception as e: + logger.error(f"LongLLMLingua compression failed: {e}") + # Fallback: Truncate based on target tokens (approximate) + tokens = compressor.tokenizer.encode(context_clean) + if len(tokens) > target_token: + return compressor.tokenizer.decode(tokens[:target_token]) + return context_clean + +# Helper function for CodeCompressor (Rank Only or Fine-grained) + + +def compress_code_compressor(context: str, query: str, compressor: CodeCompressor, target_token: int, instruction: str, language: str, rank_only: bool, fine_ratio: float, importance_beta: float) -> str: + """Compresses context using CodeCompressor based on target tokens and rank_only flag.""" + if not context.strip(): + return "" + try: + # Ensure no "<|endoftext|>" + context_clean = context.replace("<|endoftext|>", "") + if not context_clean.strip(): + return "" # Return empty if clean context is empty + + # Tokenize to get original length + # Use the compressor's tokenizer + original_tokens = len(compressor.tokenizer.encode(context_clean)) + if original_tokens == 0: + return "" # Avoid division by zero + + # Calculate target ratio + target_ratio = min(1.0, max(0.0, target_token / original_tokens)) + logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}") + + # Pass rank_only and fine_ratio + # Assuming compressor is already initialized with the correct model + compressed_result = compressor.compress_code_file( + code=context_clean, + query=query, # Using current function context as query focus + instruction=instruction, + rate=target_ratio, + language=language, + rank_only=rank_only, # Ensure rank_only mode is set + fine_ratio=fine_ratio if not rank_only else None, # Pass fine_ratio only if not rank_only + importance_beta=importance_beta if not rank_only else None, # Pass importance_beta only if not rank_only + ) + + # Extract compressed content - check both possible keys + compressed_context = compressed_result.get("compressed_code") + + if not isinstance(compressed_context, str): + logger.error(f"CodeCompressor returned non-string: {type(compressed_context)}") + compressed_context = "" # Fallback + + # Log results + compressed_tokens_count = len(compressor.tokenizer.encode(compressed_context)) + final_ratio = (compressed_tokens_count / original_tokens) if original_tokens > 0 else 0 + logger.info(f"CodeCompressor: Compressed tokens={compressed_tokens_count}, Actual ratio={final_ratio:.4f}") + + return compressed_context + + except Exception as e: + logger.error(f"CodeCompressor compression failed: {e}", exc_info=True) + # Fallback: Truncate approximately based on target tokens (less ideal for rank_only) + tokens = compressor.tokenizer.encode(context_clean) + if len(tokens) > target_token: + logger.warning(f"CodeCompressor falling back to simple truncation.") + return compressor.tokenizer.decode(tokens[:target_token]) + return context_clean + +# Function to save scores + + +def save_json(data: dict, file_path: str): + """Saves dictionary data to a JSON file.""" + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, 'w') as f: + json.dump(data, f, indent=4) + + +def generate_completions(llm, batch_prompts, max_new_tokens=128): + # Generate completions for batch + sampling_params = SamplingParams( + temperature=0, + top_p=0.95, + max_tokens=max_new_tokens + ) + + batch_outputs = llm.generate( + batch_prompts, + sampling_params, + use_tqdm=False + ) + + return [x.outputs[0].text for x in batch_outputs] + + +def evaluate_completion( + model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct", + method: str = "full", + result_dir: str = "results/completion_baselines", + embed_model_name: str = "microsoft/unixcoder-base", + compression_model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct", + dataset_path: str = "microsoft/LCC_python", + dataset_split: str = "test", + num_examples: int = 200, + max_new_tokens: int = 128, + batch_size: int = 16, + # RAG params + rag_window_size: int = 80, + rag_overlap: int = 40, + rag_top_k: int = 3, + # Function RAG params + function_rag_language: str = "python", + function_rag_top_k: int = 3, + # LLMLingua params + lingua_target_token: int = 500, + lingua_instruction: str = "Complete the following code function given the context.", + # LongLLMLingua params + longlingua_chunk_size: int = 80, + longlingua_overlap: int = 40, + # CodeCompressor params (New) + code_compressor_target_token: int = 500, + # vLLM params + tensor_parallel_size: int = 1, + trust_remote_code: bool = True, + gpu_memory_utilization: float = 0.9, + filter_current_lines_max: int = 50, + filter_background_tokens_min: int = 3000, + # New CodeCompressor fine-grained param + code_compressor_fine_ratio: float = 1.0, # Default 1.0 means rank_only=True + # New CodeCompressor importance beta param + importance_beta: float = 0.0, # Default beta is 0.0 +): + """Evaluates code completion baselines with a specified context preparation method.""" + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {device}") + + # --- 1. Load Data --- + # Assuming python for now, might need modification if dataset has multiple languages + # Note: Language info might be needed for CodeCompressor if not always python + dataset, _ = load_data(path=dataset_path, split=dataset_split, num_examples=num_examples, + filter_current_lines_max=filter_current_lines_max, filter_background_tokens_min=filter_background_tokens_min) + logger.info(f"Loaded {len(dataset)} examples from {dataset_path} ({dataset_split} split)") + + # --- 2. Initialize Models --- + embed_model = None + embed_tokenizer = None + if method == "rag" or method == "function_rag": + logger.info(f"Initializing embedding model: {embed_model_name}") + embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name) + embed_model = AutoModel.from_pretrained(embed_model_name).to(device) + embed_model.eval() # Set to evaluation mode + logger.info(f"Embedding model {embed_model_name} initialized.") + + lingua_compressor = None + if method == "llmlingua" or method == "longllmlingua": + logger.info(f"Initializing LLMLingua compressor: {compression_model_name}") + lingua_compressor = PromptCompressor(model_name=compression_model_name, device_map="auto") + logger.info(f"LLMLingua compressor {compression_model_name} initialized.") + + code_compressor_instance = None # Renamed to avoid conflict + if method == "code_compressor": + logger.info(f"Initializing CodeCompressor: {compression_model_name}") + # Assuming CodeCompressor takes model name and potentially device + # Pass device explicitly if needed by your CodeCompressor implementation + code_compressor_instance = CodeCompressor(compression_model_name) + logger.info(f"CodeCompressor {compression_model_name} initialized.") + + if method in ["full", "no_context"]: + tokenizer = AutoTokenizer.from_pretrained(model_name) + # try to compress a dummy prompt to avoid cuda error when initializing the vllm (strange bug) + code_compressor_instance = PromptCompressor(model_name=compression_model_name, device_map="auto") + logger.info(f"CodeCompressor {compression_model_name} initialized.") + dummy_prompt = "def hello_world():\n print('Hello, World!')"*100 + compressed_prompt = code_compressor_instance.compress_prompt(dummy_prompt, instruction="Complete the following code function given the context.", question="Complete the following code function given the context.", target_token=500) + logger.info(f"Compressed prompt: {compressed_prompt}") + + # --- 3. Process the Specified Method --- + logger.info(f"--- Processing Method: {method} ---") + + # Modify result directory based on method and parameters + method_suffix = f"method_{method}" + if method == "rag": + method_suffix += f"_w{rag_window_size}_o{rag_overlap}_k{rag_top_k}" + elif method == "function_rag": + method_suffix += f"_lang{function_rag_language}_k{function_rag_top_k}" + elif method == "llmlingua": + method_suffix += f"_t{lingua_target_token}" + elif method == "longllmlingua": + method_suffix += f"_t{lingua_target_token}_cs{longlingua_chunk_size}_o{longlingua_overlap}" + elif method == "code_compressor": + # Determine if rank_only based on fine_ratio + rank_only_for_suffix = (code_compressor_fine_ratio == 1.0) + suffix_detail = "_rankonly" if rank_only_for_suffix else f"fr{code_compressor_fine_ratio}" + # Add importance_beta to suffix + if importance_beta > 0: + suffix_detail += f"_b{importance_beta}" + # Use code_compressor_target_token for consistency + method_suffix += f"_t{code_compressor_target_token}{suffix_detail}" # Updated suffix + + method_result_dir = os.path.join(result_dir, method_suffix) + os.makedirs(method_result_dir, exist_ok=True) + + model_output_path = os.path.join( + method_result_dir, + f"{model_name.replace('/', '_slash_')}.jsonl", + ) + score_output_path = os.path.join( + method_result_dir, + f"{model_name.replace('/', '_slash_')}-SCORES.json", + ) + + all_prompts = [] + original_data = [] # Store original data to merge with results + + # Prepare prompts based on method + for i, example in enumerate(tqdm(dataset, desc=f"Preparing prompts for {method}")): + background_ctx = example['background_context'] + current_func_ctx = example['current_function_context'] # This is the prefix + ground_truth = example['gt'] # This is the completion target + # Determine language - assuming python for now based on dataset path + language = "python" # IMPORTANT: Make dynamic if dataset contains multiple languages + + context_for_prompt = "" + try: + if method == "full": + context_for_prompt = background_ctx + "\n\n" + current_func_ctx + + # some models have max context length of 32768, so we truncate the context (from the head) if it exceeds that + tokenized_context = tokenizer.encode(context_for_prompt) + if len(tokenized_context) > 32768-256: + logger.warning(f"Context length exceeds 32768, truncating from the head. Original length: {len(tokenized_context)}, Truncated length: 32768") + context_for_prompt = tokenizer.decode(tokenized_context[-(32768-256):]) + elif method == "rag": + if not embed_model or not embed_tokenizer: + raise ValueError("RAG method selected but embedding model not initialized.") + retrieved_ctx = rag_retrieve( + background_ctx, current_func_ctx, + embed_model, embed_tokenizer, device, + rag_window_size, rag_overlap, rag_top_k + ) + context_for_prompt = retrieved_ctx + "\n\n" + current_func_ctx + elif method == "function_rag": + if not embed_model or not embed_tokenizer: + raise ValueError("Function RAG method selected but embedding model not initialized.") + retrieved_ctx = function_rag_retrieve( + background_ctx, current_func_ctx, + embed_model, embed_tokenizer, device, + function_rag_language, function_rag_top_k + ) + context_for_prompt = retrieved_ctx + "\n\n" + current_func_ctx + elif method == "llmlingua": + if not lingua_compressor: + raise ValueError("LLMLingua method selected but compressor not initialized.") + compressed_ctx = compress_llmlingua( + background_ctx, current_func_ctx, + lingua_compressor, lingua_target_token, lingua_instruction + ) + context_for_prompt = compressed_ctx + "\n\n" + current_func_ctx + elif method == "longllmlingua": + if not lingua_compressor: + raise ValueError("LongLLMLingua method selected but compressor not initialized.") + compressed_ctx = compress_longllmlingua( + background_ctx, current_func_ctx, + lingua_compressor, lingua_target_token, lingua_instruction, + longlingua_chunk_size, longlingua_overlap + ) + context_for_prompt = compressed_ctx + "\n\n" + current_func_ctx + elif method == "code_compressor": + if not code_compressor_instance: + raise ValueError("CodeCompressor method selected but compressor not initialized.") + # Determine rank_only based on fine_ratio + rank_only = (code_compressor_fine_ratio == 1.0) + logger.info(f"CodeCompressor mode: {'Rank Only' if rank_only else f'Fine-grained (ratio={code_compressor_fine_ratio})'}") + # Use current_func_ctx as the query for CodeCompressor to focus retrieval + compressed_ctx = compress_code_compressor( + context=background_ctx, + query=current_func_ctx, # Query is the current function prefix + compressor=code_compressor_instance, + target_token=code_compressor_target_token, + instruction=lingua_instruction, # Reusing lingua instruction + language=language, + rank_only=rank_only, # Pass determined rank_only flag + fine_ratio=code_compressor_fine_ratio, # Pass fine_ratio + importance_beta=importance_beta, # Pass importance_beta + ) + # Combine the compressed background context with the original current function context + context_for_prompt = compressed_ctx + "\n\n" + current_func_ctx + elif method == "no_context": + context_for_prompt = current_func_ctx + else: + raise ValueError(f"Unknown method: {method}") + + prompt = context_for_prompt.strip() + all_prompts.append(prompt) + original_data.append({ + "id": example.get("id", i), + "gt": ground_truth, + "original_background_context": background_ctx, + "original_current_function_context": current_func_ctx, + "language": language # Store language if needed later + }) + except Exception as e: + logger.warning(f"Error processing example {i} (ID: {example.get('id', 'N/A')}) for method {method}: {e}", exc_info=True) + continue # Skip this example + + # --- 4. Clean up Compression/Embedding Models --- + logger.info("Freeing up GPU memory from compression/embedding models") + if embed_model: + del embed_model + if embed_tokenizer: + del embed_tokenizer + if lingua_compressor: + del lingua_compressor + if code_compressor_instance: + del code_compressor_instance # Clean up CodeCompressor + torch.cuda.empty_cache() + gc.collect() + logger.info("GPU memory freed") + + # --- 5. Initialize Generation LLM --- + # Check if there are any prompts to process before initializing LLM + if not all_prompts: + logger.error(f"No valid prompts were prepared for method {method}. Skipping generation and scoring.") + return + + logger.info(f"Initializing generation LLM: {model_name}") + llm = LLM( + model=model_name, + trust_remote_code=trust_remote_code, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tensor_parallel_size, + max_model_len=32768 + ) + logger.info(f"Generation LLM {model_name} initialized.") + + # --- 6. Generate Completions --- + all_outputs = [] + logger.info(f"Generating completions for {len(all_prompts)} prompts...") + for i in tqdm(range(0, len(all_prompts), batch_size), desc=f"Generating completions for {method}"): + batch_prompts = all_prompts[i:i + batch_size] + if not batch_prompts: + continue + + try: + batch_outputs = generate_completions(llm, batch_prompts, max_new_tokens=max_new_tokens) + all_outputs.extend(batch_outputs) + except Exception as e: + logger.error(f"Error during generation for batch starting at index {i}: {e}") + all_outputs.extend(["ERROR_GENERATING"] * len(batch_prompts)) + + # --- 7. Evaluate and Save Results --- + model_outputs_data = [] + total_es = 0 + total_em = 0 + valid_scores = 0 + + if len(all_outputs) != len(original_data): + logger.warning(f"Warning: Mismatch between generated outputs ({len(all_outputs)}) and original data ({len(original_data)}). Scores might be inaccurate.") + min_len = min(len(all_outputs), len(original_data)) + all_outputs = all_outputs[:min_len] + original_data = original_data[:min_len] + all_prompts = all_prompts[:min_len] + + logger.info(f"Calculating scores and saving results for {len(all_outputs)} examples...") + # make sure that the path exists + os.makedirs(os.path.dirname(model_output_path), exist_ok=True) + with open(model_output_path, "w") as f_out: + for i in range(len(all_outputs)): + output = all_outputs[i] + # Ensure index is valid for original_data and all_prompts + if i >= len(original_data) or i >= len(all_prompts): + logger.error(f"Index {i} out of bounds after potential mismatch alignment. Stopping result processing.") + break + orig_data = original_data[i] + prompt = all_prompts[i] + gt = orig_data['gt'] + + result = { + **orig_data, + "prompt": prompt, + "output": output, + } + + es = 0 + em = 0 + if output != "ERROR_GENERATING" and gt is not None: + try: + es = compute_ES(gt, output) + em = compute_EM(gt, output) + total_es += es + total_em += em + valid_scores += 1 + except Exception as e: + logger.error(f"Error scoring example {orig_data.get('id', i)}: {e}") + + result['es'] = es + result['em'] = em + model_outputs_data.append(result) + f_out.write(json.dumps(result) + "\n") + + logger.info(f"Raw results saved to {model_output_path}") + + avg_es = (total_es / valid_scores) if valid_scores > 0 else 0 + avg_em = (total_em / valid_scores) if valid_scores > 0 else 0 + + # Update the parameters dictionary in scores + scores = { + "model_name": model_name, + "method": method, + "num_examples_scored": valid_scores, + "num_examples_total": len(original_data), # Use length of original_data before potential alignment issues + "average_es": avg_es, + "average_em": avg_em, + "parameters": { + "dataset_path": dataset_path, + "dataset_split": dataset_split, + "filter_current_lines_max": filter_current_lines_max, + "filter_background_tokens_min": filter_background_tokens_min, + "embed_model_name": embed_model_name if method == "rag" or method == "function_rag" else None, + # Combine compression model name reporting + "compression_model_name": compression_model_name if method in ["llmlingua", "longllmlingua", "code_compressor"] else None, + "max_new_tokens": max_new_tokens, + "batch_size": batch_size, + # RAG specific params + "rag_window_size": rag_window_size if method == "rag" else None, + "rag_overlap": rag_overlap if method == "rag" else None, + "rag_top_k": rag_top_k if method == "rag" else None, + # Function RAG params + "function_rag_language": function_rag_language if method == "function_rag" else None, + "function_rag_top_k": function_rag_top_k if method == "function_rag" else None, + # Lingua specific params (shared target token name) + "lingua_target_token": lingua_target_token if method == "llmlingua" or method == "longllmlingua" else None, + # LongLingua specific params + "longlingua_chunk_size": longlingua_chunk_size if method == "longllmlingua" else None, + "longlingua_overlap": longlingua_overlap if method == "longllmlingua" else None, + # CodeCompressor specific params + "code_compressor_target_token": code_compressor_target_token if method == "code_compressor" else None, # Added parameter + "code_compressor_rank_only": (code_compressor_fine_ratio == 1.0) if method == "code_compressor" else None, # Determined by fine_ratio + "code_compressor_fine_ratio": code_compressor_fine_ratio if method == "code_compressor" else None, # Added parameter + "importance_beta": importance_beta if method == "code_compressor" else None, # Added parameter + } + } + + logger.info(f"Method {method}: Avg ES = {avg_es:.2f}, Avg EM = {avg_em:.2f} ({valid_scores}/{len(original_data)} scored)") + save_json(scores, score_output_path) + logger.info(f"Scores saved to {score_output_path}") + + logger.info("Evaluation complete.") + # Clean up LLM explicitly + if 'llm' in locals() and llm is not None: + del llm + logger.info("Generation LLM deleted.") + torch.cuda.empty_cache() + gc.collect() + +if __name__ == "__main__": + fire.Fire(evaluate_completion) diff --git a/long-code-completion/run.sh b/long-code-completion/run.sh new file mode 100644 index 0000000..a214d6c --- /dev/null +++ b/long-code-completion/run.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES=0 + +MODEL_NAME="Qwen/Qwen2.5-Coder-7B-Instruct" +MODEL_PATH_NAME="qwencoder-7b-instruct" +BASE_RESULT_DIR="results/${MODEL_PATH_NAME}" +BASE_LOG_DIR="logs/${MODEL_PATH_NAME}" + +mkdir -p ${BASE_LOG_DIR} +mkdir -p ${BASE_RESULT_DIR} + +echo "Starting experiments for ${MODEL_NAME} on GPU ${CUDA_VISIBLE_DEVICES}" + +# --- CodeCompressor Method Configuration --- +TARGET_TOKENS=(2048 4096) +FINE_RATIOS=(0.5 0.8) +BETAS=(0.0 0.5) + +echo "--- Running CodeCompressor with various configurations ---" +for tokens in "${TARGET_TOKENS[@]}"; do + for ratio in "${FINE_RATIOS[@]}"; do + for beta in "${BETAS[@]}"; do + echo "Running CodeCompressor: target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}" + python main.py \ + --model_name ${MODEL_NAME} \ + --compression_model_name ${MODEL_NAME} \ + --method code_compressor \ + --filter_background_tokens_min 5000 \ + --result_dir "${BASE_RESULT_DIR}" \ + --num_examples 500 \ + --code_compressor_target_token ${tokens} \ + --code_compressor_fine_ratio ${ratio} \ + --importance_beta ${beta} > "${BASE_LOG_DIR}/code_compressor_t${tokens}_fr${ratio}_b${beta}.log" 2>&1 + echo "Finished CodeCompressor: target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}" + done + done +done + +echo "--- Finished CodeCompressor ---" diff --git a/long-code-completion/utils.py b/long-code-completion/utils.py new file mode 100644 index 0000000..ed5543f --- /dev/null +++ b/long-code-completion/utils.py @@ -0,0 +1,288 @@ +import datasets +import editdistance +import numpy as np +import matplotlib.pyplot as plt +from transformers import AutoTokenizer +import re +from tqdm import tqdm + +def compute_ES(target, prediction): + """Compute edit similarity score""" + target_lines = [line.strip() for line in target.splitlines() if line.strip()] + target_str = '\n'.join(target_lines) + prediction_lines = [line.strip() for line in prediction.splitlines() + if line.strip() and not line.strip().startswith("#")][:len(target_lines)] + prediction_str = '\n'.join(prediction_lines) + + return (1 - (editdistance.eval(target_str, prediction_str) / + max(len(target_str), len(prediction_str))))*100 + + +def compute_EM(target, prediction): + """Compute exact match score""" + target_lines = [line.strip() for line in target.splitlines() if line.strip()] + prediction_lines = [line.strip() for line in prediction.splitlines() + if line.strip() and not line.strip().startswith("#")][:len(target_lines)] + + if len(target_lines) != len(prediction_lines): + return 0 + return (int(target_lines == prediction_lines))*100 + + +def load_data(path="microsoft/LCC_python", split="test", num_examples=500, filter_current_lines_max=50, filter_background_tokens_min=5000): + """ + Loads the dataset, processes it to split contexts, filters it based on context lengths, + and returns the filtered dataset along with the tokenizer used. + """ + print(f"Loading initial {num_examples} examples from {path} ({split} split)...") + dataset = datasets.load_dataset(path, split=split) + # keep 5 times of num_examples for testing + dataset = dataset.select(range(num_examples*10)) + original_size = len(dataset) # Size before filtering + + # Initialize tokenizer here for filtering and potential later use + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct") + print("Tokenizer Qwen/Qwen2.5-Coder-7B-Instruct initialized.") + + # Process dataset to add split contexts first + print("Splitting context into background and current function...") + def add_split_context(example): + background, current_func = split_context_ast(example['context']) + example['background_context'] = background + example['current_function_context'] = current_func + return example + + processed_dataset = dataset.map(add_split_context, num_proc=4) # Use multiple processors if available + + # --- Filter the dataset --- + filtered_dataset_list = [] + print(f"Filtering dataset: Keeping examples where current func lines <= {filter_current_lines_max} and background tokens >= {filter_background_tokens_min}.") + + for example in tqdm(processed_dataset): + curr_ctx = example['current_function_context'] + bg_ctx = example['background_context'] + + curr_line_count = len(curr_ctx.splitlines()) + + # Check if background context is non-empty before tokenizing + bg_token_count = 0 + if bg_ctx and bg_ctx.strip(): # Check if bg_ctx is not None and not just whitespace + # Use truncation=True and max_length to prevent overly long sequences if needed, though for filtering, just length is fine. + bg_token_count = len(tokenizer.encode(bg_ctx, add_special_tokens=False)) # Usually better to exclude special tokens for length calculation + + if curr_line_count <= filter_current_lines_max and bg_token_count >= filter_background_tokens_min: + filtered_dataset_list.append(example) + + filtered_dataset = datasets.Dataset.from_list(filtered_dataset_list) + if num_examples > len(filtered_dataset): + selected_dataset = filtered_dataset + else: + selected_dataset = filtered_dataset.select(range(num_examples)) + + print(f"Filtering complete. Original size: {original_size}, Filtered size: {len(filtered_dataset)}. Retaining {min(num_examples, len(filtered_dataset))} examples.") # Adjusted print statement + + # Return both the filtered dataset and the tokenizer + return selected_dataset, tokenizer + + +def find_last_func_or_class_start(code_string): + """ + Finds the starting line of the last top-level function or class definition + using line-based heuristics, robust to syntax errors. + Accounts for decorators. + Returns the 1-based line number or None if not found. + """ + lines = code_string.splitlines() + if not lines: + return None + last_def_line_index = -1 + + # Iterate backwards to find the last line starting with def/async def/class + # We use lstrip() to handle indentation + for i in range(len(lines) - 1, -1, -1): + stripped_line = lines[i].lstrip() + # Using regex for potentially more robust matching (e.g., def func():) + # Matches lines starting with 'def', 'async def', or 'class' followed by space + if re.match(r'^(def|async\s+def|class)\s+', stripped_line): + last_def_line_index = i + break + + if last_def_line_index != -1: + # Found a potential start, now check for decorators above it + start_line_index = last_def_line_index + for i in range(last_def_line_index - 1, -1, -1): + stripped_line = lines[i].lstrip() + if stripped_line.startswith('@'): + start_line_index = i + elif stripped_line == '' or stripped_line.startswith('#'): # Skip blank lines and comments + continue + else: + # Found a non-decorator, non-empty, non-comment line, stop searching upwards + break + return start_line_index + 1 # Return 1-based line number + else: + # Heuristic failed, maybe return the start of the last non-empty block + # or just None if no definitions found at all + return None # No function or class definition found + +def split_context_ast(code_string): + """ + Splits the code context into background and current function/class context using AST. + """ + lines = code_string.splitlines() + split_line_1_based = find_last_func_or_class_start(code_string) + + if split_line_1_based is not None and split_line_1_based > 0: + # split_line_1_based is the start of the function/class + # Background is lines *before* that line + background_lines = lines[:split_line_1_based - 1] + current_func_lines = lines[split_line_1_based - 1:] + return '\n'.join(background_lines), '\n'.join(current_func_lines) + else: + # If no function/class found or parse error, treat all as current + return "", code_string + +def analyze_dataset(dataset, tokenizer): # Added tokenizer parameter + """Analyzes and plots context length distributions, including function counts and token ratios.""" + # --- Analysis (Optional: Recalculate stats on the filtered dataset) --- + background_lines = [] + current_func_lines = [] + background_tokens = [] + current_func_tokens = [] + background_func_counts = [] # Added list for function counts + bg_curr_token_ratios = [] # Added list for token ratios + + + # Ensure tokenizer is available - it's passed as an argument now + # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct") # Removed: Use passed tokenizer + + print(f"\nAnalyzing {len(dataset)} examples...") # Add count here + for example in tqdm(dataset): # Use tqdm here for progress + bg_ctx = example.get('background_context', '') # Use .get for safety + curr_ctx = example.get('current_function_context', '') + + bg_token_count = 0 + curr_token_count = 0 + func_count = 0 + + # Proceed only if contexts exist + if bg_ctx: + bg_lines = bg_ctx.splitlines() + bg_line_count = len(bg_lines) + background_lines.append(bg_line_count) + # Use truncation for safety, exclude special tokens for consistency + bg_token_count = len(tokenizer.encode(bg_ctx, add_special_tokens=False)) + background_tokens.append(bg_token_count) + # Count functions in background context + for line in bg_lines: + if re.match(r'^\s*def\s+', line): # Count lines starting with 'def ' after stripping leading whitespace + func_count += 1 + background_func_counts.append(func_count) + + if curr_ctx: + curr_line_count = len(curr_ctx.splitlines()) + current_func_lines.append(curr_line_count) + curr_token_count = len(tokenizer.encode(curr_ctx, add_special_tokens=False)) + current_func_tokens.append(curr_token_count) + + # Calculate ratio, handle division by zero + if bg_token_count > 0 and curr_token_count > 0: + # Add a small epsilon to avoid potential issues with very small token counts if needed, though direct ratio is fine here. + bg_curr_token_ratios.append(bg_token_count / curr_token_count) + elif bg_token_count > 0 and curr_token_count == 0: + bg_curr_token_ratios.append(np.inf) # Or some large number, or skip - np.inf might break histograms, let's use a large number + # Alternatively, filter these out or handle them specifically during plotting/stats + pass # Let's skip infinity for plotting simplicity + # else: ratio is 0 or undefined, skip + + + # --- Plotting --- + # Check if *any* data exists before proceeding + if not any([background_lines, current_func_lines, background_tokens, current_func_tokens, background_func_counts, bg_curr_token_ratios]): + print("No data points found for analysis after filtering. Skipping plot generation.") + return # Exit if no data to plot + + fig, axs = plt.subplots(3, 2, figsize=(12, 15)) # Changed to 3x2 grid + # Use tokenizer name in titles dynamically if possible, or keep generic + tokenizer_name = tokenizer.name_or_path if hasattr(tokenizer, 'name_or_path') else "Tokenizer" + fig.suptitle(f'Context Analysis (Filtered LCC Python Dataset - {len(dataset)} examples, Tokenizer: {tokenizer_name})') + + # Row 1: Background + # Background Lines + if background_lines: + axs[0, 0].hist(background_lines, bins=50, color='skyblue', edgecolor='black') + print(f"Background Lines: Min={np.min(background_lines)}, Max={np.max(background_lines)}, Avg={np.mean(background_lines):.2f}, Median={np.median(background_lines)}") + else: + axs[0,0].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[0,0].transAxes) + axs[0, 0].set_title('Background Context (Lines)') + axs[0, 0].set_ylabel('Count') + + # Background Tokens + if background_tokens: + axs[0, 1].hist(background_tokens, bins=50, color='skyblue', edgecolor='black') + print(f"Background Tokens: Min={np.min(background_tokens)}, Max={np.max(background_tokens)}, Avg={np.mean(background_tokens):.2f}, Median={np.median(background_tokens)}") + else: + axs[0,1].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[0,1].transAxes) + axs[0, 1].set_title('Background Context (Tokens)') + axs[0, 1].set_ylabel('Count') + + + # Row 2: Background Func Count & Ratio + # Background Function Count + if background_func_counts: + # Use more bins if the range is small, decide based on max count? + max_funcs = np.max(background_func_counts) if background_func_counts else 0 + bins = min(50, max(1, max_funcs + 1)) # Adjust bins based on max count, ensure at least 1 bin + axs[1, 0].hist(background_func_counts, bins=bins, color='lightgreen', edgecolor='black') + print(f"Background Func Count: Min={np.min(background_func_counts)}, Max={max_funcs}, Avg={np.mean(background_func_counts):.2f}, Median={np.median(background_func_counts)}") + else: + axs[1,0].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[1,0].transAxes) + axs[1, 0].set_title('Background Function Count') + axs[1, 0].set_ylabel('Count') + + # Background/Current Token Ratio + if bg_curr_token_ratios: + # Ratios can have a large range, consider log scale or clipping? + # Let's cap the ratio for visualization if it gets too extreme, e.g., at 50 + # ratios_to_plot = [min(r, 50) for r in bg_curr_token_ratios] # Cap ratio at 50 for plot + ratios_to_plot = bg_curr_token_ratios + axs[1, 1].hist(ratios_to_plot, bins=50, color='gold', edgecolor='black') + # Calculate stats on original ratios before clipping for plot + print(f"BG/Current Token Ratio: Min={np.min(bg_curr_token_ratios):.2f}, Max={np.max(bg_curr_token_ratios):.2f}, Avg={np.mean(bg_curr_token_ratios):.2f}, Median={np.median(bg_curr_token_ratios):.2f}") + axs[1, 1].set_title('BG/Current Token Ratio') + + else: + axs[1,1].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[1,1].transAxes) + axs[1, 1].set_ylabel('Count') + + + # Row 3: Current Function + # Current Function Lines + if current_func_lines: + axs[2, 0].hist(current_func_lines, bins=50, color='lightcoral', edgecolor='black') + print(f"Current Func Lines: Min={np.min(current_func_lines)}, Max={np.max(current_func_lines)}, Avg={np.mean(current_func_lines):.2f}, Median={np.median(current_func_lines)}") + else: + axs[2,0].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[2,0].transAxes) + axs[2, 0].set_title('Current Function Context (Lines)') + axs[2, 0].set_xlabel('Number of Lines') + axs[2, 0].set_ylabel('Count') + + # Current Function Tokens + if current_func_tokens: + axs[2, 1].hist(current_func_tokens, bins=50, color='lightcoral', edgecolor='black') + print(f"Current Func Tokens: Min={np.min(current_func_tokens)}, Max={np.max(current_func_tokens)}, Avg={np.mean(current_func_tokens):.2f}, Median={np.median(current_func_tokens)}") + else: + axs[2,1].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[2,1].transAxes) + axs[2, 1].set_title('Current Function Context (Tokens)') + axs[2, 1].set_xlabel('Number of Tokens') + axs[2, 1].set_ylabel('Count') + + + plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap + plt.savefig('context_analysis_distributions_filtered.png') # Save with a new descriptive name + + +if __name__ == "__main__": + # Load data, which now includes filtering and returns the tokenizer + filtered_dataset, tokenizer = load_data(num_examples=2000, filter_current_lines_max=50, filter_background_tokens_min=5000) + analyze_dataset(filtered_dataset, tokenizer) # Pass tokenizer to analyze_dataset \ No newline at end of file diff --git a/module_summarization/code_compressor.py b/module_summarization/code_compressor.py new file mode 100644 index 0000000..76e160b --- /dev/null +++ b/module_summarization/code_compressor.py @@ -0,0 +1,1887 @@ +import torch +import numpy as np +from typing import List, Union, Tuple, Dict, Optional +import re +import math +import zlib +from transformers import AutoModelForCausalLM, AutoTokenizer +import time +from tqdm import tqdm +import copy +import bisect +import json +from llmlingua import PromptCompressor +from loguru import logger + +class EntropyChunking: + def __init__(self, model_name="Qwen/Qwen2.5-Coder-0.5B-Instruct"): + """Entropy-based text chunking implementation""" + logger.debug(f"Loading Entropy chunking model: {model_name}") + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + logger.debug(f"Entropy chunking model loaded on device: {self.device}") + + def split_into_sentences(self, text: str) -> List[str]: + """Split text into sentences, inserting empty lines for double newlines""" + # First replace double newlines with a special marker + text_with_markers = text.replace('\n\n', '\n__EMPTY_LINE__\n') + + # Split by single newlines + lines = text_with_markers.split('\n') + + # Process lines: replace markers with empty strings, keep original lines + sentences = [] + for line in lines: + if line == '__EMPTY_LINE__': + sentences.append(' ') # Empty line for double newline breaks + else: + sentences.append(line) # Keep original line with indentation + + return sentences + + def calculate_sentence_ppl(self, sentences: List[str]) -> List[float]: + """Calculate perplexity for each sentence based on preceding context""" + ppls = [] + + for i, sentence in enumerate(sentences): + if i == 0: + context = "" + target = sentence + else: + context = "\n".join(sentences[:i]) + target = sentence + + ppl = self._compute_ppl(context, target) + ppls.append(ppl) + + return ppls + + def _compute_ppl(self, context: str, target: str) -> float: + """Compute perplexity of target text given context""" + # Handle empty target lines + if not target: + return 0.0 # Assign zero perplexity to empty lines + + if context: + full_text = context + "\n" + target + context_tokens = self.tokenizer(context + "\n", return_tensors="pt", add_special_tokens=True) + context_length = context_tokens.input_ids.shape[1] + else: + full_text = target + context_length = 0 + + inputs = self.tokenizer(full_text, return_tensors="pt", add_special_tokens=True).to(self.device) + + with torch.no_grad(): + outputs = self.model(**inputs) + logits = outputs.logits + + if context_length > 0: + target_logits = logits[0, context_length-1:-1] + target_labels = inputs.input_ids[0, context_length:] + else: + target_logits = logits[0, :-1] + target_labels = inputs.input_ids[0, 1:] + + if len(target_labels) > 0: + log_probs = torch.log_softmax(target_logits, dim=-1) + token_log_probs = log_probs[torch.arange(len(target_labels)), target_labels] + avg_log_prob = token_log_probs.mean().item() + ppl = math.exp(-avg_log_prob) + else: + ppl = float('inf') + + # take log2 of ppl + ppl = math.log2(ppl) + + return ppl + + def calculate_adaptive_thresholds(self, ppls: List[float], k: float = 1.0) -> dict: + """Calculate adaptive thresholds using different statistical methods""" + # Filter out infinite and NaN values + valid_ppls = [p for p in ppls if not math.isinf(p) and not math.isnan(p) and p > 0] + + if len(valid_ppls) < 3: + # Fallback to fixed threshold if not enough valid data + return { + 'std': 0.5, + 'robust_std': 0.5, + 'iqr': 0.5, + 'mad': 0.5 + } + + valid_ppls = np.array(valid_ppls) + + # Method 1: Standard deviation based + mean_ppl = np.mean(valid_ppls) + std_ppl = np.std(valid_ppls) + threshold_std = k * std_ppl + + # Method 2: Robust standard deviation (using median and MAD) + median_ppl = np.median(valid_ppls) + mad = np.median(np.abs(valid_ppls - median_ppl)) + robust_std = mad * 1.4826 # Convert MAD to robust std estimate + threshold_robust_std = median_ppl + k * robust_std + + # Method 3: IQR based (Interquartile Range) + q25 = np.percentile(valid_ppls, 25) + q75 = np.percentile(valid_ppls, 75) + iqr = q75 - q25 + threshold_iqr = q75 + k * iqr + + # Method 4: MAD based (Median Absolute Deviation) + threshold_mad = median_ppl + k * mad + + return { + 'std': threshold_std, + 'robust_std': threshold_robust_std, + 'iqr': threshold_iqr, + 'mad': threshold_mad + } + + def find_ppl_spikes_adaptive(self, values: List[float], method: str = 'std', k: float = 1.0) -> tuple: + """Find PPL spikes using adaptive threshold based on statistical method""" + thresholds = self.calculate_adaptive_thresholds(values, k) + threshold = thresholds[method] + + spike_indices = [] + + for i in range(1, len(values) - 1): + current = values[i] + left = values[i - 1] + right = values[i + 1] + + # Skip infinite or NaN values + if math.isinf(current) or math.isnan(current): + continue + if math.isinf(left) or math.isnan(left): + left = current + if math.isinf(right) or math.isnan(right): + right = current + + # Check if current PPL is significantly higher than both neighbors + left_diff = current - left + right_diff = current - right + + # Condition: Current PPL is higher than both neighbors with adaptive threshold + if (left_diff >= threshold or right_diff >= threshold) and (left_diff >= 0 and right_diff >= 0): + spike_indices.append(i) + + return spike_indices, threshold + + def chunk_text_adaptive(self, text: str, method: str = 'std', k: float = 1.0) -> tuple: + """Perform PPL-based text chunking using adaptive spike detection""" + sentences = self.split_into_sentences(text) + ppls = self.calculate_sentence_ppl(sentences) + spike_indices, threshold = self.find_ppl_spikes_adaptive(ppls, method, k) + + chunks = [] + # Split at spike points (after the spike line) + split_points = [0] + [idx + 1 for idx in spike_indices] + [len(sentences)] + + for i in range(len(split_points) - 1): + start = split_points[i] + end = split_points[i + 1] + chunk_sentences = sentences[start:end] + chunk_text = "\n".join(chunk_sentences) + chunks.append(chunk_text) + + return chunks, sentences, ppls, spike_indices + +class CodeCompressor: + def __init__( + self, + model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4", + device_map: str = "cuda", + model_config: dict = {}, + ): + """ + Initialize the CodeCompressor with a language model for compression. + + Args: + model_name: The name of the model to load from HuggingFace + device_map: Device to load the model on + model_config: Additional configuration for the model + """ + self.model_name = model_name + self.device = device_map + self.model_config = model_config + self.load_model(model_name, device_map, model_config) + + logger.debug("Initializing Entropy chunking...") + self.ppl_chunking = EntropyChunking() + + # Add caching system for model outputs and token information + self.cache = { + "token_length": {}, # Cache for token length by text + "encodings": {}, # Cache for tokenizer encodings + "perplexity": {}, # Cache for perplexity calculations + "conditional_ppl": {}, # Cache for conditional perplexity + "context_rankings": {}, # Cache for context rankings + } + self.max_cache_size = 1000 # Limit cache size to prevent memory issues + + # set up the max position embeddings and cache bos num + self.max_position_embeddings = getattr(self.model.config, "max_position_embeddings", 4096) + self.cache_bos_num = 10 + self.prefix_bos_num = 100 + self.context_idxs = [] + + def load_model( + self, model_name: str, device_map: str = "cuda", model_config: dict = {} + ): + """ + Load the language model and tokenizer. + + Args: + model_name: The name of the model to load + device_map: Device to load the model on + model_config: Additional configuration for the model + """ + logger.debug(f"Loading model {model_name} on {device_map}") + torch_dtype = torch.bfloat16 if "torch_dtype" not in model_config else model_config["torch_dtype"] + # model_kwargs = {"device_map": device_map, "torch_dtype": torch_dtype, "trust_remote_code": True} + model_kwargs = {"device_map": device_map, "torch_dtype": torch_dtype, "trust_remote_code": True} + + for k, v in model_config.items(): + model_kwargs[k] = v + + self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.padding_side = "left" + logger.debug("Model and tokenizer loaded successfully") + + def _manage_cache_size(self, cache_type): + """ + Manage cache size by removing oldest entries when cache exceeds max size. + + Args: + cache_type: The type of cache to manage + """ + if len(self.cache[cache_type]) > self.max_cache_size: + # Remove 20% of the oldest entries + remove_count = int(self.max_cache_size * 0.2) + keys_to_remove = list(self.cache[cache_type].keys())[:remove_count] + for key in keys_to_remove: + del self.cache[cache_type][key] + + def get_token_length( + self, + text: str, + add_special_tokens: bool = True, + ): + """ + Get the number of tokens in the given text. + + Args: + text: The text to tokenize + add_special_tokens: Whether to count special tokens + + Returns: + The number of tokens + """ + # Create a cache key based on text and parameters + cache_key = f"{text}_{add_special_tokens}" + + # Check if result is in cache + if cache_key in self.cache["token_length"]: + return self.cache["token_length"][cache_key] + + # Calculate token length if not in cache + token_length = len(self.tokenizer.encode(text, add_special_tokens=add_special_tokens)) + + # Store in cache + self.cache["token_length"][cache_key] = token_length + self._manage_cache_size("token_length") + + return token_length + + def get_ppl( + self, + text: str, + granularity: str = "line", + input_ids=None, + attention_mask=None, + past_key_values=None, + return_kv=False, + end=None, + condition_mode: str = "none", + condition_pos_id: int = 0, + ): + """ + Calculate perplexity for the given text at line level. + + Args: + text: The text to calculate perplexity for + granularity: The granularity of perplexity calculation (line, token, chunk) + input_ids, attention_mask, past_key_values: Optional pre-processed inputs + return_kv: Whether to return key-values + end: End position for calculation + condition_mode: Mode for conditional perplexity (none, prefix) + condition_pos_id: Position ID for condition + + Returns: + A dictionary with perplexity scores and processing information + """ + # Create a cache key for this specific perplexity calculation + cache_key = f"{text}_{granularity}_{condition_mode}_{condition_pos_id}" + if past_key_values is None and not return_kv and cache_key in self.cache["perplexity"]: + return self.cache["perplexity"][cache_key] + + # Initialize input processing + if input_ids is None: + encoding_key = text + if encoding_key in self.cache["encodings"]: + cached_encoding = self.cache["encodings"][encoding_key] + input_ids = cached_encoding["input_ids"] + attention_mask = cached_encoding["attention_mask"] + else: + encoding = self.tokenizer( + text, + return_tensors="pt", + padding=True + ) + input_ids = encoding["input_ids"].to(self.model.device) + attention_mask = encoding["attention_mask"].to(self.model.device) + + # Cache the encoding + self.cache["encodings"][encoding_key] = { + "input_ids": input_ids, + "attention_mask": attention_mask + } + self._manage_cache_size("encodings") + + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + else: + past_length = 0 + + if end is None: + end = input_ids.shape[1] + end = min(end, past_length + self.max_position_embeddings) + + with torch.no_grad(): + outputs = self.model( + input_ids=input_ids[:, past_length:end], + attention_mask=attention_mask[:, :end], + past_key_values=past_key_values, + return_dict=True, + output_hidden_states=True, + use_cache=True, + ) + + # Get logits and shift + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = input_ids[..., past_length+1:end].contiguous() + + # Flatten tokens for loss calculation + active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1) + active_logits = shift_logits.view(-1, shift_logits.size(-1))[active] + active_labels = shift_labels.view(-1)[active] + + # Calculate loss + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct(active_logits, active_labels) + + # Apply condition filtering if required + if condition_mode == "prefix": + loss = loss[condition_pos_id:] + + segments = [text] if text else [] + lines_info = [] + + # Calculate mean perplexity + mean_loss = loss.mean() if len(loss) > 0 else torch.tensor(0.0) + ppl = torch.exp(mean_loss).item() if mean_loss.item() != float('inf') else float('inf') + + result = { + "loss": loss, + "input_ids": input_ids, + "attention_mask": attention_mask, + "lines_info": lines_info, + "segments": segments, + "ppl": ppl, + } + + if return_kv: + result["past_key_values"] = outputs.past_key_values + else: + # Cache the result if we're not returning KV cache + self.cache["perplexity"][cache_key] = result + self._manage_cache_size("perplexity") + + return result + + def __get_lines_info(self, lines, input_ids, loss): + """ + Get information about each line including start/end positions and importance. + + Args: + lines: List of lines in the text + input_ids: Token IDs for the entire text + loss: Per-token loss values + + Returns: + List of dictionaries with line information + """ + line_info = [] + cumulative_tokens = 0 + + input_ids_list = input_ids.cpu().tolist() + + for i, line in enumerate(lines): + if not line.strip(): + continue + + # Encode each line to find its token length + line_tokens = self.tokenizer.encode(line, add_special_tokens=False) + line_length = len(line_tokens) + + # Find position in the tokenized text + start_pos = cumulative_tokens + end_pos = start_pos + line_length + + # Calculate mean loss (importance) for this line + # Loss might be shorter than the token IDs due to shifting + if isinstance(loss, torch.Tensor) and start_pos < len(loss) and end_pos <= len(loss): + line_loss = loss[start_pos:end_pos].mean().item() + else: + # Handle edge cases + line_loss = float("inf") + + line_info.append({ + "line": line, + "start": start_pos, + "end": end_pos, + "importance": line_loss, + "tokens": line_length + }) + + cumulative_tokens += line_length + + return line_info + + def get_prefix_length(self, prefix: str, text: str): + """ + Calculate the length of a prefix in tokens when concatenated with a text. + + Args: + prefix: The prefix text + text: The main text + + Returns: + Length of the prefix in tokens + """ + possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1) + full_input_ids = self.tokenizer(prefix + text[:100], add_special_tokens=False).input_ids + + for i in range(possible_prefix_token, len(full_input_ids)): + cur_prefix = self.tokenizer.decode(full_input_ids[:i]) + if cur_prefix == prefix: + break + + return i + + def get_condition_ppl( + self, + text: str, + question: str, + condition_in_question: str = "none", + granularity: str = "line", + ): + """ + Calculate perplexity change of a question when given context text. + A positive change means the context helps reduce question perplexity. + + Args: + text: The context text + question: The question to evaluate + condition_in_question: Conditioning mode (none, prefix) + granularity: Granularity for perplexity calculation + + Returns: + Perplexity change for the question with/without context + """ + # Create a cache key for this conditional perplexity calculation + cache_key = f"{text}_{question}_{condition_in_question}_{granularity}" + + if cache_key in self.cache["conditional_ppl"]: + return self.cache["conditional_ppl"][cache_key] + + if condition_in_question == "none": + # Just return the perplexity of the text + result = self.get_ppl( + text=text, granularity=granularity, condition_mode="none" + ) + ppl_value = result["ppl"] + else: + # First calculate question perplexity without context + question_ppl_without_context = self.get_ppl( + text=question, + granularity=granularity + )["ppl"] + + # Then calculate question perplexity with context + question_ppl_with_context = self.get_ppl( + text=text + "\n\n" + question, + granularity=granularity, + condition_mode="prefix", + condition_pos_id=self.get_token_length(text + "\n\n", add_special_tokens=True) + )["ppl"] + + # Calculate the change (positive means context helps) + ppl_value = question_ppl_without_context - question_ppl_with_context + + # Cache the result + self.cache["conditional_ppl"][cache_key] = ppl_value + self._manage_cache_size("conditional_ppl") + + return ppl_value + + def control_context_budget( + self, + context_list: List[str], + target_token: float, + question: str = "", + reorder_context: str = "original", + condition_in_question: str = "none", + force_context_ids: List[int] = None, + force_context_number: int = None, + context_budget: str = "+100", + dynamic_context_compression_ratio: float = 0.0, + ): + """ + Control token budget for contexts based on relevance ranking, following LongLLMLingua. + + Args: + context_list: List of contexts + target_token: Target number of tokens + question: Question for relevance ranking + reorder_context: How to reorder contexts ("original", "importance", "two_stage") + condition_in_question: Mode for conditional ranking + force_context_ids: List of context IDs to always include + force_context_number: Number of contexts to forcibly include + context_budget: String expression to modify target token budget + dynamic_context_compression_ratio: Ratio for dynamic compression (0.0-1.0) + + Returns: + Selected contexts, their indices, and dynamic ratios + """ + logger.debug(f"Controlling context budget with target_token={target_token}") + start_time = time.time() + + if not context_list: + return [], [], [] + + # Get token counts for each context + logger.debug("Calculating token lengths for contexts") + context_tokens_length = [self.get_token_length(context) for context in context_list] + + # If total tokens already fit within budget, return all contexts + total_tokens = sum(context_tokens_length) + if total_tokens <= target_token: + logger.debug(f"All contexts fit within budget ({total_tokens} <= {target_token})") + end_time = time.time() + logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds") + return context_list, list(range(len(context_list))), [0.0] * len(context_list) + + # Rank contexts by relevance if question is provided + logger.debug("Ranking contexts by relevance") + if question: + # Get perplexity change for each context with the question + context_ppl_changes = [] + for d, dl in zip(context_list, context_tokens_length): + # Calculate how much this context reduces question perplexity + ppl_change = self.get_condition_ppl( + d, + question, + condition_in_question, + ) + # Apply length adjustment factor similar to before + context_ppl_changes.append(ppl_change - dl * 2 / 250 * 0) + + # Sort by perplexity change - higher is better (more reduction in question perplexity) + demonstrations_sort = sorted(enumerate(context_ppl_changes), key=lambda x: -x[1]) + else: + # Without question, use default ordering + demonstrations_sort = [(i, 0) for i in range(len(context_list))] + + # Extract ranking for later use + self.context_idxs.append([x for idx, (x, _) in enumerate(demonstrations_sort)]) + + # Calculate the target token budget with context_budget expression + if target_token < 0: + target_token = 100 + target_token = eval("target_token" + context_budget) + + # Initialize selected context tracking + used = force_context_ids if force_context_ids is not None else [] + + # Select contexts until we reach the token budget + for idx, _ in demonstrations_sort: + if idx >= len(context_tokens_length): + continue + target_token -= context_tokens_length[idx] + if idx not in used: + used.append(idx) + if target_token < 0 or ( + force_context_number is not None and len(used) >= force_context_number + ): + break + + # Store original selection order + original_used = used.copy() + + # Reorder contexts if requested + if reorder_context == "original": + used = sorted(used) + elif reorder_context == "two_stage": + l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [ + _ for idx, _ in enumerate(used) if idx % 2 == 1 + ] + used = l + r[::-1] + + # Calculate dynamic compression ratios if requested + if dynamic_context_compression_ratio > 0: + N = len(used) + dynamic_ratio = [ + i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0 + for i in range(-(N - 1), N, 2) + ][::-1] + dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)} + dynamic_ratio = [dynamic_ratio_map[i] for i in used] + else: + dynamic_ratio = [0.0] * len(used) + + # Build list of selected contexts + selected_contexts = [context_list[idx] for idx in used if idx < len(context_list)] + + end_time = time.time() + logger.debug(f"Selected {len(selected_contexts)} contexts out of {len(context_list)}") + logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds") + + return selected_contexts, used, dynamic_ratio, demonstrations_sort + + def compress_code_file( + self, + code: str, + query: str = "", + instruction: str = "", + rate: float = 0.5, + target_token: float = -1, + language: str = "python", + use_iterative_compression: bool = True, + iterative_size: int = 200, + dynamic_compression_ratio: float = 0.2, + context_budget: str = "+100", + rank_only: bool = False, + fine_ratio: float = None, + fine_grained_importance_method: str = "conditional_ppl", + min_lines_for_fine_grained: int = 5, + importance_beta: float = 0.5, + use_knapsack: bool = True, + ): + """ + Compress a code file by first splitting it into function-based chunks and then compressing. + Functions are prioritized based on query relevance, similar to LongLLMLingua. + + Args: + code: The code to compress + query: Query to prioritize relevant functions + instruction: Additional instruction to guide compression + rate: Compression rate for coarse-grained (function level) compression (0.0-1.0) + target_token: Target number of tokens (alternative to rate) + language: Programming language of the code + use_iterative_compression: Whether to use iterative compression + iterative_size: Size of each iteration for iterative compression + dynamic_compression_ratio: Ratio for dynamic compression + context_budget: String expression to modify token budget + rank_only: If True, just rank and select contexts without fine-grained compression + fine_ratio: Ratio for fine-grained line selection (0.0-1.0). If None, uses `rate`. + fine_grained_importance_method: Method for scoring line importance ('contrastive_perplexity' or 'conditional_ppl'). Defaults to 'conditional_ppl'. + min_lines_for_fine_grained: Minimum number of lines a function must have to undergo fine-grained compression (otherwise kept fully). + importance_beta: Controls how much function importance affects its individual compression rate during fine-grained compression. + use_knapsack: Whether to use knapsack algorithm for block selection (True) or greedy line-by-line approach (False). + + Returns: + Compressed code and statistics with the following structure: + { + "original_code": Original uncompressed code, + "compressed_code": Compressed code, + "compressed_prompt": Complete compressed prompt with instruction and query, + "original_tokens": Number of tokens in original code, + "compressed_tokens": Number of tokens in compressed code, + "final_compressed_tokens": Number of tokens in final compressed prompt, + "compression_ratio": Ratio of compressed to original tokens, + "function_compressions": Details about compression for each function, + "selected_functions": Indices of selected functions, + "demonstrations_sort": Ranking of functions by importance, + "compressed_chunks": List of compressed code chunks + "fine_grained_method_used": The method used for fine-grained importance scoring. + } + """ + logger.debug(f"Starting code file compression with rate={rate}, target_token={target_token}, language={language}") + start_time = time.time() + + # Split code into function-based chunks + logger.debug("Splitting code into function-based chunks") + code_chunks = self.split_code_by_functions(code, language=language) + logger.debug(f"Split code into {len(code_chunks)} chunks") + + # Calculate total tokens + logger.debug("Calculating total tokens") + total_tokens = sum(self.get_token_length(chunk) for chunk in code_chunks) + logger.debug(f"Total tokens: {total_tokens}") + + # Determine target_token based on rate if not specified + original_target_token = target_token # Store original value if provided + if target_token <= 0: + if rate <= 0: + # Default target if both rate and target_token are invalid + target_token = int(total_tokens * 0.5) + logger.warning(f"Rate and target_token invalid, defaulting target_token to {target_token}") + else: + target_token = int(total_tokens * rate) + logger.debug(f"Coarse Target tokens: {target_token}") + + # Use context budget control to select important functions + logger.debug("Selecting important functions using context budget control") + selected_contexts, selected_indices, dynamic_ratios, demonstrations_sort = self.control_context_budget( + code_chunks, + target_token=target_token, + question=query, + reorder_context="original", # Keep original order to maintain code structure + condition_in_question="prefix", + context_budget=context_budget, + dynamic_context_compression_ratio=dynamic_compression_ratio, + ) + + # If rank_only is True, just use the selected contexts without further compression + logger.debug("Using rank-only mode: selecting top functions without fine-grained compression") + compressed_chunks = [] + compressed_tokens = 0 + function_compressions = {} + + # Just keep the selected contexts as is + for i, chunk in enumerate(code_chunks): + if i in selected_indices: + compressed_chunks.append(chunk) + chunk_tokens = self.get_token_length(chunk) + compressed_tokens += chunk_tokens + + # Store compression info - no actual compression in this mode + function_compressions[i] = { + "original_tokens": chunk_tokens, + "compressed_tokens": chunk_tokens, + "compression_ratio": 1.0, + } + else: + # Skip this function completely + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + omission_text = f"{comment_marker} ... " + compressed_chunks.append(omission_text) + compressed_tokens += self.get_token_length(omission_text) + + # Combine compressed chunks + compressed_code = "\n\n".join(compressed_chunks) + + # --- Post-join cleanup for consecutive omission markers --- + logger.debug("Cleaning up consecutive omission markers after joining...") + lines = compressed_code.split("\n") + cleaned_lines = [] + last_non_empty_line_was_omission = False + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + omission_marker_content = f"{comment_marker} ...".strip() # Content to check against + + for line in lines: + stripped_line = line.strip() + if not stripped_line: + # Keep empty lines + cleaned_lines.append(line) + # Don't reset the flag here, wait for a non-empty line + elif stripped_line == omission_marker_content: + if last_non_empty_line_was_omission: + # Skip this consecutive omission marker line + logger.debug(f"Skipping line: '{line}' (consecutive omission)") + continue + else: + # Keep the first omission marker line + cleaned_lines.append(line) + last_non_empty_line_was_omission = True + else: + # Regular code line + cleaned_lines.append(line) + last_non_empty_line_was_omission = False + + compressed_code = "\n".join(cleaned_lines) + logger.debug("Cleanup finished.") + # --- End post-join cleanup --- + + + output = f"{instruction}\n\n{compressed_code}\n\n{query}\n{instruction}" + + # Calculate actual compressed tokens + final_compressed_tokens = self.get_token_length(output) + + end_time = time.time() + logger.debug(f"Code file compression completed in {end_time - start_time:.2f} seconds") + logger.debug(f"Compression ratio: {compressed_tokens / total_tokens if total_tokens > 0 else 1.0:.2f}") + + if rank_only: + return { + "original_code": code, + "compressed_code": compressed_code, + "compressed_prompt": output, + "original_tokens": total_tokens, + "compressed_tokens": compressed_tokens, + "final_compressed_tokens": final_compressed_tokens, + "compression_ratio": compressed_tokens / total_tokens if total_tokens > 0 else 1.0, + "function_compressions": function_compressions, + "selected_functions": selected_indices, + "demonstrations_sort": demonstrations_sort, + "compressed_chunks": compressed_chunks, + "fine_grained_method_used": None, + } + else: + # enter fine-grained compression + logger.debug(f"Starting fine-grained compression on selected functions using method: {fine_grained_importance_method}") + + # --- Dynamic Fine-grained Rate Allocation --- + logger.debug("Calculating dynamic fine-grained compression rates...") + + # 1. Collect data for selected functions + selected_functions_data = [] + importance_map = {idx: score for idx, score in demonstrations_sort} # Map index to score + total_lines_selected = 0 + for i in selected_indices: + if i < len(code_chunks): + chunk = code_chunks[i] + # Use simple line splitting for allocation efficiency + lines = chunk.split("\n") + line_count = len(lines) + score = importance_map.get(i, 0.0) # Default score 0 if not found + selected_functions_data.append({ + "index": i, + "lines": lines, + "line_count": line_count, + "score": score + }) + total_lines_selected += line_count + else: + logger.warning(f"Selected index {i} is out of bounds for code_chunks (length {len(code_chunks)})") + + + # 2. Calculate overall fine-grained target lines + current_fine_ratio = fine_ratio if fine_ratio is not None else rate # Use rate if fine_ratio not set + if original_target_token > 0: # If target_token was set explicitly, derive ratio from it for fine-grained stage + # Estimate target lines based on the ratio of selected tokens to total tokens, then apply fine_ratio + selected_tokens = sum(self.get_token_length(code_chunks[d['index']]) for d in selected_functions_data) + effective_coarse_rate = selected_tokens / total_tokens if total_tokens > 0 else 1.0 + # Use the user-provided fine_ratio, or fall back to rate/coarse target estimate + fine_target_rate = current_fine_ratio + logger.debug(f"Using fine_ratio={fine_target_rate} for fine-grained target calculation.") + target_total_lines = int(total_lines_selected * fine_target_rate) + + else: # Calculate target based on fine_ratio/rate directly applied to selected lines + target_total_lines = int(total_lines_selected * current_fine_ratio) + logger.debug(f"Using current_fine_ratio={current_fine_ratio} for fine-grained target calculation.") + + logger.debug(f"Total lines in selected functions: {total_lines_selected}") + logger.debug(f"Target total lines after fine-grained compression: {target_total_lines}") + + # 3. Separate small and large functions + small_functions = [] + large_functions = [] + lines_in_small_functions = 0 + lines_in_large_functions = 0 + + for data in selected_functions_data: + if data["line_count"] < min_lines_for_fine_grained: + small_functions.append(data) + lines_in_small_functions += data["line_count"] + else: + large_functions.append(data) + lines_in_large_functions += data["line_count"] + + logger.debug(f"Found {len(small_functions)} small functions (< {min_lines_for_fine_grained} lines) with {lines_in_small_functions} total lines (will be kept).") + logger.debug(f"Found {len(large_functions)} large functions (>= {min_lines_for_fine_grained} lines) with {lines_in_large_functions} total lines.") + + # 4. Calculate target lines for large functions + target_lines_for_large = max(0, target_total_lines - lines_in_small_functions) + logger.debug(f"Target lines to keep from large functions: {target_lines_for_large}") + + # 5. Allocate rates for large functions + function_fine_ratios = {} # Map: index -> individual_fine_ratio + + if not large_functions or lines_in_large_functions == 0: + logger.debug("No large functions to compress further or zero lines in large functions.") + global_rate_for_large = 1.0 if lines_in_large_functions > 0 else 0.0 # Should be 0 if lines_in_large_functions is 0 + elif target_lines_for_large <= 0: + logger.debug("Target lines for large functions is <= 0. Setting rates to 0.") + global_rate_for_large = 0.0 + elif target_lines_for_large >= lines_in_large_functions: + logger.debug("Target lines for large functions >= total lines. Setting rates to 1.0.") + global_rate_for_large = 1.0 + else: + global_rate_for_large = target_lines_for_large / lines_in_large_functions + logger.debug(f"Global target rate for large functions: {global_rate_for_large:.4f}") + + # Normalize scores for weighting (MinMax scaling) + scores = [d["score"] for d in large_functions] + valid_scores = [s for s in scores if not math.isinf(s) and not math.isnan(s)] + + if not valid_scores or max(valid_scores) == min(valid_scores): + logger.debug("Scores are uniform or invalid, using global rate for all large functions.") + for data in large_functions: + function_fine_ratios[data["index"]] = global_rate_for_large + else: + min_score = min(valid_scores) + max_score = max(valid_scores) + normalized_scores = [(s - min_score) / (max_score - min_score) if not math.isinf(s) and not math.isnan(s) else 0.0 for s in scores] # Normalize to [0, 1], default 0 for invalid + + # Calculate initial biased rates + initial_rates = [] + for norm_score in normalized_scores: + # Bias rate: higher score -> higher rate (closer to 1) + # Beta controls sensitivity. beta=0 -> uniform rate. beta=1 -> max sensitivity. + biased_rate = global_rate_for_large * (1 + importance_beta * (norm_score - 0.5) * 2) # Scale norm_score diff to [-beta, beta] + clamped_rate = max(0.0, min(1.0, biased_rate)) # Clamp to [0, 1] + initial_rates.append(clamped_rate) + + # Calculate actual lines kept with initial rates + actual_lines_kept = sum(initial_rates[i] * large_functions[i]["line_count"] for i in range(len(large_functions))) + logger.debug(f"Initial biased rates calculated. Estimated lines kept: {actual_lines_kept:.1f}") + + # Adjust rates proportionally to meet target + if actual_lines_kept > 0 and abs(actual_lines_kept - target_lines_for_large) > 1: # Adjust if difference is significant + adjustment_factor = target_lines_for_large / actual_lines_kept + logger.debug(f"Adjusting rates by factor: {adjustment_factor:.4f}") + final_rates = [max(0.0, min(1.0, r * adjustment_factor)) for r in initial_rates] # Adjust and clamp again + else: + logger.debug("Initial rates are close enough or actual_lines_kept is 0, no adjustment needed.") + final_rates = initial_rates + + for i, data in enumerate(large_functions): + function_fine_ratios[data["index"]] = final_rates[i] + + # Set rate 1.0 for small functions + for data in small_functions: + function_fine_ratios[data["index"]] = 1.0 + + # --- End Dynamic Allocation --- + + + # Apply fine-grained compression to each selected function + fine_compressed_chunks = [] + compressed_tokens = 0 + function_compressions = {} + + # Define a smoothing window size for moving average + smoothing_window = 5 + # fine_ratio = fine_ratio if fine_ratio is not None else rate # Use the same ratio by default if fine_ratio not specified # Removed, using individual ratios now + + # Process each chunk in the original order + # Use tqdm.auto for compatibility + fine_grained_pbar = tqdm(enumerate(code_chunks), total=len(code_chunks), desc="Fine-Grained Compression", leave=False) + for i, chunk in fine_grained_pbar: + # for i, chunk in enumerate(code_chunks): + if i in selected_indices: + # This function was selected during coarse-grained compression + individual_fine_ratio = function_fine_ratios.get(i) # Get dynamically assigned ratio + if individual_fine_ratio is None: + logger.error(f"Missing fine-grained ratio for selected function index {i}. Skipping fine-grained compression for this chunk.") + individual_fine_ratio = 1.0 # Fallback to keep the chunk + + # Use Entropy chunking for fine-grained compression instead of simple line splitting + chunks, sentences, ppls, spike_indices = self.entropy_chunking.chunk_text_adaptive( + code_chunks[i], method='std', k=1.0 + ) + # Use chunks as lines, but preserve all chunks including empty ones to maintain formatting + chunk_lines = chunks # Keep all chunks to preserve \n\n and formatting + chunk_line_count = len([chunk for chunk in chunk_lines if chunk.strip()]) # Count only non-empty for logic + chunk_score = importance_map.get(i, float('nan')) # Get score + + logger.debug(f"Processing Func {i}: Entropy Chunks={len(chunk_lines)}, Non-empty={chunk_line_count}, Score={chunk_score:.4f}, Assigned FineRatio={individual_fine_ratio:.4f}") + + + # Skip fine-grained compression if ratio is 1.0 (or close) or function is small + if individual_fine_ratio >= 0.999 or chunk_line_count < min_lines_for_fine_grained: + note = "Kept (Ratio=1.0)" if individual_fine_ratio >= 0.999 else f"Kept (Small Func < {min_lines_for_fine_grained} lines)" + logger.debug(f" - {note}") + fine_compressed_chunks.append(chunk) + chunk_tokens = self.get_token_length(chunk) + compressed_tokens += chunk_tokens + function_compressions[i] = { + "original_tokens": chunk_tokens, + "compressed_tokens": chunk_tokens, + "compression_ratio": 1.0, + "individual_fine_ratio": individual_fine_ratio, + "note": note, + "importance_method": None # No line importance calculation needed + } + continue # Move to next chunk + + + # Apply fine-grained compression only if the function is large enough + # and we're not in rank-only mode (already checked) and ratio < 1.0 + if chunk_line_count >= min_lines_for_fine_grained and individual_fine_ratio < 0.999: + logger.debug(f" - Applying fine-grained compression with ratio {individual_fine_ratio:.4f}") + fine_grained_pbar.set_description(f"Fine-Grained Compressing Func {i}") + + # Calculate target tokens for this function + original_func_tokens = self.get_token_length(chunk) + target_func_tokens = int(original_func_tokens * individual_fine_ratio) + + # Calculate importance for each block based on the chosen method + block_importances = [] + importance_calculation_start = time.time() + + if fine_grained_importance_method == "conditional_ppl": + # Calculate conditional PPL importance for each block + if not query or not query.strip(): + logger.warning(f"Query is empty for func {i}, cannot calculate conditional PPL. Assigning 0 importance.") + block_importances = [0.0] * len(chunk_lines) + else: + query_ppl_result = self.get_ppl(query, granularity="line") + query_ppl_without_context = query_ppl_result["ppl"] + + if math.isinf(query_ppl_without_context): + logger.warning(f"Base query PPL is infinite for func {i}. Assigning 0 importance to blocks.") + block_importances = [0.0] * len(chunk_lines) + else: + pbar_cond = tqdm(enumerate(chunk_lines), total=len(chunk_lines), desc=f"Func {i} Block CondPPL", leave=False) + for block_idx, block in pbar_cond: + if not block.strip(): + block_importances.append(-float('inf')) # Low score for empty blocks + continue + + conditional_text = block + "\n\n" + query + prefix_len_text = block + "\n\n" + prefix_len = self.get_token_length(prefix_len_text, add_special_tokens=True) + + cond_ppl_result = self.get_ppl( + text=conditional_text, + granularity="line", + condition_mode="prefix", + condition_pos_id=prefix_len - 1 + ) + ppl_with_context = cond_ppl_result["ppl"] + + if math.isinf(ppl_with_context): + ppl_change = -float('inf') + else: + ppl_change = query_ppl_without_context - ppl_with_context + + block_importances.append(ppl_change) + pbar_cond.set_description(f"Func {i} Block CondPPL (B{block_idx}: {ppl_change:.2f})") + + elif fine_grained_importance_method == "contrastive_perplexity": + # Calculate contrastive PPL importance for each block + fine_grained_pbar.set_description(f"Fine-Grained ContrastivePPL Func {i}") + + with torch.no_grad(): + pbar = tqdm(enumerate(chunk_lines), total=len(chunk_lines), desc="Block Contrastive PPL", leave=False) + for block_idx, block in pbar: + if not block.strip(): + block_importances.append(-float('inf')) + continue + + # Build context from previous blocks + prev_context = "\n\n".join(chunk_lines[:block_idx]) if block_idx > 0 else "" + + # 1. PPL(Block | prev_blocks) + regular_ppl_condition = prev_context + "\n\n" if prev_context else None + regular_ppl = self._calculate_perplexity_for_contrastive(block, condition_text=regular_ppl_condition) + + # 2. PPL(Block | query, prev_blocks) + question_context_parts = [query] + if prev_context: + question_context_parts.append(prev_context) + question_context = "\n\n".join(filter(None, question_context_parts)) + cond_ppl_condition = question_context + "\n\n" + cond_ppl = self._calculate_perplexity_for_contrastive(block, condition_text=cond_ppl_condition) + + # 3. Importance = PPL(Block|prev) - PPL(Block|Q,prev) + if math.isinf(regular_ppl) or math.isinf(cond_ppl): + importance = -float('inf') + else: + importance = regular_ppl - cond_ppl + + block_importances.append(importance) + pbar.set_description(f"Block {block_idx}: {importance:.2f}") + + else: + raise ValueError(f"Unsupported fine_grained_importance_method: {fine_grained_importance_method}") + + importance_calculation_end = time.time() + logger.debug(f" - Block importance calculation took {importance_calculation_end - importance_calculation_start:.2f}s") + + # Identify preserved blocks (function signature, comments, returns) + preserved_block_indices = set() + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + + # Find blocks containing function signature + for block_idx, block in enumerate(chunk_lines): + block_lines = block.split('\n') + for line in block_lines: + if line.strip(): + # Check for function/class definitions + if any(keyword in line for keyword in ['def ', 'class ', 'function ', 'fn ', 'func ']): + preserved_block_indices.add(block_idx) + break + # Check for function-level comments + if line.strip().startswith(comment_marker): + preserved_block_indices.add(block_idx) + break + # Check for return statements + if 'return ' in line: + preserved_block_indices.add(block_idx) + break + break # Only check first non-empty line of each block + + # Choose selection method based on use_knapsack parameter + processing_start = time.time() + + if use_knapsack: + # Use knapsack algorithm to select blocks + logger.debug(f" - Using knapsack algorithm for block selection") + selected_block_indices, selection_info = self._knapsack_block_selection( + blocks=chunk_lines, + block_importances=block_importances, + target_tokens=target_func_tokens, + preserved_block_indices=preserved_block_indices, + language=language + ) + + # Build compressed chunk from selected blocks + compressed_blocks = [] + + # Determine base indentation for omission markers + base_indentation = "" + for block in chunk_lines: + for line in block.split('\n'): + if line.strip(): + match = re.match(r"^(\s*)", line) + if match: + base_indentation = match.group(1) + break + if base_indentation: + break + + omission_marker = f"{base_indentation}{comment_marker} ... " + + # Build output with omission markers for gaps + last_selected_idx = -1 + for block_idx in sorted(selected_block_indices): + # Add omission marker if there's a gap + if last_selected_idx != -1 and block_idx > last_selected_idx + 1: + if not compressed_blocks or compressed_blocks[-1] != omission_marker: + compressed_blocks.append(omission_marker) + + compressed_blocks.append(chunk_lines[block_idx]) + last_selected_idx = block_idx + + # Handle trailing omission if needed + if last_selected_idx != -1 and last_selected_idx < len(chunk_lines) - 1: + if not compressed_blocks or compressed_blocks[-1] != omission_marker: + compressed_blocks.append(omission_marker) + + # Join blocks with double newlines to preserve Entropy chunk structure + compressed_chunk = "\n\n".join(compressed_blocks) + + else: + # Use original greedy line-by-line approach with smoothing + logger.debug(f" - Using original greedy line-by-line approach") + + # Convert block importances to line importances for compatibility + lines = [] + line_importances = [] + line_indices = [] + + for block_idx, (block, block_importance) in enumerate(zip(chunk_lines, block_importances)): + block_lines = block.split('\n') + for line_idx_in_block, line in enumerate(block_lines): + global_line_idx = len(lines) + lines.append(line) + line_importances.append(block_importance) # Use block importance for all lines in block + line_indices.append(global_line_idx) + + # Apply original processing logic with smoothing + full_line_scores = [float('nan')] * len(lines) + for score_idx, original_line_idx in enumerate(line_indices): + if score_idx < len(line_importances): + full_line_scores[original_line_idx] = line_importances[score_idx] + + # Replace NaN/Inf with min valid score for consistent processing + valid_scores = [s for s in full_line_scores if not math.isnan(s) and not math.isinf(s)] + if valid_scores: + min_valid_score = min(valid_scores) + if min_valid_score == float('inf') or min_valid_score == -float('inf') or math.isnan(min_valid_score): + min_replacement_score = 0.0 + else: + min_replacement_score = min_valid_score + + processed_line_scores = [] + for s in full_line_scores: + if math.isnan(s) or s == -float('inf'): + processed_line_scores.append(min_replacement_score) + elif s == float('inf'): + processed_line_scores.append(min_replacement_score) + else: + processed_line_scores.append(s) + else: + processed_line_scores = [0.0] * len(lines) + + # Apply smoothing using moving average + smoothing_window = 5 + smoothed_importances = processed_line_scores.copy() + num_processed_scores = len(processed_line_scores) + for j in range(num_processed_scores): + window_start = max(0, j - smoothing_window // 2) + window_end = min(num_processed_scores, j + smoothing_window // 2 + 1) + window = processed_line_scores[window_start:window_end] + valid_window_scores = [s for s in window if not math.isnan(s) and not math.isinf(s)] + if valid_window_scores: + smoothed_importances[j] = sum(valid_window_scores) / len(valid_window_scores) + + # Find preserved lines (convert block indices to line indices) + preserved_line_indices = set() + line_offset = 0 + for block_idx, block in enumerate(chunk_lines): + block_lines = block.split('\n') + if block_idx in preserved_block_indices: + for line_idx_in_block in range(len(block_lines)): + preserved_line_indices.add(line_offset + line_idx_in_block) + line_offset += len(block_lines) + + # Sort remaining lines by importance + sortable_lines = [] + for idx in range(len(lines)): + if idx not in preserved_line_indices: + if idx < len(line_indices) and idx < len(line_importances): + original_score = line_importances[idx] + if not math.isnan(original_score) and not math.isinf(original_score): + smoothed_score = smoothed_importances[idx] + sortable_lines.append((idx, smoothed_score)) + + # Sort descending by score + sorted_line_indices = sorted(sortable_lines, key=lambda x: -x[1]) + + # Calculate target number of lines to keep + total_lines = len(lines) + preserved_count = len(preserved_line_indices) + target_lines = max(preserved_count, int(total_lines * individual_fine_ratio)) + + # Select top lines by importance up to target + selected_lines = set(preserved_line_indices) + for idx, score in sorted_line_indices: + if len(selected_lines) >= target_lines: + break + selected_lines.add(idx) + + # Build compressed chunk from selected lines + compressed_chunks = [] + base_indentation = "" + if lines: + for line in lines: + if line.strip(): + match = re.match(r"^(\s*)", line) + if match: + base_indentation = match.group(1) + break + + omission_marker_line = f"{base_indentation}{comment_marker} ... " + + last_added_line_idx = -1 + for j in range(len(lines)): + if j in selected_lines: + if last_added_line_idx != -1 and j > last_added_line_idx + 1: + if not compressed_chunks or compressed_chunks[-1] != omission_marker_line: + compressed_chunks.append(omission_marker_line) + compressed_chunks.append(lines[j]) + last_added_line_idx = j + + if last_added_line_idx != -1 and last_added_line_idx < len(lines) - 1: + if not compressed_chunks or compressed_chunks[-1] != omission_marker_line: + compressed_chunks.append(omission_marker_line) + + compressed_chunk = "\n".join(compressed_chunks) + + # Create selection info for compatibility + selection_info = { + "method": "greedy_line_by_line", + "preserved_lines": len(preserved_line_indices), + "selected_lines": len(selected_lines), + "total_lines": len(lines), + "smoothing_applied": True + } + selected_block_indices = preserved_block_indices # For compatibility + + processing_end = time.time() + method_name = "knapsack" if use_knapsack else "greedy" + logger.debug(f" - {method_name} selection took {processing_end - processing_start:.2f}s") + + if use_knapsack: + logger.debug(f" - Selected {len(selected_block_indices)}/{len(chunk_lines)} blocks") + else: + logger.debug(f" - Selected {len(selected_lines)}/{len(lines)} lines") + + # Update token count and store compression info + fine_compressed_chunks.append(compressed_chunk) + compressed_chunk_tokens = self.get_token_length(compressed_chunk) + compressed_tokens += compressed_chunk_tokens + + # Store compression info + actual_compression_ratio = compressed_chunk_tokens / original_func_tokens if original_func_tokens > 0 else 1.0 + function_compressions[i] = { + "original_tokens": original_func_tokens, + "compressed_tokens": compressed_chunk_tokens, + "compression_ratio": actual_compression_ratio, + "individual_fine_ratio": individual_fine_ratio, + "preserved_blocks": list(preserved_block_indices), + "selected_blocks": list(selected_block_indices), + "selection_info": selection_info, + "importance_method": fine_grained_importance_method, + "selection_method": "knapsack" if use_knapsack else "greedy_line_by_line" + } + logger.debug(f" - Compressed func {i}: {original_func_tokens} -> {compressed_chunk_tokens} tokens (Ratio: {actual_compression_ratio:.3f})") + else: + # This case should now be handled by the check at the beginning of the loop + logger.warning(f"Reached unexpected state for func {i}. Keeping chunk as is.") + fine_compressed_chunks.append(chunk) + chunk_tokens = self.get_token_length(chunk) + compressed_tokens += chunk_tokens + function_compressions[i] = { + "original_tokens": chunk_tokens, + "compressed_tokens": chunk_tokens, + "compression_ratio": 1.0, + "individual_fine_ratio": individual_fine_ratio, + "note": "Unexpected state, kept function.", + "importance_method": None + } + + else: + # This function was not selected during coarse-grained compression + # Add a placeholder + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + omission_text = f"{comment_marker} ... " + fine_compressed_chunks.append(omission_text) + compressed_tokens += self.get_token_length(omission_text) + # Log skipped chunk + # logger.debug(f"Skipped Func {i} (not selected in coarse stage)") + + + # Combine fine-grained compressed chunks + compressed_code = "\n\n".join(fine_compressed_chunks) + + # --- Post-join cleanup for consecutive omission markers --- + logger.debug("Cleaning up consecutive omission markers after joining...") + lines = compressed_code.split("\n") + cleaned_lines = [] + last_non_empty_line_was_omission = False + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + omission_marker_content = f"{comment_marker} ...".strip() # Content to check against + + for line in lines: + stripped_line = line.strip() + if not stripped_line: + # Keep empty lines + cleaned_lines.append(line) + # Don't reset the flag here, wait for a non-empty line + elif stripped_line == omission_marker_content: + if last_non_empty_line_was_omission: + # Skip this consecutive omission marker line + logger.debug(f"Skipping line: '{line}' (consecutive omission)") + continue + else: + # Keep the first omission marker line + cleaned_lines.append(line) + last_non_empty_line_was_omission = True + else: + # Regular code line + cleaned_lines.append(line) + last_non_empty_line_was_omission = False + + compressed_code = "\n".join(cleaned_lines) + logger.debug("Cleanup finished.") + # --- End post-join cleanup --- + + + # Ensure instruction/query parts are handled correctly, maybe use a template + prompt_parts = [] + if instruction and instruction.strip(): + prompt_parts.append(instruction.strip()) + if compressed_code.strip(): + prompt_parts.append(compressed_code) # Already has newlines handled + if query and query.strip(): + # Add query, potentially repeating instruction based on original logic + prompt_parts.append(query.strip()) + # Decide if instruction should be repeated after query based on original implementation's needs + # if instruction and instruction.strip(): # Repeat instruction if needed + # prompt_parts.append(instruction.strip()) + + output = "\n\n".join(prompt_parts) # Use double newline separation + + # Calculate final compressed tokens + final_compressed_tokens = self.get_token_length(output) + + end_time = time.time() + logger.debug(f"Fine-grained compression processing completed in {end_time - start_time:.2f} seconds") + final_compression_ratio = compressed_tokens / total_tokens if total_tokens > 0 else 1.0 + logger.debug(f"Final Compression ratio (fine-grained tokens / total original tokens): {final_compression_ratio:.4f}") + + + return { + "original_code": code, + "compressed_code": compressed_code, + "compressed_prompt": output, + "original_tokens": total_tokens, + "compressed_tokens": compressed_tokens, + "final_compressed_tokens": final_compressed_tokens, + "compression_ratio": final_compression_ratio, + "function_compressions": function_compressions, + "selected_functions": selected_indices, + "demonstrations_sort": demonstrations_sort, + "compressed_chunks": fine_compressed_chunks, + "fine_grained_method_used": fine_grained_importance_method, + } + + def split_code_by_functions(self, code: str, language: str = "python", custom_separator: str = "# --CHUNK_SEPARATOR-- #") -> List[str]: + """ + Split code into chunks based on function and class definitions for various languages. + Also splits on custom separator if provided. + + Args: + code: The code to split + language: Programming language of the code (python, cpp, java, typescript, rust, go) + custom_separator: Optional custom separator string to also split on + + Returns: + List of code chunks, each containing a function, class, or class method + """ + logger.debug(f"Splitting code by functions and classes for language: {language}") + start_time = time.time() + + # Define regex patterns for different languages + patterns = { + # Python: Simplified to match 'def' or 'class' followed by content until the next def/class or end + "python": r'(^|\n)(\s*)(def|class)\s+[^\n]+(\n(?!\s*(?:def|class)\s)[^\n]*)*', + # C++: Improved to better handle multi-line declarations + "cpp": r'(^|\n)(\s*)(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s*:\s*[^{]*)?|(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*\s+)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?', + # Java: Improved for multi-line method declarations + "java": r'(^|\n)(\s*)(?:(?:public|private|protected|static|final|abstract|synchronized)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:<.*>)?(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*)\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*throws\s+[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?', + # TypeScript: Enhanced to handle multi-line methods and arrow functions + "typescript": r'(^|\n)(\s*)(?:(?:public|private|protected|static|abstract)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:(?:public|private|protected|static|async)\s+)*(?:function\s+)?(?:[a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<.*>)?\s*\([^{;]*\)\s*(?::\s*[^{;]*\s*)?(?:=>)?)\s*(?:{[^}]*}|[^;]*;)?', + # Rust: Improved for multi-line function declarations + "rust": r'(^|\n)(\s*)(?:pub\s+)?(?:struct\s+[a-zA-Z_][a-zA-Z0-9_]*|impl(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+for\s+[a-zA-Z_][a-zA-Z0-9_]*)?|(?:async\s+)?fn\s+[a-zA-Z_][a-zA-Z0-9_]*\s*(?:<.*>)?\s*\([^{;]*\)(?:\s*->\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?', + # Go: Improved for multi-line function declarations + "go": r'(^|\n)(\s*)(?:type\s+[a-zA-Z_][a-zA-Z0-9_]*\s+struct|func\s+(?:\([^)]*\)\s*)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?', + } + + # Use default Python pattern if language not supported + if language.lower() not in patterns: + language = "python" + + # First check if we need to split by custom separator + separator_chunks = [] + if custom_separator and custom_separator in code: + logger.debug(f"Custom separator '{custom_separator}' found, first splitting by separator") + separator_chunks = [chunk for chunk in code.split(custom_separator) if chunk.strip()] + else: + separator_chunks = [code] # Just one chunk - the entire code + + # Function to split a single chunk by functions/classes + def split_chunk_by_pattern(chunk_code): + function_pattern = re.compile(patterns[language.lower()], re.MULTILINE) + matches = list(function_pattern.finditer(chunk_code)) + + if not matches: + return [chunk_code] # No matches, return whole chunk + + result_chunks = [] + + # Add code before first match + if matches[0].start() > 0: + result_chunks.append(chunk_code[:matches[0].start()]) + + # Process each match + for i, match in enumerate(matches): + start = match.start() + + # End is either start of next match or end of code + if i < len(matches) - 1: + end = matches[i + 1].start() + else: + end = len(chunk_code) + + result_chunks.append(chunk_code[start:end]) + + return result_chunks + + # Now apply function/class splitting to each separator chunk + final_chunks = [] + for chunk in separator_chunks: + function_chunks = split_chunk_by_pattern(chunk) + final_chunks.extend(function_chunks) + + end_time = time.time() + logger.debug(f"Code splitting completed in {end_time - start_time:.2f} seconds") + logger.debug(f"Split code into {len(final_chunks)} chunks (using both separator and patterns)") + + return final_chunks + + def _calculate_perplexity_for_contrastive(self, text, condition_text=None): + """Helper to calculate perplexity of text, optionally conditioned on condition_text""" + if condition_text: + full_text = condition_text + text + inputs = self.tokenizer(full_text, return_tensors="pt", add_special_tokens=True).to(self.device) # Use add_special_tokens=True for consistency + + condition_input_ids = self.tokenizer(condition_text, return_tensors="pt", add_special_tokens=True).input_ids + condition_length = condition_input_ids.size(1) + + # Handle potential edge case where condition length might exceed max length or input length + if condition_length >= inputs.input_ids.size(1): + logger.warning(f"Condition length ({condition_length}) >= input length ({inputs.input_ids.size(1)}). Cannot calculate conditional PPL.") + return float('inf') + + with torch.no_grad(): + outputs = self.model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) # Pass attention_mask + + # Logits for the 'text' part, labels are the 'text' part shifted + logits = outputs.logits[0, condition_length-1:-1] + labels = inputs.input_ids[0, condition_length:] + + if logits.size(0) == 0 or labels.size(0) == 0 or logits.size(0) != labels.size(0): + logger.warning(f"Logits/Labels shape mismatch or empty in _calculate_perplexity_for_contrastive (cond). Logits: {logits.shape}, Labels: {labels.shape}. Returning inf.") + return float('inf') # Return inf if shapes mismatch or empty + + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + mean_loss = loss.mean().item() + perplexity = math.exp(mean_loss) if not math.isnan(mean_loss) and not math.isinf(mean_loss) else float('inf') + + else: + # Calculate unconditional perplexity + inputs = self.tokenizer(text, return_tensors="pt", add_special_tokens=True).to(self.device) # Use add_special_tokens=True + with torch.no_grad(): + outputs = self.model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) # Pass attention_mask + + # Logits for all tokens except last, labels are all tokens except first + logits = outputs.logits[0, :-1] + labels = inputs.input_ids[0, 1:] + + if logits.size(0) == 0 or labels.size(0) == 0 or logits.size(0) != labels.size(0): + logger.warning(f"Logits/Labels shape mismatch or empty in _calculate_perplexity_for_contrastive (uncond). Logits: {logits.shape}, Labels: {labels.shape}. Returning inf.") + return float('inf') # Return inf if shapes mismatch or empty + + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + mean_loss = loss.mean().item() + perplexity = math.exp(mean_loss) if not math.isnan(mean_loss) and not math.isinf(mean_loss) else float('inf') + + return perplexity + + def _calculate_contrastive_perplexity(self, code_lines: List[str], question: str): + """ + Calculate contrastive perplexity-based importance for each line of code. + s_i = perplexity(x_i | x_{ Tuple[set, Dict]: + """ + Use knapsack algorithm to select blocks that maximize total importance within token budget. + + Args: + blocks: List of code blocks (Entropy chunks) + block_importances: Importance scores for each block + target_tokens: Target number of tokens to keep + preserved_block_indices: Set of block indices that must be preserved + language: Programming language for omission markers + + Returns: + Tuple of (selected_block_indices, selection_info) + """ + logger.debug(f"Running knapsack block selection with target_tokens={target_tokens}") + + if not blocks: + return set(), {} + + # Calculate token weights for each block + block_weights = [self.get_token_length(block) for block in blocks] + + # Handle preserved blocks + if preserved_block_indices is None: + preserved_block_indices = set() + + # Calculate tokens already used by preserved blocks + preserved_tokens = sum(block_weights[i] for i in preserved_block_indices) + remaining_budget = max(0, target_tokens - preserved_tokens) + + logger.debug(f"Preserved blocks: {len(preserved_block_indices)}, tokens: {preserved_tokens}") + logger.debug(f"Remaining budget for knapsack: {remaining_budget}") + + # If no remaining budget, just return preserved blocks + if remaining_budget <= 0: + return preserved_block_indices, { + "method": "knapsack", + "preserved_only": True, + "total_value": sum(block_importances[i] for i in preserved_block_indices), + "total_weight": preserved_tokens + } + + # Prepare items for knapsack (excluding preserved blocks) + knapsack_items = [] + for i, (weight, value) in enumerate(zip(block_weights, block_importances)): + if i not in preserved_block_indices: + # Handle invalid importance scores + if math.isnan(value) or math.isinf(value): + value = 0.0 + knapsack_items.append((i, weight, value)) + + # Sort by value-to-weight ratio for efficiency (greedy approximation first) + knapsack_items.sort(key=lambda x: x[2] / max(x[1], 1), reverse=True) + + # Use dynamic programming for exact knapsack solution + # For efficiency, limit to reasonable problem size + if len(knapsack_items) <= 100 and remaining_budget <= 2000: + selected_indices = self._solve_knapsack_dp(knapsack_items, remaining_budget) + else: + # Use greedy approximation for large problems + logger.debug("Using greedy approximation for large knapsack problem") + selected_indices = self._solve_knapsack_greedy(knapsack_items, remaining_budget) + + # Combine with preserved blocks + final_selection = preserved_block_indices.union(selected_indices) + + # Calculate selection statistics + total_value = sum(block_importances[i] for i in final_selection) + total_weight = sum(block_weights[i] for i in final_selection) + + selection_info = { + "method": "knapsack", + "preserved_blocks": len(preserved_block_indices), + "selected_blocks": len(selected_indices), + "total_blocks": len(final_selection), + "total_value": total_value, + "total_weight": total_weight, + "target_weight": target_tokens, + "efficiency": total_value / max(total_weight, 1) + } + + logger.debug(f"Knapsack selection: {len(final_selection)}/{len(blocks)} blocks, " + f"value={total_value:.2f}, weight={total_weight}/{target_tokens}") + + return final_selection, selection_info + + def _solve_knapsack_dp(self, items: List[Tuple[int, int, float]], capacity: int) -> set: + """ + Solve knapsack problem using dynamic programming. + + Args: + items: List of (index, weight, value) tuples + capacity: Maximum weight capacity + + Returns: + Set of selected item indices + """ + n = len(items) + if n == 0 or capacity <= 0: + return set() + + # DP table: dp[i][w] = maximum value using first i items with weight limit w + dp = [[0.0 for _ in range(capacity + 1)] for _ in range(n + 1)] + + # Fill DP table + for i in range(1, n + 1): + idx, weight, value = items[i - 1] + for w in range(capacity + 1): + # Don't take item i + dp[i][w] = dp[i - 1][w] + + # Take item i if it fits + if weight <= w: + dp[i][w] = max(dp[i][w], dp[i - 1][w - weight] + value) + + # Backtrack to find selected items + selected = set() + w = capacity + for i in range(n, 0, -1): + if dp[i][w] != dp[i - 1][w]: + idx, weight, value = items[i - 1] + selected.add(idx) + w -= weight + + return selected + + def _solve_knapsack_greedy(self, items: List[Tuple[int, int, float]], capacity: int) -> set: + """ + Solve knapsack problem using greedy approximation (by value/weight ratio). + + Args: + items: List of (index, weight, value) tuples (should be pre-sorted by ratio) + capacity: Maximum weight capacity + + Returns: + Set of selected item indices + """ + selected = set() + current_weight = 0 + + for idx, weight, value in items: + if current_weight + weight <= capacity: + selected.add(idx) + current_weight += weight + + return selected + +if __name__ == "__main__": + # Load real examples from the dataset + with open("data/data.jsonl", "r") as f: + data = [json.loads(line) for line in f] + + example = data[190] + # print(example.keys()) # dict_keys(['id', 'gt', 'original_background_context', 'original_current_function_context', 'language', 'prompt', 'output', 'es', 'em']) + + context = example["original_background_context"] + question = example["original_current_function_context"] + ground_truth = example["gt"] + + # Initialize compressor + logger.info("Initializing compressor...") + model_name = "Qwen/Qwen2.5-Coder-7B-Instruct" + compressor = CodeCompressor(model_name=model_name) + + # Test function-based code file compression with query + logger.info("\nTesting function-based code file compression with query...") + + original_tokens = len(compressor.tokenizer.encode(context)) + target_token = 512 + target_ratio = min(1.0, max(0.0, target_token / original_tokens)) + logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}") + + result = compressor.compress_code_file( + code=context, + query=question, # Using current function context as query focus + instruction="Complete the following code function given the context.", + rate=target_ratio, + rank_only=False, # Test fine-grained compression + fine_grained_importance_method="contrastive_perplexity", # Explicitly test default + min_lines_for_fine_grained=5, # New parameter + importance_beta=0.5, # Sensitivity to importance score + use_knapsack=True, + ) + + # show the compressed code + logger.info(f"Compressed code (using {result['fine_grained_method_used']}): \n{result['compressed_code']}") + logger.info(f"Current function context: \n{question}") + # final prompt + final_prompt = result['compressed_prompt'] + # get the completion + try: + tokenized_prompt = compressor.tokenizer(final_prompt, return_tensors="pt").to(compressor.device) + # Increase max_new_tokens for potentially longer completions + completion_ids = compressor.model.generate(**tokenized_prompt, max_new_tokens=128, pad_token_id=compressor.tokenizer.eos_token_id) + # Decode only the generated part, skipping special tokens + completion = compressor.tokenizer.decode(completion_ids[0][len(tokenized_prompt.input_ids[0]):], skip_special_tokens=True) + + # Basic cleanup: remove leading/trailing whitespace and potentially stop words if needed + completion = completion.strip() + # More robust cleanup: Find the first meaningful line if generation includes noise + completion_lines = [line for line in completion.split("\n") if line.strip() and not line.strip().startswith(("#", "//"))] # Simple comment removal + cleaned_completion = completion_lines[0] if completion_lines else completion # Take first non-comment line or original if none found + + except Exception as e: + logger.error(f"Error during generation or decoding: {e}") + cleaned_completion = "[ERROR DURING GENERATION]" + + logger.info(f"Cleaned Completion: {cleaned_completion}") + logger.info(f"Ground truth: {ground_truth}") + + # Optional: Test with conditional_ppl method + logger.info("\nTesting fine-grained compression with conditional_ppl...") + result_cond = compressor.compress_code_file( + code=context, + query=question, + instruction="Complete the following code function given the context.", + rate=target_ratio, + rank_only=False, + fine_grained_importance_method="conditional_ppl", + min_lines_for_fine_grained=5, + importance_beta=0.5 + ) + logger.info(f"Compressed code (using {result_cond['fine_grained_method_used']}): \n{result_cond['compressed_code']}") \ No newline at end of file diff --git a/module_summarization/main.py b/module_summarization/main.py new file mode 100644 index 0000000..260c26b --- /dev/null +++ b/module_summarization/main.py @@ -0,0 +1,1318 @@ +import argparse +import os +import torch +import numpy as np +import gc +import json +from tqdm import tqdm +from vllm import LLM, EngineArgs, SamplingParams +from transformers import AutoTokenizer, AutoModel +from loguru import logger +from openai import OpenAI +from utils import truncate_text, load_dataset_samples +import fire +from llmlingua import PromptCompressor +from code_compressor import CodeCompressor +import asyncio +from itertools import cycle + +class LLMGenerator: + def __init__(self, model_name, device, **model_args): + # Create a vllm LLM instance + engine_args = EngineArgs(model=model_name, device=device, **model_args) + self.model = LLM(**vars(engine_args)) + self.model_name = model_name + self.device = device + # Use the tokenizer from the model to ensure consistency + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + + def generate(self, prompt, max_tokens=2048, temperature=0.0): + logger.debug(f"Generation input prompt: {truncate_text(prompt)}") + + # Convert to chat format + conversation = [ + {"role": "system", "content": "You are a documentation generating assistant specialized in code understanding."}, + {"role": "user", "content": prompt} + ] + + sampling_params = SamplingParams( + max_tokens=max_tokens, + temperature=temperature, + top_p=1.0, + top_k=50, + ) + + outputs = self.model.chat(conversation, sampling_params, use_tqdm=False) + result = outputs[0].outputs[0].text + + logger.debug(f"Generation output: {truncate_text(result)}") + return result + + def free_memory(self): + """Release model resources to free GPU memory""" + del self.model + torch.cuda.empty_cache() + gc.collect() + + +class LLMScorer: + def __init__(self, model_name, device, **model_args): + # Create a vllm LLM instance + engine_args = EngineArgs(model=model_name, device=device, **model_args) + self.model = LLM(**vars(engine_args)) + self.model_name = model_name + self.device = device + # Use the tokenizer from the model to ensure consistency + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + + def score_options(self, query, options): + # Convert to a chat format query + conversation = [ + {"role": "system", "content": "You are a code quality assessing engine."}, + {"role": "user", "content": query} + ] + + logger.debug(f"Scoring input query: {truncate_text(query)}") + logger.debug(f"Scoring options: {options}") + + sampling_params = SamplingParams( + max_tokens=1, + temperature=0.3, + logprobs=20, + ) + + # Get the completion with logprobs + outputs = self.model.chat(conversation, sampling_params, use_tqdm=False) + output = outputs[0].outputs[0] + + # Debug output structure + logger.debug(f"Output structure: {type(output)}") + logger.debug(f"Logprobs structure: {type(output.logprobs)}") + + # Extract logprobs for the options + logprobs = torch.zeros(len(options)) + found_options = set() + + # Convert options to lowercase for case-insensitive matching + option_map = {opt.lower(): i for i, opt in enumerate(options)} + + # Extract logprobs from the output + for token_dict in output.logprobs: + # Each item is a dictionary with token_id -> Logprob object + for _, logprob_obj in token_dict.items(): + try: + # Directly access the token and logprob attributes + token = logprob_obj.decoded_token.strip().lower() + logprob_value = logprob_obj.logprob + + # Check if this token matches one of our options + if token in option_map and option_map[token] not in found_options: + logprobs[option_map[token]] = logprob_value + found_options.add(option_map[token]) + logger.debug(f"Found option: {token} with logprob: {logprob_value}") + + except AttributeError: + # If the object doesn't have the expected attributes, skip it + continue + except Exception as e: + logger.error(f"Error processing token: {e}") + continue + + # Special case for options A and B + if not found_options and len(output.logprobs) > 0: + for token_dict in output.logprobs: + for _, logprob_obj in token_dict.items(): + try: + # Check specifically for A or B tokens + token = logprob_obj.decoded_token.strip().lower() + + if token in ['a', 'b'] and option_map.get(token) not in found_options: + logprobs[option_map[token]] = logprob_obj.logprob + found_options.add(option_map[token]) + logger.debug(f"Found exact option: {token.upper()} with logprob: {logprob_obj.logprob}") + except Exception as e: + logger.error(f"Error processing token for A/B check: {e}") + continue + + # If some options weren't found, assign a very low logprob + min_prob = logprobs[list(found_options)].min().item() if found_options else -100 + for i in range(len(options)): + if i not in found_options: + logprobs[i] = min_prob - 2.3 # approximately 10 times less + + logger.debug(f"Final scoring output logprobs: {logprobs}") + + return logprobs + + def free_memory(self): + """Release model resources to free GPU memory""" + del self.model + torch.cuda.empty_cache() + gc.collect() + + +class GPTScorer: + def __init__(self, model_name="gpt-4o-mini", **model_args): + self.model_name = model_name + # Use transformers tokenizer instead of tiktoken + self.tokenizer = AutoTokenizer.from_pretrained("gpt2") # Using gpt2 tokenizer as a good approximation + + # Array of API tokens for rotation + self.api_tokens = [ + "your_api_key" + ] + self.token_iterator = cycle(self.api_tokens) + + # Initialize OpenAI client with the first token + self.current_token = next(self.token_iterator) + self.client = OpenAI( + api_key=self.current_token + ) + logger.debug(f"Initialized GPTScorer with model: {model_name}") + + def rotate_token(self): + """Rotate to the next API token""" + self.current_token = next(self.token_iterator) + self.client = OpenAI( + api_key=self.current_token + ) + logger.debug(f"Rotated to next API token") + + def score_options(self, query, options): + logger.debug(f"Scoring input query: {truncate_text(query)}") + logger.debug(f"Scoring options: {options}") + + # Create logit bias to prioritize the option tokens + logit_bias = dict() + for opt in options: + # Use transformers tokenizer + tok_ids = self.tokenizer.encode(opt, add_special_tokens=False) + if len(tok_ids) == 1: + logit_bias[tok_ids[0]] = 100 + else: + logger.warning(f"Option '{opt}' encodes to multiple tokens {tok_ids}, using first token only") + logit_bias[tok_ids[0]] = 100 + + # Try up to 3 times with token rotation on failure + for attempt in range(3): + try: + # Call the OpenAI API + completion = self.client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "system", "content": "You are a code quality assessing engine."}, + {"role": "user", "content": query}, + ], + max_tokens=1, + temperature=0.3, + n=1, + logprobs=True, + top_logprobs=20, + logit_bias=logit_bias + ) + + # Process the results + logprobs = np.full(len(options), np.nan) + choice = completion.choices[0] + logger.debug(f"Choice: {choice}") + opt_to_idx = {t: n for n, t in enumerate(options)} + min_lp = 0 + + try: + for logprob_item in choice.logprobs.content[0].top_logprobs: + tok = logprob_item.token + lp = logprob_item.logprob + min_lp = min(min_lp, lp) + if tok in opt_to_idx: + logprobs[opt_to_idx[tok]] = lp + + # If any options weren't found, assign them a low probability + logprobs[np.isnan(logprobs)] = min_lp - 2.3 + assert not np.isnan(logprobs).any() + break # Success, exit retry loop + except Exception as e: + logger.error(f"Error processing logprobs: {e}") + # Return equal logprobs in case of error + return torch.zeros(len(options)) + + except Exception as e: + logger.warning(f"API call failed (attempt {attempt+1}/3): {e}") + # Rotate token on failure + self.rotate_token() + if attempt == 2: # Last attempt failed + logger.error("All API attempts failed") + return torch.zeros(len(options)) + + logger.debug(f"Final scoring output logprobs: {logprobs}") + return torch.from_numpy(logprobs) + + async def async_score_options(self, query, options): + """Asynchronous version of score_options that runs in a thread pool""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, self.score_options, query, options + ) + + def free_memory(self): + """Release any resources""" + # Nothing to free for API-based model + pass + + +# Helper function for sliding window chunking +def chunk_sliding_window(code: str, window_size: int, overlap: int) -> list[str]: + """Splits code into overlapping chunks using a sliding window.""" + lines = code.splitlines() + if not lines: + return [] + + chunks = [] + start = 0 + stride = window_size - overlap + if stride <= 0: + raise ValueError("Overlap size must be smaller than window size.") + + while True: + end = min(start + window_size, len(lines)) + chunk_lines = lines[start:end] + if not chunk_lines: # Should not happen if lines is not empty, but safety check + break + chunks.append("\n".join(chunk_lines)) + if end == len(lines): + break # Exit loop if we reached the end + next_start = start + stride + # If the next window would go past the end, break + if next_start >= len(lines): + # Add the final overlapping chunk if needed + final_start = max(0, len(lines) - window_size) + if final_start > start: # Ensure it's a new chunk not already added + final_chunk_lines = lines[final_start:] + chunks.append("\n".join(final_chunk_lines)) + break + start = next_start + + # Handle case where code is shorter than window size + if not chunks and lines: + return ["\n".join(lines)] + + # Remove duplicates while preserving order (important for RAG) + seen = set() + unique_chunks = [] + for chunk in chunks: + if chunk not in seen: + seen.add(chunk) + unique_chunks.append(chunk) + + return unique_chunks + + +# Helper function to compute embeddings (using mean pooling) +def compute_embedding(text: str, model, tokenizer, device) -> torch.Tensor: + """Computes sentence embedding for a text using the provided model.""" + if not text.strip(): # Handle empty strings + return torch.zeros(model.config.hidden_size).to(device) + inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(device) + with torch.no_grad(): + outputs = model(**inputs) + # Mean pool the last hidden state + embedding = outputs.last_hidden_state.mean(dim=1).squeeze() + return embedding + + +# Helper function for RAG retrieval +def rag_retrieve(background_code: str, query_code: str, model, tokenizer, device, window_size: int, overlap: int, top_k: int) -> str: + """Chunks background, embeds chunks and query, retrieves top_k similar chunks.""" + if not background_code.strip(): + return "" # Return empty if no background context + + chunks = chunk_sliding_window(background_code, window_size, overlap) + if not chunks: + return "" # Return empty if chunking results in nothing + + query_embedding = compute_embedding(query_code, model, tokenizer, device) + + chunk_embeddings = [] + valid_chunks = [] + for chunk in chunks: + if chunk.strip(): + chunk_embeddings.append(compute_embedding(chunk, model, tokenizer, device)) + valid_chunks.append(chunk) + + if not valid_chunks: + return "" + + # Stack embeddings for efficient similarity calculation + chunk_embeddings_tensor = torch.stack(chunk_embeddings) + + # Compute cosine similarity + similarities = torch.cosine_similarity(query_embedding.unsqueeze(0), chunk_embeddings_tensor, dim=1) + + # Get top_k indices + top_k_indices = torch.topk(similarities, k=min(top_k, len(valid_chunks)), dim=0).indices + + # Retrieve and sort chunks by their original position + relevant_chunks_with_indices = [] + original_indices_map = {chunk_content: idx for idx, chunk_content in enumerate(chunks)} # Map content back to original index + + retrieved_chunk_contents = [valid_chunks[i] for i in top_k_indices.tolist()] + + # Find original start lines to sort chronologically (approximate) + chunk_start_lines = {} + current_line = 0 + lines = background_code.splitlines() + chunk_map_from_sliding = chunk_sliding_window(background_code, window_size, overlap) # Re-chunk to get consistent indexing if needed + start_line_num = 0 + stride = window_size - overlap + for i, chunk_content in enumerate(chunk_map_from_sliding): + # This assumes the chunking function returns chunks in order + chunk_start_lines[chunk_content] = start_line_num + start_line_num += stride + # Rough approximation, doesn't perfectly handle edge cases/final chunks + + sorted_relevant_chunks = sorted( + retrieved_chunk_contents, + key=lambda content: chunk_start_lines.get(content, float('inf')) # Sort by approximate start line + ) + + # Combine relevant chunks + # Original implementation joined with \n, let's keep it simple + combined_code = "\n\n".join(sorted_relevant_chunks) # Separate chunks by double newline for clarity + + return combined_code + + +# Helper function for LLMLingua compression +def compress_llmlingua(context: str, query: str, compressor: PromptCompressor, target_token: int, instruction: str) -> str: + """Compresses context using LLMLingua.""" + if not context.strip(): + return "" + try: + # Ensure no "<|endoftext|>" + context_clean = context.replace("<|endoftext|>", "") + compressed = compressor.compress_prompt( + context_clean, + instruction=instruction, + question=query + "\n" + instruction, # Combine query and instruction for question + target_token=target_token + ) + # Ensure result exists and is string + result = compressed.get('compressed_prompt', '') + return result if isinstance(result, str) else "" + except Exception as e: + logger.error(f"LLMLingua compression failed: {e}") + # Fallback: Truncate based on target tokens (approximate) + tokens = compressor.tokenizer.encode(context_clean) + if len(tokens) > target_token: + return compressor.tokenizer.decode(tokens[:target_token]) + return context_clean + + +# Helper function for LongLLMLingua compression +def compress_longllmlingua(context: str, query: str, compressor: PromptCompressor, target_token: int, instruction: str, chunk_size: int, overlap: int) -> str: + """Compresses context using LongLLMLingua with sliding window chunks.""" + if not context.strip(): + return "" + try: + # Ensure no "<|endoftext|>" + context_clean = context.replace("<|endoftext|>", "") + # Use our sliding window chunker + chunks = chunk_sliding_window(context_clean, chunk_size, overlap) + if not chunks: + return "" # Handle case where context is too short or chunking fails + + compressed = compressor.compress_prompt( + chunks, + instruction=instruction, + question=query + "\n" + instruction, # Combine query and instruction for question + target_token=target_token, + rank_method="longllmlingua" # Use the specified rank method + ) + # Ensure result exists and is string + result = compressed.get('compressed_prompt', '') + return result if isinstance(result, str) else "" + except Exception as e: + logger.error(f"LongLLMLingua compression failed: {e}") + # Fallback: Truncate based on target tokens (approximate) + tokens = compressor.tokenizer.encode(context_clean) + if len(tokens) > target_token: + return compressor.tokenizer.decode(tokens[:target_token]) + return context_clean + + +# Helper function for CodeCompressor +def compress_code_compressor(context: str, query: str, compressor, target_token: int, instruction: str, language: str, rank_only: bool, fine_ratio: float, importance_beta: float) -> str: + """Compresses context using CodeCompressor based on target tokens and rank_only flag.""" + if not context.strip(): + return "" + try: + # Ensure no "<|endoftext|>" + context_clean = context.replace("<|endoftext|>", "") + if not context_clean.strip(): + return "" # Return empty if clean context is empty + + # Tokenize to get original length + # Use the compressor's tokenizer + original_tokens = len(compressor.tokenizer.encode(context_clean)) + if original_tokens == 0: + return "" # Avoid division by zero + + # Calculate target ratio + target_ratio = min(1.0, max(0.0, target_token / original_tokens)) + logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}") + + # Pass rank_only and fine_ratio + # Assuming compressor is already initialized with the correct model + compressed_result = compressor.compress_code_file( + code=context_clean, + query=query, # Using current function context as query focus + instruction=instruction, + rate=target_ratio, + language=language, + rank_only=rank_only, # Ensure rank_only mode is set + fine_ratio=fine_ratio if not rank_only else None, # Pass fine_ratio only if not rank_only + importance_beta=importance_beta if not rank_only else None, # Pass importance_beta only if not rank_only + ) + + # Extract compressed content - check both possible keys + compressed_context = compressed_result.get("compressed_code") + + if not isinstance(compressed_context, str): + logger.error(f"CodeCompressor returned non-string: {type(compressed_context)}") + compressed_context = "" # Fallback + + # Log results + compressed_tokens_count = len(compressor.tokenizer.encode(compressed_context)) + final_ratio = (compressed_tokens_count / original_tokens) if original_tokens > 0 else 0 + logger.info(f"CodeCompressor: Compressed tokens={compressed_tokens_count}, Actual ratio={final_ratio:.4f}") + + return compressed_context + + except Exception as e: + logger.error(f"CodeCompressor compression failed: {e}", exc_info=True) + # Fallback: Truncate approximately based on target tokens (less ideal for rank_only) + tokens = compressor.tokenizer.encode(context_clean) + if len(tokens) > target_token: + logger.warning(f"CodeCompressor falling back to simple truncation.") + return compressor.tokenizer.decode(tokens[:target_token]) + return context_clean + + +# Helper function for splitting code by functions (from main_lcc.py) +def split_code_by_functions_standalone(code: str, language: str = "python") -> list[str]: + """ + Split code into chunks based on function and class definitions for various languages. + Standalone version that doesn't require CodeCompressor instance. + + Args: + code: The code to split + language: Programming language of the code (python, cpp, java, typescript, rust, go) + + Returns: + List of code chunks, each containing a function, class, or class method + """ + import re + + # Define regex patterns for different languages + patterns = { + # Python: Simplified to match 'def' or 'class' followed by content until the next def/class or end + "python": r'(^|\n)(\s*)(def|class)\s+[^\n]+(\n(?!\s*(?:def|class)\s)[^\n]*)*', + # C++: Improved to better handle multi-line declarations + "cpp": r'(^|\n)(\s*)(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s*:\s*[^{]*)?|(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*\s+)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?', + # Java: Improved for multi-line method declarations + "java": r'(^|\n)(\s*)(?:(?:public|private|protected|static|final|abstract|synchronized)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:<.*>)?(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*)\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*throws\s+[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?', + # TypeScript: Enhanced to handle multi-line methods and arrow functions + "typescript": r'(^|\n)(\s*)(?:(?:public|private|protected|static|abstract)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:(?:public|private|protected|static|async)\s+)*(?:function\s+)?(?:[a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<.*>)?\s*\([^{;]*\)\s*(?::\s*[^{;]*\s*)?(?:=>)?)\s*(?:{[^}]*}|[^;]*;)?', + # Rust: Improved for multi-line function declarations + "rust": r'(^|\n)(\s*)(?:pub\s+)?(?:struct\s+[a-zA-Z_][a-zA-Z0-9_]*|impl(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+for\s+[a-zA-Z_][a-zA-Z0-9_]*)?|(?:async\s+)?fn\s+[a-zA-Z_][a-zA-Z0-9_]*\s*(?:<.*>)?\s*\([^{;]*\)(?:\s*->\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?', + # Go: Improved for multi-line function declarations + "go": r'(^|\n)(\s*)(?:type\s+[a-zA-Z_][a-zA-Z0-9_]*\s+struct|func\s+(?:\([^)]*\)\s*)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?', + } + + # Use default Python pattern if language not supported + if language.lower() not in patterns: + language = "python" + + function_pattern = re.compile(patterns[language.lower()], re.MULTILINE) + matches = list(function_pattern.finditer(code)) + + if not matches: + return [code] if code.strip() else [] # No matches, return whole code if not empty + + result_chunks = [] + + # Add code before first match if exists + if matches[0].start() > 0: + pre_code = code[:matches[0].start()].strip() + if pre_code: + result_chunks.append(pre_code) + + # Process each match + for i, match in enumerate(matches): + start = match.start() + + # End is either start of next match or end of code + if i < len(matches) - 1: + end = matches[i + 1].start() + else: + end = len(code) + + chunk = code[start:end].strip() + if chunk: + result_chunks.append(chunk) + + return result_chunks + + +# Helper function for function-level RAG retrieval with budget +def function_rag_retrieve(background_code: str, query_code: str, model, tokenizer, device, language: str, budget: int) -> str: + """Uses function-level chunking and retrieves functions within the specified token budget.""" + if not background_code.strip(): + return "" # Return empty if no background context + + # Split code into function-based chunks + chunks = split_code_by_functions_standalone(background_code, language) + if not chunks: + return "" # Return empty if chunking results in nothing + + query_embedding = compute_embedding(query_code, model, tokenizer, device) + + chunk_embeddings = [] + valid_chunks = [] + for chunk in chunks: + if chunk.strip(): + chunk_embeddings.append(compute_embedding(chunk, model, tokenizer, device)) + valid_chunks.append(chunk) + + if not valid_chunks: + return "" + + # Stack embeddings for efficient similarity calculation + chunk_embeddings_tensor = torch.stack(chunk_embeddings) + + # Compute cosine similarity + similarities = torch.cosine_similarity(query_embedding.unsqueeze(0), chunk_embeddings_tensor, dim=1) + + # Sort chunks by similarity score (descending) + sorted_indices = torch.argsort(similarities, descending=True) + + # Select chunks within budget + selected_chunks = [] + current_tokens = 0 + + for idx in sorted_indices: + chunk = valid_chunks[idx.item()] + + # Calculate tokens for this chunk + chunk_tokens = len(tokenizer.encode(chunk, add_special_tokens=False)) + + # Check if adding this chunk would exceed budget + if current_tokens + chunk_tokens <= budget: + selected_chunks.append((chunk, similarities[idx].item())) + current_tokens += chunk_tokens + else: + # Try to partially include the chunk if there's remaining budget + remaining_budget = budget - current_tokens + if remaining_budget > 50: # Only include if we have at least 50 tokens left + chunk_tokens_list = tokenizer.encode(chunk, add_special_tokens=False) + if len(chunk_tokens_list) > remaining_budget: + # Truncate the chunk to fit the remaining budget + truncated_tokens = chunk_tokens_list[:remaining_budget] + truncated_chunk = tokenizer.decode(truncated_tokens, skip_special_tokens=True) + selected_chunks.append((truncated_chunk, similarities[idx].item())) + current_tokens = budget + break + + # Stop if we've reached the budget + if current_tokens >= budget: + break + + if not selected_chunks: + return "" + + # Sort selected chunks by their original position in the code to maintain structure + # We'll use the similarity score for this approximation since we don't have direct position info + selected_chunks.sort(key=lambda x: x[1], reverse=True) # Keep similarity order for now + + # Combine selected chunks + combined_code = "\n\n".join([chunk for chunk, _ in selected_chunks]) + + logger.info(f"Function RAG: Selected {len(selected_chunks)} functions using {current_tokens}/{budget} tokens") + + return combined_code + + +async def async_get_metric(scorer, intent, code_context, gold_doc, pred_doc): + logger.debug(f"Evaluating intent: {intent}") + logger.debug(f"Gold doc: {truncate_text(gold_doc)}") + logger.debug(f"Pred doc: {truncate_text(pred_doc)}") + logger.debug(f"Gold doc length: {len(gold_doc)}, Pred doc length: {len(pred_doc)}") + + prompt = f'I have 2 different documentations about {intent}. Decide which documentation is better: documentation A or documentation B.\n\n' + prompt += f'My code:\n\n{code_context}\n\n\n\n' + prompt += f'Documentation A:\n\n{gold_doc}\n\n\n\n' + prompt += f'Documentation B:\n\n{pred_doc}\n\n\n\n' + prompt += 'Please directly return the option that is better (A or B) without any other text.' + + options = ["A", "B"] + unnorm_logprobs = await scorer.async_score_options(prompt, options) + norm_probs1 = torch.exp(torch.log_softmax(unnorm_logprobs, dim=0)) + + prompt = f'I have 2 different documentations about {intent}. Decide which documentation is better: documentation A or documentation B.\n\n' + prompt += f'My code:\n\n{code_context}\n\n\n\n' + prompt += f'Documentation A:\n\n{pred_doc}\n\n\n\n' + prompt += f'Documentation B:\n\n{gold_doc}\n\n\n\n' + prompt += 'Please directly return the option that is better (A or B) without any other text.' + unnorm_logprobs = await scorer.async_score_options(prompt, options) + norm_probs2 = torch.exp(torch.log_softmax(unnorm_logprobs, dim=0)) + + p_better1 = (norm_probs1[1] + norm_probs2[0]) / 2 + logger.debug(f"First evaluation: {norm_probs1}, Second evaluation: {norm_probs2}, Final score: {p_better1}") + + return float(p_better1) + + +def get_metric(scorer, intent, code_context, gold_doc, pred_doc): + logger.debug(f"Evaluating intent: {intent}") + logger.debug(f"Gold doc: {truncate_text(gold_doc)}") + logger.debug(f"Pred doc: {truncate_text(pred_doc)}") + logger.debug(f"Gold doc length: {len(gold_doc)}, Pred doc length: {len(pred_doc)}") + + prompt = f'I have 2 different documentations about {intent}. Decide which documentation is better: documentation A or documentation B.\n\n' + prompt += f'My code:\n\n{code_context}\n\n\n\n' + prompt += f'Documentation A:\n\n{gold_doc}\n\n\n\n' + prompt += f'Documentation B:\n\n{pred_doc}\n\n\n\n' + prompt += 'Please directly return the option that is better (A or B) without any other text.' + + options = ["A", "B"] + unnorm_logprobs = scorer.score_options(prompt, options) + norm_probs1 = torch.exp(torch.log_softmax(unnorm_logprobs, dim=0)) + + prompt = f'I have 2 different documentations about {intent}. Decide which documentation is better: documentation A or documentation B.\n\n' + prompt += f'My code:\n\n{code_context}\n\n\n\n' + prompt += f'Documentation A:\n\n{pred_doc}\n\n\n\n' + prompt += f'Documentation B:\n\n{gold_doc}\n\n\n\n' + prompt += 'Please directly return the option that is better (A or B) without any other text.' + unnorm_logprobs = scorer.score_options(prompt, options) + norm_probs2 = torch.exp(torch.log_softmax(unnorm_logprobs, dim=0)) + + p_better1 = (norm_probs1[1] + norm_probs2[0]) / 2 + logger.debug(f"First evaluation: {norm_probs1}, Second evaluation: {norm_probs2}, Final score: {p_better1}") + + return float(p_better1) + + +async def evaluate_batch(batch_data, scorer, samples_data, method, is_async=True): + """Evaluate a batch of samples""" + results = [] + + for item in batch_data: + idx, row = item + gold_doc = row['target_text'] + + # Skip if sample data doesn't exist + if idx >= len(samples_data) or not samples_data[idx]: + logger.warning(f"Sample data not found for sample {idx}. Skipping evaluation.") + continue + + # Get sample data + sample_data = samples_data[idx] + pred_doc = sample_data.get('generated_text', '') + + code_context = row['relevant_code_context'] + + # Use the appropriate metric function based on whether the scorer is async + if is_async: + metric = await async_get_metric(scorer, row['intent'], code_context, gold_doc, pred_doc) + else: + # For synchronous scorers, run in an executor to not block the event loop + loop = asyncio.get_running_loop() + metric = await loop.run_in_executor( + None, get_metric, scorer, row['intent'], code_context, gold_doc, pred_doc + ) + + # Update sample data with evaluation score + sample_data['generation_score'] = float(metric) + + results.append((idx, metric, sample_data)) + + return results + + +async def run_parallel_evaluation(dataset, scorer, samples_data, method, num_processes=4, is_async=True): + """Run evaluation in parallel using specified number of processes""" + # Prepare data with indices + indexed_data = list(enumerate(dataset)) + + # Split data into chunks for each process + chunk_size = len(indexed_data) // num_processes + if chunk_size == 0: + chunk_size = 1 + + batches = [indexed_data[i:i+chunk_size] for i in range(0, len(indexed_data), chunk_size)] + + # Ensure we don't create more batches than needed + batches = batches[:num_processes] + + # Create tasks for each batch + tasks = [evaluate_batch(batch, scorer, samples_data, method, is_async) for batch in batches] + + # Run all batches concurrently and collect results + batch_results = await asyncio.gather(*tasks) + + # Flatten results + all_results = [] + for batch in batch_results: + all_results.extend(batch) + + # Sort by sample index + all_results.sort(key=lambda x: x[0]) + + # Extract metrics and metadata + metrics = [r[1] for r in all_results] + detailed_results = [r[2] for r in all_results] + + return metrics, detailed_results + + +def run_documentation_task( + # Generation model parameters + gen_model: str = "Qwen/Qwen2.5-Coder-7B-Instruct", + compress_model: str = None, + model_name: str = None, + # Evaluation model parameters + eval_model: str = "gpt-4o-mini", + # Common model parameters + device: str = "cuda", + tensor_parallel_size: int = 1, + # Generation parameters + max_tokens: int = 2048, + temperature: float = 0.0, + # Context method parameters + method: str = "full", + # RAG parameters + rag_window_size: int = 80, + rag_overlap: int = 40, + rag_top_k: int = 3, + embed_model_name: str = "microsoft/unixcoder-base", + # Function RAG parameters + function_rag_language: str = "python", + function_rag_budget: int = 1024, + # LLMLingua parameters + lingua_target_token: int = 500, + lingua_instruction: str = "Generate documentation based on this code.", + # LongLLMLingua parameters + longlingua_chunk_size: int = 80, + longlingua_overlap: int = 40, + # CodeCompressor parameters + code_compressor_target_token: int = 500, + code_compressor_fine_ratio: float = 1.0, + importance_beta: float = 0.0, + # Task parameters + mode: str = "both", + save_dir: str = "./predictions", + hf_api_key: str = None, + max_examples: int = None, + use_llm_scorer: bool = False, + # Parallel evaluation parameters + num_eval_processes: int = 4 +): + """Run documentation generation and evaluation with the specified parameters.""" + + # Get model short name from argument or extract from model path + model_short_name = model_name + if model_short_name is None: + # Extract model name from path - use last component after / or use the whole string + model_short_name = gen_model.split('/')[-1] if '/' in gen_model else gen_model + + if compress_model is None: + compress_model = gen_model + logger.info(f"Using generation model for compression: {compress_model}") + + # Create method-specific suffix for results directory + method_suffix = f"method_{method}" + if method == "rag": + method_suffix += f"_w{rag_window_size}_o{rag_overlap}_k{rag_top_k}" + elif method == "function_rag": + method_suffix += f"_lang{function_rag_language}_b{function_rag_budget}" + elif method == "llmlingua": + method_suffix += f"_t{lingua_target_token}" + elif method == "longllmlingua": + method_suffix += f"_t{lingua_target_token}_cs{longlingua_chunk_size}_o{longlingua_overlap}" + elif method == "code_compressor": + # Determine if rank_only based on fine_ratio + rank_only_for_suffix = (code_compressor_fine_ratio == 1.0) + suffix_detail = "_rankonly" if rank_only_for_suffix else f"_fr{code_compressor_fine_ratio}" + # Add importance_beta to suffix + if importance_beta > 0: + suffix_detail += f"_b{importance_beta}" + method_suffix += f"_t{code_compressor_target_token}{suffix_detail}" + + # Create method-specific directory + model_save_dir = os.path.join(save_dir, method_suffix, model_short_name) + if not os.path.exists(model_save_dir): + os.makedirs(model_save_dir) + + # Path to our single results JSON file + results_json_path = os.path.join(model_save_dir, "detailed_results.json") + + # Load dataset + print("Loading dataset") + dataset = load_dataset_samples( + max_examples=max_examples, + hf_api_key=hf_api_key + ) + + # Common model args for both generator and scorer + model_args = { + "tensor_parallel_size": tensor_parallel_size, + } + + # Initialize or load samples data + samples_data = [] + + # Check if results file exists (for continuing an interrupted run) + if os.path.exists(results_json_path): + try: + with open(results_json_path, 'r') as f: + existing_results = json.load(f) + # Extract existing samples data + samples_data = existing_results.get('samples', []) + # Ensure we have enough entries for all samples + if len(samples_data) < len(dataset): + samples_data.extend([None] * (len(dataset) - len(samples_data))) + print(f"Loaded existing results with {len(samples_data)} samples.") + except Exception as e: + print(f"Error loading existing results: {e}. Starting fresh.") + samples_data = [None] * len(dataset) + else: + # Initialize with empty slots for each sample + samples_data = [None] * len(dataset) + + # Generation phase - split into compression and generation steps + if mode in ["generate", "both"]: + # Step 1: Compress contexts first if needed + if method not in ["full", "no_context"]: + print(f"Step 1: Preparing contexts using {method} method...") + + # Initialize context preparation models based on method + embed_model = None + embed_tokenizer = None + lingua_compressor = None + code_compressor_instance = None + + if method == "rag": + print(f"Initializing embedding model: {embed_model_name}") + embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name) + embed_model = AutoModel.from_pretrained(embed_model_name).to(device) + embed_model.eval() # Set to evaluation mode + + if method == "function_rag": + print(f"Initializing embedding model for function RAG: {embed_model_name}") + embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name) + embed_model = AutoModel.from_pretrained(embed_model_name).to(device) + embed_model.eval() # Set to evaluation mode + + if method == "llmlingua" or method == "longllmlingua": + print(f"Initializing LLMLingua compressor") + lingua_compressor = PromptCompressor(model_name=gen_model, device_map="auto") + + if method == "code_compressor": + try: + print(f"Initializing CodeCompressor") + code_compressor_instance = CodeCompressor(gen_model) + except Exception as e: + print(f"Failed to initialize CodeCompressor: {e}. Falling back to full context.") + method = "full" + + # Process and compress all contexts + for idx, row in tqdm(enumerate(dataset), total=len(dataset), desc="Compressing contexts"): + # If sample already has context processing, skip + if samples_data[idx] and 'processed_context' in samples_data[idx]: + continue + + # Get the context + code_context = row['relevant_code_context'] + + # Process context based on the selected method + processed_context = code_context + language = "python" # Default language, could be determined dynamically + + try: + if method == "rag": + # Split the context and retrieve relevant parts + background_ctx = code_context + query_ctx = f"Generate documentation for {row['docfile_name']} about {row['intent']}" + processed_context = rag_retrieve( + background_ctx, query_ctx, + embed_model, embed_tokenizer, device, + rag_window_size, rag_overlap, rag_top_k + ) + elif method == "function_rag": + # Split the context and retrieve relevant functions within budget + background_ctx = code_context + query_ctx = f"Generate documentation for {row['docfile_name']} about {row['intent']}" + processed_context = function_rag_retrieve( + background_ctx, query_ctx, + embed_model, embed_tokenizer, device, + function_rag_language, function_rag_budget + ) + elif method == "llmlingua": + background_ctx = code_context + query_ctx = f"Generate documentation for {row['docfile_name']} about {row['intent']}" + processed_context = compress_llmlingua( + background_ctx, query_ctx, + lingua_compressor, lingua_target_token, lingua_instruction + ) + elif method == "longllmlingua": + background_ctx = code_context + query_ctx = f"Generate documentation for {row['docfile_name']} about {row['intent']}" + processed_context = compress_longllmlingua( + background_ctx, query_ctx, + lingua_compressor, lingua_target_token, lingua_instruction, + longlingua_chunk_size, longlingua_overlap + ) + elif method == "code_compressor": + # Determine rank_only based on fine_ratio + rank_only = (code_compressor_fine_ratio == 1.0) + logger.info(f"CodeCompressor mode: {'Rank Only' if rank_only else f'Fine-grained (ratio={code_compressor_fine_ratio})'}") + background_ctx = code_context + query_ctx = f"Generate documentation for {row['docfile_name']} about {row['intent']}" + processed_context = compress_code_compressor( + context=background_ctx, + query=query_ctx, + compressor=code_compressor_instance, + target_token=code_compressor_target_token, + instruction=lingua_instruction, + language=language, + rank_only=rank_only, + fine_ratio=code_compressor_fine_ratio, + importance_beta=importance_beta + ) + except Exception as e: + logger.error(f"Error during context preparation with method {method}: {e}", exc_info=True) + # Fallback to full context in case of error + processed_context = code_context + + # Create or update sample data + sample_data = samples_data[idx] or {} + sample_data.update({ + 'sample_id': idx, + 'intent': row['intent'], + 'docfile_name': row['docfile_name'], + 'target_text': row['target_text'], + 'original_context': code_context, + 'processed_context': processed_context, + 'context_compression': { + 'method': method, + 'original_length': len(code_context), + 'processed_length': len(processed_context), + 'compression_ratio': len(processed_context) / len(code_context) if len(code_context) > 0 else 1.0, + 'method_params': { + 'type': method, + 'rag_params': { + 'window_size': rag_window_size, + 'overlap': rag_overlap, + 'top_k': rag_top_k + } if method == "rag" else None, + 'function_rag_params': { + 'language': function_rag_language, + 'budget': function_rag_budget + } if method == "function_rag" else None, + 'llmlingua_params': { + 'target_token': lingua_target_token + } if method == "llmlingua" else None, + 'longllmlingua_params': { + 'target_token': lingua_target_token, + 'chunk_size': longlingua_chunk_size, + 'overlap': longlingua_overlap + } if method == "longllmlingua" else None, + 'code_compressor_params': { + 'target_token': code_compressor_target_token, + 'rank_only': (code_compressor_fine_ratio == 1.0), + 'fine_ratio': code_compressor_fine_ratio, + 'importance_beta': importance_beta + } if method == "code_compressor" else None + } + } + }) + + # Update samples data + samples_data[idx] = sample_data + + # Save the updated results file periodically + if idx % 10 == 0 or idx == len(dataset) - 1: + results_data = { + 'model': model_short_name, + 'method': method, + 'method_params': { + 'type': method, + 'rag_params': { + 'window_size': rag_window_size, + 'overlap': rag_overlap, + 'top_k': rag_top_k + } if method == "rag" else None, + 'function_rag_params': { + 'language': function_rag_language, + 'budget': function_rag_budget + } if method == "function_rag" else None, + 'llmlingua_params': { + 'target_token': lingua_target_token + } if method == "llmlingua" else None, + 'longllmlingua_params': { + 'target_token': lingua_target_token, + 'chunk_size': longlingua_chunk_size, + 'overlap': longlingua_overlap + } if method == "longllmlingua" else None, + 'code_compressor_params': { + 'target_token': code_compressor_target_token, + 'rank_only': (code_compressor_fine_ratio == 1.0), + 'fine_ratio': code_compressor_fine_ratio, + 'importance_beta': importance_beta + } if method == "code_compressor" else None + }, + 'average_score': None, # Will be filled during evaluation + 'samples': samples_data + } + # make sure the results_json_path is a valid path + if not os.path.exists(os.path.dirname(results_json_path)): + os.makedirs(os.path.dirname(results_json_path)) + with open(results_json_path, 'w') as f: + json.dump(results_data, f, indent=2) + + # Free up context preparation resources + print("Cleaning up context preparation resources...") + if embed_model: + del embed_model + if embed_tokenizer: + del embed_tokenizer + if lingua_compressor: + del lingua_compressor + if code_compressor_instance: + del code_compressor_instance + torch.cuda.empty_cache() + gc.collect() + + # Step 2: Generate documentation using the (potentially compressed) contexts + print(f"Step 2: Initializing generation model: {gen_model}") + generator = LLMGenerator(gen_model, device, **model_args) + + # Define a token limit for context when method is "full" + MAX_CONTEXT_TOKENS_FOR_FULL_METHOD = 30000 + + print(f"Generating documentation...") + for idx, row in tqdm(enumerate(dataset), total=len(dataset), desc="Generating documentation"): + # Skip if this sample already has generated text + if samples_data[idx] and 'generated_text' in samples_data[idx]: + continue + + # Create or load sample data + sample_data = samples_data[idx] or {} + + # Determine context to use + if method not in ["full", "no_context"] and sample_data.get('processed_context'): + context = sample_data.get('processed_context') + else: + # For full or no_context methods + context = row['relevant_code_context'] + + # For no_context, use minimal information + if method == "no_context": + context = f"Generate documentation for {row['docfile_name']} about {row['intent']}" + + # Update sample data with context info if not already there + if 'original_context' not in sample_data: + sample_data.update({ + 'sample_id': idx, + 'intent': row['intent'], + 'docfile_name': row['docfile_name'], + 'target_text': row['target_text'], + 'original_context': row['relevant_code_context'], + 'processed_context': None if method == "full" else context + }) + + # Truncate context if method is "full" and context is too long + if method == "full": + context_tokens = generator.tokenizer.encode(context) + if len(context_tokens) > MAX_CONTEXT_TOKENS_FOR_FULL_METHOD: + logger.warning(f"Sample {idx}: Context for 'full' method was too long ({len(context_tokens)} tokens). Truncating to {MAX_CONTEXT_TOKENS_FOR_FULL_METHOD} tokens.") + truncated_tokens = context_tokens[:MAX_CONTEXT_TOKENS_FOR_FULL_METHOD] + context = generator.tokenizer.decode(truncated_tokens, skip_special_tokens=True) + # Update context length in sample_data if it was already populated + if 'context_length' in sample_data: + sample_data['context_length'] = len(context) + + # Generate documentation + prompt = f'Using the code provided, generate documentation for {row["docfile_name"]} about {row["intent"]}.\n\n' + prompt += f'Code:\n\n{context}' + prompt += f'\n\n\nReturn only the documentation text for {row["docfile_name"]} about {row["intent"]}. Do not include instructions or explanations.' + + generated_doc = generator.generate(prompt, max_tokens, temperature) + + # Update sample data with generated text + sample_data.update({ + 'generated_text': generated_doc, + 'target_text_length': len(row['target_text']), + 'generated_text_length': len(generated_doc), + 'context_length': len(context), + 'method': method, + 'method_params': { + 'type': method, + 'rag_params': { + 'window_size': rag_window_size, + 'overlap': rag_overlap, + 'top_k': rag_top_k + } if method == "rag" else None, + 'function_rag_params': { + 'language': function_rag_language, + 'budget': function_rag_budget + } if method == "function_rag" else None, + 'llmlingua_params': { + 'target_token': lingua_target_token + } if method == "llmlingua" else None, + 'longllmlingua_params': { + 'target_token': lingua_target_token, + 'chunk_size': longlingua_chunk_size, + 'overlap': longlingua_overlap + } if method == "longllmlingua" else None, + 'code_compressor_params': { + 'target_token': code_compressor_target_token, + 'rank_only': (code_compressor_fine_ratio == 1.0), + 'fine_ratio': code_compressor_fine_ratio, + 'importance_beta': importance_beta + } if method == "code_compressor" else None + } + }) + + # Update samples data + samples_data[idx] = sample_data + + # Save the updated results file periodically + if idx % 10 == 0 or idx == len(dataset) - 1: + results_data = { + 'model': model_short_name, + 'method': method, + 'method_params': { + 'type': method, + 'rag_params': { + 'window_size': rag_window_size, + 'overlap': rag_overlap, + 'top_k': rag_top_k + } if method == "rag" else None, + 'function_rag_params': { + 'language': function_rag_language, + 'budget': function_rag_budget + } if method == "function_rag" else None, + 'llmlingua_params': { + 'target_token': lingua_target_token + } if method == "llmlingua" else None, + 'longllmlingua_params': { + 'target_token': lingua_target_token, + 'chunk_size': longlingua_chunk_size, + 'overlap': longlingua_overlap + } if method == "longllmlingua" else None, + 'code_compressor_params': { + 'target_token': code_compressor_target_token, + 'rank_only': (code_compressor_fine_ratio == 1.0), + 'fine_ratio': code_compressor_fine_ratio, + 'importance_beta': importance_beta + } if method == "code_compressor" else None + }, + 'average_score': None, # Will be filled during evaluation + 'samples': samples_data + } + # make sure the results_json_path is a valid path + if not os.path.exists(os.path.dirname(results_json_path)): + os.makedirs(os.path.dirname(results_json_path)) + with open(results_json_path, 'w') as f: + json.dump(results_data, f, indent=2) + + # Free up memory after generation + print("Freeing generator memory...") + generator.free_memory() + del generator + torch.cuda.empty_cache() + gc.collect() + + # Evaluation phase + if mode in ["evaluate", "both"]: + # Initialize the scorer based on the model type + if use_llm_scorer: + print(f"Initializing LLM evaluation model: {eval_model}") + scorer = LLMScorer(eval_model, device, **model_args) + is_async = False + else: + print(f"Initializing GPT evaluation model: {eval_model}") + scorer = GPTScorer(eval_model, **model_args) + is_async = True + + print(f"Evaluating documentation with {num_eval_processes} parallel processes...") + + # Use asyncio to run evaluation in parallel + metrics, detailed_results = asyncio.run( + run_parallel_evaluation(dataset, scorer, samples_data, method, num_eval_processes, is_async) + ) + + # Update samples data with evaluation scores + for idx, result in enumerate(detailed_results): + if idx < len(samples_data) and samples_data[idx]: + # Update with evaluation score + samples_data[idx]['generation_score'] = result.get('generation_score') + + average_metric = np.mean([s.get('generation_score', 0) for s in samples_data if s and 'generation_score' in s]) + print(f"Average evaluation metric: {average_metric:.4f}") + + # Save evaluation results + if not os.path.exists(os.path.dirname(results_json_path)): + os.makedirs(os.path.dirname(results_json_path)) + with open(os.path.join(model_save_dir, "metrics.txt"), 'w') as f: + f.write(f"Average metric: {average_metric:.4f}\n") + f.write("Individual metrics:\n") + for idx, sample in enumerate(samples_data): + if sample and 'generation_score' in sample: + f.write(f"Sample {idx}: {sample['generation_score']:.4f}\n") + + # Save final detailed results + if not os.path.exists(os.path.dirname(results_json_path)): + os.makedirs(os.path.dirname(results_json_path)) + with open(results_json_path, 'w') as f: + results_data = { + 'model': model_short_name, + 'method': method, + 'method_params': { + 'type': method, + 'rag_params': { + 'window_size': rag_window_size, + 'overlap': rag_overlap, + 'top_k': rag_top_k + } if method == "rag" else None, + 'function_rag_params': { + 'language': function_rag_language, + 'budget': function_rag_budget + } if method == "function_rag" else None, + 'llmlingua_params': { + 'target_token': lingua_target_token + } if method == "llmlingua" else None, + 'longllmlingua_params': { + 'target_token': lingua_target_token, + 'chunk_size': longlingua_chunk_size, + 'overlap': longlingua_overlap + } if method == "longllmlingua" else None, + 'code_compressor_params': { + 'target_token': code_compressor_target_token, + 'rank_only': (code_compressor_fine_ratio == 1.0), + 'fine_ratio': code_compressor_fine_ratio, + 'importance_beta': importance_beta + } if method == "code_compressor" else None + }, + 'average_score': float(average_metric), + 'samples': samples_data + } + json.dump(results_data, f, indent=2) + + # Free up scorer memory + scorer.free_memory() + + +if __name__ == "__main__": + + fire.Fire(run_documentation_task) \ No newline at end of file diff --git a/module_summarization/run.sh b/module_summarization/run.sh new file mode 100644 index 0000000..460e30c --- /dev/null +++ b/module_summarization/run.sh @@ -0,0 +1,51 @@ +export CUDA_VISIBLE_DEVICES=0 + +MODEL_NAME="Qwen/Qwen2.5-Coder-7B-Instruct" +MODEL_PATH_NAME="qwencoder-7b-instruct" +BASE_RESULT_DIR="results/${MODEL_PATH_NAME}" +BASE_LOG_DIR="logs/${MODEL_PATH_NAME}" + +mkdir -p ${BASE_LOG_DIR} +mkdir -p ${BASE_RESULT_DIR} + +echo "Starting experiments for ${MODEL_NAME} on GPU ${CUDA_VISIBLE_DEVICES}" + +# --- CodeCompressor Method (Fine-grained with Beta) --- +TARGET_TOKENS=(4096) +FINE_RATIOS=(0.5) +BETAS=(0.5) + +echo "--- Running CodeCompressor (Fine-grained with various Beta values) ---" +for ratio in "${FINE_RATIOS[@]}"; do + for tokens in "${TARGET_TOKENS[@]}"; do + if [[ "${ratio}" == "1.0" ]]; then + # If fine_ratio is 1.0, only use default beta 0.0 + beta=0.0 + echo "Running CodeCompressor (Fine-grained): target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}" + python main.py \ + --gen_model ${MODEL_NAME} \ + --model_name ${MODEL_PATH_NAME} \ + --method code_compressor \ + --save_dir "${BASE_RESULT_DIR}" \ + --code_compressor_target_token ${tokens} \ + --code_compressor_fine_ratio ${ratio} \ + --importance_beta ${beta} > "${BASE_LOG_DIR}/code_compressor_t${tokens}_fr${ratio}_b${beta}.log" 2>&1 + echo "Finished CodeCompressor (Fine-grained): target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}" + else + # For other fine_ratios, test different beta values + for beta in "${BETAS[@]}"; do + echo "Running CodeCompressor (Fine-grained): target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}" + python main.py \ + --gen_model ${MODEL_NAME} \ + --model_name ${MODEL_PATH_NAME} \ + --method code_compressor \ + --save_dir "${BASE_RESULT_DIR}" \ + --code_compressor_target_token ${tokens} \ + --code_compressor_fine_ratio ${ratio} \ + --importance_beta ${beta} > "${BASE_LOG_DIR}/code_compressor_t${tokens}_fr${ratio}_b${beta}.log" 2>&1 + echo "Finished CodeCompressor (Fine-grained): target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}" + done + fi + done +done +echo "--- Finished CodeCompressor ---" \ No newline at end of file diff --git a/module_summarization/utils.py b/module_summarization/utils.py new file mode 100644 index 0000000..73ba2a1 --- /dev/null +++ b/module_summarization/utils.py @@ -0,0 +1,142 @@ +import os +import matplotlib.pyplot as plt +import numpy as np +from datasets import load_dataset +from loguru import logger +from typing import List, Dict, Optional, Tuple, Any +from transformers import AutoTokenizer +from tqdm import tqdm + + +def truncate_text(text, max_len=512): + """Helper function to truncate long text for logging.""" + if len(text) <= max_len: + return text + return text[:max_len//2] + "\n...\n" + text[-max_len//2:] + + +def load_dataset_samples(dataset_name="JetBrains-Research/lca-module-summarization", split="test", + max_examples=None, hf_api_key=None, max_tokens=32768, min_tokens=1024): + """Load dataset samples with optional limiting and filtering of long examples.""" + dataset = load_dataset(dataset_name, token=hf_api_key)[split] + if max_examples is not None: + dataset = dataset.select(range(min(max_examples, len(dataset)))) + + # Filter out examples with extremely long code + if max_tokens > 0: + filtered_indices = [] + skipped_count_long = 0 + skipped_count_short = 0 + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct") + + for i, row in enumerate(tqdm(dataset, desc="Filtering long examples")): + code = row['relevant_code_context'] + if len(code) > max_tokens*10: + logger.warning(f"Skipping example {i} because it exceeds {max_tokens*10} characters ({len(code)}/{max_tokens*10})") + skipped_count_long += 1 + continue + tokens = tokenizer.encode(code, truncation=False) + if len(tokens) > max_tokens: + logger.warning(f"Skipping example {i} because it exceeds {max_tokens} tokens ({len(tokens)}/{max_tokens})") + skipped_count_long += 1 + continue + if len(tokens) < min_tokens: + logger.warning(f"Skipping example {i} because it is too short ({len(tokens)}/{min_tokens})") + skipped_count_short += 1 + continue + filtered_indices.append(i) + + if skipped_count_long > 0: + logger.info(f"Skipped {skipped_count_long} examples that exceeded token limit of {max_tokens}") + if skipped_count_short > 0: + logger.info(f"Skipped {skipped_count_short} examples that are too short ({min_tokens} tokens)") + + dataset = dataset.select(filtered_indices) + + return dataset + + +def get_actual_token_lengths(dataset, tokenizer, output_path="./analysis"): + """ + Calculate actual token lengths using the specified tokenizer. + + Args: + dataset: The dataset containing code samples + tokenizer: The tokenizer to use for counting tokens + output_path: Path to save analysis results and plots + + Returns: + Dict with statistics about token lengths + """ + # Create output directory if it doesn't exist + os.makedirs(output_path, exist_ok=True) + + # Extract actual token counts + token_lengths = [] + + # Print intent for each file + logger.info("\nIntent for each example:") + logger.info("======================") + + for i, row in enumerate(tqdm(dataset, desc="Calculating token lengths")): + code = row['relevant_code_context'] + tokens = tokenizer.encode(code, truncation=False) if hasattr(tokenizer, 'encode') else [] + token_len = len(tokens) + token_lengths.append(token_len) + + # Print the intent for each file + docfile_name = row.get('docfile_name', f'file_{i}') + intent = row.get('intent', 'unknown') + logger.info(f" Example {i}: {docfile_name} - Intent: {intent} - Token Length: {token_len}") + + # Calculate statistics + stats = { + 'min': min(token_lengths), + 'max': max(token_lengths), + 'mean': np.mean(token_lengths), + 'median': np.median(token_lengths), + 'p90': np.percentile(token_lengths, 90), + 'p95': np.percentile(token_lengths, 95), + 'p99': np.percentile(token_lengths, 99), + } + + # Plot token length histogram + plt.figure(figsize=(10, 6)) + plt.hist(token_lengths, bins=50, alpha=0.7) + plt.axvline(stats['mean'], color='red', linestyle='dashed', linewidth=1, label=f"Mean: {stats['mean']:.0f}") + plt.axvline(stats['median'], color='green', linestyle='dashed', linewidth=1, label=f"Median: {stats['median']:.0f}") + plt.axvline(stats['p90'], color='orange', linestyle='dashed', linewidth=1, label=f"90th %: {stats['p90']:.0f}") + plt.axvline(stats['p95'], color='purple', linestyle='dashed', linewidth=1, label=f"95th %: {stats['p95']:.0f}") + plt.title('Actual Code Length Distribution (Tokens)') + plt.xlabel('Tokens') + plt.ylabel('Count') + plt.legend() + plt.savefig(os.path.join(output_path, 'code_length_actual_tokens.png')) + + # Save statistics to a text file + with open(os.path.join(output_path, 'token_length_stats.txt'), 'w') as f: + f.write("Token Length Statistics\n") + f.write("=====================\n\n") + + for key, value in stats.items(): + f.write(f" {key}: {value:.2f}\n") + + # Return the statistics for further use + return stats + + +if __name__ == "__main__": + + dataset = load_dataset_samples(dataset_name="JetBrains-Research/lca-module-summarization", split="test", max_examples=1000) + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct") + token_stats = get_actual_token_lengths(dataset, tokenizer, "./analysis") + + # Print summary of findings using logger + logger.info("\nSummary of Code Length Analysis:") + logger.info("================================") + logger.info(f"Number of examples analyzed: {len(dataset)}") + + logger.info("\nActual token-based statistics:") + logger.info(f" Mean length: {token_stats['mean']:.0f} tokens") + logger.info(f" Median length: {token_stats['median']:.0f} tokens") + logger.info(f" 95th percentile: {token_stats['p95']:.0f} tokens") \ No newline at end of file diff --git a/repoqa/__init__.py b/repoqa/__init__.py new file mode 100644 index 0000000..9312c01 --- /dev/null +++ b/repoqa/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +try: + from repoqa._version import __version__, __version_tuple__ +except ImportError: + __version__ = "local-dev" diff --git a/repoqa/code_compressor.py b/repoqa/code_compressor.py new file mode 100644 index 0000000..4a23d69 --- /dev/null +++ b/repoqa/code_compressor.py @@ -0,0 +1,1544 @@ +import torch +import numpy as np +from typing import List, Union, Tuple, Dict, Optional +import re +import math +import zlib +from transformers import AutoModelForCausalLM, AutoTokenizer +import time +from tqdm import tqdm +import logging +import copy +import bisect +import json +from llmlingua import PromptCompressor + +# set up the logger +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger("CodeCompressor") + +class CodeCompressor: + def __init__( + self, + model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4", + device_map: str = "cuda", + model_config: dict = {}, + ): + """ + Initialize the CodeCompressor with a language model for compression. + + Args: + model_name: The name of the model to load from HuggingFace + device_map: Device to load the model on + model_config: Additional configuration for the model + """ + self.model_name = model_name + self.device = device_map + self.model_config = model_config + self.load_model(model_name, device_map, model_config) + + # Add caching system for model outputs and token information + self.cache = { + "token_length": {}, # Cache for token length by text + "encodings": {}, # Cache for tokenizer encodings + "perplexity": {}, # Cache for perplexity calculations + "conditional_ppl": {}, # Cache for conditional perplexity + "context_rankings": {}, # Cache for context rankings + } + self.max_cache_size = 1000 # Limit cache size to prevent memory issues + + # set up the max position embeddings and cache bos num + self.max_position_embeddings = getattr(self.model.config, "max_position_embeddings", 4096) + self.cache_bos_num = 10 + self.prefix_bos_num = 100 + self.context_idxs = [] + + def load_model( + self, model_name: str, device_map: str = "cuda", model_config: dict = {} + ): + """ + Load the language model and tokenizer. + + Args: + model_name: The name of the model to load + device_map: Device to load the model on + model_config: Additional configuration for the model + """ + logger.debug(f"Loading model {model_name} on {device_map}") + torch_dtype = torch.float16 if "torch_dtype" not in model_config else model_config["torch_dtype"] + model_kwargs = {"device_map": device_map, "torch_dtype": torch_dtype} + + for k, v in model_config.items(): + if k != "torch_dtype": + model_kwargs[k] = v + + self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.padding_side = "left" + + self.tokenizer_is_gpt = "gpt" in model_name.lower() + logger.debug("Model and tokenizer loaded successfully") + + def _manage_cache_size(self, cache_type): + """ + Manage cache size by removing oldest entries when cache exceeds max size. + + Args: + cache_type: The type of cache to manage + """ + if len(self.cache[cache_type]) > self.max_cache_size: + # Remove 20% of the oldest entries + remove_count = int(self.max_cache_size * 0.2) + keys_to_remove = list(self.cache[cache_type].keys())[:remove_count] + for key in keys_to_remove: + del self.cache[cache_type][key] + + def get_token_length( + self, + text: str, + add_special_tokens: bool = True, + ): + """ + Get the number of tokens in the given text. + + Args: + text: The text to tokenize + add_special_tokens: Whether to count special tokens + + Returns: + The number of tokens + """ + # Create a cache key based on text and parameters + cache_key = f"{text}_{add_special_tokens}" + + # Check if result is in cache + if cache_key in self.cache["token_length"]: + return self.cache["token_length"][cache_key] + + # Calculate token length if not in cache + token_length = len(self.tokenizer.encode(text, add_special_tokens=add_special_tokens)) + + # Store in cache + self.cache["token_length"][cache_key] = token_length + self._manage_cache_size("token_length") + + return token_length + + def get_ppl( + self, + text: str, + granularity: str = "line", + input_ids=None, + attention_mask=None, + past_key_values=None, + return_kv=False, + end=None, + condition_mode: str = "none", + condition_pos_id: int = 0, + ): + """ + Calculate perplexity for the given text at line level. + + Args: + text: The text to calculate perplexity for + granularity: The granularity of perplexity calculation (line, token, chunk) + input_ids, attention_mask, past_key_values: Optional pre-processed inputs + return_kv: Whether to return key-values + end: End position for calculation + condition_mode: Mode for conditional perplexity (none, prefix) + condition_pos_id: Position ID for condition + + Returns: + A dictionary with perplexity scores and processing information + """ + # Create a cache key for this specific perplexity calculation + cache_key = f"{text}_{granularity}_{condition_mode}_{condition_pos_id}" + if past_key_values is None and not return_kv and cache_key in self.cache["perplexity"]: + return self.cache["perplexity"][cache_key] + + # Initialize input processing + if input_ids is None: + encoding_key = text + if encoding_key in self.cache["encodings"]: + cached_encoding = self.cache["encodings"][encoding_key] + input_ids = cached_encoding["input_ids"] + attention_mask = cached_encoding["attention_mask"] + else: + encoding = self.tokenizer( + text, + return_tensors="pt", + padding=True + ) + input_ids = encoding["input_ids"].to(self.model.device) + attention_mask = encoding["attention_mask"].to(self.model.device) + + # Cache the encoding + self.cache["encodings"][encoding_key] = { + "input_ids": input_ids, + "attention_mask": attention_mask + } + self._manage_cache_size("encodings") + + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + else: + past_length = 0 + + if end is None: + end = input_ids.shape[1] + end = min(end, past_length + self.max_position_embeddings) + + with torch.no_grad(): + outputs = self.model( + input_ids=input_ids[:, past_length:end], + attention_mask=attention_mask[:, :end], + past_key_values=past_key_values, + return_dict=True, + output_hidden_states=True, + use_cache=True, + ) + + # Get logits and shift + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = input_ids[..., past_length+1:end].contiguous() + + # Flatten tokens for loss calculation + active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1) + active_logits = shift_logits.view(-1, shift_logits.size(-1))[active] + active_labels = shift_labels.view(-1)[active] + + # Calculate loss + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct(active_logits, active_labels) + + # Apply condition filtering if required + if condition_mode == "prefix": + loss = loss[condition_pos_id:] + + # Process based on granularity + if granularity == "token": + result_loss = loss + else: + result_loss = loss.mean() + + # Split text into lines for line-level granularity + if granularity == "line" and text: + segments = text.split("\n") + segments = [seg for seg in segments if seg.strip()] + lines_info = self.__get_lines_info(segments, input_ids[0], loss) + else: + segments = [text] if text else [] + lines_info = [] + + # Calculate mean perplexity + mean_loss = loss.mean() if len(loss) > 0 else torch.tensor(0.0) + ppl = torch.exp(mean_loss).item() if mean_loss.item() != float('inf') else float('inf') + + result = { + "loss": loss, + "input_ids": input_ids, + "attention_mask": attention_mask, + "lines_info": lines_info, + "segments": segments, + "ppl": ppl, + } + + if return_kv: + result["past_key_values"] = outputs.past_key_values + else: + # Cache the result if we're not returning KV cache + self.cache["perplexity"][cache_key] = result + self._manage_cache_size("perplexity") + + return result + + def __get_lines_info(self, lines, input_ids, loss): + """ + Get information about each line including start/end positions and importance. + + Args: + lines: List of lines in the text + input_ids: Token IDs for the entire text + loss: Per-token loss values + + Returns: + List of dictionaries with line information + """ + line_info = [] + cumulative_tokens = 0 + + input_ids_list = input_ids.cpu().tolist() + + for i, line in enumerate(lines): + if not line.strip(): + continue + + # Encode each line to find its token length + line_tokens = self.tokenizer.encode(line, add_special_tokens=False) + line_length = len(line_tokens) + + # Find position in the tokenized text + start_pos = cumulative_tokens + end_pos = start_pos + line_length + + # Calculate mean loss (importance) for this line + # Loss might be shorter than the token IDs due to shifting + if isinstance(loss, torch.Tensor) and start_pos < len(loss) and end_pos <= len(loss): + line_loss = loss[start_pos:end_pos].mean().item() + else: + # Handle edge cases + line_loss = float("inf") + + line_info.append({ + "line": line, + "start": start_pos, + "end": end_pos, + "importance": line_loss, + "tokens": line_length + }) + + cumulative_tokens += line_length + + return line_info + + def get_prefix_length(self, prefix: str, text: str): + """ + Calculate the length of a prefix in tokens when concatenated with a text. + + Args: + prefix: The prefix text + text: The main text + + Returns: + Length of the prefix in tokens + """ + possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1) + full_input_ids = self.tokenizer(prefix + text[:100], add_special_tokens=False).input_ids + + for i in range(possible_prefix_token, len(full_input_ids)): + cur_prefix = self.tokenizer.decode(full_input_ids[:i]) + if cur_prefix == prefix: + break + + return i + + def get_condition_ppl( + self, + text: str, + question: str, + condition_in_question: str = "none", + granularity: str = "line", + ): + """ + Calculate perplexity change of a question when given context text. + A positive change means the context helps reduce question perplexity. + + Args: + text: The context text + question: The question to evaluate + condition_in_question: Conditioning mode (none, prefix) + granularity: Granularity for perplexity calculation + + Returns: + Perplexity change for the question with/without context + """ + # Create a cache key for this conditional perplexity calculation + cache_key = f"{text}_{question}_{condition_in_question}_{granularity}" + + if cache_key in self.cache["conditional_ppl"]: + return self.cache["conditional_ppl"][cache_key] + + if condition_in_question == "none": + # Just return the perplexity of the text + result = self.get_ppl( + text=text, granularity=granularity, condition_mode="none" + ) + ppl_value = result["ppl"] + else: + # First calculate question perplexity without context + question_ppl_without_context = self.get_ppl( + text=question, + granularity=granularity + )["ppl"] + + # Then calculate question perplexity with context + question_ppl_with_context = self.get_ppl( + text=text + "\n\n" + question, + granularity=granularity, + condition_mode="prefix", + condition_pos_id=self.get_token_length(text + "\n\n", add_special_tokens=True) + )["ppl"] + + # Calculate the change (positive means context helps) + ppl_value = question_ppl_without_context - question_ppl_with_context + + # Cache the result + self.cache["conditional_ppl"][cache_key] = ppl_value + self._manage_cache_size("conditional_ppl") + + return ppl_value + + def get_estimate_threshold_base_distribution( + self, ppl_values, ratio: float, condition_flag: bool = False + ): + """ + Estimate threshold value for compression based on perplexity distribution. + + Args: + ppl_values: Perplexity values for tokens or lines + ratio: Compression ratio (0.0-1.0) + condition_flag: Whether values are conditional (affecting sorting direction) + + Returns: + Threshold value for filtering + """ + if ratio >= 1.0: + return float("-inf") + + if isinstance(ppl_values, torch.Tensor): + # Filter out extreme values that might skew the threshold + valid_values = ppl_values[ppl_values != float('inf')] + valid_values = valid_values[valid_values != -float('inf')] + valid_values = valid_values[~torch.isnan(valid_values)] + + if len(valid_values) == 0: + return 0.0 + + # Calculate the target position for the percentile + target_token = max(0, min(len(valid_values) - 1, int(len(valid_values) * ratio) - 1)) + + # Sort values based on condition_flag and get threshold + sort_values = valid_values.sort(descending=not condition_flag).values + if target_token < len(sort_values): + return sort_values[target_token].item() + return 0.0 + else: + # Handle non-tensor inputs (lists, numpy arrays) + valid_values = [v for v in ppl_values if v != float('inf') and v != -float('inf') and not math.isnan(v)] + + if not valid_values: + return 0.0 + + # Calculate the target position for the percentile + target_idx = max(0, min(len(valid_values) - 1, int(len(valid_values) * ratio) - 1)) + + # Sort values and get threshold + sorted_values = sorted(valid_values, reverse=not condition_flag) + if target_idx < len(sorted_values): + return sorted_values[target_idx] + return 0.0 + + def get_dynamic_compression_ratio( + self, + context: list, + target_token: float, + iterative_size: int, + dynamic_ratio: list, + start: int, + ): + """ + Calculate dynamic compression ratios for iterative compression. + + Args: + context: List of context strings + target_token: Target number of tokens + iterative_size: Size of each iteration + dynamic_ratio: List of dynamic ratio adjustments + start: Start position for processing + + Returns: + List of ratios for each iteration chunk + """ + def get_ratio(base: float, delta: float): + return max(min(1, base + delta), 0) + + context_length = [self.get_token_length(ii, False) + 2 for ii in context] + if start: + context_length = context_length[1:] + + tau = target_token / (sum(context_length) + 1) + res, idx, last, last_target = [], 0, 1, [] + + while idx < len(context_length): + if last + context_length[idx] >= iterative_size: + last_target.append( + (iterative_size - last, get_ratio(tau, dynamic_ratio[idx])) + ) + res.append(last_target) + last = last + context_length[idx] - iterative_size + + if last > iterative_size: + k = last // iterative_size + res.extend( + [[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k + ) + last -= k * iterative_size + + last_target = ( + [(last, get_ratio(tau, dynamic_ratio[idx]))] if last else [] + ) + else: + last += context_length[idx] + last_target.append( + (context_length[idx], get_ratio(tau, dynamic_ratio[idx])) + ) + idx += 1 + + if last_target: + res.append(last_target) + + return res + + def iterative_compress_prompt( + self, + context: List[str], + target_token: float, + iterative_size: int = 200, + keep_lines: bool = True, + start: int = 0, + dynamic_ratio: list = None, + condition_compare: bool = False, + ): + """ + Iteratively compress text using a sliding window approach with KV caching. + + Args: + context: List of text contexts to compress + target_token: Target number of tokens after compression + iterative_size: Size of each iteration window + keep_lines: Whether to keep line structure + start: Start position for processing + dynamic_ratio: List of dynamic compression ratios + condition_compare: Whether to use conditional comparison + + Returns: + Compressed input IDs and attention mask + """ + # Calculate dynamic compression ratios for each iteration + iterative_ratios = self.get_dynamic_compression_ratio( + context, target_token, iterative_size, dynamic_ratio, start + ) + + # Join contexts and tokenize + context_joined = "\n\n".join(context) + tokenized_text = self.tokenizer( + context_joined, return_tensors="pt", add_special_tokens=False + ) + input_ids = tokenized_text["input_ids"].to(self.model.device) + attention_mask = tokenized_text["attention_mask"].to(self.model.device) + + # Initialize working variables + compressed_input_ids, compressed_attention_mask = input_ids, attention_mask + end = min(iterative_size + start, compressed_input_ids.shape[1]) + threshold, keep_flag = None, None + + if keep_lines: + # Build a keep flag for important line tokens (e.g., indentation patterns) + input_ids_numpy = input_ids.cpu().detach().numpy()[0] + N = len(input_ids_numpy) + # Identify line break patterns to preserve + newline_ids = set(self.tokenizer.encode("\n", add_special_tokens=False)) + keep_flag = torch.zeros(N, dtype=torch.bool).to(self.model.device) + + # Mark tokens that represent indentation to be preserved + for i in range(1, N): + if input_ids_numpy[i-1] in newline_ids: + # Check if this token is whitespace (indentation) + token = self.tokenizer.decode([input_ids_numpy[i]]) + if token.isspace(): + keep_flag[i] = True + + # Initialize processing state + past_key_values, past_loss, ready_end = None, None, 0 + pop_compressed_input_ids = None + idx = 0 + + # Process text in chunks + while end <= compressed_input_ids.shape[1]: + # Handle KV-cache window sliding for long texts + if end > self.max_position_embeddings and past_key_values is not None: + # KV-Cache Compression + e, s = end - self.max_position_embeddings, min( + self.cache_bos_num + start, self.max_position_embeddings + ) + if pop_compressed_input_ids is None: + pop_compressed_input_ids = compressed_input_ids[:, :e] + else: + pop_compressed_input_ids = torch.cat( + [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1 + ) + compressed_input_ids = compressed_input_ids[:, e:] + compressed_attention_mask = compressed_attention_mask[:, e:] + + # Update KV cache - keep beginning tokens and skip processed tokens + past_key_values = [ + [ + torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), + torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), + ] + for k, v in past_key_values + ] + + if keep_flag is not None: + keep_flag = keep_flag[e:] + + end, ready_end = end - e, ready_end - e + + # Calculate perplexity for current window + result = self.get_ppl( + "", + "token", + compressed_input_ids, + compressed_attention_mask, + past_key_values=past_key_values, + return_kv=True, + end=end if idx else None, + ) + + loss, past_key_values = result["loss"], result["past_key_values"] + + if loss.shape[0] == 0: + break + + # Merge with previous loss calculations + if past_loss is not None: + if end - 1 > len(past_loss): + past_loss = torch.cat( + [past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]] + ) + past_loss[ready_end : end - 1] = loss + loss = past_loss + else: + past_loss = loss + + # Slide the KV cache window + if idx: + past_key_values = [ + [k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]] + for k, v in past_key_values + ] + else: + past_key_values = None + + # Apply compression for each chunk in the current window + for delta_end, ratio in iterative_ratios[idx]: + loss = past_loss + # Calculate threshold for token filtering + threshold = self.get_estimate_threshold_base_distribution( + loss, ratio, False + ) + + # Filter tokens using the calculated threshold + compressed_input_ids, compressed_attention_mask, keep_flag, end, past_loss = self.get_compressed_input( + loss, + compressed_input_ids, + compressed_attention_mask, + end - iterative_size + delta_end, + iterative_size=delta_end, + threshold=threshold, + keep_flag=keep_flag, + start=start, + ) + + end += iterative_size + + ready_end = end - iterative_size if not (start and idx == 0) else 0 + idx += 1 + + # Concatenate saved tokens with final compressed tokens + if pop_compressed_input_ids is not None: + compressed_input_ids = torch.cat( + [pop_compressed_input_ids, compressed_input_ids], dim=-1 + ) + + return compressed_input_ids[:, start:], compressed_attention_mask[:, start:] + + def iterative_compress_prompt_line( + self, + context: List[str], + target_token: float, + dynamic_ratio: list = None, + ): + """ + Compress text by evaluating and filtering entire lines based on importance. + This is a line-level alternative to the token-level iterative_compress_prompt. + + Args: + context: List of text contexts to compress + target_token: Target number of tokens after compression + dynamic_ratio: List of dynamic compression ratios for each context + + Returns: + Compressed input IDs and attention mask + """ + # Join contexts + context_joined = "\n\n".join(context) + + # Split text into lines + lines = context_joined.split("\n") + + # Get perplexity for the entire text at line level + ppl_result = self.get_ppl(context_joined, granularity="line") + lines_info = ppl_result["lines_info"] + + # Calculate token count for each line + line_tokens = [(i, info["tokens"], info["importance"]) + for i, info in enumerate(lines_info)] + + # Apply dynamic ratio adjustments if provided + if dynamic_ratio and len(dynamic_ratio) > 0: + # Create dynamic ratios for each line based on context dynamic ratios + # We'll infer which context each line belongs to + line_contexts = [] + context_idx = 0 + line_count = 0 + + # Map each line to its corresponding context + for i, info in enumerate(lines_info): + line_contexts.append(min(context_idx, len(dynamic_ratio) - 1)) + line_count += 1 + + # Check if we've reached the end of a context + if line_count >= lines.count("\n") + 1 and context_idx < len(context) - 1: + context_idx += 1 + line_count = 0 + + # Apply dynamic ratio adjustments to line importance scores + for i in range(len(line_tokens)): + if i < len(line_contexts): + context_idx = line_contexts[i] + if context_idx < len(dynamic_ratio): + # Adjust importance using dynamic ratio + # Lower importance score means higher priority (will be kept) + adjustment = dynamic_ratio[context_idx] + line_tokens[i] = ( + line_tokens[i][0], + line_tokens[i][1], + line_tokens[i][2] - adjustment # Lower importance means keep + ) + + # Sort lines by importance (lower score is more important) + sorted_lines = sorted(line_tokens, key=lambda x: x[2]) + + # Select lines to keep within token budget + tokens_so_far = 0 + lines_to_keep = set() + + for line_idx, line_tokens, _ in sorted_lines: + if tokens_so_far + line_tokens <= target_token: + lines_to_keep.add(line_idx) + tokens_so_far += line_tokens + else: + # Stop if we've reached our target + break + + # Create compressed text with only the selected lines + compressed_lines = [lines_info[i]["line"] for i in sorted(lines_to_keep)] + compressed_text = "\n".join(compressed_lines) + + # Tokenize the compressed text + tokenized_text = self.tokenizer( + compressed_text, return_tensors="pt", add_special_tokens=False + ) + compressed_input_ids = tokenized_text["input_ids"].to(self.model.device) + compressed_attention_mask = tokenized_text["attention_mask"].to(self.model.device) + + return compressed_input_ids, compressed_attention_mask + + def get_compressed_input( + self, + loss, + input_ids, + attention_mask, + end=200, + iterative_size=200, + threshold=0.5, + keep_flag=None, + start: int = 0, + ): + """ + Filter input tokens based on loss values and thresholds. + + Args: + loss: Loss values for each token + input_ids: Input token IDs + attention_mask: Attention mask + end: End position for processing + iterative_size: Size of each iteration + threshold: Threshold value for filtering + keep_flag: Flags for tokens to always keep + start: Start position for processing + + Returns: + Compressed inputs and updated state + """ + # Determine which tokens to keep based on loss values and threshold + need_idx = torch.concat([loss > threshold, loss[:1] > 0]) + + # Ensure we keep tokens at positions outside our current window + need_idx[end:] = 1 + need_idx[: end - iterative_size] = 1 + + # Get filtered loss + loss = loss[need_idx[:-1]] + + # Ensure need_idx matches input_ids length + if need_idx.shape[0] < input_ids.shape[1]: + need_idx = torch.cat( + [ + need_idx, + torch.ones( + input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool + ).to(need_idx.device), + ] + ) + elif need_idx.shape[0] > input_ids.shape[1]: + need_idx = need_idx[: input_ids.shape[1]] + + # Enforce keeping tokens marked in keep_flag + if keep_flag is not None: + need_idx[keep_flag] = 1 + + # Optionally apply line break preservation logic + # Find tokens representing newlines and always keep one of consecutive newlines + tokens = input_ids[0] + newline_ids = set(self.tokenizer.encode("\n", add_special_tokens=False)) + last_kept_newline = False + + for ii in range(max(0, end - iterative_size), end): + if need_idx[ii] == 0: + continue + + token_id = tokens[ii].item() + + # Handle newline logic - avoid consecutive newlines unless marked important + if token_id in newline_ids: + if last_kept_newline and keep_flag[ii].item() == 0: + need_idx[ii] = 0 + else: + last_kept_newline = True + else: + last_kept_newline = False + + # Apply the filtering to get compressed tokens + compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0) + compressed_attention_mask = attention_mask[attention_mask == 1][need_idx].unsqueeze(0) + + # Update the end position based on how many tokens we removed + end -= (need_idx[:end] == 0).sum() + + return compressed_input_ids, compressed_attention_mask, keep_flag, end, loss + + def compress_code( + self, + code: str, + query: str = "", + instruction: str = "", + rate: float = 0.5, + target_token: int = -1, + use_line_level_filter: bool = True, + use_iterative_compression: bool = True, + iterative_size: int = 200, + dynamic_compression_ratio: float = 0.2, + ): + """ + Compress code by removing less important lines based on query relevance. + + Args: + code: The code to compress + query: Query to prioritize relevant lines + instruction: Additional instruction to guide compression + rate: Compression rate (0.0-1.0), where 1.0 means no compression + target_token: Target number of tokens (alternative to rate) + use_line_level_filter: Whether to use line-level filtering + use_iterative_compression: Whether to use token-level iterative compression + iterative_size: Size of each iteration for token-level compression + dynamic_compression_ratio: Ratio for dynamic compression (0.0-1.0) + + Returns: + Compressed code and statistics + """ + logger.debug(f"Starting code compression with rate={rate}, target_token={target_token}") + start_time = time.time() + + # Calculate total tokens in the code + total_tokens = self.get_token_length(code) + logger.debug(f"Total tokens in code: {total_tokens}") + + # Determine target tokens + if target_token <= 0: + target_token = int(total_tokens * rate) + logger.debug(f"Target tokens: {target_token}") + + if rate >= 1.0 or target_token >= total_tokens: + # No compression needed + return { + "original_code": code, + "compressed_code": code, + "output": code, + "original_tokens": total_tokens, + "compressed_tokens": total_tokens, + "final_compressed_tokens": total_tokens, + "compression_ratio": 1.0, + "kept_lines": list(range(len(code.split("\n")))), + } + + # For very small code snippets, skip iterative compression + if total_tokens < 100: + use_iterative_compression = False + + if use_line_level_filter: + # Split code into lines for line-level filtering + lines = code.split("\n") + non_empty_lines = [line for line in lines if line.strip()] + logger.debug(f"Split code into {len(non_empty_lines)} non-empty lines") + + # Get perplexity for entire code + ppl_result = self.get_ppl(code, granularity="line") + lines_info = ppl_result["lines_info"] + + # For query is provided, rank lines by relevance + if query: + logger.debug("Ranking lines by relevance to query") + # Get conditional perplexity for each line + line_importances = [] + for i, line_info in tqdm(enumerate(lines_info), total=len(lines_info), desc="Calculating line importance"): + # First calculate the perplexity of the query without the line + query_ppl_without_context = self.get_ppl(query, granularity="line")["ppl"] + + # Then calculate the perplexity of the query with the line as context + query_ppl_with_context = self.get_ppl( + line_info["line"] + "\n\n" + query, + granularity="line", + condition_mode="prefix", + condition_pos_id=self.get_token_length(line_info["line"] + "\n\n", add_special_tokens=True) + )["ppl"] + + # Calculate the perplexity change (lower value means context is more helpful) + ppl_change = query_ppl_without_context - query_ppl_with_context + + # Add length adjustment similar to before + line_importances.append((i, -ppl_change - line_info["tokens"] * 2 / 250 * 0)) + + # Sort by importance (higher perplexity reduction = more relevant to query) + sorted_lines = sorted(line_importances, key=lambda x: x[1]) + else: + # Sort lines by importance (lower loss = more important) + line_importances = [(i, info["importance"]) for i, info in enumerate(lines_info)] + sorted_lines = sorted(line_importances, key=lambda x: x[1]) + + # Apply dynamic compression ratio if specified + if dynamic_compression_ratio > 0: + N = len(sorted_lines) + # This creates a gradient of compression rates from higher to lower importance + dynamic_ratios = [ + i * (dynamic_compression_ratio / (N - 1)) if N > 1 else 0 + for i in range(-(N - 1), N, 2) + ] + + # Assign dynamic ratios to lines based on their importance rank + sorted_indices = [idx for idx, _ in sorted_lines] + dynamic_ratio_map = {idx: ratio for idx, ratio in zip(sorted_indices, dynamic_ratios)} + else: + dynamic_ratio_map = {i: 0 for i in range(len(lines_info))} + + # Determine which lines to keep based on token budget + tokens_so_far = 0 + lines_to_keep = set() + + # First pass - keep most important lines within budget + for line_idx, _ in sorted_lines: + if line_idx >= len(lines_info): + continue + + line_info = lines_info[line_idx] + line_tokens = line_info["tokens"] + + if tokens_so_far + line_tokens <= target_token: + lines_to_keep.add(line_idx) + tokens_so_far += line_tokens + else: + # Stop if we've reached our target + break + + logger.debug(f"Selected {len(lines_to_keep)} lines to keep out of {len(lines_info)}") + + # Construct code with only the selected lines + preserved_code = "\n".join([lines_info[i]["line"] for i in sorted(lines_to_keep)]) + + # If we need iterative token-level compression + if use_iterative_compression: + logger.debug("Applying iterative line-level compression") + + # Create dynamic ratios for iterative compression + dynamic_ratios = [dynamic_ratio_map.get(i, 0.0) for i in sorted(lines_to_keep)] + + # Convert to list for iterative compression + context = [preserved_code] + + # Apply line-level compression instead of token-level compression + compressed_ids, compressed_mask = self.iterative_compress_prompt_line( + context, + target_token=target_token, + dynamic_ratio=dynamic_ratios, + ) + + # Convert back to text + compressed_code = self.tokenizer.decode(compressed_ids[0]) + else: + compressed_code = preserved_code + else: + # Without line-level filter, apply iterative compression directly + if use_iterative_compression: + logger.debug("Applying iterative line-level compression without line filtering") + + # Apply line-level compression to the entire code + compressed_ids, _ = self.iterative_compress_prompt_line( + [code], + target_token=target_token, + dynamic_ratio=[0.0], # No dynamic ratio adjustment for single context + ) + + # Convert back to text + compressed_code = self.tokenizer.decode(compressed_ids[0]) + else: + # Simple truncation + logger.debug("No compression methods selected, using simple truncation") + encoded = self.tokenizer.encode(code, add_special_tokens=False) + truncated = encoded[:target_token] + compressed_code = self.tokenizer.decode(truncated) + + # Construct final output with instruction and query + output = "" + if instruction: + output += f"{instruction}\n\n" + output += compressed_code + if query: + output += f"\n\n{query}" + + # Calculate compression statistics + compressed_tokens = self.get_token_length(compressed_code) + final_compressed_tokens = self.get_token_length(output) + compression_ratio = compressed_tokens / total_tokens if total_tokens > 0 else 1.0 + + end_time = time.time() + logger.debug(f"Code compression completed in {end_time - start_time:.2f} seconds") + logger.debug(f"Compression ratio: {compression_ratio:.2f}") + + # For line-level filtering, include which lines were kept + if use_line_level_filter: + kept_lines = sorted(lines_to_keep) + else: + # Approximate which lines were kept based on content + original_lines = code.split("\n") + compressed_lines = compressed_code.split("\n") + kept_lines = [] + for i, line in enumerate(original_lines): + if line in compressed_lines: + kept_lines.append(i) + + return { + "original_code": code, + "compressed_code": compressed_code, + "output": output, + "original_tokens": total_tokens, + "compressed_tokens": compressed_tokens, + "final_compressed_tokens": final_compressed_tokens, + "compression_ratio": compression_ratio, + "kept_lines": kept_lines, + } + + def control_context_budget( + self, + context_list: List[str], + target_token: float, + question: str = "", + reorder_context: str = "original", + condition_in_question: str = "none", + force_context_ids: List[int] = None, + force_context_number: int = None, + context_budget: str = "+100", + dynamic_context_compression_ratio: float = 0.0, + ): + """ + Control token budget for contexts based on relevance ranking, following LongLLMLingua. + + Args: + context_list: List of contexts + target_token: Target number of tokens + question: Question for relevance ranking + reorder_context: How to reorder contexts ("original", "importance", "two_stage") + condition_in_question: Mode for conditional ranking + force_context_ids: List of context IDs to always include + force_context_number: Number of contexts to forcibly include + context_budget: String expression to modify target token budget + dynamic_context_compression_ratio: Ratio for dynamic compression (0.0-1.0) + + Returns: + Selected contexts, their indices, and dynamic ratios + """ + logger.debug(f"Controlling context budget with target_token={target_token}") + start_time = time.time() + + if not context_list: + return [], [], [] + + # Get token counts for each context + logger.debug("Calculating token lengths for contexts") + context_tokens_length = [self.get_token_length(context) for context in context_list] + + # If total tokens already fit within budget, return all contexts + total_tokens = sum(context_tokens_length) + if total_tokens <= target_token: + logger.debug(f"All contexts fit within budget ({total_tokens} <= {target_token})") + end_time = time.time() + logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds") + return context_list, list(range(len(context_list))), [0.0] * len(context_list) + + # Rank contexts by relevance if question is provided + logger.debug("Ranking contexts by relevance") + if question: + # Get perplexity change for each context with the question + context_ppl_changes = [] + for d, dl in zip(context_list, context_tokens_length): + # Calculate how much this context reduces question perplexity + ppl_change = self.get_condition_ppl( + d, + question, + condition_in_question, + ) + # Apply length adjustment factor similar to before + context_ppl_changes.append(ppl_change - dl * 2 / 250 * 0) + + # Sort by perplexity change - higher is better (more reduction in question perplexity) + demonstrations_sort = sorted(enumerate(context_ppl_changes), key=lambda x: -x[1]) + else: + # Without question, use default ordering + demonstrations_sort = [(i, 0) for i in range(len(context_list))] + + # Extract ranking for later use + self.context_idxs.append([x for idx, (x, _) in enumerate(demonstrations_sort)]) + + # Calculate the target token budget with context_budget expression + if target_token < 0: + target_token = 100 + target_token = eval("target_token" + context_budget) + + # Initialize selected context tracking + used = force_context_ids if force_context_ids is not None else [] + + # Select contexts until we reach the token budget + for idx, _ in demonstrations_sort: + if idx >= len(context_tokens_length): + continue + target_token -= context_tokens_length[idx] + if idx not in used: + used.append(idx) + if target_token < 0 or ( + force_context_number is not None and len(used) >= force_context_number + ): + break + + # Store original selection order + original_used = used.copy() + + # Reorder contexts if requested + if reorder_context == "original": + used = sorted(used) + elif reorder_context == "two_stage": + l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [ + _ for idx, _ in enumerate(used) if idx % 2 == 1 + ] + used = l + r[::-1] + + # Calculate dynamic compression ratios if requested + if dynamic_context_compression_ratio > 0: + N = len(used) + dynamic_ratio = [ + i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0 + for i in range(-(N - 1), N, 2) + ][::-1] + dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)} + dynamic_ratio = [dynamic_ratio_map[i] for i in used] + else: + dynamic_ratio = [0.0] * len(used) + + # Build list of selected contexts + selected_contexts = [context_list[idx] for idx in used if idx < len(context_list)] + + end_time = time.time() + logger.debug(f"Selected {len(selected_contexts)} contexts out of {len(context_list)}") + logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds") + + return selected_contexts, used, dynamic_ratio, demonstrations_sort + + def compress_code_file( + self, + code: str, + query: str = "", + instruction: str = "", + rate: float = 0.5, + target_token: float = -1, + language: str = "python", + use_iterative_compression: bool = True, + iterative_size: int = 200, + dynamic_compression_ratio: float = 0.2, + context_budget: str = "+100", + rank_only: bool = False, + ): + """ + Compress a code file by first splitting it into function-based chunks and then compressing. + Functions are prioritized based on query relevance, similar to LongLLMLingua. + + Args: + code: The code to compress + query: Query to prioritize relevant functions + instruction: Additional instruction to guide compression + rate: Compression rate (0.0-1.0) + target_token: Target number of tokens (alternative to rate) + language: Programming language of the code + use_iterative_compression: Whether to use iterative compression + iterative_size: Size of each iteration for iterative compression + dynamic_compression_ratio: Ratio for dynamic compression + context_budget: String expression to modify token budget + rank_only: If True, just rank and select contexts without fine-grained compression + + Returns: + Compressed code and statistics + """ + logger.debug(f"Starting code file compression with rate={rate}, target_token={target_token}, language={language}") + start_time = time.time() + + # Split code into function-based chunks + logger.debug("Splitting code into function-based chunks") + code_chunks = self.split_code_by_functions(code, language=language) + logger.debug(f"Split code into {len(code_chunks)} chunks") + + # Calculate total tokens + logger.debug("Calculating total tokens") + total_tokens = sum(self.get_token_length(chunk) for chunk in code_chunks) + logger.debug(f"Total tokens: {total_tokens}") + + # If target token is not provided, use rate + if target_token <= 0: + target_token = int(total_tokens * rate) + logger.debug(f"Target tokens: {target_token}") + + # Use context budget control to select important functions + logger.debug("Selecting important functions using context budget control") + selected_contexts, selected_indices, dynamic_ratios, demonstrations_sort = self.control_context_budget( + code_chunks, + target_token=target_token, + question=query, + reorder_context="original", # Keep original order to maintain code structure + condition_in_question="prefix", + context_budget=context_budget, + dynamic_context_compression_ratio=dynamic_compression_ratio, + ) + + # If rank_only is True, just use the selected contexts without further compression + if rank_only: + logger.debug("Using rank-only mode: selecting top functions without fine-grained compression") + compressed_chunks = [] + compressed_tokens = 0 + function_compressions = {} + + # Just keep the selected contexts as is + for i, chunk in enumerate(code_chunks): + if i in selected_indices: + compressed_chunks.append(chunk) + chunk_tokens = self.get_token_length(chunk) + compressed_tokens += chunk_tokens + + # Store compression info - no actual compression in this mode + function_compressions[i] = { + "original_tokens": chunk_tokens, + "compressed_tokens": chunk_tokens, + "compression_ratio": 1.0, + } + else: + # Skip this function completely + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + omission_text = f"{comment_marker} ... " + compressed_chunks.append(omission_text) + compressed_tokens += self.get_token_length(omission_text) + + # Combine compressed chunks + compressed_code = "\n\n".join(compressed_chunks) + output = f"{instruction}\n\n{compressed_code}\n\n{query}\n{instruction}" + + # Calculate actual compressed tokens + final_compressed_tokens = self.get_token_length(output) + + end_time = time.time() + logger.debug(f"Code file compression completed in {end_time - start_time:.2f} seconds") + logger.debug(f"Compression ratio: {compressed_tokens / total_tokens if total_tokens > 0 else 1.0:.2f}") + + return { + "original_code": code, + "compressed_code": compressed_code, + "compressed_prompt": output, + "original_tokens": total_tokens, + "compressed_tokens": compressed_tokens, + "final_compressed_tokens": final_compressed_tokens, + "compression_ratio": compressed_tokens / total_tokens if total_tokens > 0 else 1.0, + "function_compressions": function_compressions, + "selected_functions": selected_indices, + "demonstrations_sort": demonstrations_sort, + } + + # Compress each function according to its importance + logger.debug("Compressing selected functions") + compressed_chunks = [] + compressed_tokens = 0 + function_compressions = {} + + # Allocate tokens proportionally based on importance + importance_scores = {} + for i, idx in enumerate(selected_indices): + # Higher importance for functions mentioned early in ranking + importance_scores[idx] = len(selected_indices) - i + + # Calculate total importance + total_importance = sum(importance_scores.values()) if importance_scores else 1 + + # Allocate tokens based on importance + token_allocation = {} + for idx, importance in importance_scores.items(): + allocation = max(10, int(target_token * importance / total_importance)) + token_allocation[idx] = min(allocation, self.get_token_length(code_chunks[idx])) + + # Adjust allocations to fit target + logger.debug("Adjusting token allocations to fit target") + while sum(token_allocation.values()) > target_token: + max_idx = max(token_allocation, key=token_allocation.get) + token_allocation[max_idx] = max(0, token_allocation[max_idx] - 10) + # Show the allocation + logger.debug(f"Token allocation: {token_allocation}") + + # Process each chunk + for i, chunk in tqdm(enumerate(code_chunks), total=len(code_chunks), desc="Compressing functions"): + if i in token_allocation and token_allocation[i] > 0: + # Calculate compression rate for this chunk + chunk_tokens = self.get_token_length(chunk) + chunk_rate = token_allocation[i] / chunk_tokens + + # Apply dynamic compression ratio based on importance + dynamic_ratio = dynamic_ratios[selected_indices.index(i)] if i in selected_indices else 0.0 + + # Compress the chunk using line-level compression if requested + if use_iterative_compression and chunk_tokens > 50: + compressed_input_ids, _ = self.iterative_compress_prompt_line( + [chunk], + target_token=token_allocation[i], + dynamic_ratio=[dynamic_ratio], + ) + compressed_chunk = self.tokenizer.decode(compressed_input_ids[0]) + else: + # Use simple line-level compression for smaller chunks + compress_result = self.compress_code( + code=chunk, + query=query, + rate=chunk_rate, + use_iterative_compression=False + ) + compressed_chunk = compress_result["compressed_code"] + + compressed_chunks.append(compressed_chunk) + chunk_compressed_tokens = self.get_token_length(compressed_chunk) + compressed_tokens += chunk_compressed_tokens + + # Store compression info for this function + function_compressions[i] = { + "original_tokens": chunk_tokens, + "compressed_tokens": chunk_compressed_tokens, + "compression_ratio": chunk_compressed_tokens / chunk_tokens if chunk_tokens > 0 else 1.0, + } + else: + # Skip this function completely + comment_marker = "#" if language.lower() in ["python", "typescript", "rust"] else "//" + # omission_text = f"{comment_marker} ... function omitted ..." + omission_text = f"{comment_marker} ... " + compressed_chunks.append(omission_text) + compressed_tokens += self.get_token_length(omission_text) + + # Combine compressed chunks + logger.debug("Combining compressed chunks") + compressed_code = "\n\n".join(compressed_chunks) + + # # If instruction is provided, add it to the final output + # output = "" + # if instruction: + # output += f"{instruction}\n\n" + # output += compressed_code + # if query: + # output += f"\n\n{query}" + output = f"{instruction}\n\n{compressed_code}\n\n{query}\n{instruction}" + + # Calculate actual compressed tokens including instruction and query + final_compressed_tokens = self.get_token_length(output) + + end_time = time.time() + logger.debug(f"Code file compression completed in {end_time - start_time:.2f} seconds") + logger.debug(f"Compression ratio: {compressed_tokens / total_tokens if total_tokens > 0 else 1.0:.2f}") + + return { + "original_code": code, + "compressed_code": compressed_code, + "compressed_prompt": output, + "original_tokens": total_tokens, + "compressed_tokens": compressed_tokens, + "final_compressed_tokens": final_compressed_tokens, + "compression_ratio": compressed_tokens / total_tokens if total_tokens > 0 else 1.0, + "function_compressions": function_compressions, + "selected_functions": selected_indices, + "demonstrations_sort": demonstrations_sort, + } + + def split_code_by_functions(self, code: str, language: str = "python") -> List[str]: + """ + Split code into chunks based on function and class definitions for various languages. + + Args: + code: The code to split + language: Programming language of the code (python, cpp, java, typescript, rust, go) + + Returns: + List of code chunks, each containing a function, class, or class method + """ + logger.debug(f"Splitting code by functions and classes for language: {language}") + start_time = time.time() + + # Define regex patterns for different languages + patterns = { + # Python: Simplified to match 'def' or 'class' followed by content until the next def/class or end + "python": r'(^|\n)(\s*)(def|class)\s+[^\n]+(\n(?!\s*(?:def|class)\s)[^\n]*)*', + # C++: Improved to better handle multi-line declarations + "cpp": r'(^|\n)(\s*)(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s*:\s*[^{]*)?|(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*\s+)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?', + # Java: Improved for multi-line method declarations + "java": r'(^|\n)(\s*)(?:(?:public|private|protected|static|final|abstract|synchronized)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:<.*>)?(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*)\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*throws\s+[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?', + # TypeScript: Enhanced to handle multi-line methods and arrow functions + "typescript": r'(^|\n)(\s*)(?:(?:public|private|protected|static|abstract)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:(?:public|private|protected|static|async)\s+)*(?:function\s+)?(?:[a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<.*>)?\s*\([^{;]*\)\s*(?::\s*[^{;]*\s*)?(?:=>)?)\s*(?:{[^}]*}|[^;]*;)?', + # Rust: Improved for multi-line function declarations + "rust": r'(^|\n)(\s*)(?:pub\s+)?(?:struct\s+[a-zA-Z_][a-zA-Z0-9_]*|impl(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+for\s+[a-zA-Z_][a-zA-Z0-9_]*)?|(?:async\s+)?fn\s+[a-zA-Z_][a-zA-Z0-9_]*\s*(?:<.*>)?\s*\([^{;]*\)(?:\s*->\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?', + # Go: Improved for multi-line function declarations + "go": r'(^|\n)(\s*)(?:type\s+[a-zA-Z_][a-zA-Z0-9_]*\s+struct|func\s+(?:\([^)]*\)\s*)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?', + } + + + # Use default Python pattern if language not supported + if language.lower() not in patterns: + language = "python" + + function_pattern = re.compile(patterns[language.lower()], re.MULTILINE) + + # Find all function and class definitions + matches = list(function_pattern.finditer(code)) + logger.debug(f"Found {len(matches)} function and class definitions") + + # If no functions or classes found, return the whole code as one chunk + if not matches: + logger.debug("No functions or classes found, returning entire code as one chunk") + end_time = time.time() + logger.debug(f"Code splitting completed in {end_time - start_time:.2f} seconds") + return [code] + + # Extract chunks that include function and class definitions + chunks = [] + + # Add imports and other code before the first function or class + if matches[0].start() > 0: + chunks.append(code[:matches[0].start()]) + + # Process each function or class match + for i, match in enumerate(matches): + # Get the current function or class + start = match.start() + + # Determine end position (either the start of the next function/class or the end of the code) + if i < len(matches) - 1: + end = matches[i + 1].start() + else: + end = len(code) + + # Extract the function/class and its body + chunks.append(code[start:end]) + + end_time = time.time() + logger.debug(f"Code splitting completed in {end_time - start_time:.2f} seconds") + logger.debug(f"Split code into {len(chunks)} chunks") + + return chunks + +def load_examples(language: Optional[str] = None) -> List[Dict]: + """Load examples from the results file, optionally filtered by language""" + with open("../results/ntoken_16384/Qwen_slash_Qwen2.5-7B-Instruct.jsonl", "r") as f: + # with open("../results/ntoken_16384/Qwen_slash_Qwen2.5-7B-Instruct-GPTQ-Int4.jsonl", "r") as f: + data = [json.loads(line) for line in f] + + if language: + data = [example for example in data if example["language"] == language] + if not data: + available_languages = set(ex["language"] for ex in data) + raise ValueError(f"No examples found for language '{language}'. Available languages: {available_languages}") + + return data + +# Simple test code +if __name__ == "__main__": + # Load real examples from the dataset + examples = load_examples(language="python") + example = examples[0] # Use the first example + sample_code = example["code_context"] + query = example["description"] + language = example["language"] + + print(f"Using example with language: {language}") + print(f"Query: {query}") + + # Initialize compressor + print("Initializing compressor...") + compressor = CodeCompressor() + + # Test function-based code file compression with query + print("\nTesting function-based code file compression with query...") + + start_time = time.time() + file_result = compressor.compress_code_file( + code=sample_code, + query=query, + rate=0.1, + language=language + ) + end_time = time.time() + + print(f"File compression with query completed in {end_time - start_time:.2f} seconds") + print(f"Original tokens: {file_result['original_tokens']}") + print(f"Compressed tokens: {file_result['compressed_tokens']}") + print(f"Final compressed tokens (with query): {file_result['final_compressed_tokens']}") + print(f"Compression ratio: {file_result['compression_ratio']:.2f}") + print(f"Kept function IDs: {file_result['selected_functions']}") + print(f"Demonstrations sort: {file_result['demonstrations_sort']}") + + chunk_ppl_scores = {idx: score for idx, score in file_result['demonstrations_sort']} + top_5_score = sorted(chunk_ppl_scores.values(), reverse=True)[5] + # Split into chunks and show the chunks + chunks = compressor.split_code_by_functions(sample_code, language=language) + print(f"Split code into {len(chunks)} chunks") + # show the chunk with corresponding ppl score + for i, chunk in enumerate(chunks): + print(f"==========Chunk {i+1} with demonstration sort: {chunk_ppl_scores[i]}==========") + if chunk_ppl_scores[i] >= top_5_score: + print(chunk) + print("\n") + else: + # only show some lines and then use ... to indicate the rest + print(chunk[:100]) + print("...") + print(chunk[-100:]) + print("\n") + + print("\nCompressed Code File with Query:") + print("-------------------") + print(file_result['compressed_code']) \ No newline at end of file diff --git a/repoqa/code_segment_extractor.py b/repoqa/code_segment_extractor.py new file mode 100644 index 0000000..b25c1fa --- /dev/null +++ b/repoqa/code_segment_extractor.py @@ -0,0 +1,349 @@ +import re +from typing import List, Dict, Optional +from loguru import logger +import json +import os + +def extract_code_segments(code: str, language: str = "python") -> List[Dict]: + """ + Break down code into a hierarchical structure based on language-specific patterns. + Supports Python, C++, Java, TypeScript, Rust, and Go. + + Args: + code: Original code string + language: Programming language of the code (python, cpp, java, typescript, rust, go) + + Returns: + List of code segments, each containing type, content, position, etc. + """ + language = language.lower() + + # Language-specific patterns + patterns = { + "python": { + "class": r"^class\s+(\w+)", + "function": r"^def\s+(\w+)", + "import": r"^(import|from)\s+", + "comment": r"^#", + "docstring": r'^("""|\'\'\')', + "docstring_end": r'("""|\'\'\')$', + "indent": lambda line: len(line) - len(line.lstrip()), + "block_start": lambda line: line.rstrip().endswith(":"), + "block_end": lambda line, indent: len(line) - len(line.lstrip()) <= indent + }, + "cpp": { + "class": r"^(class|struct)\s+(\w+)", + "function": r"^(void|int|bool|string|char|float|double|auto|template\s*<.*>)\s+(\w+)", + "import": r"^#include\s+", + "comment": r"^//|^/\*", + "docstring": r"^/\*\*", + "docstring_end": r"\*/$", + "indent": lambda line: len(line) - len(line.lstrip()), + "block_start": lambda line: line.rstrip().endswith("{"), + "block_end": lambda line, indent: line.rstrip() == "}" and len(line) - len(line.lstrip()) <= indent + }, + "java": { + "class": r"^(public|private|protected)?\s*(class|interface)\s+(\w+)", + "function": r"^(public|private|protected)?\s*(void|int|boolean|String|char|float|double)\s+(\w+)", + "import": r"^import\s+", + "comment": r"^//|^/\*", + "docstring": r"^/\*\*", + "docstring_end": r"\*/$", + "indent": lambda line: len(line) - len(line.lstrip()), + "block_start": lambda line: line.rstrip().endswith("{"), + "block_end": lambda line, indent: line.rstrip() == "}" and len(line) - len(line.lstrip()) <= indent + }, + "typescript": { + "class": r"^(export\s+)?(class|interface)\s+(\w+)", + "function": r"^(export\s+)?(function|const|let|var)\s+(\w+)\s*=", + "import": r"^import\s+", + "comment": r"^//|^/\*", + "docstring": r"^/\*\*", + "docstring_end": r"\*/$", + "indent": lambda line: len(line) - len(line.lstrip()), + "block_start": lambda line: line.rstrip().endswith("{"), + "block_end": lambda line, indent: line.rstrip() == "}" and len(line) - len(line.lstrip()) <= indent + }, + "rust": { + "class": r"^(pub\s+)?(struct|enum|trait)\s+(\w+)", + "function": r"^(pub\s+)?fn\s+(\w+)", + "import": r"^use\s+", + "comment": r"^//|^/\*", + "docstring": r"^//!|^/\*\*", + "docstring_end": r"\*/$", + "indent": lambda line: len(line) - len(line.lstrip()), + "block_start": lambda line: line.rstrip().endswith("{"), + "block_end": lambda line, indent: line.rstrip() == "}" and len(line) - len(line.lstrip()) <= indent + }, + "go": { + "class": r"^type\s+(\w+)\s+(struct|interface)", + "function": r"^func\s+(\w+)", + "import": r"^import\s+", + "comment": r"^//", + "docstring": r"^//", + "docstring_end": None, # Go doesn't have multi-line docstrings + "indent": lambda line: len(line) - len(line.lstrip()), + "block_start": lambda line: line.rstrip().endswith("{"), + "block_end": lambda line, indent: line.rstrip() == "}" and len(line) - len(line.lstrip()) <= indent + } + } + + if language not in patterns: + raise ValueError(f"Unsupported language: {language}. Supported languages: {', '.join(patterns.keys())}") + + def get_token_length(text: str) -> int: + """Simple approximation of token length by splitting by whitespace""" + if not text: + return 0 + return len(text.split()) + + lines = code.split('\n') + segments = [] + lang_patterns = patterns[language] + + i = 0 + while i < len(lines): + line = lines[i].strip() + indent_level = lang_patterns["indent"](lines[i]) + + # Skip empty lines + if not line: + i += 1 + continue + + # Process class/struct/enum/trait definitions + class_match = re.match(lang_patterns["class"], line) + if class_match and indent_level == 0: + class_start = i + class_name = class_match.group(1) if language == "python" else class_match.group(2) + class_indent = indent_level + + # Save class header (signature and docstring) separately + class_header_start = i + + # Skip to class body + i += 1 + + # Skip whitespace and comments to find the start of class body + while i < len(lines) and (not lines[i].strip() or re.match(lang_patterns["comment"], lines[i].strip())): + i += 1 + + # Check for docstring + if i < len(lines) and lang_patterns["docstring"] and re.match(lang_patterns["docstring"], lines[i].strip()): + docstring_start = i + i += 1 + # Find the end of the docstring + while i < len(lines): + if lang_patterns["docstring_end"] and re.search(lang_patterns["docstring_end"], lines[i]): + i += 1 + break + i += 1 + + class_header_end = i + class_header_code = '\n'.join(lines[class_header_start:class_header_end]) + + # Continue processing the rest of the class body + class_body_start = i + + # Extract methods/functions within the class + while i < len(lines): + if i >= len(lines) or (lines[i].strip() and lang_patterns["indent"](lines[i]) <= class_indent): + break + + line = lines[i].strip() + current_indent = lang_patterns["indent"](lines[i]) + + # Check for method/function definition + method_indent = class_indent + (4 if language == "python" else 2) + if re.match(lang_patterns["function"], line) and current_indent == method_indent: + method_start = i + method_name = re.match(lang_patterns["function"], line).group(1) + + # Find where method ends + i += 1 + while i < len(lines): + if i < len(lines) and lines[i].strip() and lang_patterns["indent"](lines[i]) <= current_indent: + break + i += 1 + + method_end = i + method_code = '\n'.join(lines[method_start:method_end]) + + segments.append({ + "type": "method", + "name": method_name, + "class_name": class_name, + "start_line": method_start, + "end_line": method_end, + "code": method_code, + "token_length": get_token_length(method_code), + "indent_level": current_indent + }) + + continue + else: + # Process non-method code (class attributes, etc.) + i += 1 + + class_end = i + class_code = '\n'.join(lines[class_start:class_end]) + + # Add the class header segment + segments.append({ + "type": "class_header", + "name": class_name, + "start_line": class_header_start, + "end_line": class_header_end, + "code": class_header_code, + "token_length": get_token_length(class_header_code), + "indent_level": class_indent + }) + + continue + + # Process function definitions + func_match = re.match(lang_patterns["function"], line) + if func_match and indent_level == 0: + func_start = i + func_name = func_match.group(1) + func_indent = indent_level + + # Find the end of the function + i += 1 + while i < len(lines): + current_line = lines[i].strip() + current_indent = lang_patterns["indent"](lines[i]) + + # If we hit another function or class at same or higher level, stop + if (re.match(lang_patterns["function"], current_line) or re.match(lang_patterns["class"], current_line)) and current_indent <= func_indent: + break + + i += 1 + + func_end = i + func_code = '\n'.join(lines[func_start:func_end]) + + segments.append({ + "type": "function", + "name": func_name, + "start_line": func_start, + "end_line": func_end, + "code": func_code, + "token_length": get_token_length(func_code), + "indent_level": 0 + }) + + continue + + # Process imports + if re.match(lang_patterns["import"], line) and indent_level == 0: + import_start = i + + # Check if import statement spans multiple lines + while i + 1 < len(lines) and (re.match(lang_patterns["import"], lines[i+1].strip()) or + lines[i+1].lstrip().startswith('\\')): + i += 1 + + import_end = i + 1 + import_code = '\n'.join(lines[import_start:import_end]) + + segments.append({ + "type": "import", + "start_line": import_start, + "end_line": import_end, + "code": import_code, + "token_length": get_token_length(import_code), + "indent_level": 0 + }) + + i += 1 + continue + + # Other top-level statements + elif indent_level == 0: + stmt_start = i + + # Find the end of the statement + i += 1 + while i < len(lines) and (not lines[i].strip() or lang_patterns["indent"](lines[i]) > 0): + i += 1 + + stmt_end = i + stmt_code = '\n'.join(lines[stmt_start:stmt_end]) + + segments.append({ + "type": "statement", + "start_line": stmt_start, + "end_line": stmt_end, + "code": stmt_code, + "token_length": get_token_length(stmt_code), + "indent_level": 0 + }) + + continue + + # If nothing matched, move to next line + i += 1 + + return segments + + +# Example usage +if __name__ == "__main__": + # Example Python code + python_code = """ +import os +import sys + +class MyClass: + \"\"\"This is a docstring.\"\"\" + + def __init__(self, name): + self.name = name + + def my_method(self): + print(f"Hello, {self.name}!") + +def my_function(): + return "Hello, world!" + +# This is a comment +x = 10 +y = 20 +z = x + y + """ + + # Example C++ code + cpp_code = """ +#include +#include + +class MyClass { +public: + MyClass(const std::string& name) : name_(name) {} + + void myMethod() { + std::cout << "Hello, " << name_ << "!" << std::endl; + } + +private: + std::string name_; +}; + +int myFunction() { + return 42; +} + +// This is a comment +int x = 10; +int y = 20; +int z = x + y; + """ + + # Test with Python + python_segments = extract_code_segments(python_code, language="python") + print(f"Python segments: {len(python_segments)}") + + # Test with C++ + cpp_segments = extract_code_segments(cpp_code, language="cpp") + print(f"C++ segments: {len(cpp_segments)}") \ No newline at end of file diff --git a/repoqa/compute_score.py b/repoqa/compute_score.py new file mode 100644 index 0000000..b75bb76 --- /dev/null +++ b/repoqa/compute_score.py @@ -0,0 +1,426 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +import itertools +import json +import os +import re +from collections import defaultdict +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Dict, List, Tuple, Union + +import numpy as np +import tempdir +from rich.console import Console +from rich.table import Table +from transformers import AutoConfig +from tree_sitter_languages import get_language, get_parser + +from repoqa.data import get_repoqa_data +from repoqa.metric import compute_function_similarity +from repoqa.utility import COMMENT_QUERY, FUNCTION_QUERY, progress + +LANGUAGES = list(FUNCTION_QUERY.keys()) +THRESHOLDS = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + + +class Result(Enum): + BEST_MATCH = "best_match" + FAIL_MATCH = "fail_match" + + +# unbiased estimator from https://github.com/openai/human-eval +def estimate_pass_at_k( + num_samples: Union[int, List[int], np.ndarray], + num_correct: Union[List[int], np.ndarray], + k: int, +) -> np.ndarray: + """ + Estimates pass@k of each problem and returns them in an array. + """ + + def estimator(n: int, c: int, k: int) -> float: + """ + Calculates 1 - comb(n - c, k) / comb(n, k). + """ + if n - c < k: + return 1.0 + return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) + + if isinstance(num_samples, int): + num_samples_it = itertools.repeat(num_samples, len(num_correct)) + else: + assert len(num_samples) == len(num_correct) + num_samples_it = iter(num_samples) + + return np.array( + [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)] + ) + + +def remove_comments(source_code: str, lang: str) -> str: + source_bytes = bytes(source_code, "utf8") + parser = get_parser(lang) + tree = parser.parse(source_bytes) + root_node = tree.root_node + + # Remove comments from source code + capture_list = [] + for query_str in COMMENT_QUERY[lang]: + comment_query = get_language(lang).query(query_str) + capture_list += comment_query.captures(root_node) + + capture_list.sort(key=lambda cap: cap[0].start_byte, reverse=True) + + for node, _ in capture_list: + source_bytes = source_bytes[: node.start_byte] + source_bytes[node.end_byte :] + + return source_bytes.decode("utf-8") + + +def sanitize_output(model_output: str, lang: str) -> str: + model_output = model_output.strip() + search_pattern = r"^```(?:\w+)?\s*\n(.*?)(?=^```)```" + code_blocks = re.findall(search_pattern, model_output, re.DOTALL | re.MULTILINE) + + parser = get_parser(lang) + fn_query = get_language(lang).query(FUNCTION_QUERY[lang]) + + # If not code blocks found, simply return model output + if not code_blocks: + return model_output + + processed_blocks = [] + for block in code_blocks: + processed_blocks.append(block) + + # Try to use tree-sitter to parse if possible + try: + block_bytes = bytes(block, "utf8") + tree = parser.parse(block_bytes) + for capture in fn_query.captures(tree.root_node): + node, _ = capture + function_content = block_bytes[node.start_byte : node.end_byte] + return function_content.decode("utf8") + except: + pass + + # no valid functions found by tree-sitter approach return first block + return processed_blocks[0] + + +def print_result_table(model_name, pass_results): + # Printing scores in a table + table = Table(title=f"Scores (%) of {model_name} at different thresholds") + table.add_column("Threshold", justify="center", style="bold magenta") + for threshold in THRESHOLDS: + table.add_column(f"{threshold}", justify="center") + + # Prepare data to determine the maximum values for each threshold + threshold_scores = {threshold: [] for threshold in THRESHOLDS} + for lang_results in pass_results.values(): + for thresh, value in lang_results.items(): + try: + threshold_scores[eval(thresh)].append(value["pass@1"]) + except: + threshold_scores[thresh].append(value["pass@1"]) + + # Calculate the maximum score for each threshold + max_scores = { + threshold: max(scores) for threshold, scores in threshold_scores.items() + } + min_scores = { + threshold: min(scores) for threshold, scores in threshold_scores.items() + } + + # Fill the table rows + for language, lang_results in pass_results.items(): + row = [("⭐" if language == "all" else "") + language] + for threshold, value in lang_results.items(): + score = value["pass@1"] + formatted_score = f"{100 * score:.1f}" + try: + if max_scores[eval(threshold)] - score < 0.01: + formatted_score = f"[bold green]{formatted_score}[/]" + elif score - min_scores[eval(threshold)] < 0.01: + formatted_score = f"[bold red]{formatted_score}[/]" + except: + if max_scores[threshold] - score < 0.01: + formatted_score = f"[bold green]{formatted_score}[/]" + elif score - min_scores[threshold] < 0.01: + formatted_score = f"[bold red]{formatted_score}[/]" + row.append(formatted_score) + if language == "all": + row = [f"[bold yellow]{r}[/]" for r in row] + table.add_row(*row) + + Console(width=120).print(table) + +def needle_evaluator( + model_output: str, + ground_truth: str, + repo_info: Dict, + lang: str, + ignore_comments: bool, +) -> Tuple[Result, str, float]: + contents = repo_info["content"] + needles = repo_info["needles"] + + best_target = None + best_similarity = 0 + sanitized_output = sanitize_output(model_output, lang) + if ignore_comments: + sanitized_output = remove_comments(sanitized_output, lang) + for needle in needles: + current_path = needle["path"] + current_name = needle["name"] + current_func = "\n".join( + contents[current_path].split("\n")[ + needle["start_line"] : needle["end_line"] + ] + ) + if ignore_comments: + current_func = remove_comments(current_func, lang) + + current_similarity = compute_function_similarity(sanitized_output, current_func) + if current_similarity > best_similarity: + best_similarity = current_similarity + best_target = current_name + + if best_target == ground_truth: + verdict = Result.BEST_MATCH + else: + verdict = Result.FAIL_MATCH + return verdict, best_target, best_similarity + + +def _get_repo(lang_data: Dict, repo_name: str) -> Dict: + for repo in lang_data: + if repo["repo"] == repo_name: + return repo + + +def compute_language_results(evaluation_result: Dict, all_results: Dict) -> None: + for language, lang_results in evaluation_result.items(): + current_result = {} + total = np.array([1 for _ in lang_results]) + + for threshold in THRESHOLDS: + correct_result = [] + for res in lang_results: + bc = 0 + if res["is_best_similar"] and res["best_similar_score"] >= threshold: + bc = 1 + correct_result.append(bc) + correct_result = np.array(correct_result) + + pass_at_k = { + f"pass@{k}": estimate_pass_at_k(total, correct_result, k).mean() + for k in [1, 10, 100] + if total.min() >= k + } + current_result[threshold] = pass_at_k + all_results[language] = current_result + + +def fetch_hf_context(model_name: str) -> str: + # Retrieved from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1073 + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + try: + with tempdir.TempDir() as temp_dir: + config = AutoConfig.from_pretrained( + model_name, + cache_dir=temp_dir, + force_download=True, + trust_remote_code=True, + ).to_dict() + longest_context = 0 + for key in possible_keys: + if key in config: + longest_context = max(config[key], longest_context) + if not (longest_context): + return "N/A" + return str(int(longest_context / 1024)) + "k" + except Exception as err: + print(f"fetching failed... Reason:\n{err}") + return "N/A" + + +def compute_score( + model_name: str, dataset: Dict, model_output: List[Dict], ignore_comments: bool, result_dir: str +) -> Dict: + evaluation_result = defaultdict(list) + + # if the scores already exist, load them, print the scores and exit + try: + if os.path.exists(result_dir): + with open(result_dir, "r") as f: + output_json = json.load(f) + print_result_table(model_name, output_json[model_name]["scores"]) + return output_json + except Exception as e: + print(f"Error loading scores from {result_dir}: {e}, continuing...") + + with progress(f"Scoring {model_name}") as pbar: + for result in pbar.track(model_output): + lang = result["language"] + repo_name = result["repo"] + model_outputs = result["output"] + ground_truth = result["name"] + repo_info = _get_repo(dataset[lang], repo_name) + + model_output = model_outputs[0] + verdict, best_target, best_similarity = needle_evaluator( + model_output, ground_truth, repo_info, lang, ignore_comments + ) + + is_best_similar = False + if verdict == Result.BEST_MATCH: + is_best_similar = True + + current_task = { + "repo": repo_name, + "name": ground_truth, + "needle_position": result["position_ratio"], + "is_best_similar": is_best_similar, + "best_similar_score": best_similarity, + "best_target": best_target, + "position": { + "token_start": result["needle_token_start"], + "token_end": result["needle_token_end"], + }, + } + evaluation_result[lang].append(current_task) + + # Calculate pass@k + pass_results = {} + + all_langs = [] + for lang in evaluation_result: + all_langs += evaluation_result[lang] + total = np.array([1 for _ in all_langs]) + + pass_results["all"] = {} + for threshold in THRESHOLDS: + correct_result = [] + for res in all_langs: + bc = 0 + if res["is_best_similar"] and res["best_similar_score"] >= threshold: + bc = 1 + correct_result.append(bc) + correct_result = np.array(correct_result) + pass_at_k = { + f"pass@{k}": estimate_pass_at_k(total, correct_result, k).mean() + for k in [1, 10, 100] + if total.min() >= k + } + pass_results["all"][threshold] = pass_at_k + + compute_language_results(evaluation_result, pass_results) + print_result_table(model_name, pass_results) + + output_json = {} + model_json = {} + model_json["eval_date"] = str(datetime.now()) + + # hardcode paid models + if "/" in model_name: + if model_name.startswith("bigcode/starcoder2"): + train_context = "16k" + else: + train_context = fetch_hf_context(model_name) + elif model_name.startswith("gpt-4-turbo") or model_name.startswith("gpt-4o-"): + train_context = "128k" + elif model_name.startswith("gpt-3.5-"): + train_context = "16k" + elif model_name.startswith("gemini-1.5-pro") or model_name.startswith( + "gemini-1.5-flash" + ): + train_context = "1000k" + elif model_name.startswith("gemini-1.0-pro"): + train_context = "32k" + elif model_name.startswith("claude-3-"): + train_context = "200k" + else: + train_context = "N/A" + model_json["train_size"] = train_context + model_json["scores"] = pass_results + model_json["results"] = evaluation_result + + output_json[model_name] = model_json + + return output_json + + +def get_model_name(output_path: str) -> str: + file_name = Path(output_path).stem + segments = file_name.split("_") + output_name = "" + for segment in segments: + if segment == "slash": + output_name += "/" + else: + output_name += segment + return output_name + + +def save_json(output_json, result_path) -> None: + if os.path.isfile(result_path): + decision = "" + while decision.lower() not in ["y", "n"]: + print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...") + # decision = input() + decision = "y" + + if not os.path.isfile(result_path): + with open(result_path, "w") as f: + json.dump(output_json, f) + + +def compute_main( + model_output_path: str, ignore_comments: bool = False, dataset_path: str = None +): + if dataset_path is None: + dataset = get_repoqa_data() + else: + with open(dataset_path, "r") as dataset_f: + dataset = json.load(dataset_f) + + model_outputs = [] + with open(model_output_path, "r") as output_f: + for line in output_f: + model_outputs.append(json.loads(line)) + + file_base, _ = os.path.splitext(model_output_path) + result_path = file_base + "-SCORES.json" + model_name = get_model_name(model_output_path) + output_json = compute_score(model_name, dataset, model_outputs, ignore_comments) + save_json(output_json, result_path) + + +def main(): + from fire import Fire + + Fire(compute_main) + + +if __name__ == "__main__": + main() diff --git a/repoqa/data.py b/repoqa/data.py new file mode 100644 index 0000000..d0ef747 --- /dev/null +++ b/repoqa/data.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +import gzip +import json +import os + +import tempdir +import wget +from appdirs import user_cache_dir + +CACHE_DIR = user_cache_dir("repoqa") + +REPOQA_DATA_OVERRIDE_PATH = os.getenv("REPOQA_DATA_OVERRIDE_PATH", None) +REPOQA_DATA_VERSION = os.getenv("REPOQA_DATA_VERSION", "2024-06-23") + + +def _get_repoqa_data_ready_path() -> str: + if REPOQA_DATA_OVERRIDE_PATH: + assert os.path.exists( + REPOQA_DATA_OVERRIDE_PATH + ), f"File not found: {REPOQA_DATA_OVERRIDE_PATH}" + return REPOQA_DATA_OVERRIDE_PATH + + gzip_url = f"https://github.com/evalplus/repoqa_release/releases/download/{REPOQA_DATA_VERSION}/repoqa-{REPOQA_DATA_VERSION}.json.gz" + cache_path = os.path.join(CACHE_DIR, f"repoqa-{REPOQA_DATA_VERSION}.json") + # Check if human eval file exists in CACHE_DIR + if not os.path.exists(cache_path): + # Install HumanEval dataset and parse as json + print(f"Downloading dataset from {gzip_url}") + with tempdir.TempDir() as tmpdir: + gzip_path = os.path.join(tmpdir, f"data.json.gz") + wget.download(gzip_url, gzip_path) + + with gzip.open(gzip_path, "rb") as f: + repoqa_data = f.read().decode("utf-8") + + # create CACHE_DIR if not exists + os.makedirs(CACHE_DIR, exist_ok=True) + # Write the original human eval file to CACHE_DIR + with open(cache_path, "w") as f: + f.write(repoqa_data) + + return cache_path + + +def get_repoqa_data(): + with open(_get_repoqa_data_ready_path(), "r") as f: + return json.load(f) diff --git a/repoqa/main.py b/repoqa/main.py new file mode 100644 index 0000000..9b8fe2a --- /dev/null +++ b/repoqa/main.py @@ -0,0 +1,784 @@ +from repoqa.code_compressor import CodeCompressor +from repoqa.mgcode_compressor import CodeCompressor as MGCodeCompressor +from repoqa.utility import COMMENT_QUERY, progress +from repoqa.data import CACHE_DIR, get_repoqa_data +from repoqa.compute_score import compute_score, save_json +from llmlingua import PromptCompressor +from loguru import logger +from tree_sitter_languages import get_language, get_parser +import torch +from transformers import AutoTokenizer, AutoModel +from tqdm import tqdm +import json +import os +from enum import Enum +from typing import List, Tuple, Dict +import warnings +from dataclasses import dataclass +import sys + +class ChunkStrategy(Enum): + FUNCTION_BASED = "function_based" + SLIDING_WINDOW = "sliding_window" + + +# Language-specific chunk markers +CHUNK_MARKERS = { + "python": ["class", "def"], + "cpp": ["class", "struct", "void", "int", "bool", "double", "float", "char", "auto"], + "java": ["class", "interface", "void", "int", "boolean", "double", "float", "char"], + "typescript": ["class", "interface", "function", "const", "let", "var"], + "rust": ["fn", "struct", "impl", "trait", "enum"], + "go": ["func", "type", "struct", "interface"] +} + +# all languages +# ALL_LANGUAGES = ["python", "cpp", "java", "typescript", "rust", "go"] + +# Model context template +TEMPLATE = "instruction\ncode_context\ndescription\ninstruction" + +INSTRUCTION = ( + "Based on the function description and code context," + " please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:" +) + + +@dataclass +class CodeChunk: + """Represents a chunk of code with its embedding""" + content: str + start_line: int + end_line: int + embedding: torch.Tensor = None + + +class CodeChunker: + def __init__(self, language: str, strategy: ChunkStrategy = ChunkStrategy.FUNCTION_BASED, + window_size: int = 20, overlap_size: int = 10): + self.language = language + self.parser = get_parser(language) + self.strategy = strategy + self.window_size = window_size + self.overlap_size = overlap_size + + def _is_function_or_class_start(self, line: str) -> bool: + """Check if line starts a new function or class definition""" + line = line.strip() + return any(line.startswith(marker) for marker in CHUNK_MARKERS[self.language]) + + def _chunk_by_function(self, lines: List[str]) -> List[CodeChunk]: + """Split code into chunks based on function/class definitions""" + chunks = [] + current_chunk_lines = [] + current_start = 0 + + for i, line in enumerate(lines): + if self._is_function_or_class_start(line) and current_chunk_lines: + # Store previous chunk + chunk_content = '\n'.join(current_chunk_lines) + chunks.append(CodeChunk(chunk_content, current_start, i-1)) + current_chunk_lines = [] + current_start = i + current_chunk_lines.append(line) + + # Add final chunk + if current_chunk_lines: + chunk_content = '\n'.join(current_chunk_lines) + chunks.append(CodeChunk(chunk_content, current_start, len(lines)-1)) + + return chunks + + def _chunk_by_sliding_window(self, lines: List[str]) -> List[CodeChunk]: + """Split code into chunks using sliding window approach""" + chunks = [] + + # Handle case when code is shorter than window size + if len(lines) <= self.window_size: + return [CodeChunk('\n'.join(lines), 0, len(lines)-1)] + + # Create overlapping chunks + start = 0 + while start < len(lines): + end = min(start + self.window_size, len(lines)) + chunk_content = '\n'.join(lines[start:end]) + chunks.append(CodeChunk(chunk_content, start, end-1)) + + # Move start position by (window_size - overlap_size) + start += self.window_size - self.overlap_size + + # If remaining lines are less than window_size, adjust start to include them in last chunk + if len(lines) - start < self.window_size: + if len(lines) - start > self.overlap_size: # Only if there's enough new content + chunk_content = '\n'.join(lines[start:]) + chunks.append(CodeChunk(chunk_content, start, len(lines)-1)) + break + + return chunks + + def chunk_code(self, code: str) -> List[CodeChunk]: + """Split code into chunks based on selected strategy""" + lines = code.split('\n') + + if self.strategy == ChunkStrategy.FUNCTION_BASED: + return self._chunk_by_function(lines) + elif self.strategy == ChunkStrategy.SLIDING_WINDOW: + return self._chunk_by_sliding_window(lines) + else: + raise ValueError(f"Unknown chunking strategy: {self.strategy}") + + +class RAGCompressor: + def __init__(self, model_name: str = "microsoft/unixcoder-base"): + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModel.from_pretrained(model_name) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) + + def compute_embeddings(self, chunks: List[CodeChunk]) -> List[CodeChunk]: + """Compute embeddings for code chunks""" + for chunk in chunks: + inputs = self.tokenizer(chunk.content, return_tensors="pt", + truncation=True, max_length=512).to(self.device) + with torch.no_grad(): + outputs = self.model(**inputs) + # Use mean pooling + chunk.embedding = outputs.last_hidden_state.mean(dim=1).squeeze() + return chunks + + def get_relevant_chunks(self, + query_embedding: torch.Tensor, + chunks: List[CodeChunk], + top_k: int = 5) -> List[CodeChunk]: + """Get most relevant chunks based on cosine similarity""" + similarities = [] + for chunk in chunks: + if chunk.embedding is None: + continue + sim = torch.cosine_similarity(query_embedding, chunk.embedding, dim=0) + similarities.append((sim.item(), chunk)) + + # Sort by similarity and take top k + similarities.sort(key=lambda x: x[0], reverse=True) + return [chunk for _, chunk in similarities[:top_k]] + + +def compress_context(code_context: str, + target_function: str, + language: str, + rag_compressor: RAGCompressor, + chunker: CodeChunker) -> str: + """Compress code context using RAG approach""" + # Split into chunks + chunks = chunker.chunk_code(code_context) + + # Get original token count + original_tokens = len(rag_compressor.tokenizer.encode(code_context)) + + # Log original context size + logger.info(f"Original context: {code_context}") + logger.info(f"Original token count: {original_tokens}") + logger.info(f"Number of chunks: {len(chunks)}") + + # Compute embeddings for all chunks + chunks = rag_compressor.compute_embeddings(chunks) + + # Get embedding for target function + target_embedding = rag_compressor.model( + **rag_compressor.tokenizer(target_function, return_tensors="pt", + truncation=True, max_length=512).to(rag_compressor.device) + ).last_hidden_state.mean(dim=1).squeeze() + + # Get most relevant chunks + relevant_chunks = rag_compressor.get_relevant_chunks(target_embedding, chunks) + + # Combine relevant chunks + compressed_context = "\n".join(chunk.content for chunk in relevant_chunks) + + # Get compressed token count + compressed_tokens = len(rag_compressor.tokenizer.encode(compressed_context)) + + # Log compression results + logger.info(f"Compressed token count: {compressed_tokens}") + logger.info(f"Token compression ratio: {compressed_tokens/original_tokens:.2%}") + logger.info("Selected chunks:") + for i, chunk in enumerate(relevant_chunks): + logger.info(f"Chunk {i+1} (lines {chunk.start_line}-{chunk.end_line}):\n{chunk.content}\n") + + return compressed_context + + +def compress_context_llm_lingua(compressor: PromptCompressor, + code_context: str, + target_function: str, + language: str, + target_token: int = 1000) -> str: + """Compress code context using LLMLingua approach""" + # Get original token count using LLMLingua's tokenizer + original_tokens = len(compressor.tokenizer.encode(code_context)) + + # replace the "<|endoftext|>" in the code if there is any + if "<|endoftext|>" in code_context: + logger.warning(f"Removing <|endoftext|> in code context: {code_context}") + code_context = code_context.replace("<|endoftext|>", "") + + # Compress the prompt + logger.info(f"Compressing prompt with instruction: \n{INSTRUCTION}") + logger.info(f"Code context: \n{code_context}") + logger.info(f"Description: \n{target_function}") + compressed = compressor.compress_prompt( + code_context, + instruction=INSTRUCTION, + question=target_function + INSTRUCTION, + target_token=target_token + ) + + compressed_prompt = compressed['compressed_prompt'] + logger.info(f"Compressed prompt: \n{compressed_prompt}") + + # Get compressed token count + compressed_tokens = len(compressor.tokenizer.encode(compressed_prompt)) + + # Log compression results + logger.info(f"Original token count: {original_tokens}") + logger.info(f"LLMLingua compressed token count: {compressed_tokens}") + logger.info(f"Token compression ratio: {compressed_tokens/original_tokens:.2%}") + + return compressed_prompt + + +def compress_context_longllmlingua_chunks(compressor: PromptCompressor, + code_context: str, + target_function: str, + language: str, + target_token: int = 1000, + chunk_size: int = 80, + overlap: int = 40) -> str: + """Compress code context using LongLLMLingua chunks approach""" + # Get original token count using LLMLingua's tokenizer + original_tokens = len(compressor.tokenizer.encode(code_context)) + + # replace the "<|endoftext|>" in the code if there is any + if "<|endoftext|>" in code_context: + logger.warning(f"Removing <|endoftext|> in code context: {code_context}") + code_context = code_context.replace("<|endoftext|>", "") + + # Split code into chunks for longllmlingua_chunks method + lines = code_context.split('\n') + chunks = [] + for i in range(0, len(lines), chunk_size - overlap): + chunk = lines[i:i + chunk_size] + if chunk: + chunks.append('\n'.join(chunk)) + + # Compress the prompt using chunks + compressed = compressor.compress_prompt( + chunks, + instruction=INSTRUCTION, + question=target_function + INSTRUCTION, + target_token=target_token, + rank_method="longllmlingua" + ) + + compressed_prompt = compressed['compressed_prompt'] + logger.info(f"Compressed prompt: \n{compressed_prompt}") + + # Get compressed token count + compressed_tokens = len(compressor.tokenizer.encode(compressed_prompt)) + + # Log compression results + logger.info(f"Original token count: {original_tokens}") + logger.info(f"LongLLMLingua chunks compressed token count: {compressed_tokens}") + logger.info(f"Token compression ratio: {compressed_tokens/original_tokens:.2%}") + + return compressed_prompt + + +def compress_context_code_compressor(compressor: CodeCompressor, + code_context: str, + target_function: str, + language: str, + target_ratio: float = 0.5, + ppl_strategy: str = "default", + condition_in_question: str = "default", + rank_only: bool = False, + use_iterative_compression: bool = True, + use_line_level_filter: bool = True) -> str: + """Compress code context using CodeCompressor approach + + Args: + compressor: The CodeCompressor instance + code_context: The code to compress + target_function: The function description/query + language: The programming language + target_ratio: Compression ratio (0.0-1.0) + ppl_strategy: Strategy for perplexity calculation + condition_in_question: Conditioning mode for perplexity + rank_only: If True, only rank and select functions without fine-grained compression + use_iterative_compression: Whether to use token-level iterative compression + use_line_level_filter: Whether to use line-level filtering + """ + # replace the "<|endoftext|>" in the code if there is any + if "<|endoftext|>" in code_context: + logger.warning(f"Removing <|endoftext|> in code context: {code_context}") + code_context = code_context.replace("<|endoftext|>", "") + + # Compress the code using CodeCompressor + if rank_only: + # When rank_only is True, we'll use the compress_code_file method + logger.info("===== Rank-only mode =====") + compressed = compressor.compress_code_file( + code=code_context, + query=target_function, + instruction=INSTRUCTION, + rate=target_ratio, + language=language, + rank_only=True + ) + else: + # For non-function chunk processing, use compress_code if not splitting by functions + if not use_line_level_filter and not use_iterative_compression: + logger.info("===== Simple truncation mode =====") + # Simple truncation mode + compressed = compressor.compress_code( + code=code_context, + query=target_function, + instruction=INSTRUCTION, + rate=target_ratio, + use_line_level_filter=False, + use_iterative_compression=False + ) + elif use_line_level_filter and not use_iterative_compression: + logger.info("===== Line-level filtering only =====") + # Line-level filtering only + compressed = compressor.compress_code( + code=code_context, + query=target_function, + instruction=INSTRUCTION, + rate=target_ratio, + use_line_level_filter=True, + use_iterative_compression=False + ) + elif not use_line_level_filter and use_iterative_compression: + logger.info("===== Token-level iterative compression only =====") + # Token-level iterative compression only + compressed = compressor.compress_code( + code=code_context, + query=target_function, + instruction=INSTRUCTION, + rate=target_ratio, + use_line_level_filter=False, + use_iterative_compression=True + ) + else: + # Full function-based splitting and compression + logger.info("===== Full function-based splitting and compression =====") + compressed = compressor.compress_code_file( + code=code_context, + query=target_function, + instruction=INSTRUCTION, + rate=target_ratio, + language=language, + rank_only=False, + use_iterative_compression=use_iterative_compression + ) + + # Get compressed prompt from results + if "compressed_prompt" in compressed: + compressed_prompt = compressed["compressed_prompt"] + else: + compressed_prompt = compressed["output"] + + # Log compression results + logger.info(f"Original token count: {compressed['original_tokens']}") + logger.info(f"CodeCompressor compressed token count: {compressed['compressed_tokens']}") + logger.info(f"Token compression ratio: {compressed['compressed_tokens']/compressed['original_tokens']:.2%}") + + return compressed_prompt + + +def compress_context_mgcode_compressor(compressor: MGCodeCompressor, + code_context: str, + target_function: str, + language: str, + target_ratio: float = 0.5, + compression_mode: str = "balanced") -> str: + """Compress code context using MG CodeCompressor approach""" + # replace the "<|endoftext|>" in the code if there is any + if "<|endoftext|>" in code_context: + logger.warning(f"Removing <|endoftext|> in code context: {code_context}") + code_context = code_context.replace("<|endoftext|>", "") + + # Compress the code using MG CodeCompressor + compressed = compressor.compress_code( + code=code_context, + query=target_function, + instruction=INSTRUCTION, + target_ratio=target_ratio, + compression_mode=compression_mode, + enable_fine_compression=True, + max_iterations=3, + preserve_top_functions=True, + language=language + ) + + compressed_prompt = compressed["compressed_prompt"] + # logger.info(f"Compressed prompt: \n{compressed_prompt}") + + # Log compression results + logger.info(f"Original token count: {compressed['original_tokens']}") + logger.info(f"MG CodeCompressor compressed token count: {compressed['compressed_tokens']}") + logger.info(f"Token compression ratio: {compressed['compressed_tokens']/compressed['original_tokens']:.2%}") + + return compressed_prompt + + +def evaluate_model_rag( + model: str, + code_context_size: int = 16 * 1024, + max_new_tokens: int = 1024, + result_dir: str = "results/rag_compressed_v1", + languages: List[str] = None, + tensor_parallel_size: int = 1, + trust_remote_code: bool = True, + chunk_strategy: str = "function_based", + window_size: int = 20, + overlap_size: int = 10, + dataset_path: str = None, + compression_method: str = "rag", + llm_lingua_target_token: int = 1000, + compression_ratio: float = 0.5, + backend: str = "vllm", + ppl_strategy: str = "default", + condition_in_question: str = "default", + compression_mode: str = "function_focus", + rank_only: bool = False, + use_iterative_compression: bool = False, + use_line_level_filter: bool = False, + compression_model: str = "Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4" +): + # show the parameters of rank_only, use_iterative_compression, use_line_level_filter + logger.info(f"Rank-only: {rank_only}") + logger.info(f"Use iterative compression: {use_iterative_compression}") + logger.info(f"Use line-level filter: {use_line_level_filter}") + + """Main evaluation function with compression method selection + + Args: + model: Model name or path + code_context_size: Model context size in tokens + max_new_tokens: Maximum tokens to generate + result_dir: Directory to save results + languages: List of languages to evaluate + tensor_parallel_size: Tensor parallel size for vLLM + trust_remote_code: Trust remote code for tokenizer and model + chunk_strategy: Chunking strategy ("function_based" or "sliding_window") + window_size: Window size for sliding window strategy + overlap_size: Overlap size for sliding window strategy + dataset_path: Path to dataset file + compression_method: Compression method + ("rag", "llm_lingua", "longllmlingua_chunks", "code_compressor", "mgcode_compressor", "original") + llm_lingua_target_token: Target token count for LLMLingua + compression_ratio: Compression ratio for CodeCompressor + backend: Backend for inference ("vllm") + ppl_strategy: Perplexity strategy for CodeCompressor + condition_in_question: Condition in question for CodeCompressor + compression_mode: Compression mode for MGCodeCompressor + rank_only: If True, only rank and select functions without fine-grained compression + use_iterative_compression: Whether to use token-level iterative compression for code_compressor + use_line_level_filter: Whether to apply line-level filtering for code_compressor + compression_model: Model name for LLMLingua and CodeCompressor + """ + # Create result directory + os.makedirs(result_dir, exist_ok=True) + + # Add strategy to the output directory name + strategy_str = f"_{compression_method}" + if compression_method == "llm_lingua": + strategy_str += f"_t{llm_lingua_target_token}" + elif compression_method == "longllmlingua_chunks": + strategy_str += f"_t{llm_lingua_target_token}_w{window_size}_o{overlap_size}" + elif compression_method == "code_compressor": + # Create a compression mode string based on settings + cc_mode = [] + if rank_only: + cc_mode.append("rank_only") + else: + if use_iterative_compression: + cc_mode.append("iter") + if use_line_level_filter: + cc_mode.append("line") + + mode_str = "_".join(cc_mode) if cc_mode else "simple" + strategy_str += f"_t{compression_ratio}_mode_{mode_str}" + elif compression_method == "mgcode_compressor": + strategy_str += f"_t{compression_ratio}_m{compression_mode}" + + if chunk_strategy == "sliding_window": + strategy_str += f"_w{window_size}_o{overlap_size}" + + context_size_dir = os.path.join(result_dir, f"ntoken_{code_context_size}{strategy_str}") + os.makedirs(context_size_dir, exist_ok=True) + + model_output_path = os.path.join( + context_size_dir, + f"{model.replace('/', '_slash_')}.jsonl", + ) + + # Intermediate file to store compressed contexts + compressed_contexts_path = os.path.join( + context_size_dir, + f"compressed_contexts_{model.replace('/', '_slash_')}.jsonl", + ) + + # Load cache from Qwen results + cache_file = os.path.join("results/ntoken_16384", "Qwen_slash_Qwen2.5-7B-Instruct.jsonl") + # cache_file = os.path.join("results/ntoken_16384", "Qwen_slash_Qwen2.5-7B-Instruct-GPTQ-Int4.jsonl") + if not os.path.exists(cache_file): + raise FileNotFoundError(f"Cache file not found: {cache_file}") + + with open(cache_file) as f: + cache = [json.loads(line) for line in f] + + logger.info(f"Loaded {len(cache)} examples from {cache_file}") + logger.info(f"Using chunking strategy: {chunk_strategy}") + if chunk_strategy == "sliding_window": + logger.info(f"Window size: {window_size}, Overlap size: {overlap_size}") + if compression_method == "llm_lingua": + logger.info(f"Using LLMLingua compression with target tokens: {llm_lingua_target_token}") + elif compression_method == "longllmlingua_chunks": + logger.info(f"Using LongLLMLingua chunks compression with:") + logger.info(f" - Target tokens: {llm_lingua_target_token}") + logger.info(f" - Chunk size: {window_size}") + logger.info(f" - Overlap: {overlap_size}") + elif compression_method == "code_compressor": + logger.info(f"Using CodeCompressor with ratio: {compression_ratio}") + logger.info(f"CodeCompressor settings:") + logger.info(f" - rank_only: {rank_only}") + logger.info(f" - use_iterative_compression: {use_iterative_compression}") + logger.info(f" - use_line_level_filter: {use_line_level_filter}") + + # Filter by languages if specified + if languages: + cache = [c for c in cache if c["language"] in languages] + + if dataset_path is not None: + with open(dataset_path) as f: + dataset = json.load(f) + else: + dataset = get_repoqa_data() + + # If results already exist, load and evaluate + if os.path.exists(model_output_path) and os.path.getsize(model_output_path) > 0: + logger.info(f"Loading {model_output_path} and evaluating") + model_outputs = [json.loads(line) for line in open(model_output_path)] + file_base, _ = os.path.splitext(model_output_path) + result_path = file_base + "-SCORES.json" + output_json = compute_score( + model, + dataset, + model_outputs, + True, # Ignore comments since we're using compressed context + result_dir=result_dir, + ) + save_json(output_json, result_path) + return + + # PHASE 1: Compress all contexts + compressed_tasks = [] + + # Initialize appropriate compressor based on compression method + if compression_method in ["rag", "original"]: + rag_compressor = RAGCompressor() + else: + rag_compressor = None + + # Initialize compressors if needed + llm_lingua_compressor = None + code_compressor = None + mgcode_compressor = None + if compression_method in ["llm_lingua", "longllmlingua_chunks"]: + llm_lingua_compressor = PromptCompressor(compression_model) + elif compression_method == "code_compressor": + code_compressor = CodeCompressor(compression_model) + elif compression_method == "mgcode_compressor": + mgcode_compressor = MGCodeCompressor(compression_model) + + # Convert string strategy to enum + try: + chunk_strategy_enum = ChunkStrategy(chunk_strategy) + except ValueError: + raise ValueError(f"Invalid chunk strategy: {chunk_strategy}. " + f"Must be one of {[s.value for s in ChunkStrategy]}") + + # Check if compressed contexts already exist + if os.path.exists(compressed_contexts_path) and os.path.getsize(compressed_contexts_path) > 0: + logger.info(f"Loading pre-compressed contexts from {compressed_contexts_path}") + with open(compressed_contexts_path) as f: + compressed_tasks = [json.loads(line) for line in f] + else: + logger.info(f"Starting compression phase for {len(cache)} examples") + # Process and compress each task + for i, task in enumerate(tqdm(cache, desc="Compressing contexts")): + # Make a copy of the original task + compressed_task = dict(task) + + try: + # Compression logic based on selected method + if compression_method == "rag": + chunker = CodeChunker( + task["language"], + strategy=chunk_strategy_enum, + window_size=window_size, + overlap_size=overlap_size + ) + compressed_context = compress_context( + task["code_context"], + task["description"], + task["language"], + rag_compressor, + chunker=chunker + ) + elif compression_method == "llm_lingua": + compressed_context = compress_context_llm_lingua( + compressor=llm_lingua_compressor, + code_context=task["code_context"], + target_function=task["description"], + language=task["language"], + target_token=llm_lingua_target_token + ) + elif compression_method == "longllmlingua_chunks": + compressed_context = compress_context_longllmlingua_chunks( + compressor=llm_lingua_compressor, + code_context=task["code_context"], + target_function=task["description"], + language=task["language"], + target_token=llm_lingua_target_token, + chunk_size=window_size, + overlap=overlap_size + ) + elif compression_method == "code_compressor": + compressed_context = compress_context_code_compressor( + compressor=code_compressor, + code_context=task["code_context"], + target_function=task["description"], + language=task["language"], + target_ratio=compression_ratio, + ppl_strategy=ppl_strategy, + condition_in_question=condition_in_question, + rank_only=rank_only, + use_iterative_compression=use_iterative_compression, + use_line_level_filter=use_line_level_filter + ) + elif compression_method == "mgcode_compressor": + compressed_context = compress_context_mgcode_compressor( + compressor=mgcode_compressor, + code_context=task["code_context"], + target_function=task["description"], + language=task["language"], + target_ratio=compression_ratio, + compression_mode=compression_mode + ) + elif compression_method == "original": + compressed_context = task["code_context"] + else: + raise ValueError(f"Invalid compression method: {compression_method}") + + # Update task with compressed context + compressed_task["code_context"] = compressed_context + + # Generate prompt + if compression_method == "code_compressor": + compressed_task["prompt"] = compressed_context + else: + prompt = "" + for key in task["template"].split("\n"): + prompt += compressed_task[key] + compressed_task["prompt"] = prompt + + except Exception as e: + logger.error(f"Error compressing item {i} of {len(cache)}: {e}") + # Use original context if compression fails + compressed_task["code_context"] = task["code_context"] + prompt = "" + for key in task["template"].split("\n"): + prompt += compressed_task[key] + compressed_task["prompt"] = prompt + + compressed_tasks.append(compressed_task) + + # Save intermediate results periodically + if (i + 1) % 10 == 0 or i == len(cache) - 1: + with open(compressed_contexts_path, "w") as f_out: + for t in compressed_tasks: + f_out.write(json.dumps(t) + "\n") + f_out.flush() + logger.info(f"Saved {i+1}/{len(cache)} compressed contexts") + + # Clean up compressor objects to free memory + del rag_compressor + del llm_lingua_compressor + del code_compressor + del mgcode_compressor + + # Force garbage collection to free GPU memory + import gc + gc.collect() + + # Clear CUDA cache if torch is available + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info("Cleared GPU memory cache") + + # PHASE 2: Generate responses with vLLM + logger.info("Starting response generation phase") + + # Initialize vLLM provider + from repoqa.provider.vllm import VllmProvider + engine = VllmProvider( + model, + tensor_parallel_size=tensor_parallel_size, + max_model_len=int(code_context_size * 1.5), + trust_remote_code=trust_remote_code, + gpu_memory_utilization=0.8 # Can use higher utilization now + ) + + # Generate responses for all compressed tasks + model_outputs = [] + for i, task in enumerate(tqdm(compressed_tasks, desc="Generating responses")): + # Generate reply + replies = engine.generate_reply( + task["prompt"], n=1, max_tokens=max_new_tokens + ) + + # Save result + result = {**task, "output": replies} + model_outputs.append(result) + + # Save all model outputs + with open(model_output_path, "w") as f_out: + for r in model_outputs: + f_out.write(json.dumps(r) + "\n") + f_out.flush() + logger.info(f"Saved {len(model_outputs)} responses") + + # Compute and save scores + file_base, _ = os.path.splitext(model_output_path) + result_path = file_base + "-SCORES.json" + output_json = compute_score( + model, + dataset, + model_outputs, + True, # Ignore comments since we're using compressed context + result_dir=result_dir, + ) + save_json(output_json, result_path) + + +def main(): + from fire import Fire + Fire(evaluate_model_rag) + + +if __name__ == "__main__": + main() diff --git a/repoqa/metric.py b/repoqa/metric.py new file mode 100644 index 0000000..6ec875e --- /dev/null +++ b/repoqa/metric.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +import re + +from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu + + +def compute_function_similarity( + candidate_function: str, reference_function: str +) -> float: + candidate_tokens = [item for item in re.split("\s+", candidate_function.strip())] + + reference_tokens = [item for item in re.split("\s+", reference_function.strip())] + + chencherry = SmoothingFunction() + + return sentence_bleu( + [reference_tokens], candidate_tokens, smoothing_function=chencherry.method4 + ) diff --git a/repoqa/provider/__init__.py b/repoqa/provider/__init__.py new file mode 100644 index 0000000..6f2edf5 --- /dev/null +++ b/repoqa/provider/__init__.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +from repoqa.provider.base import BaseProvider diff --git a/repoqa/provider/anthropic.py b/repoqa/provider/anthropic.py new file mode 100644 index 0000000..b09c08f --- /dev/null +++ b/repoqa/provider/anthropic.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import List + +from anthropic import Client + +from repoqa.provider.base import BaseProvider +from repoqa.provider.request.anthropic import make_auto_request + + +class AnthropicProvider(BaseProvider): + def __init__(self, model): + self.model = model + self.client = Client(api_key=os.getenv("ANTHROPIC_KEY")) + + def generate_reply( + self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None + ) -> List[str]: + assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" + replies = [] + for _ in range(n): + reply = make_auto_request( + self.client, + message=question, + model=self.model, + temperature=temperature, + max_tokens=max_tokens, + system_msg=system_msg, + ) + replies.append(reply.content[0].text) + + return replies diff --git a/repoqa/provider/base.py b/repoqa/provider/base.py new file mode 100644 index 0000000..374f2ea --- /dev/null +++ b/repoqa/provider/base.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import List + + +class BaseProvider(ABC): + @abstractmethod + def generate_reply( + self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None + ) -> List[str]: + ... diff --git a/repoqa/provider/google.py b/repoqa/provider/google.py new file mode 100644 index 0000000..63ec7b4 --- /dev/null +++ b/repoqa/provider/google.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import List + +import google.generativeai as genai + +from repoqa.provider.base import BaseProvider +from repoqa.provider.request.google import make_auto_request + + +class GoogleProvider(BaseProvider): + def __init__(self, model): + genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) + self.model = model + self.client = genai.GenerativeModel(model) + + def generate_reply( + self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None + ) -> List[str]: + assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" + replies = make_auto_request( + self.client, + question, + self.model, + n=n, + max_tokens=max_tokens, + temperature=temperature, + system_msg=system_msg, + ) + + if len(replies.candidates) != n: + print(f"[WARNING] # replies = {len(replies.candidates)} != {n = }") + + ret_texts = [] + for candidate in replies.candidates: + parts = candidate.content.parts + if parts: + ret_texts.append(parts[0].text) + else: + print("Empty response!") + ret_texts.append("") + print(f"{candidate.safety_ratings = }") + + return ret_texts + [""] * (n - len(ret_texts)) diff --git a/repoqa/provider/hf.py b/repoqa/provider/hf.py new file mode 100644 index 0000000..22b66ff --- /dev/null +++ b/repoqa/provider/hf.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from repoqa.provider.base import BaseProvider +from repoqa.provider.request import construct_message_list, hacky_assistant_stop_seq + + +class HfProvider(BaseProvider): + def __init__(self, model, trust_remote_code=False, attn_implementation=None): + self.tokenizer = AutoTokenizer.from_pretrained( + model, trust_remote_code=trust_remote_code + ) + self.hf_model = AutoModelForCausalLM.from_pretrained( + model, + trust_remote_code=trust_remote_code, + attn_implementation=attn_implementation, + torch_dtype="auto", + ).cuda() + self.stop_seq = [] + if self.tokenizer.chat_template: + self.stop_seq.append(hacky_assistant_stop_seq(self.tokenizer)) + + @torch.inference_mode() + def generate_reply( + self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None + ) -> List[str]: + assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" + + prompt_tokens = self.tokenizer.apply_chat_template( + construct_message_list(question, system_msg), + return_tensors="pt", + add_generation_prompt=True, + ).cuda() + input_length = prompt_tokens.size(-1) + + gen_args = {"do_sample": False} + if temperature > 0: + gen_args["do_sample"] = True + gen_args["temperature"] = temperature + + output_text = self.hf_model.generate( + input_ids=prompt_tokens, + max_new_tokens=max_tokens, + num_return_sequences=n, + pad_token_id=self.tokenizer.eos_token_id, + use_cache=True, + stop_strings=self.stop_seq, + tokenizer=self.tokenizer, + **gen_args, + ) + + gen_strs = [ + self.tokenizer.decode( + x[input_length:], + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + for x in output_text + ] + return gen_strs diff --git a/repoqa/provider/openai.py b/repoqa/provider/openai.py new file mode 100644 index 0000000..8d59e9b --- /dev/null +++ b/repoqa/provider/openai.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import List + +from openai import Client +from transformers import AutoTokenizer + +from repoqa.provider.base import BaseProvider +from repoqa.provider.request import hacky_assistant_stop_seq +from repoqa.provider.request.openai import make_auto_request + + +class OpenAIProvider(BaseProvider): + def __init__(self, model, base_url: str = None): + self.model = model + self.client = Client( + api_key=os.getenv("OPENAI_API_KEY", "none"), base_url=base_url + ) + self.stop_seq = [] + try: + tokenizer = AutoTokenizer.from_pretrained(model) + if tokenizer.chat_template: + self.stop_seq.append(hacky_assistant_stop_seq(tokenizer)) + print("Using stop sequence: ", self.stop_seq) + except: + print("Failed to automatically fetch stop tokens from HuggingFace.") + + def generate_reply( + self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None + ) -> List[str]: + assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" + replies = make_auto_request( + self.client, + message=question, + model=self.model, + temperature=temperature, + n=n, + max_tokens=max_tokens, + system_msg=system_msg, + stop=self.stop_seq, + ) + + return [reply.message.content for reply in replies.choices] diff --git a/repoqa/provider/request/__init__.py b/repoqa/provider/request/__init__.py new file mode 100644 index 0000000..76bdb35 --- /dev/null +++ b/repoqa/provider/request/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + + +def construct_message_list(message, system_message=None): + msglist = [{"role": "user", "content": message}] + if system_message: + msglist.insert(0, {"role": "system", "content": system_message}) + return msglist + + +def hacky_assistant_stop_seq(tokenizer) -> str: + _magic_string_ = "&==NowOrNever==&Accelerate!!!==&" + return tokenizer.apply_chat_template( + [ + {"role": "user", "content": ""}, + {"role": "assistant", "content": _magic_string_}, + ], + tokenize=False, + ).split(_magic_string_)[-1] diff --git a/repoqa/provider/request/anthropic.py b/repoqa/provider/request/anthropic.py new file mode 100644 index 0000000..26e2d9a --- /dev/null +++ b/repoqa/provider/request/anthropic.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +import signal +import time + +import anthropic +from anthropic.types import Message + +from repoqa.provider.request import construct_message_list + + +def make_request( + client: anthropic.Client, + message: str, + model: str, + max_tokens: int = 512, + temperature: float = 1, + system_msg="You are a helpful assistant good at coding.", + **kwargs, +) -> Message: + return client.messages.create( + model=model, + messages=construct_message_list(message, system_message=system_msg), + max_tokens=max_tokens, + temperature=temperature, + **kwargs, + ) + + +def handler(signum, frame): + # swallow signum and frame + raise Exception("end of time") + + +def make_auto_request(client: anthropic.Client, *args, **kwargs) -> Message: + ret = None + while ret is None: + try: + signal.signal(signal.SIGALRM, handler) + signal.alarm(100) + ret = make_request(client, *args, **kwargs) + signal.alarm(0) + except anthropic.RateLimitError: + print("Rate limit exceeded. Waiting...") + signal.alarm(0) + time.sleep(10) + except anthropic.APIConnectionError: + print("API connection error. Waiting...") + signal.alarm(0) + time.sleep(5) + except anthropic.InternalServerError: + print("Internal server error. Waiting...") + signal.alarm(0) + time.sleep(5) + except anthropic.APIError as e: + print("Unknown API error") + print(e) + if ( + e.body["error"]["message"] + == "Output blocked by content filtering policy" + ): + raise Exception("Content filtering policy blocked output") + signal.alarm(0) + except Exception as e: + print("Unknown error. Waiting...") + print(e) + signal.alarm(0) + time.sleep(1) + return ret diff --git a/repoqa/provider/request/google.py b/repoqa/provider/request/google.py new file mode 100644 index 0000000..39103e9 --- /dev/null +++ b/repoqa/provider/request/google.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +import signal +import time + +import google.generativeai as genai +from google.api_core.exceptions import GoogleAPICallError, ResourceExhausted + +from repoqa.provider.request import construct_message_list + + +def make_request( + client: genai.GenerativeModel, + message: str, + model: str, + max_tokens: int = 512, + temperature: float = 1, + n: int = 1, + system_msg="You are a helpful assistant good at coding.", + **kwargs, +) -> genai.types.GenerateContentResponse: + messages = [] + if system_msg: + messages.append({"role": "system", "parts": [system_msg]}) + messages.append({"role": "user", "parts": [message]}) + return client.generate_content( + messages, + generation_config=genai.types.GenerationConfig( + candidate_count=n, max_output_tokens=max_tokens, temperature=temperature + ), + **kwargs, + ) + + +def handler(signum, frame): + # swallow signum and frame + raise Exception("end of time") + + +def make_auto_request(*args, **kwargs) -> genai.types.GenerateContentResponse: + ret = None + while ret is None: + try: + signal.signal(signal.SIGALRM, handler) + signal.alarm(100) + ret = make_request(*args, **kwargs) + signal.alarm(0) + except ResourceExhausted as e: + print("Rate limit exceeded. Waiting...", e.message) + signal.alarm(0) + time.sleep(10) + except GoogleAPICallError as e: + print(e.message) + signal.alarm(0) + time.sleep(1) + except Exception as e: + print("Unknown error. Waiting...") + print(e) + signal.alarm(0) + time.sleep(1) + return ret diff --git a/repoqa/provider/request/openai.py b/repoqa/provider/request/openai.py new file mode 100644 index 0000000..6f6e213 --- /dev/null +++ b/repoqa/provider/request/openai.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +import signal +import time + +import openai +from openai.types.chat import ChatCompletion + +from repoqa.provider.request import construct_message_list + + +def make_request( + client: openai.Client, + message: str, + model: str, + max_tokens: int = 512, + temperature: float = 1, + n: int = 1, + system_msg="You are a helpful assistant good at coding.", + **kwargs, +) -> ChatCompletion: + return client.chat.completions.create( + model=model, + messages=construct_message_list(message, system_message=system_msg), + max_tokens=max_tokens, + temperature=temperature, + n=n, + **kwargs, + ) + + +def handler(signum, frame): + # swallow signum and frame + raise Exception("end of time") + + +def make_auto_request(*args, **kwargs) -> ChatCompletion: + ret = None + while ret is None: + try: + signal.signal(signal.SIGALRM, handler) + signal.alarm(100) + ret = make_request(*args, **kwargs) + signal.alarm(0) + except openai.RateLimitError: + print("Rate limit exceeded. Waiting...") + signal.alarm(0) + time.sleep(10) + except openai.APIConnectionError: + print("API connection error. Waiting...") + signal.alarm(0) + time.sleep(5) + except openai.APIError as e: + print(e) + signal.alarm(0) + except Exception as e: + print("Unknown error. Waiting...") + print(e) + signal.alarm(0) + time.sleep(1) + return ret diff --git a/repoqa/provider/vllm.py b/repoqa/provider/vllm.py new file mode 100644 index 0000000..6843d06 --- /dev/null +++ b/repoqa/provider/vllm.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams + +from repoqa.provider.base import BaseProvider +from repoqa.provider.request import construct_message_list, hacky_assistant_stop_seq + + +class VllmProvider(BaseProvider): + def __init__( + self, model, tensor_parallel_size, max_model_len=None, trust_remote_code=False, gpu_memory_utilization=0.9 + ): + self.tokenizer = AutoTokenizer.from_pretrained( + model, trust_remote_code=trust_remote_code + ) + self.llm = LLM( + model=model, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + trust_remote_code=trust_remote_code, + gpu_memory_utilization=gpu_memory_utilization, + ) + self.stop_seq = [] + if self.tokenizer.chat_template: + self.stop_seq.append(hacky_assistant_stop_seq(self.tokenizer)) + + def generate_reply( + self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None + ) -> List[str]: + assert temperature != 0 or n == 1, "n must be 1 when temperature is 0" + + prompt = self.tokenizer.apply_chat_template( + construct_message_list(question, system_msg), + tokenize=False, + add_generation_prompt=True, + ) + vllm_outputs = self.llm.generate( + [prompt], + SamplingParams( + temperature=temperature, + max_tokens=max_tokens, + stop=self.stop_seq, + ), + use_tqdm=False, + ) + + gen_strs = [x.outputs[0].text for x in vllm_outputs] + return gen_strs diff --git a/repoqa/run.sh b/repoqa/run.sh new file mode 100644 index 0000000..d48f92e --- /dev/null +++ b/repoqa/run.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +MODEL_NAME="Qwen/Qwen2.5-Coder-7B-Instruct" +MODEL_PATH_NAME="qwencoder-7b-instruct" +BACKEND="vllm" +COMPRESSION_METHOD="code_compressor" +BASE_RESULT_DIR="code_compressor_exp_results" +BASE_LOG_DIR="logs-combinations" + +mkdir -p ${BASE_LOG_DIR} +mkdir -p ${BASE_RESULT_DIR} + +echo "Starting experiments for ${MODEL_NAME}" + +# Configuration arrays +COMPRESSION_RATIOS=(0.1 0.2 0.3 0.4) +GPU_IDS=(0 1 2 3) + +echo "--- Running CodeCompressor with various compression ratios ---" +for i in "${!COMPRESSION_RATIOS[@]}"; do + ratio="${COMPRESSION_RATIOS[$i]}" + gpu_id="${GPU_IDS[$i]}" + + echo "Running CodeCompressor: compression_ratio=${ratio} on GPU ${gpu_id}" + CUDA_VISIBLE_DEVICES=${gpu_id} nohup python main.py \ + --model ${MODEL_NAME} \ + --backend ${BACKEND} \ + --compression-method ${COMPRESSION_METHOD} \ + --compression-ratio ${ratio} \ + --result-dir ${BASE_RESULT_DIR} \ + --rank-only > "${BASE_LOG_DIR}/7B_code_compressor_${ratio}_rank_only_true.log" 2>&1 & + echo "Started CodeCompressor: compression_ratio=${ratio} on GPU ${gpu_id}" +done + +echo "--- All CodeCompressor experiments started ---" \ No newline at end of file diff --git a/repoqa/utility.py b/repoqa/utility.py new file mode 100644 index 0000000..862306d --- /dev/null +++ b/repoqa/utility.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team +# +# SPDX-License-Identifier: Apache-2.0 + +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + TextColumn, + TimeElapsedColumn, +) + +FUNCTION_QUERY = { + "python": "(function_definition name: (_)) @fdef", + "java": "(method_declaration name: (_)) @fdef", + "typescript": "(function_declaration name: (_)) @fdef", + "rust": "(function_item name: (_)) @fdef", + "cpp": "(function_definition declarator: (function_declarator declarator: (identifier))) @fdef", + "go": "(function_declaration name: (_)) @fdef", +} + +COMMENT_QUERY = { + "python": [ + "(block (expression_statement (string) @docstring))", + "(comment) @comment", + ], + "java": ["(line_comment) @comment", "(block_comment) @comment"], + "cpp": ["(comment) @comment"], + "rust": ["(line_comment) @comment", "(block_comment) @comment"], + "typescript": ["(comment) @comment"], + "go": ["(comment) @comment"], +} + +FUNCTION_NAME_QUERY = { + "python": """ + ((function_definition + name: (identifier) @function_name)) + """, + "java": """ + (method_declaration + name: (identifier) @method_name) + """, + "typescript": """ + (function_declaration + name: (identifier) @function_name) + """, + "rust": """ + (function_item + name: (identifier) @function_name) + """, + "cpp": """ + (function_definition + name: (identifier) @function_name) + """, +} + + +def topological_sort(graph): + # Stack to store the topological order + stack = [] + # Set to keep track of visited nodes + visited = set() + + # Recursive function to process nodes + def dfs(node): + # Mark the current node as visited + visited.add(node) + # Recurse for all the vertices adjacent to this vertex + for neighbour in graph.get(node, []): + if neighbour not in visited: + dfs(neighbour) + # Push current vertex to stack which stores the result + stack.append(node) + + # Call the recursive helper function to store the topological sort starting from all vertices one by one + for node in graph: + if node not in visited: + dfs(node) + + return stack + + +def progress(note: str = "processing"): + return Progress( + TextColumn(f"{note} •" + "[progress.percentage]{task.percentage:>3.0f}%"), + BarColumn(), + MofNCompleteColumn(), + TextColumn("•"), + TimeElapsedColumn(), + )