mirror of
https://github.com/YerbaPage/LongCodeZip.git
synced 2025-10-22 23:19:46 +03:00
279 lines
11 KiB
Python
279 lines
11 KiB
Python
# The example file is from trae-agent
|
|
|
|
"""Base Agent class for LLM-based agents."""
|
|
|
|
class BaseAgent(ABC):
|
|
"""Base class for LLM-based agents."""
|
|
|
|
_tool_caller: Union[ToolExecutor, DockerToolExecutor]
|
|
|
|
def __init__(
|
|
self, agent_config: AgentConfig, docker_config: dict | None = None, docker_keep: bool = True
|
|
):
|
|
"""Initialize the agent.
|
|
Args:
|
|
agent_config: Configuration object containing model parameters and other settings.
|
|
docker_config: Configuration for running in a Docker environment.
|
|
"""
|
|
self._llm_client = LLMClient(agent_config.model)
|
|
self._model_config = agent_config.model
|
|
self._max_steps = agent_config.max_steps
|
|
self._initial_messages: list[LLMMessage] = []
|
|
self._task: str = ""
|
|
self._tools: list[Tool] = [
|
|
tools_registry[tool_name](model_provider=self._model_config.model_provider.provider)
|
|
for tool_name in agent_config.tools
|
|
]
|
|
self.docker_keep = docker_keep
|
|
self.docker_manager: DockerManager | None = None
|
|
original_tool_executor = ToolExecutor(self._tools)
|
|
if docker_config:
|
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
# tools_dir = os.path.join(project_root, 'tools')
|
|
|
|
tools_dir = os.path.join(project_root, "dist")
|
|
|
|
is_interactive_mode = False
|
|
self.docker_manager = DockerManager(
|
|
image=docker_config.get("image"),
|
|
container_id=docker_config.get("container_id"),
|
|
dockerfile_path=docker_config.get("dockerfile_path"),
|
|
docker_image_file=docker_config.get("docker_image_file"),
|
|
workspace_dir=docker_config["workspace_dir"],
|
|
tools_dir=tools_dir,
|
|
interactive=is_interactive_mode,
|
|
)
|
|
self._tool_caller = DockerToolExecutor(
|
|
original_executor=original_tool_executor,
|
|
docker_manager=self.docker_manager,
|
|
docker_tools=["bash", "str_replace_based_edit_tool", "json_edit_tool"],
|
|
host_workspace_dir=docker_config.get("workspace_dir"),
|
|
container_workspace_dir=self.docker_manager.container_workspace,
|
|
)
|
|
else:
|
|
self._tool_caller = original_tool_executor
|
|
|
|
self._cli_console: CLIConsole | None = None
|
|
|
|
# Trajectory recorder
|
|
self._trajectory_recorder: TrajectoryRecorder | None = None
|
|
|
|
# CKG tool-specific: clear the older CKG databases
|
|
clear_older_ckg()
|
|
|
|
@abstractmethod
|
|
def new_task(
|
|
self,
|
|
task: str,
|
|
extra_args: dict[str, str] | None = None,
|
|
tool_names: list[str] | None = None,
|
|
):
|
|
"""Create a new task."""
|
|
pass
|
|
|
|
async def execute_task(self) -> AgentExecution:
|
|
"""Execute a task using the agent."""
|
|
import time
|
|
|
|
if self.docker_manager:
|
|
self.docker_manager.start()
|
|
|
|
start_time = time.time()
|
|
execution = AgentExecution(task=self._task, steps=[])
|
|
step: AgentStep | None = None
|
|
|
|
try:
|
|
messages = self._initial_messages
|
|
step_number = 1
|
|
execution.agent_state = AgentState.RUNNING
|
|
|
|
while step_number <= self._max_steps:
|
|
step = AgentStep(step_number=step_number, state=AgentStepState.THINKING)
|
|
try:
|
|
messages = await self._run_llm_step(step, messages, execution)
|
|
await self._finalize_step(
|
|
step, messages, execution
|
|
) # record trajectory for this step and update the CLI console
|
|
if execution.agent_state == AgentState.COMPLETED:
|
|
break
|
|
step_number += 1
|
|
except Exception as error:
|
|
execution.agent_state = AgentState.ERROR
|
|
step.state = AgentStepState.ERROR
|
|
step.error = str(error)
|
|
await self._finalize_step(step, messages, execution)
|
|
break
|
|
if step_number > self._max_steps and not execution.success:
|
|
execution.final_result = "Task execution exceeded maximum steps without completion."
|
|
execution.agent_state = AgentState.ERROR
|
|
|
|
except Exception as e:
|
|
execution.final_result = f"Agent execution failed: {str(e)}"
|
|
|
|
finally:
|
|
if self.docker_manager and not self.docker_keep:
|
|
self.docker_manager.stop()
|
|
|
|
# Ensure tool resources are released whether an exception occurs or not.
|
|
await self._close_tools()
|
|
|
|
execution.execution_time = time.time() - start_time
|
|
|
|
# Clean up any MCP clients
|
|
with contextlib.suppress(Exception):
|
|
await self.cleanup_mcp_clients()
|
|
|
|
self._update_cli_console(step, execution)
|
|
return execution
|
|
|
|
async def _close_tools(self):
|
|
"""Release tool resources, mainly about BashTool object."""
|
|
if self._tool_caller:
|
|
# Ensure all tool resources are properly released.
|
|
res = await self._tool_caller.close_tools()
|
|
return res
|
|
|
|
async def _run_llm_step(
|
|
self, step: "AgentStep", messages: list["LLMMessage"], execution: "AgentExecution"
|
|
) -> list["LLMMessage"]:
|
|
# Display thinking state
|
|
step.state = AgentStepState.THINKING
|
|
self._update_cli_console(step, execution)
|
|
# Get LLM response
|
|
llm_response = self._llm_client.chat(messages, self._model_config, self._tools)
|
|
step.llm_response = llm_response
|
|
|
|
# Display step with LLM response
|
|
self._update_cli_console(step, execution)
|
|
|
|
# Update token usage
|
|
self._update_llm_usage(llm_response, execution)
|
|
|
|
if self.llm_indicates_task_completed(llm_response):
|
|
if self._is_task_completed(llm_response):
|
|
execution.agent_state = AgentState.COMPLETED
|
|
execution.final_result = llm_response.content
|
|
execution.success = True
|
|
return messages
|
|
else:
|
|
execution.agent_state = AgentState.RUNNING
|
|
return [LLMMessage(role="user", content=self.task_incomplete_message())]
|
|
else:
|
|
tool_calls = llm_response.tool_calls
|
|
return await self._tool_call_handler(tool_calls, step)
|
|
|
|
async def _finalize_step(
|
|
self, step: "AgentStep", messages: list["LLMMessage"], execution: "AgentExecution"
|
|
) -> None:
|
|
step.state = AgentStepState.COMPLETED
|
|
self._record_handler(step, messages)
|
|
self._update_cli_console(step, execution)
|
|
execution.steps.append(step)
|
|
|
|
def reflect_on_result(self, tool_results: list[ToolResult]) -> str | None:
|
|
"""Reflect on tool execution result. Override for custom reflection logic."""
|
|
if len(tool_results) == 0:
|
|
return None
|
|
|
|
reflection = "\n".join(
|
|
f"The tool execution failed with error: {tool_result.error}. Consider trying a different approach or fixing the parameters."
|
|
for tool_result in tool_results
|
|
if not tool_result.success
|
|
)
|
|
|
|
return reflection
|
|
|
|
def llm_indicates_task_completed(self, llm_response: LLMResponse) -> bool:
|
|
"""Check if the LLM indicates that the task is completed. Override for custom logic."""
|
|
completion_indicators = [
|
|
"task completed",
|
|
"task finished",
|
|
"done",
|
|
"completed successfully",
|
|
"finished successfully",
|
|
]
|
|
|
|
response_lower = llm_response.content.lower()
|
|
return any(indicator in response_lower for indicator in completion_indicators)
|
|
|
|
def _is_task_completed(self, llm_response: LLMResponse) -> bool: # pyright: ignore[reportUnusedParameter]
|
|
"""Check if the task is completed based on the response. Override for custom logic."""
|
|
return True
|
|
|
|
def task_incomplete_message(self) -> str:
|
|
"""Return a message indicating that the task is incomplete. Override for custom logic."""
|
|
return "The task is incomplete. Please try again."
|
|
|
|
@abstractmethod
|
|
async def cleanup_mcp_clients(self) -> None:
|
|
"""Clean up MCP clients. Override in subclasses that use MCP."""
|
|
pass
|
|
|
|
def _update_cli_console(
|
|
self, step: AgentStep | None = None, agent_execution: AgentExecution | None = None
|
|
) -> None:
|
|
if self.cli_console:
|
|
self.cli_console.update_status(step, agent_execution)
|
|
|
|
def _update_llm_usage(self, llm_response: LLMResponse, execution: AgentExecution):
|
|
if not llm_response.usage:
|
|
return
|
|
# if execution.total_tokens is None then set it to be llm_response.usage else sum it up
|
|
# execution.total_tokens is not None
|
|
if not execution.total_tokens:
|
|
execution.total_tokens = llm_response.usage
|
|
else:
|
|
execution.total_tokens += llm_response.usage
|
|
|
|
def _record_handler(self, step: AgentStep, messages: list[LLMMessage]) -> None:
|
|
if self.trajectory_recorder:
|
|
self.trajectory_recorder.record_agent_step(
|
|
step_number=step.step_number,
|
|
state=step.state.value,
|
|
llm_messages=messages,
|
|
llm_response=step.llm_response,
|
|
tool_calls=step.tool_calls,
|
|
tool_results=step.tool_results,
|
|
reflection=step.reflection,
|
|
error=step.error,
|
|
)
|
|
|
|
async def _tool_call_handler(
|
|
self, tool_calls: list[ToolCall] | None, step: AgentStep
|
|
) -> list[LLMMessage]:
|
|
messages: list[LLMMessage] = []
|
|
if not tool_calls or len(tool_calls) <= 0:
|
|
messages = [
|
|
LLMMessage(
|
|
role="user",
|
|
content="It seems that you have not completed the task.",
|
|
)
|
|
]
|
|
return messages
|
|
|
|
step.state = AgentStepState.CALLING_TOOL
|
|
step.tool_calls = tool_calls
|
|
self._update_cli_console(step)
|
|
|
|
if self._model_config.parallel_tool_calls:
|
|
tool_results = await self._tool_caller.parallel_tool_call(tool_calls)
|
|
else:
|
|
tool_results = await self._tool_caller.sequential_tool_call(tool_calls)
|
|
step.tool_results = tool_results
|
|
self._update_cli_console(step)
|
|
for tool_result in tool_results:
|
|
# Add tool result to conversation
|
|
message = LLMMessage(role="user", tool_result=tool_result)
|
|
messages.append(message)
|
|
|
|
reflection = self.reflect_on_result(tool_results)
|
|
if reflection:
|
|
step.state = AgentStepState.REFLECTING
|
|
step.reflection = reflection
|
|
|
|
# Display reflection
|
|
self._update_cli_console(step)
|
|
|
|
messages.append(LLMMessage(role="assistant", content=reflection))
|
|
|
|
return messages |