mirror of
https://github.com/YerbaPage/LongCodeZip.git
synced 2025-10-22 23:19:46 +03:00
fix the demo
This commit is contained in:
279
assets/example_context.py
Normal file
279
assets/example_context.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# The example file is from trae-agent
|
||||
|
||||
"""Base Agent class for LLM-based agents."""
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Base class for LLM-based agents."""
|
||||
|
||||
_tool_caller: Union[ToolExecutor, DockerToolExecutor]
|
||||
|
||||
def __init__(
|
||||
self, agent_config: AgentConfig, docker_config: dict | None = None, docker_keep: bool = True
|
||||
):
|
||||
"""Initialize the agent.
|
||||
Args:
|
||||
agent_config: Configuration object containing model parameters and other settings.
|
||||
docker_config: Configuration for running in a Docker environment.
|
||||
"""
|
||||
self._llm_client = LLMClient(agent_config.model)
|
||||
self._model_config = agent_config.model
|
||||
self._max_steps = agent_config.max_steps
|
||||
self._initial_messages: list[LLMMessage] = []
|
||||
self._task: str = ""
|
||||
self._tools: list[Tool] = [
|
||||
tools_registry[tool_name](model_provider=self._model_config.model_provider.provider)
|
||||
for tool_name in agent_config.tools
|
||||
]
|
||||
self.docker_keep = docker_keep
|
||||
self.docker_manager: DockerManager | None = None
|
||||
original_tool_executor = ToolExecutor(self._tools)
|
||||
if docker_config:
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
# tools_dir = os.path.join(project_root, 'tools')
|
||||
|
||||
tools_dir = os.path.join(project_root, "dist")
|
||||
|
||||
is_interactive_mode = False
|
||||
self.docker_manager = DockerManager(
|
||||
image=docker_config.get("image"),
|
||||
container_id=docker_config.get("container_id"),
|
||||
dockerfile_path=docker_config.get("dockerfile_path"),
|
||||
docker_image_file=docker_config.get("docker_image_file"),
|
||||
workspace_dir=docker_config["workspace_dir"],
|
||||
tools_dir=tools_dir,
|
||||
interactive=is_interactive_mode,
|
||||
)
|
||||
self._tool_caller = DockerToolExecutor(
|
||||
original_executor=original_tool_executor,
|
||||
docker_manager=self.docker_manager,
|
||||
docker_tools=["bash", "str_replace_based_edit_tool", "json_edit_tool"],
|
||||
host_workspace_dir=docker_config.get("workspace_dir"),
|
||||
container_workspace_dir=self.docker_manager.container_workspace,
|
||||
)
|
||||
else:
|
||||
self._tool_caller = original_tool_executor
|
||||
|
||||
self._cli_console: CLIConsole | None = None
|
||||
|
||||
# Trajectory recorder
|
||||
self._trajectory_recorder: TrajectoryRecorder | None = None
|
||||
|
||||
# CKG tool-specific: clear the older CKG databases
|
||||
clear_older_ckg()
|
||||
|
||||
@abstractmethod
|
||||
def new_task(
|
||||
self,
|
||||
task: str,
|
||||
extra_args: dict[str, str] | None = None,
|
||||
tool_names: list[str] | None = None,
|
||||
):
|
||||
"""Create a new task."""
|
||||
pass
|
||||
|
||||
async def execute_task(self) -> AgentExecution:
|
||||
"""Execute a task using the agent."""
|
||||
import time
|
||||
|
||||
if self.docker_manager:
|
||||
self.docker_manager.start()
|
||||
|
||||
start_time = time.time()
|
||||
execution = AgentExecution(task=self._task, steps=[])
|
||||
step: AgentStep | None = None
|
||||
|
||||
try:
|
||||
messages = self._initial_messages
|
||||
step_number = 1
|
||||
execution.agent_state = AgentState.RUNNING
|
||||
|
||||
while step_number <= self._max_steps:
|
||||
step = AgentStep(step_number=step_number, state=AgentStepState.THINKING)
|
||||
try:
|
||||
messages = await self._run_llm_step(step, messages, execution)
|
||||
await self._finalize_step(
|
||||
step, messages, execution
|
||||
) # record trajectory for this step and update the CLI console
|
||||
if execution.agent_state == AgentState.COMPLETED:
|
||||
break
|
||||
step_number += 1
|
||||
except Exception as error:
|
||||
execution.agent_state = AgentState.ERROR
|
||||
step.state = AgentStepState.ERROR
|
||||
step.error = str(error)
|
||||
await self._finalize_step(step, messages, execution)
|
||||
break
|
||||
if step_number > self._max_steps and not execution.success:
|
||||
execution.final_result = "Task execution exceeded maximum steps without completion."
|
||||
execution.agent_state = AgentState.ERROR
|
||||
|
||||
except Exception as e:
|
||||
execution.final_result = f"Agent execution failed: {str(e)}"
|
||||
|
||||
finally:
|
||||
if self.docker_manager and not self.docker_keep:
|
||||
self.docker_manager.stop()
|
||||
|
||||
# Ensure tool resources are released whether an exception occurs or not.
|
||||
await self._close_tools()
|
||||
|
||||
execution.execution_time = time.time() - start_time
|
||||
|
||||
# Clean up any MCP clients
|
||||
with contextlib.suppress(Exception):
|
||||
await self.cleanup_mcp_clients()
|
||||
|
||||
self._update_cli_console(step, execution)
|
||||
return execution
|
||||
|
||||
async def _close_tools(self):
|
||||
"""Release tool resources, mainly about BashTool object."""
|
||||
if self._tool_caller:
|
||||
# Ensure all tool resources are properly released.
|
||||
res = await self._tool_caller.close_tools()
|
||||
return res
|
||||
|
||||
async def _run_llm_step(
|
||||
self, step: "AgentStep", messages: list["LLMMessage"], execution: "AgentExecution"
|
||||
) -> list["LLMMessage"]:
|
||||
# Display thinking state
|
||||
step.state = AgentStepState.THINKING
|
||||
self._update_cli_console(step, execution)
|
||||
# Get LLM response
|
||||
llm_response = self._llm_client.chat(messages, self._model_config, self._tools)
|
||||
step.llm_response = llm_response
|
||||
|
||||
# Display step with LLM response
|
||||
self._update_cli_console(step, execution)
|
||||
|
||||
# Update token usage
|
||||
self._update_llm_usage(llm_response, execution)
|
||||
|
||||
if self.llm_indicates_task_completed(llm_response):
|
||||
if self._is_task_completed(llm_response):
|
||||
execution.agent_state = AgentState.COMPLETED
|
||||
execution.final_result = llm_response.content
|
||||
execution.success = True
|
||||
return messages
|
||||
else:
|
||||
execution.agent_state = AgentState.RUNNING
|
||||
return [LLMMessage(role="user", content=self.task_incomplete_message())]
|
||||
else:
|
||||
tool_calls = llm_response.tool_calls
|
||||
return await self._tool_call_handler(tool_calls, step)
|
||||
|
||||
async def _finalize_step(
|
||||
self, step: "AgentStep", messages: list["LLMMessage"], execution: "AgentExecution"
|
||||
) -> None:
|
||||
step.state = AgentStepState.COMPLETED
|
||||
self._record_handler(step, messages)
|
||||
self._update_cli_console(step, execution)
|
||||
execution.steps.append(step)
|
||||
|
||||
def reflect_on_result(self, tool_results: list[ToolResult]) -> str | None:
|
||||
"""Reflect on tool execution result. Override for custom reflection logic."""
|
||||
if len(tool_results) == 0:
|
||||
return None
|
||||
|
||||
reflection = "\n".join(
|
||||
f"The tool execution failed with error: {tool_result.error}. Consider trying a different approach or fixing the parameters."
|
||||
for tool_result in tool_results
|
||||
if not tool_result.success
|
||||
)
|
||||
|
||||
return reflection
|
||||
|
||||
def llm_indicates_task_completed(self, llm_response: LLMResponse) -> bool:
|
||||
"""Check if the LLM indicates that the task is completed. Override for custom logic."""
|
||||
completion_indicators = [
|
||||
"task completed",
|
||||
"task finished",
|
||||
"done",
|
||||
"completed successfully",
|
||||
"finished successfully",
|
||||
]
|
||||
|
||||
response_lower = llm_response.content.lower()
|
||||
return any(indicator in response_lower for indicator in completion_indicators)
|
||||
|
||||
def _is_task_completed(self, llm_response: LLMResponse) -> bool: # pyright: ignore[reportUnusedParameter]
|
||||
"""Check if the task is completed based on the response. Override for custom logic."""
|
||||
return True
|
||||
|
||||
def task_incomplete_message(self) -> str:
|
||||
"""Return a message indicating that the task is incomplete. Override for custom logic."""
|
||||
return "The task is incomplete. Please try again."
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_mcp_clients(self) -> None:
|
||||
"""Clean up MCP clients. Override in subclasses that use MCP."""
|
||||
pass
|
||||
|
||||
def _update_cli_console(
|
||||
self, step: AgentStep | None = None, agent_execution: AgentExecution | None = None
|
||||
) -> None:
|
||||
if self.cli_console:
|
||||
self.cli_console.update_status(step, agent_execution)
|
||||
|
||||
def _update_llm_usage(self, llm_response: LLMResponse, execution: AgentExecution):
|
||||
if not llm_response.usage:
|
||||
return
|
||||
# if execution.total_tokens is None then set it to be llm_response.usage else sum it up
|
||||
# execution.total_tokens is not None
|
||||
if not execution.total_tokens:
|
||||
execution.total_tokens = llm_response.usage
|
||||
else:
|
||||
execution.total_tokens += llm_response.usage
|
||||
|
||||
def _record_handler(self, step: AgentStep, messages: list[LLMMessage]) -> None:
|
||||
if self.trajectory_recorder:
|
||||
self.trajectory_recorder.record_agent_step(
|
||||
step_number=step.step_number,
|
||||
state=step.state.value,
|
||||
llm_messages=messages,
|
||||
llm_response=step.llm_response,
|
||||
tool_calls=step.tool_calls,
|
||||
tool_results=step.tool_results,
|
||||
reflection=step.reflection,
|
||||
error=step.error,
|
||||
)
|
||||
|
||||
async def _tool_call_handler(
|
||||
self, tool_calls: list[ToolCall] | None, step: AgentStep
|
||||
) -> list[LLMMessage]:
|
||||
messages: list[LLMMessage] = []
|
||||
if not tool_calls or len(tool_calls) <= 0:
|
||||
messages = [
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content="It seems that you have not completed the task.",
|
||||
)
|
||||
]
|
||||
return messages
|
||||
|
||||
step.state = AgentStepState.CALLING_TOOL
|
||||
step.tool_calls = tool_calls
|
||||
self._update_cli_console(step)
|
||||
|
||||
if self._model_config.parallel_tool_calls:
|
||||
tool_results = await self._tool_caller.parallel_tool_call(tool_calls)
|
||||
else:
|
||||
tool_results = await self._tool_caller.sequential_tool_call(tool_calls)
|
||||
step.tool_results = tool_results
|
||||
self._update_cli_console(step)
|
||||
for tool_result in tool_results:
|
||||
# Add tool result to conversation
|
||||
message = LLMMessage(role="user", tool_result=tool_result)
|
||||
messages.append(message)
|
||||
|
||||
reflection = self.reflect_on_result(tool_results)
|
||||
if reflection:
|
||||
step.state = AgentStepState.REFLECTING
|
||||
step.reflection = reflection
|
||||
|
||||
# Display reflection
|
||||
self._update_cli_console(step)
|
||||
|
||||
messages.append(LLMMessage(role="assistant", content=reflection))
|
||||
|
||||
return messages
|
||||
86
demo.py
86
demo.py
@@ -3,33 +3,15 @@ from loguru import logger
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
context = """
|
||||
def add(a, b):
|
||||
return a + b
|
||||
with open("assets/example_context.py", "r") as f:
|
||||
context = f.read()
|
||||
|
||||
def quick_sort(arr):
|
||||
if len(arr) <= 1:
|
||||
return arr
|
||||
pivot = arr[len(arr) // 2]
|
||||
left = [x for x in arr if x < pivot]
|
||||
middle = [x for x in arr if x == pivot]
|
||||
right = [x for x in arr if x > pivot]
|
||||
return quick_sort(left) + middle + quick_sort(right)
|
||||
|
||||
def search_with_binary_search(arr, target):
|
||||
left, right = 0, len(arr) - 1
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
if arr[mid] == target:
|
||||
return mid
|
||||
elif arr[mid] < target:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid - 1
|
||||
return -1
|
||||
"""
|
||||
|
||||
question = "How to write a quick sort algorithm?"
|
||||
question = '''
|
||||
async def _finalize_step(
|
||||
self, step: "AgentStep", messages: list["LLMMessage"], execution: "AgentExecution"
|
||||
) -> None:
|
||||
step.state = AgentStepState.COMPLETED
|
||||
'''
|
||||
|
||||
# Initialize compressor
|
||||
logger.info("Initializing compressor...")
|
||||
@@ -44,48 +26,22 @@ if __name__ == "__main__":
|
||||
target_ratio = min(1.0, max(0.0, target_token / original_tokens))
|
||||
logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}")
|
||||
|
||||
result = compressor.compress_code_file(
|
||||
code=context,
|
||||
query=question, # Using current function context as query focus
|
||||
instruction="Complete the following code function given the context.",
|
||||
rate=target_ratio,
|
||||
rank_only=False, # False to use fine-grained compression
|
||||
fine_grained_importance_method="contrastive_perplexity", # Explicitly test default
|
||||
min_lines_for_fine_grained=5, # Min number of lines for fine-grained compression
|
||||
importance_beta=0.5, # Sensitivity to importance score
|
||||
use_knapsack=True,
|
||||
)
|
||||
|
||||
# show the compressed code
|
||||
logger.info(f"Compressed code (using {result['fine_grained_method_used']}): \n{result['compressed_code']}")
|
||||
logger.info(f"Current function context: \n{question}")
|
||||
# final prompt
|
||||
final_prompt = result['compressed_prompt']
|
||||
# get the completion
|
||||
tokenized_prompt = compressor.tokenizer(final_prompt, return_tensors="pt").to(compressor.device)
|
||||
# Increase max_new_tokens for potentially longer completions
|
||||
completion_ids = compressor.model.generate(**tokenized_prompt, max_new_tokens=128, pad_token_id=compressor.tokenizer.eos_token_id)
|
||||
# Decode only the generated part, skipping special tokens
|
||||
completion = compressor.tokenizer.decode(completion_ids[0][len(tokenized_prompt.input_ids[0]):], skip_special_tokens=True)
|
||||
|
||||
# Basic cleanup: remove leading/trailing whitespace and potentially stop words if needed
|
||||
completion = completion.strip()
|
||||
# More robust cleanup: Find the first meaningful line if generation includes noise
|
||||
completion_lines = [line for line in completion.split("\n") if line.strip() and not line.strip().startswith(("#", "//"))] # Simple comment removal
|
||||
cleaned_completion = completion_lines[0] if completion_lines else completion # Take first non-comment line or original if none found
|
||||
|
||||
logger.info(f"Cleaned Completion: {cleaned_completion}")
|
||||
|
||||
# Optional: Test with conditional_ppl method
|
||||
logger.info("\nTesting fine-grained compression with conditional_ppl...")
|
||||
logger.info("\nTesting compression with Coarse-grained compression only...")
|
||||
result_cond = compressor.compress_code_file(
|
||||
code=context,
|
||||
query=question,
|
||||
instruction="Complete the following code function given the context.",
|
||||
rate=target_ratio,
|
||||
rank_only=False,
|
||||
fine_grained_importance_method="conditional_ppl",
|
||||
min_lines_for_fine_grained=5,
|
||||
importance_beta=0.5
|
||||
rank_only=True # Coarse-grained compression
|
||||
)
|
||||
logger.info(f"Compressed code (using {result_cond['fine_grained_method_used']}): \n{result_cond['compressed_code']}")
|
||||
logger.info(f"Compressed prompt: \n{result_cond['compressed_prompt']}")
|
||||
|
||||
logger.info("\nTesting compression with Coarse-grained and Fine-grained compression...")
|
||||
result_cond = compressor.compress_code_file(
|
||||
code=context,
|
||||
query=question,
|
||||
instruction="Complete the following code function given the context.",
|
||||
rate=target_ratio,
|
||||
rank_only=False # Corase-grained and Fine-grained compression
|
||||
)
|
||||
logger.info(f"Compressed prompt: \n{result_cond['compressed_prompt']}")
|
||||
@@ -1803,93 +1803,4 @@ class CodeCompressor:
|
||||
selected.add(idx)
|
||||
current_weight += weight
|
||||
|
||||
return selected
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
context = """
|
||||
def add(a, b):
|
||||
return a + b
|
||||
|
||||
def quick_sort(arr):
|
||||
if len(arr) <= 1:
|
||||
return arr
|
||||
pivot = arr[len(arr) // 2]
|
||||
left = [x for x in arr if x < pivot]
|
||||
middle = [x for x in arr if x == pivot]
|
||||
right = [x for x in arr if x > pivot]
|
||||
return quick_sort(left) + middle + quick_sort(right)
|
||||
|
||||
def search_with_binary_search(arr, target):
|
||||
left, right = 0, len(arr) - 1
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
if arr[mid] == target:
|
||||
return mid
|
||||
elif arr[mid] < target:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid - 1
|
||||
return -1
|
||||
"""
|
||||
|
||||
question = "How to write a quick sort algorithm?"
|
||||
|
||||
# Initialize compressor
|
||||
logger.info("Initializing compressor...")
|
||||
model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"
|
||||
compressor = CodeCompressor(model_name=model_name)
|
||||
|
||||
# Test function-based code file compression with query
|
||||
logger.info("\nTesting function-based code file compression with query...")
|
||||
|
||||
original_tokens = len(compressor.tokenizer.encode(context))
|
||||
target_token = 64
|
||||
target_ratio = min(1.0, max(0.0, target_token / original_tokens))
|
||||
logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}")
|
||||
|
||||
result = compressor.compress_code_file(
|
||||
code=context,
|
||||
query=question, # Using current function context as query focus
|
||||
instruction="Complete the following code function given the context.",
|
||||
rate=target_ratio,
|
||||
rank_only=True, # Only use coarse-grained compression
|
||||
fine_grained_importance_method="conditional_ppl", # Explicitly test default
|
||||
min_lines_for_fine_grained=5, # Min number of lines for fine-grained compression
|
||||
importance_beta=0.5, # Sensitivity to importance score
|
||||
use_knapsack=True,
|
||||
)
|
||||
|
||||
# show the compressed code
|
||||
logger.info(f"Compressed code (using {result['fine_grained_method_used']}): \n{result['compressed_code']}")
|
||||
logger.info(f"Current function context: \n{question}")
|
||||
# final prompt
|
||||
final_prompt = result['compressed_prompt']
|
||||
# get the completion
|
||||
tokenized_prompt = compressor.tokenizer(final_prompt, return_tensors="pt").to(compressor.device)
|
||||
# Increase max_new_tokens for potentially longer completions
|
||||
completion_ids = compressor.model.generate(**tokenized_prompt, max_new_tokens=128, pad_token_id=compressor.tokenizer.eos_token_id)
|
||||
# Decode only the generated part, skipping special tokens
|
||||
completion = compressor.tokenizer.decode(completion_ids[0][len(tokenized_prompt.input_ids[0]):], skip_special_tokens=True)
|
||||
|
||||
# Basic cleanup: remove leading/trailing whitespace and potentially stop words if needed
|
||||
completion = completion.strip()
|
||||
# More robust cleanup: Find the first meaningful line if generation includes noise
|
||||
completion_lines = [line for line in completion.split("\n") if line.strip() and not line.strip().startswith(("#", "//"))] # Simple comment removal
|
||||
cleaned_completion = completion_lines[0] if completion_lines else completion # Take first non-comment line or original if none found
|
||||
|
||||
logger.info(f"Cleaned Completion: {cleaned_completion}")
|
||||
|
||||
# Optional: Test with conditional_ppl method
|
||||
logger.info("\nTesting fine-grained compression with conditional_ppl...")
|
||||
result_cond = compressor.compress_code_file(
|
||||
code=context,
|
||||
query=question,
|
||||
instruction="Complete the following code function given the context.",
|
||||
rate=target_ratio,
|
||||
rank_only=False,
|
||||
fine_grained_importance_method="conditional_ppl",
|
||||
min_lines_for_fine_grained=5,
|
||||
importance_beta=0.5
|
||||
)
|
||||
logger.info(f"Compressed code (using {result_cond['fine_grained_method_used']}): \n{result_cond['compressed_code']}")
|
||||
return selected
|
||||
Reference in New Issue
Block a user