Merge pull request #173 from vvincent1234/fix/adapt_latest_browser-use
Fix/adapt latest browser use
This commit is contained in:
@@ -1,6 +1,3 @@
|
||||
browser-use==0.1.19
|
||||
langchain-google-genai==2.0.8
|
||||
browser-use==0.1.29
|
||||
pyperclip==1.9.0
|
||||
gradio==5.9.1
|
||||
langchain-ollama==0.2.2
|
||||
langchain-openai==0.2.14
|
||||
gradio==5.10.0
|
||||
|
||||
@@ -2,12 +2,12 @@ import json
|
||||
import logging
|
||||
import pdb
|
||||
import traceback
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List, Dict, Any, Callable
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import os
|
||||
import base64
|
||||
import io
|
||||
|
||||
import platform
|
||||
from browser_use.agent.prompts import SystemPrompt
|
||||
from browser_use.agent.service import Agent
|
||||
from browser_use.agent.views import (
|
||||
@@ -21,9 +21,9 @@ from browser_use.browser.context import BrowserContext
|
||||
from browser_use.browser.views import BrowserStateHistory
|
||||
from browser_use.controller.service import Controller
|
||||
from browser_use.telemetry.views import (
|
||||
AgentEndTelemetryEvent,
|
||||
AgentRunTelemetryEvent,
|
||||
AgentStepErrorTelemetryEvent,
|
||||
AgentEndTelemetryEvent,
|
||||
AgentRunTelemetryEvent,
|
||||
AgentStepTelemetryEvent,
|
||||
)
|
||||
from browser_use.utils import time_execution_async
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
@@ -70,6 +70,11 @@ class CustomAgent(Agent):
|
||||
max_actions_per_step: int = 10,
|
||||
tool_call_in_content: bool = True,
|
||||
agent_state: AgentState = None,
|
||||
initial_actions: Optional[List[Dict[str, Dict[str, Any]]]] = None,
|
||||
# Cloud Callbacks
|
||||
register_new_step_callback: Callable[['BrowserState', 'AgentOutput', int], None] | None = None,
|
||||
register_done_callback: Callable[['AgentHistoryList'], None] | None = None,
|
||||
tool_calling_method: Optional[str] = 'auto',
|
||||
):
|
||||
super().__init__(
|
||||
task=task,
|
||||
@@ -88,15 +93,22 @@ class CustomAgent(Agent):
|
||||
max_error_length=max_error_length,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
initial_actions=initial_actions,
|
||||
register_new_step_callback=register_new_step_callback,
|
||||
register_done_callback=register_done_callback,
|
||||
tool_calling_method=tool_calling_method
|
||||
)
|
||||
if hasattr(self.llm, 'model_name') and self.llm.model_name in ["deepseek-reasoner"]:
|
||||
if self.model_name in ["deepseek-reasoner"] or self.model_name.startswith("deepseek-r1"):
|
||||
# deepseek-reasoner does not support function calling
|
||||
self.use_function_calling = False
|
||||
# TODO: deepseek-reasoner only support 64000 context
|
||||
self.use_deepseek_r1 = True
|
||||
# deepseek-reasoner only support 64000 context
|
||||
self.max_input_tokens = 64000
|
||||
else:
|
||||
self.use_function_calling = True
|
||||
self.use_deepseek_r1 = False
|
||||
|
||||
# custom new info
|
||||
self.add_infos = add_infos
|
||||
# agent_state for Stop
|
||||
self.agent_state = agent_state
|
||||
self.message_manager = CustomMassageManager(
|
||||
llm=self.llm,
|
||||
@@ -107,8 +119,7 @@ class CustomAgent(Agent):
|
||||
include_attributes=self.include_attributes,
|
||||
max_error_length=self.max_error_length,
|
||||
max_actions_per_step=self.max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
use_function_calling=self.use_function_calling
|
||||
use_deepseek_r1=self.use_deepseek_r1
|
||||
)
|
||||
|
||||
def _setup_action_models(self) -> None:
|
||||
@@ -167,57 +178,39 @@ class CustomAgent(Agent):
|
||||
@time_execution_async("--get_next_action")
|
||||
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
|
||||
"""Get next action from LLM based on current state"""
|
||||
if self.use_function_calling:
|
||||
try:
|
||||
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
|
||||
response: dict[str, Any] = await structured_llm.ainvoke(input_messages) # type: ignore
|
||||
|
||||
parsed: AgentOutput = response['parsed']
|
||||
# cut the number of actions to max_actions_per_step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self.n_steps += 1
|
||||
|
||||
return parsed
|
||||
except Exception as e:
|
||||
# If something goes wrong, try to invoke the LLM again without structured output,
|
||||
# and Manually parse the response. Temporarily solution for DeepSeek
|
||||
ret = self.llm.invoke(input_messages)
|
||||
if isinstance(ret.content, list):
|
||||
parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", ""))
|
||||
else:
|
||||
parsed_json = json.loads(ret.content.replace("```json", "").replace("```", ""))
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
if parsed is None:
|
||||
raise ValueError(f'Could not parse response.')
|
||||
|
||||
# cut the number of actions to max_actions_per_step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self.n_steps += 1
|
||||
|
||||
return parsed
|
||||
else:
|
||||
ret = self.llm.invoke(input_messages)
|
||||
if not self.use_function_calling:
|
||||
self.message_manager._add_message_with_tokens(ret)
|
||||
if self.use_deepseek_r1:
|
||||
merged_input_messages = self.message_manager.merge_successive_human_messages(input_messages)
|
||||
ai_message = self.llm.invoke(merged_input_messages)
|
||||
self.message_manager._add_message_with_tokens(ai_message)
|
||||
logger.info(f"🤯 Start Deep Thinking: ")
|
||||
logger.info(ret.reasoning_content)
|
||||
logger.info(ai_message.reasoning_content)
|
||||
logger.info(f"🤯 End Deep Thinking")
|
||||
if isinstance(ret.content, list):
|
||||
parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", ""))
|
||||
if isinstance(ai_message.content, list):
|
||||
parsed_json = json.loads(ai_message.content[0].replace("```json", "").replace("```", ""))
|
||||
else:
|
||||
parsed_json = json.loads(ret.content.replace("```json", "").replace("```", ""))
|
||||
parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", ""))
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
if parsed is None:
|
||||
logger.debug(ai_message.content)
|
||||
raise ValueError(f'Could not parse response.')
|
||||
else:
|
||||
ai_message = self.llm.invoke(input_messages)
|
||||
self.message_manager._add_message_with_tokens(ai_message)
|
||||
if isinstance(ai_message.content, list):
|
||||
parsed_json = json.loads(ai_message.content[0].replace("```json", "").replace("```", ""))
|
||||
else:
|
||||
parsed_json = json.loads(ai_message.content.replace("```json", "").replace("```", ""))
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
if parsed is None:
|
||||
logger.debug(ai_message.content)
|
||||
raise ValueError(f'Could not parse response.')
|
||||
|
||||
# cut the number of actions to max_actions_per_step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self.n_steps += 1
|
||||
# cut the number of actions to max_actions_per_step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self.n_steps += 1
|
||||
|
||||
return parsed
|
||||
return parsed
|
||||
|
||||
@time_execution_async("--step")
|
||||
async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None:
|
||||
@@ -231,13 +224,20 @@ class CustomAgent(Agent):
|
||||
state = await self.browser_context.get_state(use_vision=self.use_vision)
|
||||
self.message_manager.add_state_message(state, self._last_result, step_info)
|
||||
input_messages = self.message_manager.get_messages()
|
||||
model_output = await self.get_next_action(input_messages)
|
||||
self.update_step_info(model_output, step_info)
|
||||
logger.info(f"🧠 All Memory: \n{step_info.memory}")
|
||||
self._save_conversation(input_messages, model_output)
|
||||
if self.use_function_calling:
|
||||
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history
|
||||
self.message_manager.add_model_output(model_output)
|
||||
try:
|
||||
model_output = await self.get_next_action(input_messages)
|
||||
if self.register_new_step_callback:
|
||||
self.register_new_step_callback(state, model_output, self.n_steps)
|
||||
self.update_step_info(model_output, step_info)
|
||||
logger.info(f"🧠 All Memory: \n{step_info.memory}")
|
||||
self._save_conversation(input_messages, model_output)
|
||||
# should we remove last state message? at least, deepseek-reasoner cannot remove
|
||||
if self.model_name != "deepseek-reasoner":
|
||||
self.message_manager._remove_last_state_message()
|
||||
except Exception as e:
|
||||
# model call failed, remove last state message from history
|
||||
self.message_manager._remove_last_state_message()
|
||||
raise e
|
||||
|
||||
result: list[ActionResult] = await self.controller.multi_act(
|
||||
model_output.action, self.browser_context
|
||||
@@ -258,34 +258,172 @@ class CustomAgent(Agent):
|
||||
self.consecutive_failures = 0
|
||||
|
||||
except Exception as e:
|
||||
result = self._handle_step_error(e)
|
||||
result = await self._handle_step_error(e)
|
||||
self._last_result = result
|
||||
|
||||
finally:
|
||||
actions = [a.model_dump(exclude_unset=True) for a in model_output.action] if model_output else []
|
||||
self.telemetry.capture(
|
||||
AgentStepTelemetryEvent(
|
||||
agent_id=self.agent_id,
|
||||
step=self.n_steps,
|
||||
actions=actions,
|
||||
consecutive_failures=self.consecutive_failures,
|
||||
step_error=[r.error for r in result if r.error] if result else ['No result'],
|
||||
)
|
||||
)
|
||||
if not result:
|
||||
return
|
||||
for r in result:
|
||||
if r.error:
|
||||
self.telemetry.capture(
|
||||
AgentStepErrorTelemetryEvent(
|
||||
agent_id=self.agent_id,
|
||||
error=r.error,
|
||||
)
|
||||
)
|
||||
|
||||
if state:
|
||||
self._make_history_item(model_output, state, result)
|
||||
|
||||
async def run(self, max_steps: int = 100) -> AgentHistoryList:
|
||||
"""Execute the task with maximum number of steps"""
|
||||
try:
|
||||
self._log_agent_run()
|
||||
|
||||
# Execute initial actions if provided
|
||||
if self.initial_actions:
|
||||
result = await self.controller.multi_act(self.initial_actions, self.browser_context, check_for_new_elements=False)
|
||||
self._last_result = result
|
||||
|
||||
step_info = CustomAgentStepInfo(
|
||||
task=self.task,
|
||||
add_infos=self.add_infos,
|
||||
step_number=1,
|
||||
max_steps=max_steps,
|
||||
memory="",
|
||||
task_progress="",
|
||||
future_plans=""
|
||||
)
|
||||
|
||||
for step in range(max_steps):
|
||||
# 1) Check if stop requested
|
||||
if self.agent_state and self.agent_state.is_stop_requested():
|
||||
logger.info("🛑 Stop requested by user")
|
||||
self._create_stop_history_item()
|
||||
break
|
||||
|
||||
# 2) Store last valid state before step
|
||||
if self.browser_context and self.agent_state:
|
||||
state = await self.browser_context.get_state(use_vision=self.use_vision)
|
||||
self.agent_state.set_last_valid_state(state)
|
||||
|
||||
if self._too_many_failures():
|
||||
break
|
||||
|
||||
# 3) Do the step
|
||||
await self.step(step_info)
|
||||
|
||||
if self.history.is_done():
|
||||
if (
|
||||
self.validate_output and step < max_steps - 1
|
||||
): # if last step, we dont need to validate
|
||||
if not await self._validate_output():
|
||||
continue
|
||||
|
||||
logger.info("✅ Task completed successfully")
|
||||
break
|
||||
else:
|
||||
logger.info("❌ Failed to complete task in maximum steps")
|
||||
|
||||
return self.history
|
||||
|
||||
finally:
|
||||
self.telemetry.capture(
|
||||
AgentEndTelemetryEvent(
|
||||
agent_id=self.agent_id,
|
||||
success=self.history.is_done(),
|
||||
steps=self.n_steps,
|
||||
max_steps_reached=self.n_steps >= max_steps,
|
||||
errors=self.history.errors(),
|
||||
)
|
||||
)
|
||||
|
||||
if not self.injected_browser_context:
|
||||
await self.browser_context.close()
|
||||
|
||||
if not self.injected_browser and self.browser:
|
||||
await self.browser.close()
|
||||
|
||||
if self.generate_gif:
|
||||
output_path: str = 'agent_history.gif'
|
||||
if isinstance(self.generate_gif, str):
|
||||
output_path = self.generate_gif
|
||||
|
||||
self.create_history_gif(output_path=output_path)
|
||||
|
||||
def _create_stop_history_item(self):
|
||||
"""Create a history item for when the agent is stopped."""
|
||||
try:
|
||||
# Attempt to retrieve the last valid state from agent_state
|
||||
state = None
|
||||
if self.agent_state:
|
||||
last_state = self.agent_state.get_last_valid_state()
|
||||
if last_state:
|
||||
# Convert to BrowserStateHistory
|
||||
state = BrowserStateHistory(
|
||||
url=getattr(last_state, 'url', ""),
|
||||
title=getattr(last_state, 'title', ""),
|
||||
tabs=getattr(last_state, 'tabs', []),
|
||||
interacted_element=[None],
|
||||
screenshot=getattr(last_state, 'screenshot', None)
|
||||
)
|
||||
else:
|
||||
state = self._create_empty_state()
|
||||
else:
|
||||
state = self._create_empty_state()
|
||||
|
||||
# Create a final item in the agent history indicating done
|
||||
stop_history = AgentHistory(
|
||||
model_output=None,
|
||||
state=state,
|
||||
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
|
||||
)
|
||||
self.history.history.append(stop_history)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating stop history item: {e}")
|
||||
# Create empty state as fallback
|
||||
state = self._create_empty_state()
|
||||
stop_history = AgentHistory(
|
||||
model_output=None,
|
||||
state=state,
|
||||
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
|
||||
)
|
||||
self.history.history.append(stop_history)
|
||||
|
||||
def _convert_to_browser_state_history(self, browser_state):
|
||||
return BrowserStateHistory(
|
||||
url=getattr(browser_state, 'url', ""),
|
||||
title=getattr(browser_state, 'title', ""),
|
||||
tabs=getattr(browser_state, 'tabs', []),
|
||||
interacted_element=[None],
|
||||
screenshot=getattr(browser_state, 'screenshot', None)
|
||||
)
|
||||
|
||||
def _create_empty_state(self):
|
||||
return BrowserStateHistory(
|
||||
url="",
|
||||
title="",
|
||||
tabs=[],
|
||||
interacted_element=[None],
|
||||
screenshot=None
|
||||
)
|
||||
|
||||
def create_history_gif(
|
||||
self,
|
||||
output_path: str = 'agent_history.gif',
|
||||
duration: int = 3000,
|
||||
show_goals: bool = True,
|
||||
show_task: bool = True,
|
||||
show_logo: bool = False,
|
||||
font_size: int = 40,
|
||||
title_font_size: int = 56,
|
||||
goal_font_size: int = 44,
|
||||
margin: int = 40,
|
||||
line_spacing: float = 1.5,
|
||||
self,
|
||||
output_path: str = 'agent_history.gif',
|
||||
duration: int = 3000,
|
||||
show_goals: bool = True,
|
||||
show_task: bool = True,
|
||||
show_logo: bool = False,
|
||||
font_size: int = 40,
|
||||
title_font_size: int = 56,
|
||||
goal_font_size: int = 44,
|
||||
margin: int = 40,
|
||||
line_spacing: float = 1.5,
|
||||
) -> None:
|
||||
"""Create a GIF from the agent's history with overlaid task and goal text."""
|
||||
if not self.history.history:
|
||||
@@ -306,10 +444,9 @@ class CustomAgent(Agent):
|
||||
|
||||
for font_name in font_options:
|
||||
try:
|
||||
import platform
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == 'Windows':
|
||||
# Need to specify the abs font path on Windows
|
||||
font_name = os.path.join(os.getenv("WIN_FONT_DIR", "C:\\Windows\\Fonts"), font_name + ".ttf")
|
||||
font_name = os.path.join(os.getenv('WIN_FONT_DIR', 'C:\\Windows\\Fonts'), font_name + '.ttf')
|
||||
regular_font = ImageFont.truetype(font_name, font_size)
|
||||
title_font = ImageFont.truetype(font_name, title_font_size)
|
||||
goal_font = ImageFont.truetype(font_name, goal_font_size)
|
||||
@@ -386,134 +523,4 @@ class CustomAgent(Agent):
|
||||
)
|
||||
logger.info(f'Created GIF at {output_path}')
|
||||
else:
|
||||
logger.warning('No images found in history to create GIF')
|
||||
|
||||
async def run(self, max_steps: int = 100) -> AgentHistoryList:
|
||||
"""Execute the task with maximum number of steps"""
|
||||
try:
|
||||
logger.info(f"🚀 Starting task: {self.task}")
|
||||
|
||||
self.telemetry.capture(
|
||||
AgentRunTelemetryEvent(
|
||||
agent_id=self.agent_id,
|
||||
task=self.task,
|
||||
)
|
||||
)
|
||||
|
||||
step_info = CustomAgentStepInfo(
|
||||
task=self.task,
|
||||
add_infos=self.add_infos,
|
||||
step_number=1,
|
||||
max_steps=max_steps,
|
||||
memory="",
|
||||
task_progress="",
|
||||
future_plans=""
|
||||
)
|
||||
|
||||
for step in range(max_steps):
|
||||
# 1) Check if stop requested
|
||||
if self.agent_state and self.agent_state.is_stop_requested():
|
||||
logger.info("🛑 Stop requested by user")
|
||||
self._create_stop_history_item()
|
||||
break
|
||||
|
||||
# 2) Store last valid state before step
|
||||
if self.browser_context and self.agent_state:
|
||||
state = await self.browser_context.get_state(use_vision=self.use_vision)
|
||||
self.agent_state.set_last_valid_state(state)
|
||||
|
||||
if self._too_many_failures():
|
||||
break
|
||||
|
||||
# 3) Do the step
|
||||
await self.step(step_info)
|
||||
|
||||
if self.history.is_done():
|
||||
if (
|
||||
self.validate_output and step < max_steps - 1
|
||||
): # if last step, we dont need to validate
|
||||
if not await self._validate_output():
|
||||
continue
|
||||
|
||||
logger.info("✅ Task completed successfully")
|
||||
break
|
||||
else:
|
||||
logger.info("❌ Failed to complete task in maximum steps")
|
||||
|
||||
return self.history
|
||||
|
||||
finally:
|
||||
self.telemetry.capture(
|
||||
AgentEndTelemetryEvent(
|
||||
agent_id=self.agent_id,
|
||||
task=self.task,
|
||||
success=self.history.is_done(),
|
||||
steps=len(self.history.history),
|
||||
)
|
||||
)
|
||||
if not self.injected_browser_context:
|
||||
await self.browser_context.close()
|
||||
|
||||
if not self.injected_browser and self.browser:
|
||||
await self.browser.close()
|
||||
|
||||
if self.generate_gif:
|
||||
self.create_history_gif()
|
||||
|
||||
def _create_stop_history_item(self):
|
||||
"""Create a history item for when the agent is stopped."""
|
||||
try:
|
||||
# Attempt to retrieve the last valid state from agent_state
|
||||
state = None
|
||||
if self.agent_state:
|
||||
last_state = self.agent_state.get_last_valid_state()
|
||||
if last_state:
|
||||
# Convert to BrowserStateHistory
|
||||
state = BrowserStateHistory(
|
||||
url=getattr(last_state, 'url', ""),
|
||||
title=getattr(last_state, 'title', ""),
|
||||
tabs=getattr(last_state, 'tabs', []),
|
||||
interacted_element=[None],
|
||||
screenshot=getattr(last_state, 'screenshot', None)
|
||||
)
|
||||
else:
|
||||
state = self._create_empty_state()
|
||||
else:
|
||||
state = self._create_empty_state()
|
||||
|
||||
# Create a final item in the agent history indicating done
|
||||
stop_history = AgentHistory(
|
||||
model_output=None,
|
||||
state=state,
|
||||
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
|
||||
)
|
||||
self.history.history.append(stop_history)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating stop history item: {e}")
|
||||
# Create empty state as fallback
|
||||
state = self._create_empty_state()
|
||||
stop_history = AgentHistory(
|
||||
model_output=None,
|
||||
state=state,
|
||||
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
|
||||
)
|
||||
self.history.history.append(stop_history)
|
||||
|
||||
def _convert_to_browser_state_history(self, browser_state):
|
||||
return BrowserStateHistory(
|
||||
url=getattr(browser_state, 'url', ""),
|
||||
title=getattr(browser_state, 'title', ""),
|
||||
tabs=getattr(browser_state, 'tabs', []),
|
||||
interacted_element=[None],
|
||||
screenshot=getattr(browser_state, 'screenshot', None)
|
||||
)
|
||||
|
||||
def _create_empty_state(self):
|
||||
return BrowserStateHistory(
|
||||
url="",
|
||||
title="",
|
||||
tabs=[],
|
||||
interacted_element=[None],
|
||||
screenshot=None
|
||||
)
|
||||
logger.warning('No images found in history to create GIF')
|
||||
@@ -15,6 +15,7 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
ToolMessage
|
||||
)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from ..utils.llm import DeepSeekR1ChatOpenAI
|
||||
@@ -31,13 +32,13 @@ class CustomMassageManager(MessageManager):
|
||||
action_descriptions: str,
|
||||
system_prompt_class: Type[SystemPrompt],
|
||||
max_input_tokens: int = 128000,
|
||||
estimated_tokens_per_character: int = 3,
|
||||
estimated_characters_per_token: int = 3,
|
||||
image_tokens: int = 800,
|
||||
include_attributes: list[str] = [],
|
||||
max_error_length: int = 400,
|
||||
max_actions_per_step: int = 10,
|
||||
tool_call_in_content: bool = False,
|
||||
use_function_calling: bool = True
|
||||
message_context: Optional[str] = None,
|
||||
use_deepseek_r1: bool = False
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
@@ -45,55 +46,30 @@ class CustomMassageManager(MessageManager):
|
||||
action_descriptions=action_descriptions,
|
||||
system_prompt_class=system_prompt_class,
|
||||
max_input_tokens=max_input_tokens,
|
||||
estimated_tokens_per_character=estimated_tokens_per_character,
|
||||
estimated_characters_per_token=estimated_characters_per_token,
|
||||
image_tokens=image_tokens,
|
||||
include_attributes=include_attributes,
|
||||
max_error_length=max_error_length,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
message_context=message_context
|
||||
)
|
||||
self.use_function_calling = use_function_calling
|
||||
self.tool_id = 1
|
||||
self.use_deepseek_r1 = use_deepseek_r1
|
||||
# Custom: Move Task info to state_message
|
||||
self.history = MessageHistory()
|
||||
self._add_message_with_tokens(self.system_prompt)
|
||||
|
||||
if self.use_function_calling:
|
||||
tool_calls = [
|
||||
{
|
||||
'name': 'CustomAgentOutput',
|
||||
'args': {
|
||||
'current_state': {
|
||||
'prev_action_evaluation': 'Unknown - No previous actions to evaluate.',
|
||||
'important_contents': '',
|
||||
'completed_contents': '',
|
||||
'thought': 'Now Google is open. Need to type OpenAI to search.',
|
||||
'summary': 'Type OpenAI to search.',
|
||||
},
|
||||
'action': [],
|
||||
},
|
||||
'id': '',
|
||||
'type': 'tool_call',
|
||||
}
|
||||
]
|
||||
if self.tool_call_in_content:
|
||||
# openai throws error if tool_calls are not responded -> move to content
|
||||
example_tool_call = AIMessage(
|
||||
content=f'{tool_calls}',
|
||||
tool_calls=[],
|
||||
)
|
||||
else:
|
||||
example_tool_call = AIMessage(
|
||||
content=f'',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
self._add_message_with_tokens(example_tool_call)
|
||||
if self.message_context:
|
||||
context_message = HumanMessage(content=self.message_context)
|
||||
self._add_message_with_tokens(context_message)
|
||||
|
||||
def cut_messages(self):
|
||||
"""Get current message list, potentially trimmed to max tokens"""
|
||||
diff = self.history.total_tokens - self.max_input_tokens
|
||||
while diff > 0 and len(self.history.messages) > 1:
|
||||
self.history.remove_message(1) # alway remove the oldest one
|
||||
min_message_len = 2 if self.message_context is not None else 1
|
||||
|
||||
while diff > 0 and len(self.history.messages) > min_message_len:
|
||||
self.history.remove_message(min_message_len) # alway remove the oldest message
|
||||
diff = self.history.total_tokens - self.max_input_tokens
|
||||
|
||||
def add_state_message(
|
||||
@@ -119,10 +95,10 @@ class CustomMassageManager(MessageManager):
|
||||
tokens = self.llm.get_num_tokens(text)
|
||||
except Exception:
|
||||
tokens = (
|
||||
len(text) // self.ESTIMATED_TOKENS_PER_CHARACTER
|
||||
len(text) // self.estimated_characters_per_token
|
||||
) # Rough estimate if no tokenizer available
|
||||
else:
|
||||
tokens = (
|
||||
len(text) // self.ESTIMATED_TOKENS_PER_CHARACTER
|
||||
len(text) // self.estimated_characters_per_token
|
||||
) # Rough estimate if no tokenizer available
|
||||
return tokens
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pdb
|
||||
from typing import List, Optional
|
||||
|
||||
from browser_use.agent.prompts import SystemPrompt
|
||||
from browser_use.agent.prompts import SystemPrompt, AgentMessagePrompt
|
||||
from browser_use.agent.views import ActionResult
|
||||
from browser_use.browser.views import BrowserState
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
@@ -19,19 +19,14 @@ class CustomSystemPrompt(SystemPrompt):
|
||||
{
|
||||
"current_state": {
|
||||
"prev_action_evaluation": "Success|Failed|Unknown - Analyze the current elements and the image to check if the previous goals/actions are successful like intended by the task. Ignore the action result. The website is the ground truth. Also mention if something unexpected happened like new suggestions in an input field. Shortly state why/why not. Note that the result you output must be consistent with the reasoning you output afterwards. If you consider it to be 'Failed,' you should reflect on this during your thought.",
|
||||
"important_contents": "Output important contents closely related to user\'s instruction or task on the current page. If there is, please output the contents. If not, please output empty string ''.",
|
||||
"important_contents": "Output important contents closely related to user\'s instruction on the current page. If there is, please output the contents. If not, please output empty string ''.",
|
||||
"task_progress": "Task Progress is a general summary of the current contents that have been completed. Just summarize the contents that have been actually completed based on the content at current step and the history operations. Please list each completed item individually, such as: 1. Input username. 2. Input Password. 3. Click confirm button. Please return string type not a list.",
|
||||
"future_plans": "Based on the user's request and the current state, outline the remaining steps needed to complete the task. This should be a concise list of actions yet to be performed, such as: 1. Select a date. 2. Choose a specific time slot. 3. Confirm booking. Please return string type not a list.",
|
||||
"thought": "Think about the requirements that have been completed in previous operations and the requirements that need to be completed in the next one operation. If your output of prev_action_evaluation is 'Failed', please reflect and output your reflection here.",
|
||||
"summary": "Please generate a brief natural language description for the operation in next actions based on your Thought."
|
||||
},
|
||||
"action": [
|
||||
{
|
||||
"action_name": {
|
||||
// action-specific parameters
|
||||
}
|
||||
},
|
||||
// ... more actions in sequence
|
||||
* actions in sequences, please refer to **Common action sequences**. Each output action MUST be formated as: \{action_name\: action_params\}*
|
||||
]
|
||||
}
|
||||
|
||||
@@ -44,7 +39,6 @@ class CustomSystemPrompt(SystemPrompt):
|
||||
{"click_element": {"index": 3}}
|
||||
]
|
||||
- Navigation and extraction: [
|
||||
{"open_new_tab": {}},
|
||||
{"go_to_url": {"url": "https://example.com"}},
|
||||
{"extract_page_content": {}}
|
||||
]
|
||||
@@ -127,7 +121,7 @@ class CustomSystemPrompt(SystemPrompt):
|
||||
AGENT_PROMPT = f"""You are a precise browser automation agent that interacts with websites through structured commands. Your role is to:
|
||||
1. Analyze the provided webpage elements and structure
|
||||
2. Plan a sequence of actions to accomplish the given task
|
||||
3. Respond with valid JSON containing your action sequence and state assessment
|
||||
3. Your final result MUST be a valid JSON as the **RESPONSE FORMAT** described, containing your action sequence and state assessment, No need extra content to expalin.
|
||||
|
||||
Current date and time: {time_str}
|
||||
|
||||
@@ -142,7 +136,7 @@ class CustomSystemPrompt(SystemPrompt):
|
||||
return SystemMessage(content=AGENT_PROMPT)
|
||||
|
||||
|
||||
class CustomAgentMessagePrompt:
|
||||
class CustomAgentMessagePrompt(AgentMessagePrompt):
|
||||
def __init__(
|
||||
self,
|
||||
state: BrowserState,
|
||||
@@ -151,11 +145,12 @@ class CustomAgentMessagePrompt:
|
||||
max_error_length: int = 400,
|
||||
step_info: Optional[CustomAgentStepInfo] = None,
|
||||
):
|
||||
self.state = state
|
||||
self.result = result
|
||||
self.max_error_length = max_error_length
|
||||
self.include_attributes = include_attributes
|
||||
self.step_info = step_info
|
||||
super(CustomAgentMessagePrompt, self).__init__(state=state,
|
||||
result=result,
|
||||
include_attributes=include_attributes,
|
||||
max_error_length=max_error_length,
|
||||
step_info=step_info
|
||||
)
|
||||
|
||||
def get_user_message(self) -> HumanMessage:
|
||||
if self.step_info:
|
||||
@@ -164,8 +159,26 @@ class CustomAgentMessagePrompt:
|
||||
step_info_description = ''
|
||||
|
||||
elements_text = self.state.element_tree.clickable_elements_to_string(include_attributes=self.include_attributes)
|
||||
if not elements_text:
|
||||
|
||||
has_content_above = (self.state.pixels_above or 0) > 0
|
||||
has_content_below = (self.state.pixels_below or 0) > 0
|
||||
|
||||
if elements_text != '':
|
||||
if has_content_above:
|
||||
elements_text = (
|
||||
f'... {self.state.pixels_above} pixels above - scroll or extract content to see more ...\n{elements_text}'
|
||||
)
|
||||
else:
|
||||
elements_text = f'[Start of page]\n{elements_text}'
|
||||
if has_content_below:
|
||||
elements_text = (
|
||||
f'{elements_text}\n... {self.state.pixels_below} pixels below - scroll or extract content to see more ...'
|
||||
)
|
||||
else:
|
||||
elements_text = f'{elements_text}\n[End of page]'
|
||||
else:
|
||||
elements_text = 'empty page'
|
||||
|
||||
state_description = f"""
|
||||
{step_info_description}
|
||||
1. Task: {self.step_info.task}
|
||||
@@ -181,15 +194,17 @@ class CustomAgentMessagePrompt:
|
||||
"""
|
||||
|
||||
if self.result:
|
||||
|
||||
for i, result in enumerate(self.result):
|
||||
if result.extracted_content:
|
||||
state_description += f"\nResult of action {i + 1}/{len(self.result)}: {result.extracted_content}"
|
||||
if result.error:
|
||||
# only use last 300 characters of error
|
||||
error = result.error[-self.max_error_length:]
|
||||
state_description += (
|
||||
f"\nError of action {i + 1}/{len(self.result)}: ...{error}"
|
||||
)
|
||||
if result.include_in_memory:
|
||||
if result.extracted_content:
|
||||
state_description += f"\nResult of previous action {i + 1}/{len(self.result)}: {result.extracted_content}"
|
||||
if result.error:
|
||||
# only use last 300 characters of error
|
||||
error = result.error[-self.max_error_length:]
|
||||
state_description += (
|
||||
f"\nError of previous action {i + 1}/{len(self.result)}: ...{error}"
|
||||
)
|
||||
|
||||
if self.state.screenshot:
|
||||
# Format message for vision model
|
||||
|
||||
@@ -45,7 +45,7 @@ class CustomAgentOutput(AgentOutput):
|
||||
) -> Type["CustomAgentOutput"]:
|
||||
"""Extend actions with custom actions"""
|
||||
return create_model(
|
||||
"AgentOutput",
|
||||
"CustomAgentOutput",
|
||||
__base__=CustomAgentOutput,
|
||||
action=(
|
||||
list[custom_actions],
|
||||
|
||||
@@ -3,11 +3,11 @@ import pdb
|
||||
|
||||
from playwright.async_api import Browser as PlaywrightBrowser
|
||||
from playwright.async_api import (
|
||||
BrowserContext as PlaywrightBrowserContext,
|
||||
BrowserContext as PlaywrightBrowserContext,
|
||||
)
|
||||
from playwright.async_api import (
|
||||
Playwright,
|
||||
async_playwright,
|
||||
Playwright,
|
||||
async_playwright,
|
||||
)
|
||||
from browser_use.browser.browser import Browser
|
||||
from browser_use.browser.context import BrowserContext, BrowserContextConfig
|
||||
@@ -25,96 +25,57 @@ class CustomBrowser(Browser):
|
||||
config: BrowserContextConfig = BrowserContextConfig()
|
||||
) -> CustomBrowserContext:
|
||||
return CustomBrowserContext(config=config, browser=self)
|
||||
|
||||
async def _setup_browser(self, playwright: Playwright) -> PlaywrightBrowser:
|
||||
|
||||
async def _setup_browser_with_instance(self, playwright: Playwright) -> PlaywrightBrowser:
|
||||
"""Sets up and returns a Playwright Browser instance with anti-detection measures."""
|
||||
if self.config.wss_url:
|
||||
browser = await playwright.chromium.connect(self.config.wss_url)
|
||||
return browser
|
||||
elif self.config.chrome_instance_path:
|
||||
import subprocess
|
||||
if not self.config.chrome_instance_path:
|
||||
raise ValueError('Chrome instance path is required')
|
||||
import subprocess
|
||||
|
||||
import requests
|
||||
import requests
|
||||
|
||||
try:
|
||||
# Check if browser is already running
|
||||
response = requests.get('http://localhost:9222/json/version', timeout=2)
|
||||
if response.status_code == 200:
|
||||
logger.info('Reusing existing Chrome instance')
|
||||
browser = await playwright.chromium.connect_over_cdp(
|
||||
endpoint_url='http://localhost:9222',
|
||||
timeout=20000, # 20 second timeout for connection
|
||||
)
|
||||
return browser
|
||||
except requests.ConnectionError:
|
||||
logger.debug('No existing Chrome instance found, starting a new one')
|
||||
|
||||
# Start a new Chrome instance
|
||||
subprocess.Popen(
|
||||
[
|
||||
self.config.chrome_instance_path,
|
||||
'--remote-debugging-port=9222',
|
||||
],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
# Attempt to connect again after starting a new instance
|
||||
for _ in range(10):
|
||||
try:
|
||||
response = requests.get('http://localhost:9222/json/version', timeout=2)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except requests.ConnectionError:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
|
||||
try:
|
||||
try:
|
||||
# Check if browser is already running
|
||||
response = requests.get('http://localhost:9222/json/version', timeout=2)
|
||||
if response.status_code == 200:
|
||||
logger.info('Reusing existing Chrome instance')
|
||||
browser = await playwright.chromium.connect_over_cdp(
|
||||
endpoint_url='http://localhost:9222',
|
||||
timeout=20000, # 20 second timeout for connection
|
||||
)
|
||||
return browser
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to start a new Chrome instance.: {str(e)}')
|
||||
raise RuntimeError(
|
||||
' To start chrome in Debug mode, you need to close all existing Chrome instances and try again otherwise we can not connect to the instance.'
|
||||
)
|
||||
except requests.ConnectionError:
|
||||
logger.debug('No existing Chrome instance found, starting a new one')
|
||||
|
||||
else:
|
||||
# Start a new Chrome instance
|
||||
subprocess.Popen(
|
||||
[
|
||||
self.config.chrome_instance_path,
|
||||
'--remote-debugging-port=9222',
|
||||
] + self.config.extra_chromium_args,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
# try to connect first in case the browser have not started
|
||||
for _ in range(10):
|
||||
try:
|
||||
disable_security_args = []
|
||||
if self.config.disable_security:
|
||||
disable_security_args = [
|
||||
'--disable-web-security',
|
||||
'--disable-site-isolation-trials',
|
||||
'--disable-features=IsolateOrigins,site-per-process',
|
||||
]
|
||||
response = requests.get('http://localhost:9222/json/version', timeout=2)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except requests.ConnectionError:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
|
||||
browser = await playwright.chromium.launch(
|
||||
headless=self.config.headless,
|
||||
args=[
|
||||
'--no-sandbox',
|
||||
'--disable-blink-features=AutomationControlled',
|
||||
'--disable-infobars',
|
||||
'--disable-background-timer-throttling',
|
||||
'--disable-popup-blocking',
|
||||
'--disable-backgrounding-occluded-windows',
|
||||
'--disable-renderer-backgrounding',
|
||||
'--disable-window-activation',
|
||||
'--disable-focus-on-load',
|
||||
'--no-first-run',
|
||||
'--no-default-browser-check',
|
||||
'--no-startup-window',
|
||||
'--window-position=0,0',
|
||||
# '--window-size=1280,1000',
|
||||
]
|
||||
+ disable_security_args
|
||||
+ self.config.extra_chromium_args,
|
||||
proxy=self.config.proxy,
|
||||
)
|
||||
|
||||
return browser
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to initialize Playwright browser: {str(e)}')
|
||||
raise
|
||||
# Attempt to connect again after starting a new instance
|
||||
try:
|
||||
browser = await playwright.chromium.connect_over_cdp(
|
||||
endpoint_url='http://localhost:9222',
|
||||
timeout=20000, # 20 second timeout for connection
|
||||
)
|
||||
return browser
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to start a new Chrome instance.: {str(e)}')
|
||||
raise RuntimeError(
|
||||
' To start chrome in Debug mode, you need to close all existing Chrome instances and try again otherwise we can not connect to the instance.'
|
||||
)
|
||||
@@ -16,73 +16,4 @@ class CustomBrowserContext(BrowserContext):
|
||||
browser: "Browser",
|
||||
config: BrowserContextConfig = BrowserContextConfig()
|
||||
):
|
||||
super(CustomBrowserContext, self).__init__(browser=browser, config=config)
|
||||
|
||||
async def _create_context(self, browser: PlaywrightBrowser) -> PlaywrightBrowserContext:
|
||||
"""Creates a new browser context with anti-detection measures and loads cookies if available."""
|
||||
# If we have a context, return it directly
|
||||
|
||||
# Check if we should use existing context for persistence
|
||||
if self.browser.config.chrome_instance_path and len(browser.contexts) > 0:
|
||||
# Connect to existing Chrome instance instead of creating new one
|
||||
context = browser.contexts[0]
|
||||
else:
|
||||
# Original code for creating new context
|
||||
context = await browser.new_context(
|
||||
viewport=self.config.browser_window_size,
|
||||
no_viewport=False,
|
||||
user_agent=(
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Chrome/85.0.4183.102 Safari/537.36"
|
||||
),
|
||||
java_script_enabled=True,
|
||||
bypass_csp=self.config.disable_security,
|
||||
ignore_https_errors=self.config.disable_security,
|
||||
record_video_dir=self.config.save_recording_path,
|
||||
record_video_size=self.config.browser_window_size,
|
||||
)
|
||||
|
||||
if self.config.trace_path:
|
||||
await context.tracing.start(screenshots=True, snapshots=True, sources=True)
|
||||
|
||||
# Load cookies if they exist
|
||||
if self.config.cookies_file and os.path.exists(self.config.cookies_file):
|
||||
with open(self.config.cookies_file, "r") as f:
|
||||
cookies = json.load(f)
|
||||
logger.info(
|
||||
f"Loaded {len(cookies)} cookies from {self.config.cookies_file}"
|
||||
)
|
||||
await context.add_cookies(cookies)
|
||||
|
||||
# Expose anti-detection scripts
|
||||
await context.add_init_script(
|
||||
"""
|
||||
// Webdriver property
|
||||
Object.defineProperty(navigator, 'webdriver', {
|
||||
get: () => undefined
|
||||
});
|
||||
|
||||
// Languages
|
||||
Object.defineProperty(navigator, 'languages', {
|
||||
get: () => ['en-US', 'en']
|
||||
});
|
||||
|
||||
// Plugins
|
||||
Object.defineProperty(navigator, 'plugins', {
|
||||
get: () => [1, 2, 3, 4, 5]
|
||||
});
|
||||
|
||||
// Chrome runtime
|
||||
window.chrome = { runtime: {} };
|
||||
|
||||
// Permissions
|
||||
const originalQuery = window.navigator.permissions.query;
|
||||
window.navigator.permissions.query = (parameters) => (
|
||||
parameters.name === 'notifications' ?
|
||||
Promise.resolve({ state: Notification.permission }) :
|
||||
originalQuery(parameters)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
return context
|
||||
super(CustomBrowserContext, self).__init__(browser=browser, config=config)
|
||||
@@ -1,12 +1,16 @@
|
||||
import pyperclip
|
||||
from typing import Optional, Type
|
||||
from pydantic import BaseModel
|
||||
from browser_use.agent.views import ActionResult
|
||||
from browser_use.browser.context import BrowserContext
|
||||
from browser_use.controller.service import Controller
|
||||
|
||||
|
||||
class CustomController(Controller):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, exclude_actions: list[str] = [],
|
||||
output_model: Optional[Type[BaseModel]] = None
|
||||
):
|
||||
super().__init__(exclude_actions=exclude_actions, output_model=output_model)
|
||||
self._register_custom_actions()
|
||||
|
||||
def _register_custom_actions(self):
|
||||
|
||||
@@ -11,7 +11,7 @@ def default_config():
|
||||
"max_steps": 100,
|
||||
"max_actions_per_step": 10,
|
||||
"use_vision": True,
|
||||
"tool_call_in_content": True,
|
||||
"tool_calling_method": "auto",
|
||||
"llm_provider": "openai",
|
||||
"llm_model_name": "gpt-4o",
|
||||
"llm_temperature": 1.0,
|
||||
@@ -56,7 +56,7 @@ def save_current_config(*args):
|
||||
"max_steps": args[1],
|
||||
"max_actions_per_step": args[2],
|
||||
"use_vision": args[3],
|
||||
"tool_call_in_content": args[4],
|
||||
"tool_calling_method": args[4],
|
||||
"llm_provider": args[5],
|
||||
"llm_model_name": args[6],
|
||||
"llm_temperature": args[7],
|
||||
@@ -86,7 +86,7 @@ def update_ui_from_config(config_file):
|
||||
gr.update(value=loaded_config.get("max_steps", 100)),
|
||||
gr.update(value=loaded_config.get("max_actions_per_step", 10)),
|
||||
gr.update(value=loaded_config.get("use_vision", True)),
|
||||
gr.update(value=loaded_config.get("tool_call_in_content", True)),
|
||||
gr.update(value=loaded_config.get("tool_calling_method", True)),
|
||||
gr.update(value=loaded_config.get("llm_provider", "openai")),
|
||||
gr.update(value=loaded_config.get("llm_model_name", "gpt-4o")),
|
||||
gr.update(value=loaded_config.get("llm_temperature", 1.0)),
|
||||
|
||||
@@ -25,6 +25,7 @@ from langchain_core.outputs import (
|
||||
LLMResult,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.runnables import Runnable, RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
@@ -98,4 +99,38 @@ class DeepSeekR1ChatOpenAI(ChatOpenAI):
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
content = response.choices[0].message.content
|
||||
return AIMessage(content=content, reasoning_content=reasoning_content)
|
||||
|
||||
class DeepSeekR1ChatOllama(ChatOllama):
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
org_ai_message = await super().ainvoke(input=input)
|
||||
org_content = org_ai_message.content
|
||||
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
|
||||
content = org_content.split("</think>")[1]
|
||||
if "**JSON Response:**" in content:
|
||||
content = content.split("**JSON Response:**")[-1]
|
||||
return AIMessage(content=content, reasoning_content=reasoning_content)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
org_ai_message = super().invoke(input=input)
|
||||
org_content = org_ai_message.content
|
||||
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
|
||||
content = org_content.split("</think>")[1]
|
||||
if "**JSON Response:**" in content:
|
||||
content = content.split("**JSON Response:**")[-1]
|
||||
return AIMessage(content=content, reasoning_content=reasoning_content)
|
||||
@@ -10,7 +10,7 @@ from langchain_ollama import ChatOllama
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
import gradio as gr
|
||||
|
||||
from .llm import DeepSeekR1ChatOpenAI
|
||||
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
|
||||
|
||||
def get_llm_model(provider: str, **kwargs):
|
||||
"""
|
||||
@@ -94,12 +94,20 @@ def get_llm_model(provider: str, **kwargs):
|
||||
else:
|
||||
base_url = kwargs.get("base_url")
|
||||
|
||||
return ChatOllama(
|
||||
model=kwargs.get("model_name", "qwen2.5:7b"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
num_ctx=kwargs.get("num_ctx", 32000),
|
||||
base_url=base_url,
|
||||
)
|
||||
if kwargs.get("model_name", "qwen2.5:7b").startswith("deepseek-r1"):
|
||||
return DeepSeekR1ChatOllama(
|
||||
model=kwargs.get("model_name", "deepseek-r1:7b"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
num_ctx=kwargs.get("num_ctx", 32000),
|
||||
base_url=kwargs.get("base_url", base_url),
|
||||
)
|
||||
else:
|
||||
return ChatOllama(
|
||||
model=kwargs.get("model_name", "qwen2.5:7b"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
num_ctx=kwargs.get("num_ctx", 32000),
|
||||
base_url=kwargs.get("base_url", base_url),
|
||||
)
|
||||
elif provider == "azure_openai":
|
||||
if not kwargs.get("base_url", ""):
|
||||
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
|
||||
@@ -125,7 +133,7 @@ model_names = {
|
||||
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
|
||||
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
|
||||
"gemini": ["gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest", "gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-1219" ],
|
||||
"ollama": ["qwen2.5:7b", "llama2:7b"],
|
||||
"ollama": ["qwen2.5:7b", "llama2:7b", "deepseek-r1:14b", "deepseek-r1:32b"],
|
||||
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"]
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,15 @@ async def test_browser_use_org():
|
||||
|
||||
window_w, window_h = 1920, 1080
|
||||
use_vision = False
|
||||
chrome_path = os.getenv("CHROME_PATH", None)
|
||||
use_own_browser = False
|
||||
if use_own_browser:
|
||||
chrome_path = os.getenv("CHROME_PATH", None)
|
||||
if chrome_path == "":
|
||||
chrome_path = None
|
||||
else:
|
||||
chrome_path = None
|
||||
|
||||
tool_calling_method = "json_schema" # setting to json_schema when using ollma
|
||||
|
||||
browser = Browser(
|
||||
config=BrowserConfig(
|
||||
@@ -64,7 +72,8 @@ async def test_browser_use_org():
|
||||
task="go to google.com and type 'OpenAI' click search and give me the first url",
|
||||
llm=llm,
|
||||
browser_context=browser_context,
|
||||
use_vision=use_vision
|
||||
use_vision=use_vision,
|
||||
tool_calling_method=tool_calling_method
|
||||
)
|
||||
history: AgentHistoryList = await agent.run(max_steps=10)
|
||||
|
||||
@@ -242,22 +251,32 @@ async def test_browser_use_custom_v2():
|
||||
# api_key=os.getenv("GOOGLE_API_KEY", "")
|
||||
# )
|
||||
|
||||
llm = utils.get_llm_model(
|
||||
provider="deepseek",
|
||||
model_name="deepseek-reasoner",
|
||||
temperature=0.8
|
||||
)
|
||||
# llm = utils.get_llm_model(
|
||||
# provider="deepseek",
|
||||
# model_name="deepseek-reasoner",
|
||||
# temperature=0.8
|
||||
# )
|
||||
|
||||
# llm = utils.get_llm_model(
|
||||
# provider="deepseek",
|
||||
# model_name="deepseek-chat",
|
||||
# temperature=0.8
|
||||
# )
|
||||
|
||||
# llm = utils.get_llm_model(
|
||||
# provider="ollama", model_name="qwen2.5:7b", temperature=0.5
|
||||
# )
|
||||
|
||||
# llm = utils.get_llm_model(
|
||||
# provider="ollama", model_name="deepseek-r1:14b", temperature=0.5
|
||||
# )
|
||||
|
||||
controller = CustomController()
|
||||
use_own_browser = False
|
||||
disable_security = True
|
||||
use_vision = False # Set to False when using DeepSeek
|
||||
tool_call_in_content = True # Set to True when using Ollama
|
||||
max_actions_per_step = 1
|
||||
|
||||
max_actions_per_step = 10
|
||||
playwright = None
|
||||
browser = None
|
||||
browser_context = None
|
||||
@@ -288,7 +307,7 @@ async def test_browser_use_custom_v2():
|
||||
)
|
||||
)
|
||||
agent = CustomAgent(
|
||||
task="go to google.com and type 'OpenAI' click search and give me the first url",
|
||||
task="go to google.com and type 'Nvidia' click search and give me the first url",
|
||||
add_infos="", # some hints for llm to complete the task
|
||||
llm=llm,
|
||||
browser=browser,
|
||||
@@ -296,7 +315,6 @@ async def test_browser_use_custom_v2():
|
||||
controller=controller,
|
||||
system_prompt_class=CustomSystemPrompt,
|
||||
use_vision=use_vision,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
max_actions_per_step=max_actions_per_step
|
||||
)
|
||||
history: AgentHistoryList = await agent.run(max_steps=10)
|
||||
|
||||
@@ -142,6 +142,14 @@ def test_ollama_model():
|
||||
llm = ChatOllama(model="qwen2.5:7b")
|
||||
ai_msg = llm.invoke("Sing a ballad of LangChain.")
|
||||
print(ai_msg.content)
|
||||
|
||||
def test_deepseek_r1_ollama_model():
|
||||
from src.utils.llm import DeepSeekR1ChatOllama
|
||||
|
||||
llm = DeepSeekR1ChatOllama(model="deepseek-r1:14b")
|
||||
ai_msg = llm.invoke("how many r in strawberry?")
|
||||
print(ai_msg.content)
|
||||
pdb.set_trace()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -151,3 +159,4 @@ if __name__ == '__main__':
|
||||
# test_deepseek_model()
|
||||
# test_ollama_model()
|
||||
test_deepseek_r1_model()
|
||||
# test_deepseek_r1_ollama_model()
|
||||
91
webui.py
91
webui.py
@@ -21,6 +21,7 @@ from browser_use.browser.context import (
|
||||
BrowserContextConfig,
|
||||
BrowserContextWindowSize,
|
||||
)
|
||||
from langchain_ollama import ChatOllama
|
||||
from playwright.async_api import async_playwright
|
||||
from src.utils.agent_state import AgentState
|
||||
|
||||
@@ -91,7 +92,7 @@ async def run_browser_agent(
|
||||
max_steps,
|
||||
use_vision,
|
||||
max_actions_per_step,
|
||||
tool_call_in_content
|
||||
tool_calling_method
|
||||
):
|
||||
global _global_agent_state
|
||||
_global_agent_state.clear_stop() # Clear any previous stop requests
|
||||
@@ -137,7 +138,7 @@ async def run_browser_agent(
|
||||
max_steps=max_steps,
|
||||
use_vision=use_vision,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
tool_calling_method=tool_calling_method
|
||||
)
|
||||
elif agent_type == "custom":
|
||||
final_result, errors, model_actions, model_thoughts, trace_file, history_file = await run_custom_agent(
|
||||
@@ -156,7 +157,7 @@ async def run_browser_agent(
|
||||
max_steps=max_steps,
|
||||
use_vision=use_vision,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
tool_calling_method=tool_calling_method
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type: {agent_type}")
|
||||
@@ -215,7 +216,7 @@ async def run_org_agent(
|
||||
max_steps,
|
||||
use_vision,
|
||||
max_actions_per_step,
|
||||
tool_call_in_content
|
||||
tool_calling_method
|
||||
):
|
||||
try:
|
||||
global _global_browser, _global_browser_context, _global_agent_state
|
||||
@@ -251,7 +252,7 @@ async def run_org_agent(
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
agent = Agent(
|
||||
task=task,
|
||||
llm=llm,
|
||||
@@ -259,7 +260,7 @@ async def run_org_agent(
|
||||
browser=_global_browser,
|
||||
browser_context=_global_browser_context,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
tool_calling_method=tool_calling_method
|
||||
)
|
||||
history = await agent.run(max_steps=max_steps)
|
||||
|
||||
@@ -306,7 +307,7 @@ async def run_custom_agent(
|
||||
max_steps,
|
||||
use_vision,
|
||||
max_actions_per_step,
|
||||
tool_call_in_content
|
||||
tool_calling_method
|
||||
):
|
||||
try:
|
||||
global _global_browser, _global_browser_context, _global_agent_state
|
||||
@@ -345,7 +346,7 @@ async def run_custom_agent(
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Create and run agent
|
||||
agent = CustomAgent(
|
||||
task=task,
|
||||
@@ -357,8 +358,8 @@ async def run_custom_agent(
|
||||
controller=controller,
|
||||
system_prompt_class=CustomSystemPrompt,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
agent_state=_global_agent_state
|
||||
agent_state=_global_agent_state,
|
||||
tool_calling_method=tool_calling_method
|
||||
)
|
||||
history = await agent.run(max_steps=max_steps)
|
||||
|
||||
@@ -411,7 +412,7 @@ async def run_with_stream(
|
||||
max_steps,
|
||||
use_vision,
|
||||
max_actions_per_step,
|
||||
tool_call_in_content
|
||||
tool_calling_method
|
||||
):
|
||||
global _global_agent_state
|
||||
stream_vw = 80
|
||||
@@ -439,7 +440,7 @@ async def run_with_stream(
|
||||
max_steps=max_steps,
|
||||
use_vision=use_vision,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
tool_calling_method=tool_calling_method
|
||||
)
|
||||
# Add HTML content at the start of the result array
|
||||
html_content = f"<h1 style='width:{stream_vw}vw; height:{stream_vh}vh'>Using browser...</h1>"
|
||||
@@ -471,7 +472,7 @@ async def run_with_stream(
|
||||
max_steps=max_steps,
|
||||
use_vision=use_vision,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
tool_calling_method=tool_calling_method
|
||||
)
|
||||
)
|
||||
|
||||
@@ -628,32 +629,38 @@ def create_ui(config, theme_name="Ocean"):
|
||||
value=config['agent_type'],
|
||||
info="Select the type of agent to use",
|
||||
)
|
||||
max_steps = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=200,
|
||||
value=config['max_steps'],
|
||||
step=1,
|
||||
label="Max Run Steps",
|
||||
info="Maximum number of steps the agent will take",
|
||||
)
|
||||
max_actions_per_step = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=20,
|
||||
value=config['max_actions_per_step'],
|
||||
step=1,
|
||||
label="Max Actions per Step",
|
||||
info="Maximum number of actions the agent will take per step",
|
||||
)
|
||||
use_vision = gr.Checkbox(
|
||||
label="Use Vision",
|
||||
value=config['use_vision'],
|
||||
info="Enable visual processing capabilities",
|
||||
)
|
||||
tool_call_in_content = gr.Checkbox(
|
||||
label="Use Tool Calls in Content",
|
||||
value=config['tool_call_in_content'],
|
||||
info="Enable Tool Calls in content",
|
||||
)
|
||||
with gr.Column():
|
||||
max_steps = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=200,
|
||||
value=config['max_steps'],
|
||||
step=1,
|
||||
label="Max Run Steps",
|
||||
info="Maximum number of steps the agent will take",
|
||||
)
|
||||
max_actions_per_step = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=20,
|
||||
value=config['max_actions_per_step'],
|
||||
step=1,
|
||||
label="Max Actions per Step",
|
||||
info="Maximum number of actions the agent will take per step",
|
||||
)
|
||||
with gr.Column():
|
||||
use_vision = gr.Checkbox(
|
||||
label="Use Vision",
|
||||
value=config['use_vision'],
|
||||
info="Enable visual processing capabilities",
|
||||
)
|
||||
tool_calling_method = gr.Dropdown(
|
||||
label="Tool Calling Method",
|
||||
value=config['tool_calling_method'],
|
||||
interactive=True,
|
||||
allow_custom_value=True, # Allow users to input custom model names
|
||||
choices=["auto", "json_schema", "function_calling"],
|
||||
info="Tool Calls Funtion Name",
|
||||
visible=False
|
||||
)
|
||||
|
||||
with gr.TabItem("🔧 LLM Configuration", id=2):
|
||||
with gr.Group():
|
||||
@@ -803,7 +810,7 @@ def create_ui(config, theme_name="Ocean"):
|
||||
fn=update_ui_from_config,
|
||||
inputs=[config_file_input],
|
||||
outputs=[
|
||||
agent_type, max_steps, max_actions_per_step, use_vision, tool_call_in_content,
|
||||
agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method,
|
||||
llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
|
||||
use_own_browser, keep_browser_open, headless, disable_security, enable_recording,
|
||||
window_w, window_h, save_recording_path, save_trace_path, save_agent_history_path,
|
||||
@@ -814,7 +821,7 @@ def create_ui(config, theme_name="Ocean"):
|
||||
save_config_button.click(
|
||||
fn=save_current_config,
|
||||
inputs=[
|
||||
agent_type, max_steps, max_actions_per_step, use_vision, tool_call_in_content,
|
||||
agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method,
|
||||
llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
|
||||
use_own_browser, keep_browser_open, headless, disable_security,
|
||||
enable_recording, window_w, window_h, save_recording_path, save_trace_path,
|
||||
@@ -866,7 +873,7 @@ def create_ui(config, theme_name="Ocean"):
|
||||
agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
|
||||
use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h,
|
||||
save_recording_path, save_agent_history_path, save_trace_path, # Include the new path
|
||||
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_call_in_content
|
||||
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_calling_method
|
||||
],
|
||||
outputs=[
|
||||
browser_view, # Browser view
|
||||
|
||||
Reference in New Issue
Block a user