mirror of
https://github.com/YerbaPage/LongCodeZip.git
synced 2025-10-22 23:19:46 +03:00
Merge pull request #3 from STEVENTAN100/main
Fixed code in experiments/module-summarization.
This commit is contained in:
@@ -213,7 +213,7 @@ class CodeCompressor:
|
|||||||
self.load_model(model_name, device_map, model_config)
|
self.load_model(model_name, device_map, model_config)
|
||||||
|
|
||||||
logger.debug("Initializing Entropy chunking...")
|
logger.debug("Initializing Entropy chunking...")
|
||||||
self.ppl_chunking = EntropyChunking()
|
self.entropy_chunking = EntropyChunking()
|
||||||
|
|
||||||
# Add caching system for model outputs and token information
|
# Add caching system for model outputs and token information
|
||||||
self.cache = {
|
self.cache = {
|
||||||
@@ -576,7 +576,9 @@ class CodeCompressor:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if not context_list:
|
if not context_list:
|
||||||
return [], [], []
|
# Always return a 4-tuple: (selected_contexts, used_indices, dynamic_ratio, demonstrations_sort)
|
||||||
|
# Keep API consistent for callers that unpack 4 values
|
||||||
|
return [], [], [], []
|
||||||
|
|
||||||
# Get token counts for each context
|
# Get token counts for each context
|
||||||
logger.debug("Calculating token lengths for contexts")
|
logger.debug("Calculating token lengths for contexts")
|
||||||
@@ -588,7 +590,9 @@ class CodeCompressor:
|
|||||||
logger.debug(f"All contexts fit within budget ({total_tokens} <= {target_token})")
|
logger.debug(f"All contexts fit within budget ({total_tokens} <= {target_token})")
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds")
|
logger.debug(f"Context budget control completed in {end_time - start_time:.2f} seconds")
|
||||||
return context_list, list(range(len(context_list))), [0.0] * len(context_list)
|
# Build a default demonstrations_sort with zero scores to preserve structure
|
||||||
|
demonstrations_sort = list(zip(range(len(context_list)), [0.0] * len(context_list)))
|
||||||
|
return context_list, list(range(len(context_list))), [0.0] * len(context_list), demonstrations_sort
|
||||||
|
|
||||||
# Rank contexts by relevance if question is provided
|
# Rank contexts by relevance if question is provided
|
||||||
logger.debug("Ranking contexts by relevance")
|
logger.debug("Ranking contexts by relevance")
|
||||||
@@ -1883,4 +1887,4 @@ if __name__ == "__main__":
|
|||||||
min_lines_for_fine_grained=5,
|
min_lines_for_fine_grained=5,
|
||||||
importance_beta=0.5
|
importance_beta=0.5
|
||||||
)
|
)
|
||||||
logger.info(f"Compressed code (using {result_cond['fine_grained_method_used']}): \n{result_cond['compressed_code']}")
|
logger.info(f"Compressed code (using {result_cond['fine_grained_method_used']}): \n{result_cond['compressed_code']}")
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from itertools import cycle
|
|||||||
class LLMGenerator:
|
class LLMGenerator:
|
||||||
def __init__(self, model_name, device, **model_args):
|
def __init__(self, model_name, device, **model_args):
|
||||||
# Create a vllm LLM instance
|
# Create a vllm LLM instance
|
||||||
engine_args = EngineArgs(model=model_name, device=device, **model_args)
|
engine_args = EngineArgs(model=model_name, gpu_memory_utilization=0.8, device=device, **model_args)
|
||||||
self.model = LLM(**vars(engine_args))
|
self.model = LLM(**vars(engine_args))
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -58,7 +58,7 @@ class LLMGenerator:
|
|||||||
class LLMScorer:
|
class LLMScorer:
|
||||||
def __init__(self, model_name, device, **model_args):
|
def __init__(self, model_name, device, **model_args):
|
||||||
# Create a vllm LLM instance
|
# Create a vllm LLM instance
|
||||||
engine_args = EngineArgs(model=model_name, device=device, **model_args)
|
engine_args = EngineArgs(model=model_name, gpu_memory_utilization=0.8, device=device, **model_args)
|
||||||
self.model = LLM(**vars(engine_args))
|
self.model = LLM(**vars(engine_args))
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -1315,4 +1315,4 @@ def run_documentation_task(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
fire.Fire(run_documentation_task)
|
fire.Fire(run_documentation_task)
|
||||||
|
|||||||
Reference in New Issue
Block a user