fix deepseek-r1 ollama

This commit is contained in:
vincent
2025-01-28 20:38:29 +08:00
parent 7c8949600d
commit 75ab5051ec
5 changed files with 34 additions and 28 deletions

View File

@@ -242,17 +242,17 @@ class CustomAgent(Agent):
logger.info(f"🧠 All Memory: \n{step_info.memory}")
self._save_conversation(input_messages, model_output)
if self.model_name != "deepseek-reasoner":
# remove pre-prev message
self.message_manager._remove_last_state_message()
# remove prev message
self.message_manager._remove_state_message_by_index(-1)
except Exception as e:
# model call failed, remove last state message from history
self.message_manager._remove_last_state_message()
self.message_manager._remove_state_message_by_index(-1)
raise e
result: list[ActionResult] = await self.controller.multi_act(
model_output.action, self.browser_context
)
actions: list[ActionModel] = model_output.action
result: list[ActionResult] = await self.controller.multi_act(
actions, self.browser_context
)
if len(result) != len(actions):
# I think something changes, such information should let LLM know
for ri in range(len(result), len(actions)):
@@ -261,6 +261,9 @@ class CustomAgent(Agent):
error=f"{actions[ri].model_dump_json(exclude_unset=True)} is Failed to execute. \
Something new appeared after action {actions[len(result) - 1].model_dump_json(exclude_unset=True)}",
is_done=False))
if len(actions) == 0:
# TODO: fix no action case
result = [ActionResult(is_done=True, extracted_content=step_info.memory, include_in_memory=True)]
self._last_result = result
self._last_actions = actions
if len(result) > 0 and result[-1].is_done:

View File

@@ -70,18 +70,6 @@ class CustomMassageManager(MessageManager):
while diff > 0 and len(self.history.messages) > min_message_len:
self.history.remove_message(min_message_len) # alway remove the oldest message
diff = self.history.total_tokens - self.max_input_tokens
def _remove_state_message_by_index(self, remove_ind=-1) -> None:
"""Remove last state message from history"""
i = 0
remove_cnt = 0
while len(self.history.messages) and i <= len(self.history.messages):
i += 1
if isinstance(self.history.messages[-i].message, HumanMessage):
remove_cnt += 1
if remove_cnt == abs(remove_ind):
self.history.remove_message(-i)
break
def add_state_message(
self,
@@ -115,3 +103,15 @@ class CustomMassageManager(MessageManager):
len(text) // self.estimated_characters_per_token
) # Rough estimate if no tokenizer available
return tokens
def _remove_state_message_by_index(self, remove_ind=-1) -> None:
"""Remove last state message from history"""
i = len(self.history.messages) - 1
remove_cnt = 0
while i >= 0:
if isinstance(self.history.messages[i].message, HumanMessage):
remove_cnt += 1
if remove_cnt == abs(remove_ind):
self.history.remove_message(i)
break
i -= 1

View File

@@ -183,7 +183,7 @@ class CustomAgentMessagePrompt(AgentMessagePrompt):
state_description = f"""
{step_info_description}
1. Task: {self.step_info.task}
1. Task: {self.step_info.task}.
2. Hints(Optional):
{self.step_info.add_infos}
3. Memory:

View File

@@ -94,12 +94,11 @@ def get_llm_model(provider: str, **kwargs):
else:
base_url = kwargs.get("base_url")
if kwargs.get("model_name", "qwen2.5:7b").startswith("deepseek-r1"):
if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"):
return DeepSeekR1ChatOllama(
model=kwargs.get("model_name", "deepseek-r1:7b"),
model=kwargs.get("model_name", "deepseek-r1:14b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
num_predict=kwargs.get("num_predict", 1024),
base_url=kwargs.get("base_url", base_url),
)
else:

View File

@@ -32,10 +32,14 @@ async def test_browser_use_org():
# api_key=os.getenv("AZURE_OPENAI_API_KEY", ""),
# )
# llm = utils.get_llm_model(
# provider="deepseek",
# model_name="deepseek-chat",
# temperature=0.8
# )
llm = utils.get_llm_model(
provider="deepseek",
model_name="deepseek-chat",
temperature=0.8
provider="ollama", model_name="deepseek-r1:14b", temperature=0.5
)
window_w, window_h = 1920, 1080
@@ -152,9 +156,9 @@ async def test_browser_use_custom():
controller = CustomController()
use_own_browser = True
disable_security = True
use_vision = True # Set to False when using DeepSeek
use_vision = False # Set to False when using DeepSeek
max_actions_per_step = 10
max_actions_per_step = 1
playwright = None
browser = None
browser_context = None
@@ -189,7 +193,7 @@ async def test_browser_use_custom():
)
)
agent = CustomAgent(
task="go to google.com and type 'OpenAI' click search and give me the first url",
task="Search 'Nvidia' and give me the first url",
add_infos="", # some hints for llm to complete the task
llm=llm,
browser=browser,