diff --git a/assets/example_context.py b/assets/example_context.py new file mode 100644 index 0000000..841e578 --- /dev/null +++ b/assets/example_context.py @@ -0,0 +1,279 @@ +# The example file is from trae-agent + +"""Base Agent class for LLM-based agents.""" + +class BaseAgent(ABC): + """Base class for LLM-based agents.""" + + _tool_caller: Union[ToolExecutor, DockerToolExecutor] + + def __init__( + self, agent_config: AgentConfig, docker_config: dict | None = None, docker_keep: bool = True + ): + """Initialize the agent. + Args: + agent_config: Configuration object containing model parameters and other settings. + docker_config: Configuration for running in a Docker environment. + """ + self._llm_client = LLMClient(agent_config.model) + self._model_config = agent_config.model + self._max_steps = agent_config.max_steps + self._initial_messages: list[LLMMessage] = [] + self._task: str = "" + self._tools: list[Tool] = [ + tools_registry[tool_name](model_provider=self._model_config.model_provider.provider) + for tool_name in agent_config.tools + ] + self.docker_keep = docker_keep + self.docker_manager: DockerManager | None = None + original_tool_executor = ToolExecutor(self._tools) + if docker_config: + project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + # tools_dir = os.path.join(project_root, 'tools') + + tools_dir = os.path.join(project_root, "dist") + + is_interactive_mode = False + self.docker_manager = DockerManager( + image=docker_config.get("image"), + container_id=docker_config.get("container_id"), + dockerfile_path=docker_config.get("dockerfile_path"), + docker_image_file=docker_config.get("docker_image_file"), + workspace_dir=docker_config["workspace_dir"], + tools_dir=tools_dir, + interactive=is_interactive_mode, + ) + self._tool_caller = DockerToolExecutor( + original_executor=original_tool_executor, + docker_manager=self.docker_manager, + docker_tools=["bash", "str_replace_based_edit_tool", "json_edit_tool"], + host_workspace_dir=docker_config.get("workspace_dir"), + container_workspace_dir=self.docker_manager.container_workspace, + ) + else: + self._tool_caller = original_tool_executor + + self._cli_console: CLIConsole | None = None + + # Trajectory recorder + self._trajectory_recorder: TrajectoryRecorder | None = None + + # CKG tool-specific: clear the older CKG databases + clear_older_ckg() + + @abstractmethod + def new_task( + self, + task: str, + extra_args: dict[str, str] | None = None, + tool_names: list[str] | None = None, + ): + """Create a new task.""" + pass + + async def execute_task(self) -> AgentExecution: + """Execute a task using the agent.""" + import time + + if self.docker_manager: + self.docker_manager.start() + + start_time = time.time() + execution = AgentExecution(task=self._task, steps=[]) + step: AgentStep | None = None + + try: + messages = self._initial_messages + step_number = 1 + execution.agent_state = AgentState.RUNNING + + while step_number <= self._max_steps: + step = AgentStep(step_number=step_number, state=AgentStepState.THINKING) + try: + messages = await self._run_llm_step(step, messages, execution) + await self._finalize_step( + step, messages, execution + ) # record trajectory for this step and update the CLI console + if execution.agent_state == AgentState.COMPLETED: + break + step_number += 1 + except Exception as error: + execution.agent_state = AgentState.ERROR + step.state = AgentStepState.ERROR + step.error = str(error) + await self._finalize_step(step, messages, execution) + break + if step_number > self._max_steps and not execution.success: + execution.final_result = "Task execution exceeded maximum steps without completion." + execution.agent_state = AgentState.ERROR + + except Exception as e: + execution.final_result = f"Agent execution failed: {str(e)}" + + finally: + if self.docker_manager and not self.docker_keep: + self.docker_manager.stop() + + # Ensure tool resources are released whether an exception occurs or not. + await self._close_tools() + + execution.execution_time = time.time() - start_time + + # Clean up any MCP clients + with contextlib.suppress(Exception): + await self.cleanup_mcp_clients() + + self._update_cli_console(step, execution) + return execution + + async def _close_tools(self): + """Release tool resources, mainly about BashTool object.""" + if self._tool_caller: + # Ensure all tool resources are properly released. + res = await self._tool_caller.close_tools() + return res + + async def _run_llm_step( + self, step: "AgentStep", messages: list["LLMMessage"], execution: "AgentExecution" + ) -> list["LLMMessage"]: + # Display thinking state + step.state = AgentStepState.THINKING + self._update_cli_console(step, execution) + # Get LLM response + llm_response = self._llm_client.chat(messages, self._model_config, self._tools) + step.llm_response = llm_response + + # Display step with LLM response + self._update_cli_console(step, execution) + + # Update token usage + self._update_llm_usage(llm_response, execution) + + if self.llm_indicates_task_completed(llm_response): + if self._is_task_completed(llm_response): + execution.agent_state = AgentState.COMPLETED + execution.final_result = llm_response.content + execution.success = True + return messages + else: + execution.agent_state = AgentState.RUNNING + return [LLMMessage(role="user", content=self.task_incomplete_message())] + else: + tool_calls = llm_response.tool_calls + return await self._tool_call_handler(tool_calls, step) + + async def _finalize_step( + self, step: "AgentStep", messages: list["LLMMessage"], execution: "AgentExecution" + ) -> None: + step.state = AgentStepState.COMPLETED + self._record_handler(step, messages) + self._update_cli_console(step, execution) + execution.steps.append(step) + + def reflect_on_result(self, tool_results: list[ToolResult]) -> str | None: + """Reflect on tool execution result. Override for custom reflection logic.""" + if len(tool_results) == 0: + return None + + reflection = "\n".join( + f"The tool execution failed with error: {tool_result.error}. Consider trying a different approach or fixing the parameters." + for tool_result in tool_results + if not tool_result.success + ) + + return reflection + + def llm_indicates_task_completed(self, llm_response: LLMResponse) -> bool: + """Check if the LLM indicates that the task is completed. Override for custom logic.""" + completion_indicators = [ + "task completed", + "task finished", + "done", + "completed successfully", + "finished successfully", + ] + + response_lower = llm_response.content.lower() + return any(indicator in response_lower for indicator in completion_indicators) + + def _is_task_completed(self, llm_response: LLMResponse) -> bool: # pyright: ignore[reportUnusedParameter] + """Check if the task is completed based on the response. Override for custom logic.""" + return True + + def task_incomplete_message(self) -> str: + """Return a message indicating that the task is incomplete. Override for custom logic.""" + return "The task is incomplete. Please try again." + + @abstractmethod + async def cleanup_mcp_clients(self) -> None: + """Clean up MCP clients. Override in subclasses that use MCP.""" + pass + + def _update_cli_console( + self, step: AgentStep | None = None, agent_execution: AgentExecution | None = None + ) -> None: + if self.cli_console: + self.cli_console.update_status(step, agent_execution) + + def _update_llm_usage(self, llm_response: LLMResponse, execution: AgentExecution): + if not llm_response.usage: + return + # if execution.total_tokens is None then set it to be llm_response.usage else sum it up + # execution.total_tokens is not None + if not execution.total_tokens: + execution.total_tokens = llm_response.usage + else: + execution.total_tokens += llm_response.usage + + def _record_handler(self, step: AgentStep, messages: list[LLMMessage]) -> None: + if self.trajectory_recorder: + self.trajectory_recorder.record_agent_step( + step_number=step.step_number, + state=step.state.value, + llm_messages=messages, + llm_response=step.llm_response, + tool_calls=step.tool_calls, + tool_results=step.tool_results, + reflection=step.reflection, + error=step.error, + ) + + async def _tool_call_handler( + self, tool_calls: list[ToolCall] | None, step: AgentStep + ) -> list[LLMMessage]: + messages: list[LLMMessage] = [] + if not tool_calls or len(tool_calls) <= 0: + messages = [ + LLMMessage( + role="user", + content="It seems that you have not completed the task.", + ) + ] + return messages + + step.state = AgentStepState.CALLING_TOOL + step.tool_calls = tool_calls + self._update_cli_console(step) + + if self._model_config.parallel_tool_calls: + tool_results = await self._tool_caller.parallel_tool_call(tool_calls) + else: + tool_results = await self._tool_caller.sequential_tool_call(tool_calls) + step.tool_results = tool_results + self._update_cli_console(step) + for tool_result in tool_results: + # Add tool result to conversation + message = LLMMessage(role="user", tool_result=tool_result) + messages.append(message) + + reflection = self.reflect_on_result(tool_results) + if reflection: + step.state = AgentStepState.REFLECTING + step.reflection = reflection + + # Display reflection + self._update_cli_console(step) + + messages.append(LLMMessage(role="assistant", content=reflection)) + + return messages \ No newline at end of file diff --git a/demo.py b/demo.py index 1935264..8371669 100644 --- a/demo.py +++ b/demo.py @@ -3,33 +3,15 @@ from loguru import logger if __name__ == "__main__": - context = """ - def add(a, b): - return a + b + with open("assets/example_context.py", "r") as f: + context = f.read() - def quick_sort(arr): - if len(arr) <= 1: - return arr - pivot = arr[len(arr) // 2] - left = [x for x in arr if x < pivot] - middle = [x for x in arr if x == pivot] - right = [x for x in arr if x > pivot] - return quick_sort(left) + middle + quick_sort(right) - - def search_with_binary_search(arr, target): - left, right = 0, len(arr) - 1 - while left <= right: - mid = (left + right) // 2 - if arr[mid] == target: - return mid - elif arr[mid] < target: - left = mid + 1 - else: - right = mid - 1 - return -1 - """ - - question = "How to write a quick sort algorithm?" + question = ''' + async def _finalize_step( + self, step: "AgentStep", messages: list["LLMMessage"], execution: "AgentExecution" + ) -> None: + step.state = AgentStepState.COMPLETED + ''' # Initialize compressor logger.info("Initializing compressor...") @@ -44,48 +26,22 @@ if __name__ == "__main__": target_ratio = min(1.0, max(0.0, target_token / original_tokens)) logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}") - result = compressor.compress_code_file( - code=context, - query=question, # Using current function context as query focus - instruction="Complete the following code function given the context.", - rate=target_ratio, - rank_only=False, # False to use fine-grained compression - fine_grained_importance_method="contrastive_perplexity", # Explicitly test default - min_lines_for_fine_grained=5, # Min number of lines for fine-grained compression - importance_beta=0.5, # Sensitivity to importance score - use_knapsack=True, - ) - - # show the compressed code - logger.info(f"Compressed code (using {result['fine_grained_method_used']}): \n{result['compressed_code']}") - logger.info(f"Current function context: \n{question}") - # final prompt - final_prompt = result['compressed_prompt'] - # get the completion - tokenized_prompt = compressor.tokenizer(final_prompt, return_tensors="pt").to(compressor.device) - # Increase max_new_tokens for potentially longer completions - completion_ids = compressor.model.generate(**tokenized_prompt, max_new_tokens=128, pad_token_id=compressor.tokenizer.eos_token_id) - # Decode only the generated part, skipping special tokens - completion = compressor.tokenizer.decode(completion_ids[0][len(tokenized_prompt.input_ids[0]):], skip_special_tokens=True) - - # Basic cleanup: remove leading/trailing whitespace and potentially stop words if needed - completion = completion.strip() - # More robust cleanup: Find the first meaningful line if generation includes noise - completion_lines = [line for line in completion.split("\n") if line.strip() and not line.strip().startswith(("#", "//"))] # Simple comment removal - cleaned_completion = completion_lines[0] if completion_lines else completion # Take first non-comment line or original if none found - - logger.info(f"Cleaned Completion: {cleaned_completion}") - - # Optional: Test with conditional_ppl method - logger.info("\nTesting fine-grained compression with conditional_ppl...") + logger.info("\nTesting compression with Coarse-grained compression only...") result_cond = compressor.compress_code_file( code=context, query=question, instruction="Complete the following code function given the context.", rate=target_ratio, - rank_only=False, - fine_grained_importance_method="conditional_ppl", - min_lines_for_fine_grained=5, - importance_beta=0.5 + rank_only=True # Coarse-grained compression ) - logger.info(f"Compressed code (using {result_cond['fine_grained_method_used']}): \n{result_cond['compressed_code']}") \ No newline at end of file + logger.info(f"Compressed prompt: \n{result_cond['compressed_prompt']}") + + logger.info("\nTesting compression with Coarse-grained and Fine-grained compression...") + result_cond = compressor.compress_code_file( + code=context, + query=question, + instruction="Complete the following code function given the context.", + rate=target_ratio, + rank_only=False # Corase-grained and Fine-grained compression + ) + logger.info(f"Compressed prompt: \n{result_cond['compressed_prompt']}") \ No newline at end of file diff --git a/longcodezip.py b/longcodezip.py index 4625286..04721ea 100644 --- a/longcodezip.py +++ b/longcodezip.py @@ -1803,93 +1803,4 @@ class CodeCompressor: selected.add(idx) current_weight += weight - return selected - -if __name__ == "__main__": - - context = """ - def add(a, b): - return a + b - - def quick_sort(arr): - if len(arr) <= 1: - return arr - pivot = arr[len(arr) // 2] - left = [x for x in arr if x < pivot] - middle = [x for x in arr if x == pivot] - right = [x for x in arr if x > pivot] - return quick_sort(left) + middle + quick_sort(right) - - def search_with_binary_search(arr, target): - left, right = 0, len(arr) - 1 - while left <= right: - mid = (left + right) // 2 - if arr[mid] == target: - return mid - elif arr[mid] < target: - left = mid + 1 - else: - right = mid - 1 - return -1 - """ - - question = "How to write a quick sort algorithm?" - - # Initialize compressor - logger.info("Initializing compressor...") - model_name = "Qwen/Qwen2.5-Coder-7B-Instruct" - compressor = CodeCompressor(model_name=model_name) - - # Test function-based code file compression with query - logger.info("\nTesting function-based code file compression with query...") - - original_tokens = len(compressor.tokenizer.encode(context)) - target_token = 64 - target_ratio = min(1.0, max(0.0, target_token / original_tokens)) - logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}") - - result = compressor.compress_code_file( - code=context, - query=question, # Using current function context as query focus - instruction="Complete the following code function given the context.", - rate=target_ratio, - rank_only=True, # Only use coarse-grained compression - fine_grained_importance_method="conditional_ppl", # Explicitly test default - min_lines_for_fine_grained=5, # Min number of lines for fine-grained compression - importance_beta=0.5, # Sensitivity to importance score - use_knapsack=True, - ) - - # show the compressed code - logger.info(f"Compressed code (using {result['fine_grained_method_used']}): \n{result['compressed_code']}") - logger.info(f"Current function context: \n{question}") - # final prompt - final_prompt = result['compressed_prompt'] - # get the completion - tokenized_prompt = compressor.tokenizer(final_prompt, return_tensors="pt").to(compressor.device) - # Increase max_new_tokens for potentially longer completions - completion_ids = compressor.model.generate(**tokenized_prompt, max_new_tokens=128, pad_token_id=compressor.tokenizer.eos_token_id) - # Decode only the generated part, skipping special tokens - completion = compressor.tokenizer.decode(completion_ids[0][len(tokenized_prompt.input_ids[0]):], skip_special_tokens=True) - - # Basic cleanup: remove leading/trailing whitespace and potentially stop words if needed - completion = completion.strip() - # More robust cleanup: Find the first meaningful line if generation includes noise - completion_lines = [line for line in completion.split("\n") if line.strip() and not line.strip().startswith(("#", "//"))] # Simple comment removal - cleaned_completion = completion_lines[0] if completion_lines else completion # Take first non-comment line or original if none found - - logger.info(f"Cleaned Completion: {cleaned_completion}") - - # Optional: Test with conditional_ppl method - logger.info("\nTesting fine-grained compression with conditional_ppl...") - result_cond = compressor.compress_code_file( - code=context, - query=question, - instruction="Complete the following code function given the context.", - rate=target_ratio, - rank_only=False, - fine_grained_importance_method="conditional_ppl", - min_lines_for_fine_grained=5, - importance_beta=0.5 - ) - logger.info(f"Compressed code (using {result_cond['fine_grained_method_used']}): \n{result_cond['compressed_code']}") \ No newline at end of file + return selected \ No newline at end of file