fix deepseek-r1 ollama
This commit is contained in:
@@ -242,17 +242,17 @@ class CustomAgent(Agent):
|
||||
logger.info(f"🧠 All Memory: \n{step_info.memory}")
|
||||
self._save_conversation(input_messages, model_output)
|
||||
if self.model_name != "deepseek-reasoner":
|
||||
# remove pre-prev message
|
||||
self.message_manager._remove_last_state_message()
|
||||
# remove prev message
|
||||
self.message_manager._remove_state_message_by_index(-1)
|
||||
except Exception as e:
|
||||
# model call failed, remove last state message from history
|
||||
self.message_manager._remove_last_state_message()
|
||||
self.message_manager._remove_state_message_by_index(-1)
|
||||
raise e
|
||||
|
||||
result: list[ActionResult] = await self.controller.multi_act(
|
||||
model_output.action, self.browser_context
|
||||
)
|
||||
actions: list[ActionModel] = model_output.action
|
||||
result: list[ActionResult] = await self.controller.multi_act(
|
||||
actions, self.browser_context
|
||||
)
|
||||
if len(result) != len(actions):
|
||||
# I think something changes, such information should let LLM know
|
||||
for ri in range(len(result), len(actions)):
|
||||
@@ -261,6 +261,9 @@ class CustomAgent(Agent):
|
||||
error=f"{actions[ri].model_dump_json(exclude_unset=True)} is Failed to execute. \
|
||||
Something new appeared after action {actions[len(result) - 1].model_dump_json(exclude_unset=True)}",
|
||||
is_done=False))
|
||||
if len(actions) == 0:
|
||||
# TODO: fix no action case
|
||||
result = [ActionResult(is_done=True, extracted_content=step_info.memory, include_in_memory=True)]
|
||||
self._last_result = result
|
||||
self._last_actions = actions
|
||||
if len(result) > 0 and result[-1].is_done:
|
||||
|
||||
@@ -70,18 +70,6 @@ class CustomMassageManager(MessageManager):
|
||||
while diff > 0 and len(self.history.messages) > min_message_len:
|
||||
self.history.remove_message(min_message_len) # alway remove the oldest message
|
||||
diff = self.history.total_tokens - self.max_input_tokens
|
||||
|
||||
def _remove_state_message_by_index(self, remove_ind=-1) -> None:
|
||||
"""Remove last state message from history"""
|
||||
i = 0
|
||||
remove_cnt = 0
|
||||
while len(self.history.messages) and i <= len(self.history.messages):
|
||||
i += 1
|
||||
if isinstance(self.history.messages[-i].message, HumanMessage):
|
||||
remove_cnt += 1
|
||||
if remove_cnt == abs(remove_ind):
|
||||
self.history.remove_message(-i)
|
||||
break
|
||||
|
||||
def add_state_message(
|
||||
self,
|
||||
@@ -115,3 +103,15 @@ class CustomMassageManager(MessageManager):
|
||||
len(text) // self.estimated_characters_per_token
|
||||
) # Rough estimate if no tokenizer available
|
||||
return tokens
|
||||
|
||||
def _remove_state_message_by_index(self, remove_ind=-1) -> None:
|
||||
"""Remove last state message from history"""
|
||||
i = len(self.history.messages) - 1
|
||||
remove_cnt = 0
|
||||
while i >= 0:
|
||||
if isinstance(self.history.messages[i].message, HumanMessage):
|
||||
remove_cnt += 1
|
||||
if remove_cnt == abs(remove_ind):
|
||||
self.history.remove_message(i)
|
||||
break
|
||||
i -= 1
|
||||
@@ -183,7 +183,7 @@ class CustomAgentMessagePrompt(AgentMessagePrompt):
|
||||
|
||||
state_description = f"""
|
||||
{step_info_description}
|
||||
1. Task: {self.step_info.task}
|
||||
1. Task: {self.step_info.task}.
|
||||
2. Hints(Optional):
|
||||
{self.step_info.add_infos}
|
||||
3. Memory:
|
||||
|
||||
@@ -94,12 +94,11 @@ def get_llm_model(provider: str, **kwargs):
|
||||
else:
|
||||
base_url = kwargs.get("base_url")
|
||||
|
||||
if kwargs.get("model_name", "qwen2.5:7b").startswith("deepseek-r1"):
|
||||
if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"):
|
||||
return DeepSeekR1ChatOllama(
|
||||
model=kwargs.get("model_name", "deepseek-r1:7b"),
|
||||
model=kwargs.get("model_name", "deepseek-r1:14b"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
num_ctx=kwargs.get("num_ctx", 32000),
|
||||
num_predict=kwargs.get("num_predict", 1024),
|
||||
base_url=kwargs.get("base_url", base_url),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -32,10 +32,14 @@ async def test_browser_use_org():
|
||||
# api_key=os.getenv("AZURE_OPENAI_API_KEY", ""),
|
||||
# )
|
||||
|
||||
# llm = utils.get_llm_model(
|
||||
# provider="deepseek",
|
||||
# model_name="deepseek-chat",
|
||||
# temperature=0.8
|
||||
# )
|
||||
|
||||
llm = utils.get_llm_model(
|
||||
provider="deepseek",
|
||||
model_name="deepseek-chat",
|
||||
temperature=0.8
|
||||
provider="ollama", model_name="deepseek-r1:14b", temperature=0.5
|
||||
)
|
||||
|
||||
window_w, window_h = 1920, 1080
|
||||
@@ -152,9 +156,9 @@ async def test_browser_use_custom():
|
||||
controller = CustomController()
|
||||
use_own_browser = True
|
||||
disable_security = True
|
||||
use_vision = True # Set to False when using DeepSeek
|
||||
use_vision = False # Set to False when using DeepSeek
|
||||
|
||||
max_actions_per_step = 10
|
||||
max_actions_per_step = 1
|
||||
playwright = None
|
||||
browser = None
|
||||
browser_context = None
|
||||
@@ -189,7 +193,7 @@ async def test_browser_use_custom():
|
||||
)
|
||||
)
|
||||
agent = CustomAgent(
|
||||
task="go to google.com and type 'OpenAI' click search and give me the first url",
|
||||
task="Search 'Nvidia' and give me the first url",
|
||||
add_infos="", # some hints for llm to complete the task
|
||||
llm=llm,
|
||||
browser=browser,
|
||||
|
||||
Reference in New Issue
Block a user