Added custom actions registry and fixed extraction layer (#20)

* Validator

* Test mind2web

* Cleaned up logger

* Pytest logger

* Cleaned up logger

* Disable flag for human input

* Multiple clicks per button

* Multiple clicks per button

* More structured system prompt

* Fields with description

* System prompt example

* One logger

* Cleaner logging

* Log step in step function

* Fix critical clicking error - wrong argument used

* Improved thought process of agent

* Improve system prompt

* Remove human input message

* Custome action registration

* Pydantic model for custom actions

* Pydantic model for custome output

* Runs through, model outputs functions, but not called yet

* Work in progress - description for custome actions

* Description works, but schema not yet

* Model can call the right action - but is not executed

* Seperate is_controller_action  and is_custom_action

* Works! Model can call custom function

* Use registry for action, but result is not feed back to model

* Include result in messages

* Works with custom function - but typing is not correct

* Renamed registry

* First test cases

* Captcha tests

* Pydantic for tests

* Improve prompts for multy step

* System prompt structure

* Handle errors like validation error

* Refactor error handling in agent

* Refactor error handling in agent

* Improved logging

* Update view

* Fix click parameter to index

* Simplify dynamic actions

* Use run instead of step

* Rename history

* Rename AgentService to Agent

* Rename ControllerService to Controller

* Pytest file

* Rename get state

* Rename BrowserService

* reversed dom extraction recursion to while

* Rename use_vision

* Rename use_vision

* reversed dom tree items and made browser less anoying

* Renaming and fixing type errors

* Renamed class names for agent

* updated requirements

* Update prompt

* Action registration works for user and controller

* Fix done call by returning ActionResult

* Fix if result is none

* Rename AgentOutput and ActionModel

* Improved prompt Passes 6/8 tests from test_agent_actions

* Calculate token cost

* Improve display

* Simplified logger

* Test function calling

* created super simple xpath extraction algo

* Tests logging

* tiny fixes to dom extraction

* Remove test

* Dont log number of clicks

* Pytest file

* merged per element js checks

* Check if driver is still open

* super fast processing

* fixed agent planning and stuff

* Fix example

* Fix example

* Improve error

* Improved error correction

* New line for step

* small type error fixes

* Test for pydantic

* Fix line

* Removed sample

* fixed readme and examples

---------

Co-authored-by: magmueller <mamagnus00@gmail.com>
This commit is contained in:
Gregor Žunič
2024-11-15 21:42:02 +01:00
committed by GitHub
parent 5b5ee3e9b3
commit 89c63fdd63
42 changed files with 19263 additions and 1342 deletions

6
.gitignore vendored
View File

@@ -160,4 +160,8 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
temp
temp
tmp
.DS_Store

53
.vscode/launch.json vendored
View File

@@ -2,20 +2,51 @@
"version": "0.2.0",
"configurations": [
{
"name": "Python: Debug Tests",
"name": "Python Debugger: Module",
"type": "debugpy",
"request": "launch",
"module": "examples.extend_actions"
},
{
"name": "Python: Debug extend_actions",
"type": "module",
"request": "launch",
"module": "examples.extend_actions",
"console": "integratedTerminal",
"justMyCode": false,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
},
{
"name": "Python: Debug Captcha Tests",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/pytest",
"module": "pytest",
"args": [
"src/tests/test_kayak_search.py",
"tests/test_agent_actions.py",
"-v",
"-s"
"-k",
"test_captcha_solver",
"--capture=no",
],
"console": "integratedTerminal",
"justMyCode": false,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
}
]
}
"justMyCode": false
},
{
"name": "Python: Debug Ecommerce Interaction",
"type": "python",
"request": "launch",
"module": "pytest",
"args": [
"tests/test_agent_actions.py",
"-v",
"-k",
"test_ecommerce_interaction",
"--capture=no",
],
"console": "integratedTerminal",
"justMyCode": false
}
]
}

231
README.md
View File

@@ -1,30 +1,30 @@
<div align="center">
# 🌐 Browser Use
# 🌐 Browser-Use
### Open-Source Web Automation with LLMs
Make websites accessible for AI agents 🤖.
[![GitHub stars](https://img.shields.io/github/stars/gregpr07/browser-use?style=social)](https://github.com/gregpr07/browser-use/stargazers)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/)
[![Discord](https://img.shields.io/discord/1303749220842340412?color=7289DA&label=Discord&logo=discord&logoColor=white)](https://discord.gg/uaCtrbbv)
[![Discord](https://img.shields.io/discord/1303749220842340412?color=7289DA&label=Discord&logo=discord&logoColor=white)](https://link.browser-use.com/discord)
</div>
Browser use is the easiest way to connect your AI agents with the browser. If you have used Browser Use for your project feel free to show it off in our [Discord](https://link.browser-use.com/discord).
Let LLMs interact with websites through a simple interface.
# Quick start
## Short Example
With pip:
```bash
pip install browser-use
```
Spin up your agent:
```python
from langchain_openai import ChatOpenAI
from browser_use import Agent
agent = Agent(
task="Go to hackernews on show hn and give me top 10 post titles, their points and hours. Calculate for each the ratio of points per hour.",
task="Find a one-way flight from Bali to Oman on 12 January 2025 on Google Flights. Return me the cheapest option.",
llm=ChatOpenAI(model="gpt-4o"),
)
@@ -32,42 +32,94 @@ agent = Agent(
await agent.run()
```
## Demo
And don't forget to add your API keys to your `.env` file.
<div>
<a href="https://www.loom.com/share/63612b5994164cb1bb36938d62fe9983">
<img style="max-width:300px;" src="https://cdn.loom.com/sessions/thumbnails/63612b5994164cb1bb36938d62fe9983-7133f9e169672e6f-full-play.gif">
</a>
<p><i>Prompt: Go to hackernews on show hn and give me top 10 post titles, their points and hours. Calculate for each the ratio of points per hour. (1x speed) </i></p>
</div>
<div>
<a href="https://www.loom.com/share/2af938b9f8024647950a9e18b3946054">
<img style="max-width:300px;" src="https://cdn.loom.com/sessions/thumbnails/2af938b9f8024647950a9e18b3946054-b99c733cf670e568-full-play.gif">
</a>
<p><i>Prompt: Search the top 3 AI companies 2024 and find what out what concrete hardware each is using for their model. (1x speed)</i></p>
</div>
```bash
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
```
# Demo
<div style="display: flex; justify-content: space-between; margin-top: 20px;">
<div style="flex: 1; margin-right: 10px;">
<img style="width: 100%;" src="./static/kayak.gif" alt="Kayak flight search demo">
<p><i>Prompt: Go to kayak.com and find a one-way flight from Zürich to San Francisco on 12 January 2025. (2.5x speed)</i></p>
</div>
<div style="flex: 1; margin-left: 10px;">
<img style="width: 100%;" src="./static/photos.gif" alt="Photos search demo">
<p><i>Prompt: Opening new tabs and searching for images for these people: Albert Einstein, Oprah Winfrey, Steve Jobs. (2.5x speed)</i></p>
</div>
</div>
</div>
DEMO VIDEO HERE
## Local Setup
# Features ⭐
- Vision + html extraction
- Automatic multi-tab management
- Extract clicked elements XPaths
- Add custom actions (e.g. add data to database which the LLM can use)
- Self-correcting
- Use any LLM supported by LangChain (e.g. gpt4o, gpt4o mini, claude 3.5 sonnet, llama 3.1 405b, etc.)
## Register custom actions
If you want to add custom actions your agent can take, you can register them like this:
```python
from browser_use.agent.service import Agent
from browser_use.browser.service import Browser
from browser_use.controller.service import Controller
# Initialize controller first
controller = Controller()
@controller.action('Ask user for information')
def ask_human(question: str, display_question: bool) -> str:
return input(f'\n{question}\nInput: ')
```
Or define your parameters using Pydantic
```python
class JobDetails(BaseModel):
title: str
company: str
job_link: str
salary: Optional[str] = None
@controller.action('Save job details which you found on page', param_model=JobDetails, requires_browser=True)
def save_job(params: JobDetails, browser: Browser):
print(params)
# use the browser normally
browser.driver.get(params.job_link)
```
and then run your agent:
```python
model = ChatAnthropic(model_name='claude-3-5-sonnet-20240620', timeout=25, stop=None, temperature=0.3)
agent = Agent(task=task, llm=model, controller=controller)
await agent.run()
```
## Get XPath history
To get the entire history of everything the agent has done, you can use the output of the `run` method:
```python
history: list[AgentHistory] = await agent.run()
print(history)
```
## More examples
For more examples see the [examples](examples) folder or join the [Discord](https://link.browser-use.com/discord) and show off your project.
# Contributing
Contributions are welcome! Feel free to open issues for bugs or feature requests.
## Setup
1. Create a virtual environment and install dependencies:
```bash
# To install all dependencies including dev
pip install . ."[dev]"
pip install -r requirements.txt -r requirements-dev.txt
```
2. Add your API keys to the `.env` file:
@@ -76,119 +128,22 @@ pip install . ."[dev]"
cp .env.example .env
```
E.g. for OpenAI:
or copy the following to your `.env` file:
```bash
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
```
You can use any LLM model supported by LangChain by adding the appropriate environment variables. See [langchain models](https://python.langchain.com/docs/integrations/chat/) for available options.
## Features
- Universal LLM Support - Works with any Language Model
- Interactive Element Detection - Automatically finds interactive elements
- Multi-Tab Management - Seamless handling of browser tabs
- XPath Extraction for scraping functions - No more manual DevTools inspection
- Vision Model Support - Process visual page information
- Customizable Actions - Add your own browser interactions (e.g. add data to database which the LLM can use)
- Handles dynamic content - dont worry about cookies or changing content
- Chain-of-thought prompting with memory - Solve long-term tasks
- Self-correcting - If the LLM makes a mistake, the agent will self-correct its actions
## Advanced Examples
### Chain of Agents
You can persist the browser across multiple agents and chain them together.
```python
from asyncio import run
from browser_use import Agent, Controller
from dotenv import load_dotenv
from langchain_anthropic import ChatAnthropic
load_dotenv()
# Persist browser state across agents
controller = Controller()
# Initialize browser agent
agent1 = Agent(
task="Open 3 VCs websites in the New York area.",
llm=ChatAnthropic(model="claude-3-5-sonnet-20240620", timeout=25, stop=None),
controller=controller)
agent2 = Agent(
task="Give me the names of the founders of the companies in all tabs.",
llm=ChatAnthropic(model="claude-3-5-sonnet-20240620", timeout=25, stop=None),
controller=controller)
run(agent1.run())
founders, history = run(agent2.run())
print(founders)
```
You can use the `history` to run the agents again deterministically.
## Command Line Usage
Run examples directly from the command line (clone the repo first):
### Building the package
```bash
python examples/try.py "Your query here" --provider [openai|anthropic]
hatch build
```
### Anthropic
You need to add `ANTHROPIC_API_KEY` to your environment variables. Example usage:
```bash
python examples/try.py "Search the top 3 AI companies 2024 and find out in 3 new tabs what hardware each is using for their models" --provider anthropic
```
### OpenAI
You need to add `OPENAI_API_KEY` to your environment variables. Example usage:
```bash
python examples/try.py "Go to hackernews on show hn and give me top 10 post titles, their points and hours. Calculate for each the ratio of points per hour. " --provider anthropic
```
## 🤖 Supported Models
All LangChain chat models are supported. Tested with:
- GPT-4o
- GPT-4o Mini
- Claude 3.5 Sonnet
- LLama 3.1 405B
## Limitations
- When extracting page content, the message length increases and the LLM gets slower.
- Currently one agent costs about 0.01$
- Sometimes it tries to repeat the same task over and over again.
- Some elements might not be extracted which you want to interact with.
- What should we focus on the most?
- Robustness
- Speed
- Cost reduction
## Roadmap
- [x] Save agent actions and execute them deterministically
- [ ] Pydantic forced output
- [ ] Third party SERP API for faster Google Search results
- [ ] Multi-step action execution to increase speed
- [ ] Test on mind2web dataset
- [ ] Add more browser actions
## Contributing
Contributions are welcome! Feel free to open issues for bugs or feature requests.
Feel free to join the [Discord](https://discord.gg/uaCtrbbv) for discussions and support.
Feel free to join the [Discord](https://link.browser-use.com/discord) for discussions and support.
---

View File

@@ -1,6 +1,10 @@
from browser_use.agent.service import AgentService as Agent
from browser_use.browser.service import BrowserService as Browser
from browser_use.controller.service import ControllerService as Controller
from browser_use.logging_config import setup_logging
setup_logging()
from browser_use.agent.service import Agent as Agent
from browser_use.browser.service import Browser as Browser
from browser_use.controller.service import Controller as Controller
from browser_use.dom.service import DomService
__all__ = ['Agent', 'Browser', 'Controller', 'DomService']

View File

@@ -1,12 +1,73 @@
from datetime import datetime
from langchain_core.messages import HumanMessage, SystemMessage
from browser_use.controller.views import ControllerPageState
from browser_use.browser.views import BrowserState
class AgentSystemPrompt:
def __init__(self, task: str, default_action_description: str):
def __init__(self, task: str, action_description: str, current_date: datetime):
self.task = task
self.default_action_description = default_action_description
self.default_action_description = action_description
self.current_date = current_date
def response_format(self) -> str:
"""
Returns the response format for the agent.
Returns:
str: Response format
"""
return """
{{
"current_state": {{
"valuation_previous_goal": "String starting with "Success", "Failed:" or "Unknown" to evaluate if the previous next_goal is achieved. If failed or unknown describe why.",
"memory": "Your memory with things you need to remeber until the end of the task for the user. You can also store overall progress in a bigger task. You have access to this in the next steps.",
"next_goal": "String describing the next immediate goal which can be achieved with one action"
}},
"action": {{
// EXACTLY ONE of the following available actions must be specified
}}
}}"""
def example_response(self) -> str:
"""
Returns an example response for the agent.
Returns:
str: Example response
"""
return """{"current_state": {"valuation_previous_goal": "Success", "memory": "We applied already for 3/7 jobs, 1. ..., 2. ..., 3. ...", "next_goal": "Click on the button x to apply for the next job"}, "action": {"click_element": {"index": 44,"num_clicks": 2}}}"""
def important_rules(self) -> str:
"""
Returns the important rules for the agent.
Returns:
str: Important rules
"""
return """
1. Only use indexes that exist in the input list for click or input text actions. If no indexes exist, try alternative actions, e.g. go back, search google etc.
2. If stuck, try alternative approaches, e.g. go back, search google, or extract_page_content
3. When you are done with the complete task, use the done action. Make sure to have all information the user needs and return the result.
4. If an image is provided, use it to understand the context, the bounding boxes around the buttons have the same indexes as the interactive elements.
6. ALWAYS respond in the RESPONSE FORMAT with valid JSON.
7. If the page is empty use actions like "go_to_url", "search_google" or "open_tab"
8. Remember: Choose EXACTLY ONE action per response. Invalid combinations or multiple actions will be rejected.
9. If popups like cookies appear, accept or close them
"""
def input_format(self) -> str:
return """
Example:
33[:]\t<button>Interactive element</button>
_[:] Text content...
Explanation:
index[:] Interactible element with index. You can only interact with all elements which are clickable and refer to them by their index.
_[:] elements are just for more context, but not interactable.
\t: Tab indent (1 tab for depth 1 etc.). This is to help you understand which elements belong to each other.
"""
def get_system_message(self) -> SystemMessage:
"""
@@ -15,56 +76,31 @@ class AgentSystemPrompt:
Returns:
str: Formatted system prompt
"""
# System prompts for the agent
# output_format = """
# {"valuation_previous_goal": "Success if completed, else short sentence of why not successful.", "goal": "short description what you want to achieve", "action": "action_name", "params": {"param_name": "param_value"}}
# """
time_str = self.current_date.strftime('%Y-%m-%d %H:%M')
AGENT_PROMPT = f"""
You are an AI agent that helps users interact with websites.
Your input are all the interactive elements with its context of the current page from.
This is how an input looks like:
33:\t<button>Clickable element</button>
_: Not clickable, only for your context
\t: Tab indent (1 tab for depth 1 etc.). This is to help you understand which elements belong to each other.
You are an AI agent that helps users interact with websites. You receive a list of interactive elements from the current webpage and must respond with specific actions. Today's date is {time_str}.
INPUT FORMAT:
{self.input_format()}
In the beginning the list will be empty.
On elements with _ you can not click.
On elements with a index you can click.
Additional you get a list of your previous actions.
You have to respond in the following RESPONSE FORMAT:
{self.response_format()}
Respond with a valid JSON object, containing 2 keys: current_state and action.
In current_state there are 3 keys:
valuation_previous_goal: Evaluate if the previous goal was achieved or what went wrong. E.g. Failed because ...
memory: This you can use as a memory to store where you are in your overall task. E.g. if you need to find 10 jobs, you can store the already found jobs here.
next_goal: The next goal you want to achieve e.g. clicking on ... button to ...
Your AVAILABLE ACTIONS:
{self.default_action_description}
For the action choose EXACTLY ONE from the following list:
{self.default_action_description}
To interact with elements, use their index number from the input line. Write it in the click_element() or input_text() actions.
Make sure to only use indexes that are present in the list.
If you need more text from the page you can use the extract_page_content action.
Example:
{self.example_response()}
If you see any cookie or accept privacy policy please always just accepted them without hesitation.
If you evaluate repeatedly that you dont achieve the next_goal, try to find a new element that can help you achieve your task or if persistent, go back or reload the page and try a different approach.
You can ask_human for clarification if you are completely stuck or if you really need more information.
If a picture is provided, use it to understand the context and the next action.
If you are sure you are done you can extract_page_content to get the markdown content and in the next action call done() with the text of the requested result to end the task and wait for further instructions.
"""
IMPORTANT RULES:
{self.important_rules()}
"""
return SystemMessage(content=AGENT_PROMPT)
class AgentMessagePrompt:
def __init__(self, state: ControllerPageState):
def __init__(self, state: BrowserState):
self.state = state
def get_user_message(self) -> HumanMessage:
@@ -91,4 +127,4 @@ Interactive elements:
return HumanMessage(content=state_description)
def get_message_for_history(self) -> HumanMessage:
return HumanMessage(content=f'Currently on url: {self.state.url}')
return HumanMessage(content=f'Step url: {self.state.url}')

View File

@@ -1,216 +1,386 @@
from __future__ import annotations
import json
import logging
import os
import time
from datetime import datetime
from typing import Any, Optional, TypeVar
from dotenv import load_dotenv
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_openai import ChatOpenAI
from openai import RateLimitError
from pydantic import BaseModel, ValidationError
from browser_use.agent.prompts import AgentMessagePrompt, AgentSystemPrompt
from browser_use.agent.views import (
ActionResult,
AgentError,
AgentHistory,
AgentOutput,
ClickElementControllerHistoryItem,
InputTextControllerHistoryItem,
Output,
)
from browser_use.controller.service import ControllerService
from browser_use.controller.views import (
ControllerActionResult,
ControllerPageState,
ModelPricingCatalog,
Pricing,
TokenDetails,
TokenUsage,
)
from browser_use.browser.views import BrowserState
from browser_use.controller.service import Controller
from browser_use.utils import time_execution_async
load_dotenv()
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
force=True, # Prevent changing root logger config
)
T = TypeVar('T', bound=BaseModel)
class AgentService:
class Agent:
def __init__(
self,
task: str,
llm: BaseChatModel,
controller: ControllerService | None = None,
controller: Optional[Controller] = None,
use_vision: bool = True,
save_conversation_path: str | None = None,
save_conversation_path: Optional[str] = None,
max_failures: int = 5,
retry_delay: int = 10,
):
"""
Agent service.
Args:
task (str): Task to be performed.
llm (AvailableModel): Model to be used.
controller (ControllerService | None): You can reuse an existing or (automatically) create a new one.
"""
self.task = task
self.use_vision = use_vision
self.controller_injected = controller is not None
self.controller = controller or ControllerService()
self.llm = llm
system_prompt = AgentSystemPrompt(
task, default_action_description=self._get_action_description()
).get_system_message()
# Init messages
first_message = HumanMessage(content=f'Your main task is: {task}')
self.messages: list[BaseMessage] = [system_prompt, first_message]
self.n = 0
self.save_conversation_path = save_conversation_path
if save_conversation_path is not None:
# Controller setup
self.controller_injected = controller is not None
self.controller = controller or Controller()
# Action and output models setup
self._setup_action_models()
# Message history setup
self.messages = self._initialize_messages()
# Tracking variables
self.history: list[AgentHistory] = []
self.n_steps = 1
self.consecutive_failures = 0
self.max_failures = max_failures
self.retry_delay = retry_delay
if save_conversation_path:
logger.info(f'Saving conversation to {save_conversation_path}')
self.action_history: list[AgentHistory] = []
self.usage_metadata = TokenUsage(
input_tokens=0,
output_tokens=0,
total_tokens=0,
input_token_details=TokenDetails(),
output_token_details=TokenDetails(),
)
async def run(self, max_steps: int = 100):
"""
Execute the task.
def _setup_action_models(self) -> None:
"""Setup dynamic action models from controller's registry"""
# Get the dynamic action model from controller's registry
self.ActionModel = self.controller.registry.create_action_model()
# Create output model with the dynamic actions
self.AgentOutput = AgentOutput.type_with_custom_actions(self.ActionModel)
@dev ctrl+c to interrupt
"""
def _initialize_messages(self) -> list[BaseMessage]:
"""Initialize message history with system and first message"""
# Get action descriptions from controller's registry
action_descriptions = self.controller.registry.get_prompt_description()
system_prompt = AgentSystemPrompt(
self.task, action_description=action_descriptions, current_date=datetime.now()
).get_system_message()
first_message = HumanMessage(content=f'Your task is: {self.task}')
return [system_prompt, first_message]
@time_execution_async('--step')
async def step(self) -> None:
"""Execute one step of the task"""
logger.info(f'\n📍 Step {self.n_steps}')
state = self.controller.browser.get_state(use_vision=self.use_vision)
try:
logger.info('\n' + '=' * 50)
model_output = await self.get_next_action(state)
result = self.controller.act(model_output.action)
if result.extracted_content:
logger.info(f'📄 Result: {result.extracted_content}')
self.consecutive_failures = 0
except Exception as e:
result = self._handle_step_error(e, state)
model_output = None
self._update_messages_with_result(result)
self._make_history_item(model_output, state, result)
def _handle_step_error(self, error: Exception, state: BrowserState) -> ActionResult:
"""Handle all types of errors that can occur during a step"""
error_msg = AgentError.format_error(error)
prefix = f'❌ Result failed {self.consecutive_failures + 1}/{self.max_failures} times:\n '
if isinstance(error, (ValidationError, ValueError)):
logger.error(f'{prefix}{error_msg}')
self.consecutive_failures += 1
elif isinstance(error, RateLimitError):
logger.warning(f'{prefix}{error_msg}')
time.sleep(self.retry_delay)
self.consecutive_failures += 1
else:
logger.error(f'{prefix}{error_msg}')
self.consecutive_failures += 1
return ActionResult(error=error_msg)
def _update_messages_with_result(self, result: ActionResult) -> None:
"""Update message history with action results"""
if result.extracted_content:
self.messages.append(HumanMessage(content=result.extracted_content))
if result.error:
self.messages.append(HumanMessage(content=result.error))
def _make_history_item(
self,
model_output: AgentOutput | None,
state: BrowserState,
result: ActionResult,
) -> None:
"""Create and store history item"""
history_item = AgentHistory(model_output=model_output, result=result, state=state)
self.history.append(history_item)
@time_execution_async('--get_next_action')
async def get_next_action(self, state: BrowserState) -> AgentOutput:
"""Get next action from LLM based on current state"""
new_message = AgentMessagePrompt(state).get_user_message()
input_messages = self.messages + [new_message]
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
response: dict[str, Any] = await structured_llm.ainvoke(input_messages) # type: ignore
parsed: AgentOutput = response['parsed']
self._update_message_history(state, parsed)
self._log_response(parsed)
self._save_conversation(input_messages, parsed)
self._update_usage_metadata(response['raw'])
return parsed
def _calc_token_cost(self) -> float:
"""
Calculate the cost of tokens used in a request based on the model.
:param usage_metadata: TokenUsage model containing token usage details.
:param model_name: The name of the model used.
:return: Cost of the tokens used.
"""
if isinstance(self.llm, ChatOpenAI):
model_name = self.llm.model_name
elif isinstance(self.llm, ChatAnthropic):
model_name = self.llm.model
else:
logger.debug('Model name not supported for pricing calculation')
return 0
pricing_catalog = ModelPricingCatalog()
if model_name == 'gpt-4o':
model_pricing = pricing_catalog.gpt_4o
elif model_name == 'gpt-4o-mini':
model_pricing = pricing_catalog.gpt_4o_mini
elif model_name == 'claude-3-5-sonnet-20240620':
model_pricing = pricing_catalog.claude_3_5_sonnet
else:
logger.debug(f'Unsupported model: {model_name}')
return 0
uncached_input_tokens = (
self.usage_metadata.input_tokens - self.usage_metadata.input_token_details.cache_read
)
factor = 1e6
cost = (
(uncached_input_tokens / factor) * model_pricing.uncached_input
+ (self.usage_metadata.input_token_details.cache_read / factor)
* model_pricing.cached_input
+ (self.usage_metadata.output_tokens / factor) * model_pricing.output
)
return cost
def _update_usage_metadata(self, raw_response: AIMessage) -> None:
"""
Process the response and update usage.
:param raw_response: The response object containing usage metadata.
"""
# only supported for openai models for now
if isinstance(self.llm, ChatAnthropic):
token_usage_data: dict[str, Any] = raw_response.response_metadata['usage']
usage_metadata = TokenUsage(
input_tokens=token_usage_data.get('input_tokens', 0),
output_tokens=token_usage_data.get('output_tokens', 0),
total_tokens=token_usage_data.get('input_tokens', 0)
+ token_usage_data.get('output_tokens', 0),
)
self.usage_metadata.input_tokens += usage_metadata.input_tokens
self.usage_metadata.output_tokens += usage_metadata.output_tokens
self.usage_metadata.total_tokens += usage_metadata.total_tokens
elif isinstance(self.llm, ChatOpenAI):
token_usage_data: dict[str, Any] = raw_response.response_metadata['token_usage']
usage_metadata = TokenUsage(
input_tokens=token_usage_data.get('prompt_tokens', 0),
output_tokens=token_usage_data.get('completion_tokens', 0),
total_tokens=token_usage_data.get('total_tokens', 0),
input_token_details=TokenDetails(
audio=token_usage_data.get('prompt_tokens_details', {}).get('audio_tokens', 0),
cache_read=token_usage_data.get('prompt_tokens_details', {}).get(
'cached_tokens', 0
),
reasoning=0, # Assuming reasoning is not part of prompt_tokens_details
),
output_token_details=TokenDetails(
audio=token_usage_data.get('completion_tokens_details', {}).get(
'audio_tokens', 0
),
cache_read=0, # Assuming cache_read is not part of completion_tokens_details
reasoning=token_usage_data.get('completion_tokens_details', {}).get(
'reasoning_tokens', 0
),
),
)
self.usage_metadata.input_tokens += usage_metadata.input_tokens
self.usage_metadata.output_tokens += usage_metadata.output_tokens
self.usage_metadata.total_tokens += usage_metadata.total_tokens
# update usage metadata
if usage_metadata.input_token_details:
for detail_key in usage_metadata.input_token_details.model_dump():
setattr(
self.usage_metadata.input_token_details,
detail_key,
getattr(self.usage_metadata.input_token_details, detail_key)
+ getattr(usage_metadata.input_token_details, detail_key),
)
if usage_metadata.output_token_details:
for detail_key in usage_metadata.output_token_details.model_dump():
setattr(
self.usage_metadata.output_token_details,
detail_key,
getattr(self.usage_metadata.output_token_details, detail_key)
+ getattr(usage_metadata.output_token_details, detail_key),
)
else:
logger.debug('Model name not supported for pricing calculation')
return
self._log_usage_metadata(usage_metadata)
def _log_usage_metadata(self, current_tokens: Optional[TokenUsage] = None) -> None:
"""Log the usage metadata"""
total_cost = self._calc_token_cost()
total_tokens = self.usage_metadata.total_tokens
logger.debug(
f'🔢 Total Tokens: input: {self.usage_metadata.input_tokens} (cached: {self.usage_metadata.input_token_details.cache_read}) + output: {self.usage_metadata.output_tokens} = {total_tokens} = ${total_cost:.4f} 💰'
)
if current_tokens:
logger.debug(
f'🔢 Last Tokens: input: {current_tokens.input_tokens} (cached: {current_tokens.input_token_details.cache_read}) + output: {current_tokens.output_tokens} = {current_tokens.total_tokens} '
)
def _update_message_history(self, state: BrowserState, response: Any) -> None:
"""Update message history with new interactions"""
history_message = AgentMessagePrompt(state).get_message_for_history()
self.messages.append(history_message)
self.messages.append(AIMessage(content=response.model_dump_json(exclude_unset=True)))
self.n_steps += 1
def _log_response(self, response: Any) -> None:
"""Log the model's response"""
if 'Success' in response.current_state.valuation_previous_goal:
emoji = '👍'
elif 'Failed' in response.current_state.valuation_previous_goal:
emoji = '⚠️'
else:
emoji = '🤷'
logger.info(f'{emoji} Evaluation: {response.current_state.valuation_previous_goal}')
logger.info(f'🧠 Memory: {response.current_state.memory}')
logger.info(f'🎯 Next Goal: {response.current_state.next_goal}')
logger.info(f'🛠️ Action: {response.action.model_dump_json(exclude_unset=True)}')
def _save_conversation(self, input_messages: list[BaseMessage], response: Any) -> None:
"""Save conversation history to file if path is specified"""
if not self.save_conversation_path:
return
# create folders if not exists
os.makedirs(os.path.dirname(self.save_conversation_path), exist_ok=True)
with open(self.save_conversation_path + f'_{self.n_steps}.txt', 'w') as f:
self._write_messages_to_file(f, input_messages)
self._write_response_to_file(f, response)
def _write_messages_to_file(self, f: Any, messages: list[BaseMessage]) -> None:
"""Write messages to conversation file"""
for message in messages:
f.write(f' {message.__class__.__name__} \n')
if isinstance(message.content, list):
for item in message.content:
if isinstance(item, dict) and item.get('type') == 'text':
f.write(item['text'].strip() + '\n')
elif isinstance(message.content, str):
try:
content = json.loads(message.content)
f.write(json.dumps(content, indent=2) + '\n')
except json.JSONDecodeError:
f.write(message.content.strip() + '\n')
f.write('\n')
def _write_response_to_file(self, f: Any, response: Any) -> None:
"""Write model response to conversation file"""
f.write(' RESPONSE\n')
f.write(json.dumps(json.loads(response.model_dump_json(exclude_unset=True)), indent=2))
async def run(self, max_steps: int = 100) -> list[AgentHistory]:
"""Execute the task with maximum number of steps"""
try:
logger.info(f'🚀 Starting task: {self.task}')
logger.info('=' * 50)
for i in range(max_steps):
logger.info(f'\n📍 Step {i+1}')
for step in range(max_steps):
if self._too_many_failures():
break
action, result = await self.step()
await self.step()
if result.done:
logger.info('\n✅ Task completed successfully')
logger.info(f'Extracted content: \n{result.extracted_content}')
return action.done, self.action_history
if self._is_task_complete():
logger.info('✅ Task completed successfully')
break
else:
logger.info('❌ Failed to complete task in maximum steps')
return self.history
logger.info('\n' + '=' * 50)
logger.info('❌ Failed to complete task in maximum steps')
logger.info('=' * 50)
return None, self.action_history
finally:
if not self.controller_injected:
self.controller.browser.close()
@time_execution_async('--step')
async def step(self) -> tuple[AgentHistory, ControllerActionResult]:
state = self.controller.get_current_state(screenshot=self.use_vision)
action = await self.get_next_action(state)
def _too_many_failures(self) -> bool:
"""Check if we should stop due to too many failures"""
if self.consecutive_failures >= self.max_failures:
logger.error(f'❌ Stopping due to {self.max_failures} consecutive failures')
return True
return False
if action.ask_human and action.ask_human.question:
action = await self._take_human_input(action.ask_human.question)
result = self.controller.act(action)
self.n += 1
if result.error:
self.messages.append(HumanMessage(content=f'Error: {result.error}'))
if result.extracted_content:
self.messages.append(
HumanMessage(content=f'Extracted content:\n {result.extracted_content}')
)
# Convert action to history and update click/input fields if present
history_item = self._make_history_item(action, state)
self.action_history.append(history_item)
return history_item, result
def _make_history_item(self, action: AgentOutput, state: ControllerPageState) -> AgentHistory:
return AgentHistory(
search_google=action.search_google,
go_to_url=action.go_to_url,
nothing=action.nothing,
go_back=action.go_back,
done=action.done,
click_element=ClickElementControllerHistoryItem(
id=action.click_element.id, xpath=state.selector_map.get(action.click_element.id)
)
if action.click_element and state.selector_map.get(action.click_element.id)
else None,
input_text=InputTextControllerHistoryItem(
id=action.input_text.id,
xpath=state.selector_map.get(action.input_text.id),
text=action.input_text.text,
)
if action.input_text and state.selector_map.get(action.input_text.id)
else None,
extract_page_content=action.extract_page_content,
switch_tab=action.switch_tab,
open_tab=action.open_tab,
ask_human=action.ask_human,
url=state.url,
)
async def _take_human_input(self, question: str) -> AgentOutput:
human_input = input(f'\nHi, your input is required: {question}\n\n')
logger.info('-' * 50)
self.messages.append(HumanMessage(content=human_input))
structured_llm = self.llm.with_structured_output(AgentOutput)
action: AgentOutput = await structured_llm.ainvoke(self.messages) # type: ignore
self.messages.append(AIMessage(content=action.model_dump_json()))
return action
@time_execution_async('--get_next_action')
async def get_next_action(self, state: ControllerPageState) -> AgentOutput:
# TODO: include state, actions, etc.
new_message = AgentMessagePrompt(state).get_user_message()
logger.info(f'current tabs: {state.tabs}')
input_messages = self.messages + [new_message]
structured_llm = self.llm.with_structured_output(Output, include_raw=False)
#
response: Output = await structured_llm.ainvoke(input_messages) # type: ignore
# Only append the output message
history_new_message = AgentMessagePrompt(state).get_message_for_history()
self.messages.append(history_new_message)
self.messages.append(AIMessage(content=response.model_dump_json()))
logger.info(f'current state\n: {response.current_state.model_dump_json(indent=4)}')
logger.info(f'action\n: {response.action.model_dump_json(indent=4)}')
self._save_conversation(input_messages, response)
return response.action
def _get_action_description(self) -> str:
return AgentOutput.description()
def _save_conversation(self, input_messages: list[BaseMessage], response: Output):
if self.save_conversation_path is not None:
with open(self.save_conversation_path + f'_{self.n}.txt', 'w') as f:
# Write messages with proper formatting
for message in input_messages:
f.write('=' * 33 + f' {message.__class__.__name__} ' + '=' * 33 + '\n\n')
# Handle different content types
if isinstance(message.content, list):
# Handle vision model messages
for item in message.content:
if isinstance(item, dict) and item.get('type') == 'text':
f.write(item['text'].strip() + '\n')
elif isinstance(message.content, str):
try:
# Try to parse and format JSON content
content = json.loads(message.content)
f.write(json.dumps(content, indent=2) + '\n')
except json.JSONDecodeError:
# If not JSON, write as regular text
f.write(message.content.strip() + '\n')
f.write('\n')
# Write final response as formatted JSON
f.write('=' * 33 + ' Response ' + '=' * 33 + '\n\n')
f.write(json.dumps(json.loads(response.model_dump_json()), indent=2))
def _is_task_complete(self) -> bool:
"""Check if the task has been completed successfully"""
return bool(self.history and self.history[-1].result.is_done)

View File

@@ -1,55 +1,107 @@
from typing import Optional
from __future__ import annotations
from pydantic import BaseModel
from typing import Optional, Type
from browser_use.controller.views import (
ClickElementControllerAction,
ControllerActions,
InputTextControllerAction,
)
from openai import RateLimitError
from pydantic import BaseModel, ConfigDict, Field, ValidationError, create_model
from browser_use.browser.views import BrowserState
from browser_use.controller.registry.views import ActionModel
class AskHumanAgentAction(BaseModel):
question: str
class TokenDetails(BaseModel):
audio: int = 0
cache_read: int = 0
reasoning: int = 0
class AgentState(BaseModel):
class TokenUsage(BaseModel):
input_tokens: int
output_tokens: int
total_tokens: int
input_token_details: TokenDetails = Field(default=TokenDetails())
output_token_details: TokenDetails = Field(default=TokenDetails())
# allow arbitrary types
model_config = ConfigDict(arbitrary_types_allowed=True)
class Pricing(BaseModel):
uncached_input: float # per 1M tokens
cached_input: float
output: float
class ModelPricingCatalog(BaseModel):
gpt_4o: Pricing = Field(default=Pricing(uncached_input=2.50, cached_input=1.25, output=10.00))
gpt_4o_mini: Pricing = Field(
default=Pricing(uncached_input=0.15, cached_input=0.075, output=0.60)
)
claude_3_5_sonnet: Pricing = Field(
default=Pricing(uncached_input=3.00, cached_input=1.50, output=15.00)
)
class ActionResult(BaseModel):
"""Result of executing an action"""
is_done: Optional[bool] = False
extracted_content: Optional[str] = None
error: Optional[str] = None
class AgentBrain(BaseModel):
"""Current state of the agent"""
valuation_previous_goal: str
memory: str
next_goal: str
class AgentOnlyAction(BaseModel):
ask_human: Optional[AskHumanAgentAction] = None
class AgentOutput(BaseModel):
"""Output model for agent
@dev note: this model is extended with custom actions in AgentService. You can also use some fields that are not in this model as provided by the linter, as long as they are registered in the DynamicActions model.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
current_state: AgentBrain
action: ActionModel
@staticmethod
def description() -> str:
return """
- Ask human for help
Example: {"ask_human": {"question": "To clarify ..."}}"""
def type_with_custom_actions(custom_actions: Type[ActionModel]) -> Type['AgentOutput']:
"""Extend actions with custom actions"""
return create_model(
'AgentOutput',
__base__=AgentOutput,
action=(custom_actions, Field(...)), # Properly annotated field with no default
__module__=AgentOutput.__module__,
)
class AgentOutput(ControllerActions, AgentOnlyAction):
class AgentHistory(BaseModel):
"""History item for agent actions"""
model_output: AgentOutput | None
result: ActionResult
state: BrowserState
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
class AgentError:
"""Container for agent error handling"""
VALIDATION_ERROR = 'Invalid model output format. Please follow the correct schema.'
RATE_LIMIT_ERROR = 'Rate limit reached. Waiting before retry.'
NO_VALID_ACTION = 'No valid action found'
@staticmethod
def description() -> str:
return AgentOnlyAction.description() + ControllerActions.description()
#
class Output(BaseModel):
current_state: AgentState
action: AgentOutput
class ClickElementControllerHistoryItem(ClickElementControllerAction):
xpath: str | None
class InputTextControllerHistoryItem(InputTextControllerAction):
xpath: str | None
class AgentHistory(AgentOutput):
click_element: Optional[ClickElementControllerHistoryItem] = None
input_text: Optional[InputTextControllerHistoryItem] = None
url: str
def format_error(error: Exception) -> str:
"""Format error message based on error type"""
if isinstance(error, ValidationError):
return f'{AgentError.VALIDATION_ERROR}\nDetails: {str(error)}'
if isinstance(error, RateLimitError):
return AgentError.RATE_LIMIT_ERROR
return f'Unexpected error: {str(error)}'

View File

@@ -9,104 +9,114 @@ import tempfile
import time
from typing import Literal
from main_content_extractor import MainContentExtractor
from Screenshot import Screenshot
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.service import Service as ChromeService
from selenium.webdriver.common.action_chains import ActionChains
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait
from webdriver_manager.chrome import ChromeDriverManager
from browser_use.browser.views import BrowserState
from browser_use.browser.views import BrowserState, TabInfo
from browser_use.dom.service import DomService
from browser_use.dom.views import SelectorMap
from browser_use.utils import time_execution_sync
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class BrowserService:
class Browser:
def __init__(self, headless: bool = False, keep_open: bool = False):
self.headless = headless
self.driver: webdriver.Chrome | None = None
self._ob = Screenshot.Screenshot()
self.keep_open = keep_open
self.MINIMUM_WAIT_TIME = 0.5
self.MAXIMUM_WAIT_TIME = 5
self._current_handle = None # Track current handle
self._tab_cache = {}
self.keep_open = keep_open
self._tab_cache: dict[str, TabInfo] = {}
self._current_handle = None
self._ob = Screenshot.Screenshot()
def init(self) -> webdriver.Chrome:
"""
Sets up and returns a Selenium WebDriver instance with anti-detection measures.
# Initialize driver during construction
self.driver: webdriver.Chrome | None = self._setup_webdriver()
self._cached_state = self._update_state()
Returns:
webdriver.Chrome: Configured Chrome WebDriver instance
"""
chrome_options = Options()
if self.headless:
chrome_options.add_argument('--headless')
def _setup_webdriver(self) -> webdriver.Chrome:
"""Sets up and returns a Selenium WebDriver instance with anti-detection measures."""
try:
# if webdriver is not starting, try to kill it or rm -rf ~/.wdm
chrome_options = Options()
if self.headless:
chrome_options.add_argument('--headless=new') # Updated headless argument
# Anti-detection measures
chrome_options.add_argument('--disable-blink-features=AutomationControlled')
chrome_options.add_experimental_option('excludeSwitches', ['enable-automation'])
chrome_options.add_experimental_option('useAutomationExtension', False)
# Essential automation and performance settings
chrome_options.add_argument('--disable-blink-features=AutomationControlled')
chrome_options.add_experimental_option('excludeSwitches', ['enable-automation'])
chrome_options.add_experimental_option('useAutomationExtension', False)
chrome_options.add_argument('--no-sandbox')
chrome_options.add_argument('--window-size=1280,1024')
chrome_options.add_argument('--disable-extensions')
# Additional stealth settings
# chrome_options.add_argument('--start-maximized')
chrome_options.add_argument('--window-size=1280,1024')
chrome_options.add_argument('--disable-extensions')
chrome_options.add_argument('--no-sandbox')
chrome_options.add_argument('--disable-infobars')
# Background process optimization
chrome_options.add_argument('--disable-background-timer-throttling')
chrome_options.add_argument('--disable-popup-blocking')
# Initialize the Chrome driver
driver = webdriver.Chrome(
service=Service(ChromeDriverManager().install()), options=chrome_options
)
# Additional stealth settings
chrome_options.add_argument('--disable-infobars')
# Much better when working in non-headless mode
chrome_options.add_argument('--disable-backgrounding-occluded-windows')
chrome_options.add_argument('--disable-renderer-backgrounding')
# Execute stealth scripts
driver.execute_cdp_cmd(
'Page.addScriptToEvaluateOnNewDocument',
{
'source': """
Object.defineProperty(navigator, 'webdriver', {
get: () => undefined
});
Object.defineProperty(navigator, 'languages', {
get: () => ['en-US', 'en']
});
Object.defineProperty(navigator, 'plugins', {
get: () => [1, 2, 3, 4, 5]
});
window.chrome = {
runtime: {}
};
Object.defineProperty(navigator, 'permissions', {
get: () => ({
query: Promise.resolve.bind(Promise)
})
});
"""
},
)
# Initialize the Chrome driver with better error handling
service = ChromeService(ChromeDriverManager().install())
driver = webdriver.Chrome(service=service, options=chrome_options)
self.driver = driver
# Execute stealth scripts
driver.execute_cdp_cmd(
'Page.addScriptToEvaluateOnNewDocument',
{
'source': """
Object.defineProperty(navigator, 'webdriver', {
get: () => undefined
});
Object.defineProperty(navigator, 'languages', {
get: () => ['en-US', 'en']
});
Object.defineProperty(navigator, 'plugins', {
get: () => [1, 2, 3, 4, 5]
});
window.chrome = {
runtime: {}
};
Object.defineProperty(navigator, 'permissions', {
get: () => ({
query: Promise.resolve.bind(Promise)
})
});
"""
},
)
# driver.get('https://www.google.com')
return driver
return driver
except Exception as e:
logger.error(f'Failed to initialize Chrome driver: {str(e)}')
# Clean up any existing driver
if hasattr(self, 'driver') and self.driver:
try:
self.driver.quit()
self.driver = None
except Exception:
pass
raise
def _get_driver(self) -> webdriver.Chrome:
if self.driver is None:
self.driver = self.init()
self.driver = self._setup_webdriver()
return self.driver
def wait_for_page_load(self):
@@ -131,7 +141,7 @@ class BrowserService:
elapsed = time.time() - start_time
remaining = max(self.MINIMUM_WAIT_TIME - elapsed, 0)
logger.info(
logger.debug(
f'--Page loaded in {elapsed:.2f} seconds, waiting for additional {remaining:.2f} seconds'
)
@@ -139,27 +149,36 @@ class BrowserService:
if remaining > 0:
time.sleep(remaining)
def get_updated_state(self) -> BrowserState:
def _update_state(self, use_vision: bool = False) -> BrowserState:
"""
Update and return state.
"""
driver = self._get_driver()
dom_service = DomService(driver)
content = dom_service.get_clickable_elements()
screenshot_b64 = None
if use_vision:
screenshot_b64 = self.take_screenshot(selector_map=content.selector_map)
self.current_state = BrowserState(
items=content.items,
selector_map=content.selector_map,
url=driver.current_url,
title=driver.title,
current_tab_handle=self._current_handle or driver.current_window_handle,
tabs=self.get_tabs_info(),
screenshot=screenshot_b64,
)
return self.current_state
def close(self):
if not self.keep_open:
driver = self._get_driver()
driver.quit()
self.driver = None
def close(self, force: bool = False):
if not self.keep_open or force:
if self.driver:
driver = self._get_driver()
driver.quit()
self.driver = None
else:
input('Press Enter to close Browser...')
self.keep_open = False
@@ -174,44 +193,6 @@ class BrowserService:
# region - Browser Actions
def search_google(self, query: str):
"""
@dev TODO: add serp api call here
"""
driver = self._get_driver()
driver.get(f'https://www.google.com/search?q={query}')
self.wait_for_page_load()
def go_to_url(self, url: str):
driver = self._get_driver()
driver.get(url)
self.wait_for_page_load()
def go_back(self):
driver = self._get_driver()
driver.back()
self.wait_for_page_load()
def refresh(self):
driver = self._get_driver()
driver.refresh()
self.wait_for_page_load()
def extract_page_content(self, value: Literal['text', 'markdown'] = 'markdown') -> str:
"""
TODO: switch to a better parser/extractor
"""
driver = self._get_driver()
content = MainContentExtractor.extract(driver.page_source, output_format=value) # type: ignore TODO
return content
def done(self, text: str):
"""
Ends the task and waits for further instructions.
"""
logger.info(f'Done on page {self.current_state.url}\n\n: {text}')
return text
def take_screenshot(self, selector_map: SelectorMap | None, full_page: bool = False) -> str:
"""
Returns a base64 encoded screenshot of the current page.
@@ -332,7 +313,6 @@ class BrowserService:
# Then send keys
element.send_keys(text)
logger.info(f'Input text into element with xpath: {xpath}')
self.wait_for_page_load()
@@ -341,13 +321,6 @@ class BrowserService:
f'Failed to input text into element with xpath: {xpath}. Error: {str(e)}'
)
def input_text_by_index(self, index: int, text: str, state: BrowserState):
if index not in state.selector_map:
raise Exception(f'Element index {index} not found in selector map')
xpath = state.selector_map[index]
self._input_text_by_xpath(xpath, text)
def _click_element_by_xpath(self, xpath: str):
"""
Optimized method to click an element using xpath.
@@ -362,7 +335,7 @@ class BrowserService:
EC.element_to_be_clickable((By.XPATH, xpath)),
message=f'Element not clickable: {xpath}',
)
driver.execute_script('arguments[0].click();', element)
element.click()
self.wait_for_page_load()
return
except Exception:
@@ -399,30 +372,7 @@ class BrowserService:
except Exception as e:
raise Exception(f'Failed to click element with xpath: {xpath}. Error: {str(e)}')
@time_execution_sync('--click')
def click_element_by_index(self, index: int, state: BrowserState):
"""
Clicks an element using its index from the selector map.
"""
if index not in state.selector_map:
raise Exception(f'Element index {index} not found in selector map')
# Store current number of tabs before clicking
driver = self._get_driver()
initial_handles = len(driver.window_handles)
xpath = state.selector_map[index]
self._click_element_by_xpath(xpath)
logger.info(f'Clicked on index {index}: with xpath {xpath}')
# Check if new tab was opened
current_handles = len(driver.window_handles)
if current_handles > initial_handles:
return self.handle_new_tab()
# endregion
def handle_new_tab(self) -> dict:
def handle_new_tab(self) -> None:
"""Handle newly opened tab and switch to it"""
driver = self._get_driver()
handles = driver.window_handles
@@ -430,50 +380,31 @@ class BrowserService:
# Switch to new tab
driver.switch_to.window(new_handle)
self._current_handle = new_handle # Update current handle
self._current_handle = new_handle
# Wait for page load
self.wait_for_page_load()
# Get and cache tab info
tab_info = {
'handle': new_handle,
'url': driver.current_url,
'title': driver.title,
'is_current': True,
}
# Create and cache tab info
tab_info = TabInfo(handle=new_handle, url=driver.current_url, title=driver.title)
self._tab_cache[new_handle] = tab_info
# Update is_current for all tabs
for handle in self._tab_cache:
self._tab_cache[handle]['is_current'] = handle == new_handle
return tab_info
def get_tabs_info(self) -> list[dict]:
def get_tabs_info(self) -> list[TabInfo]:
"""Get information about all tabs"""
driver = self._get_driver()
current_handle = driver.current_window_handle
self._current_handle = current_handle # Update current handle
self._current_handle = current_handle
tabs_info = []
for handle in driver.window_handles:
is_current = handle == current_handle
# Use cached info if available, otherwise get new info
if handle in self._tab_cache:
tab_info = self._tab_cache[handle]
tab_info['is_current'] = is_current
else:
# Only switch if we need to get info
if not is_current:
if handle != current_handle:
driver.switch_to.window(handle)
tab_info = {
'handle': handle,
'url': driver.current_url,
'title': driver.title,
'is_current': is_current,
}
tab_info = TabInfo(handle=handle, url=driver.current_url, title=driver.title)
self._tab_cache[handle] = tab_info
tabs_info.append(tab_info)
@@ -484,33 +415,12 @@ class BrowserService:
return tabs_info
def switch_tab(self, handle: str) -> dict:
"""Switch to specified tab and return its info"""
driver = self._get_driver()
# endregion
# Verify handle exists
if handle not in driver.window_handles:
raise ValueError(f'Tab handle {handle} not found')
# Only switch if we're not already on that tab
current_handle = driver.current_window_handle
if handle != current_handle:
driver.switch_to.window(handle)
# Wait for tab to be ready
self.wait_for_page_load()
# Update and return tab info
tab_info = {
'handle': handle,
'url': driver.current_url,
'title': driver.title,
'is_current': True,
}
self._tab_cache[handle] = tab_info
return tab_info
def open_tab(self, url: str):
driver = self._get_driver()
driver.execute_script(f'window.open("{url}", "_blank");')
self.wait_for_page_load()
return self.handle_new_tab()
@time_execution_sync('--get_state')
def get_state(self, use_vision: bool = False) -> BrowserState:
"""
Get the current state of the browser including page content and tab information.
"""
self._cached_state = self._update_state(use_vision=use_vision)
return self._cached_state

View File

@@ -1,40 +0,0 @@
import base64
import os
import time
from browser_use.browser.service import BrowserService
from browser_use.dom.service import DomService
from browser_use.utils import time_execution_sync
def test_highlight_elements():
browser = BrowserService(headless=False)
driver = browser.init()
dom_service = DomService(driver)
browser.go_to_url('https://www.kayak.ch')
# browser.go_to_url('https://google.com/flights')
time.sleep(1)
# browser._click_element_by_xpath(
# '/html/body/div[5]/div/div[2]/div/div/div[3]/div/div[1]/button[1]'
# )
browser._click_element_by_xpath("//button[div/div[text()='Alle akzeptieren']]")
elements = time_execution_sync('get_clickable_elements')(dom_service.get_clickable_elements)()
# time_execution_sync('highlight_selector_map_elements')(browser.highlight_selector_map_elements)(
# elements.selector_map
# )
image = time_execution_sync('take_screenshot')(browser.take_screenshot)(elements.selector_map)
temp_image_path = os.path.join(os.path.dirname(__file__), 'temp', 'temp.png')
with open(temp_image_path, 'wb') as f:
f.write(base64.b64decode(image))
# time_execution_sync('remove_highlights')(browser.remove_highlights)()
input('Press Enter to continue...')

View File

@@ -1,18 +1,18 @@
import base64
import pytest
from browser_use.browser.service import BrowserService
from browser_use.browser.service import Browser
@pytest.fixture
def browser():
browser_service = BrowserService(headless=True)
browser_service.init()
browser_service = Browser(headless=True)
yield browser_service
browser_service.close()
@pytest.mark.skip(reason='takes too long')
# @pytest.mark.skip(reason='takes too long')
def test_take_full_page_screenshot(browser):
# Go to a test page
browser.go_to_url('https://example.com')
@@ -30,3 +30,7 @@ def test_take_full_page_screenshot(browser):
base64.b64decode(screenshot_b64)
except Exception as e:
pytest.fail(f'Failed to decode base64 screenshot: {str(e)}')
if __name__ == '__main__':
test_take_full_page_screenshot(Browser(headless=False))

View File

@@ -0,0 +1,59 @@
import time
from browser_use.browser.service import Browser
from browser_use.utils import time_execution_sync
def test_highlight_elements():
browser = Browser()
browser._get_driver().get('https://kayak.com')
# browser.go_to_url('https://google.com/flights')
# browser.go_to_url('https://immobilienscout24.de')
time.sleep(1)
# browser._click_element_by_xpath(
# '/html/body/div[5]/div/div[2]/div/div/div[3]/div/div[1]/button[1]'
# )
# browser._click_element_by_xpath("//button[div/div[text()='Alle akzeptieren']]")
while True:
state = browser.get_state()
time_execution_sync('highlight_selector_map_elements')(
browser.highlight_selector_map_elements
)(state.selector_map)
print(state.dom_items_to_string(use_tabs=False))
# print(state.selector_map)
# Find and print duplicate XPaths
xpath_counts = {}
for selector in state.selector_map.values():
if selector in xpath_counts:
xpath_counts[selector] += 1
else:
xpath_counts[selector] = 1
print('\nDuplicate XPaths found:')
for xpath, count in xpath_counts.items():
if count > 1:
print(f'XPath: {xpath}')
print(f'Count: {count}\n')
print(state.selector_map.keys(), 'Selector map keys')
action = input('Select next action: ')
time_execution_sync('remove_highlight_elements')(browser.remove_highlights)()
xpath = state.selector_map[int(action)]
browser._click_element_by_xpath(xpath)
def main():
test_highlight_elements()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,50 @@
import time
import pytest
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
def test_selenium():
try:
print('1. Setting up Chrome options...')
chrome_options = Options()
chrome_options.add_argument('--no-sandbox')
# Uncomment to test headless mode
# chrome_options.add_argument('--headless=new')
print('2. Installing/finding ChromeDriver...')
service = Service(ChromeDriverManager().install())
print('3. Creating Chrome WebDriver...')
driver = webdriver.Chrome(service=service, options=chrome_options)
print('4. Navigating to Google...')
driver.get('https://www.google.com')
print('5. Getting page title...')
title = driver.title
print(f'Page title: {title}')
time.sleep(2) # Wait to see the page if not in headless mode
print('6. Closing browser...')
driver.quit()
print('✅ Test completed successfully!')
return True
except Exception as e:
print(f'❌ Test failed with error: {str(e)}')
print(f'Error type: {type(e).__name__}')
return False
# run with: pytest browser_use/browser/tests/test_selenium.py
#
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -1,12 +1,31 @@
from typing import Optional
from pydantic import BaseModel
from browser_use.dom.views import ProcessedDomContent
# Exceptions
class BrowserException(Exception):
pass
# Pydantic
class TabInfo(BaseModel):
"""Represents information about a browser tab"""
handle: str
url: str
title: str
class BrowserState(ProcessedDomContent):
url: str
title: str
current_tab_handle: str
tabs: list[TabInfo]
screenshot: Optional[str] = None
def model_dump(self) -> dict:
dump = super().model_dump()
# Add a summary of available tabs
if self.tabs:
dump['available_tabs'] = [
f'Tab {i+1}: {tab.title} ({tab.url})' for i, tab in enumerate(self.tabs)
]
return dump

View File

@@ -0,0 +1,104 @@
from inspect import signature
from typing import Any, Callable, Optional, Type
from pydantic import BaseModel, create_model
from browser_use.browser.service import Browser
from browser_use.controller.registry.views import (
ActionModel,
ActionRegistry,
RegisteredAction,
)
class Registry:
"""Service for registering and managing actions"""
def __init__(self):
self.registry = ActionRegistry()
def _create_param_model(self, function: Callable) -> Type[BaseModel]:
"""Creates a Pydantic model from function signature"""
sig = signature(function)
params = {
name: (param.annotation, ... if param.default == param.empty else param.default)
for name, param in sig.parameters.items()
if name != 'browser'
}
# TODO: make the types here work
return create_model(
f'{function.__name__}Params',
__base__=ActionModel,
**params, # type: ignore
)
def action(
self,
description: str,
param_model: Optional[Type[BaseModel]] = None,
requires_browser: bool = False,
):
"""Decorator for registering actions"""
def decorator(func: Callable):
# Create param model from function if not provided
actual_param_model = param_model or self._create_param_model(func)
action = RegisteredAction(
name=func.__name__,
description=description,
function=func,
param_model=actual_param_model,
requires_browser=requires_browser,
)
self.registry.actions[func.__name__] = action
return func
return decorator
def execute_action(
self, action_name: str, params: dict, browser: Optional[Browser] = None
) -> Any:
"""Execute a registered action"""
if action_name not in self.registry.actions:
raise ValueError(f'Action {action_name} not found')
action = self.registry.actions[action_name]
try:
# Create the validated Pydantic model
validated_params = action.param_model(**params)
# Check if the first parameter is a Pydantic model
sig = signature(action.function)
first_param = next(iter(sig.parameters.values()))
is_pydantic = (
hasattr(first_param.annotation, '__bases__')
and BaseModel in first_param.annotation.__bases__
)
# Execute with or without browser
if action.requires_browser:
if not browser:
raise ValueError(f'Action {action_name} requires browser but none provided')
if is_pydantic:
return action.function(validated_params, browser=browser)
return action.function(**validated_params.model_dump(), browser=browser)
if is_pydantic:
return action.function(validated_params)
return action.function(**validated_params.model_dump())
except Exception as e:
raise Exception(f'Error executing action {action_name}: {str(e)}')
def create_action_model(self) -> Type[ActionModel]:
"""Creates a Pydantic model from registered actions"""
fields = {
name: (Optional[action.param_model], None)
for name, action in self.registry.actions.items()
}
return create_model('ActionModel', __base__=ActionModel, **fields) # type:ignore
def get_prompt_description(self) -> str:
"""Get a description of all actions for the prompt"""
return self.registry.get_prompt_description()

View File

@@ -0,0 +1,45 @@
from typing import Callable, Dict, Type
from pydantic import BaseModel, ConfigDict
class RegisteredAction(BaseModel):
"""Model for a registered action"""
name: str
description: str
function: Callable
param_model: Type[BaseModel]
requires_browser: bool = False
model_config = ConfigDict(arbitrary_types_allowed=True)
def prompt_description(self) -> str:
"""Get a description of the action for the prompt"""
skip_keys = ['title']
s = f'{self.description}: \n'
s += '{' + str(self.name) + ': '
s += str(
{
k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k not in skip_keys}
for k, v in self.param_model.schema()['properties'].items()
}
)
s += '}'
return s
class ActionModel(BaseModel):
"""Base model for dynamically created action models"""
model_config = ConfigDict(arbitrary_types_allowed=True)
class ActionRegistry(BaseModel):
"""Model representing the action registry"""
actions: Dict[str, RegisteredAction] = {}
def get_prompt_description(self) -> str:
"""Get a description of all actions for the prompt"""
return '\n'.join([action.prompt_description() for action in self.actions.values()])

View File

@@ -1,96 +1,170 @@
import logging
from browser_use.browser.service import BrowserService
from browser_use.browser.views import BrowserState
from main_content_extractor import MainContentExtractor
from browser_use.agent.views import ActionModel, ActionResult
from browser_use.browser.service import Browser
from browser_use.browser.views import TabInfo
from browser_use.controller.registry.service import Registry
from browser_use.controller.views import (
ControllerActionResult,
ControllerActions,
ControllerPageState,
ClickElementAction,
DoneAction,
ExtractPageContentAction,
GoToUrlAction,
InputTextAction,
OpenTabAction,
SearchGoogleAction,
SwitchTabAction,
)
from browser_use.utils import time_execution_sync
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class ControllerService:
"""
Controller service that interacts with the browser.
Right now this is just a LLM friendly wrapper around the browser service.
In the future we can add the functionality that this is a self-contained agent that can plan and act single steps.
TODO: easy hanging fruit: pass in a list of actions, compare html that changed and self assess if goal is done -> makes clicking MUCH MUCH faster and cheaper.
TODO#2: from the state generate functions that can be passed directly into the LLM as function calls. Then it could actually in the same call request for example multiple actions and new state.
"""
class Controller:
def __init__(self, keep_open: bool = False):
self.browser = BrowserService(keep_open=keep_open)
self.cached_browser_state: BrowserState | None = None
self.browser = Browser(keep_open=keep_open)
self.registry = Registry()
self._register_default_actions()
@time_execution_sync('--get_cached_browser_state')
def get_cached_browser_state(self, force_update: bool = False) -> BrowserState:
if self.cached_browser_state is None or force_update:
self.cached_browser_state = self.browser.get_updated_state()
return self.cached_browser_state
def _register_default_actions(self):
"""Register all default browser actions"""
return self.cached_browser_state
# return self.browser.get_updated_state()
def get_current_state(self, screenshot: bool = False) -> ControllerPageState:
browser_state = self.get_cached_browser_state(force_update=True)
# Get tab information without switching
tabs = self.browser.get_tabs_info()
screenshot_b64 = None
if screenshot:
screenshot_b64 = self.browser.take_screenshot(selector_map=browser_state.selector_map)
return ControllerPageState(
items=browser_state.items,
url=browser_state.url,
title=browser_state.title,
selector_map=browser_state.selector_map,
screenshot=screenshot_b64,
tabs=tabs,
# Basic Navigation Actions
@self.registry.action(
'Search Google', param_model=SearchGoogleAction, requires_browser=True
)
def search_google(params: SearchGoogleAction, browser: Browser):
driver = browser._get_driver()
driver.get(f'https://www.google.com/search?q={params.query}')
browser.wait_for_page_load()
@self.registry.action('Navigate to URL', param_model=GoToUrlAction, requires_browser=True)
def go_to_url(params: GoToUrlAction, browser: Browser):
driver = browser._get_driver()
driver.get(params.url)
browser.wait_for_page_load()
@self.registry.action('Go back', requires_browser=True)
def go_back(browser: Browser):
driver = browser._get_driver()
driver.back()
browser.wait_for_page_load()
# Element Interaction Actions
@self.registry.action(
'Click element', param_model=ClickElementAction, requires_browser=True
)
def click_element(params: ClickElementAction, browser: Browser):
state = browser._cached_state
if params.index not in state.selector_map:
print(state.selector_map)
raise Exception(
f'Element with index {params.index} does not exist - retry or use alternative actions'
)
xpath = state.selector_map[params.index]
driver = browser._get_driver()
initial_handles = len(driver.window_handles)
msg = None
for _ in range(params.num_clicks):
try:
browser._click_element_by_xpath(xpath)
msg = f'🖱️ Clicked element {params.index}: {xpath}'
if params.num_clicks > 1:
msg += f' ({_ + 1}/{params.num_clicks} clicks)'
except Exception as e:
logger.warning(f'Element no longer available after {_ + 1} clicks: {str(e)}')
break
if len(driver.window_handles) > initial_handles:
browser.handle_new_tab()
return ActionResult(extracted_content=f'Clicked element {msg}')
@self.registry.action('Input text', param_model=InputTextAction, requires_browser=True)
def input_text(params: InputTextAction, browser: Browser):
state = browser._cached_state
if params.index not in state.selector_map:
raise Exception(
f'Element index {params.index} does not exist - retry or use alternative actions'
)
xpath = state.selector_map[params.index]
browser._input_text_by_xpath(xpath, params.text)
msg = f'⌨️ Input text "{params.text}" into element {params.index}: {xpath}'
return ActionResult(extracted_content=msg)
# Tab Management Actions
@self.registry.action('Switch tab', param_model=SwitchTabAction, requires_browser=True)
def switch_tab(params: SwitchTabAction, browser: Browser):
driver = browser._get_driver()
# Verify handle exists
if params.handle not in driver.window_handles:
raise ValueError(f'Tab handle {params.handle} not found')
# Only switch if we're not already on that tab
if params.handle != driver.current_window_handle:
driver.switch_to.window(params.handle)
browser._current_handle = params.handle
# Wait for tab to be ready
browser.wait_for_page_load()
# Update and return tab info
tab_info = TabInfo(handle=params.handle, url=driver.current_url, title=driver.title)
browser._tab_cache[params.handle] = tab_info
@self.registry.action('Open new tab', param_model=OpenTabAction, requires_browser=True)
def open_tab(params: OpenTabAction, browser: Browser):
driver = browser._get_driver()
driver.execute_script(f'window.open("{params.url}", "_blank");')
browser.wait_for_page_load()
browser.handle_new_tab()
# Content Actions
@self.registry.action(
'Extract page content', param_model=ExtractPageContentAction, requires_browser=True
)
def extract_content(params: ExtractPageContentAction, browser: Browser):
driver = browser._get_driver()
content = MainContentExtractor.extract( # type: ignore
html=driver.page_source,
output_format=params.value,
)
return ActionResult(extracted_content=content)
@self.registry.action('Complete task', param_model=DoneAction, requires_browser=True)
def done(params: DoneAction, browser: Browser):
logger.info(f'✅ Done on page {browser._cached_state.url}\n\n: {params.text}')
return ActionResult(is_done=True, extracted_content=params.text)
def action(self, description: str, **kwargs):
"""Decorator for registering custom actions
@param description: Describe the LLM what the function does (better description == better function calling)
"""
return self.registry.action(description, **kwargs)
@time_execution_sync('--act')
def act(self, action: ControllerActions) -> ControllerActionResult:
def act(self, action: ActionModel) -> ActionResult:
"""Execute an action"""
try:
current_state = self.get_cached_browser_state(force_update=False)
if action.search_google:
self.browser.search_google(action.search_google.query)
elif action.switch_tab:
self.browser.switch_tab(action.switch_tab.handle)
elif action.open_tab:
self.browser.open_tab(action.open_tab.url)
elif action.go_to_url:
self.browser.go_to_url(action.go_to_url.url)
elif action.nothing:
# self.browser.nothing()
# TODO: implement
pass
elif action.go_back:
self.browser.go_back()
elif action.done:
self.browser.done(action.done.text)
return ControllerActionResult(done=True, extracted_content=action.done.text)
elif action.click_element:
self.browser.click_element_by_index(action.click_element.id, current_state)
elif action.input_text:
self.browser.input_text_by_index(
action.input_text.id, action.input_text.text, current_state
)
elif action.extract_page_content:
content = self.browser.extract_page_content()
return ControllerActionResult(done=False, extracted_content=content)
else:
raise ValueError(f'Unknown action: {action}')
return ControllerActionResult(done=False)
for action_name, params in action.model_dump(exclude_unset=True).items():
if params is not None:
result = self.registry.execute_action(action_name, params, browser=self.browser)
if isinstance(result, str):
return ActionResult(extracted_content=result)
elif isinstance(result, ActionResult):
return result
elif result is None:
return ActionResult()
else:
raise ValueError(f'Invalid action result type: {type(result)} of {result}')
return ActionResult()
except Exception as e:
return ControllerActionResult(done=False, error=f'Error executing action: {str(e)}')
raise e

View File

@@ -1,15 +1,17 @@
import pytest
from browser_use.controller.service import ControllerService
from browser_use.agent.views import ActionModel
from browser_use.controller.service import Controller
def test_get_current_state():
# Initialize controller
controller = ControllerService()
controller = Controller()
# Go to a test URL
controller.browser.go_to_url('https://www.example.com')
# controller.act(ActionModel(name='go_to_url', url='https://www.example.com'))
# Get current state without screenshot
state = controller.get_current_state(screenshot=True)
state = controller.browser.get_state(use_vision=True)
input('Press Enter to continue...')

View File

@@ -1,99 +1,38 @@
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel
from browser_use.browser.views import BrowserState
class SearchGoogleControllerAction(BaseModel):
# Action Input Models
class SearchGoogleAction(BaseModel):
query: str
class GoToUrlControllerAction(BaseModel):
class GoToUrlAction(BaseModel):
url: str
class ClickElementControllerAction(BaseModel):
id: int
class ClickElementAction(BaseModel):
index: int
num_clicks: int = 1
class InputTextControllerAction(BaseModel):
id: int
class InputTextAction(BaseModel):
index: int
text: str
class DoneControllerAction(BaseModel):
class DoneAction(BaseModel):
text: str
class SwitchTabControllerAction(BaseModel):
handle: str # The window handle to switch to
class SwitchTabAction(BaseModel):
handle: str
class OpenTabControllerAction(BaseModel):
class OpenTabAction(BaseModel):
url: str
class ControllerActions(BaseModel):
"""
Controller actions you can use to interact.
"""
search_google: Optional[SearchGoogleControllerAction] = None
go_to_url: Optional[GoToUrlControllerAction] = None
nothing: Optional[Literal[True]] = None
go_back: Optional[Literal[True]] = None
done: Optional[DoneControllerAction] = None
click_element: Optional[ClickElementControllerAction] = None
input_text: Optional[InputTextControllerAction] = None
extract_page_content: Optional[Literal[True]] = None
switch_tab: Optional[SwitchTabControllerAction] = None
open_tab: Optional[OpenTabControllerAction] = None
@staticmethod
def description() -> str:
"""
Returns a human-readable description of available actions.
"""
return """
- Search Google with a query
Example: {"search_google": {"query": "weather today"}}
- Navigate directly to a URL where you want to go
Example: {"go_to_url": {"url": "https://abc.com"}}
- Do nothing/wait
Example: {"nothing": true}
- Go back to previous page
Example: {"go_back": true}
- Mark entire task as complete
Example: {"done": {"text": "This is the requested result of the task which is send to the human..."}}
- Click an interactive element by its given ID
Example: {"click_element": {"id": 1}}
- Input text into an interactiveelement by its ID
Example: {"input_text": {"id": 1, "text": "Hello world"}}
- Get the page content in markdown
Example: {"extract_page_content": true}
- Switch to a different browser tab
Example: {"switch_tab": {"handle": "CDwindow-1234..."}}
- Open a new tab
Example: {"open_tab": {"url": "https://abc.com"}}
"""
class ControllerActionResult(BaseModel):
done: bool
extracted_content: Optional[str] = None
error: Optional[str] = None
class ControllerPageState(BrowserState):
screenshot: Optional[str] = None
tabs: list[dict] = [] # Add tabs info to state
def model_dump(self) -> dict:
dump = super().model_dump()
# Add a summary of available tabs
if self.tabs:
dump['available_tabs'] = [
f"Tab {i+1}: {tab['title']} ({tab['url']})" for i, tab in enumerate(self.tabs)
]
return dump
class ExtractPageContentAction(BaseModel):
value: Literal['text', 'markdown', 'html'] = 'text'

View File

@@ -1,14 +1,23 @@
import json
import logging
from typing import Optional
from bs4 import BeautifulSoup, NavigableString, PageElement, Tag
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.remote.webelement import WebElement
from browser_use.dom.views import DomContentItem, ProcessedDomContent
from browser_use.dom.views import (
BatchCheckResults,
DomContentItem,
ElementCheckResult,
ElementState,
ProcessedDomContent,
TextCheckResult,
TextState,
)
from browser_use.utils import time_execution_sync
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class DomService:
@@ -24,111 +33,239 @@ class DomService:
@time_execution_sync('--_process_content')
def _process_content(self, html_content: str) -> ProcessedDomContent:
"""
Process HTML content to extract and clean relevant elements.
Args:
html_content: Raw HTML string to process
Returns:
ProcessedDomContent: Processed DOM content
@dev TODO: instead of of using enumerated index, use random 4 digit numbers -> a bit more tokens BUT updates on the screen wont click on incorrect items -> tricky because you have to consider that same elements need to have the same index ...
"""
# Parse HTML content using BeautifulSoup with html.parser
soup = BeautifulSoup(html_content, 'html.parser')
candidate_elements: list[Tag | NavigableString] = []
output_items: list[DomContentItem] = []
selector_map: dict[int, str] = {}
current_index = 0
def _process_element(element: PageElement):
should_add_element = False
# Collectors for batch processing with order tracking
interactive_elements: dict[str, tuple[Tag, int]] = {} # xpath -> (element, order)
text_nodes: dict[str, tuple[NavigableString, int]] = {} # xpath -> (text_node, order)
xpath_order_counter = 0 # Track order of appearance
# if not self._quick_element_filter(element):
# if isinstance(element, Tag):
# element.decompose()
# return
dom_queue: list[tuple[PageElement, list, Optional[str]]] = (
[(element, [], None) for element in reversed(list(soup.body.children))]
if soup.body
else []
)
# First pass: collect all elements that need checking
while dom_queue:
element, path_indices, parent_xpath = dom_queue.pop()
if isinstance(element, Tag):
# Don't add any children of non-interactive elements
if not self._is_element_accepted(element):
element.decompose()
return
if self._is_interactive_element(element) or self._is_leaf_element(element):
if (
self._is_active(element)
and self._is_top_element(element)
and self._is_visible(element)
):
should_add_element = True
elif isinstance(element, NavigableString) and element.strip():
if self._is_visible(element):
should_add_element = True
if should_add_element:
if isinstance(element, (Tag, NavigableString)):
candidate_elements.append(element)
if isinstance(element, Tag):
for child in element.children:
_process_element(child)
for element in soup.body.children if soup.body else []:
_process_element(element)
# Process candidates
selector_map: dict[int, str] = {}
output_items: list[DomContentItem] = []
for index, element in enumerate(candidate_elements):
xpath = self._generate_xpath(element)
depth = max(len(xpath.split('/')) - 2, 0)
# Skip text nodes that are direct children of already processed elements
if isinstance(element, NavigableString) and element.parent in [
e for e in candidate_elements
]:
continue
if isinstance(element, NavigableString):
text_content = self._cap_text_length(element.strip())
if text_content:
output_string = f'{text_content}'
output_items.append(
DomContentItem(
index=index, text=output_string, clickable=False, depth=depth
)
)
continue
else:
# don't add empty text nodes
continue
else:
text_content = self._extract_text_from_all_children(element)
tag_name = element.name
attributes = self._get_essential_attributes(element)
opening_tag = f"<{tag_name}{' ' + attributes if attributes else ''}>"
closing_tag = f'</{tag_name}>'
output_string = f'{opening_tag}{text_content}{closing_tag}'
output_items.append(
DomContentItem(index=index, text=output_string, clickable=True, depth=depth)
siblings = (
list(element.parent.find_all(element.name, recursive=False))
if element.parent
else []
)
sibling_index = siblings.index(element) + 1 if siblings else 1
current_path = path_indices + [(element.name, sibling_index)]
element_xpath = '//' + '/'.join(f'{tag}[{idx}]' for tag, idx in current_path)
selector_map[index] = xpath
# Add children to queue with their path information
for child in reversed(list(element.children)):
dom_queue.append((child, current_path, element_xpath)) # Pass parent's xpath
# Remove all elements from selector map that are not in output items
selector_map = {
k: v for k, v in selector_map.items() if k in [i.index for i in output_items]
}
# Collect interactive elements with their order
if (
self._is_interactive_element(element) or self._is_leaf_element(element)
) and self._is_active(element):
interactive_elements[element_xpath] = (element, xpath_order_counter)
xpath_order_counter += 1
elif isinstance(element, NavigableString) and element.strip():
if element.parent and element.parent not in [e[0] for e in dom_queue]:
if parent_xpath:
text_nodes[parent_xpath] = (element, xpath_order_counter)
xpath_order_counter += 1
# Batch check all elements
element_results = self._batch_check_elements(interactive_elements)
text_results = self._batch_check_texts(text_nodes)
# Create ordered results
ordered_results: list[
tuple[int, str, bool, str, int, bool]
] = [] # [(order, xpath, is_clickable, content, depth, is_text_only), ...]
# Process interactive elements
for xpath, (element, order) in interactive_elements.items():
if xpath in element_results.elements:
result = element_results.elements[xpath]
if result.isVisible and result.isTopElement:
text_content = self._extract_text_from_all_children(element)
tag_name = element.name
attributes = self._get_essential_attributes(element)
output_string = f"<{tag_name}{' ' + attributes if attributes else ''}>{text_content}</{tag_name}>"
depth = len(xpath.split('/')) - 2
ordered_results.append((order, xpath, True, output_string, depth, False))
# Process text nodes
for xpath, (text_node, order) in text_nodes.items():
if xpath in text_results.texts:
result = text_results.texts[xpath]
if result.isVisible:
text_content = self._cap_text_length(text_node.strip())
if text_content:
depth = len(xpath.split('/')) - 2
ordered_results.append((order, xpath, False, text_content, depth, True))
# Sort by original order
ordered_results.sort(key=lambda x: x[0])
# Build final output maintaining order
for i, (_, xpath, is_clickable, content, depth, is_text_only) in enumerate(ordered_results):
output_items.append(
DomContentItem(
index=i,
text=content,
# clickable=is_clickable,
depth=depth,
is_text_only=is_text_only,
)
)
# if is_clickable: # Only add clickable elements to selector map
# TODO: make this right, for now we add all elements (except text) to selector map
if not is_text_only:
selector_map[i] = xpath
return ProcessedDomContent(items=output_items, selector_map=selector_map)
def _cap_text_length(self, text: str, max_length: int = 150) -> str:
def _batch_check_elements(self, elements: dict[str, tuple[Tag, int]]) -> BatchCheckResults:
"""Batch check all interactive elements at once."""
if not elements:
return BatchCheckResults(elements={}, texts={})
check_script = """
return (function() {
const results = {};
const elements = %s;
for (const [xpath, elementData] of Object.entries(elements)) {
const element = document.evaluate(xpath, document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
if (!element) continue;
// Check visibility
const isVisible = element.checkVisibility({
checkOpacity: true,
checkVisibilityCSS: true
});
if (!isVisible) continue;
// Check if topmost
const rect = element.getBoundingClientRect();
const points = [
{x: rect.left + rect.width * 0.25, y: rect.top + rect.height * 0.25},
{x: rect.left + rect.width * 0.75, y: rect.top + rect.height * 0.25},
{x: rect.left + rect.width * 0.25, y: rect.top + rect.height * 0.75},
{x: rect.left + rect.width * 0.75, y: rect.top + rect.height * 0.75},
{x: rect.left + rect.width / 2, y: rect.top + rect.height / 2}
];
const isTopElement = points.some(point => {
const topEl = document.elementFromPoint(point.x, point.y);
let current = topEl;
while (current && current !== document.body) {
if (current === element) return true;
current = current.parentElement;
}
return false;
});
if (isTopElement) {
results[xpath] = {
xpath: xpath,
isVisible: true,
isTopElement: true
};
}
}
return results;
})();
""" % json.dumps({xpath: {} for xpath in elements.keys()})
try:
results = self.driver.execute_script(check_script)
return BatchCheckResults(
elements={xpath: ElementCheckResult(**data) for xpath, data in results.items()},
texts={},
)
except Exception as e:
logger.error('Error in batch element check: %s', e)
return BatchCheckResults(elements={}, texts={})
def _batch_check_texts(
self, texts: dict[str, tuple[NavigableString, int]]
) -> BatchCheckResults:
"""Batch check all text nodes at once."""
if not texts:
return BatchCheckResults(elements={}, texts={})
check_script = """
return (function() {
const results = {};
const texts = %s;
for (const [xpath, textData] of Object.entries(texts)) {
const parent = document.evaluate(xpath, document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
if (!parent) continue;
try {
const range = document.createRange();
const textNode = parent.childNodes[textData.index];
range.selectNodeContents(textNode);
const rect = range.getBoundingClientRect();
const isVisible = (
rect.width !== 0 &&
rect.height !== 0 &&
rect.top >= 0 &&
rect.top <= window.innerHeight &&
parent.checkVisibility({
checkOpacity: true,
checkVisibilityCSS: true
})
);
if (isVisible) {
results[xpath] = {
xpath: xpath,
isVisible: true
};
}
} catch (e) {
continue;
}
}
return results;
})();
""" % json.dumps(
{
xpath: {'index': list(text_node[0].parent.children).index(text_node[0])}
for xpath, text_node in texts.items()
if text_node[0].parent
}
)
try:
results = self.driver.execute_script(check_script)
return BatchCheckResults(
elements={},
texts={xpath: TextCheckResult(**data) for xpath, data in results.items()},
)
except Exception as e:
logger.error('Error in batch text check: %s', e)
return BatchCheckResults(elements={}, texts={})
def _cap_text_length(self, text: str, max_length: int = 250) -> str:
if len(text) > max_length:
half_length = max_length // 2
return text[:half_length] + '...' + text[-half_length:]
@@ -230,75 +367,6 @@ class DomService:
return element.name not in leaf_element_deny_list
def _generate_xpath(self, element: Tag | NavigableString) -> str:
# Generate cache key based on element properties
cache_key = None
if isinstance(element, Tag):
attributes = [
element.get('id', ''),
element.get('class', ''),
element.name,
element.get_text().strip(),
]
cache_key = '|'.join(str(attr) for attr in attributes)
# Return cached xpath if exists
if cache_key in self.xpath_cache:
return self.xpath_cache[cache_key]
if isinstance(element, NavigableString):
if element.parent:
return self._generate_xpath(element.parent)
return ''
if not hasattr(element, 'name'):
return ''
element_id = getattr(element, 'get', lambda x: None)('id')
if element_id:
xpath = f"//*[@id='{element_id}']"
if cache_key:
self.xpath_cache[cache_key] = xpath
return xpath
parts = []
current = element
while current and getattr(current, 'name', None):
if current.name == '[document]':
break
parent = getattr(current, 'parent', None)
if parent and hasattr(parent, 'children'):
# Get only visible element nodes
siblings = [
s for s in parent.find_all(current.name, recursive=False) if isinstance(s, Tag)
]
if len(siblings) > 1:
try:
index = siblings.index(current) + 1
parts.insert(0, f'{current.name}[{index}]')
except ValueError:
parts.insert(0, current.name)
else:
parts.insert(0, current.name)
current = parent
if parts and parts[0] != 'html':
parts.insert(0, 'html')
if len(parts) > 1 and parts[1] != 'body':
parts.insert(1, 'body')
xpath = '//' + '/'.join(parts)
# Cache the generated xpath
if cache_key:
self.xpath_cache[cache_key] = xpath
return xpath
def _get_essential_attributes(self, element: Tag) -> str:
"""
Collects essential attributes from an element.
@@ -333,11 +401,11 @@ class DomService:
if attr in element.attrs:
element_attr = element[attr]
if isinstance(element_attr, str):
element_attr = element_attr[:50]
element_attr = element_attr
elif isinstance(element_attr, (list, tuple)):
element_attr = ' '.join(str(v)[:50] for v in element_attr)
element_attr = ' '.join(str(v) for v in element_attr)
attrs.append(f'{attr}="{element_attr}"')
attrs.append(f'{attr}="{self._cap_text_length(element_attr, 25)}"')
state_attributes_prefixes = (
'aria-',
@@ -351,115 +419,6 @@ class DomService:
return ' '.join(attrs)
def _is_visible(self, element: Tag | NavigableString) -> bool:
if not isinstance(element, Tag):
return self._is_text_visible(element)
element_id = element.get('id', '')
if element_id:
js_selector = f'document.getElementById("{element_id}")'
else:
xpath = self._generate_xpath(element)
js_selector = f'document.evaluate("{xpath}", document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue'
visibility_check = f"""
return (function() {{
const element = {js_selector};
if (!element) {{
return false;
}}
// Force return as boolean
return Boolean(element.checkVisibility({{
checkOpacity: true,
checkVisibilityCSS: true
}}));
}}());
"""
try:
# todo: parse responses with pydantic
is_visible = self.driver.execute_script(visibility_check)
return bool(is_visible)
except Exception:
return False
def _is_text_visible(self, element: NavigableString) -> bool:
"""Check if text node is visible using JavaScript."""
parent = element.parent
if not parent:
return False
xpath = self._generate_xpath(parent)
visibility_check = f"""
return (function() {{
const parent = document.evaluate("{xpath}", document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
if (!parent) {{
return false;
}}
const range = document.createRange();
const textNode = parent.childNodes[{list(parent.children).index(element)}];
range.selectNodeContents(textNode);
const rect = range.getBoundingClientRect();
if (rect.width === 0 || rect.height === 0 ||
rect.top < 0 || rect.top > window.innerHeight) {{
return false;
}}
// Force return as boolean
return Boolean(parent.checkVisibility({{
checkOpacity: true,
checkVisibilityCSS: true
}}));
}}());
"""
try:
is_visible = self.driver.execute_script(visibility_check)
return bool(is_visible)
except Exception:
return False
def _is_top_element(self, element: Tag | NavigableString, rect=None) -> bool:
"""Check if element is the topmost at its position."""
xpath = self._generate_xpath(element)
check_top = f"""
return (function() {{
const elem = document.evaluate("{xpath}", document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
if (!elem) {{
return false;
}}
const rect = elem.getBoundingClientRect();
const points = [
{{x: rect.left + rect.width * 0.25, y: rect.top + rect.height * 0.25}},
{{x: rect.left + rect.width * 0.75, y: rect.top + rect.height * 0.25}},
{{x: rect.left + rect.width * 0.25, y: rect.top + rect.height * 0.75}},
{{x: rect.left + rect.width * 0.75, y: rect.top + rect.height * 0.75}},
{{x: rect.left + rect.width / 2, y: rect.top + rect.height / 2}}
];
return Boolean(points.some(point => {{
const topEl = document.elementFromPoint(point.x, point.y);
let current = topEl;
while (current && current !== document.body) {{
if (current === elem) return true;
current = current.parentElement;
}}
return false;
}}));
}}());
"""
try:
is_top = self.driver.execute_script(check_top)
return bool(is_top)
except Exception:
logger.error(f'Error checking top element: {element}')
return False
def _is_active(self, element: Tag) -> bool:
"""Check if element is active (not disabled)."""
return not (
@@ -467,40 +426,3 @@ class DomService:
or element.get('hidden') is not None
or element.get('aria-disabled') == 'true'
)
# def _quick_element_filter(self, element: PageElement) -> bool:
# """
# Quick pre-filter to eliminate elements before expensive checks.
# Returns True if element passes initial filtering.
# """
# if isinstance(element, NavigableString):
# # Quick check for empty or whitespace-only strings
# return bool(element.strip())
# if not isinstance(element, Tag):
# return False
# style = element.get('style')
# # Quick attribute checks that would make element invisible/non-interactive
# if any(
# [
# element.get('aria-hidden') == 'true',
# element.get('hidden') is not None,
# element.get('disabled') is not None,
# style and ('display: none' in style or 'visibility: hidden' in style),
# element.has_attr('class')
# and any(cls in element['class'] for cls in ['hidden', 'invisible']),
# # Common hidden class patterns
# element.get('type') == 'hidden',
# ]
# ):
# return False
# # Skip elements that definitely won't be interactive or visible
# non_interactive_display = ['none', 'hidden']
# computed_style = element.get('style', '') or ''
# if any(display in computed_style for display in non_interactive_display):
# return False
# return True

View File

@@ -2,33 +2,32 @@ import time
from tokencost import count_string_tokens
from browser_use.browser.service import BrowserService
from browser_use.browser.service import Browser
from browser_use.dom.service import DomService
from browser_use.utils import time_execution_sync
# @pytest.mark.skip("slow af")
def test_process_html_file():
browser = BrowserService(headless=False)
browser = Browser(headless=False)
driver = browser.init()
driver = browser._get_driver()
dom_service = DomService(driver)
browser.go_to_url('https://www.kayak.ch')
driver.get('https://kayak.com/flights')
# browser.go_to_url('https://google.com/flights')
# browser.go_to_url('https://immobilienscout24.de')
time.sleep(1)
time.sleep(3)
# browser._click_element_by_xpath(
# '/html/body/div[5]/div/div[2]/div/div/div[3]/div/div[1]/button[1]'
# )
browser._click_element_by_xpath("//button[div/div[text()='Alle akzeptieren']]")
# browser._click_element_by_xpath("//button[div/div[text()='Alle akzeptieren']]")
elements = time_execution_sync('get_clickable_elements')(
dom_service.get_clickable_elements().dom_items_to_string
)()
elements = time_execution_sync('get_clickable_elements')(dom_service.get_clickable_elements)()
print(elements)
print('Tokens:', count_string_tokens(elements, model='gpt-4o'))
print(elements.dom_items_to_string(use_tabs=False))
print('Tokens:', count_string_tokens(elements.dom_items_to_string(), model='gpt-4o'))
input('Press Enter to continue...')

View File

@@ -1,10 +1,11 @@
from typing import Dict, List
from pydantic import BaseModel
class DomContentItem(BaseModel):
index: int
text: str
clickable: bool
is_text_only: bool
depth: int
@@ -15,13 +16,38 @@ class ProcessedDomContent(BaseModel):
items: list[DomContentItem]
selector_map: SelectorMap
def dom_items_to_string(self) -> str:
def dom_items_to_string(self, use_tabs: bool = True) -> str:
"""Convert the processed DOM content to HTML."""
formatted_text = ''
for item in self.items:
item_depth = '\t' * item.depth * 1
if item.clickable:
formatted_text += f'{item.index}:{item_depth}{item.text}\n'
item_depth = '\t' * item.depth * 1 if use_tabs else ''
if item.is_text_only:
formatted_text += f'_[:]{item_depth}{item.text}\n'
else:
formatted_text += f'{item.index}:{item_depth}{item.text}\n'
formatted_text += f'{item.index}[:]{item_depth}{item.text}\n'
return formatted_text
class ElementState(BaseModel):
isVisible: bool
isTopElement: bool
class TextState(BaseModel):
isVisible: bool
class ElementCheckResult(BaseModel):
xpath: str
isVisible: bool
isTopElement: bool
class TextCheckResult(BaseModel):
xpath: str
isVisible: bool
class BatchCheckResults(BaseModel):
elements: Dict[str, ElementCheckResult]
texts: Dict[str, TextCheckResult]

View File

@@ -0,0 +1,47 @@
import logging
import sys
def setup_logging():
# Check if handlers are already set up
if logging.getLogger().hasHandlers():
return
# Clear existing handlers
root = logging.getLogger()
root.handlers = []
class BrowserUseFormatter(logging.Formatter):
def format(self, record):
if record.name.startswith('browser_use.'):
record.name = record.name.split('.')[-2]
return super().format(record)
# Setup single handler for all loggers
console = logging.StreamHandler(sys.stdout)
console.setFormatter(BrowserUseFormatter('%(levelname)-8s [%(name)s] %(message)s'))
# Configure root logger only
root.addHandler(console)
root.setLevel(logging.INFO)
# Configure browser_use logger to prevent propagation
browser_use_logger = logging.getLogger('browser_use')
browser_use_logger.propagate = False
browser_use_logger.addHandler(console)
# Silence third-party loggers
for logger in [
'WDM',
'httpx',
'selenium',
'urllib3',
'asyncio',
'langchain',
'openai',
'httpcore',
'charset_normalizer',
]:
third_party = logging.getLogger(logger)
third_party.setLevel(logging.ERROR)
third_party.propagate = False

View File

@@ -3,13 +3,13 @@ import time
from functools import wraps
from typing import Any, Callable, Coroutine, ParamSpec, TypeVar
logger = logging.getLogger(__name__)
# Define generic type variables for return type and parameters
R = TypeVar('R')
P = ParamSpec('P')
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
def time_execution_sync(additional_text: str = '') -> Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(func: Callable[P, R]) -> Callable[P, R]:
@@ -18,7 +18,7 @@ def time_execution_sync(additional_text: str = '') -> Callable[[Callable[P, R]],
start_time = time.time()
result = func(*args, **kwargs)
execution_time = time.time() - start_time
logger.info(f'{additional_text} Execution time: {execution_time:.2f} seconds')
logger.debug(f'{additional_text} Execution time: {execution_time:.2f} seconds')
return result
return wrapper
@@ -35,7 +35,7 @@ def time_execution_async(
start_time = time.time()
result = await func(*args, **kwargs)
execution_time = time.time() - start_time
logger.info(f'{additional_text} Execution time: {execution_time:.2f} seconds')
logger.debug(f'{additional_text} Execution time: {execution_time:.2f} seconds')
return result
return wrapper

View File

@@ -1,6 +1,10 @@
import os
import sys
from browser_use.logging_config import setup_logging
# Get the absolute path to the project root
project_root = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, project_root)
setup_logging()

View File

@@ -1,4 +1,3 @@
import logging
import os
import sys
@@ -10,9 +9,6 @@ import asyncio
from browser_use import Agent, Controller
logging.basicConfig(level=logging.INFO)
# Persist the browser state across agents
controller = Controller()

View File

@@ -0,0 +1,93 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
from typing import List, Optional
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from browser_use.agent.service import Agent
from browser_use.browser.service import Browser
from browser_use.controller.service import Controller
# Initialize controller first
controller = Controller()
# Action Models
class JobDetails(BaseModel):
title: str
company: str
job_link: str
salary: Optional[str] = None
@controller.action('Save job details which you found on page', param_model=JobDetails)
def save_job(params: JobDetails):
with open('jobs.txt', 'a') as f:
f.write(f'{params.title} at {params.company}: {params.salary}\n')
class StarredPeople(BaseModel):
usernames: List[str]
@controller.action('Save people who starred the repo', param_model=StarredPeople)
def save_starred_people(params: StarredPeople):
with open('starred_people.txt', 'a') as f:
for username in params.usernames:
f.write(f'{username}\n')
# Browser-requiring action example
class PageSaver(BaseModel):
filename: str
@controller.action('Save current page info', param_model=PageSaver, requires_browser=True)
def save_page_info(params: PageSaver, browser: Browser):
state = browser.get_state()
with open(params.filename, 'w') as f:
f.write(f'URL: {state.url}\n')
f.write(f'Title: {state.title}\n')
f.write(f'HTML: {state.items}\n')
class Job(BaseModel):
title: str
link: str
company: str
salary: Optional[str] = None
class Jobs(BaseModel):
jobs: List[Job]
@controller.action('Save jobs', param_model=Jobs, requires_browser=True)
def save_jobs(params: Jobs, browser: Browser):
with open('jobs.txt', 'a') as f:
for job in params.jobs:
f.write(f'{job.title} at {job.company}: {job.salary} ({job.link})\n')
# Without Pydantic model - using simple parameters
@controller.action('Ask user for information')
def ask_human(question: str, display_question: bool) -> str:
return input(f'\n{question}\nInput: ')
async def main():
task = 'Find 10 software developer jobs in San Francisco at YC startups in google and save the jobs to a file. Then ask human for more information'
model = ChatOpenAI(model='gpt-4o')
agent = Agent(task=task, llm=model, controller=controller)
await agent.run()
if __name__ == '__main__':
asyncio.run(main())

View File

@@ -1,43 +0,0 @@
"""
@dev You need to add ANTHROPIC_API_KEY to your environment variables.
"""
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from browser_use.agent.service import AgentService
from browser_use.controller.service import ControllerService
task = 'Go to kayak.com and find a one-way flight from Zürich to San Francisco on 12 January 2025.'
controller = ControllerService()
# model = ChatAnthropic(
# model_name='claude-3-5-sonnet-20240620', timeout=25, stop=None, temperature=0.3
# )
model = ChatOpenAI(model='gpt-4o')
agent = AgentService(task, model, controller, use_vision=True)
async def main():
max_steps = 50
# Run the agent step by step
for i in range(max_steps):
print(f'\n📍 Step {i+1}')
action, result = await agent.step()
print('Action:', action)
print('Result:', result)
if result.done:
print('\n✅ Task completed successfully!')
print('Extracted content:', result.extracted_content)
break
asyncio.run(main())

View File

@@ -1,36 +0,0 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
from langchain_openai import ChatOpenAI
from browser_use.agent.service import AgentService
from browser_use.controller.service import ControllerService
people = ['Albert Einstein', 'Oprah Winfrey', 'Steve Jobs']
task = f'Opening new tabs and searching for images for these people: {", ".join(people)}. Then ask me for further instructions.'
controller = ControllerService(keep_open=True)
model = ChatOpenAI(model='gpt-4o')
agent = AgentService(task, model, controller, use_vision=True)
async def main():
max_steps = 50
# Run the agent step by step
for i in range(max_steps):
print(f'\n📍 Step {i+1}')
action, result = await agent.step()
print('Action:', action)
print('Result:', result)
if result.done:
print('\n✅ Task completed successfully!')
print('Extracted content:', result.extracted_content)
break
asyncio.run(main())

View File

@@ -4,12 +4,9 @@ Simple try of the agent.
@dev You need to add OPENAI_API_KEY to your environment variables.
"""
import logging
import os
import sys
from browser_use.controller.service import ControllerService
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
@@ -18,19 +15,15 @@ from langchain_openai import ChatOpenAI
from browser_use import Agent
logging.basicConfig(level=logging.INFO)
llm = ChatOpenAI(model='gpt-4o')
agent = Agent(
task='Opening new tabs to search for images of Albert Einstein, Oprah Winfrey, and Steve Jobs. Then ask user for further instructions.',
llm=llm,
task='Find a one-way flight from Bali to Oman on 12 January 2025 on Google Flights. Return me the cheapest option.',
llm=ChatOpenAI(model='gpt-4o'),
)
async def main():
result, history = await agent.run()
print(result)
print(history)
await agent.run()
asyncio.run(main())

View File

@@ -1,32 +0,0 @@
"""
Simple try of the agent.
@dev You need to add OPENAI_API_KEY to your environment variables.
"""
import logging
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
from langchain_openai import ChatOpenAI
from browser_use import Agent
logging.basicConfig(level=logging.INFO)
llm = ChatOpenAI(model='gpt-4o')
agent = Agent(
task='Apply to 2025 batch of Talent Kick. Use dummy data. Here is some information about me: Name: John Doe, Email: john.doe@example.com, Phone: +1234567890, LinkedIn: https://www.linkedin.com/in/john-doe, Github: https://github.com/john-doe, Twitter: https://twitter.com/john-doe, StackOverflow: https://stackoverflow.com/users/123456/john-doe, Youtube: https://www.youtube.com/user/john-doe, Education: BSc Computer Science from MIT (2020-2024), Work Experience: Software Engineer at Google (2024-present), Skills: Python, JavaScript, React, Node.js, AWS, Docker, Kubernetes, Projects: Built an AI-powered chatbot with 10k+ users, Created an open-source library with 1k+ stars on Github, Achievements: Won first place in MIT Hackathon 2023, Published paper on ML at ICML 2024, Languages: English (Native), Spanish (Fluent), Mandarin (Basic), Interests: AI/ML, Open Source, Competitive Programming, Hobbies: Playing guitar, Rock climbing, Chess',
llm=llm,
)
async def main():
await agent.run()
asyncio.run(main())

View File

@@ -17,9 +17,7 @@ import argparse
import asyncio
from browser_use import Agent
from browser_use.controller.service import ControllerService
logging.basicConfig(level=logging.INFO)
from browser_use.controller.service import Controller
def get_llm(provider: str):
@@ -50,14 +48,13 @@ llm = get_llm(args.provider)
agent = Agent(
task=args.query,
llm=llm,
controller=ControllerService(keep_open=True),
controller=Controller(keep_open=True),
# save_conversation_path='./tmp/try_flight/',
)
async def main():
result, history = await agent.run()
print(result)
print(history)
await agent.run()
asyncio.run(main())

View File

@@ -1,37 +0,0 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
from langchain_anthropic import ChatAnthropic
from browser_use.agent.service import AgentService
from browser_use.controller.service import ControllerService
task = 'Open 3 wikipedia pages in different tabs and summarize the content of all pages.'
controller = ControllerService()
model = ChatAnthropic(
model_name='claude-3-5-sonnet-20240620', timeout=25, stop=None, temperature=0.3
)
agent = AgentService(task, model, controller, use_vision=True)
async def main():
max_steps = 50
# Run the agent step by step
for i in range(max_steps):
print(f'\n📍 Step {i+1}')
action, result = await agent.step()
print('Action:', action)
print('Result:', result)
if result.done:
print('\n✅ Task completed successfully!')
print('Extracted content:', result.extracted_content)
break
asyncio.run(main())

View File

@@ -4,7 +4,7 @@ description = "Let LLMs interact with websites through a simple interface"
authors = [
{ name = "Gregor Zunic" }
]
version = "0.1.0"
version = "0.1.1"
readme = "README.md"
requires-python = ">=3.11"
classifiers = [
@@ -12,24 +12,11 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dynamic = ["dependencies", "optional-dependencies"]
dependencies = [
"MainContentExtractor>=0.0.4",
"Selenium-Screenshot>=2.1.0",
"beautifulsoup4>=4.12.3",
"langchain>=0.3.7",
"langchain-openai>=0.2.5",
"langchain-anthropic>=0.2.4",
"langchain-fireworks>=0.2.5",
"pydantic>=2.9.2",
"pytest>=8.3.3",
"pytest-asyncio>=0.24.0",
"python-dotenv>=1.0.1",
"requests>=2.32.3",
"selenium>=4.26.1",
"webdriver-manager>=4.0.2",
"lxml_html_clean>=0.3.1"
]
[tool.setuptools.dynamic]
dependencies = {file = ["requirements.txt"]}
optional-dependencies = {dev = {file = ["requirements-dev.txt"]}}
[tool.ruff]
line-length = 100
@@ -42,9 +29,3 @@ docstring-code-format = true
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project.optional-dependencies]
dev = [
"tokencost>=0.1.16",
"hatch>=1.13.0",
]

View File

@@ -1,2 +1,29 @@
[pytest]
asyncio_mode = auto
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
integration: marks tests as integration tests
unit: marks tests as unit tests
testpaths =
tests
python_files =
test_*.py
*_test.py
addopts =
-v
--strict-markers
--tb=short
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function
log_cli = true
; log_cli_level = DEBUG
log_cli_format = %(levelname)-8s [%(name)s] %(message)s
filterwarnings =
ignore::pytest.PytestDeprecationWarning
ignore::DeprecationWarning
log_level = INFO

3
requirements-dev.txt Normal file
View File

@@ -0,0 +1,3 @@
tokencost>=0.1.16
hatch>=1.13.0
build>=1.2.2

View File

@@ -11,4 +11,4 @@ pytest-asyncio>=0.24.0
python-dotenv>=1.0.1
requests>=2.32.3
selenium>=4.26.1
webdriver-manager>=4.0.2
webdriver-manager>=4.0.2

File diff suppressed because it is too large Load Diff

208
tests/test_agent_actions.py Normal file
View File

@@ -0,0 +1,208 @@
import asyncio
import pytest
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from browser_use.agent.service import Agent
from browser_use.controller.service import Controller
@pytest.fixture
def llm():
"""Initialize language model for testing"""
# return ChatAnthropic(model_name='claude-3-5-sonnet-20240620', timeout=25, stop=None)
return ChatOpenAI(model='gpt-4o')
# return ChatOpenAI(model='gpt-4o-mini')
@pytest.fixture
async def agent_with_controller():
"""Create agent with controller for testing"""
controller = Controller(keep_open=False)
print('init controller')
try:
yield controller
finally:
if controller.browser:
controller.browser.close(force=True)
@pytest.mark.asyncio
async def test_ecommerce_interaction(llm, agent_with_controller):
"""Test complex ecommerce interaction sequence"""
agent = Agent(
task="Go to amazon.com, search for 'laptop', filter by 4+ stars, and find the price of the first result",
llm=llm,
controller=agent_with_controller,
save_conversation_path='tmp/test_ecommerce_interaction/conversation',
)
history = await agent.run(max_steps=20)
# Verify sequence of actions
action_sequence = []
for h in history:
action = getattr(h.model_output, 'action', None)
if action and (getattr(action, 'go_to_url', None) or getattr(action, 'open_tab', None)):
action_sequence.append('navigate')
elif action and getattr(action, 'input_text', None):
action_sequence.append('input')
# Check that the input is 'laptop'
inp = action.input_text.text.lower()
if inp == 'laptop':
action_sequence.append('input_exact_correct')
elif 'laptop' in inp:
action_sequence.append('correct_in_input')
else:
action_sequence.append('incorrect_input')
elif action and getattr(action, 'click_element', None):
action_sequence.append('click')
if action is None:
print(h.result)
print(h.model_output)
# Verify essential steps were performed
assert 'navigate' in action_sequence # Navigated to Amazon
assert 'input' in action_sequence # Entered search term
assert 'click' in action_sequence # Clicked search/filter
assert 'input_exact_correct' in action_sequence or 'correct_in_input' in action_sequence
@pytest.mark.asyncio
async def test_error_recovery(llm, agent_with_controller):
"""Test agent's ability to recover from errors"""
agent = Agent(
task='Navigate to nonexistent-site.com and then recover by going to google.com',
llm=llm,
controller=agent_with_controller,
)
history = await agent.run(max_steps=10)
recovery_action = next(
(
h
for h in history
if h.model_output
and getattr(h.model_output, 'action', None)
and getattr(h.model_output.action, 'go_to_url', None)
and getattr(h.model_output.action.go_to_url, 'url', '').endswith('google.com') # type: ignore -> pretty weird way to do this
),
None,
)
assert recovery_action is not None
@pytest.mark.asyncio
async def test_find_contact_email(llm, agent_with_controller):
"""Test agent's ability to find contact email on a website"""
agent = Agent(
task='Go to https://browser-use.com/ and find out the contact email',
llm=llm,
controller=agent_with_controller,
)
history = await agent.run(max_steps=10)
# Verify the agent found the contact email
email_action = next(
(
h
for h in history
if h.result.extracted_content and 'info@browser-use.com' in h.result.extracted_content
),
None,
)
assert email_action is not None
@pytest.mark.asyncio
async def test_agent_finds_installation_command(llm, agent_with_controller):
"""Test agent's ability to find the pip installation command for browser-use on the web"""
agent = Agent(
task='Find the pip installation command for the browser-use repo',
llm=llm,
controller=agent_with_controller,
)
history = await agent.run(max_steps=10)
# Verify the agent found the correct installation command
install_command_action = next(
(
h
for h in history
if h.result.extracted_content
and 'pip install browser-use' in h.result.extracted_content
),
None,
)
assert install_command_action is not None
class CaptchaTest(BaseModel):
name: str
url: str
success_text: str
additional_text: str | None = None
# pytest tests/test_agent_actions.py -v -k "test_captcha_solver" --capture=no --log-cli-level=INFO
@pytest.mark.asyncio
@pytest.mark.parametrize(
'captcha',
[
# good test for num_clicks
CaptchaTest(
name='Rotate Captcha',
url='https://2captcha.com/demo/rotatecaptcha',
success_text='Captcha is passed successfully',
additional_text='Use num_clicks with number to click multiple times at once in same direction. click done when image is exact correct position.',
),
CaptchaTest(
name='Text Captcha',
url='https://2captcha.com/demo/text',
success_text='Captcha is passed successfully!',
),
CaptchaTest(
name='Basic Captcha',
url='https://captcha.com/demos/features/captcha-demo.aspx',
success_text='Correct!',
),
CaptchaTest(
name='MT Captcha',
url='https://2captcha.com/demo/mtcaptcha',
success_text='Verified Successfully',
additional_text='Stop when you solved it successfully.',
),
],
)
async def test_captcha_solver(llm, agent_with_controller, captcha: CaptchaTest):
"""Test agent's ability to solve different types of captchas"""
agent = Agent(
task=f'Go to {captcha.url} and solve the captcha. {captcha.additional_text}',
llm=llm,
controller=agent_with_controller,
)
history = await agent.run(max_steps=10)
# Verify the agent solved the captcha
solved = False
for h in history:
last = h.state.items
if any(captcha.success_text in item.text for item in last):
solved = True
break
assert solved, f'Failed to solve {captcha.name}'
# python -m pytest tests/test_agent_actions.py -v --capture=no
# pytest tests/test_agent_actions.py -v -k "test_captcha_solver" --capture=no --log-cli-level=INFO

View File

@@ -0,0 +1,157 @@
import asyncio
import pytest
from langchain_openai import ChatOpenAI
from browser_use.agent.service import Agent
from browser_use.controller.service import Controller
@pytest.fixture
def llm():
"""Initialize language model for testing"""
return ChatOpenAI(model='gpt-4o') # Use appropriate model
@pytest.fixture
async def controller():
"""Initialize the controller"""
controller = Controller()
try:
yield controller
finally:
if controller.browser:
controller.browser.close(force=True)
@pytest.mark.asyncio
async def test_search_google(llm, controller):
"""Test 'Search Google' action"""
agent = Agent(
task="Search Google for 'OpenAI'.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=2)
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'search_google' in action_names
@pytest.mark.asyncio
async def test_go_to_url(llm, controller):
"""Test 'Navigate to URL' action"""
agent = Agent(
task="Navigate to 'https://www.python.org'.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=2)
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'go_to_url' in action_names
@pytest.mark.asyncio
async def test_go_back(llm, controller):
"""Test 'Go back' action"""
agent = Agent(
task="Go to 'https://www.example.com', then go back.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=3)
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'go_to_url' in action_names
assert 'go_back' in action_names
@pytest.mark.asyncio
async def test_click_element(llm, controller):
"""Test 'Click element' action"""
agent = Agent(
task="Go to 'https://www.python.org' and click on the first link.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=4)
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'go_to_url' in action_names
assert 'click_element' in action_names
@pytest.mark.asyncio
async def test_input_text(llm, controller):
"""Test 'Input text' action"""
agent = Agent(
task="Go to 'https://www.google.com' and input 'OpenAI' into the search box.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=4)
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'go_to_url' in action_names
assert 'input_text' in action_names
@pytest.mark.asyncio
async def test_switch_tab(llm, controller):
"""Test 'Switch tab' action"""
agent = Agent(
task="Open new tabs with 'https://www.google.com' and 'https://www.wikipedia.org', then switch to the first tab.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=6)
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
open_tab_count = action_names.count('open_tab')
assert open_tab_count >= 2
assert 'switch_tab' in action_names
@pytest.mark.asyncio
async def test_open_new_tab(llm, controller):
"""Test 'Open new tab' action"""
agent = Agent(
task="Open a new tab and go to 'https://www.example.com'.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=3)
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'open_tab' in action_names
@pytest.mark.asyncio
async def test_extract_page_content(llm, controller):
"""Test 'Extract page content' action"""
agent = Agent(
task="Go to 'https://www.example.com' and extract the page content.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=3)
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'go_to_url' in action_names
assert 'extract_content' in action_names
@pytest.mark.asyncio
async def test_done_action(llm, controller):
"""Test 'Complete task' action"""
agent = Agent(
task="Navigate to 'https://www.example.com' and signal that the task is done.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=3)
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'go_to_url' in action_names
assert 'done' in action_names

116
tests/test_mind2web.py Normal file
View File

@@ -0,0 +1,116 @@
"""
Test browser automation using Mind2Web dataset tasks with pytest framework.
"""
import json
import logging
import os
from typing import Any, Dict, List
import pytest
from langchain_openai import ChatOpenAI
from browser_use.agent.service import Agent
from browser_use.controller.service import Controller
from browser_use.utils import logger
# Constants
MAX_STEPS = 50
TEST_SUBSET_SIZE = 10
@pytest.fixture(scope='session')
def test_cases() -> List[Dict[str, Any]]:
"""Load test cases from Mind2Web dataset"""
file_path = os.path.join(os.path.dirname(__file__), 'mind2web_data/processed.json')
logger.info(f'Loading test cases from {file_path}')
with open(file_path, 'r') as f:
data = json.load(f)
subset = data[:TEST_SUBSET_SIZE]
logger.info(f'Loaded {len(subset)}/{len(data)} test cases')
return subset
@pytest.fixture(scope='session')
def llm():
"""Initialize the language model"""
return ChatOpenAI(model='gpt-4o')
@pytest.fixture(scope='function')
async def controller():
"""Initialize the controller"""
controller = Controller()
try:
yield controller
finally:
if controller.browser:
controller.browser.close(force=True)
# run with: pytest -s -v tests/test_mind2web.py:test_random_samples
@pytest.mark.asyncio
async def test_random_samples(test_cases: List[Dict[str, Any]], llm, controller, validator):
"""Test a random sampling of tasks across different websites"""
import random
logger.info('=== Testing Random Samples ===')
# Take random samples
samples = random.sample(test_cases, 1)
for i, case in enumerate(samples, 1):
task = f"Go to {case['website']}.com and {case['confirmed_task']}"
logger.info(f'--- Random Sample {i}/{len(samples)} ---')
logger.info(f'Task: {task}\n')
agent = Agent(task, llm, controller)
await agent.run()
logger.info('Validating random sample task...')
# TODO: Validate the task
def test_dataset_integrity(test_cases):
"""Test the integrity of the test dataset"""
logger.info('\n=== Testing Dataset Integrity ===')
required_fields = ['website', 'confirmed_task', 'action_reprs']
missing_fields = []
logger.info(f'Checking {len(test_cases)} test cases for required fields')
for i, case in enumerate(test_cases, 1):
logger.debug(f'Checking case {i}/{len(test_cases)}')
for field in required_fields:
if field not in case:
missing_fields.append(f'Case {i}: {field}')
logger.warning(f"Missing field '{field}' in case {i}")
# Type checks
if not isinstance(case.get('confirmed_task'), str):
logger.error(f"Case {i}: 'confirmed_task' must be string")
assert False, 'Task must be string'
if not isinstance(case.get('action_reprs'), list):
logger.error(f"Case {i}: 'action_reprs' must be list")
assert False, 'Actions must be list'
if len(case.get('action_reprs', [])) == 0:
logger.error(f"Case {i}: 'action_reprs' must not be empty")
assert False, 'Must have at least one action'
if missing_fields:
logger.error('Dataset integrity check failed')
assert False, f'Missing fields: {missing_fields}'
else:
logger.info('✅ Dataset integrity check passed')
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,176 @@
import asyncio
import pytest
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from browser_use.agent.service import Agent
from browser_use.controller.service import Controller
@pytest.fixture
def llm():
"""Initialize the language model"""
return ChatOpenAI(model='gpt-4o') # Use appropriate model
@pytest.fixture
async def controller():
"""Initialize the controller with self-registered actions"""
controller = Controller()
# Define custom actions without Pydantic models
@controller.action('Print a message')
def print_message(message: str):
print(f'Message: {message}')
return f'Printed message: {message}'
@controller.action('Add two numbers')
def add_numbers(a: int, b: int):
result = a + b
return f'The sum is {result}'
@controller.action('Concatenate strings')
def concatenate_strings(str1: str, str2: str):
result = str1 + str2
return f'Concatenated string: {result}'
# Define Pydantic models
class SimpleModel(BaseModel):
name: str
age: int
class Address(BaseModel):
street: str
city: str
class NestedModel(BaseModel):
user: SimpleModel
address: Address
# Add actions with Pydantic model arguments
@controller.action('Process simple model', param_model=SimpleModel)
def process_simple_model(model: SimpleModel):
return f'Processed {model.name}, age {model.age}'
@controller.action('Process nested model', param_model=NestedModel)
def process_nested_model(model: NestedModel):
user_info = f'{model.user.name}, age {model.user.age}'
address_info = f'{model.address.street}, {model.address.city}'
return f'Processed user {user_info} at address {address_info}'
@controller.action('Process multiple models')
def process_multiple_models(model1: SimpleModel, model2: Address):
return f'Processed {model1.name} living at {model2.street}, {model2.city}'
try:
yield controller
finally:
if controller.browser:
controller.browser.close(force=True)
@pytest.mark.asyncio
async def test_self_registered_actions_no_pydantic(llm, controller):
"""Test self-registered actions with individual arguments"""
agent = Agent(
task="First, print the message 'Hello, World!'. Then, add 10 and 20. Next, concatenate 'foo' and 'bar'.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=10)
# Check that custom actions were executed
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'print_message' in action_names
assert 'add_numbers' in action_names
assert 'concatenate_strings' in action_names
@pytest.mark.asyncio
async def test_mixed_arguments_actions(llm, controller):
"""Test actions with mixed argument types"""
# Define another action during the test
@controller.action('Calculate the area of a rectangle')
def calculate_area(length: float, width: float):
area = length * width
return f'The area is {area}'
agent = Agent(
task='Calculate the area of a rectangle with length 5.5 and width 3.2.',
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=5)
# Check that the action was executed
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'calculate_area' in action_names
# check result
correct = 'The area is 17.6'
assert correct in [h.result.extracted_content for h in history if h.model_output]
@pytest.mark.asyncio
async def test_pydantic_simple_model(llm, controller):
"""Test action with a simple Pydantic model argument"""
agent = Agent(
task="Process a simple model with name 'Alice' and age 30.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=5)
# Check that the action was executed
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'process_simple_model' in action_names
correct = 'Processed Alice, age 30'
assert correct in [h.result.extracted_content for h in history if h.model_output]
@pytest.mark.asyncio
async def test_pydantic_nested_model(llm, controller):
"""Test action with a nested Pydantic model argument"""
agent = Agent(
task="Process a nested model with user name 'Bob', age 25, living at '123 Maple St', 'Springfield'.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=5)
# Check that the action was executed
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'process_nested_model' in action_names
correct = 'Processed user Bob, age 25 at address 123 Maple St, Springfield'
assert correct in [h.result.extracted_content for h in history if h.model_output]
@pytest.mark.asyncio
async def test_pydantic_multiple_models(llm, controller):
"""Test action with multiple Pydantic model arguments"""
agent = Agent(
task="Process models with user name 'Carol', age 28, living at '456 Oak Ave', 'Shelbyville'.",
llm=llm,
controller=controller,
)
history = await agent.run(max_steps=5)
# Check that the action was executed
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
assert 'process_multiple_models' in action_names
correct = 'Processed Carol living at 456 Oak Ave, Shelbyville'
assert correct in [h.result.extracted_content for h in history if h.model_output]
# run this file with:
# pytest tests/test_self_registered_actions.py --capture=no

48
tests/test_stress.py Normal file
View File

@@ -0,0 +1,48 @@
import asyncio
import time
import pytest
from langchain_openai import ChatOpenAI
from browser_use.agent.service import Agent
from browser_use.controller.service import Controller
@pytest.fixture
def llm():
"""Initialize the language model"""
return ChatOpenAI(model='gpt-4o') # Use appropriate model
@pytest.fixture
async def controller():
"""Initialize the controller"""
controller = Controller()
try:
yield controller
finally:
if controller.browser:
controller.browser.close(force=True)
# should get rate limited
@pytest.mark.asyncio
async def test_open_10_tabs_and_extract_content(llm, controller):
"""Stress test: Open 10 tabs and extract content"""
agent = Agent(
task='Open new tabs with example.com, example.net, example.org, and seven more example sites. Then, extract the content from each.',
llm=llm,
controller=controller,
)
start_time = time.time()
history = await agent.run(max_steps=50)
end_time = time.time()
total_time = end_time - start_time
print(f'Total time: {total_time:.2f} seconds')
# Check for errors
errors = [h.result.error for h in history if h.result and h.result.error]
assert len(errors) == 0, 'Errors occurred during the test'
# check if 10 tabs were opened
assert len(controller.browser.current_state.tabs) >= 10, '10 tabs were not opened'