refactor: remove code duplication in get_next_action method

This commit is contained in:
marginal23326
2025-01-29 00:41:52 +06:00
parent 0c9cb9ba11
commit 3fb8020387

View File

@@ -184,42 +184,39 @@ class CustomAgent(Agent):
@time_execution_async("--get_next_action")
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
"""Get next action from LLM based on current state"""
if self.use_deepseek_r1:
merged_input_messages = self.message_manager.merge_successive_human_messages(input_messages)
ai_message = self.llm.invoke(merged_input_messages)
self.message_manager._add_message_with_tokens(ai_message)
logger.info(f"🤯 Start Deep Thinking: ")
logger.info(ai_message.reasoning_content)
logger.info(f"🤯 End Deep Thinking")
if isinstance(ai_message.content, list):
ai_content = ai_message.content[0].replace("```json", "").replace("```", "")
else:
ai_content = ai_message.content.replace("```json", "").replace("```", "")
ai_content = repair_json(ai_content)
parsed_json = json.loads(ai_content)
parsed: AgentOutput = self.AgentOutput(**parsed_json)
if parsed is None:
logger.debug(ai_message.content)
raise ValueError(f'Could not parse response.')
else:
ai_message = self.llm.invoke(input_messages)
self.message_manager._add_message_with_tokens(ai_message)
if isinstance(ai_message.content, list):
ai_content = ai_message.content[0].replace("```json", "").replace("```", "")
else:
ai_content = ai_message.content.replace("```json", "").replace("```", "")
ai_content = repair_json(ai_content)
parsed_json = json.loads(ai_content)
parsed: AgentOutput = self.AgentOutput(**parsed_json)
if parsed is None:
logger.debug(ai_message.content)
raise ValueError(f'Could not parse response.')
messages_to_process = (
self.message_manager.merge_successive_human_messages(input_messages)
if self.use_deepseek_r1
else input_messages
)
# cut the number of actions to max_actions_per_step
ai_message = self.llm.invoke(messages_to_process)
self.message_manager._add_message_with_tokens(ai_message)
if self.use_deepseek_r1:
logger.info("🤯 Start Deep Thinking: ")
logger.info(ai_message.reasoning_content)
logger.info("🤯 End Deep Thinking")
if isinstance(ai_message.content, list):
ai_content = ai_message.content[0]
else:
ai_content = ai_message.content
ai_content = ai_content.replace("```json", "").replace("```", "")
ai_content = repair_json(ai_content)
parsed_json = json.loads(ai_content)
parsed: AgentOutput = self.AgentOutput(**parsed_json)
if parsed is None:
logger.debug(ai_message.content)
raise ValueError('Could not parse response.')
# Limit actions to maximum allowed per step
parsed.action = parsed.action[: self.max_actions_per_step]
self._log_response(parsed)
self.n_steps += 1
return parsed
@time_execution_async("--step")