mirror of
https://github.com/YerbaPage/LongCodeZip.git
synced 2025-10-22 23:19:46 +03:00
init
This commit is contained in:
188
.gitignore
vendored
Normal file
188
.gitignore
vendored
Normal file
@@ -0,0 +1,188 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
crosscodeeval/
|
||||
temp/
|
||||
*.jsonl
|
||||
*.out
|
||||
*.txt
|
||||
datasets/
|
||||
repositories/
|
||||
*.ipynb
|
||||
cache/
|
||||
output*/
|
||||
*.pdf
|
||||
*.json
|
||||
*.jsonl
|
||||
old_scripts/
|
||||
*cache*/
|
||||
*.zip
|
||||
*.tar
|
||||
*.tar.gz
|
||||
*.tar.xz
|
||||
*.tar.bz2
|
||||
*.tar.lzma
|
||||
*.tar.lz4
|
||||
*.tar.zstd
|
||||
*.tar.lz
|
||||
*.html
|
||||
124
README.md
124
README.md
@@ -1 +1,125 @@
|
||||
# LongCodeZip
|
||||
|
||||
This repository is the official implementation of LongCodeZip, a novel two-stage long code compression method.
|
||||
|
||||
|
||||
## Method Overview
|
||||
|
||||

|
||||
|
||||
LongCodeZip introduces a two-stage code compression framework specifically designed for code LLMs:
|
||||
|
||||
1. **Coarse-grained Compression**: Function-based chunking and ranking using conditional perplexity with respect to the query to select the most relevant functions.
|
||||
|
||||
2. **Fine-grained Compression**: Entropy-based block detection combined with 0/1 knapsack optimization to maximize relevance within adaptive token budgets.
|
||||
|
||||
The method is plug-and-play and can be integrated with existing code LLMs to achieve significant compression ratios while maintaining or improving task performance.
|
||||
|
||||
## Repository Structure
|
||||
|
||||
This repository contains implementations and experiments for three code-related tasks:
|
||||
|
||||
```
|
||||
LongCodeZip/
|
||||
├── repoqa/ # Code Retrieval Task
|
||||
│ ├── main.py # Main evaluation script
|
||||
│ ├── run.sh # Experiment runner
|
||||
│ ├── code_compressor.py # Core compression implementation
|
||||
│ ├── compute_score.py # Evaluation metrics
|
||||
│ └── ...
|
||||
├── long-code-completion/ # Code Completion Task
|
||||
│ ├── main.py # Main evaluation script
|
||||
│ ├── run.sh # Experiment runner
|
||||
│ ├── code_compressor.py # Core compression implementation
|
||||
│ ├── utils.py # Utility functions
|
||||
│ └── ...
|
||||
├── module_summarization/ # Code Summarization Task
|
||||
│ ├── main.py # Main evaluation script
|
||||
│ ├── run.sh # Experiment runner
|
||||
│ ├── code_compressor.py # Core compression implementation
|
||||
│ ├── utils.py # Utility functions
|
||||
│ └── ...
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Quick Start
|
||||
|
||||
Each task directory contains a `run.sh` script for easy experimentation. Simply navigate to the desired task directory and run:
|
||||
|
||||
```bash
|
||||
cd <task_directory>
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
### Code Retrieval (RepoQA)
|
||||
|
||||
Navigate to the `repoqa` directory and run experiments with different compression ratios:
|
||||
|
||||
```bash
|
||||
cd repoqa
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
The script will evaluate LongCodeZip on the RepoQA dataset with compression ratios of 0.1, 0.2, 0.3, and 0.4, running experiments in parallel on multiple GPUs.
|
||||
|
||||
**Key Parameters:**
|
||||
- `--compression-ratio`: Controls the compression level (0.1-0.4)
|
||||
- `--model`: Specifies the base LLM model
|
||||
- `--backend`: Backend for model inference (vllm)
|
||||
|
||||
### Code Completion
|
||||
|
||||
Navigate to the `long-code-completion` directory:
|
||||
|
||||
```bash
|
||||
cd long-code-completion
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
This evaluates LongCodeZip on long-context code completion tasks with various configurations including different target token limits, fine-grained compression ratios, and importance beta values.
|
||||
|
||||
**Key Parameters:**
|
||||
- `--code_compressor_target_token`: Target token budget (2048, 4096)
|
||||
- `--code_compressor_fine_ratio`: Fine-grained compression ratio (0.5, 0.8)
|
||||
- `--importance_beta`: Importance weighting parameter (0.0, 0.5)
|
||||
|
||||
### Code Summarization
|
||||
|
||||
Navigate to the `module_summarization` directory:
|
||||
|
||||
```bash
|
||||
cd module_summarization
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
This runs code summarization experiments with fine-grained compression and various beta values for importance weighting.
|
||||
|
||||
**Key Parameters:**
|
||||
- `--code_compressor_target_token`: Target token budget
|
||||
- `--code_compressor_fine_ratio`: Fine-grained compression ratio
|
||||
- `--importance_beta`: Importance weighting parameter
|
||||
|
||||
## Configuration
|
||||
|
||||
Each task can be customized by modifying the respective `run.sh` file or by directly calling the main scripts with custom parameters. Key configuration options include:
|
||||
|
||||
- **Model Selection**: Compatible with various code LLMs (default: Qwen2.5-Coder-7B-Instruct)
|
||||
- **Compression Ratios**: Adjustable compression levels for different use cases
|
||||
- **Token Budgets**: Configurable target token limits
|
||||
- **GPU Configuration**: Multi-GPU support for parallel experiments
|
||||
|
||||
## Performance
|
||||
|
||||
LongCodeZip achieves up to **5.6× compression ratio** without sacrificing task performance across code completion, summarization, and retrieval tasks. And even when using a 0.5B Qwen model as the compressor, it can also achieve competitive performance.
|
||||
|
||||
## Contact
|
||||
|
||||
Please feel free to contact us if you have any questions.
|
||||
BIN
assets/overview.png
Normal file
BIN
assets/overview.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 366 KiB |
1889
long-code-completion/code_compressor.py
Normal file
1889
long-code-completion/code_compressor.py
Normal file
File diff suppressed because it is too large
Load Diff
190
long-code-completion/compare_empty_line_handling.py
Normal file
190
long-code-completion/compare_empty_line_handling.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import List
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
def compare_empty_line_handling():
|
||||
"""Compare original vs corrected empty line handling in PPL chunking"""
|
||||
|
||||
code_to_be_analyzed = """def evaluate_blind(self, code, **kwargs):
|
||||
|
||||
suffix = kwargs.get('suffix', self.get('suffix', ''))
|
||||
blind = kwargs.get('blind', False)
|
||||
|
||||
action = self.actions.get('evaluate_blind', {})
|
||||
payload_action = action.get('evaluate_blind')
|
||||
call_name = action.get('call', 'inject')
|
||||
|
||||
# Skip if something is missing or call function is not set
|
||||
if not action or not payload_action or not call_name or not hasattr(self, call_name):
|
||||
return
|
||||
|
||||
expected_delay = self._get_expected_delay()
|
||||
|
||||
if '%(code_b64)s' in payload_action:
|
||||
log.debug('[b64 encoding] %s' % code)
|
||||
execution_code = payload_action % ({
|
||||
'code_b64' : base64.urlsafe_b64encode(code),
|
||||
'delay' : expected_delay
|
||||
})
|
||||
else:
|
||||
execution_code = payload_action % ({
|
||||
'code' : code,
|
||||
'delay' : expected_delay
|
||||
})
|
||||
|
||||
return getattr(self, call_name)(
|
||||
code = execution_code,
|
||||
prefix = prefix,
|
||||
suffix = suffix,
|
||||
blind=True
|
||||
)"""
|
||||
|
||||
print("="*80)
|
||||
print("COMPARISON: Empty Line Handling in PPL Chunking")
|
||||
print("="*80)
|
||||
|
||||
lines = code_to_be_analyzed.split('\n')
|
||||
|
||||
# Simulate original approach (includes empty lines in smoothing)
|
||||
def original_smoothing(values, window_size=3):
|
||||
"""Original smoothing that includes empty lines"""
|
||||
smoothed = []
|
||||
for i in range(len(values)):
|
||||
start_idx = max(0, i - window_size // 2)
|
||||
end_idx = min(len(values), i + window_size // 2 + 1)
|
||||
|
||||
window_values = []
|
||||
for j in range(start_idx, end_idx):
|
||||
if not math.isinf(values[j]) and not math.isnan(values[j]):
|
||||
window_values.append(values[j])
|
||||
|
||||
if window_values:
|
||||
smoothed.append(sum(window_values) / len(window_values))
|
||||
else:
|
||||
smoothed.append(values[i])
|
||||
|
||||
return smoothed
|
||||
|
||||
# Simulate corrected approach (excludes empty lines from smoothing)
|
||||
def corrected_smoothing(values, lines, window_size=3):
|
||||
"""Corrected smoothing that excludes empty lines"""
|
||||
smoothed = []
|
||||
|
||||
# Identify non-empty line indices
|
||||
non_empty_indices = [i for i, line in enumerate(lines) if line.strip() != '']
|
||||
|
||||
for i in range(len(values)):
|
||||
if lines[i].strip() == '': # Empty line
|
||||
smoothed.append(values[i]) # Keep original value
|
||||
else:
|
||||
# Find position in non-empty indices
|
||||
try:
|
||||
pos_in_non_empty = non_empty_indices.index(i)
|
||||
except ValueError:
|
||||
smoothed.append(values[i])
|
||||
continue
|
||||
|
||||
# Get window around this position in non-empty lines
|
||||
start_pos = max(0, pos_in_non_empty - window_size // 2)
|
||||
end_pos = min(len(non_empty_indices), pos_in_non_empty + window_size // 2 + 1)
|
||||
|
||||
# Get values from non-empty lines in the window
|
||||
window_values = []
|
||||
for j in range(start_pos, end_pos):
|
||||
idx = non_empty_indices[j]
|
||||
val = values[idx]
|
||||
if not math.isinf(val) and not math.isnan(val) and val > 0:
|
||||
window_values.append(val)
|
||||
|
||||
if window_values:
|
||||
smoothed.append(sum(window_values) / len(window_values))
|
||||
else:
|
||||
smoothed.append(values[i])
|
||||
|
||||
return smoothed
|
||||
|
||||
# Create sample PPL values (simulated)
|
||||
sample_ppls = []
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip() == '':
|
||||
sample_ppls.append(1.0) # Empty line PPL
|
||||
else:
|
||||
# Simulate varying PPL values
|
||||
if 'def ' in line:
|
||||
sample_ppls.append(101.65)
|
||||
elif 'return' in line and len(line.strip()) < 20:
|
||||
sample_ppls.append(1.50)
|
||||
elif line.strip().startswith('#'):
|
||||
sample_ppls.append(17.72)
|
||||
elif 'kwargs.get' in line:
|
||||
sample_ppls.append(8.39)
|
||||
elif 'action' in line:
|
||||
sample_ppls.append(8.17)
|
||||
elif 'if ' in line:
|
||||
sample_ppls.append(12.41)
|
||||
elif 'else:' in line:
|
||||
sample_ppls.append(1.36)
|
||||
elif line.strip().startswith("'"):
|
||||
sample_ppls.append(2.52)
|
||||
else:
|
||||
sample_ppls.append(5.0 + (i % 10)) # Varying values
|
||||
|
||||
# Apply both smoothing approaches
|
||||
original_smoothed = original_smoothing(sample_ppls, window_size=3)
|
||||
corrected_smoothed = corrected_smoothing(sample_ppls, lines, window_size=3)
|
||||
|
||||
print(f"Total lines: {len(lines)}")
|
||||
print(f"Empty lines: {len([l for l in lines if l.strip() == ''])}")
|
||||
print(f"Non-empty lines: {len([l for l in lines if l.strip() != ''])}")
|
||||
print()
|
||||
|
||||
print("Line-by-line comparison:")
|
||||
print(f"{'Line':>4} {'Empty':>5} {'Original':>10} {'Corrected':>10} {'Difference':>10} {'Content'}")
|
||||
print("-" * 80)
|
||||
|
||||
for i, (line, orig_ppl, orig_smooth, corr_smooth) in enumerate(zip(lines, sample_ppls, original_smoothed, corrected_smoothed)):
|
||||
is_empty = line.strip() == ''
|
||||
diff = abs(orig_smooth - corr_smooth)
|
||||
content = repr(line[:40] + "..." if len(line) > 40 else line)
|
||||
|
||||
print(f"{i:4d} {'Yes' if is_empty else 'No':>5} {orig_smooth:10.4f} {corr_smooth:10.4f} {diff:10.4f} {content}")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("KEY DIFFERENCES:")
|
||||
print("="*80)
|
||||
|
||||
# Find lines where smoothing differs significantly
|
||||
significant_diffs = []
|
||||
for i, (orig_smooth, corr_smooth) in enumerate(zip(original_smoothed, corrected_smoothed)):
|
||||
diff = abs(orig_smooth - corr_smooth)
|
||||
if diff > 0.1 and lines[i].strip() != '': # Non-empty lines with significant difference
|
||||
significant_diffs.append((i, diff, orig_smooth, corr_smooth))
|
||||
|
||||
print(f"\nLines with significant smoothing differences (> 0.1):")
|
||||
for line_idx, diff, orig, corr in significant_diffs:
|
||||
print(f"Line {line_idx}: Original={orig:.4f}, Corrected={corr:.4f}, Diff={diff:.4f}")
|
||||
print(f" Content: {repr(lines[line_idx])}")
|
||||
|
||||
# Show impact on empty lines
|
||||
empty_line_indices = [i for i, line in enumerate(lines) if line.strip() == '']
|
||||
print(f"\nEmpty line smoothing values:")
|
||||
for idx in empty_line_indices:
|
||||
print(f"Line {idx}: Original={original_smoothed[idx]:.4f}, Corrected={corrected_smoothed[idx]:.4f}")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("SUMMARY:")
|
||||
print("="*80)
|
||||
print("Original approach:")
|
||||
print("- Includes empty lines (PPL=1.0) in smoothing windows")
|
||||
print("- Can artificially lower smoothed PPL values near empty lines")
|
||||
print("- May create false local minimums")
|
||||
|
||||
print("\nCorrected approach:")
|
||||
print("- Excludes empty lines from smoothing calculations")
|
||||
print("- Only considers non-empty lines for smoothing windows")
|
||||
print("- Preserves original line indices for visualization")
|
||||
print("- More accurate representation of code complexity patterns")
|
||||
|
||||
if __name__ == "__main__":
|
||||
compare_empty_line_handling()
|
||||
750
long-code-completion/main.py
Normal file
750
long-code-completion/main.py
Normal file
@@ -0,0 +1,750 @@
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from llmlingua import PromptCompressor
|
||||
import fire
|
||||
from utils import load_data, compute_EM, compute_ES
|
||||
from vllm import LLM, SamplingParams
|
||||
from loguru import logger
|
||||
from code_compressor import CodeCompressor
|
||||
import gc
|
||||
from typing import List
|
||||
import re
|
||||
|
||||
|
||||
# Helper function for splitting code by functions (standalone version)
|
||||
def split_code_by_functions_standalone(code: str, language: str = "python") -> List[str]:
|
||||
"""
|
||||
Split code into chunks based on function and class definitions for various languages.
|
||||
Standalone version that doesn't require CodeCompressor instance.
|
||||
|
||||
Args:
|
||||
code: The code to split
|
||||
language: Programming language of the code (python, cpp, java, typescript, rust, go)
|
||||
|
||||
Returns:
|
||||
List of code chunks, each containing a function, class, or class method
|
||||
"""
|
||||
# Define regex patterns for different languages
|
||||
patterns = {
|
||||
# Python: Simplified to match 'def' or 'class' followed by content until the next def/class or end
|
||||
"python": r'(^|\n)(\s*)(def|class)\s+[^\n]+(\n(?!\s*(?:def|class)\s)[^\n]*)*',
|
||||
# C++: Improved to better handle multi-line declarations
|
||||
"cpp": r'(^|\n)(\s*)(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s*:\s*[^{]*)?|(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*\s+)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?',
|
||||
# Java: Improved for multi-line method declarations
|
||||
"java": r'(^|\n)(\s*)(?:(?:public|private|protected|static|final|abstract|synchronized)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:<.*>)?(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*)\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*throws\s+[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?',
|
||||
# TypeScript: Enhanced to handle multi-line methods and arrow functions
|
||||
"typescript": r'(^|\n)(\s*)(?:(?:public|private|protected|static|abstract)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:(?:public|private|protected|static|async)\s+)*(?:function\s+)?(?:[a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<.*>)?\s*\([^{;]*\)\s*(?::\s*[^{;]*\s*)?(?:=>)?)\s*(?:{[^}]*}|[^;]*;)?',
|
||||
# Rust: Improved for multi-line function declarations
|
||||
"rust": r'(^|\n)(\s*)(?:pub\s+)?(?:struct\s+[a-zA-Z_][a-zA-Z0-9_]*|impl(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+for\s+[a-zA-Z_][a-zA-Z0-9_]*)?|(?:async\s+)?fn\s+[a-zA-Z_][a-zA-Z0-9_]*\s*(?:<.*>)?\s*\([^{;]*\)(?:\s*->\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?',
|
||||
# Go: Improved for multi-line function declarations
|
||||
"go": r'(^|\n)(\s*)(?:type\s+[a-zA-Z_][a-zA-Z0-9_]*\s+struct|func\s+(?:\([^)]*\)\s*)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?',
|
||||
}
|
||||
|
||||
# Use default Python pattern if language not supported
|
||||
if language.lower() not in patterns:
|
||||
language = "python"
|
||||
|
||||
function_pattern = re.compile(patterns[language.lower()], re.MULTILINE)
|
||||
matches = list(function_pattern.finditer(code))
|
||||
|
||||
if not matches:
|
||||
return [code] if code.strip() else [] # No matches, return whole code if not empty
|
||||
|
||||
result_chunks = []
|
||||
|
||||
# Add code before first match if exists
|
||||
if matches[0].start() > 0:
|
||||
pre_code = code[:matches[0].start()].strip()
|
||||
if pre_code:
|
||||
result_chunks.append(pre_code)
|
||||
|
||||
# Process each match
|
||||
for i, match in enumerate(matches):
|
||||
start = match.start()
|
||||
|
||||
# End is either start of next match or end of code
|
||||
if i < len(matches) - 1:
|
||||
end = matches[i + 1].start()
|
||||
else:
|
||||
end = len(code)
|
||||
|
||||
chunk = code[start:end].strip()
|
||||
if chunk:
|
||||
result_chunks.append(chunk)
|
||||
|
||||
return result_chunks
|
||||
|
||||
|
||||
# Helper function for function-level RAG retrieval
|
||||
def function_rag_retrieve(background_code: str, query_code: str, model, tokenizer, device, language: str, top_k: int) -> str:
|
||||
"""Uses function-level chunking and retrieves top_k similar functions."""
|
||||
if not background_code.strip():
|
||||
return "" # Return empty if no background context
|
||||
|
||||
# Split code into function-based chunks
|
||||
chunks = split_code_by_functions_standalone(background_code, language)
|
||||
if not chunks:
|
||||
return "" # Return empty if chunking results in nothing
|
||||
|
||||
query_embedding = compute_embedding(query_code, model, tokenizer, device)
|
||||
|
||||
chunk_embeddings = []
|
||||
valid_chunks = []
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
chunk_embeddings.append(compute_embedding(chunk, model, tokenizer, device))
|
||||
valid_chunks.append(chunk)
|
||||
|
||||
if not valid_chunks:
|
||||
return ""
|
||||
|
||||
# Stack embeddings for efficient similarity calculation
|
||||
chunk_embeddings_tensor = torch.stack(chunk_embeddings)
|
||||
|
||||
# Compute cosine similarity
|
||||
similarities = torch.cosine_similarity(query_embedding.unsqueeze(0), chunk_embeddings_tensor, dim=1)
|
||||
|
||||
# Get top_k indices
|
||||
top_k_indices = torch.topk(similarities, k=min(top_k, len(valid_chunks)), dim=0).indices
|
||||
|
||||
# Retrieve relevant chunks
|
||||
retrieved_chunks = [valid_chunks[i] for i in top_k_indices.tolist()]
|
||||
|
||||
# Combine relevant chunks (maintain order by similarity score)
|
||||
combined_code = "\n\n".join(retrieved_chunks)
|
||||
|
||||
return combined_code
|
||||
|
||||
|
||||
# Helper function for sliding window chunking
|
||||
def chunk_sliding_window(code: str, window_size: int, overlap: int) -> list[str]:
|
||||
"""Splits code into overlapping chunks using a sliding window."""
|
||||
lines = code.splitlines()
|
||||
if not lines:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
stride = window_size - overlap
|
||||
if stride <= 0:
|
||||
raise ValueError("Overlap size must be smaller than window size.")
|
||||
|
||||
while True:
|
||||
end = min(start + window_size, len(lines))
|
||||
chunk_lines = lines[start:end]
|
||||
if not chunk_lines: # Should not happen if lines is not empty, but safety check
|
||||
break
|
||||
chunks.append("\n".join(chunk_lines))
|
||||
if end == len(lines):
|
||||
break # Exit loop if we reached the end
|
||||
next_start = start + stride
|
||||
# If the next window would go past the end, break
|
||||
if next_start >= len(lines):
|
||||
# Add the final overlapping chunk if needed
|
||||
final_start = max(0, len(lines) - window_size)
|
||||
if final_start > start: # Ensure it's a new chunk not already added
|
||||
final_chunk_lines = lines[final_start:]
|
||||
chunks.append("\n".join(final_chunk_lines))
|
||||
break
|
||||
start = next_start
|
||||
|
||||
# Handle case where code is shorter than window size
|
||||
if not chunks and lines:
|
||||
return ["\n".join(lines)]
|
||||
|
||||
# Remove duplicates while preserving order (important for RAG)
|
||||
seen = set()
|
||||
unique_chunks = []
|
||||
for chunk in chunks:
|
||||
if chunk not in seen:
|
||||
seen.add(chunk)
|
||||
unique_chunks.append(chunk)
|
||||
|
||||
return unique_chunks
|
||||
|
||||
|
||||
# Helper function to compute embeddings (using mean pooling)
|
||||
def compute_embedding(text: str, model, tokenizer, device) -> torch.Tensor:
|
||||
"""Computes sentence embedding for a text using the provided model."""
|
||||
if not text.strip(): # Handle empty strings
|
||||
return torch.zeros(model.config.hidden_size).to(device)
|
||||
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
# Mean pool the last hidden state
|
||||
embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
|
||||
return embedding
|
||||
|
||||
# Helper function for RAG retrieval
|
||||
|
||||
|
||||
def rag_retrieve(background_code: str, query_code: str, model, tokenizer, device, window_size: int, overlap: int, top_k: int) -> str:
|
||||
"""Chunks background, embeds chunks and query, retrieves top_k similar chunks."""
|
||||
if not background_code.strip():
|
||||
return "" # Return empty if no background context
|
||||
|
||||
chunks = chunk_sliding_window(background_code, window_size, overlap)
|
||||
if not chunks:
|
||||
return "" # Return empty if chunking results in nothing
|
||||
|
||||
query_embedding = compute_embedding(query_code, model, tokenizer, device)
|
||||
|
||||
chunk_embeddings = []
|
||||
valid_chunks = []
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
chunk_embeddings.append(compute_embedding(chunk, model, tokenizer, device))
|
||||
valid_chunks.append(chunk)
|
||||
|
||||
if not valid_chunks:
|
||||
return ""
|
||||
|
||||
# Stack embeddings for efficient similarity calculation
|
||||
chunk_embeddings_tensor = torch.stack(chunk_embeddings)
|
||||
|
||||
# Compute cosine similarity
|
||||
similarities = torch.cosine_similarity(query_embedding.unsqueeze(0), chunk_embeddings_tensor, dim=1)
|
||||
|
||||
# Get top_k indices
|
||||
top_k_indices = torch.topk(similarities, k=min(top_k, len(valid_chunks)), dim=0).indices
|
||||
|
||||
# Retrieve and sort chunks by their original position
|
||||
relevant_chunks_with_indices = []
|
||||
original_indices_map = {chunk_content: idx for idx, chunk_content in enumerate(chunks)} # Map content back to original index
|
||||
|
||||
retrieved_chunk_contents = [valid_chunks[i] for i in top_k_indices.tolist()]
|
||||
|
||||
# Find original start lines to sort chronologically (approximate)
|
||||
chunk_start_lines = {}
|
||||
current_line = 0
|
||||
lines = background_code.splitlines()
|
||||
chunk_map_from_sliding = chunk_sliding_window(background_code, window_size, overlap) # Re-chunk to get consistent indexing if needed
|
||||
start_line_num = 0
|
||||
stride = window_size - overlap
|
||||
for i, chunk_content in enumerate(chunk_map_from_sliding):
|
||||
# This assumes the chunking function returns chunks in order
|
||||
chunk_start_lines[chunk_content] = start_line_num
|
||||
start_line_num += stride
|
||||
# Rough approximation, doesn't perfectly handle edge cases/final chunks
|
||||
|
||||
sorted_relevant_chunks = sorted(
|
||||
retrieved_chunk_contents,
|
||||
key=lambda content: chunk_start_lines.get(content, float('inf')) # Sort by approximate start line
|
||||
)
|
||||
|
||||
# Combine relevant chunks
|
||||
# Original implementation joined with \n, let's keep it simple
|
||||
combined_code = "\n\n".join(sorted_relevant_chunks) # Separate chunks by double newline for clarity
|
||||
|
||||
return combined_code
|
||||
|
||||
|
||||
# Helper function for LLMLingua compression
|
||||
def compress_llmlingua(context: str, query: str, compressor: PromptCompressor, target_token: int, instruction: str) -> str:
|
||||
"""Compresses context using LLMLingua."""
|
||||
if not context.strip():
|
||||
return ""
|
||||
try:
|
||||
# Ensure no "<|endoftext|>"
|
||||
context_clean = context.replace("<|endoftext|>", "")
|
||||
compressed = compressor.compress_prompt(
|
||||
context_clean,
|
||||
instruction=instruction,
|
||||
question=query + "\n" + instruction, # Combine query and instruction for question
|
||||
target_token=target_token
|
||||
)
|
||||
# Ensure result exists and is string
|
||||
result = compressed.get('compressed_prompt', '')
|
||||
return result if isinstance(result, str) else ""
|
||||
except Exception as e:
|
||||
logger.error(f"LLMLingua compression failed: {e}")
|
||||
# Fallback: Truncate based on target tokens (approximate)
|
||||
tokens = compressor.tokenizer.encode(context_clean)
|
||||
if len(tokens) > target_token:
|
||||
return compressor.tokenizer.decode(tokens[:target_token])
|
||||
return context_clean
|
||||
|
||||
|
||||
# Helper function for LongLLMLingua compression
|
||||
def compress_longllmlingua(context: str, query: str, compressor: PromptCompressor, target_token: int, instruction: str, chunk_size: int, overlap: int) -> str:
|
||||
"""Compresses context using LongLLMLingua with sliding window chunks."""
|
||||
if not context.strip():
|
||||
return ""
|
||||
try:
|
||||
# Ensure no "<|endoftext|>"
|
||||
context_clean = context.replace("<|endoftext|>", "")
|
||||
# Use our sliding window chunker
|
||||
chunks = chunk_sliding_window(context_clean, chunk_size, overlap)
|
||||
if not chunks:
|
||||
return "" # Handle case where context is too short or chunking fails
|
||||
|
||||
compressed = compressor.compress_prompt(
|
||||
chunks,
|
||||
instruction=instruction,
|
||||
question=query + "\n" + instruction, # Combine query and instruction for question
|
||||
target_token=target_token,
|
||||
rank_method="longllmlingua" # Use the specified rank method
|
||||
)
|
||||
# Ensure result exists and is string
|
||||
result = compressed.get('compressed_prompt', '')
|
||||
return result if isinstance(result, str) else ""
|
||||
except Exception as e:
|
||||
logger.error(f"LongLLMLingua compression failed: {e}")
|
||||
# Fallback: Truncate based on target tokens (approximate)
|
||||
tokens = compressor.tokenizer.encode(context_clean)
|
||||
if len(tokens) > target_token:
|
||||
return compressor.tokenizer.decode(tokens[:target_token])
|
||||
return context_clean
|
||||
|
||||
# Helper function for CodeCompressor (Rank Only or Fine-grained)
|
||||
|
||||
|
||||
def compress_code_compressor(context: str, query: str, compressor: CodeCompressor, target_token: int, instruction: str, language: str, rank_only: bool, fine_ratio: float, importance_beta: float) -> str:
|
||||
"""Compresses context using CodeCompressor based on target tokens and rank_only flag."""
|
||||
if not context.strip():
|
||||
return ""
|
||||
try:
|
||||
# Ensure no "<|endoftext|>"
|
||||
context_clean = context.replace("<|endoftext|>", "")
|
||||
if not context_clean.strip():
|
||||
return "" # Return empty if clean context is empty
|
||||
|
||||
# Tokenize to get original length
|
||||
# Use the compressor's tokenizer
|
||||
original_tokens = len(compressor.tokenizer.encode(context_clean))
|
||||
if original_tokens == 0:
|
||||
return "" # Avoid division by zero
|
||||
|
||||
# Calculate target ratio
|
||||
target_ratio = min(1.0, max(0.0, target_token / original_tokens))
|
||||
logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}")
|
||||
|
||||
# Pass rank_only and fine_ratio
|
||||
# Assuming compressor is already initialized with the correct model
|
||||
compressed_result = compressor.compress_code_file(
|
||||
code=context_clean,
|
||||
query=query, # Using current function context as query focus
|
||||
instruction=instruction,
|
||||
rate=target_ratio,
|
||||
language=language,
|
||||
rank_only=rank_only, # Ensure rank_only mode is set
|
||||
fine_ratio=fine_ratio if not rank_only else None, # Pass fine_ratio only if not rank_only
|
||||
importance_beta=importance_beta if not rank_only else None, # Pass importance_beta only if not rank_only
|
||||
)
|
||||
|
||||
# Extract compressed content - check both possible keys
|
||||
compressed_context = compressed_result.get("compressed_code")
|
||||
|
||||
if not isinstance(compressed_context, str):
|
||||
logger.error(f"CodeCompressor returned non-string: {type(compressed_context)}")
|
||||
compressed_context = "" # Fallback
|
||||
|
||||
# Log results
|
||||
compressed_tokens_count = len(compressor.tokenizer.encode(compressed_context))
|
||||
final_ratio = (compressed_tokens_count / original_tokens) if original_tokens > 0 else 0
|
||||
logger.info(f"CodeCompressor: Compressed tokens={compressed_tokens_count}, Actual ratio={final_ratio:.4f}")
|
||||
|
||||
return compressed_context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CodeCompressor compression failed: {e}", exc_info=True)
|
||||
# Fallback: Truncate approximately based on target tokens (less ideal for rank_only)
|
||||
tokens = compressor.tokenizer.encode(context_clean)
|
||||
if len(tokens) > target_token:
|
||||
logger.warning(f"CodeCompressor falling back to simple truncation.")
|
||||
return compressor.tokenizer.decode(tokens[:target_token])
|
||||
return context_clean
|
||||
|
||||
# Function to save scores
|
||||
|
||||
|
||||
def save_json(data: dict, file_path: str):
|
||||
"""Saves dictionary data to a JSON file."""
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
|
||||
def generate_completions(llm, batch_prompts, max_new_tokens=128):
|
||||
# Generate completions for batch
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
top_p=0.95,
|
||||
max_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
batch_outputs = llm.generate(
|
||||
batch_prompts,
|
||||
sampling_params,
|
||||
use_tqdm=False
|
||||
)
|
||||
|
||||
return [x.outputs[0].text for x in batch_outputs]
|
||||
|
||||
|
||||
def evaluate_completion(
|
||||
model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct",
|
||||
method: str = "full",
|
||||
result_dir: str = "results/completion_baselines",
|
||||
embed_model_name: str = "microsoft/unixcoder-base",
|
||||
compression_model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct",
|
||||
dataset_path: str = "microsoft/LCC_python",
|
||||
dataset_split: str = "test",
|
||||
num_examples: int = 200,
|
||||
max_new_tokens: int = 128,
|
||||
batch_size: int = 16,
|
||||
# RAG params
|
||||
rag_window_size: int = 80,
|
||||
rag_overlap: int = 40,
|
||||
rag_top_k: int = 3,
|
||||
# Function RAG params
|
||||
function_rag_language: str = "python",
|
||||
function_rag_top_k: int = 3,
|
||||
# LLMLingua params
|
||||
lingua_target_token: int = 500,
|
||||
lingua_instruction: str = "Complete the following code function given the context.",
|
||||
# LongLLMLingua params
|
||||
longlingua_chunk_size: int = 80,
|
||||
longlingua_overlap: int = 40,
|
||||
# CodeCompressor params (New)
|
||||
code_compressor_target_token: int = 500,
|
||||
# vLLM params
|
||||
tensor_parallel_size: int = 1,
|
||||
trust_remote_code: bool = True,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
filter_current_lines_max: int = 50,
|
||||
filter_background_tokens_min: int = 3000,
|
||||
# New CodeCompressor fine-grained param
|
||||
code_compressor_fine_ratio: float = 1.0, # Default 1.0 means rank_only=True
|
||||
# New CodeCompressor importance beta param
|
||||
importance_beta: float = 0.0, # Default beta is 0.0
|
||||
):
|
||||
"""Evaluates code completion baselines with a specified context preparation method."""
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# --- 1. Load Data ---
|
||||
# Assuming python for now, might need modification if dataset has multiple languages
|
||||
# Note: Language info might be needed for CodeCompressor if not always python
|
||||
dataset, _ = load_data(path=dataset_path, split=dataset_split, num_examples=num_examples,
|
||||
filter_current_lines_max=filter_current_lines_max, filter_background_tokens_min=filter_background_tokens_min)
|
||||
logger.info(f"Loaded {len(dataset)} examples from {dataset_path} ({dataset_split} split)")
|
||||
|
||||
# --- 2. Initialize Models ---
|
||||
embed_model = None
|
||||
embed_tokenizer = None
|
||||
if method == "rag" or method == "function_rag":
|
||||
logger.info(f"Initializing embedding model: {embed_model_name}")
|
||||
embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
|
||||
embed_model = AutoModel.from_pretrained(embed_model_name).to(device)
|
||||
embed_model.eval() # Set to evaluation mode
|
||||
logger.info(f"Embedding model {embed_model_name} initialized.")
|
||||
|
||||
lingua_compressor = None
|
||||
if method == "llmlingua" or method == "longllmlingua":
|
||||
logger.info(f"Initializing LLMLingua compressor: {compression_model_name}")
|
||||
lingua_compressor = PromptCompressor(model_name=compression_model_name, device_map="auto")
|
||||
logger.info(f"LLMLingua compressor {compression_model_name} initialized.")
|
||||
|
||||
code_compressor_instance = None # Renamed to avoid conflict
|
||||
if method == "code_compressor":
|
||||
logger.info(f"Initializing CodeCompressor: {compression_model_name}")
|
||||
# Assuming CodeCompressor takes model name and potentially device
|
||||
# Pass device explicitly if needed by your CodeCompressor implementation
|
||||
code_compressor_instance = CodeCompressor(compression_model_name)
|
||||
logger.info(f"CodeCompressor {compression_model_name} initialized.")
|
||||
|
||||
if method in ["full", "no_context"]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
# try to compress a dummy prompt to avoid cuda error when initializing the vllm (strange bug)
|
||||
code_compressor_instance = PromptCompressor(model_name=compression_model_name, device_map="auto")
|
||||
logger.info(f"CodeCompressor {compression_model_name} initialized.")
|
||||
dummy_prompt = "def hello_world():\n print('Hello, World!')"*100
|
||||
compressed_prompt = code_compressor_instance.compress_prompt(dummy_prompt, instruction="Complete the following code function given the context.", question="Complete the following code function given the context.", target_token=500)
|
||||
logger.info(f"Compressed prompt: {compressed_prompt}")
|
||||
|
||||
# --- 3. Process the Specified Method ---
|
||||
logger.info(f"--- Processing Method: {method} ---")
|
||||
|
||||
# Modify result directory based on method and parameters
|
||||
method_suffix = f"method_{method}"
|
||||
if method == "rag":
|
||||
method_suffix += f"_w{rag_window_size}_o{rag_overlap}_k{rag_top_k}"
|
||||
elif method == "function_rag":
|
||||
method_suffix += f"_lang{function_rag_language}_k{function_rag_top_k}"
|
||||
elif method == "llmlingua":
|
||||
method_suffix += f"_t{lingua_target_token}"
|
||||
elif method == "longllmlingua":
|
||||
method_suffix += f"_t{lingua_target_token}_cs{longlingua_chunk_size}_o{longlingua_overlap}"
|
||||
elif method == "code_compressor":
|
||||
# Determine if rank_only based on fine_ratio
|
||||
rank_only_for_suffix = (code_compressor_fine_ratio == 1.0)
|
||||
suffix_detail = "_rankonly" if rank_only_for_suffix else f"fr{code_compressor_fine_ratio}"
|
||||
# Add importance_beta to suffix
|
||||
if importance_beta > 0:
|
||||
suffix_detail += f"_b{importance_beta}"
|
||||
# Use code_compressor_target_token for consistency
|
||||
method_suffix += f"_t{code_compressor_target_token}{suffix_detail}" # Updated suffix
|
||||
|
||||
method_result_dir = os.path.join(result_dir, method_suffix)
|
||||
os.makedirs(method_result_dir, exist_ok=True)
|
||||
|
||||
model_output_path = os.path.join(
|
||||
method_result_dir,
|
||||
f"{model_name.replace('/', '_slash_')}.jsonl",
|
||||
)
|
||||
score_output_path = os.path.join(
|
||||
method_result_dir,
|
||||
f"{model_name.replace('/', '_slash_')}-SCORES.json",
|
||||
)
|
||||
|
||||
all_prompts = []
|
||||
original_data = [] # Store original data to merge with results
|
||||
|
||||
# Prepare prompts based on method
|
||||
for i, example in enumerate(tqdm(dataset, desc=f"Preparing prompts for {method}")):
|
||||
background_ctx = example['background_context']
|
||||
current_func_ctx = example['current_function_context'] # This is the prefix
|
||||
ground_truth = example['gt'] # This is the completion target
|
||||
# Determine language - assuming python for now based on dataset path
|
||||
language = "python" # IMPORTANT: Make dynamic if dataset contains multiple languages
|
||||
|
||||
context_for_prompt = ""
|
||||
try:
|
||||
if method == "full":
|
||||
context_for_prompt = background_ctx + "\n\n" + current_func_ctx
|
||||
|
||||
# some models have max context length of 32768, so we truncate the context (from the head) if it exceeds that
|
||||
tokenized_context = tokenizer.encode(context_for_prompt)
|
||||
if len(tokenized_context) > 32768-256:
|
||||
logger.warning(f"Context length exceeds 32768, truncating from the head. Original length: {len(tokenized_context)}, Truncated length: 32768")
|
||||
context_for_prompt = tokenizer.decode(tokenized_context[-(32768-256):])
|
||||
elif method == "rag":
|
||||
if not embed_model or not embed_tokenizer:
|
||||
raise ValueError("RAG method selected but embedding model not initialized.")
|
||||
retrieved_ctx = rag_retrieve(
|
||||
background_ctx, current_func_ctx,
|
||||
embed_model, embed_tokenizer, device,
|
||||
rag_window_size, rag_overlap, rag_top_k
|
||||
)
|
||||
context_for_prompt = retrieved_ctx + "\n\n" + current_func_ctx
|
||||
elif method == "function_rag":
|
||||
if not embed_model or not embed_tokenizer:
|
||||
raise ValueError("Function RAG method selected but embedding model not initialized.")
|
||||
retrieved_ctx = function_rag_retrieve(
|
||||
background_ctx, current_func_ctx,
|
||||
embed_model, embed_tokenizer, device,
|
||||
function_rag_language, function_rag_top_k
|
||||
)
|
||||
context_for_prompt = retrieved_ctx + "\n\n" + current_func_ctx
|
||||
elif method == "llmlingua":
|
||||
if not lingua_compressor:
|
||||
raise ValueError("LLMLingua method selected but compressor not initialized.")
|
||||
compressed_ctx = compress_llmlingua(
|
||||
background_ctx, current_func_ctx,
|
||||
lingua_compressor, lingua_target_token, lingua_instruction
|
||||
)
|
||||
context_for_prompt = compressed_ctx + "\n\n" + current_func_ctx
|
||||
elif method == "longllmlingua":
|
||||
if not lingua_compressor:
|
||||
raise ValueError("LongLLMLingua method selected but compressor not initialized.")
|
||||
compressed_ctx = compress_longllmlingua(
|
||||
background_ctx, current_func_ctx,
|
||||
lingua_compressor, lingua_target_token, lingua_instruction,
|
||||
longlingua_chunk_size, longlingua_overlap
|
||||
)
|
||||
context_for_prompt = compressed_ctx + "\n\n" + current_func_ctx
|
||||
elif method == "code_compressor":
|
||||
if not code_compressor_instance:
|
||||
raise ValueError("CodeCompressor method selected but compressor not initialized.")
|
||||
# Determine rank_only based on fine_ratio
|
||||
rank_only = (code_compressor_fine_ratio == 1.0)
|
||||
logger.info(f"CodeCompressor mode: {'Rank Only' if rank_only else f'Fine-grained (ratio={code_compressor_fine_ratio})'}")
|
||||
# Use current_func_ctx as the query for CodeCompressor to focus retrieval
|
||||
compressed_ctx = compress_code_compressor(
|
||||
context=background_ctx,
|
||||
query=current_func_ctx, # Query is the current function prefix
|
||||
compressor=code_compressor_instance,
|
||||
target_token=code_compressor_target_token,
|
||||
instruction=lingua_instruction, # Reusing lingua instruction
|
||||
language=language,
|
||||
rank_only=rank_only, # Pass determined rank_only flag
|
||||
fine_ratio=code_compressor_fine_ratio, # Pass fine_ratio
|
||||
importance_beta=importance_beta, # Pass importance_beta
|
||||
)
|
||||
# Combine the compressed background context with the original current function context
|
||||
context_for_prompt = compressed_ctx + "\n\n" + current_func_ctx
|
||||
elif method == "no_context":
|
||||
context_for_prompt = current_func_ctx
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {method}")
|
||||
|
||||
prompt = context_for_prompt.strip()
|
||||
all_prompts.append(prompt)
|
||||
original_data.append({
|
||||
"id": example.get("id", i),
|
||||
"gt": ground_truth,
|
||||
"original_background_context": background_ctx,
|
||||
"original_current_function_context": current_func_ctx,
|
||||
"language": language # Store language if needed later
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing example {i} (ID: {example.get('id', 'N/A')}) for method {method}: {e}", exc_info=True)
|
||||
continue # Skip this example
|
||||
|
||||
# --- 4. Clean up Compression/Embedding Models ---
|
||||
logger.info("Freeing up GPU memory from compression/embedding models")
|
||||
if embed_model:
|
||||
del embed_model
|
||||
if embed_tokenizer:
|
||||
del embed_tokenizer
|
||||
if lingua_compressor:
|
||||
del lingua_compressor
|
||||
if code_compressor_instance:
|
||||
del code_compressor_instance # Clean up CodeCompressor
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
logger.info("GPU memory freed")
|
||||
|
||||
# --- 5. Initialize Generation LLM ---
|
||||
# Check if there are any prompts to process before initializing LLM
|
||||
if not all_prompts:
|
||||
logger.error(f"No valid prompts were prepared for method {method}. Skipping generation and scoring.")
|
||||
return
|
||||
|
||||
logger.info(f"Initializing generation LLM: {model_name}")
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
max_model_len=32768
|
||||
)
|
||||
logger.info(f"Generation LLM {model_name} initialized.")
|
||||
|
||||
# --- 6. Generate Completions ---
|
||||
all_outputs = []
|
||||
logger.info(f"Generating completions for {len(all_prompts)} prompts...")
|
||||
for i in tqdm(range(0, len(all_prompts), batch_size), desc=f"Generating completions for {method}"):
|
||||
batch_prompts = all_prompts[i:i + batch_size]
|
||||
if not batch_prompts:
|
||||
continue
|
||||
|
||||
try:
|
||||
batch_outputs = generate_completions(llm, batch_prompts, max_new_tokens=max_new_tokens)
|
||||
all_outputs.extend(batch_outputs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during generation for batch starting at index {i}: {e}")
|
||||
all_outputs.extend(["ERROR_GENERATING"] * len(batch_prompts))
|
||||
|
||||
# --- 7. Evaluate and Save Results ---
|
||||
model_outputs_data = []
|
||||
total_es = 0
|
||||
total_em = 0
|
||||
valid_scores = 0
|
||||
|
||||
if len(all_outputs) != len(original_data):
|
||||
logger.warning(f"Warning: Mismatch between generated outputs ({len(all_outputs)}) and original data ({len(original_data)}). Scores might be inaccurate.")
|
||||
min_len = min(len(all_outputs), len(original_data))
|
||||
all_outputs = all_outputs[:min_len]
|
||||
original_data = original_data[:min_len]
|
||||
all_prompts = all_prompts[:min_len]
|
||||
|
||||
logger.info(f"Calculating scores and saving results for {len(all_outputs)} examples...")
|
||||
# make sure that the path exists
|
||||
os.makedirs(os.path.dirname(model_output_path), exist_ok=True)
|
||||
with open(model_output_path, "w") as f_out:
|
||||
for i in range(len(all_outputs)):
|
||||
output = all_outputs[i]
|
||||
# Ensure index is valid for original_data and all_prompts
|
||||
if i >= len(original_data) or i >= len(all_prompts):
|
||||
logger.error(f"Index {i} out of bounds after potential mismatch alignment. Stopping result processing.")
|
||||
break
|
||||
orig_data = original_data[i]
|
||||
prompt = all_prompts[i]
|
||||
gt = orig_data['gt']
|
||||
|
||||
result = {
|
||||
**orig_data,
|
||||
"prompt": prompt,
|
||||
"output": output,
|
||||
}
|
||||
|
||||
es = 0
|
||||
em = 0
|
||||
if output != "ERROR_GENERATING" and gt is not None:
|
||||
try:
|
||||
es = compute_ES(gt, output)
|
||||
em = compute_EM(gt, output)
|
||||
total_es += es
|
||||
total_em += em
|
||||
valid_scores += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error scoring example {orig_data.get('id', i)}: {e}")
|
||||
|
||||
result['es'] = es
|
||||
result['em'] = em
|
||||
model_outputs_data.append(result)
|
||||
f_out.write(json.dumps(result) + "\n")
|
||||
|
||||
logger.info(f"Raw results saved to {model_output_path}")
|
||||
|
||||
avg_es = (total_es / valid_scores) if valid_scores > 0 else 0
|
||||
avg_em = (total_em / valid_scores) if valid_scores > 0 else 0
|
||||
|
||||
# Update the parameters dictionary in scores
|
||||
scores = {
|
||||
"model_name": model_name,
|
||||
"method": method,
|
||||
"num_examples_scored": valid_scores,
|
||||
"num_examples_total": len(original_data), # Use length of original_data before potential alignment issues
|
||||
"average_es": avg_es,
|
||||
"average_em": avg_em,
|
||||
"parameters": {
|
||||
"dataset_path": dataset_path,
|
||||
"dataset_split": dataset_split,
|
||||
"filter_current_lines_max": filter_current_lines_max,
|
||||
"filter_background_tokens_min": filter_background_tokens_min,
|
||||
"embed_model_name": embed_model_name if method == "rag" or method == "function_rag" else None,
|
||||
# Combine compression model name reporting
|
||||
"compression_model_name": compression_model_name if method in ["llmlingua", "longllmlingua", "code_compressor"] else None,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"batch_size": batch_size,
|
||||
# RAG specific params
|
||||
"rag_window_size": rag_window_size if method == "rag" else None,
|
||||
"rag_overlap": rag_overlap if method == "rag" else None,
|
||||
"rag_top_k": rag_top_k if method == "rag" else None,
|
||||
# Function RAG params
|
||||
"function_rag_language": function_rag_language if method == "function_rag" else None,
|
||||
"function_rag_top_k": function_rag_top_k if method == "function_rag" else None,
|
||||
# Lingua specific params (shared target token name)
|
||||
"lingua_target_token": lingua_target_token if method == "llmlingua" or method == "longllmlingua" else None,
|
||||
# LongLingua specific params
|
||||
"longlingua_chunk_size": longlingua_chunk_size if method == "longllmlingua" else None,
|
||||
"longlingua_overlap": longlingua_overlap if method == "longllmlingua" else None,
|
||||
# CodeCompressor specific params
|
||||
"code_compressor_target_token": code_compressor_target_token if method == "code_compressor" else None, # Added parameter
|
||||
"code_compressor_rank_only": (code_compressor_fine_ratio == 1.0) if method == "code_compressor" else None, # Determined by fine_ratio
|
||||
"code_compressor_fine_ratio": code_compressor_fine_ratio if method == "code_compressor" else None, # Added parameter
|
||||
"importance_beta": importance_beta if method == "code_compressor" else None, # Added parameter
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"Method {method}: Avg ES = {avg_es:.2f}, Avg EM = {avg_em:.2f} ({valid_scores}/{len(original_data)} scored)")
|
||||
save_json(scores, score_output_path)
|
||||
logger.info(f"Scores saved to {score_output_path}")
|
||||
|
||||
logger.info("Evaluation complete.")
|
||||
# Clean up LLM explicitly
|
||||
if 'llm' in locals() and llm is not None:
|
||||
del llm
|
||||
logger.info("Generation LLM deleted.")
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(evaluate_completion)
|
||||
40
long-code-completion/run.sh
Normal file
40
long-code-completion/run.sh
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
MODEL_NAME="Qwen/Qwen2.5-Coder-7B-Instruct"
|
||||
MODEL_PATH_NAME="qwencoder-7b-instruct"
|
||||
BASE_RESULT_DIR="results/${MODEL_PATH_NAME}"
|
||||
BASE_LOG_DIR="logs/${MODEL_PATH_NAME}"
|
||||
|
||||
mkdir -p ${BASE_LOG_DIR}
|
||||
mkdir -p ${BASE_RESULT_DIR}
|
||||
|
||||
echo "Starting experiments for ${MODEL_NAME} on GPU ${CUDA_VISIBLE_DEVICES}"
|
||||
|
||||
# --- CodeCompressor Method Configuration ---
|
||||
TARGET_TOKENS=(2048 4096)
|
||||
FINE_RATIOS=(0.5 0.8)
|
||||
BETAS=(0.0 0.5)
|
||||
|
||||
echo "--- Running CodeCompressor with various configurations ---"
|
||||
for tokens in "${TARGET_TOKENS[@]}"; do
|
||||
for ratio in "${FINE_RATIOS[@]}"; do
|
||||
for beta in "${BETAS[@]}"; do
|
||||
echo "Running CodeCompressor: target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}"
|
||||
python main.py \
|
||||
--model_name ${MODEL_NAME} \
|
||||
--compression_model_name ${MODEL_NAME} \
|
||||
--method code_compressor \
|
||||
--filter_background_tokens_min 5000 \
|
||||
--result_dir "${BASE_RESULT_DIR}" \
|
||||
--num_examples 500 \
|
||||
--code_compressor_target_token ${tokens} \
|
||||
--code_compressor_fine_ratio ${ratio} \
|
||||
--importance_beta ${beta} > "${BASE_LOG_DIR}/code_compressor_t${tokens}_fr${ratio}_b${beta}.log" 2>&1
|
||||
echo "Finished CodeCompressor: target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}"
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
echo "--- Finished CodeCompressor ---"
|
||||
288
long-code-completion/utils.py
Normal file
288
long-code-completion/utils.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import datasets
|
||||
import editdistance
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from transformers import AutoTokenizer
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
|
||||
def compute_ES(target, prediction):
|
||||
"""Compute edit similarity score"""
|
||||
target_lines = [line.strip() for line in target.splitlines() if line.strip()]
|
||||
target_str = '\n'.join(target_lines)
|
||||
prediction_lines = [line.strip() for line in prediction.splitlines()
|
||||
if line.strip() and not line.strip().startswith("#")][:len(target_lines)]
|
||||
prediction_str = '\n'.join(prediction_lines)
|
||||
|
||||
return (1 - (editdistance.eval(target_str, prediction_str) /
|
||||
max(len(target_str), len(prediction_str))))*100
|
||||
|
||||
|
||||
def compute_EM(target, prediction):
|
||||
"""Compute exact match score"""
|
||||
target_lines = [line.strip() for line in target.splitlines() if line.strip()]
|
||||
prediction_lines = [line.strip() for line in prediction.splitlines()
|
||||
if line.strip() and not line.strip().startswith("#")][:len(target_lines)]
|
||||
|
||||
if len(target_lines) != len(prediction_lines):
|
||||
return 0
|
||||
return (int(target_lines == prediction_lines))*100
|
||||
|
||||
|
||||
def load_data(path="microsoft/LCC_python", split="test", num_examples=500, filter_current_lines_max=50, filter_background_tokens_min=5000):
|
||||
"""
|
||||
Loads the dataset, processes it to split contexts, filters it based on context lengths,
|
||||
and returns the filtered dataset along with the tokenizer used.
|
||||
"""
|
||||
print(f"Loading initial {num_examples} examples from {path} ({split} split)...")
|
||||
dataset = datasets.load_dataset(path, split=split)
|
||||
# keep 5 times of num_examples for testing
|
||||
dataset = dataset.select(range(num_examples*10))
|
||||
original_size = len(dataset) # Size before filtering
|
||||
|
||||
# Initialize tokenizer here for filtering and potential later use
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct")
|
||||
print("Tokenizer Qwen/Qwen2.5-Coder-7B-Instruct initialized.")
|
||||
|
||||
# Process dataset to add split contexts first
|
||||
print("Splitting context into background and current function...")
|
||||
def add_split_context(example):
|
||||
background, current_func = split_context_ast(example['context'])
|
||||
example['background_context'] = background
|
||||
example['current_function_context'] = current_func
|
||||
return example
|
||||
|
||||
processed_dataset = dataset.map(add_split_context, num_proc=4) # Use multiple processors if available
|
||||
|
||||
# --- Filter the dataset ---
|
||||
filtered_dataset_list = []
|
||||
print(f"Filtering dataset: Keeping examples where current func lines <= {filter_current_lines_max} and background tokens >= {filter_background_tokens_min}.")
|
||||
|
||||
for example in tqdm(processed_dataset):
|
||||
curr_ctx = example['current_function_context']
|
||||
bg_ctx = example['background_context']
|
||||
|
||||
curr_line_count = len(curr_ctx.splitlines())
|
||||
|
||||
# Check if background context is non-empty before tokenizing
|
||||
bg_token_count = 0
|
||||
if bg_ctx and bg_ctx.strip(): # Check if bg_ctx is not None and not just whitespace
|
||||
# Use truncation=True and max_length to prevent overly long sequences if needed, though for filtering, just length is fine.
|
||||
bg_token_count = len(tokenizer.encode(bg_ctx, add_special_tokens=False)) # Usually better to exclude special tokens for length calculation
|
||||
|
||||
if curr_line_count <= filter_current_lines_max and bg_token_count >= filter_background_tokens_min:
|
||||
filtered_dataset_list.append(example)
|
||||
|
||||
filtered_dataset = datasets.Dataset.from_list(filtered_dataset_list)
|
||||
if num_examples > len(filtered_dataset):
|
||||
selected_dataset = filtered_dataset
|
||||
else:
|
||||
selected_dataset = filtered_dataset.select(range(num_examples))
|
||||
|
||||
print(f"Filtering complete. Original size: {original_size}, Filtered size: {len(filtered_dataset)}. Retaining {min(num_examples, len(filtered_dataset))} examples.") # Adjusted print statement
|
||||
|
||||
# Return both the filtered dataset and the tokenizer
|
||||
return selected_dataset, tokenizer
|
||||
|
||||
|
||||
def find_last_func_or_class_start(code_string):
|
||||
"""
|
||||
Finds the starting line of the last top-level function or class definition
|
||||
using line-based heuristics, robust to syntax errors.
|
||||
Accounts for decorators.
|
||||
Returns the 1-based line number or None if not found.
|
||||
"""
|
||||
lines = code_string.splitlines()
|
||||
if not lines:
|
||||
return None
|
||||
last_def_line_index = -1
|
||||
|
||||
# Iterate backwards to find the last line starting with def/async def/class
|
||||
# We use lstrip() to handle indentation
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
stripped_line = lines[i].lstrip()
|
||||
# Using regex for potentially more robust matching (e.g., def func():)
|
||||
# Matches lines starting with 'def', 'async def', or 'class' followed by space
|
||||
if re.match(r'^(def|async\s+def|class)\s+', stripped_line):
|
||||
last_def_line_index = i
|
||||
break
|
||||
|
||||
if last_def_line_index != -1:
|
||||
# Found a potential start, now check for decorators above it
|
||||
start_line_index = last_def_line_index
|
||||
for i in range(last_def_line_index - 1, -1, -1):
|
||||
stripped_line = lines[i].lstrip()
|
||||
if stripped_line.startswith('@'):
|
||||
start_line_index = i
|
||||
elif stripped_line == '' or stripped_line.startswith('#'): # Skip blank lines and comments
|
||||
continue
|
||||
else:
|
||||
# Found a non-decorator, non-empty, non-comment line, stop searching upwards
|
||||
break
|
||||
return start_line_index + 1 # Return 1-based line number
|
||||
else:
|
||||
# Heuristic failed, maybe return the start of the last non-empty block
|
||||
# or just None if no definitions found at all
|
||||
return None # No function or class definition found
|
||||
|
||||
def split_context_ast(code_string):
|
||||
"""
|
||||
Splits the code context into background and current function/class context using AST.
|
||||
"""
|
||||
lines = code_string.splitlines()
|
||||
split_line_1_based = find_last_func_or_class_start(code_string)
|
||||
|
||||
if split_line_1_based is not None and split_line_1_based > 0:
|
||||
# split_line_1_based is the start of the function/class
|
||||
# Background is lines *before* that line
|
||||
background_lines = lines[:split_line_1_based - 1]
|
||||
current_func_lines = lines[split_line_1_based - 1:]
|
||||
return '\n'.join(background_lines), '\n'.join(current_func_lines)
|
||||
else:
|
||||
# If no function/class found or parse error, treat all as current
|
||||
return "", code_string
|
||||
|
||||
def analyze_dataset(dataset, tokenizer): # Added tokenizer parameter
|
||||
"""Analyzes and plots context length distributions, including function counts and token ratios."""
|
||||
# --- Analysis (Optional: Recalculate stats on the filtered dataset) ---
|
||||
background_lines = []
|
||||
current_func_lines = []
|
||||
background_tokens = []
|
||||
current_func_tokens = []
|
||||
background_func_counts = [] # Added list for function counts
|
||||
bg_curr_token_ratios = [] # Added list for token ratios
|
||||
|
||||
|
||||
# Ensure tokenizer is available - it's passed as an argument now
|
||||
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct") # Removed: Use passed tokenizer
|
||||
|
||||
print(f"\nAnalyzing {len(dataset)} examples...") # Add count here
|
||||
for example in tqdm(dataset): # Use tqdm here for progress
|
||||
bg_ctx = example.get('background_context', '') # Use .get for safety
|
||||
curr_ctx = example.get('current_function_context', '')
|
||||
|
||||
bg_token_count = 0
|
||||
curr_token_count = 0
|
||||
func_count = 0
|
||||
|
||||
# Proceed only if contexts exist
|
||||
if bg_ctx:
|
||||
bg_lines = bg_ctx.splitlines()
|
||||
bg_line_count = len(bg_lines)
|
||||
background_lines.append(bg_line_count)
|
||||
# Use truncation for safety, exclude special tokens for consistency
|
||||
bg_token_count = len(tokenizer.encode(bg_ctx, add_special_tokens=False))
|
||||
background_tokens.append(bg_token_count)
|
||||
# Count functions in background context
|
||||
for line in bg_lines:
|
||||
if re.match(r'^\s*def\s+', line): # Count lines starting with 'def ' after stripping leading whitespace
|
||||
func_count += 1
|
||||
background_func_counts.append(func_count)
|
||||
|
||||
if curr_ctx:
|
||||
curr_line_count = len(curr_ctx.splitlines())
|
||||
current_func_lines.append(curr_line_count)
|
||||
curr_token_count = len(tokenizer.encode(curr_ctx, add_special_tokens=False))
|
||||
current_func_tokens.append(curr_token_count)
|
||||
|
||||
# Calculate ratio, handle division by zero
|
||||
if bg_token_count > 0 and curr_token_count > 0:
|
||||
# Add a small epsilon to avoid potential issues with very small token counts if needed, though direct ratio is fine here.
|
||||
bg_curr_token_ratios.append(bg_token_count / curr_token_count)
|
||||
elif bg_token_count > 0 and curr_token_count == 0:
|
||||
bg_curr_token_ratios.append(np.inf) # Or some large number, or skip - np.inf might break histograms, let's use a large number
|
||||
# Alternatively, filter these out or handle them specifically during plotting/stats
|
||||
pass # Let's skip infinity for plotting simplicity
|
||||
# else: ratio is 0 or undefined, skip
|
||||
|
||||
|
||||
# --- Plotting ---
|
||||
# Check if *any* data exists before proceeding
|
||||
if not any([background_lines, current_func_lines, background_tokens, current_func_tokens, background_func_counts, bg_curr_token_ratios]):
|
||||
print("No data points found for analysis after filtering. Skipping plot generation.")
|
||||
return # Exit if no data to plot
|
||||
|
||||
fig, axs = plt.subplots(3, 2, figsize=(12, 15)) # Changed to 3x2 grid
|
||||
# Use tokenizer name in titles dynamically if possible, or keep generic
|
||||
tokenizer_name = tokenizer.name_or_path if hasattr(tokenizer, 'name_or_path') else "Tokenizer"
|
||||
fig.suptitle(f'Context Analysis (Filtered LCC Python Dataset - {len(dataset)} examples, Tokenizer: {tokenizer_name})')
|
||||
|
||||
# Row 1: Background
|
||||
# Background Lines
|
||||
if background_lines:
|
||||
axs[0, 0].hist(background_lines, bins=50, color='skyblue', edgecolor='black')
|
||||
print(f"Background Lines: Min={np.min(background_lines)}, Max={np.max(background_lines)}, Avg={np.mean(background_lines):.2f}, Median={np.median(background_lines)}")
|
||||
else:
|
||||
axs[0,0].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[0,0].transAxes)
|
||||
axs[0, 0].set_title('Background Context (Lines)')
|
||||
axs[0, 0].set_ylabel('Count')
|
||||
|
||||
# Background Tokens
|
||||
if background_tokens:
|
||||
axs[0, 1].hist(background_tokens, bins=50, color='skyblue', edgecolor='black')
|
||||
print(f"Background Tokens: Min={np.min(background_tokens)}, Max={np.max(background_tokens)}, Avg={np.mean(background_tokens):.2f}, Median={np.median(background_tokens)}")
|
||||
else:
|
||||
axs[0,1].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[0,1].transAxes)
|
||||
axs[0, 1].set_title('Background Context (Tokens)')
|
||||
axs[0, 1].set_ylabel('Count')
|
||||
|
||||
|
||||
# Row 2: Background Func Count & Ratio
|
||||
# Background Function Count
|
||||
if background_func_counts:
|
||||
# Use more bins if the range is small, decide based on max count?
|
||||
max_funcs = np.max(background_func_counts) if background_func_counts else 0
|
||||
bins = min(50, max(1, max_funcs + 1)) # Adjust bins based on max count, ensure at least 1 bin
|
||||
axs[1, 0].hist(background_func_counts, bins=bins, color='lightgreen', edgecolor='black')
|
||||
print(f"Background Func Count: Min={np.min(background_func_counts)}, Max={max_funcs}, Avg={np.mean(background_func_counts):.2f}, Median={np.median(background_func_counts)}")
|
||||
else:
|
||||
axs[1,0].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[1,0].transAxes)
|
||||
axs[1, 0].set_title('Background Function Count')
|
||||
axs[1, 0].set_ylabel('Count')
|
||||
|
||||
# Background/Current Token Ratio
|
||||
if bg_curr_token_ratios:
|
||||
# Ratios can have a large range, consider log scale or clipping?
|
||||
# Let's cap the ratio for visualization if it gets too extreme, e.g., at 50
|
||||
# ratios_to_plot = [min(r, 50) for r in bg_curr_token_ratios] # Cap ratio at 50 for plot
|
||||
ratios_to_plot = bg_curr_token_ratios
|
||||
axs[1, 1].hist(ratios_to_plot, bins=50, color='gold', edgecolor='black')
|
||||
# Calculate stats on original ratios before clipping for plot
|
||||
print(f"BG/Current Token Ratio: Min={np.min(bg_curr_token_ratios):.2f}, Max={np.max(bg_curr_token_ratios):.2f}, Avg={np.mean(bg_curr_token_ratios):.2f}, Median={np.median(bg_curr_token_ratios):.2f}")
|
||||
axs[1, 1].set_title('BG/Current Token Ratio')
|
||||
|
||||
else:
|
||||
axs[1,1].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[1,1].transAxes)
|
||||
axs[1, 1].set_ylabel('Count')
|
||||
|
||||
|
||||
# Row 3: Current Function
|
||||
# Current Function Lines
|
||||
if current_func_lines:
|
||||
axs[2, 0].hist(current_func_lines, bins=50, color='lightcoral', edgecolor='black')
|
||||
print(f"Current Func Lines: Min={np.min(current_func_lines)}, Max={np.max(current_func_lines)}, Avg={np.mean(current_func_lines):.2f}, Median={np.median(current_func_lines)}")
|
||||
else:
|
||||
axs[2,0].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[2,0].transAxes)
|
||||
axs[2, 0].set_title('Current Function Context (Lines)')
|
||||
axs[2, 0].set_xlabel('Number of Lines')
|
||||
axs[2, 0].set_ylabel('Count')
|
||||
|
||||
# Current Function Tokens
|
||||
if current_func_tokens:
|
||||
axs[2, 1].hist(current_func_tokens, bins=50, color='lightcoral', edgecolor='black')
|
||||
print(f"Current Func Tokens: Min={np.min(current_func_tokens)}, Max={np.max(current_func_tokens)}, Avg={np.mean(current_func_tokens):.2f}, Median={np.median(current_func_tokens)}")
|
||||
else:
|
||||
axs[2,1].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center', transform=axs[2,1].transAxes)
|
||||
axs[2, 1].set_title('Current Function Context (Tokens)')
|
||||
axs[2, 1].set_xlabel('Number of Tokens')
|
||||
axs[2, 1].set_ylabel('Count')
|
||||
|
||||
|
||||
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
|
||||
plt.savefig('context_analysis_distributions_filtered.png') # Save with a new descriptive name
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Load data, which now includes filtering and returns the tokenizer
|
||||
filtered_dataset, tokenizer = load_data(num_examples=2000, filter_current_lines_max=50, filter_background_tokens_min=5000)
|
||||
analyze_dataset(filtered_dataset, tokenizer) # Pass tokenizer to analyze_dataset
|
||||
1887
module_summarization/code_compressor.py
Normal file
1887
module_summarization/code_compressor.py
Normal file
File diff suppressed because it is too large
Load Diff
1318
module_summarization/main.py
Normal file
1318
module_summarization/main.py
Normal file
File diff suppressed because it is too large
Load Diff
51
module_summarization/run.sh
Normal file
51
module_summarization/run.sh
Normal file
@@ -0,0 +1,51 @@
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
MODEL_NAME="Qwen/Qwen2.5-Coder-7B-Instruct"
|
||||
MODEL_PATH_NAME="qwencoder-7b-instruct"
|
||||
BASE_RESULT_DIR="results/${MODEL_PATH_NAME}"
|
||||
BASE_LOG_DIR="logs/${MODEL_PATH_NAME}"
|
||||
|
||||
mkdir -p ${BASE_LOG_DIR}
|
||||
mkdir -p ${BASE_RESULT_DIR}
|
||||
|
||||
echo "Starting experiments for ${MODEL_NAME} on GPU ${CUDA_VISIBLE_DEVICES}"
|
||||
|
||||
# --- CodeCompressor Method (Fine-grained with Beta) ---
|
||||
TARGET_TOKENS=(4096)
|
||||
FINE_RATIOS=(0.5)
|
||||
BETAS=(0.5)
|
||||
|
||||
echo "--- Running CodeCompressor (Fine-grained with various Beta values) ---"
|
||||
for ratio in "${FINE_RATIOS[@]}"; do
|
||||
for tokens in "${TARGET_TOKENS[@]}"; do
|
||||
if [[ "${ratio}" == "1.0" ]]; then
|
||||
# If fine_ratio is 1.0, only use default beta 0.0
|
||||
beta=0.0
|
||||
echo "Running CodeCompressor (Fine-grained): target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}"
|
||||
python main.py \
|
||||
--gen_model ${MODEL_NAME} \
|
||||
--model_name ${MODEL_PATH_NAME} \
|
||||
--method code_compressor \
|
||||
--save_dir "${BASE_RESULT_DIR}" \
|
||||
--code_compressor_target_token ${tokens} \
|
||||
--code_compressor_fine_ratio ${ratio} \
|
||||
--importance_beta ${beta} > "${BASE_LOG_DIR}/code_compressor_t${tokens}_fr${ratio}_b${beta}.log" 2>&1
|
||||
echo "Finished CodeCompressor (Fine-grained): target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}"
|
||||
else
|
||||
# For other fine_ratios, test different beta values
|
||||
for beta in "${BETAS[@]}"; do
|
||||
echo "Running CodeCompressor (Fine-grained): target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}"
|
||||
python main.py \
|
||||
--gen_model ${MODEL_NAME} \
|
||||
--model_name ${MODEL_PATH_NAME} \
|
||||
--method code_compressor \
|
||||
--save_dir "${BASE_RESULT_DIR}" \
|
||||
--code_compressor_target_token ${tokens} \
|
||||
--code_compressor_fine_ratio ${ratio} \
|
||||
--importance_beta ${beta} > "${BASE_LOG_DIR}/code_compressor_t${tokens}_fr${ratio}_b${beta}.log" 2>&1
|
||||
echo "Finished CodeCompressor (Fine-grained): target_tokens=${tokens}, fine_ratio=${ratio}, beta=${beta}"
|
||||
done
|
||||
fi
|
||||
done
|
||||
done
|
||||
echo "--- Finished CodeCompressor ---"
|
||||
142
module_summarization/utils.py
Normal file
142
module_summarization/utils.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from loguru import logger
|
||||
from typing import List, Dict, Optional, Tuple, Any
|
||||
from transformers import AutoTokenizer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def truncate_text(text, max_len=512):
|
||||
"""Helper function to truncate long text for logging."""
|
||||
if len(text) <= max_len:
|
||||
return text
|
||||
return text[:max_len//2] + "\n...\n" + text[-max_len//2:]
|
||||
|
||||
|
||||
def load_dataset_samples(dataset_name="JetBrains-Research/lca-module-summarization", split="test",
|
||||
max_examples=None, hf_api_key=None, max_tokens=32768, min_tokens=1024):
|
||||
"""Load dataset samples with optional limiting and filtering of long examples."""
|
||||
dataset = load_dataset(dataset_name, token=hf_api_key)[split]
|
||||
if max_examples is not None:
|
||||
dataset = dataset.select(range(min(max_examples, len(dataset))))
|
||||
|
||||
# Filter out examples with extremely long code
|
||||
if max_tokens > 0:
|
||||
filtered_indices = []
|
||||
skipped_count_long = 0
|
||||
skipped_count_short = 0
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct")
|
||||
|
||||
for i, row in enumerate(tqdm(dataset, desc="Filtering long examples")):
|
||||
code = row['relevant_code_context']
|
||||
if len(code) > max_tokens*10:
|
||||
logger.warning(f"Skipping example {i} because it exceeds {max_tokens*10} characters ({len(code)}/{max_tokens*10})")
|
||||
skipped_count_long += 1
|
||||
continue
|
||||
tokens = tokenizer.encode(code, truncation=False)
|
||||
if len(tokens) > max_tokens:
|
||||
logger.warning(f"Skipping example {i} because it exceeds {max_tokens} tokens ({len(tokens)}/{max_tokens})")
|
||||
skipped_count_long += 1
|
||||
continue
|
||||
if len(tokens) < min_tokens:
|
||||
logger.warning(f"Skipping example {i} because it is too short ({len(tokens)}/{min_tokens})")
|
||||
skipped_count_short += 1
|
||||
continue
|
||||
filtered_indices.append(i)
|
||||
|
||||
if skipped_count_long > 0:
|
||||
logger.info(f"Skipped {skipped_count_long} examples that exceeded token limit of {max_tokens}")
|
||||
if skipped_count_short > 0:
|
||||
logger.info(f"Skipped {skipped_count_short} examples that are too short ({min_tokens} tokens)")
|
||||
|
||||
dataset = dataset.select(filtered_indices)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def get_actual_token_lengths(dataset, tokenizer, output_path="./analysis"):
|
||||
"""
|
||||
Calculate actual token lengths using the specified tokenizer.
|
||||
|
||||
Args:
|
||||
dataset: The dataset containing code samples
|
||||
tokenizer: The tokenizer to use for counting tokens
|
||||
output_path: Path to save analysis results and plots
|
||||
|
||||
Returns:
|
||||
Dict with statistics about token lengths
|
||||
"""
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# Extract actual token counts
|
||||
token_lengths = []
|
||||
|
||||
# Print intent for each file
|
||||
logger.info("\nIntent for each example:")
|
||||
logger.info("======================")
|
||||
|
||||
for i, row in enumerate(tqdm(dataset, desc="Calculating token lengths")):
|
||||
code = row['relevant_code_context']
|
||||
tokens = tokenizer.encode(code, truncation=False) if hasattr(tokenizer, 'encode') else []
|
||||
token_len = len(tokens)
|
||||
token_lengths.append(token_len)
|
||||
|
||||
# Print the intent for each file
|
||||
docfile_name = row.get('docfile_name', f'file_{i}')
|
||||
intent = row.get('intent', 'unknown')
|
||||
logger.info(f" Example {i}: {docfile_name} - Intent: {intent} - Token Length: {token_len}")
|
||||
|
||||
# Calculate statistics
|
||||
stats = {
|
||||
'min': min(token_lengths),
|
||||
'max': max(token_lengths),
|
||||
'mean': np.mean(token_lengths),
|
||||
'median': np.median(token_lengths),
|
||||
'p90': np.percentile(token_lengths, 90),
|
||||
'p95': np.percentile(token_lengths, 95),
|
||||
'p99': np.percentile(token_lengths, 99),
|
||||
}
|
||||
|
||||
# Plot token length histogram
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.hist(token_lengths, bins=50, alpha=0.7)
|
||||
plt.axvline(stats['mean'], color='red', linestyle='dashed', linewidth=1, label=f"Mean: {stats['mean']:.0f}")
|
||||
plt.axvline(stats['median'], color='green', linestyle='dashed', linewidth=1, label=f"Median: {stats['median']:.0f}")
|
||||
plt.axvline(stats['p90'], color='orange', linestyle='dashed', linewidth=1, label=f"90th %: {stats['p90']:.0f}")
|
||||
plt.axvline(stats['p95'], color='purple', linestyle='dashed', linewidth=1, label=f"95th %: {stats['p95']:.0f}")
|
||||
plt.title('Actual Code Length Distribution (Tokens)')
|
||||
plt.xlabel('Tokens')
|
||||
plt.ylabel('Count')
|
||||
plt.legend()
|
||||
plt.savefig(os.path.join(output_path, 'code_length_actual_tokens.png'))
|
||||
|
||||
# Save statistics to a text file
|
||||
with open(os.path.join(output_path, 'token_length_stats.txt'), 'w') as f:
|
||||
f.write("Token Length Statistics\n")
|
||||
f.write("=====================\n\n")
|
||||
|
||||
for key, value in stats.items():
|
||||
f.write(f" {key}: {value:.2f}\n")
|
||||
|
||||
# Return the statistics for further use
|
||||
return stats
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dataset = load_dataset_samples(dataset_name="JetBrains-Research/lca-module-summarization", split="test", max_examples=1000)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct")
|
||||
token_stats = get_actual_token_lengths(dataset, tokenizer, "./analysis")
|
||||
|
||||
# Print summary of findings using logger
|
||||
logger.info("\nSummary of Code Length Analysis:")
|
||||
logger.info("================================")
|
||||
logger.info(f"Number of examples analyzed: {len(dataset)}")
|
||||
|
||||
logger.info("\nActual token-based statistics:")
|
||||
logger.info(f" Mean length: {token_stats['mean']:.0f} tokens")
|
||||
logger.info(f" Median length: {token_stats['median']:.0f} tokens")
|
||||
logger.info(f" 95th percentile: {token_stats['p95']:.0f} tokens")
|
||||
8
repoqa/__init__.py
Normal file
8
repoqa/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
try:
|
||||
from repoqa._version import __version__, __version_tuple__
|
||||
except ImportError:
|
||||
__version__ = "local-dev"
|
||||
1544
repoqa/code_compressor.py
Normal file
1544
repoqa/code_compressor.py
Normal file
File diff suppressed because it is too large
Load Diff
349
repoqa/code_segment_extractor.py
Normal file
349
repoqa/code_segment_extractor.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import re
|
||||
from typing import List, Dict, Optional
|
||||
from loguru import logger
|
||||
import json
|
||||
import os
|
||||
|
||||
def extract_code_segments(code: str, language: str = "python") -> List[Dict]:
|
||||
"""
|
||||
Break down code into a hierarchical structure based on language-specific patterns.
|
||||
Supports Python, C++, Java, TypeScript, Rust, and Go.
|
||||
|
||||
Args:
|
||||
code: Original code string
|
||||
language: Programming language of the code (python, cpp, java, typescript, rust, go)
|
||||
|
||||
Returns:
|
||||
List of code segments, each containing type, content, position, etc.
|
||||
"""
|
||||
language = language.lower()
|
||||
|
||||
# Language-specific patterns
|
||||
patterns = {
|
||||
"python": {
|
||||
"class": r"^class\s+(\w+)",
|
||||
"function": r"^def\s+(\w+)",
|
||||
"import": r"^(import|from)\s+",
|
||||
"comment": r"^#",
|
||||
"docstring": r'^("""|\'\'\')',
|
||||
"docstring_end": r'("""|\'\'\')$',
|
||||
"indent": lambda line: len(line) - len(line.lstrip()),
|
||||
"block_start": lambda line: line.rstrip().endswith(":"),
|
||||
"block_end": lambda line, indent: len(line) - len(line.lstrip()) <= indent
|
||||
},
|
||||
"cpp": {
|
||||
"class": r"^(class|struct)\s+(\w+)",
|
||||
"function": r"^(void|int|bool|string|char|float|double|auto|template\s*<.*>)\s+(\w+)",
|
||||
"import": r"^#include\s+",
|
||||
"comment": r"^//|^/\*",
|
||||
"docstring": r"^/\*\*",
|
||||
"docstring_end": r"\*/$",
|
||||
"indent": lambda line: len(line) - len(line.lstrip()),
|
||||
"block_start": lambda line: line.rstrip().endswith("{"),
|
||||
"block_end": lambda line, indent: line.rstrip() == "}" and len(line) - len(line.lstrip()) <= indent
|
||||
},
|
||||
"java": {
|
||||
"class": r"^(public|private|protected)?\s*(class|interface)\s+(\w+)",
|
||||
"function": r"^(public|private|protected)?\s*(void|int|boolean|String|char|float|double)\s+(\w+)",
|
||||
"import": r"^import\s+",
|
||||
"comment": r"^//|^/\*",
|
||||
"docstring": r"^/\*\*",
|
||||
"docstring_end": r"\*/$",
|
||||
"indent": lambda line: len(line) - len(line.lstrip()),
|
||||
"block_start": lambda line: line.rstrip().endswith("{"),
|
||||
"block_end": lambda line, indent: line.rstrip() == "}" and len(line) - len(line.lstrip()) <= indent
|
||||
},
|
||||
"typescript": {
|
||||
"class": r"^(export\s+)?(class|interface)\s+(\w+)",
|
||||
"function": r"^(export\s+)?(function|const|let|var)\s+(\w+)\s*=",
|
||||
"import": r"^import\s+",
|
||||
"comment": r"^//|^/\*",
|
||||
"docstring": r"^/\*\*",
|
||||
"docstring_end": r"\*/$",
|
||||
"indent": lambda line: len(line) - len(line.lstrip()),
|
||||
"block_start": lambda line: line.rstrip().endswith("{"),
|
||||
"block_end": lambda line, indent: line.rstrip() == "}" and len(line) - len(line.lstrip()) <= indent
|
||||
},
|
||||
"rust": {
|
||||
"class": r"^(pub\s+)?(struct|enum|trait)\s+(\w+)",
|
||||
"function": r"^(pub\s+)?fn\s+(\w+)",
|
||||
"import": r"^use\s+",
|
||||
"comment": r"^//|^/\*",
|
||||
"docstring": r"^//!|^/\*\*",
|
||||
"docstring_end": r"\*/$",
|
||||
"indent": lambda line: len(line) - len(line.lstrip()),
|
||||
"block_start": lambda line: line.rstrip().endswith("{"),
|
||||
"block_end": lambda line, indent: line.rstrip() == "}" and len(line) - len(line.lstrip()) <= indent
|
||||
},
|
||||
"go": {
|
||||
"class": r"^type\s+(\w+)\s+(struct|interface)",
|
||||
"function": r"^func\s+(\w+)",
|
||||
"import": r"^import\s+",
|
||||
"comment": r"^//",
|
||||
"docstring": r"^//",
|
||||
"docstring_end": None, # Go doesn't have multi-line docstrings
|
||||
"indent": lambda line: len(line) - len(line.lstrip()),
|
||||
"block_start": lambda line: line.rstrip().endswith("{"),
|
||||
"block_end": lambda line, indent: line.rstrip() == "}" and len(line) - len(line.lstrip()) <= indent
|
||||
}
|
||||
}
|
||||
|
||||
if language not in patterns:
|
||||
raise ValueError(f"Unsupported language: {language}. Supported languages: {', '.join(patterns.keys())}")
|
||||
|
||||
def get_token_length(text: str) -> int:
|
||||
"""Simple approximation of token length by splitting by whitespace"""
|
||||
if not text:
|
||||
return 0
|
||||
return len(text.split())
|
||||
|
||||
lines = code.split('\n')
|
||||
segments = []
|
||||
lang_patterns = patterns[language]
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
indent_level = lang_patterns["indent"](lines[i])
|
||||
|
||||
# Skip empty lines
|
||||
if not line:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Process class/struct/enum/trait definitions
|
||||
class_match = re.match(lang_patterns["class"], line)
|
||||
if class_match and indent_level == 0:
|
||||
class_start = i
|
||||
class_name = class_match.group(1) if language == "python" else class_match.group(2)
|
||||
class_indent = indent_level
|
||||
|
||||
# Save class header (signature and docstring) separately
|
||||
class_header_start = i
|
||||
|
||||
# Skip to class body
|
||||
i += 1
|
||||
|
||||
# Skip whitespace and comments to find the start of class body
|
||||
while i < len(lines) and (not lines[i].strip() or re.match(lang_patterns["comment"], lines[i].strip())):
|
||||
i += 1
|
||||
|
||||
# Check for docstring
|
||||
if i < len(lines) and lang_patterns["docstring"] and re.match(lang_patterns["docstring"], lines[i].strip()):
|
||||
docstring_start = i
|
||||
i += 1
|
||||
# Find the end of the docstring
|
||||
while i < len(lines):
|
||||
if lang_patterns["docstring_end"] and re.search(lang_patterns["docstring_end"], lines[i]):
|
||||
i += 1
|
||||
break
|
||||
i += 1
|
||||
|
||||
class_header_end = i
|
||||
class_header_code = '\n'.join(lines[class_header_start:class_header_end])
|
||||
|
||||
# Continue processing the rest of the class body
|
||||
class_body_start = i
|
||||
|
||||
# Extract methods/functions within the class
|
||||
while i < len(lines):
|
||||
if i >= len(lines) or (lines[i].strip() and lang_patterns["indent"](lines[i]) <= class_indent):
|
||||
break
|
||||
|
||||
line = lines[i].strip()
|
||||
current_indent = lang_patterns["indent"](lines[i])
|
||||
|
||||
# Check for method/function definition
|
||||
method_indent = class_indent + (4 if language == "python" else 2)
|
||||
if re.match(lang_patterns["function"], line) and current_indent == method_indent:
|
||||
method_start = i
|
||||
method_name = re.match(lang_patterns["function"], line).group(1)
|
||||
|
||||
# Find where method ends
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
if i < len(lines) and lines[i].strip() and lang_patterns["indent"](lines[i]) <= current_indent:
|
||||
break
|
||||
i += 1
|
||||
|
||||
method_end = i
|
||||
method_code = '\n'.join(lines[method_start:method_end])
|
||||
|
||||
segments.append({
|
||||
"type": "method",
|
||||
"name": method_name,
|
||||
"class_name": class_name,
|
||||
"start_line": method_start,
|
||||
"end_line": method_end,
|
||||
"code": method_code,
|
||||
"token_length": get_token_length(method_code),
|
||||
"indent_level": current_indent
|
||||
})
|
||||
|
||||
continue
|
||||
else:
|
||||
# Process non-method code (class attributes, etc.)
|
||||
i += 1
|
||||
|
||||
class_end = i
|
||||
class_code = '\n'.join(lines[class_start:class_end])
|
||||
|
||||
# Add the class header segment
|
||||
segments.append({
|
||||
"type": "class_header",
|
||||
"name": class_name,
|
||||
"start_line": class_header_start,
|
||||
"end_line": class_header_end,
|
||||
"code": class_header_code,
|
||||
"token_length": get_token_length(class_header_code),
|
||||
"indent_level": class_indent
|
||||
})
|
||||
|
||||
continue
|
||||
|
||||
# Process function definitions
|
||||
func_match = re.match(lang_patterns["function"], line)
|
||||
if func_match and indent_level == 0:
|
||||
func_start = i
|
||||
func_name = func_match.group(1)
|
||||
func_indent = indent_level
|
||||
|
||||
# Find the end of the function
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
current_line = lines[i].strip()
|
||||
current_indent = lang_patterns["indent"](lines[i])
|
||||
|
||||
# If we hit another function or class at same or higher level, stop
|
||||
if (re.match(lang_patterns["function"], current_line) or re.match(lang_patterns["class"], current_line)) and current_indent <= func_indent:
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
func_end = i
|
||||
func_code = '\n'.join(lines[func_start:func_end])
|
||||
|
||||
segments.append({
|
||||
"type": "function",
|
||||
"name": func_name,
|
||||
"start_line": func_start,
|
||||
"end_line": func_end,
|
||||
"code": func_code,
|
||||
"token_length": get_token_length(func_code),
|
||||
"indent_level": 0
|
||||
})
|
||||
|
||||
continue
|
||||
|
||||
# Process imports
|
||||
if re.match(lang_patterns["import"], line) and indent_level == 0:
|
||||
import_start = i
|
||||
|
||||
# Check if import statement spans multiple lines
|
||||
while i + 1 < len(lines) and (re.match(lang_patterns["import"], lines[i+1].strip()) or
|
||||
lines[i+1].lstrip().startswith('\\')):
|
||||
i += 1
|
||||
|
||||
import_end = i + 1
|
||||
import_code = '\n'.join(lines[import_start:import_end])
|
||||
|
||||
segments.append({
|
||||
"type": "import",
|
||||
"start_line": import_start,
|
||||
"end_line": import_end,
|
||||
"code": import_code,
|
||||
"token_length": get_token_length(import_code),
|
||||
"indent_level": 0
|
||||
})
|
||||
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Other top-level statements
|
||||
elif indent_level == 0:
|
||||
stmt_start = i
|
||||
|
||||
# Find the end of the statement
|
||||
i += 1
|
||||
while i < len(lines) and (not lines[i].strip() or lang_patterns["indent"](lines[i]) > 0):
|
||||
i += 1
|
||||
|
||||
stmt_end = i
|
||||
stmt_code = '\n'.join(lines[stmt_start:stmt_end])
|
||||
|
||||
segments.append({
|
||||
"type": "statement",
|
||||
"start_line": stmt_start,
|
||||
"end_line": stmt_end,
|
||||
"code": stmt_code,
|
||||
"token_length": get_token_length(stmt_code),
|
||||
"indent_level": 0
|
||||
})
|
||||
|
||||
continue
|
||||
|
||||
# If nothing matched, move to next line
|
||||
i += 1
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example Python code
|
||||
python_code = """
|
||||
import os
|
||||
import sys
|
||||
|
||||
class MyClass:
|
||||
\"\"\"This is a docstring.\"\"\"
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def my_method(self):
|
||||
print(f"Hello, {self.name}!")
|
||||
|
||||
def my_function():
|
||||
return "Hello, world!"
|
||||
|
||||
# This is a comment
|
||||
x = 10
|
||||
y = 20
|
||||
z = x + y
|
||||
"""
|
||||
|
||||
# Example C++ code
|
||||
cpp_code = """
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
class MyClass {
|
||||
public:
|
||||
MyClass(const std::string& name) : name_(name) {}
|
||||
|
||||
void myMethod() {
|
||||
std::cout << "Hello, " << name_ << "!" << std::endl;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
int myFunction() {
|
||||
return 42;
|
||||
}
|
||||
|
||||
// This is a comment
|
||||
int x = 10;
|
||||
int y = 20;
|
||||
int z = x + y;
|
||||
"""
|
||||
|
||||
# Test with Python
|
||||
python_segments = extract_code_segments(python_code, language="python")
|
||||
print(f"Python segments: {len(python_segments)}")
|
||||
|
||||
# Test with C++
|
||||
cpp_segments = extract_code_segments(cpp_code, language="cpp")
|
||||
print(f"C++ segments: {len(cpp_segments)}")
|
||||
426
repoqa/compute_score.py
Normal file
426
repoqa/compute_score.py
Normal file
@@ -0,0 +1,426 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tempdir
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from transformers import AutoConfig
|
||||
from tree_sitter_languages import get_language, get_parser
|
||||
|
||||
from repoqa.data import get_repoqa_data
|
||||
from repoqa.metric import compute_function_similarity
|
||||
from repoqa.utility import COMMENT_QUERY, FUNCTION_QUERY, progress
|
||||
|
||||
LANGUAGES = list(FUNCTION_QUERY.keys())
|
||||
THRESHOLDS = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
||||
|
||||
|
||||
class Result(Enum):
|
||||
BEST_MATCH = "best_match"
|
||||
FAIL_MATCH = "fail_match"
|
||||
|
||||
|
||||
# unbiased estimator from https://github.com/openai/human-eval
|
||||
def estimate_pass_at_k(
|
||||
num_samples: Union[int, List[int], np.ndarray],
|
||||
num_correct: Union[List[int], np.ndarray],
|
||||
k: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Estimates pass@k of each problem and returns them in an array.
|
||||
"""
|
||||
|
||||
def estimator(n: int, c: int, k: int) -> float:
|
||||
"""
|
||||
Calculates 1 - comb(n - c, k) / comb(n, k).
|
||||
"""
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
||||
|
||||
if isinstance(num_samples, int):
|
||||
num_samples_it = itertools.repeat(num_samples, len(num_correct))
|
||||
else:
|
||||
assert len(num_samples) == len(num_correct)
|
||||
num_samples_it = iter(num_samples)
|
||||
|
||||
return np.array(
|
||||
[estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
|
||||
)
|
||||
|
||||
|
||||
def remove_comments(source_code: str, lang: str) -> str:
|
||||
source_bytes = bytes(source_code, "utf8")
|
||||
parser = get_parser(lang)
|
||||
tree = parser.parse(source_bytes)
|
||||
root_node = tree.root_node
|
||||
|
||||
# Remove comments from source code
|
||||
capture_list = []
|
||||
for query_str in COMMENT_QUERY[lang]:
|
||||
comment_query = get_language(lang).query(query_str)
|
||||
capture_list += comment_query.captures(root_node)
|
||||
|
||||
capture_list.sort(key=lambda cap: cap[0].start_byte, reverse=True)
|
||||
|
||||
for node, _ in capture_list:
|
||||
source_bytes = source_bytes[: node.start_byte] + source_bytes[node.end_byte :]
|
||||
|
||||
return source_bytes.decode("utf-8")
|
||||
|
||||
|
||||
def sanitize_output(model_output: str, lang: str) -> str:
|
||||
model_output = model_output.strip()
|
||||
search_pattern = r"^```(?:\w+)?\s*\n(.*?)(?=^```)```"
|
||||
code_blocks = re.findall(search_pattern, model_output, re.DOTALL | re.MULTILINE)
|
||||
|
||||
parser = get_parser(lang)
|
||||
fn_query = get_language(lang).query(FUNCTION_QUERY[lang])
|
||||
|
||||
# If not code blocks found, simply return model output
|
||||
if not code_blocks:
|
||||
return model_output
|
||||
|
||||
processed_blocks = []
|
||||
for block in code_blocks:
|
||||
processed_blocks.append(block)
|
||||
|
||||
# Try to use tree-sitter to parse if possible
|
||||
try:
|
||||
block_bytes = bytes(block, "utf8")
|
||||
tree = parser.parse(block_bytes)
|
||||
for capture in fn_query.captures(tree.root_node):
|
||||
node, _ = capture
|
||||
function_content = block_bytes[node.start_byte : node.end_byte]
|
||||
return function_content.decode("utf8")
|
||||
except:
|
||||
pass
|
||||
|
||||
# no valid functions found by tree-sitter approach return first block
|
||||
return processed_blocks[0]
|
||||
|
||||
|
||||
def print_result_table(model_name, pass_results):
|
||||
# Printing scores in a table
|
||||
table = Table(title=f"Scores (%) of {model_name} at different thresholds")
|
||||
table.add_column("Threshold", justify="center", style="bold magenta")
|
||||
for threshold in THRESHOLDS:
|
||||
table.add_column(f"{threshold}", justify="center")
|
||||
|
||||
# Prepare data to determine the maximum values for each threshold
|
||||
threshold_scores = {threshold: [] for threshold in THRESHOLDS}
|
||||
for lang_results in pass_results.values():
|
||||
for thresh, value in lang_results.items():
|
||||
try:
|
||||
threshold_scores[eval(thresh)].append(value["pass@1"])
|
||||
except:
|
||||
threshold_scores[thresh].append(value["pass@1"])
|
||||
|
||||
# Calculate the maximum score for each threshold
|
||||
max_scores = {
|
||||
threshold: max(scores) for threshold, scores in threshold_scores.items()
|
||||
}
|
||||
min_scores = {
|
||||
threshold: min(scores) for threshold, scores in threshold_scores.items()
|
||||
}
|
||||
|
||||
# Fill the table rows
|
||||
for language, lang_results in pass_results.items():
|
||||
row = [("⭐" if language == "all" else "") + language]
|
||||
for threshold, value in lang_results.items():
|
||||
score = value["pass@1"]
|
||||
formatted_score = f"{100 * score:.1f}"
|
||||
try:
|
||||
if max_scores[eval(threshold)] - score < 0.01:
|
||||
formatted_score = f"[bold green]{formatted_score}[/]"
|
||||
elif score - min_scores[eval(threshold)] < 0.01:
|
||||
formatted_score = f"[bold red]{formatted_score}[/]"
|
||||
except:
|
||||
if max_scores[threshold] - score < 0.01:
|
||||
formatted_score = f"[bold green]{formatted_score}[/]"
|
||||
elif score - min_scores[threshold] < 0.01:
|
||||
formatted_score = f"[bold red]{formatted_score}[/]"
|
||||
row.append(formatted_score)
|
||||
if language == "all":
|
||||
row = [f"[bold yellow]{r}[/]" for r in row]
|
||||
table.add_row(*row)
|
||||
|
||||
Console(width=120).print(table)
|
||||
|
||||
def needle_evaluator(
|
||||
model_output: str,
|
||||
ground_truth: str,
|
||||
repo_info: Dict,
|
||||
lang: str,
|
||||
ignore_comments: bool,
|
||||
) -> Tuple[Result, str, float]:
|
||||
contents = repo_info["content"]
|
||||
needles = repo_info["needles"]
|
||||
|
||||
best_target = None
|
||||
best_similarity = 0
|
||||
sanitized_output = sanitize_output(model_output, lang)
|
||||
if ignore_comments:
|
||||
sanitized_output = remove_comments(sanitized_output, lang)
|
||||
for needle in needles:
|
||||
current_path = needle["path"]
|
||||
current_name = needle["name"]
|
||||
current_func = "\n".join(
|
||||
contents[current_path].split("\n")[
|
||||
needle["start_line"] : needle["end_line"]
|
||||
]
|
||||
)
|
||||
if ignore_comments:
|
||||
current_func = remove_comments(current_func, lang)
|
||||
|
||||
current_similarity = compute_function_similarity(sanitized_output, current_func)
|
||||
if current_similarity > best_similarity:
|
||||
best_similarity = current_similarity
|
||||
best_target = current_name
|
||||
|
||||
if best_target == ground_truth:
|
||||
verdict = Result.BEST_MATCH
|
||||
else:
|
||||
verdict = Result.FAIL_MATCH
|
||||
return verdict, best_target, best_similarity
|
||||
|
||||
|
||||
def _get_repo(lang_data: Dict, repo_name: str) -> Dict:
|
||||
for repo in lang_data:
|
||||
if repo["repo"] == repo_name:
|
||||
return repo
|
||||
|
||||
|
||||
def compute_language_results(evaluation_result: Dict, all_results: Dict) -> None:
|
||||
for language, lang_results in evaluation_result.items():
|
||||
current_result = {}
|
||||
total = np.array([1 for _ in lang_results])
|
||||
|
||||
for threshold in THRESHOLDS:
|
||||
correct_result = []
|
||||
for res in lang_results:
|
||||
bc = 0
|
||||
if res["is_best_similar"] and res["best_similar_score"] >= threshold:
|
||||
bc = 1
|
||||
correct_result.append(bc)
|
||||
correct_result = np.array(correct_result)
|
||||
|
||||
pass_at_k = {
|
||||
f"pass@{k}": estimate_pass_at_k(total, correct_result, k).mean()
|
||||
for k in [1, 10, 100]
|
||||
if total.min() >= k
|
||||
}
|
||||
current_result[threshold] = pass_at_k
|
||||
all_results[language] = current_result
|
||||
|
||||
|
||||
def fetch_hf_context(model_name: str) -> str:
|
||||
# Retrieved from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1073
|
||||
possible_keys = [
|
||||
# OPT
|
||||
"max_position_embeddings",
|
||||
# GPT-2
|
||||
"n_positions",
|
||||
# MPT
|
||||
"max_seq_len",
|
||||
# ChatGLM2
|
||||
"seq_length",
|
||||
# Command-R
|
||||
"model_max_length",
|
||||
# Others
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
try:
|
||||
with tempdir.TempDir() as temp_dir:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=temp_dir,
|
||||
force_download=True,
|
||||
trust_remote_code=True,
|
||||
).to_dict()
|
||||
longest_context = 0
|
||||
for key in possible_keys:
|
||||
if key in config:
|
||||
longest_context = max(config[key], longest_context)
|
||||
if not (longest_context):
|
||||
return "N/A"
|
||||
return str(int(longest_context / 1024)) + "k"
|
||||
except Exception as err:
|
||||
print(f"fetching failed... Reason:\n{err}")
|
||||
return "N/A"
|
||||
|
||||
|
||||
def compute_score(
|
||||
model_name: str, dataset: Dict, model_output: List[Dict], ignore_comments: bool, result_dir: str
|
||||
) -> Dict:
|
||||
evaluation_result = defaultdict(list)
|
||||
|
||||
# if the scores already exist, load them, print the scores and exit
|
||||
try:
|
||||
if os.path.exists(result_dir):
|
||||
with open(result_dir, "r") as f:
|
||||
output_json = json.load(f)
|
||||
print_result_table(model_name, output_json[model_name]["scores"])
|
||||
return output_json
|
||||
except Exception as e:
|
||||
print(f"Error loading scores from {result_dir}: {e}, continuing...")
|
||||
|
||||
with progress(f"Scoring {model_name}") as pbar:
|
||||
for result in pbar.track(model_output):
|
||||
lang = result["language"]
|
||||
repo_name = result["repo"]
|
||||
model_outputs = result["output"]
|
||||
ground_truth = result["name"]
|
||||
repo_info = _get_repo(dataset[lang], repo_name)
|
||||
|
||||
model_output = model_outputs[0]
|
||||
verdict, best_target, best_similarity = needle_evaluator(
|
||||
model_output, ground_truth, repo_info, lang, ignore_comments
|
||||
)
|
||||
|
||||
is_best_similar = False
|
||||
if verdict == Result.BEST_MATCH:
|
||||
is_best_similar = True
|
||||
|
||||
current_task = {
|
||||
"repo": repo_name,
|
||||
"name": ground_truth,
|
||||
"needle_position": result["position_ratio"],
|
||||
"is_best_similar": is_best_similar,
|
||||
"best_similar_score": best_similarity,
|
||||
"best_target": best_target,
|
||||
"position": {
|
||||
"token_start": result["needle_token_start"],
|
||||
"token_end": result["needle_token_end"],
|
||||
},
|
||||
}
|
||||
evaluation_result[lang].append(current_task)
|
||||
|
||||
# Calculate pass@k
|
||||
pass_results = {}
|
||||
|
||||
all_langs = []
|
||||
for lang in evaluation_result:
|
||||
all_langs += evaluation_result[lang]
|
||||
total = np.array([1 for _ in all_langs])
|
||||
|
||||
pass_results["all"] = {}
|
||||
for threshold in THRESHOLDS:
|
||||
correct_result = []
|
||||
for res in all_langs:
|
||||
bc = 0
|
||||
if res["is_best_similar"] and res["best_similar_score"] >= threshold:
|
||||
bc = 1
|
||||
correct_result.append(bc)
|
||||
correct_result = np.array(correct_result)
|
||||
pass_at_k = {
|
||||
f"pass@{k}": estimate_pass_at_k(total, correct_result, k).mean()
|
||||
for k in [1, 10, 100]
|
||||
if total.min() >= k
|
||||
}
|
||||
pass_results["all"][threshold] = pass_at_k
|
||||
|
||||
compute_language_results(evaluation_result, pass_results)
|
||||
print_result_table(model_name, pass_results)
|
||||
|
||||
output_json = {}
|
||||
model_json = {}
|
||||
model_json["eval_date"] = str(datetime.now())
|
||||
|
||||
# hardcode paid models
|
||||
if "/" in model_name:
|
||||
if model_name.startswith("bigcode/starcoder2"):
|
||||
train_context = "16k"
|
||||
else:
|
||||
train_context = fetch_hf_context(model_name)
|
||||
elif model_name.startswith("gpt-4-turbo") or model_name.startswith("gpt-4o-"):
|
||||
train_context = "128k"
|
||||
elif model_name.startswith("gpt-3.5-"):
|
||||
train_context = "16k"
|
||||
elif model_name.startswith("gemini-1.5-pro") or model_name.startswith(
|
||||
"gemini-1.5-flash"
|
||||
):
|
||||
train_context = "1000k"
|
||||
elif model_name.startswith("gemini-1.0-pro"):
|
||||
train_context = "32k"
|
||||
elif model_name.startswith("claude-3-"):
|
||||
train_context = "200k"
|
||||
else:
|
||||
train_context = "N/A"
|
||||
model_json["train_size"] = train_context
|
||||
model_json["scores"] = pass_results
|
||||
model_json["results"] = evaluation_result
|
||||
|
||||
output_json[model_name] = model_json
|
||||
|
||||
return output_json
|
||||
|
||||
|
||||
def get_model_name(output_path: str) -> str:
|
||||
file_name = Path(output_path).stem
|
||||
segments = file_name.split("_")
|
||||
output_name = ""
|
||||
for segment in segments:
|
||||
if segment == "slash":
|
||||
output_name += "/"
|
||||
else:
|
||||
output_name += segment
|
||||
return output_name
|
||||
|
||||
|
||||
def save_json(output_json, result_path) -> None:
|
||||
if os.path.isfile(result_path):
|
||||
decision = ""
|
||||
while decision.lower() not in ["y", "n"]:
|
||||
print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...")
|
||||
# decision = input()
|
||||
decision = "y"
|
||||
|
||||
if not os.path.isfile(result_path):
|
||||
with open(result_path, "w") as f:
|
||||
json.dump(output_json, f)
|
||||
|
||||
|
||||
def compute_main(
|
||||
model_output_path: str, ignore_comments: bool = False, dataset_path: str = None
|
||||
):
|
||||
if dataset_path is None:
|
||||
dataset = get_repoqa_data()
|
||||
else:
|
||||
with open(dataset_path, "r") as dataset_f:
|
||||
dataset = json.load(dataset_f)
|
||||
|
||||
model_outputs = []
|
||||
with open(model_output_path, "r") as output_f:
|
||||
for line in output_f:
|
||||
model_outputs.append(json.loads(line))
|
||||
|
||||
file_base, _ = os.path.splitext(model_output_path)
|
||||
result_path = file_base + "-SCORES.json"
|
||||
model_name = get_model_name(model_output_path)
|
||||
output_json = compute_score(model_name, dataset, model_outputs, ignore_comments)
|
||||
save_json(output_json, result_path)
|
||||
|
||||
|
||||
def main():
|
||||
from fire import Fire
|
||||
|
||||
Fire(compute_main)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
50
repoqa/data.py
Normal file
50
repoqa/data.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
|
||||
import tempdir
|
||||
import wget
|
||||
from appdirs import user_cache_dir
|
||||
|
||||
CACHE_DIR = user_cache_dir("repoqa")
|
||||
|
||||
REPOQA_DATA_OVERRIDE_PATH = os.getenv("REPOQA_DATA_OVERRIDE_PATH", None)
|
||||
REPOQA_DATA_VERSION = os.getenv("REPOQA_DATA_VERSION", "2024-06-23")
|
||||
|
||||
|
||||
def _get_repoqa_data_ready_path() -> str:
|
||||
if REPOQA_DATA_OVERRIDE_PATH:
|
||||
assert os.path.exists(
|
||||
REPOQA_DATA_OVERRIDE_PATH
|
||||
), f"File not found: {REPOQA_DATA_OVERRIDE_PATH}"
|
||||
return REPOQA_DATA_OVERRIDE_PATH
|
||||
|
||||
gzip_url = f"https://github.com/evalplus/repoqa_release/releases/download/{REPOQA_DATA_VERSION}/repoqa-{REPOQA_DATA_VERSION}.json.gz"
|
||||
cache_path = os.path.join(CACHE_DIR, f"repoqa-{REPOQA_DATA_VERSION}.json")
|
||||
# Check if human eval file exists in CACHE_DIR
|
||||
if not os.path.exists(cache_path):
|
||||
# Install HumanEval dataset and parse as json
|
||||
print(f"Downloading dataset from {gzip_url}")
|
||||
with tempdir.TempDir() as tmpdir:
|
||||
gzip_path = os.path.join(tmpdir, f"data.json.gz")
|
||||
wget.download(gzip_url, gzip_path)
|
||||
|
||||
with gzip.open(gzip_path, "rb") as f:
|
||||
repoqa_data = f.read().decode("utf-8")
|
||||
|
||||
# create CACHE_DIR if not exists
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
# Write the original human eval file to CACHE_DIR
|
||||
with open(cache_path, "w") as f:
|
||||
f.write(repoqa_data)
|
||||
|
||||
return cache_path
|
||||
|
||||
|
||||
def get_repoqa_data():
|
||||
with open(_get_repoqa_data_ready_path(), "r") as f:
|
||||
return json.load(f)
|
||||
784
repoqa/main.py
Normal file
784
repoqa/main.py
Normal file
@@ -0,0 +1,784 @@
|
||||
from repoqa.code_compressor import CodeCompressor
|
||||
from repoqa.mgcode_compressor import CodeCompressor as MGCodeCompressor
|
||||
from repoqa.utility import COMMENT_QUERY, progress
|
||||
from repoqa.data import CACHE_DIR, get_repoqa_data
|
||||
from repoqa.compute_score import compute_score, save_json
|
||||
from llmlingua import PromptCompressor
|
||||
from loguru import logger
|
||||
from tree_sitter_languages import get_language, get_parser
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import List, Tuple, Dict
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
import sys
|
||||
|
||||
class ChunkStrategy(Enum):
|
||||
FUNCTION_BASED = "function_based"
|
||||
SLIDING_WINDOW = "sliding_window"
|
||||
|
||||
|
||||
# Language-specific chunk markers
|
||||
CHUNK_MARKERS = {
|
||||
"python": ["class", "def"],
|
||||
"cpp": ["class", "struct", "void", "int", "bool", "double", "float", "char", "auto"],
|
||||
"java": ["class", "interface", "void", "int", "boolean", "double", "float", "char"],
|
||||
"typescript": ["class", "interface", "function", "const", "let", "var"],
|
||||
"rust": ["fn", "struct", "impl", "trait", "enum"],
|
||||
"go": ["func", "type", "struct", "interface"]
|
||||
}
|
||||
|
||||
# all languages
|
||||
# ALL_LANGUAGES = ["python", "cpp", "java", "typescript", "rust", "go"]
|
||||
|
||||
# Model context template
|
||||
TEMPLATE = "instruction\ncode_context\ndescription\ninstruction"
|
||||
|
||||
INSTRUCTION = (
|
||||
"Based on the function description and code context,"
|
||||
" please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeChunk:
|
||||
"""Represents a chunk of code with its embedding"""
|
||||
content: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
embedding: torch.Tensor = None
|
||||
|
||||
|
||||
class CodeChunker:
|
||||
def __init__(self, language: str, strategy: ChunkStrategy = ChunkStrategy.FUNCTION_BASED,
|
||||
window_size: int = 20, overlap_size: int = 10):
|
||||
self.language = language
|
||||
self.parser = get_parser(language)
|
||||
self.strategy = strategy
|
||||
self.window_size = window_size
|
||||
self.overlap_size = overlap_size
|
||||
|
||||
def _is_function_or_class_start(self, line: str) -> bool:
|
||||
"""Check if line starts a new function or class definition"""
|
||||
line = line.strip()
|
||||
return any(line.startswith(marker) for marker in CHUNK_MARKERS[self.language])
|
||||
|
||||
def _chunk_by_function(self, lines: List[str]) -> List[CodeChunk]:
|
||||
"""Split code into chunks based on function/class definitions"""
|
||||
chunks = []
|
||||
current_chunk_lines = []
|
||||
current_start = 0
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if self._is_function_or_class_start(line) and current_chunk_lines:
|
||||
# Store previous chunk
|
||||
chunk_content = '\n'.join(current_chunk_lines)
|
||||
chunks.append(CodeChunk(chunk_content, current_start, i-1))
|
||||
current_chunk_lines = []
|
||||
current_start = i
|
||||
current_chunk_lines.append(line)
|
||||
|
||||
# Add final chunk
|
||||
if current_chunk_lines:
|
||||
chunk_content = '\n'.join(current_chunk_lines)
|
||||
chunks.append(CodeChunk(chunk_content, current_start, len(lines)-1))
|
||||
|
||||
return chunks
|
||||
|
||||
def _chunk_by_sliding_window(self, lines: List[str]) -> List[CodeChunk]:
|
||||
"""Split code into chunks using sliding window approach"""
|
||||
chunks = []
|
||||
|
||||
# Handle case when code is shorter than window size
|
||||
if len(lines) <= self.window_size:
|
||||
return [CodeChunk('\n'.join(lines), 0, len(lines)-1)]
|
||||
|
||||
# Create overlapping chunks
|
||||
start = 0
|
||||
while start < len(lines):
|
||||
end = min(start + self.window_size, len(lines))
|
||||
chunk_content = '\n'.join(lines[start:end])
|
||||
chunks.append(CodeChunk(chunk_content, start, end-1))
|
||||
|
||||
# Move start position by (window_size - overlap_size)
|
||||
start += self.window_size - self.overlap_size
|
||||
|
||||
# If remaining lines are less than window_size, adjust start to include them in last chunk
|
||||
if len(lines) - start < self.window_size:
|
||||
if len(lines) - start > self.overlap_size: # Only if there's enough new content
|
||||
chunk_content = '\n'.join(lines[start:])
|
||||
chunks.append(CodeChunk(chunk_content, start, len(lines)-1))
|
||||
break
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_code(self, code: str) -> List[CodeChunk]:
|
||||
"""Split code into chunks based on selected strategy"""
|
||||
lines = code.split('\n')
|
||||
|
||||
if self.strategy == ChunkStrategy.FUNCTION_BASED:
|
||||
return self._chunk_by_function(lines)
|
||||
elif self.strategy == ChunkStrategy.SLIDING_WINDOW:
|
||||
return self._chunk_by_sliding_window(lines)
|
||||
else:
|
||||
raise ValueError(f"Unknown chunking strategy: {self.strategy}")
|
||||
|
||||
|
||||
class RAGCompressor:
|
||||
def __init__(self, model_name: str = "microsoft/unixcoder-base"):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModel.from_pretrained(model_name)
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model.to(self.device)
|
||||
|
||||
def compute_embeddings(self, chunks: List[CodeChunk]) -> List[CodeChunk]:
|
||||
"""Compute embeddings for code chunks"""
|
||||
for chunk in chunks:
|
||||
inputs = self.tokenizer(chunk.content, return_tensors="pt",
|
||||
truncation=True, max_length=512).to(self.device)
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
# Use mean pooling
|
||||
chunk.embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
|
||||
return chunks
|
||||
|
||||
def get_relevant_chunks(self,
|
||||
query_embedding: torch.Tensor,
|
||||
chunks: List[CodeChunk],
|
||||
top_k: int = 5) -> List[CodeChunk]:
|
||||
"""Get most relevant chunks based on cosine similarity"""
|
||||
similarities = []
|
||||
for chunk in chunks:
|
||||
if chunk.embedding is None:
|
||||
continue
|
||||
sim = torch.cosine_similarity(query_embedding, chunk.embedding, dim=0)
|
||||
similarities.append((sim.item(), chunk))
|
||||
|
||||
# Sort by similarity and take top k
|
||||
similarities.sort(key=lambda x: x[0], reverse=True)
|
||||
return [chunk for _, chunk in similarities[:top_k]]
|
||||
|
||||
|
||||
def compress_context(code_context: str,
|
||||
target_function: str,
|
||||
language: str,
|
||||
rag_compressor: RAGCompressor,
|
||||
chunker: CodeChunker) -> str:
|
||||
"""Compress code context using RAG approach"""
|
||||
# Split into chunks
|
||||
chunks = chunker.chunk_code(code_context)
|
||||
|
||||
# Get original token count
|
||||
original_tokens = len(rag_compressor.tokenizer.encode(code_context))
|
||||
|
||||
# Log original context size
|
||||
logger.info(f"Original context: {code_context}")
|
||||
logger.info(f"Original token count: {original_tokens}")
|
||||
logger.info(f"Number of chunks: {len(chunks)}")
|
||||
|
||||
# Compute embeddings for all chunks
|
||||
chunks = rag_compressor.compute_embeddings(chunks)
|
||||
|
||||
# Get embedding for target function
|
||||
target_embedding = rag_compressor.model(
|
||||
**rag_compressor.tokenizer(target_function, return_tensors="pt",
|
||||
truncation=True, max_length=512).to(rag_compressor.device)
|
||||
).last_hidden_state.mean(dim=1).squeeze()
|
||||
|
||||
# Get most relevant chunks
|
||||
relevant_chunks = rag_compressor.get_relevant_chunks(target_embedding, chunks)
|
||||
|
||||
# Combine relevant chunks
|
||||
compressed_context = "\n".join(chunk.content for chunk in relevant_chunks)
|
||||
|
||||
# Get compressed token count
|
||||
compressed_tokens = len(rag_compressor.tokenizer.encode(compressed_context))
|
||||
|
||||
# Log compression results
|
||||
logger.info(f"Compressed token count: {compressed_tokens}")
|
||||
logger.info(f"Token compression ratio: {compressed_tokens/original_tokens:.2%}")
|
||||
logger.info("Selected chunks:")
|
||||
for i, chunk in enumerate(relevant_chunks):
|
||||
logger.info(f"Chunk {i+1} (lines {chunk.start_line}-{chunk.end_line}):\n{chunk.content}\n")
|
||||
|
||||
return compressed_context
|
||||
|
||||
|
||||
def compress_context_llm_lingua(compressor: PromptCompressor,
|
||||
code_context: str,
|
||||
target_function: str,
|
||||
language: str,
|
||||
target_token: int = 1000) -> str:
|
||||
"""Compress code context using LLMLingua approach"""
|
||||
# Get original token count using LLMLingua's tokenizer
|
||||
original_tokens = len(compressor.tokenizer.encode(code_context))
|
||||
|
||||
# replace the "<|endoftext|>" in the code if there is any
|
||||
if "<|endoftext|>" in code_context:
|
||||
logger.warning(f"Removing <|endoftext|> in code context: {code_context}")
|
||||
code_context = code_context.replace("<|endoftext|>", "")
|
||||
|
||||
# Compress the prompt
|
||||
logger.info(f"Compressing prompt with instruction: \n{INSTRUCTION}")
|
||||
logger.info(f"Code context: \n{code_context}")
|
||||
logger.info(f"Description: \n{target_function}")
|
||||
compressed = compressor.compress_prompt(
|
||||
code_context,
|
||||
instruction=INSTRUCTION,
|
||||
question=target_function + INSTRUCTION,
|
||||
target_token=target_token
|
||||
)
|
||||
|
||||
compressed_prompt = compressed['compressed_prompt']
|
||||
logger.info(f"Compressed prompt: \n{compressed_prompt}")
|
||||
|
||||
# Get compressed token count
|
||||
compressed_tokens = len(compressor.tokenizer.encode(compressed_prompt))
|
||||
|
||||
# Log compression results
|
||||
logger.info(f"Original token count: {original_tokens}")
|
||||
logger.info(f"LLMLingua compressed token count: {compressed_tokens}")
|
||||
logger.info(f"Token compression ratio: {compressed_tokens/original_tokens:.2%}")
|
||||
|
||||
return compressed_prompt
|
||||
|
||||
|
||||
def compress_context_longllmlingua_chunks(compressor: PromptCompressor,
|
||||
code_context: str,
|
||||
target_function: str,
|
||||
language: str,
|
||||
target_token: int = 1000,
|
||||
chunk_size: int = 80,
|
||||
overlap: int = 40) -> str:
|
||||
"""Compress code context using LongLLMLingua chunks approach"""
|
||||
# Get original token count using LLMLingua's tokenizer
|
||||
original_tokens = len(compressor.tokenizer.encode(code_context))
|
||||
|
||||
# replace the "<|endoftext|>" in the code if there is any
|
||||
if "<|endoftext|>" in code_context:
|
||||
logger.warning(f"Removing <|endoftext|> in code context: {code_context}")
|
||||
code_context = code_context.replace("<|endoftext|>", "")
|
||||
|
||||
# Split code into chunks for longllmlingua_chunks method
|
||||
lines = code_context.split('\n')
|
||||
chunks = []
|
||||
for i in range(0, len(lines), chunk_size - overlap):
|
||||
chunk = lines[i:i + chunk_size]
|
||||
if chunk:
|
||||
chunks.append('\n'.join(chunk))
|
||||
|
||||
# Compress the prompt using chunks
|
||||
compressed = compressor.compress_prompt(
|
||||
chunks,
|
||||
instruction=INSTRUCTION,
|
||||
question=target_function + INSTRUCTION,
|
||||
target_token=target_token,
|
||||
rank_method="longllmlingua"
|
||||
)
|
||||
|
||||
compressed_prompt = compressed['compressed_prompt']
|
||||
logger.info(f"Compressed prompt: \n{compressed_prompt}")
|
||||
|
||||
# Get compressed token count
|
||||
compressed_tokens = len(compressor.tokenizer.encode(compressed_prompt))
|
||||
|
||||
# Log compression results
|
||||
logger.info(f"Original token count: {original_tokens}")
|
||||
logger.info(f"LongLLMLingua chunks compressed token count: {compressed_tokens}")
|
||||
logger.info(f"Token compression ratio: {compressed_tokens/original_tokens:.2%}")
|
||||
|
||||
return compressed_prompt
|
||||
|
||||
|
||||
def compress_context_code_compressor(compressor: CodeCompressor,
|
||||
code_context: str,
|
||||
target_function: str,
|
||||
language: str,
|
||||
target_ratio: float = 0.5,
|
||||
ppl_strategy: str = "default",
|
||||
condition_in_question: str = "default",
|
||||
rank_only: bool = False,
|
||||
use_iterative_compression: bool = True,
|
||||
use_line_level_filter: bool = True) -> str:
|
||||
"""Compress code context using CodeCompressor approach
|
||||
|
||||
Args:
|
||||
compressor: The CodeCompressor instance
|
||||
code_context: The code to compress
|
||||
target_function: The function description/query
|
||||
language: The programming language
|
||||
target_ratio: Compression ratio (0.0-1.0)
|
||||
ppl_strategy: Strategy for perplexity calculation
|
||||
condition_in_question: Conditioning mode for perplexity
|
||||
rank_only: If True, only rank and select functions without fine-grained compression
|
||||
use_iterative_compression: Whether to use token-level iterative compression
|
||||
use_line_level_filter: Whether to use line-level filtering
|
||||
"""
|
||||
# replace the "<|endoftext|>" in the code if there is any
|
||||
if "<|endoftext|>" in code_context:
|
||||
logger.warning(f"Removing <|endoftext|> in code context: {code_context}")
|
||||
code_context = code_context.replace("<|endoftext|>", "")
|
||||
|
||||
# Compress the code using CodeCompressor
|
||||
if rank_only:
|
||||
# When rank_only is True, we'll use the compress_code_file method
|
||||
logger.info("===== Rank-only mode =====")
|
||||
compressed = compressor.compress_code_file(
|
||||
code=code_context,
|
||||
query=target_function,
|
||||
instruction=INSTRUCTION,
|
||||
rate=target_ratio,
|
||||
language=language,
|
||||
rank_only=True
|
||||
)
|
||||
else:
|
||||
# For non-function chunk processing, use compress_code if not splitting by functions
|
||||
if not use_line_level_filter and not use_iterative_compression:
|
||||
logger.info("===== Simple truncation mode =====")
|
||||
# Simple truncation mode
|
||||
compressed = compressor.compress_code(
|
||||
code=code_context,
|
||||
query=target_function,
|
||||
instruction=INSTRUCTION,
|
||||
rate=target_ratio,
|
||||
use_line_level_filter=False,
|
||||
use_iterative_compression=False
|
||||
)
|
||||
elif use_line_level_filter and not use_iterative_compression:
|
||||
logger.info("===== Line-level filtering only =====")
|
||||
# Line-level filtering only
|
||||
compressed = compressor.compress_code(
|
||||
code=code_context,
|
||||
query=target_function,
|
||||
instruction=INSTRUCTION,
|
||||
rate=target_ratio,
|
||||
use_line_level_filter=True,
|
||||
use_iterative_compression=False
|
||||
)
|
||||
elif not use_line_level_filter and use_iterative_compression:
|
||||
logger.info("===== Token-level iterative compression only =====")
|
||||
# Token-level iterative compression only
|
||||
compressed = compressor.compress_code(
|
||||
code=code_context,
|
||||
query=target_function,
|
||||
instruction=INSTRUCTION,
|
||||
rate=target_ratio,
|
||||
use_line_level_filter=False,
|
||||
use_iterative_compression=True
|
||||
)
|
||||
else:
|
||||
# Full function-based splitting and compression
|
||||
logger.info("===== Full function-based splitting and compression =====")
|
||||
compressed = compressor.compress_code_file(
|
||||
code=code_context,
|
||||
query=target_function,
|
||||
instruction=INSTRUCTION,
|
||||
rate=target_ratio,
|
||||
language=language,
|
||||
rank_only=False,
|
||||
use_iterative_compression=use_iterative_compression
|
||||
)
|
||||
|
||||
# Get compressed prompt from results
|
||||
if "compressed_prompt" in compressed:
|
||||
compressed_prompt = compressed["compressed_prompt"]
|
||||
else:
|
||||
compressed_prompt = compressed["output"]
|
||||
|
||||
# Log compression results
|
||||
logger.info(f"Original token count: {compressed['original_tokens']}")
|
||||
logger.info(f"CodeCompressor compressed token count: {compressed['compressed_tokens']}")
|
||||
logger.info(f"Token compression ratio: {compressed['compressed_tokens']/compressed['original_tokens']:.2%}")
|
||||
|
||||
return compressed_prompt
|
||||
|
||||
|
||||
def compress_context_mgcode_compressor(compressor: MGCodeCompressor,
|
||||
code_context: str,
|
||||
target_function: str,
|
||||
language: str,
|
||||
target_ratio: float = 0.5,
|
||||
compression_mode: str = "balanced") -> str:
|
||||
"""Compress code context using MG CodeCompressor approach"""
|
||||
# replace the "<|endoftext|>" in the code if there is any
|
||||
if "<|endoftext|>" in code_context:
|
||||
logger.warning(f"Removing <|endoftext|> in code context: {code_context}")
|
||||
code_context = code_context.replace("<|endoftext|>", "")
|
||||
|
||||
# Compress the code using MG CodeCompressor
|
||||
compressed = compressor.compress_code(
|
||||
code=code_context,
|
||||
query=target_function,
|
||||
instruction=INSTRUCTION,
|
||||
target_ratio=target_ratio,
|
||||
compression_mode=compression_mode,
|
||||
enable_fine_compression=True,
|
||||
max_iterations=3,
|
||||
preserve_top_functions=True,
|
||||
language=language
|
||||
)
|
||||
|
||||
compressed_prompt = compressed["compressed_prompt"]
|
||||
# logger.info(f"Compressed prompt: \n{compressed_prompt}")
|
||||
|
||||
# Log compression results
|
||||
logger.info(f"Original token count: {compressed['original_tokens']}")
|
||||
logger.info(f"MG CodeCompressor compressed token count: {compressed['compressed_tokens']}")
|
||||
logger.info(f"Token compression ratio: {compressed['compressed_tokens']/compressed['original_tokens']:.2%}")
|
||||
|
||||
return compressed_prompt
|
||||
|
||||
|
||||
def evaluate_model_rag(
|
||||
model: str,
|
||||
code_context_size: int = 16 * 1024,
|
||||
max_new_tokens: int = 1024,
|
||||
result_dir: str = "results/rag_compressed_v1",
|
||||
languages: List[str] = None,
|
||||
tensor_parallel_size: int = 1,
|
||||
trust_remote_code: bool = True,
|
||||
chunk_strategy: str = "function_based",
|
||||
window_size: int = 20,
|
||||
overlap_size: int = 10,
|
||||
dataset_path: str = None,
|
||||
compression_method: str = "rag",
|
||||
llm_lingua_target_token: int = 1000,
|
||||
compression_ratio: float = 0.5,
|
||||
backend: str = "vllm",
|
||||
ppl_strategy: str = "default",
|
||||
condition_in_question: str = "default",
|
||||
compression_mode: str = "function_focus",
|
||||
rank_only: bool = False,
|
||||
use_iterative_compression: bool = False,
|
||||
use_line_level_filter: bool = False,
|
||||
compression_model: str = "Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4"
|
||||
):
|
||||
# show the parameters of rank_only, use_iterative_compression, use_line_level_filter
|
||||
logger.info(f"Rank-only: {rank_only}")
|
||||
logger.info(f"Use iterative compression: {use_iterative_compression}")
|
||||
logger.info(f"Use line-level filter: {use_line_level_filter}")
|
||||
|
||||
"""Main evaluation function with compression method selection
|
||||
|
||||
Args:
|
||||
model: Model name or path
|
||||
code_context_size: Model context size in tokens
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
result_dir: Directory to save results
|
||||
languages: List of languages to evaluate
|
||||
tensor_parallel_size: Tensor parallel size for vLLM
|
||||
trust_remote_code: Trust remote code for tokenizer and model
|
||||
chunk_strategy: Chunking strategy ("function_based" or "sliding_window")
|
||||
window_size: Window size for sliding window strategy
|
||||
overlap_size: Overlap size for sliding window strategy
|
||||
dataset_path: Path to dataset file
|
||||
compression_method: Compression method
|
||||
("rag", "llm_lingua", "longllmlingua_chunks", "code_compressor", "mgcode_compressor", "original")
|
||||
llm_lingua_target_token: Target token count for LLMLingua
|
||||
compression_ratio: Compression ratio for CodeCompressor
|
||||
backend: Backend for inference ("vllm")
|
||||
ppl_strategy: Perplexity strategy for CodeCompressor
|
||||
condition_in_question: Condition in question for CodeCompressor
|
||||
compression_mode: Compression mode for MGCodeCompressor
|
||||
rank_only: If True, only rank and select functions without fine-grained compression
|
||||
use_iterative_compression: Whether to use token-level iterative compression for code_compressor
|
||||
use_line_level_filter: Whether to apply line-level filtering for code_compressor
|
||||
compression_model: Model name for LLMLingua and CodeCompressor
|
||||
"""
|
||||
# Create result directory
|
||||
os.makedirs(result_dir, exist_ok=True)
|
||||
|
||||
# Add strategy to the output directory name
|
||||
strategy_str = f"_{compression_method}"
|
||||
if compression_method == "llm_lingua":
|
||||
strategy_str += f"_t{llm_lingua_target_token}"
|
||||
elif compression_method == "longllmlingua_chunks":
|
||||
strategy_str += f"_t{llm_lingua_target_token}_w{window_size}_o{overlap_size}"
|
||||
elif compression_method == "code_compressor":
|
||||
# Create a compression mode string based on settings
|
||||
cc_mode = []
|
||||
if rank_only:
|
||||
cc_mode.append("rank_only")
|
||||
else:
|
||||
if use_iterative_compression:
|
||||
cc_mode.append("iter")
|
||||
if use_line_level_filter:
|
||||
cc_mode.append("line")
|
||||
|
||||
mode_str = "_".join(cc_mode) if cc_mode else "simple"
|
||||
strategy_str += f"_t{compression_ratio}_mode_{mode_str}"
|
||||
elif compression_method == "mgcode_compressor":
|
||||
strategy_str += f"_t{compression_ratio}_m{compression_mode}"
|
||||
|
||||
if chunk_strategy == "sliding_window":
|
||||
strategy_str += f"_w{window_size}_o{overlap_size}"
|
||||
|
||||
context_size_dir = os.path.join(result_dir, f"ntoken_{code_context_size}{strategy_str}")
|
||||
os.makedirs(context_size_dir, exist_ok=True)
|
||||
|
||||
model_output_path = os.path.join(
|
||||
context_size_dir,
|
||||
f"{model.replace('/', '_slash_')}.jsonl",
|
||||
)
|
||||
|
||||
# Intermediate file to store compressed contexts
|
||||
compressed_contexts_path = os.path.join(
|
||||
context_size_dir,
|
||||
f"compressed_contexts_{model.replace('/', '_slash_')}.jsonl",
|
||||
)
|
||||
|
||||
# Load cache from Qwen results
|
||||
cache_file = os.path.join("results/ntoken_16384", "Qwen_slash_Qwen2.5-7B-Instruct.jsonl")
|
||||
# cache_file = os.path.join("results/ntoken_16384", "Qwen_slash_Qwen2.5-7B-Instruct-GPTQ-Int4.jsonl")
|
||||
if not os.path.exists(cache_file):
|
||||
raise FileNotFoundError(f"Cache file not found: {cache_file}")
|
||||
|
||||
with open(cache_file) as f:
|
||||
cache = [json.loads(line) for line in f]
|
||||
|
||||
logger.info(f"Loaded {len(cache)} examples from {cache_file}")
|
||||
logger.info(f"Using chunking strategy: {chunk_strategy}")
|
||||
if chunk_strategy == "sliding_window":
|
||||
logger.info(f"Window size: {window_size}, Overlap size: {overlap_size}")
|
||||
if compression_method == "llm_lingua":
|
||||
logger.info(f"Using LLMLingua compression with target tokens: {llm_lingua_target_token}")
|
||||
elif compression_method == "longllmlingua_chunks":
|
||||
logger.info(f"Using LongLLMLingua chunks compression with:")
|
||||
logger.info(f" - Target tokens: {llm_lingua_target_token}")
|
||||
logger.info(f" - Chunk size: {window_size}")
|
||||
logger.info(f" - Overlap: {overlap_size}")
|
||||
elif compression_method == "code_compressor":
|
||||
logger.info(f"Using CodeCompressor with ratio: {compression_ratio}")
|
||||
logger.info(f"CodeCompressor settings:")
|
||||
logger.info(f" - rank_only: {rank_only}")
|
||||
logger.info(f" - use_iterative_compression: {use_iterative_compression}")
|
||||
logger.info(f" - use_line_level_filter: {use_line_level_filter}")
|
||||
|
||||
# Filter by languages if specified
|
||||
if languages:
|
||||
cache = [c for c in cache if c["language"] in languages]
|
||||
|
||||
if dataset_path is not None:
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
else:
|
||||
dataset = get_repoqa_data()
|
||||
|
||||
# If results already exist, load and evaluate
|
||||
if os.path.exists(model_output_path) and os.path.getsize(model_output_path) > 0:
|
||||
logger.info(f"Loading {model_output_path} and evaluating")
|
||||
model_outputs = [json.loads(line) for line in open(model_output_path)]
|
||||
file_base, _ = os.path.splitext(model_output_path)
|
||||
result_path = file_base + "-SCORES.json"
|
||||
output_json = compute_score(
|
||||
model,
|
||||
dataset,
|
||||
model_outputs,
|
||||
True, # Ignore comments since we're using compressed context
|
||||
result_dir=result_dir,
|
||||
)
|
||||
save_json(output_json, result_path)
|
||||
return
|
||||
|
||||
# PHASE 1: Compress all contexts
|
||||
compressed_tasks = []
|
||||
|
||||
# Initialize appropriate compressor based on compression method
|
||||
if compression_method in ["rag", "original"]:
|
||||
rag_compressor = RAGCompressor()
|
||||
else:
|
||||
rag_compressor = None
|
||||
|
||||
# Initialize compressors if needed
|
||||
llm_lingua_compressor = None
|
||||
code_compressor = None
|
||||
mgcode_compressor = None
|
||||
if compression_method in ["llm_lingua", "longllmlingua_chunks"]:
|
||||
llm_lingua_compressor = PromptCompressor(compression_model)
|
||||
elif compression_method == "code_compressor":
|
||||
code_compressor = CodeCompressor(compression_model)
|
||||
elif compression_method == "mgcode_compressor":
|
||||
mgcode_compressor = MGCodeCompressor(compression_model)
|
||||
|
||||
# Convert string strategy to enum
|
||||
try:
|
||||
chunk_strategy_enum = ChunkStrategy(chunk_strategy)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid chunk strategy: {chunk_strategy}. "
|
||||
f"Must be one of {[s.value for s in ChunkStrategy]}")
|
||||
|
||||
# Check if compressed contexts already exist
|
||||
if os.path.exists(compressed_contexts_path) and os.path.getsize(compressed_contexts_path) > 0:
|
||||
logger.info(f"Loading pre-compressed contexts from {compressed_contexts_path}")
|
||||
with open(compressed_contexts_path) as f:
|
||||
compressed_tasks = [json.loads(line) for line in f]
|
||||
else:
|
||||
logger.info(f"Starting compression phase for {len(cache)} examples")
|
||||
# Process and compress each task
|
||||
for i, task in enumerate(tqdm(cache, desc="Compressing contexts")):
|
||||
# Make a copy of the original task
|
||||
compressed_task = dict(task)
|
||||
|
||||
try:
|
||||
# Compression logic based on selected method
|
||||
if compression_method == "rag":
|
||||
chunker = CodeChunker(
|
||||
task["language"],
|
||||
strategy=chunk_strategy_enum,
|
||||
window_size=window_size,
|
||||
overlap_size=overlap_size
|
||||
)
|
||||
compressed_context = compress_context(
|
||||
task["code_context"],
|
||||
task["description"],
|
||||
task["language"],
|
||||
rag_compressor,
|
||||
chunker=chunker
|
||||
)
|
||||
elif compression_method == "llm_lingua":
|
||||
compressed_context = compress_context_llm_lingua(
|
||||
compressor=llm_lingua_compressor,
|
||||
code_context=task["code_context"],
|
||||
target_function=task["description"],
|
||||
language=task["language"],
|
||||
target_token=llm_lingua_target_token
|
||||
)
|
||||
elif compression_method == "longllmlingua_chunks":
|
||||
compressed_context = compress_context_longllmlingua_chunks(
|
||||
compressor=llm_lingua_compressor,
|
||||
code_context=task["code_context"],
|
||||
target_function=task["description"],
|
||||
language=task["language"],
|
||||
target_token=llm_lingua_target_token,
|
||||
chunk_size=window_size,
|
||||
overlap=overlap_size
|
||||
)
|
||||
elif compression_method == "code_compressor":
|
||||
compressed_context = compress_context_code_compressor(
|
||||
compressor=code_compressor,
|
||||
code_context=task["code_context"],
|
||||
target_function=task["description"],
|
||||
language=task["language"],
|
||||
target_ratio=compression_ratio,
|
||||
ppl_strategy=ppl_strategy,
|
||||
condition_in_question=condition_in_question,
|
||||
rank_only=rank_only,
|
||||
use_iterative_compression=use_iterative_compression,
|
||||
use_line_level_filter=use_line_level_filter
|
||||
)
|
||||
elif compression_method == "mgcode_compressor":
|
||||
compressed_context = compress_context_mgcode_compressor(
|
||||
compressor=mgcode_compressor,
|
||||
code_context=task["code_context"],
|
||||
target_function=task["description"],
|
||||
language=task["language"],
|
||||
target_ratio=compression_ratio,
|
||||
compression_mode=compression_mode
|
||||
)
|
||||
elif compression_method == "original":
|
||||
compressed_context = task["code_context"]
|
||||
else:
|
||||
raise ValueError(f"Invalid compression method: {compression_method}")
|
||||
|
||||
# Update task with compressed context
|
||||
compressed_task["code_context"] = compressed_context
|
||||
|
||||
# Generate prompt
|
||||
if compression_method == "code_compressor":
|
||||
compressed_task["prompt"] = compressed_context
|
||||
else:
|
||||
prompt = ""
|
||||
for key in task["template"].split("\n"):
|
||||
prompt += compressed_task[key]
|
||||
compressed_task["prompt"] = prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error compressing item {i} of {len(cache)}: {e}")
|
||||
# Use original context if compression fails
|
||||
compressed_task["code_context"] = task["code_context"]
|
||||
prompt = ""
|
||||
for key in task["template"].split("\n"):
|
||||
prompt += compressed_task[key]
|
||||
compressed_task["prompt"] = prompt
|
||||
|
||||
compressed_tasks.append(compressed_task)
|
||||
|
||||
# Save intermediate results periodically
|
||||
if (i + 1) % 10 == 0 or i == len(cache) - 1:
|
||||
with open(compressed_contexts_path, "w") as f_out:
|
||||
for t in compressed_tasks:
|
||||
f_out.write(json.dumps(t) + "\n")
|
||||
f_out.flush()
|
||||
logger.info(f"Saved {i+1}/{len(cache)} compressed contexts")
|
||||
|
||||
# Clean up compressor objects to free memory
|
||||
del rag_compressor
|
||||
del llm_lingua_compressor
|
||||
del code_compressor
|
||||
del mgcode_compressor
|
||||
|
||||
# Force garbage collection to free GPU memory
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Clear CUDA cache if torch is available
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Cleared GPU memory cache")
|
||||
|
||||
# PHASE 2: Generate responses with vLLM
|
||||
logger.info("Starting response generation phase")
|
||||
|
||||
# Initialize vLLM provider
|
||||
from repoqa.provider.vllm import VllmProvider
|
||||
engine = VllmProvider(
|
||||
model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
max_model_len=int(code_context_size * 1.5),
|
||||
trust_remote_code=trust_remote_code,
|
||||
gpu_memory_utilization=0.8 # Can use higher utilization now
|
||||
)
|
||||
|
||||
# Generate responses for all compressed tasks
|
||||
model_outputs = []
|
||||
for i, task in enumerate(tqdm(compressed_tasks, desc="Generating responses")):
|
||||
# Generate reply
|
||||
replies = engine.generate_reply(
|
||||
task["prompt"], n=1, max_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
# Save result
|
||||
result = {**task, "output": replies}
|
||||
model_outputs.append(result)
|
||||
|
||||
# Save all model outputs
|
||||
with open(model_output_path, "w") as f_out:
|
||||
for r in model_outputs:
|
||||
f_out.write(json.dumps(r) + "\n")
|
||||
f_out.flush()
|
||||
logger.info(f"Saved {len(model_outputs)} responses")
|
||||
|
||||
# Compute and save scores
|
||||
file_base, _ = os.path.splitext(model_output_path)
|
||||
result_path = file_base + "-SCORES.json"
|
||||
output_json = compute_score(
|
||||
model,
|
||||
dataset,
|
||||
model_outputs,
|
||||
True, # Ignore comments since we're using compressed context
|
||||
result_dir=result_dir,
|
||||
)
|
||||
save_json(output_json, result_path)
|
||||
|
||||
|
||||
def main():
|
||||
from fire import Fire
|
||||
Fire(evaluate_model_rag)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
21
repoqa/metric.py
Normal file
21
repoqa/metric.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
|
||||
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
||||
|
||||
|
||||
def compute_function_similarity(
|
||||
candidate_function: str, reference_function: str
|
||||
) -> float:
|
||||
candidate_tokens = [item for item in re.split("\s+", candidate_function.strip())]
|
||||
|
||||
reference_tokens = [item for item in re.split("\s+", reference_function.strip())]
|
||||
|
||||
chencherry = SmoothingFunction()
|
||||
|
||||
return sentence_bleu(
|
||||
[reference_tokens], candidate_tokens, smoothing_function=chencherry.method4
|
||||
)
|
||||
5
repoqa/provider/__init__.py
Normal file
5
repoqa/provider/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from repoqa.provider.base import BaseProvider
|
||||
35
repoqa/provider/anthropic.py
Normal file
35
repoqa/provider/anthropic.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from anthropic import Client
|
||||
|
||||
from repoqa.provider.base import BaseProvider
|
||||
from repoqa.provider.request.anthropic import make_auto_request
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.client = Client(api_key=os.getenv("ANTHROPIC_KEY"))
|
||||
|
||||
def generate_reply(
|
||||
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
|
||||
) -> List[str]:
|
||||
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
|
||||
replies = []
|
||||
for _ in range(n):
|
||||
reply = make_auto_request(
|
||||
self.client,
|
||||
message=question,
|
||||
model=self.model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_msg=system_msg,
|
||||
)
|
||||
replies.append(reply.content[0].text)
|
||||
|
||||
return replies
|
||||
14
repoqa/provider/base.py
Normal file
14
repoqa/provider/base.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
@abstractmethod
|
||||
def generate_reply(
|
||||
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
|
||||
) -> List[str]:
|
||||
...
|
||||
47
repoqa/provider/google.py
Normal file
47
repoqa/provider/google.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
from repoqa.provider.base import BaseProvider
|
||||
from repoqa.provider.request.google import make_auto_request
|
||||
|
||||
|
||||
class GoogleProvider(BaseProvider):
|
||||
def __init__(self, model):
|
||||
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
|
||||
self.model = model
|
||||
self.client = genai.GenerativeModel(model)
|
||||
|
||||
def generate_reply(
|
||||
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
|
||||
) -> List[str]:
|
||||
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
|
||||
replies = make_auto_request(
|
||||
self.client,
|
||||
question,
|
||||
self.model,
|
||||
n=n,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system_msg=system_msg,
|
||||
)
|
||||
|
||||
if len(replies.candidates) != n:
|
||||
print(f"[WARNING] # replies = {len(replies.candidates)} != {n = }")
|
||||
|
||||
ret_texts = []
|
||||
for candidate in replies.candidates:
|
||||
parts = candidate.content.parts
|
||||
if parts:
|
||||
ret_texts.append(parts[0].text)
|
||||
else:
|
||||
print("Empty response!")
|
||||
ret_texts.append("")
|
||||
print(f"{candidate.safety_ratings = }")
|
||||
|
||||
return ret_texts + [""] * (n - len(ret_texts))
|
||||
66
repoqa/provider/hf.py
Normal file
66
repoqa/provider/hf.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from repoqa.provider.base import BaseProvider
|
||||
from repoqa.provider.request import construct_message_list, hacky_assistant_stop_seq
|
||||
|
||||
|
||||
class HfProvider(BaseProvider):
|
||||
def __init__(self, model, trust_remote_code=False, attn_implementation=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code
|
||||
)
|
||||
self.hf_model = AutoModelForCausalLM.from_pretrained(
|
||||
model,
|
||||
trust_remote_code=trust_remote_code,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype="auto",
|
||||
).cuda()
|
||||
self.stop_seq = []
|
||||
if self.tokenizer.chat_template:
|
||||
self.stop_seq.append(hacky_assistant_stop_seq(self.tokenizer))
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_reply(
|
||||
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
|
||||
) -> List[str]:
|
||||
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
|
||||
|
||||
prompt_tokens = self.tokenizer.apply_chat_template(
|
||||
construct_message_list(question, system_msg),
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).cuda()
|
||||
input_length = prompt_tokens.size(-1)
|
||||
|
||||
gen_args = {"do_sample": False}
|
||||
if temperature > 0:
|
||||
gen_args["do_sample"] = True
|
||||
gen_args["temperature"] = temperature
|
||||
|
||||
output_text = self.hf_model.generate(
|
||||
input_ids=prompt_tokens,
|
||||
max_new_tokens=max_tokens,
|
||||
num_return_sequences=n,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
stop_strings=self.stop_seq,
|
||||
tokenizer=self.tokenizer,
|
||||
**gen_args,
|
||||
)
|
||||
|
||||
gen_strs = [
|
||||
self.tokenizer.decode(
|
||||
x[input_length:],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
for x in output_text
|
||||
]
|
||||
return gen_strs
|
||||
46
repoqa/provider/openai.py
Normal file
46
repoqa/provider/openai.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from openai import Client
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from repoqa.provider.base import BaseProvider
|
||||
from repoqa.provider.request import hacky_assistant_stop_seq
|
||||
from repoqa.provider.request.openai import make_auto_request
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
def __init__(self, model, base_url: str = None):
|
||||
self.model = model
|
||||
self.client = Client(
|
||||
api_key=os.getenv("OPENAI_API_KEY", "none"), base_url=base_url
|
||||
)
|
||||
self.stop_seq = []
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
if tokenizer.chat_template:
|
||||
self.stop_seq.append(hacky_assistant_stop_seq(tokenizer))
|
||||
print("Using stop sequence: ", self.stop_seq)
|
||||
except:
|
||||
print("Failed to automatically fetch stop tokens from HuggingFace.")
|
||||
|
||||
def generate_reply(
|
||||
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
|
||||
) -> List[str]:
|
||||
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
|
||||
replies = make_auto_request(
|
||||
self.client,
|
||||
message=question,
|
||||
model=self.model,
|
||||
temperature=temperature,
|
||||
n=n,
|
||||
max_tokens=max_tokens,
|
||||
system_msg=system_msg,
|
||||
stop=self.stop_seq,
|
||||
)
|
||||
|
||||
return [reply.message.content for reply in replies.choices]
|
||||
21
repoqa/provider/request/__init__.py
Normal file
21
repoqa/provider/request/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
def construct_message_list(message, system_message=None):
|
||||
msglist = [{"role": "user", "content": message}]
|
||||
if system_message:
|
||||
msglist.insert(0, {"role": "system", "content": system_message})
|
||||
return msglist
|
||||
|
||||
|
||||
def hacky_assistant_stop_seq(tokenizer) -> str:
|
||||
_magic_string_ = "&==NowOrNever==&Accelerate!!!==&"
|
||||
return tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": ""},
|
||||
{"role": "assistant", "content": _magic_string_},
|
||||
],
|
||||
tokenize=False,
|
||||
).split(_magic_string_)[-1]
|
||||
71
repoqa/provider/request/anthropic.py
Normal file
71
repoqa/provider/request/anthropic.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import signal
|
||||
import time
|
||||
|
||||
import anthropic
|
||||
from anthropic.types import Message
|
||||
|
||||
from repoqa.provider.request import construct_message_list
|
||||
|
||||
|
||||
def make_request(
|
||||
client: anthropic.Client,
|
||||
message: str,
|
||||
model: str,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 1,
|
||||
system_msg="You are a helpful assistant good at coding.",
|
||||
**kwargs,
|
||||
) -> Message:
|
||||
return client.messages.create(
|
||||
model=model,
|
||||
messages=construct_message_list(message, system_message=system_msg),
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def handler(signum, frame):
|
||||
# swallow signum and frame
|
||||
raise Exception("end of time")
|
||||
|
||||
|
||||
def make_auto_request(client: anthropic.Client, *args, **kwargs) -> Message:
|
||||
ret = None
|
||||
while ret is None:
|
||||
try:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(100)
|
||||
ret = make_request(client, *args, **kwargs)
|
||||
signal.alarm(0)
|
||||
except anthropic.RateLimitError:
|
||||
print("Rate limit exceeded. Waiting...")
|
||||
signal.alarm(0)
|
||||
time.sleep(10)
|
||||
except anthropic.APIConnectionError:
|
||||
print("API connection error. Waiting...")
|
||||
signal.alarm(0)
|
||||
time.sleep(5)
|
||||
except anthropic.InternalServerError:
|
||||
print("Internal server error. Waiting...")
|
||||
signal.alarm(0)
|
||||
time.sleep(5)
|
||||
except anthropic.APIError as e:
|
||||
print("Unknown API error")
|
||||
print(e)
|
||||
if (
|
||||
e.body["error"]["message"]
|
||||
== "Output blocked by content filtering policy"
|
||||
):
|
||||
raise Exception("Content filtering policy blocked output")
|
||||
signal.alarm(0)
|
||||
except Exception as e:
|
||||
print("Unknown error. Waiting...")
|
||||
print(e)
|
||||
signal.alarm(0)
|
||||
time.sleep(1)
|
||||
return ret
|
||||
63
repoqa/provider/request/google.py
Normal file
63
repoqa/provider/request/google.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import signal
|
||||
import time
|
||||
|
||||
import google.generativeai as genai
|
||||
from google.api_core.exceptions import GoogleAPICallError, ResourceExhausted
|
||||
|
||||
from repoqa.provider.request import construct_message_list
|
||||
|
||||
|
||||
def make_request(
|
||||
client: genai.GenerativeModel,
|
||||
message: str,
|
||||
model: str,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 1,
|
||||
n: int = 1,
|
||||
system_msg="You are a helpful assistant good at coding.",
|
||||
**kwargs,
|
||||
) -> genai.types.GenerateContentResponse:
|
||||
messages = []
|
||||
if system_msg:
|
||||
messages.append({"role": "system", "parts": [system_msg]})
|
||||
messages.append({"role": "user", "parts": [message]})
|
||||
return client.generate_content(
|
||||
messages,
|
||||
generation_config=genai.types.GenerationConfig(
|
||||
candidate_count=n, max_output_tokens=max_tokens, temperature=temperature
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def handler(signum, frame):
|
||||
# swallow signum and frame
|
||||
raise Exception("end of time")
|
||||
|
||||
|
||||
def make_auto_request(*args, **kwargs) -> genai.types.GenerateContentResponse:
|
||||
ret = None
|
||||
while ret is None:
|
||||
try:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(100)
|
||||
ret = make_request(*args, **kwargs)
|
||||
signal.alarm(0)
|
||||
except ResourceExhausted as e:
|
||||
print("Rate limit exceeded. Waiting...", e.message)
|
||||
signal.alarm(0)
|
||||
time.sleep(10)
|
||||
except GoogleAPICallError as e:
|
||||
print(e.message)
|
||||
signal.alarm(0)
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
print("Unknown error. Waiting...")
|
||||
print(e)
|
||||
signal.alarm(0)
|
||||
time.sleep(1)
|
||||
return ret
|
||||
63
repoqa/provider/request/openai.py
Normal file
63
repoqa/provider/request/openai.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import signal
|
||||
import time
|
||||
|
||||
import openai
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from repoqa.provider.request import construct_message_list
|
||||
|
||||
|
||||
def make_request(
|
||||
client: openai.Client,
|
||||
message: str,
|
||||
model: str,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 1,
|
||||
n: int = 1,
|
||||
system_msg="You are a helpful assistant good at coding.",
|
||||
**kwargs,
|
||||
) -> ChatCompletion:
|
||||
return client.chat.completions.create(
|
||||
model=model,
|
||||
messages=construct_message_list(message, system_message=system_msg),
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
n=n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def handler(signum, frame):
|
||||
# swallow signum and frame
|
||||
raise Exception("end of time")
|
||||
|
||||
|
||||
def make_auto_request(*args, **kwargs) -> ChatCompletion:
|
||||
ret = None
|
||||
while ret is None:
|
||||
try:
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(100)
|
||||
ret = make_request(*args, **kwargs)
|
||||
signal.alarm(0)
|
||||
except openai.RateLimitError:
|
||||
print("Rate limit exceeded. Waiting...")
|
||||
signal.alarm(0)
|
||||
time.sleep(10)
|
||||
except openai.APIConnectionError:
|
||||
print("API connection error. Waiting...")
|
||||
signal.alarm(0)
|
||||
time.sleep(5)
|
||||
except openai.APIError as e:
|
||||
print(e)
|
||||
signal.alarm(0)
|
||||
except Exception as e:
|
||||
print("Unknown error. Waiting...")
|
||||
print(e)
|
||||
signal.alarm(0)
|
||||
time.sleep(1)
|
||||
return ret
|
||||
53
repoqa/provider/vllm.py
Normal file
53
repoqa/provider/vllm.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from repoqa.provider.base import BaseProvider
|
||||
from repoqa.provider.request import construct_message_list, hacky_assistant_stop_seq
|
||||
|
||||
|
||||
class VllmProvider(BaseProvider):
|
||||
def __init__(
|
||||
self, model, tensor_parallel_size, max_model_len=None, trust_remote_code=False, gpu_memory_utilization=0.9
|
||||
):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code
|
||||
)
|
||||
self.llm = LLM(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
max_model_len=max_model_len,
|
||||
trust_remote_code=trust_remote_code,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
)
|
||||
self.stop_seq = []
|
||||
if self.tokenizer.chat_template:
|
||||
self.stop_seq.append(hacky_assistant_stop_seq(self.tokenizer))
|
||||
|
||||
def generate_reply(
|
||||
self, question, n=1, max_tokens=1024, temperature=0.0, system_msg=None
|
||||
) -> List[str]:
|
||||
assert temperature != 0 or n == 1, "n must be 1 when temperature is 0"
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
construct_message_list(question, system_msg),
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
vllm_outputs = self.llm.generate(
|
||||
[prompt],
|
||||
SamplingParams(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop=self.stop_seq,
|
||||
),
|
||||
use_tqdm=False,
|
||||
)
|
||||
|
||||
gen_strs = [x.outputs[0].text for x in vllm_outputs]
|
||||
return gen_strs
|
||||
35
repoqa/run.sh
Normal file
35
repoqa/run.sh
Normal file
@@ -0,0 +1,35 @@
|
||||
#!/bin/bash
|
||||
|
||||
MODEL_NAME="Qwen/Qwen2.5-Coder-7B-Instruct"
|
||||
MODEL_PATH_NAME="qwencoder-7b-instruct"
|
||||
BACKEND="vllm"
|
||||
COMPRESSION_METHOD="code_compressor"
|
||||
BASE_RESULT_DIR="code_compressor_exp_results"
|
||||
BASE_LOG_DIR="logs-combinations"
|
||||
|
||||
mkdir -p ${BASE_LOG_DIR}
|
||||
mkdir -p ${BASE_RESULT_DIR}
|
||||
|
||||
echo "Starting experiments for ${MODEL_NAME}"
|
||||
|
||||
# Configuration arrays
|
||||
COMPRESSION_RATIOS=(0.1 0.2 0.3 0.4)
|
||||
GPU_IDS=(0 1 2 3)
|
||||
|
||||
echo "--- Running CodeCompressor with various compression ratios ---"
|
||||
for i in "${!COMPRESSION_RATIOS[@]}"; do
|
||||
ratio="${COMPRESSION_RATIOS[$i]}"
|
||||
gpu_id="${GPU_IDS[$i]}"
|
||||
|
||||
echo "Running CodeCompressor: compression_ratio=${ratio} on GPU ${gpu_id}"
|
||||
CUDA_VISIBLE_DEVICES=${gpu_id} nohup python main.py \
|
||||
--model ${MODEL_NAME} \
|
||||
--backend ${BACKEND} \
|
||||
--compression-method ${COMPRESSION_METHOD} \
|
||||
--compression-ratio ${ratio} \
|
||||
--result-dir ${BASE_RESULT_DIR} \
|
||||
--rank-only > "${BASE_LOG_DIR}/7B_code_compressor_${ratio}_rank_only_true.log" 2>&1 &
|
||||
echo "Started CodeCompressor: compression_ratio=${ratio} on GPU ${gpu_id}"
|
||||
done
|
||||
|
||||
echo "--- All CodeCompressor experiments started ---"
|
||||
90
repoqa/utility.py
Normal file
90
repoqa/utility.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# SPDX-FileCopyrightText: (c) 2024 EvalPlus Team
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
)
|
||||
|
||||
FUNCTION_QUERY = {
|
||||
"python": "(function_definition name: (_)) @fdef",
|
||||
"java": "(method_declaration name: (_)) @fdef",
|
||||
"typescript": "(function_declaration name: (_)) @fdef",
|
||||
"rust": "(function_item name: (_)) @fdef",
|
||||
"cpp": "(function_definition declarator: (function_declarator declarator: (identifier))) @fdef",
|
||||
"go": "(function_declaration name: (_)) @fdef",
|
||||
}
|
||||
|
||||
COMMENT_QUERY = {
|
||||
"python": [
|
||||
"(block (expression_statement (string) @docstring))",
|
||||
"(comment) @comment",
|
||||
],
|
||||
"java": ["(line_comment) @comment", "(block_comment) @comment"],
|
||||
"cpp": ["(comment) @comment"],
|
||||
"rust": ["(line_comment) @comment", "(block_comment) @comment"],
|
||||
"typescript": ["(comment) @comment"],
|
||||
"go": ["(comment) @comment"],
|
||||
}
|
||||
|
||||
FUNCTION_NAME_QUERY = {
|
||||
"python": """
|
||||
((function_definition
|
||||
name: (identifier) @function_name))
|
||||
""",
|
||||
"java": """
|
||||
(method_declaration
|
||||
name: (identifier) @method_name)
|
||||
""",
|
||||
"typescript": """
|
||||
(function_declaration
|
||||
name: (identifier) @function_name)
|
||||
""",
|
||||
"rust": """
|
||||
(function_item
|
||||
name: (identifier) @function_name)
|
||||
""",
|
||||
"cpp": """
|
||||
(function_definition
|
||||
name: (identifier) @function_name)
|
||||
""",
|
||||
}
|
||||
|
||||
|
||||
def topological_sort(graph):
|
||||
# Stack to store the topological order
|
||||
stack = []
|
||||
# Set to keep track of visited nodes
|
||||
visited = set()
|
||||
|
||||
# Recursive function to process nodes
|
||||
def dfs(node):
|
||||
# Mark the current node as visited
|
||||
visited.add(node)
|
||||
# Recurse for all the vertices adjacent to this vertex
|
||||
for neighbour in graph.get(node, []):
|
||||
if neighbour not in visited:
|
||||
dfs(neighbour)
|
||||
# Push current vertex to stack which stores the result
|
||||
stack.append(node)
|
||||
|
||||
# Call the recursive helper function to store the topological sort starting from all vertices one by one
|
||||
for node in graph:
|
||||
if node not in visited:
|
||||
dfs(node)
|
||||
|
||||
return stack
|
||||
|
||||
|
||||
def progress(note: str = "processing"):
|
||||
return Progress(
|
||||
TextColumn(f"{note} •" + "[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
BarColumn(),
|
||||
MofNCompleteColumn(),
|
||||
TextColumn("•"),
|
||||
TimeElapsedColumn(),
|
||||
)
|
||||
Reference in New Issue
Block a user