refactor: remove code duplication in get_next_action method
This commit is contained in:
@@ -184,42 +184,39 @@ class CustomAgent(Agent):
|
||||
@time_execution_async("--get_next_action")
|
||||
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
|
||||
"""Get next action from LLM based on current state"""
|
||||
if self.use_deepseek_r1:
|
||||
merged_input_messages = self.message_manager.merge_successive_human_messages(input_messages)
|
||||
ai_message = self.llm.invoke(merged_input_messages)
|
||||
self.message_manager._add_message_with_tokens(ai_message)
|
||||
logger.info(f"🤯 Start Deep Thinking: ")
|
||||
logger.info(ai_message.reasoning_content)
|
||||
logger.info(f"🤯 End Deep Thinking")
|
||||
if isinstance(ai_message.content, list):
|
||||
ai_content = ai_message.content[0].replace("```json", "").replace("```", "")
|
||||
else:
|
||||
ai_content = ai_message.content.replace("```json", "").replace("```", "")
|
||||
ai_content = repair_json(ai_content)
|
||||
parsed_json = json.loads(ai_content)
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
if parsed is None:
|
||||
logger.debug(ai_message.content)
|
||||
raise ValueError(f'Could not parse response.')
|
||||
else:
|
||||
ai_message = self.llm.invoke(input_messages)
|
||||
self.message_manager._add_message_with_tokens(ai_message)
|
||||
if isinstance(ai_message.content, list):
|
||||
ai_content = ai_message.content[0].replace("```json", "").replace("```", "")
|
||||
else:
|
||||
ai_content = ai_message.content.replace("```json", "").replace("```", "")
|
||||
ai_content = repair_json(ai_content)
|
||||
parsed_json = json.loads(ai_content)
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
if parsed is None:
|
||||
logger.debug(ai_message.content)
|
||||
raise ValueError(f'Could not parse response.')
|
||||
messages_to_process = (
|
||||
self.message_manager.merge_successive_human_messages(input_messages)
|
||||
if self.use_deepseek_r1
|
||||
else input_messages
|
||||
)
|
||||
|
||||
# cut the number of actions to max_actions_per_step
|
||||
ai_message = self.llm.invoke(messages_to_process)
|
||||
self.message_manager._add_message_with_tokens(ai_message)
|
||||
|
||||
if self.use_deepseek_r1:
|
||||
logger.info("🤯 Start Deep Thinking: ")
|
||||
logger.info(ai_message.reasoning_content)
|
||||
logger.info("🤯 End Deep Thinking")
|
||||
|
||||
if isinstance(ai_message.content, list):
|
||||
ai_content = ai_message.content[0]
|
||||
else:
|
||||
ai_content = ai_message.content
|
||||
|
||||
ai_content = ai_content.replace("```json", "").replace("```", "")
|
||||
ai_content = repair_json(ai_content)
|
||||
parsed_json = json.loads(ai_content)
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
|
||||
if parsed is None:
|
||||
logger.debug(ai_message.content)
|
||||
raise ValueError('Could not parse response.')
|
||||
|
||||
# Limit actions to maximum allowed per step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self.n_steps += 1
|
||||
|
||||
|
||||
return parsed
|
||||
|
||||
@time_execution_async("--step")
|
||||
|
||||
Reference in New Issue
Block a user