refactor: remove code duplication in get_next_action method
This commit is contained in:
@@ -184,42 +184,39 @@ class CustomAgent(Agent):
|
|||||||
@time_execution_async("--get_next_action")
|
@time_execution_async("--get_next_action")
|
||||||
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
|
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
|
||||||
"""Get next action from LLM based on current state"""
|
"""Get next action from LLM based on current state"""
|
||||||
if self.use_deepseek_r1:
|
messages_to_process = (
|
||||||
merged_input_messages = self.message_manager.merge_successive_human_messages(input_messages)
|
self.message_manager.merge_successive_human_messages(input_messages)
|
||||||
ai_message = self.llm.invoke(merged_input_messages)
|
if self.use_deepseek_r1
|
||||||
self.message_manager._add_message_with_tokens(ai_message)
|
else input_messages
|
||||||
logger.info(f"🤯 Start Deep Thinking: ")
|
)
|
||||||
logger.info(ai_message.reasoning_content)
|
|
||||||
logger.info(f"🤯 End Deep Thinking")
|
|
||||||
if isinstance(ai_message.content, list):
|
|
||||||
ai_content = ai_message.content[0].replace("```json", "").replace("```", "")
|
|
||||||
else:
|
|
||||||
ai_content = ai_message.content.replace("```json", "").replace("```", "")
|
|
||||||
ai_content = repair_json(ai_content)
|
|
||||||
parsed_json = json.loads(ai_content)
|
|
||||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
|
||||||
if parsed is None:
|
|
||||||
logger.debug(ai_message.content)
|
|
||||||
raise ValueError(f'Could not parse response.')
|
|
||||||
else:
|
|
||||||
ai_message = self.llm.invoke(input_messages)
|
|
||||||
self.message_manager._add_message_with_tokens(ai_message)
|
|
||||||
if isinstance(ai_message.content, list):
|
|
||||||
ai_content = ai_message.content[0].replace("```json", "").replace("```", "")
|
|
||||||
else:
|
|
||||||
ai_content = ai_message.content.replace("```json", "").replace("```", "")
|
|
||||||
ai_content = repair_json(ai_content)
|
|
||||||
parsed_json = json.loads(ai_content)
|
|
||||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
|
||||||
if parsed is None:
|
|
||||||
logger.debug(ai_message.content)
|
|
||||||
raise ValueError(f'Could not parse response.')
|
|
||||||
|
|
||||||
# cut the number of actions to max_actions_per_step
|
ai_message = self.llm.invoke(messages_to_process)
|
||||||
|
self.message_manager._add_message_with_tokens(ai_message)
|
||||||
|
|
||||||
|
if self.use_deepseek_r1:
|
||||||
|
logger.info("🤯 Start Deep Thinking: ")
|
||||||
|
logger.info(ai_message.reasoning_content)
|
||||||
|
logger.info("🤯 End Deep Thinking")
|
||||||
|
|
||||||
|
if isinstance(ai_message.content, list):
|
||||||
|
ai_content = ai_message.content[0]
|
||||||
|
else:
|
||||||
|
ai_content = ai_message.content
|
||||||
|
|
||||||
|
ai_content = ai_content.replace("```json", "").replace("```", "")
|
||||||
|
ai_content = repair_json(ai_content)
|
||||||
|
parsed_json = json.loads(ai_content)
|
||||||
|
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||||
|
|
||||||
|
if parsed is None:
|
||||||
|
logger.debug(ai_message.content)
|
||||||
|
raise ValueError('Could not parse response.')
|
||||||
|
|
||||||
|
# Limit actions to maximum allowed per step
|
||||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||||
self._log_response(parsed)
|
self._log_response(parsed)
|
||||||
self.n_steps += 1
|
self.n_steps += 1
|
||||||
|
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
@time_execution_async("--step")
|
@time_execution_async("--step")
|
||||||
|
|||||||
Reference in New Issue
Block a user