fix the demo

This commit is contained in:
YerbaPage
2025-10-11 21:09:19 +08:00
parent 61429182d1
commit 60201d365f
3 changed files with 301 additions and 155 deletions

279
assets/example_context.py Normal file
View File

@@ -0,0 +1,279 @@
# The example file is from trae-agent
"""Base Agent class for LLM-based agents."""
class BaseAgent(ABC):
"""Base class for LLM-based agents."""
_tool_caller: Union[ToolExecutor, DockerToolExecutor]
def __init__(
self, agent_config: AgentConfig, docker_config: dict | None = None, docker_keep: bool = True
):
"""Initialize the agent.
Args:
agent_config: Configuration object containing model parameters and other settings.
docker_config: Configuration for running in a Docker environment.
"""
self._llm_client = LLMClient(agent_config.model)
self._model_config = agent_config.model
self._max_steps = agent_config.max_steps
self._initial_messages: list[LLMMessage] = []
self._task: str = ""
self._tools: list[Tool] = [
tools_registry[tool_name](model_provider=self._model_config.model_provider.provider)
for tool_name in agent_config.tools
]
self.docker_keep = docker_keep
self.docker_manager: DockerManager | None = None
original_tool_executor = ToolExecutor(self._tools)
if docker_config:
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# tools_dir = os.path.join(project_root, 'tools')
tools_dir = os.path.join(project_root, "dist")
is_interactive_mode = False
self.docker_manager = DockerManager(
image=docker_config.get("image"),
container_id=docker_config.get("container_id"),
dockerfile_path=docker_config.get("dockerfile_path"),
docker_image_file=docker_config.get("docker_image_file"),
workspace_dir=docker_config["workspace_dir"],
tools_dir=tools_dir,
interactive=is_interactive_mode,
)
self._tool_caller = DockerToolExecutor(
original_executor=original_tool_executor,
docker_manager=self.docker_manager,
docker_tools=["bash", "str_replace_based_edit_tool", "json_edit_tool"],
host_workspace_dir=docker_config.get("workspace_dir"),
container_workspace_dir=self.docker_manager.container_workspace,
)
else:
self._tool_caller = original_tool_executor
self._cli_console: CLIConsole | None = None
# Trajectory recorder
self._trajectory_recorder: TrajectoryRecorder | None = None
# CKG tool-specific: clear the older CKG databases
clear_older_ckg()
@abstractmethod
def new_task(
self,
task: str,
extra_args: dict[str, str] | None = None,
tool_names: list[str] | None = None,
):
"""Create a new task."""
pass
async def execute_task(self) -> AgentExecution:
"""Execute a task using the agent."""
import time
if self.docker_manager:
self.docker_manager.start()
start_time = time.time()
execution = AgentExecution(task=self._task, steps=[])
step: AgentStep | None = None
try:
messages = self._initial_messages
step_number = 1
execution.agent_state = AgentState.RUNNING
while step_number <= self._max_steps:
step = AgentStep(step_number=step_number, state=AgentStepState.THINKING)
try:
messages = await self._run_llm_step(step, messages, execution)
await self._finalize_step(
step, messages, execution
) # record trajectory for this step and update the CLI console
if execution.agent_state == AgentState.COMPLETED:
break
step_number += 1
except Exception as error:
execution.agent_state = AgentState.ERROR
step.state = AgentStepState.ERROR
step.error = str(error)
await self._finalize_step(step, messages, execution)
break
if step_number > self._max_steps and not execution.success:
execution.final_result = "Task execution exceeded maximum steps without completion."
execution.agent_state = AgentState.ERROR
except Exception as e:
execution.final_result = f"Agent execution failed: {str(e)}"
finally:
if self.docker_manager and not self.docker_keep:
self.docker_manager.stop()
# Ensure tool resources are released whether an exception occurs or not.
await self._close_tools()
execution.execution_time = time.time() - start_time
# Clean up any MCP clients
with contextlib.suppress(Exception):
await self.cleanup_mcp_clients()
self._update_cli_console(step, execution)
return execution
async def _close_tools(self):
"""Release tool resources, mainly about BashTool object."""
if self._tool_caller:
# Ensure all tool resources are properly released.
res = await self._tool_caller.close_tools()
return res
async def _run_llm_step(
self, step: "AgentStep", messages: list["LLMMessage"], execution: "AgentExecution"
) -> list["LLMMessage"]:
# Display thinking state
step.state = AgentStepState.THINKING
self._update_cli_console(step, execution)
# Get LLM response
llm_response = self._llm_client.chat(messages, self._model_config, self._tools)
step.llm_response = llm_response
# Display step with LLM response
self._update_cli_console(step, execution)
# Update token usage
self._update_llm_usage(llm_response, execution)
if self.llm_indicates_task_completed(llm_response):
if self._is_task_completed(llm_response):
execution.agent_state = AgentState.COMPLETED
execution.final_result = llm_response.content
execution.success = True
return messages
else:
execution.agent_state = AgentState.RUNNING
return [LLMMessage(role="user", content=self.task_incomplete_message())]
else:
tool_calls = llm_response.tool_calls
return await self._tool_call_handler(tool_calls, step)
async def _finalize_step(
self, step: "AgentStep", messages: list["LLMMessage"], execution: "AgentExecution"
) -> None:
step.state = AgentStepState.COMPLETED
self._record_handler(step, messages)
self._update_cli_console(step, execution)
execution.steps.append(step)
def reflect_on_result(self, tool_results: list[ToolResult]) -> str | None:
"""Reflect on tool execution result. Override for custom reflection logic."""
if len(tool_results) == 0:
return None
reflection = "\n".join(
f"The tool execution failed with error: {tool_result.error}. Consider trying a different approach or fixing the parameters."
for tool_result in tool_results
if not tool_result.success
)
return reflection
def llm_indicates_task_completed(self, llm_response: LLMResponse) -> bool:
"""Check if the LLM indicates that the task is completed. Override for custom logic."""
completion_indicators = [
"task completed",
"task finished",
"done",
"completed successfully",
"finished successfully",
]
response_lower = llm_response.content.lower()
return any(indicator in response_lower for indicator in completion_indicators)
def _is_task_completed(self, llm_response: LLMResponse) -> bool: # pyright: ignore[reportUnusedParameter]
"""Check if the task is completed based on the response. Override for custom logic."""
return True
def task_incomplete_message(self) -> str:
"""Return a message indicating that the task is incomplete. Override for custom logic."""
return "The task is incomplete. Please try again."
@abstractmethod
async def cleanup_mcp_clients(self) -> None:
"""Clean up MCP clients. Override in subclasses that use MCP."""
pass
def _update_cli_console(
self, step: AgentStep | None = None, agent_execution: AgentExecution | None = None
) -> None:
if self.cli_console:
self.cli_console.update_status(step, agent_execution)
def _update_llm_usage(self, llm_response: LLMResponse, execution: AgentExecution):
if not llm_response.usage:
return
# if execution.total_tokens is None then set it to be llm_response.usage else sum it up
# execution.total_tokens is not None
if not execution.total_tokens:
execution.total_tokens = llm_response.usage
else:
execution.total_tokens += llm_response.usage
def _record_handler(self, step: AgentStep, messages: list[LLMMessage]) -> None:
if self.trajectory_recorder:
self.trajectory_recorder.record_agent_step(
step_number=step.step_number,
state=step.state.value,
llm_messages=messages,
llm_response=step.llm_response,
tool_calls=step.tool_calls,
tool_results=step.tool_results,
reflection=step.reflection,
error=step.error,
)
async def _tool_call_handler(
self, tool_calls: list[ToolCall] | None, step: AgentStep
) -> list[LLMMessage]:
messages: list[LLMMessage] = []
if not tool_calls or len(tool_calls) <= 0:
messages = [
LLMMessage(
role="user",
content="It seems that you have not completed the task.",
)
]
return messages
step.state = AgentStepState.CALLING_TOOL
step.tool_calls = tool_calls
self._update_cli_console(step)
if self._model_config.parallel_tool_calls:
tool_results = await self._tool_caller.parallel_tool_call(tool_calls)
else:
tool_results = await self._tool_caller.sequential_tool_call(tool_calls)
step.tool_results = tool_results
self._update_cli_console(step)
for tool_result in tool_results:
# Add tool result to conversation
message = LLMMessage(role="user", tool_result=tool_result)
messages.append(message)
reflection = self.reflect_on_result(tool_results)
if reflection:
step.state = AgentStepState.REFLECTING
step.reflection = reflection
# Display reflection
self._update_cli_console(step)
messages.append(LLMMessage(role="assistant", content=reflection))
return messages

86
demo.py
View File

@@ -3,33 +3,15 @@ from loguru import logger
if __name__ == "__main__":
context = """
def add(a, b):
return a + b
with open("assets/example_context.py", "r") as f:
context = f.read()
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)
def search_with_binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
"""
question = "How to write a quick sort algorithm?"
question = '''
async def _finalize_step(
self, step: "AgentStep", messages: list["LLMMessage"], execution: "AgentExecution"
) -> None:
step.state = AgentStepState.COMPLETED
'''
# Initialize compressor
logger.info("Initializing compressor...")
@@ -44,48 +26,22 @@ if __name__ == "__main__":
target_ratio = min(1.0, max(0.0, target_token / original_tokens))
logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}")
result = compressor.compress_code_file(
code=context,
query=question, # Using current function context as query focus
instruction="Complete the following code function given the context.",
rate=target_ratio,
rank_only=False, # False to use fine-grained compression
fine_grained_importance_method="contrastive_perplexity", # Explicitly test default
min_lines_for_fine_grained=5, # Min number of lines for fine-grained compression
importance_beta=0.5, # Sensitivity to importance score
use_knapsack=True,
)
# show the compressed code
logger.info(f"Compressed code (using {result['fine_grained_method_used']}): \n{result['compressed_code']}")
logger.info(f"Current function context: \n{question}")
# final prompt
final_prompt = result['compressed_prompt']
# get the completion
tokenized_prompt = compressor.tokenizer(final_prompt, return_tensors="pt").to(compressor.device)
# Increase max_new_tokens for potentially longer completions
completion_ids = compressor.model.generate(**tokenized_prompt, max_new_tokens=128, pad_token_id=compressor.tokenizer.eos_token_id)
# Decode only the generated part, skipping special tokens
completion = compressor.tokenizer.decode(completion_ids[0][len(tokenized_prompt.input_ids[0]):], skip_special_tokens=True)
# Basic cleanup: remove leading/trailing whitespace and potentially stop words if needed
completion = completion.strip()
# More robust cleanup: Find the first meaningful line if generation includes noise
completion_lines = [line for line in completion.split("\n") if line.strip() and not line.strip().startswith(("#", "//"))] # Simple comment removal
cleaned_completion = completion_lines[0] if completion_lines else completion # Take first non-comment line or original if none found
logger.info(f"Cleaned Completion: {cleaned_completion}")
# Optional: Test with conditional_ppl method
logger.info("\nTesting fine-grained compression with conditional_ppl...")
logger.info("\nTesting compression with Coarse-grained compression only...")
result_cond = compressor.compress_code_file(
code=context,
query=question,
instruction="Complete the following code function given the context.",
rate=target_ratio,
rank_only=False,
fine_grained_importance_method="conditional_ppl",
min_lines_for_fine_grained=5,
importance_beta=0.5
rank_only=True # Coarse-grained compression
)
logger.info(f"Compressed code (using {result_cond['fine_grained_method_used']}): \n{result_cond['compressed_code']}")
logger.info(f"Compressed prompt: \n{result_cond['compressed_prompt']}")
logger.info("\nTesting compression with Coarse-grained and Fine-grained compression...")
result_cond = compressor.compress_code_file(
code=context,
query=question,
instruction="Complete the following code function given the context.",
rate=target_ratio,
rank_only=False # Corase-grained and Fine-grained compression
)
logger.info(f"Compressed prompt: \n{result_cond['compressed_prompt']}")

View File

@@ -1803,93 +1803,4 @@ class CodeCompressor:
selected.add(idx)
current_weight += weight
return selected
if __name__ == "__main__":
context = """
def add(a, b):
return a + b
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)
def search_with_binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
"""
question = "How to write a quick sort algorithm?"
# Initialize compressor
logger.info("Initializing compressor...")
model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"
compressor = CodeCompressor(model_name=model_name)
# Test function-based code file compression with query
logger.info("\nTesting function-based code file compression with query...")
original_tokens = len(compressor.tokenizer.encode(context))
target_token = 64
target_ratio = min(1.0, max(0.0, target_token / original_tokens))
logger.info(f"CodeCompressor: Original tokens={original_tokens}, Target tokens={target_token}, Calculated ratio={target_ratio:.4f}")
result = compressor.compress_code_file(
code=context,
query=question, # Using current function context as query focus
instruction="Complete the following code function given the context.",
rate=target_ratio,
rank_only=True, # Only use coarse-grained compression
fine_grained_importance_method="conditional_ppl", # Explicitly test default
min_lines_for_fine_grained=5, # Min number of lines for fine-grained compression
importance_beta=0.5, # Sensitivity to importance score
use_knapsack=True,
)
# show the compressed code
logger.info(f"Compressed code (using {result['fine_grained_method_used']}): \n{result['compressed_code']}")
logger.info(f"Current function context: \n{question}")
# final prompt
final_prompt = result['compressed_prompt']
# get the completion
tokenized_prompt = compressor.tokenizer(final_prompt, return_tensors="pt").to(compressor.device)
# Increase max_new_tokens for potentially longer completions
completion_ids = compressor.model.generate(**tokenized_prompt, max_new_tokens=128, pad_token_id=compressor.tokenizer.eos_token_id)
# Decode only the generated part, skipping special tokens
completion = compressor.tokenizer.decode(completion_ids[0][len(tokenized_prompt.input_ids[0]):], skip_special_tokens=True)
# Basic cleanup: remove leading/trailing whitespace and potentially stop words if needed
completion = completion.strip()
# More robust cleanup: Find the first meaningful line if generation includes noise
completion_lines = [line for line in completion.split("\n") if line.strip() and not line.strip().startswith(("#", "//"))] # Simple comment removal
cleaned_completion = completion_lines[0] if completion_lines else completion # Take first non-comment line or original if none found
logger.info(f"Cleaned Completion: {cleaned_completion}")
# Optional: Test with conditional_ppl method
logger.info("\nTesting fine-grained compression with conditional_ppl...")
result_cond = compressor.compress_code_file(
code=context,
query=question,
instruction="Complete the following code function given the context.",
rate=target_ratio,
rank_only=False,
fine_grained_importance_method="conditional_ppl",
min_lines_for_fine_grained=5,
importance_beta=0.5
)
logger.info(f"Compressed code (using {result_cond['fine_grained_method_used']}): \n{result_cond['compressed_code']}")
return selected