Files
LongCodeZip/experiments/long-code-completion/main.py
YerbaPage a391badfe1 packaging
2025-10-11 21:33:12 +08:00

751 lines
34 KiB
Python

import os
import json
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel
from llmlingua import PromptCompressor
import fire
from utils import load_data, compute_EM, compute_ES
from vllm import LLM, SamplingParams
from loguru import logger
from code_compressor import CodeCompressor
import gc
from typing import List
import re
# Helper function for splitting code by functions (standalone version)
def split_code_by_functions_standalone(code: str, language: str = "python") -> List[str]:
"""
Split code into chunks based on function and class definitions for various languages.
Standalone version that doesn't require CodeCompressor instance.
Args:
code: The code to split
language: Programming language of the code (python, cpp, java, typescript, rust, go)
Returns:
List of code chunks, each containing a function, class, or class method
"""
# Define regex patterns for different languages
patterns = {
# Python: Simplified to match 'def' or 'class' followed by content until the next def/class or end
"python": r'(^|\n)(\s*)(def|class)\s+[^\n]+(\n(?!\s*(?:def|class)\s)[^\n]*)*',
# C++: Improved to better handle multi-line declarations
"cpp": r'(^|\n)(\s*)(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s*:\s*[^{]*)?|(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*\s+)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?',
# Java: Improved for multi-line method declarations
"java": r'(^|\n)(\s*)(?:(?:public|private|protected|static|final|abstract|synchronized)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:<.*>)?(?:[a-zA-Z_][a-zA-Z0-9_<>:,\s]*)\s+[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*throws\s+[^{;]*)?)\s*(?:{[^}]*}|[^;]*;)?',
# TypeScript: Enhanced to handle multi-line methods and arrow functions
"typescript": r'(^|\n)(\s*)(?:(?:public|private|protected|static|abstract)\s+)*(?:class\s+[a-zA-Z_][a-zA-Z0-9_]*(?:\s+extends\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+implements\s+[^{]*)?|(?:(?:public|private|protected|static|async)\s+)*(?:function\s+)?(?:[a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<.*>)?\s*\([^{;]*\)\s*(?::\s*[^{;]*\s*)?(?:=>)?)\s*(?:{[^}]*}|[^;]*;)?',
# Rust: Improved for multi-line function declarations
"rust": r'(^|\n)(\s*)(?:pub\s+)?(?:struct\s+[a-zA-Z_][a-zA-Z0-9_]*|impl(?:\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\s+for\s+[a-zA-Z_][a-zA-Z0-9_]*)?|(?:async\s+)?fn\s+[a-zA-Z_][a-zA-Z0-9_]*\s*(?:<.*>)?\s*\([^{;]*\)(?:\s*->\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?',
# Go: Improved for multi-line function declarations
"go": r'(^|\n)(\s*)(?:type\s+[a-zA-Z_][a-zA-Z0-9_]*\s+struct|func\s+(?:\([^)]*\)\s*)?[a-zA-Z_][a-zA-Z0-9_]*\s*\([^{;]*\)(?:\s*[^{;]*\s*)?)\s*(?:{[^}]*}|[^;]*;)?',
}
# Use default Python pattern if language not supported
if language.lower() not in patterns:
language = "python"
function_pattern = re.compile(patterns[language.lower()], re.MULTILINE)
matches = list(function_pattern.finditer(code))
if not matches:
return [code] if code.strip() else [] # No matches, return whole code if not empty
result_chunks = []
# Add code before first match if exists
if matches[0].start() > 0:
pre_code = code[:matches[0].start()].strip()
if pre_code:
result_chunks.append(pre_code)
# Process each match
for i, match in enumerate(matches):
start = match.start()
# End is either start of next match or end of code
if i < len(matches) - 1:
end = matches[i + 1].start()
else:
end = len(code)
chunk = code[start:end].strip()
if chunk:
result_chunks.append(chunk)
return result_chunks
# Helper function for function-level RAG retrieval
def function_rag_retrieve(background_code: str, query_code: str, model, tokenizer, device, language: str, top_k: int) -> str:
"""Uses function-level chunking and retrieves top_k similar functions."""
if not background_code.strip():
return "" # Return empty if no background context
# Split code into function-based chunks
chunks = split_code_by_functions_standalone(background_code, language)
if not chunks:
return "" # Return empty if chunking results in nothing
query_embedding = compute_embedding(query_code, model, tokenizer, device)
chunk_embeddings = []
valid_chunks = []
for chunk in chunks:
if chunk.strip():
chunk_embeddings.append(compute_embedding(chunk, model, tokenizer, device))
valid_chunks.append(chunk)
if not valid_chunks:
return ""
# Stack embeddings for efficient similarity calculation
chunk_embeddings_tensor = torch.stack(chunk_embeddings)
# Compute cosine similarity
similarities = torch.cosine_similarity(query_embedding.unsqueeze(0), chunk_embeddings_tensor, dim=1)
# Get top_k indices
top_k_indices = torch.topk(similarities, k=min(top_k, len(valid_chunks)), dim=0).indices
# Retrieve relevant chunks
retrieved_chunks = [valid_chunks[i] for i in top_k_indices.tolist()]
# Combine relevant chunks (maintain order by similarity score)
combined_code = "\n\n".join(retrieved_chunks)
return combined_code
# Helper function for sliding window chunking
def chunk_sliding_window(code: str, window_size: int, overlap: int) -> list[str]:
"""Splits code into overlapping chunks using a sliding window."""
lines = code.splitlines()
if not lines:
return []
chunks = []
start = 0
stride = window_size - overlap
if stride <= 0:
raise ValueError("Overlap size must be smaller than window size.")
while True:
end = min(start + window_size, len(lines))
chunk_lines = lines[start:end]
if not chunk_lines: # Should not happen if lines is not empty, but safety check
break
chunks.append("\n".join(chunk_lines))
if end == len(lines):
break # Exit loop if we reached the end
next_start = start + stride
# If the next window would go past the end, break
if next_start >= len(lines):
# Add the final overlapping chunk if needed
final_start = max(0, len(lines) - window_size)
if final_start > start: # Ensure it's a new chunk not already added
final_chunk_lines = lines[final_start:]
chunks.append("\n".join(final_chunk_lines))
break
start = next_start
# Handle case where code is shorter than window size
if not chunks and lines:
return ["\n".join(lines)]
# Remove duplicates while preserving order (important for RAG)
seen = set()
unique_chunks = []
for chunk in chunks:
if chunk not in seen:
seen.add(chunk)
unique_chunks.append(chunk)
return unique_chunks
# Helper function to compute embeddings (using mean pooling)
def compute_embedding(text: str, model, tokenizer, device) -> torch.Tensor:
"""Computes sentence embedding for a text using the provided model."""
if not text.strip(): # Handle empty strings
return torch.zeros(model.config.hidden_size).to(device)
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
# Mean pool the last hidden state
embedding = outputs.last_hidden_state.mean(dim=1).squeeze()
return embedding
# Helper function for RAG retrieval
def rag_retrieve(background_code: str, query_code: str, model, tokenizer, device, window_size: int, overlap: int, top_k: int) -> str:
"""Chunks background, embeds chunks and query, retrieves top_k similar chunks."""
if not background_code.strip():
return "" # Return empty if no background context
chunks = chunk_sliding_window(background_code, window_size, overlap)
if not chunks:
return "" # Return empty if chunking results in nothing
query_embedding = compute_embedding(query_code, model, tokenizer, device)
chunk_embeddings = []
valid_chunks = []
for chunk in chunks:
if chunk.strip():
chunk_embeddings.append(compute_embedding(chunk, model, tokenizer, device))
valid_chunks.append(chunk)
if not valid_chunks:
return ""
# Stack embeddings for efficient similarity calculation
chunk_embeddings_tensor = torch.stack(chunk_embeddings)
# Compute cosine similarity
similarities = torch.cosine_similarity(query_embedding.unsqueeze(0), chunk_embeddings_tensor, dim=1)
# Get top_k indices
top_k_indices = torch.topk(similarities, k=min(top_k, len(valid_chunks)), dim=0).indices
# Retrieve and sort chunks by their original position
relevant_chunks_with_indices = []
original_indices_map = {chunk_content: idx for idx, chunk_content in enumerate(chunks)} # Map content back to original index
retrieved_chunk_contents = [valid_chunks[i] for i in top_k_indices.tolist()]
# Find original start lines to sort chronologically (approximate)
chunk_start_lines = {}
current_line = 0
lines = background_code.splitlines()
chunk_map_from_sliding = chunk_sliding_window(background_code, window_size, overlap) # Re-chunk to get consistent indexing if needed
start_line_num = 0
stride = window_size - overlap
for i, chunk_content in enumerate(chunk_map_from_sliding):
# This assumes the chunking function returns chunks in order
chunk_start_lines[chunk_content] = start_line_num
start_line_num += stride
# Rough approximation, doesn't perfectly handle edge cases/final chunks
sorted_relevant_chunks = sorted(
retrieved_chunk_contents,
key=lambda content: chunk_start_lines.get(content, float('inf')) # Sort by approximate start line
)
# Combine relevant chunks
# Original implementation joined with \n, let's keep it simple
combined_code = "\n\n".join(sorted_relevant_chunks) # Separate chunks by double newline for clarity
return combined_code
# Helper function for LLMLingua compression
def compress_llmlingua(context: str, query: str, compressor: PromptCompressor, target_token: int, instruction: str) -> str:
"""Compresses context using LLMLingua."""
if not context.strip():
return ""
try:
# Ensure no "<|endoftext|>"
context_clean = context.replace("<|endoftext|>", "")
compressed = compressor.compress_prompt(
context_clean,
instruction=instruction,
question=query + "\n" + instruction, # Combine query and instruction for question
target_token=target_token
)
# Ensure result exists and is string
result = compressed.get('compressed_prompt', '')
return result if isinstance(result, str) else ""
except Exception as e:
logger.error(f"LLMLingua compression failed: {e}")
# Fallback: Truncate based on target tokens (approximate)
tokens = compressor.tokenizer.encode(context_clean)
if len(tokens) > target_token:
return compressor.tokenizer.decode(tokens[:target_token])
return context_clean
# Helper function for LongLLMLingua compression
def compress_longllmlingua(context: str, query: str, compressor: PromptCompressor, target_token: int, instruction: str, chunk_size: int, overlap: int) -> str:
"""Compresses context using LongLLMLingua with sliding window chunks."""
if not context.strip():
return ""
try:
# Ensure no "<|endoftext|>"
context_clean = context.replace("<|endoftext|>", "")
# Use our sliding window chunker
chunks = chunk_sliding_window(context_clean, chunk_size, overlap)
if not chunks:
return "" # Handle case where context is too short or chunking fails
compressed = compressor.compress_prompt(
chunks,
instruction=instruction,
question=query + "\n" + instruction, # Combine query and instruction for question
target_token=target_token,
rank_method="longllmlingua" # Use the specified rank method
)
# Ensure result exists and is string
result = compressed.get('compressed_prompt', '')
return result if isinstance(result, str) else ""
except Exception as e:
logger.error(f"LongLLMLingua compression failed: {e}")
# Fallback: Truncate based on target tokens (approximate)
tokens = compressor.tokenizer.encode(context_clean)
if len(tokens) > target_token:
return compressor.tokenizer.decode(tokens[:target_token])
return context_clean
# Helper function for CodeCompressor (Rank Only or Fine-grained)
def compress_code_compressor(context: str, query: str, compressor: CodeCompressor, target_token: int, instruction: str, language: str, rank_only: bool, fine_ratio: float, importance_beta: float) -> str:
"""Compresses context using CodeCompressor based on target tokens and rank_only flag."""
if not context.strip():
return ""
try:
# Ensure no "<|endoftext|>"
context_clean = context.replace("<|endoftext|>", "")
if not context_clean.strip():
return "" # Return empty if clean context is empty
# Tokenize to get original length
# Use the compressor's tokenizer
original_tokens = len(compressor.tokenizer.encode(context_clean))
if original_tokens == 0:
return "" # Avoid division by zero
# Calculate target ratio
target_ratio = min(1.0, max(0.0, target_token / original_tokens))
logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}")
# Pass rank_only and fine_ratio
# Assuming compressor is already initialized with the correct model
compressed_result = compressor.compress_code_file(
code=context_clean,
query=query, # Using current function context as query focus
instruction=instruction,
rate=target_ratio,
language=language,
rank_only=rank_only, # Ensure rank_only mode is set
fine_ratio=fine_ratio if not rank_only else None, # Pass fine_ratio only if not rank_only
importance_beta=importance_beta if not rank_only else None, # Pass importance_beta only if not rank_only
)
# Extract compressed content - check both possible keys
compressed_context = compressed_result.get("compressed_code")
if not isinstance(compressed_context, str):
logger.error(f"CodeCompressor returned non-string: {type(compressed_context)}")
compressed_context = "" # Fallback
# Log results
compressed_tokens_count = len(compressor.tokenizer.encode(compressed_context))
final_ratio = (compressed_tokens_count / original_tokens) if original_tokens > 0 else 0
logger.info(f"CodeCompressor: Compressed tokens={compressed_tokens_count}, Actual ratio={final_ratio:.4f}")
return compressed_context
except Exception as e:
logger.error(f"CodeCompressor compression failed: {e}", exc_info=True)
# Fallback: Truncate approximately based on target tokens (less ideal for rank_only)
tokens = compressor.tokenizer.encode(context_clean)
if len(tokens) > target_token:
logger.warning(f"CodeCompressor falling back to simple truncation.")
return compressor.tokenizer.decode(tokens[:target_token])
return context_clean
# Function to save scores
def save_json(data: dict, file_path: str):
"""Saves dictionary data to a JSON file."""
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w') as f:
json.dump(data, f, indent=4)
def generate_completions(llm, batch_prompts, max_new_tokens=128):
# Generate completions for batch
sampling_params = SamplingParams(
temperature=0,
top_p=0.95,
max_tokens=max_new_tokens
)
batch_outputs = llm.generate(
batch_prompts,
sampling_params,
use_tqdm=False
)
return [x.outputs[0].text for x in batch_outputs]
def evaluate_completion(
model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct",
method: str = "full",
result_dir: str = "results/completion_baselines",
embed_model_name: str = "microsoft/unixcoder-base",
compression_model_name: str = "Qwen/Qwen2.5-Coder-7B-Instruct",
dataset_path: str = "microsoft/LCC_python",
dataset_split: str = "test",
num_examples: int = 200,
max_new_tokens: int = 128,
batch_size: int = 16,
# RAG params
rag_window_size: int = 80,
rag_overlap: int = 40,
rag_top_k: int = 3,
# Function RAG params
function_rag_language: str = "python",
function_rag_top_k: int = 3,
# LLMLingua params
lingua_target_token: int = 500,
lingua_instruction: str = "Complete the following code function given the context.",
# LongLLMLingua params
longlingua_chunk_size: int = 80,
longlingua_overlap: int = 40,
# CodeCompressor params (New)
code_compressor_target_token: int = 500,
# vLLM params
tensor_parallel_size: int = 1,
trust_remote_code: bool = True,
gpu_memory_utilization: float = 0.9,
filter_current_lines_max: int = 50,
filter_background_tokens_min: int = 3000,
# New CodeCompressor fine-grained param
code_compressor_fine_ratio: float = 1.0, # Default 1.0 means rank_only=True
# New CodeCompressor importance beta param
importance_beta: float = 0.0, # Default beta is 0.0
):
"""Evaluates code completion baselines with a specified context preparation method."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# --- 1. Load Data ---
# Assuming python for now, might need modification if dataset has multiple languages
# Note: Language info might be needed for CodeCompressor if not always python
dataset, _ = load_data(path=dataset_path, split=dataset_split, num_examples=num_examples,
filter_current_lines_max=filter_current_lines_max, filter_background_tokens_min=filter_background_tokens_min)
logger.info(f"Loaded {len(dataset)} examples from {dataset_path} ({dataset_split} split)")
# --- 2. Initialize Models ---
embed_model = None
embed_tokenizer = None
if method == "rag" or method == "function_rag":
logger.info(f"Initializing embedding model: {embed_model_name}")
embed_tokenizer = AutoTokenizer.from_pretrained(embed_model_name)
embed_model = AutoModel.from_pretrained(embed_model_name).to(device)
embed_model.eval() # Set to evaluation mode
logger.info(f"Embedding model {embed_model_name} initialized.")
lingua_compressor = None
if method == "llmlingua" or method == "longllmlingua":
logger.info(f"Initializing LLMLingua compressor: {compression_model_name}")
lingua_compressor = PromptCompressor(model_name=compression_model_name, device_map="auto")
logger.info(f"LLMLingua compressor {compression_model_name} initialized.")
code_compressor_instance = None # Renamed to avoid conflict
if method == "code_compressor":
logger.info(f"Initializing CodeCompressor: {compression_model_name}")
# Assuming CodeCompressor takes model name and potentially device
# Pass device explicitly if needed by your CodeCompressor implementation
code_compressor_instance = CodeCompressor(compression_model_name)
logger.info(f"CodeCompressor {compression_model_name} initialized.")
if method in ["full", "no_context"]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
# try to compress a dummy prompt to avoid cuda error when initializing the vllm (strange bug)
code_compressor_instance = PromptCompressor(model_name=compression_model_name, device_map="auto")
logger.info(f"CodeCompressor {compression_model_name} initialized.")
dummy_prompt = "def hello_world():\n print('Hello, World!')"*100
compressed_prompt = code_compressor_instance.compress_prompt(dummy_prompt, instruction="Complete the following code function given the context.", question="Complete the following code function given the context.", target_token=500)
logger.info(f"Compressed prompt: {compressed_prompt}")
# --- 3. Process the Specified Method ---
logger.info(f"--- Processing Method: {method} ---")
# Modify result directory based on method and parameters
method_suffix = f"method_{method}"
if method == "rag":
method_suffix += f"_w{rag_window_size}_o{rag_overlap}_k{rag_top_k}"
elif method == "function_rag":
method_suffix += f"_lang{function_rag_language}_k{function_rag_top_k}"
elif method == "llmlingua":
method_suffix += f"_t{lingua_target_token}"
elif method == "longllmlingua":
method_suffix += f"_t{lingua_target_token}_cs{longlingua_chunk_size}_o{longlingua_overlap}"
elif method == "code_compressor":
# Determine if rank_only based on fine_ratio
rank_only_for_suffix = (code_compressor_fine_ratio == 1.0)
suffix_detail = "_rankonly" if rank_only_for_suffix else f"fr{code_compressor_fine_ratio}"
# Add importance_beta to suffix
if importance_beta > 0:
suffix_detail += f"_b{importance_beta}"
# Use code_compressor_target_token for consistency
method_suffix += f"_t{code_compressor_target_token}{suffix_detail}" # Updated suffix
method_result_dir = os.path.join(result_dir, method_suffix)
os.makedirs(method_result_dir, exist_ok=True)
model_output_path = os.path.join(
method_result_dir,
f"{model_name.replace('/', '_slash_')}.jsonl",
)
score_output_path = os.path.join(
method_result_dir,
f"{model_name.replace('/', '_slash_')}-SCORES.json",
)
all_prompts = []
original_data = [] # Store original data to merge with results
# Prepare prompts based on method
for i, example in enumerate(tqdm(dataset, desc=f"Preparing prompts for {method}")):
background_ctx = example['background_context']
current_func_ctx = example['current_function_context'] # This is the prefix
ground_truth = example['gt'] # This is the completion target
# Determine language - assuming python for now based on dataset path
language = "python" # IMPORTANT: Make dynamic if dataset contains multiple languages
context_for_prompt = ""
try:
if method == "full":
context_for_prompt = background_ctx + "\n\n" + current_func_ctx
# some models have max context length of 32768, so we truncate the context (from the head) if it exceeds that
tokenized_context = tokenizer.encode(context_for_prompt)
if len(tokenized_context) > 32768-256:
logger.warning(f"Context length exceeds 32768, truncating from the head. Original length: {len(tokenized_context)}, Truncated length: 32768")
context_for_prompt = tokenizer.decode(tokenized_context[-(32768-256):])
elif method == "rag":
if not embed_model or not embed_tokenizer:
raise ValueError("RAG method selected but embedding model not initialized.")
retrieved_ctx = rag_retrieve(
background_ctx, current_func_ctx,
embed_model, embed_tokenizer, device,
rag_window_size, rag_overlap, rag_top_k
)
context_for_prompt = retrieved_ctx + "\n\n" + current_func_ctx
elif method == "function_rag":
if not embed_model or not embed_tokenizer:
raise ValueError("Function RAG method selected but embedding model not initialized.")
retrieved_ctx = function_rag_retrieve(
background_ctx, current_func_ctx,
embed_model, embed_tokenizer, device,
function_rag_language, function_rag_top_k
)
context_for_prompt = retrieved_ctx + "\n\n" + current_func_ctx
elif method == "llmlingua":
if not lingua_compressor:
raise ValueError("LLMLingua method selected but compressor not initialized.")
compressed_ctx = compress_llmlingua(
background_ctx, current_func_ctx,
lingua_compressor, lingua_target_token, lingua_instruction
)
context_for_prompt = compressed_ctx + "\n\n" + current_func_ctx
elif method == "longllmlingua":
if not lingua_compressor:
raise ValueError("LongLLMLingua method selected but compressor not initialized.")
compressed_ctx = compress_longllmlingua(
background_ctx, current_func_ctx,
lingua_compressor, lingua_target_token, lingua_instruction,
longlingua_chunk_size, longlingua_overlap
)
context_for_prompt = compressed_ctx + "\n\n" + current_func_ctx
elif method == "code_compressor":
if not code_compressor_instance:
raise ValueError("CodeCompressor method selected but compressor not initialized.")
# Determine rank_only based on fine_ratio
rank_only = (code_compressor_fine_ratio == 1.0)
logger.info(f"CodeCompressor mode: {'Rank Only' if rank_only else f'Fine-grained (ratio={code_compressor_fine_ratio})'}")
# Use current_func_ctx as the query for CodeCompressor to focus retrieval
compressed_ctx = compress_code_compressor(
context=background_ctx,
query=current_func_ctx, # Query is the current function prefix
compressor=code_compressor_instance,
target_token=code_compressor_target_token,
instruction=lingua_instruction, # Reusing lingua instruction
language=language,
rank_only=rank_only, # Pass determined rank_only flag
fine_ratio=code_compressor_fine_ratio, # Pass fine_ratio
importance_beta=importance_beta, # Pass importance_beta
)
# Combine the compressed background context with the original current function context
context_for_prompt = compressed_ctx + "\n\n" + current_func_ctx
elif method == "no_context":
context_for_prompt = current_func_ctx
else:
raise ValueError(f"Unknown method: {method}")
prompt = context_for_prompt.strip()
all_prompts.append(prompt)
original_data.append({
"id": example.get("id", i),
"gt": ground_truth,
"original_background_context": background_ctx,
"original_current_function_context": current_func_ctx,
"language": language # Store language if needed later
})
except Exception as e:
logger.warning(f"Error processing example {i} (ID: {example.get('id', 'N/A')}) for method {method}: {e}", exc_info=True)
continue # Skip this example
# --- 4. Clean up Compression/Embedding Models ---
logger.info("Freeing up GPU memory from compression/embedding models")
if embed_model:
del embed_model
if embed_tokenizer:
del embed_tokenizer
if lingua_compressor:
del lingua_compressor
if code_compressor_instance:
del code_compressor_instance # Clean up CodeCompressor
torch.cuda.empty_cache()
gc.collect()
logger.info("GPU memory freed")
# --- 5. Initialize Generation LLM ---
# Check if there are any prompts to process before initializing LLM
if not all_prompts:
logger.error(f"No valid prompts were prepared for method {method}. Skipping generation and scoring.")
return
logger.info(f"Initializing generation LLM: {model_name}")
llm = LLM(
model=model_name,
trust_remote_code=trust_remote_code,
gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tensor_parallel_size,
max_model_len=32768
)
logger.info(f"Generation LLM {model_name} initialized.")
# --- 6. Generate Completions ---
all_outputs = []
logger.info(f"Generating completions for {len(all_prompts)} prompts...")
for i in tqdm(range(0, len(all_prompts), batch_size), desc=f"Generating completions for {method}"):
batch_prompts = all_prompts[i:i + batch_size]
if not batch_prompts:
continue
try:
batch_outputs = generate_completions(llm, batch_prompts, max_new_tokens=max_new_tokens)
all_outputs.extend(batch_outputs)
except Exception as e:
logger.error(f"Error during generation for batch starting at index {i}: {e}")
all_outputs.extend(["ERROR_GENERATING"] * len(batch_prompts))
# --- 7. Evaluate and Save Results ---
model_outputs_data = []
total_es = 0
total_em = 0
valid_scores = 0
if len(all_outputs) != len(original_data):
logger.warning(f"Warning: Mismatch between generated outputs ({len(all_outputs)}) and original data ({len(original_data)}). Scores might be inaccurate.")
min_len = min(len(all_outputs), len(original_data))
all_outputs = all_outputs[:min_len]
original_data = original_data[:min_len]
all_prompts = all_prompts[:min_len]
logger.info(f"Calculating scores and saving results for {len(all_outputs)} examples...")
# make sure that the path exists
os.makedirs(os.path.dirname(model_output_path), exist_ok=True)
with open(model_output_path, "w") as f_out:
for i in range(len(all_outputs)):
output = all_outputs[i]
# Ensure index is valid for original_data and all_prompts
if i >= len(original_data) or i >= len(all_prompts):
logger.error(f"Index {i} out of bounds after potential mismatch alignment. Stopping result processing.")
break
orig_data = original_data[i]
prompt = all_prompts[i]
gt = orig_data['gt']
result = {
**orig_data,
"prompt": prompt,
"output": output,
}
es = 0
em = 0
if output != "ERROR_GENERATING" and gt is not None:
try:
es = compute_ES(gt, output)
em = compute_EM(gt, output)
total_es += es
total_em += em
valid_scores += 1
except Exception as e:
logger.error(f"Error scoring example {orig_data.get('id', i)}: {e}")
result['es'] = es
result['em'] = em
model_outputs_data.append(result)
f_out.write(json.dumps(result) + "\n")
logger.info(f"Raw results saved to {model_output_path}")
avg_es = (total_es / valid_scores) if valid_scores > 0 else 0
avg_em = (total_em / valid_scores) if valid_scores > 0 else 0
# Update the parameters dictionary in scores
scores = {
"model_name": model_name,
"method": method,
"num_examples_scored": valid_scores,
"num_examples_total": len(original_data), # Use length of original_data before potential alignment issues
"average_es": avg_es,
"average_em": avg_em,
"parameters": {
"dataset_path": dataset_path,
"dataset_split": dataset_split,
"filter_current_lines_max": filter_current_lines_max,
"filter_background_tokens_min": filter_background_tokens_min,
"embed_model_name": embed_model_name if method == "rag" or method == "function_rag" else None,
# Combine compression model name reporting
"compression_model_name": compression_model_name if method in ["llmlingua", "longllmlingua", "code_compressor"] else None,
"max_new_tokens": max_new_tokens,
"batch_size": batch_size,
# RAG specific params
"rag_window_size": rag_window_size if method == "rag" else None,
"rag_overlap": rag_overlap if method == "rag" else None,
"rag_top_k": rag_top_k if method == "rag" else None,
# Function RAG params
"function_rag_language": function_rag_language if method == "function_rag" else None,
"function_rag_top_k": function_rag_top_k if method == "function_rag" else None,
# Lingua specific params (shared target token name)
"lingua_target_token": lingua_target_token if method == "llmlingua" or method == "longllmlingua" else None,
# LongLingua specific params
"longlingua_chunk_size": longlingua_chunk_size if method == "longllmlingua" else None,
"longlingua_overlap": longlingua_overlap if method == "longllmlingua" else None,
# CodeCompressor specific params
"code_compressor_target_token": code_compressor_target_token if method == "code_compressor" else None, # Added parameter
"code_compressor_rank_only": (code_compressor_fine_ratio == 1.0) if method == "code_compressor" else None, # Determined by fine_ratio
"code_compressor_fine_ratio": code_compressor_fine_ratio if method == "code_compressor" else None, # Added parameter
"importance_beta": importance_beta if method == "code_compressor" else None, # Added parameter
}
}
logger.info(f"Method {method}: Avg ES = {avg_es:.2f}, Avg EM = {avg_em:.2f} ({valid_scores}/{len(original_data)} scored)")
save_json(scores, score_output_path)
logger.info(f"Scores saved to {score_output_path}")
logger.info("Evaluation complete.")
# Clean up LLM explicitly
if 'llm' in locals() and llm is not None:
del llm
logger.info("Generation LLM deleted.")
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
fire.Fire(evaluate_completion)