mirror of
https://github.com/browser-use/browser-use.git
synced 2025-02-18 01:18:20 +03:00
Added custom actions registry and fixed extraction layer (#20)
* Validator * Test mind2web * Cleaned up logger * Pytest logger * Cleaned up logger * Disable flag for human input * Multiple clicks per button * Multiple clicks per button * More structured system prompt * Fields with description * System prompt example * One logger * Cleaner logging * Log step in step function * Fix critical clicking error - wrong argument used * Improved thought process of agent * Improve system prompt * Remove human input message * Custome action registration * Pydantic model for custom actions * Pydantic model for custome output * Runs through, model outputs functions, but not called yet * Work in progress - description for custome actions * Description works, but schema not yet * Model can call the right action - but is not executed * Seperate is_controller_action and is_custom_action * Works! Model can call custom function * Use registry for action, but result is not feed back to model * Include result in messages * Works with custom function - but typing is not correct * Renamed registry * First test cases * Captcha tests * Pydantic for tests * Improve prompts for multy step * System prompt structure * Handle errors like validation error * Refactor error handling in agent * Refactor error handling in agent * Improved logging * Update view * Fix click parameter to index * Simplify dynamic actions * Use run instead of step * Rename history * Rename AgentService to Agent * Rename ControllerService to Controller * Pytest file * Rename get state * Rename BrowserService * reversed dom extraction recursion to while * Rename use_vision * Rename use_vision * reversed dom tree items and made browser less anoying * Renaming and fixing type errors * Renamed class names for agent * updated requirements * Update prompt * Action registration works for user and controller * Fix done call by returning ActionResult * Fix if result is none * Rename AgentOutput and ActionModel * Improved prompt Passes 6/8 tests from test_agent_actions * Calculate token cost * Improve display * Simplified logger * Test function calling * created super simple xpath extraction algo * Tests logging * tiny fixes to dom extraction * Remove test * Dont log number of clicks * Pytest file * merged per element js checks * Check if driver is still open * super fast processing * fixed agent planning and stuff * Fix example * Fix example * Improve error * Improved error correction * New line for step * small type error fixes * Test for pydantic * Fix line * Removed sample * fixed readme and examples --------- Co-authored-by: magmueller <mamagnus00@gmail.com>
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -160,4 +160,8 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
temp
|
||||
temp
|
||||
tmp
|
||||
|
||||
|
||||
.DS_Store
|
||||
53
.vscode/launch.json
vendored
53
.vscode/launch.json
vendored
@@ -2,20 +2,51 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Debug Tests",
|
||||
"name": "Python Debugger: Module",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "examples.extend_actions"
|
||||
},
|
||||
{
|
||||
"name": "Python: Debug extend_actions",
|
||||
"type": "module",
|
||||
"request": "launch",
|
||||
"module": "examples.extend_actions",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Python: Debug Captcha Tests",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/pytest",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"src/tests/test_kayak_search.py",
|
||||
"tests/test_agent_actions.py",
|
||||
"-v",
|
||||
"-s"
|
||||
"-k",
|
||||
"test_captcha_solver",
|
||||
"--capture=no",
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Python: Debug Ecommerce Interaction",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"tests/test_agent_actions.py",
|
||||
"-v",
|
||||
"-k",
|
||||
"test_ecommerce_interaction",
|
||||
"--capture=no",
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
}
|
||||
]
|
||||
}
|
||||
231
README.md
231
README.md
@@ -1,30 +1,30 @@
|
||||
<div align="center">
|
||||
# 🌐 Browser Use
|
||||
|
||||
# 🌐 Browser-Use
|
||||
|
||||
### Open-Source Web Automation with LLMs
|
||||
Make websites accessible for AI agents 🤖.
|
||||
|
||||
[](https://github.com/gregpr07/browser-use/stargazers)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://discord.gg/uaCtrbbv)
|
||||
[](https://link.browser-use.com/discord)
|
||||
|
||||
</div>
|
||||
Browser use is the easiest way to connect your AI agents with the browser. If you have used Browser Use for your project feel free to show it off in our [Discord](https://link.browser-use.com/discord).
|
||||
|
||||
Let LLMs interact with websites through a simple interface.
|
||||
# Quick start
|
||||
|
||||
## Short Example
|
||||
With pip:
|
||||
|
||||
```bash
|
||||
pip install browser-use
|
||||
```
|
||||
|
||||
Spin up your agent:
|
||||
|
||||
```python
|
||||
from langchain_openai import ChatOpenAI
|
||||
from browser_use import Agent
|
||||
|
||||
agent = Agent(
|
||||
task="Go to hackernews on show hn and give me top 10 post titles, their points and hours. Calculate for each the ratio of points per hour.",
|
||||
task="Find a one-way flight from Bali to Oman on 12 January 2025 on Google Flights. Return me the cheapest option.",
|
||||
llm=ChatOpenAI(model="gpt-4o"),
|
||||
)
|
||||
|
||||
@@ -32,42 +32,94 @@ agent = Agent(
|
||||
await agent.run()
|
||||
```
|
||||
|
||||
## Demo
|
||||
And don't forget to add your API keys to your `.env` file.
|
||||
|
||||
<div>
|
||||
<a href="https://www.loom.com/share/63612b5994164cb1bb36938d62fe9983">
|
||||
<img style="max-width:300px;" src="https://cdn.loom.com/sessions/thumbnails/63612b5994164cb1bb36938d62fe9983-7133f9e169672e6f-full-play.gif">
|
||||
</a>
|
||||
<p><i>Prompt: Go to hackernews on show hn and give me top 10 post titles, their points and hours. Calculate for each the ratio of points per hour. (1x speed) </i></p>
|
||||
</div>
|
||||
<div>
|
||||
<a href="https://www.loom.com/share/2af938b9f8024647950a9e18b3946054">
|
||||
<img style="max-width:300px;" src="https://cdn.loom.com/sessions/thumbnails/2af938b9f8024647950a9e18b3946054-b99c733cf670e568-full-play.gif">
|
||||
</a>
|
||||
<p><i>Prompt: Search the top 3 AI companies 2024 and find what out what concrete hardware each is using for their model. (1x speed)</i></p>
|
||||
</div>
|
||||
|
||||
```bash
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
```
|
||||
|
||||
# Demo
|
||||
|
||||
<div style="display: flex; justify-content: space-between; margin-top: 20px;">
|
||||
<div style="flex: 1; margin-right: 10px;">
|
||||
<img style="width: 100%;" src="./static/kayak.gif" alt="Kayak flight search demo">
|
||||
<p><i>Prompt: Go to kayak.com and find a one-way flight from Zürich to San Francisco on 12 January 2025. (2.5x speed)</i></p>
|
||||
</div>
|
||||
<div style="flex: 1; margin-left: 10px;">
|
||||
<img style="width: 100%;" src="./static/photos.gif" alt="Photos search demo">
|
||||
<p><i>Prompt: Opening new tabs and searching for images for these people: Albert Einstein, Oprah Winfrey, Steve Jobs. (2.5x speed)</i></p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
DEMO VIDEO HERE
|
||||
|
||||
## Local Setup
|
||||
# Features ⭐
|
||||
|
||||
- Vision + html extraction
|
||||
- Automatic multi-tab management
|
||||
- Extract clicked elements XPaths
|
||||
- Add custom actions (e.g. add data to database which the LLM can use)
|
||||
- Self-correcting
|
||||
- Use any LLM supported by LangChain (e.g. gpt4o, gpt4o mini, claude 3.5 sonnet, llama 3.1 405b, etc.)
|
||||
|
||||
## Register custom actions
|
||||
|
||||
If you want to add custom actions your agent can take, you can register them like this:
|
||||
|
||||
```python
|
||||
from browser_use.agent.service import Agent
|
||||
from browser_use.browser.service import Browser
|
||||
from browser_use.controller.service import Controller
|
||||
|
||||
# Initialize controller first
|
||||
controller = Controller()
|
||||
|
||||
@controller.action('Ask user for information')
|
||||
def ask_human(question: str, display_question: bool) -> str:
|
||||
return input(f'\n{question}\nInput: ')
|
||||
```
|
||||
|
||||
Or define your parameters using Pydantic
|
||||
|
||||
```python
|
||||
class JobDetails(BaseModel):
|
||||
title: str
|
||||
company: str
|
||||
job_link: str
|
||||
salary: Optional[str] = None
|
||||
|
||||
@controller.action('Save job details which you found on page', param_model=JobDetails, requires_browser=True)
|
||||
def save_job(params: JobDetails, browser: Browser):
|
||||
print(params)
|
||||
|
||||
# use the browser normally
|
||||
browser.driver.get(params.job_link)
|
||||
```
|
||||
|
||||
and then run your agent:
|
||||
|
||||
```python
|
||||
model = ChatAnthropic(model_name='claude-3-5-sonnet-20240620', timeout=25, stop=None, temperature=0.3)
|
||||
agent = Agent(task=task, llm=model, controller=controller)
|
||||
|
||||
await agent.run()
|
||||
```
|
||||
|
||||
## Get XPath history
|
||||
|
||||
To get the entire history of everything the agent has done, you can use the output of the `run` method:
|
||||
|
||||
```python
|
||||
history: list[AgentHistory] = await agent.run()
|
||||
|
||||
print(history)
|
||||
```
|
||||
|
||||
## More examples
|
||||
|
||||
For more examples see the [examples](examples) folder or join the [Discord](https://link.browser-use.com/discord) and show off your project.
|
||||
|
||||
# Contributing
|
||||
|
||||
Contributions are welcome! Feel free to open issues for bugs or feature requests.
|
||||
|
||||
## Setup
|
||||
|
||||
1. Create a virtual environment and install dependencies:
|
||||
|
||||
```bash
|
||||
# To install all dependencies including dev
|
||||
pip install . ."[dev]"
|
||||
pip install -r requirements.txt -r requirements-dev.txt
|
||||
```
|
||||
|
||||
2. Add your API keys to the `.env` file:
|
||||
@@ -76,119 +128,22 @@ pip install . ."[dev]"
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
E.g. for OpenAI:
|
||||
or copy the following to your `.env` file:
|
||||
|
||||
```bash
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
```
|
||||
|
||||
You can use any LLM model supported by LangChain by adding the appropriate environment variables. See [langchain models](https://python.langchain.com/docs/integrations/chat/) for available options.
|
||||
|
||||
## Features
|
||||
|
||||
- Universal LLM Support - Works with any Language Model
|
||||
- Interactive Element Detection - Automatically finds interactive elements
|
||||
- Multi-Tab Management - Seamless handling of browser tabs
|
||||
- XPath Extraction for scraping functions - No more manual DevTools inspection
|
||||
- Vision Model Support - Process visual page information
|
||||
- Customizable Actions - Add your own browser interactions (e.g. add data to database which the LLM can use)
|
||||
- Handles dynamic content - dont worry about cookies or changing content
|
||||
- Chain-of-thought prompting with memory - Solve long-term tasks
|
||||
- Self-correcting - If the LLM makes a mistake, the agent will self-correct its actions
|
||||
|
||||
## Advanced Examples
|
||||
|
||||
### Chain of Agents
|
||||
|
||||
You can persist the browser across multiple agents and chain them together.
|
||||
|
||||
```python
|
||||
from asyncio import run
|
||||
from browser_use import Agent, Controller
|
||||
from dotenv import load_dotenv
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
load_dotenv()
|
||||
|
||||
# Persist browser state across agents
|
||||
controller = Controller()
|
||||
|
||||
# Initialize browser agent
|
||||
agent1 = Agent(
|
||||
task="Open 3 VCs websites in the New York area.",
|
||||
llm=ChatAnthropic(model="claude-3-5-sonnet-20240620", timeout=25, stop=None),
|
||||
controller=controller)
|
||||
agent2 = Agent(
|
||||
task="Give me the names of the founders of the companies in all tabs.",
|
||||
llm=ChatAnthropic(model="claude-3-5-sonnet-20240620", timeout=25, stop=None),
|
||||
controller=controller)
|
||||
|
||||
run(agent1.run())
|
||||
founders, history = run(agent2.run())
|
||||
|
||||
print(founders)
|
||||
```
|
||||
|
||||
You can use the `history` to run the agents again deterministically.
|
||||
|
||||
## Command Line Usage
|
||||
|
||||
Run examples directly from the command line (clone the repo first):
|
||||
### Building the package
|
||||
|
||||
```bash
|
||||
python examples/try.py "Your query here" --provider [openai|anthropic]
|
||||
hatch build
|
||||
```
|
||||
|
||||
### Anthropic
|
||||
|
||||
You need to add `ANTHROPIC_API_KEY` to your environment variables. Example usage:
|
||||
|
||||
```bash
|
||||
|
||||
python examples/try.py "Search the top 3 AI companies 2024 and find out in 3 new tabs what hardware each is using for their models" --provider anthropic
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
You need to add `OPENAI_API_KEY` to your environment variables. Example usage:
|
||||
|
||||
```bash
|
||||
python examples/try.py "Go to hackernews on show hn and give me top 10 post titles, their points and hours. Calculate for each the ratio of points per hour. " --provider anthropic
|
||||
```
|
||||
|
||||
## 🤖 Supported Models
|
||||
|
||||
All LangChain chat models are supported. Tested with:
|
||||
|
||||
- GPT-4o
|
||||
- GPT-4o Mini
|
||||
- Claude 3.5 Sonnet
|
||||
- LLama 3.1 405B
|
||||
|
||||
## Limitations
|
||||
|
||||
- When extracting page content, the message length increases and the LLM gets slower.
|
||||
- Currently one agent costs about 0.01$
|
||||
- Sometimes it tries to repeat the same task over and over again.
|
||||
- Some elements might not be extracted which you want to interact with.
|
||||
- What should we focus on the most?
|
||||
- Robustness
|
||||
- Speed
|
||||
- Cost reduction
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [x] Save agent actions and execute them deterministically
|
||||
- [ ] Pydantic forced output
|
||||
- [ ] Third party SERP API for faster Google Search results
|
||||
- [ ] Multi-step action execution to increase speed
|
||||
- [ ] Test on mind2web dataset
|
||||
- [ ] Add more browser actions
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Feel free to open issues for bugs or feature requests.
|
||||
|
||||
Feel free to join the [Discord](https://discord.gg/uaCtrbbv) for discussions and support.
|
||||
Feel free to join the [Discord](https://link.browser-use.com/discord) for discussions and support.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from browser_use.agent.service import AgentService as Agent
|
||||
from browser_use.browser.service import BrowserService as Browser
|
||||
from browser_use.controller.service import ControllerService as Controller
|
||||
from browser_use.logging_config import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
from browser_use.agent.service import Agent as Agent
|
||||
from browser_use.browser.service import Browser as Browser
|
||||
from browser_use.controller.service import Controller as Controller
|
||||
from browser_use.dom.service import DomService
|
||||
|
||||
__all__ = ['Agent', 'Browser', 'Controller', 'DomService']
|
||||
|
||||
@@ -1,12 +1,73 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from browser_use.controller.views import ControllerPageState
|
||||
from browser_use.browser.views import BrowserState
|
||||
|
||||
|
||||
class AgentSystemPrompt:
|
||||
def __init__(self, task: str, default_action_description: str):
|
||||
def __init__(self, task: str, action_description: str, current_date: datetime):
|
||||
self.task = task
|
||||
self.default_action_description = default_action_description
|
||||
self.default_action_description = action_description
|
||||
self.current_date = current_date
|
||||
|
||||
def response_format(self) -> str:
|
||||
"""
|
||||
Returns the response format for the agent.
|
||||
|
||||
Returns:
|
||||
str: Response format
|
||||
"""
|
||||
return """
|
||||
{{
|
||||
"current_state": {{
|
||||
"valuation_previous_goal": "String starting with "Success", "Failed:" or "Unknown" to evaluate if the previous next_goal is achieved. If failed or unknown describe why.",
|
||||
"memory": "Your memory with things you need to remeber until the end of the task for the user. You can also store overall progress in a bigger task. You have access to this in the next steps.",
|
||||
"next_goal": "String describing the next immediate goal which can be achieved with one action"
|
||||
}},
|
||||
"action": {{
|
||||
// EXACTLY ONE of the following available actions must be specified
|
||||
}}
|
||||
}}"""
|
||||
|
||||
def example_response(self) -> str:
|
||||
"""
|
||||
Returns an example response for the agent.
|
||||
|
||||
Returns:
|
||||
str: Example response
|
||||
"""
|
||||
return """{"current_state": {"valuation_previous_goal": "Success", "memory": "We applied already for 3/7 jobs, 1. ..., 2. ..., 3. ...", "next_goal": "Click on the button x to apply for the next job"}, "action": {"click_element": {"index": 44,"num_clicks": 2}}}"""
|
||||
|
||||
def important_rules(self) -> str:
|
||||
"""
|
||||
Returns the important rules for the agent.
|
||||
|
||||
Returns:
|
||||
str: Important rules
|
||||
"""
|
||||
return """
|
||||
1. Only use indexes that exist in the input list for click or input text actions. If no indexes exist, try alternative actions, e.g. go back, search google etc.
|
||||
2. If stuck, try alternative approaches, e.g. go back, search google, or extract_page_content
|
||||
3. When you are done with the complete task, use the done action. Make sure to have all information the user needs and return the result.
|
||||
4. If an image is provided, use it to understand the context, the bounding boxes around the buttons have the same indexes as the interactive elements.
|
||||
6. ALWAYS respond in the RESPONSE FORMAT with valid JSON.
|
||||
7. If the page is empty use actions like "go_to_url", "search_google" or "open_tab"
|
||||
8. Remember: Choose EXACTLY ONE action per response. Invalid combinations or multiple actions will be rejected.
|
||||
9. If popups like cookies appear, accept or close them
|
||||
"""
|
||||
|
||||
def input_format(self) -> str:
|
||||
return """
|
||||
Example:
|
||||
33[:]\t<button>Interactive element</button>
|
||||
_[:] Text content...
|
||||
|
||||
Explanation:
|
||||
index[:] Interactible element with index. You can only interact with all elements which are clickable and refer to them by their index.
|
||||
_[:] elements are just for more context, but not interactable.
|
||||
\t: Tab indent (1 tab for depth 1 etc.). This is to help you understand which elements belong to each other.
|
||||
"""
|
||||
|
||||
def get_system_message(self) -> SystemMessage:
|
||||
"""
|
||||
@@ -15,56 +76,31 @@ class AgentSystemPrompt:
|
||||
Returns:
|
||||
str: Formatted system prompt
|
||||
"""
|
||||
# System prompts for the agent
|
||||
# output_format = """
|
||||
# {"valuation_previous_goal": "Success if completed, else short sentence of why not successful.", "goal": "short description what you want to achieve", "action": "action_name", "params": {"param_name": "param_value"}}
|
||||
# """
|
||||
time_str = self.current_date.strftime('%Y-%m-%d %H:%M')
|
||||
|
||||
AGENT_PROMPT = f"""
|
||||
You are an AI agent that helps users interact with websites.
|
||||
Your input are all the interactive elements with its context of the current page from.
|
||||
|
||||
This is how an input looks like:
|
||||
33:\t<button>Clickable element</button>
|
||||
_: Not clickable, only for your context
|
||||
\t: Tab indent (1 tab for depth 1 etc.). This is to help you understand which elements belong to each other.
|
||||
You are an AI agent that helps users interact with websites. You receive a list of interactive elements from the current webpage and must respond with specific actions. Today's date is {time_str}.
|
||||
|
||||
INPUT FORMAT:
|
||||
{self.input_format()}
|
||||
|
||||
In the beginning the list will be empty.
|
||||
On elements with _ you can not click.
|
||||
On elements with a index you can click.
|
||||
|
||||
Additional you get a list of your previous actions.
|
||||
You have to respond in the following RESPONSE FORMAT:
|
||||
{self.response_format()}
|
||||
|
||||
|
||||
Respond with a valid JSON object, containing 2 keys: current_state and action.
|
||||
In current_state there are 3 keys:
|
||||
valuation_previous_goal: Evaluate if the previous goal was achieved or what went wrong. E.g. Failed because ...
|
||||
memory: This you can use as a memory to store where you are in your overall task. E.g. if you need to find 10 jobs, you can store the already found jobs here.
|
||||
next_goal: The next goal you want to achieve e.g. clicking on ... button to ...
|
||||
Your AVAILABLE ACTIONS:
|
||||
{self.default_action_description}
|
||||
|
||||
For the action choose EXACTLY ONE from the following list:
|
||||
{self.default_action_description}
|
||||
To interact with elements, use their index number from the input line. Write it in the click_element() or input_text() actions.
|
||||
Make sure to only use indexes that are present in the list.
|
||||
If you need more text from the page you can use the extract_page_content action.
|
||||
Example:
|
||||
{self.example_response()}
|
||||
|
||||
If you see any cookie or accept privacy policy please always just accepted them without hesitation.
|
||||
|
||||
If you evaluate repeatedly that you dont achieve the next_goal, try to find a new element that can help you achieve your task or if persistent, go back or reload the page and try a different approach.
|
||||
|
||||
You can ask_human for clarification if you are completely stuck or if you really need more information.
|
||||
|
||||
If a picture is provided, use it to understand the context and the next action.
|
||||
|
||||
If you are sure you are done you can extract_page_content to get the markdown content and in the next action call done() with the text of the requested result to end the task and wait for further instructions.
|
||||
|
||||
"""
|
||||
IMPORTANT RULES:
|
||||
{self.important_rules()}
|
||||
"""
|
||||
return SystemMessage(content=AGENT_PROMPT)
|
||||
|
||||
|
||||
class AgentMessagePrompt:
|
||||
def __init__(self, state: ControllerPageState):
|
||||
def __init__(self, state: BrowserState):
|
||||
self.state = state
|
||||
|
||||
def get_user_message(self) -> HumanMessage:
|
||||
@@ -91,4 +127,4 @@ Interactive elements:
|
||||
return HumanMessage(content=state_description)
|
||||
|
||||
def get_message_for_history(self) -> HumanMessage:
|
||||
return HumanMessage(content=f'Currently on url: {self.state.url}')
|
||||
return HumanMessage(content=f'Step url: {self.state.url}')
|
||||
|
||||
@@ -1,216 +1,386 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from openai import RateLimitError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from browser_use.agent.prompts import AgentMessagePrompt, AgentSystemPrompt
|
||||
from browser_use.agent.views import (
|
||||
ActionResult,
|
||||
AgentError,
|
||||
AgentHistory,
|
||||
AgentOutput,
|
||||
ClickElementControllerHistoryItem,
|
||||
InputTextControllerHistoryItem,
|
||||
Output,
|
||||
)
|
||||
from browser_use.controller.service import ControllerService
|
||||
from browser_use.controller.views import (
|
||||
ControllerActionResult,
|
||||
ControllerPageState,
|
||||
ModelPricingCatalog,
|
||||
Pricing,
|
||||
TokenDetails,
|
||||
TokenUsage,
|
||||
)
|
||||
from browser_use.browser.views import BrowserState
|
||||
from browser_use.controller.service import Controller
|
||||
from browser_use.utils import time_execution_async
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
force=True, # Prevent changing root logger config
|
||||
)
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
|
||||
class AgentService:
|
||||
class Agent:
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
llm: BaseChatModel,
|
||||
controller: ControllerService | None = None,
|
||||
controller: Optional[Controller] = None,
|
||||
use_vision: bool = True,
|
||||
save_conversation_path: str | None = None,
|
||||
save_conversation_path: Optional[str] = None,
|
||||
max_failures: int = 5,
|
||||
retry_delay: int = 10,
|
||||
):
|
||||
"""
|
||||
Agent service.
|
||||
|
||||
Args:
|
||||
task (str): Task to be performed.
|
||||
llm (AvailableModel): Model to be used.
|
||||
controller (ControllerService | None): You can reuse an existing or (automatically) create a new one.
|
||||
"""
|
||||
self.task = task
|
||||
self.use_vision = use_vision
|
||||
|
||||
self.controller_injected = controller is not None
|
||||
self.controller = controller or ControllerService()
|
||||
|
||||
self.llm = llm
|
||||
system_prompt = AgentSystemPrompt(
|
||||
task, default_action_description=self._get_action_description()
|
||||
).get_system_message()
|
||||
|
||||
# Init messages
|
||||
first_message = HumanMessage(content=f'Your main task is: {task}')
|
||||
self.messages: list[BaseMessage] = [system_prompt, first_message]
|
||||
self.n = 0
|
||||
|
||||
self.save_conversation_path = save_conversation_path
|
||||
if save_conversation_path is not None:
|
||||
|
||||
# Controller setup
|
||||
self.controller_injected = controller is not None
|
||||
self.controller = controller or Controller()
|
||||
|
||||
# Action and output models setup
|
||||
self._setup_action_models()
|
||||
|
||||
# Message history setup
|
||||
self.messages = self._initialize_messages()
|
||||
|
||||
# Tracking variables
|
||||
self.history: list[AgentHistory] = []
|
||||
self.n_steps = 1
|
||||
self.consecutive_failures = 0
|
||||
self.max_failures = max_failures
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
if save_conversation_path:
|
||||
logger.info(f'Saving conversation to {save_conversation_path}')
|
||||
|
||||
self.action_history: list[AgentHistory] = []
|
||||
self.usage_metadata = TokenUsage(
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
total_tokens=0,
|
||||
input_token_details=TokenDetails(),
|
||||
output_token_details=TokenDetails(),
|
||||
)
|
||||
|
||||
async def run(self, max_steps: int = 100):
|
||||
"""
|
||||
Execute the task.
|
||||
def _setup_action_models(self) -> None:
|
||||
"""Setup dynamic action models from controller's registry"""
|
||||
# Get the dynamic action model from controller's registry
|
||||
self.ActionModel = self.controller.registry.create_action_model()
|
||||
# Create output model with the dynamic actions
|
||||
self.AgentOutput = AgentOutput.type_with_custom_actions(self.ActionModel)
|
||||
|
||||
@dev ctrl+c to interrupt
|
||||
"""
|
||||
def _initialize_messages(self) -> list[BaseMessage]:
|
||||
"""Initialize message history with system and first message"""
|
||||
# Get action descriptions from controller's registry
|
||||
action_descriptions = self.controller.registry.get_prompt_description()
|
||||
|
||||
system_prompt = AgentSystemPrompt(
|
||||
self.task, action_description=action_descriptions, current_date=datetime.now()
|
||||
).get_system_message()
|
||||
|
||||
first_message = HumanMessage(content=f'Your task is: {self.task}')
|
||||
return [system_prompt, first_message]
|
||||
|
||||
@time_execution_async('--step')
|
||||
async def step(self) -> None:
|
||||
"""Execute one step of the task"""
|
||||
logger.info(f'\n📍 Step {self.n_steps}')
|
||||
state = self.controller.browser.get_state(use_vision=self.use_vision)
|
||||
|
||||
try:
|
||||
logger.info('\n' + '=' * 50)
|
||||
model_output = await self.get_next_action(state)
|
||||
result = self.controller.act(model_output.action)
|
||||
if result.extracted_content:
|
||||
logger.info(f'📄 Result: {result.extracted_content}')
|
||||
self.consecutive_failures = 0
|
||||
|
||||
except Exception as e:
|
||||
result = self._handle_step_error(e, state)
|
||||
model_output = None
|
||||
self._update_messages_with_result(result)
|
||||
self._make_history_item(model_output, state, result)
|
||||
|
||||
def _handle_step_error(self, error: Exception, state: BrowserState) -> ActionResult:
|
||||
"""Handle all types of errors that can occur during a step"""
|
||||
error_msg = AgentError.format_error(error)
|
||||
prefix = f'❌ Result failed {self.consecutive_failures + 1}/{self.max_failures} times:\n '
|
||||
|
||||
if isinstance(error, (ValidationError, ValueError)):
|
||||
logger.error(f'{prefix}{error_msg}')
|
||||
self.consecutive_failures += 1
|
||||
elif isinstance(error, RateLimitError):
|
||||
logger.warning(f'{prefix}{error_msg}')
|
||||
time.sleep(self.retry_delay)
|
||||
self.consecutive_failures += 1
|
||||
else:
|
||||
logger.error(f'{prefix}{error_msg}')
|
||||
self.consecutive_failures += 1
|
||||
|
||||
return ActionResult(error=error_msg)
|
||||
|
||||
def _update_messages_with_result(self, result: ActionResult) -> None:
|
||||
"""Update message history with action results"""
|
||||
if result.extracted_content:
|
||||
self.messages.append(HumanMessage(content=result.extracted_content))
|
||||
if result.error:
|
||||
self.messages.append(HumanMessage(content=result.error))
|
||||
|
||||
def _make_history_item(
|
||||
self,
|
||||
model_output: AgentOutput | None,
|
||||
state: BrowserState,
|
||||
result: ActionResult,
|
||||
) -> None:
|
||||
"""Create and store history item"""
|
||||
history_item = AgentHistory(model_output=model_output, result=result, state=state)
|
||||
self.history.append(history_item)
|
||||
|
||||
@time_execution_async('--get_next_action')
|
||||
async def get_next_action(self, state: BrowserState) -> AgentOutput:
|
||||
"""Get next action from LLM based on current state"""
|
||||
new_message = AgentMessagePrompt(state).get_user_message()
|
||||
input_messages = self.messages + [new_message]
|
||||
|
||||
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
|
||||
response: dict[str, Any] = await structured_llm.ainvoke(input_messages) # type: ignore
|
||||
|
||||
parsed: AgentOutput = response['parsed']
|
||||
|
||||
self._update_message_history(state, parsed)
|
||||
self._log_response(parsed)
|
||||
self._save_conversation(input_messages, parsed)
|
||||
self._update_usage_metadata(response['raw'])
|
||||
return parsed
|
||||
|
||||
def _calc_token_cost(self) -> float:
|
||||
"""
|
||||
Calculate the cost of tokens used in a request based on the model.
|
||||
|
||||
:param usage_metadata: TokenUsage model containing token usage details.
|
||||
:param model_name: The name of the model used.
|
||||
:return: Cost of the tokens used.
|
||||
"""
|
||||
if isinstance(self.llm, ChatOpenAI):
|
||||
model_name = self.llm.model_name
|
||||
elif isinstance(self.llm, ChatAnthropic):
|
||||
model_name = self.llm.model
|
||||
else:
|
||||
logger.debug('Model name not supported for pricing calculation')
|
||||
return 0
|
||||
|
||||
pricing_catalog = ModelPricingCatalog()
|
||||
|
||||
if model_name == 'gpt-4o':
|
||||
model_pricing = pricing_catalog.gpt_4o
|
||||
elif model_name == 'gpt-4o-mini':
|
||||
model_pricing = pricing_catalog.gpt_4o_mini
|
||||
elif model_name == 'claude-3-5-sonnet-20240620':
|
||||
model_pricing = pricing_catalog.claude_3_5_sonnet
|
||||
else:
|
||||
logger.debug(f'Unsupported model: {model_name}')
|
||||
return 0
|
||||
|
||||
uncached_input_tokens = (
|
||||
self.usage_metadata.input_tokens - self.usage_metadata.input_token_details.cache_read
|
||||
)
|
||||
factor = 1e6
|
||||
cost = (
|
||||
(uncached_input_tokens / factor) * model_pricing.uncached_input
|
||||
+ (self.usage_metadata.input_token_details.cache_read / factor)
|
||||
* model_pricing.cached_input
|
||||
+ (self.usage_metadata.output_tokens / factor) * model_pricing.output
|
||||
)
|
||||
return cost
|
||||
|
||||
def _update_usage_metadata(self, raw_response: AIMessage) -> None:
|
||||
"""
|
||||
Process the response and update usage.
|
||||
|
||||
:param raw_response: The response object containing usage metadata.
|
||||
"""
|
||||
# only supported for openai models for now
|
||||
if isinstance(self.llm, ChatAnthropic):
|
||||
token_usage_data: dict[str, Any] = raw_response.response_metadata['usage']
|
||||
usage_metadata = TokenUsage(
|
||||
input_tokens=token_usage_data.get('input_tokens', 0),
|
||||
output_tokens=token_usage_data.get('output_tokens', 0),
|
||||
total_tokens=token_usage_data.get('input_tokens', 0)
|
||||
+ token_usage_data.get('output_tokens', 0),
|
||||
)
|
||||
self.usage_metadata.input_tokens += usage_metadata.input_tokens
|
||||
self.usage_metadata.output_tokens += usage_metadata.output_tokens
|
||||
self.usage_metadata.total_tokens += usage_metadata.total_tokens
|
||||
|
||||
elif isinstance(self.llm, ChatOpenAI):
|
||||
token_usage_data: dict[str, Any] = raw_response.response_metadata['token_usage']
|
||||
usage_metadata = TokenUsage(
|
||||
input_tokens=token_usage_data.get('prompt_tokens', 0),
|
||||
output_tokens=token_usage_data.get('completion_tokens', 0),
|
||||
total_tokens=token_usage_data.get('total_tokens', 0),
|
||||
input_token_details=TokenDetails(
|
||||
audio=token_usage_data.get('prompt_tokens_details', {}).get('audio_tokens', 0),
|
||||
cache_read=token_usage_data.get('prompt_tokens_details', {}).get(
|
||||
'cached_tokens', 0
|
||||
),
|
||||
reasoning=0, # Assuming reasoning is not part of prompt_tokens_details
|
||||
),
|
||||
output_token_details=TokenDetails(
|
||||
audio=token_usage_data.get('completion_tokens_details', {}).get(
|
||||
'audio_tokens', 0
|
||||
),
|
||||
cache_read=0, # Assuming cache_read is not part of completion_tokens_details
|
||||
reasoning=token_usage_data.get('completion_tokens_details', {}).get(
|
||||
'reasoning_tokens', 0
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
self.usage_metadata.input_tokens += usage_metadata.input_tokens
|
||||
self.usage_metadata.output_tokens += usage_metadata.output_tokens
|
||||
self.usage_metadata.total_tokens += usage_metadata.total_tokens
|
||||
|
||||
# update usage metadata
|
||||
if usage_metadata.input_token_details:
|
||||
for detail_key in usage_metadata.input_token_details.model_dump():
|
||||
setattr(
|
||||
self.usage_metadata.input_token_details,
|
||||
detail_key,
|
||||
getattr(self.usage_metadata.input_token_details, detail_key)
|
||||
+ getattr(usage_metadata.input_token_details, detail_key),
|
||||
)
|
||||
|
||||
if usage_metadata.output_token_details:
|
||||
for detail_key in usage_metadata.output_token_details.model_dump():
|
||||
setattr(
|
||||
self.usage_metadata.output_token_details,
|
||||
detail_key,
|
||||
getattr(self.usage_metadata.output_token_details, detail_key)
|
||||
+ getattr(usage_metadata.output_token_details, detail_key),
|
||||
)
|
||||
|
||||
else:
|
||||
logger.debug('Model name not supported for pricing calculation')
|
||||
return
|
||||
|
||||
self._log_usage_metadata(usage_metadata)
|
||||
|
||||
def _log_usage_metadata(self, current_tokens: Optional[TokenUsage] = None) -> None:
|
||||
"""Log the usage metadata"""
|
||||
total_cost = self._calc_token_cost()
|
||||
total_tokens = self.usage_metadata.total_tokens
|
||||
logger.debug(
|
||||
f'🔢 Total Tokens: input: {self.usage_metadata.input_tokens} (cached: {self.usage_metadata.input_token_details.cache_read}) + output: {self.usage_metadata.output_tokens} = {total_tokens} = ${total_cost:.4f} 💰'
|
||||
)
|
||||
|
||||
if current_tokens:
|
||||
logger.debug(
|
||||
f'🔢 Last Tokens: input: {current_tokens.input_tokens} (cached: {current_tokens.input_token_details.cache_read}) + output: {current_tokens.output_tokens} = {current_tokens.total_tokens} '
|
||||
)
|
||||
|
||||
def _update_message_history(self, state: BrowserState, response: Any) -> None:
|
||||
"""Update message history with new interactions"""
|
||||
history_message = AgentMessagePrompt(state).get_message_for_history()
|
||||
self.messages.append(history_message)
|
||||
self.messages.append(AIMessage(content=response.model_dump_json(exclude_unset=True)))
|
||||
self.n_steps += 1
|
||||
|
||||
def _log_response(self, response: Any) -> None:
|
||||
"""Log the model's response"""
|
||||
if 'Success' in response.current_state.valuation_previous_goal:
|
||||
emoji = '👍'
|
||||
elif 'Failed' in response.current_state.valuation_previous_goal:
|
||||
emoji = '⚠️'
|
||||
else:
|
||||
emoji = '🤷'
|
||||
|
||||
logger.info(f'{emoji} Evaluation: {response.current_state.valuation_previous_goal}')
|
||||
logger.info(f'🧠 Memory: {response.current_state.memory}')
|
||||
logger.info(f'🎯 Next Goal: {response.current_state.next_goal}')
|
||||
logger.info(f'🛠️ Action: {response.action.model_dump_json(exclude_unset=True)}')
|
||||
|
||||
def _save_conversation(self, input_messages: list[BaseMessage], response: Any) -> None:
|
||||
"""Save conversation history to file if path is specified"""
|
||||
if not self.save_conversation_path:
|
||||
return
|
||||
|
||||
# create folders if not exists
|
||||
os.makedirs(os.path.dirname(self.save_conversation_path), exist_ok=True)
|
||||
|
||||
with open(self.save_conversation_path + f'_{self.n_steps}.txt', 'w') as f:
|
||||
self._write_messages_to_file(f, input_messages)
|
||||
self._write_response_to_file(f, response)
|
||||
|
||||
def _write_messages_to_file(self, f: Any, messages: list[BaseMessage]) -> None:
|
||||
"""Write messages to conversation file"""
|
||||
for message in messages:
|
||||
f.write(f' {message.__class__.__name__} \n')
|
||||
|
||||
if isinstance(message.content, list):
|
||||
for item in message.content:
|
||||
if isinstance(item, dict) and item.get('type') == 'text':
|
||||
f.write(item['text'].strip() + '\n')
|
||||
elif isinstance(message.content, str):
|
||||
try:
|
||||
content = json.loads(message.content)
|
||||
f.write(json.dumps(content, indent=2) + '\n')
|
||||
except json.JSONDecodeError:
|
||||
f.write(message.content.strip() + '\n')
|
||||
|
||||
f.write('\n')
|
||||
|
||||
def _write_response_to_file(self, f: Any, response: Any) -> None:
|
||||
"""Write model response to conversation file"""
|
||||
f.write(' RESPONSE\n')
|
||||
f.write(json.dumps(json.loads(response.model_dump_json(exclude_unset=True)), indent=2))
|
||||
|
||||
async def run(self, max_steps: int = 100) -> list[AgentHistory]:
|
||||
"""Execute the task with maximum number of steps"""
|
||||
try:
|
||||
logger.info(f'🚀 Starting task: {self.task}')
|
||||
logger.info('=' * 50)
|
||||
|
||||
for i in range(max_steps):
|
||||
logger.info(f'\n📍 Step {i+1}')
|
||||
for step in range(max_steps):
|
||||
if self._too_many_failures():
|
||||
break
|
||||
|
||||
action, result = await self.step()
|
||||
await self.step()
|
||||
|
||||
if result.done:
|
||||
logger.info('\n✅ Task completed successfully')
|
||||
logger.info(f'Extracted content: \n{result.extracted_content}')
|
||||
return action.done, self.action_history
|
||||
if self._is_task_complete():
|
||||
logger.info('✅ Task completed successfully')
|
||||
break
|
||||
else:
|
||||
logger.info('❌ Failed to complete task in maximum steps')
|
||||
|
||||
return self.history
|
||||
|
||||
logger.info('\n' + '=' * 50)
|
||||
logger.info('❌ Failed to complete task in maximum steps')
|
||||
logger.info('=' * 50)
|
||||
return None, self.action_history
|
||||
finally:
|
||||
if not self.controller_injected:
|
||||
self.controller.browser.close()
|
||||
|
||||
@time_execution_async('--step')
|
||||
async def step(self) -> tuple[AgentHistory, ControllerActionResult]:
|
||||
state = self.controller.get_current_state(screenshot=self.use_vision)
|
||||
action = await self.get_next_action(state)
|
||||
def _too_many_failures(self) -> bool:
|
||||
"""Check if we should stop due to too many failures"""
|
||||
if self.consecutive_failures >= self.max_failures:
|
||||
logger.error(f'❌ Stopping due to {self.max_failures} consecutive failures')
|
||||
return True
|
||||
return False
|
||||
|
||||
if action.ask_human and action.ask_human.question:
|
||||
action = await self._take_human_input(action.ask_human.question)
|
||||
|
||||
result = self.controller.act(action)
|
||||
self.n += 1
|
||||
|
||||
if result.error:
|
||||
self.messages.append(HumanMessage(content=f'Error: {result.error}'))
|
||||
|
||||
if result.extracted_content:
|
||||
self.messages.append(
|
||||
HumanMessage(content=f'Extracted content:\n {result.extracted_content}')
|
||||
)
|
||||
|
||||
# Convert action to history and update click/input fields if present
|
||||
history_item = self._make_history_item(action, state)
|
||||
self.action_history.append(history_item)
|
||||
|
||||
return history_item, result
|
||||
|
||||
def _make_history_item(self, action: AgentOutput, state: ControllerPageState) -> AgentHistory:
|
||||
return AgentHistory(
|
||||
search_google=action.search_google,
|
||||
go_to_url=action.go_to_url,
|
||||
nothing=action.nothing,
|
||||
go_back=action.go_back,
|
||||
done=action.done,
|
||||
click_element=ClickElementControllerHistoryItem(
|
||||
id=action.click_element.id, xpath=state.selector_map.get(action.click_element.id)
|
||||
)
|
||||
if action.click_element and state.selector_map.get(action.click_element.id)
|
||||
else None,
|
||||
input_text=InputTextControllerHistoryItem(
|
||||
id=action.input_text.id,
|
||||
xpath=state.selector_map.get(action.input_text.id),
|
||||
text=action.input_text.text,
|
||||
)
|
||||
if action.input_text and state.selector_map.get(action.input_text.id)
|
||||
else None,
|
||||
extract_page_content=action.extract_page_content,
|
||||
switch_tab=action.switch_tab,
|
||||
open_tab=action.open_tab,
|
||||
ask_human=action.ask_human,
|
||||
url=state.url,
|
||||
)
|
||||
|
||||
async def _take_human_input(self, question: str) -> AgentOutput:
|
||||
human_input = input(f'\nHi, your input is required: {question}\n\n')
|
||||
logger.info('-' * 50)
|
||||
self.messages.append(HumanMessage(content=human_input))
|
||||
|
||||
structured_llm = self.llm.with_structured_output(AgentOutput)
|
||||
|
||||
action: AgentOutput = await structured_llm.ainvoke(self.messages) # type: ignore
|
||||
|
||||
self.messages.append(AIMessage(content=action.model_dump_json()))
|
||||
|
||||
return action
|
||||
|
||||
@time_execution_async('--get_next_action')
|
||||
async def get_next_action(self, state: ControllerPageState) -> AgentOutput:
|
||||
# TODO: include state, actions, etc.
|
||||
|
||||
new_message = AgentMessagePrompt(state).get_user_message()
|
||||
logger.info(f'current tabs: {state.tabs}')
|
||||
input_messages = self.messages + [new_message]
|
||||
|
||||
structured_llm = self.llm.with_structured_output(Output, include_raw=False)
|
||||
|
||||
#
|
||||
response: Output = await structured_llm.ainvoke(input_messages) # type: ignore
|
||||
|
||||
# Only append the output message
|
||||
history_new_message = AgentMessagePrompt(state).get_message_for_history()
|
||||
self.messages.append(history_new_message)
|
||||
self.messages.append(AIMessage(content=response.model_dump_json()))
|
||||
logger.info(f'current state\n: {response.current_state.model_dump_json(indent=4)}')
|
||||
logger.info(f'action\n: {response.action.model_dump_json(indent=4)}')
|
||||
self._save_conversation(input_messages, response)
|
||||
|
||||
return response.action
|
||||
|
||||
def _get_action_description(self) -> str:
|
||||
return AgentOutput.description()
|
||||
|
||||
def _save_conversation(self, input_messages: list[BaseMessage], response: Output):
|
||||
if self.save_conversation_path is not None:
|
||||
with open(self.save_conversation_path + f'_{self.n}.txt', 'w') as f:
|
||||
# Write messages with proper formatting
|
||||
for message in input_messages:
|
||||
f.write('=' * 33 + f' {message.__class__.__name__} ' + '=' * 33 + '\n\n')
|
||||
|
||||
# Handle different content types
|
||||
if isinstance(message.content, list):
|
||||
# Handle vision model messages
|
||||
for item in message.content:
|
||||
if isinstance(item, dict) and item.get('type') == 'text':
|
||||
f.write(item['text'].strip() + '\n')
|
||||
elif isinstance(message.content, str):
|
||||
try:
|
||||
# Try to parse and format JSON content
|
||||
content = json.loads(message.content)
|
||||
f.write(json.dumps(content, indent=2) + '\n')
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, write as regular text
|
||||
f.write(message.content.strip() + '\n')
|
||||
|
||||
f.write('\n')
|
||||
|
||||
# Write final response as formatted JSON
|
||||
f.write('=' * 33 + ' Response ' + '=' * 33 + '\n\n')
|
||||
f.write(json.dumps(json.loads(response.model_dump_json()), indent=2))
|
||||
def _is_task_complete(self) -> bool:
|
||||
"""Check if the task has been completed successfully"""
|
||||
return bool(self.history and self.history[-1].result.is_done)
|
||||
|
||||
@@ -1,55 +1,107 @@
|
||||
from typing import Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Type
|
||||
|
||||
from browser_use.controller.views import (
|
||||
ClickElementControllerAction,
|
||||
ControllerActions,
|
||||
InputTextControllerAction,
|
||||
)
|
||||
from openai import RateLimitError
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, create_model
|
||||
|
||||
from browser_use.browser.views import BrowserState
|
||||
from browser_use.controller.registry.views import ActionModel
|
||||
|
||||
|
||||
class AskHumanAgentAction(BaseModel):
|
||||
question: str
|
||||
class TokenDetails(BaseModel):
|
||||
audio: int = 0
|
||||
cache_read: int = 0
|
||||
reasoning: int = 0
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
class TokenUsage(BaseModel):
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
total_tokens: int
|
||||
input_token_details: TokenDetails = Field(default=TokenDetails())
|
||||
output_token_details: TokenDetails = Field(default=TokenDetails())
|
||||
|
||||
# allow arbitrary types
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class Pricing(BaseModel):
|
||||
uncached_input: float # per 1M tokens
|
||||
cached_input: float
|
||||
output: float
|
||||
|
||||
|
||||
class ModelPricingCatalog(BaseModel):
|
||||
gpt_4o: Pricing = Field(default=Pricing(uncached_input=2.50, cached_input=1.25, output=10.00))
|
||||
gpt_4o_mini: Pricing = Field(
|
||||
default=Pricing(uncached_input=0.15, cached_input=0.075, output=0.60)
|
||||
)
|
||||
claude_3_5_sonnet: Pricing = Field(
|
||||
default=Pricing(uncached_input=3.00, cached_input=1.50, output=15.00)
|
||||
)
|
||||
|
||||
|
||||
class ActionResult(BaseModel):
|
||||
"""Result of executing an action"""
|
||||
|
||||
is_done: Optional[bool] = False
|
||||
extracted_content: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class AgentBrain(BaseModel):
|
||||
"""Current state of the agent"""
|
||||
|
||||
valuation_previous_goal: str
|
||||
memory: str
|
||||
next_goal: str
|
||||
|
||||
|
||||
class AgentOnlyAction(BaseModel):
|
||||
ask_human: Optional[AskHumanAgentAction] = None
|
||||
class AgentOutput(BaseModel):
|
||||
"""Output model for agent
|
||||
|
||||
@dev note: this model is extended with custom actions in AgentService. You can also use some fields that are not in this model as provided by the linter, as long as they are registered in the DynamicActions model.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
current_state: AgentBrain
|
||||
action: ActionModel
|
||||
|
||||
@staticmethod
|
||||
def description() -> str:
|
||||
return """
|
||||
- Ask human for help
|
||||
Example: {"ask_human": {"question": "To clarify ..."}}"""
|
||||
def type_with_custom_actions(custom_actions: Type[ActionModel]) -> Type['AgentOutput']:
|
||||
"""Extend actions with custom actions"""
|
||||
return create_model(
|
||||
'AgentOutput',
|
||||
__base__=AgentOutput,
|
||||
action=(custom_actions, Field(...)), # Properly annotated field with no default
|
||||
__module__=AgentOutput.__module__,
|
||||
)
|
||||
|
||||
|
||||
class AgentOutput(ControllerActions, AgentOnlyAction):
|
||||
class AgentHistory(BaseModel):
|
||||
"""History item for agent actions"""
|
||||
|
||||
model_output: AgentOutput | None
|
||||
result: ActionResult
|
||||
state: BrowserState
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
|
||||
|
||||
class AgentError:
|
||||
"""Container for agent error handling"""
|
||||
|
||||
VALIDATION_ERROR = 'Invalid model output format. Please follow the correct schema.'
|
||||
RATE_LIMIT_ERROR = 'Rate limit reached. Waiting before retry.'
|
||||
NO_VALID_ACTION = 'No valid action found'
|
||||
|
||||
@staticmethod
|
||||
def description() -> str:
|
||||
return AgentOnlyAction.description() + ControllerActions.description()
|
||||
#
|
||||
|
||||
|
||||
class Output(BaseModel):
|
||||
current_state: AgentState
|
||||
action: AgentOutput
|
||||
|
||||
|
||||
class ClickElementControllerHistoryItem(ClickElementControllerAction):
|
||||
xpath: str | None
|
||||
|
||||
|
||||
class InputTextControllerHistoryItem(InputTextControllerAction):
|
||||
xpath: str | None
|
||||
|
||||
|
||||
class AgentHistory(AgentOutput):
|
||||
click_element: Optional[ClickElementControllerHistoryItem] = None
|
||||
input_text: Optional[InputTextControllerHistoryItem] = None
|
||||
url: str
|
||||
def format_error(error: Exception) -> str:
|
||||
"""Format error message based on error type"""
|
||||
if isinstance(error, ValidationError):
|
||||
return f'{AgentError.VALIDATION_ERROR}\nDetails: {str(error)}'
|
||||
if isinstance(error, RateLimitError):
|
||||
return AgentError.RATE_LIMIT_ERROR
|
||||
return f'Unexpected error: {str(error)}'
|
||||
|
||||
@@ -9,104 +9,114 @@ import tempfile
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
from main_content_extractor import MainContentExtractor
|
||||
from Screenshot import Screenshot
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
from selenium.webdriver.chrome.service import Service
|
||||
from selenium.webdriver.chrome.service import Service as ChromeService
|
||||
from selenium.webdriver.common.action_chains import ActionChains
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from webdriver_manager.chrome import ChromeDriverManager
|
||||
|
||||
from browser_use.browser.views import BrowserState
|
||||
from browser_use.browser.views import BrowserState, TabInfo
|
||||
from browser_use.dom.service import DomService
|
||||
from browser_use.dom.views import SelectorMap
|
||||
from browser_use.utils import time_execution_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class BrowserService:
|
||||
class Browser:
|
||||
def __init__(self, headless: bool = False, keep_open: bool = False):
|
||||
self.headless = headless
|
||||
self.driver: webdriver.Chrome | None = None
|
||||
self._ob = Screenshot.Screenshot()
|
||||
self.keep_open = keep_open
|
||||
self.MINIMUM_WAIT_TIME = 0.5
|
||||
self.MAXIMUM_WAIT_TIME = 5
|
||||
self._current_handle = None # Track current handle
|
||||
self._tab_cache = {}
|
||||
self.keep_open = keep_open
|
||||
self._tab_cache: dict[str, TabInfo] = {}
|
||||
self._current_handle = None
|
||||
self._ob = Screenshot.Screenshot()
|
||||
|
||||
def init(self) -> webdriver.Chrome:
|
||||
"""
|
||||
Sets up and returns a Selenium WebDriver instance with anti-detection measures.
|
||||
# Initialize driver during construction
|
||||
self.driver: webdriver.Chrome | None = self._setup_webdriver()
|
||||
self._cached_state = self._update_state()
|
||||
|
||||
Returns:
|
||||
webdriver.Chrome: Configured Chrome WebDriver instance
|
||||
"""
|
||||
chrome_options = Options()
|
||||
if self.headless:
|
||||
chrome_options.add_argument('--headless')
|
||||
def _setup_webdriver(self) -> webdriver.Chrome:
|
||||
"""Sets up and returns a Selenium WebDriver instance with anti-detection measures."""
|
||||
try:
|
||||
# if webdriver is not starting, try to kill it or rm -rf ~/.wdm
|
||||
chrome_options = Options()
|
||||
if self.headless:
|
||||
chrome_options.add_argument('--headless=new') # Updated headless argument
|
||||
|
||||
# Anti-detection measures
|
||||
chrome_options.add_argument('--disable-blink-features=AutomationControlled')
|
||||
chrome_options.add_experimental_option('excludeSwitches', ['enable-automation'])
|
||||
chrome_options.add_experimental_option('useAutomationExtension', False)
|
||||
# Essential automation and performance settings
|
||||
chrome_options.add_argument('--disable-blink-features=AutomationControlled')
|
||||
chrome_options.add_experimental_option('excludeSwitches', ['enable-automation'])
|
||||
chrome_options.add_experimental_option('useAutomationExtension', False)
|
||||
chrome_options.add_argument('--no-sandbox')
|
||||
chrome_options.add_argument('--window-size=1280,1024')
|
||||
chrome_options.add_argument('--disable-extensions')
|
||||
|
||||
# Additional stealth settings
|
||||
# chrome_options.add_argument('--start-maximized')
|
||||
chrome_options.add_argument('--window-size=1280,1024')
|
||||
chrome_options.add_argument('--disable-extensions')
|
||||
chrome_options.add_argument('--no-sandbox')
|
||||
chrome_options.add_argument('--disable-infobars')
|
||||
# Background process optimization
|
||||
chrome_options.add_argument('--disable-background-timer-throttling')
|
||||
chrome_options.add_argument('--disable-popup-blocking')
|
||||
|
||||
# Initialize the Chrome driver
|
||||
driver = webdriver.Chrome(
|
||||
service=Service(ChromeDriverManager().install()), options=chrome_options
|
||||
)
|
||||
# Additional stealth settings
|
||||
chrome_options.add_argument('--disable-infobars')
|
||||
# Much better when working in non-headless mode
|
||||
chrome_options.add_argument('--disable-backgrounding-occluded-windows')
|
||||
chrome_options.add_argument('--disable-renderer-backgrounding')
|
||||
|
||||
# Execute stealth scripts
|
||||
driver.execute_cdp_cmd(
|
||||
'Page.addScriptToEvaluateOnNewDocument',
|
||||
{
|
||||
'source': """
|
||||
Object.defineProperty(navigator, 'webdriver', {
|
||||
get: () => undefined
|
||||
});
|
||||
|
||||
Object.defineProperty(navigator, 'languages', {
|
||||
get: () => ['en-US', 'en']
|
||||
});
|
||||
|
||||
Object.defineProperty(navigator, 'plugins', {
|
||||
get: () => [1, 2, 3, 4, 5]
|
||||
});
|
||||
|
||||
window.chrome = {
|
||||
runtime: {}
|
||||
};
|
||||
|
||||
Object.defineProperty(navigator, 'permissions', {
|
||||
get: () => ({
|
||||
query: Promise.resolve.bind(Promise)
|
||||
})
|
||||
});
|
||||
"""
|
||||
},
|
||||
)
|
||||
# Initialize the Chrome driver with better error handling
|
||||
service = ChromeService(ChromeDriverManager().install())
|
||||
driver = webdriver.Chrome(service=service, options=chrome_options)
|
||||
|
||||
self.driver = driver
|
||||
# Execute stealth scripts
|
||||
driver.execute_cdp_cmd(
|
||||
'Page.addScriptToEvaluateOnNewDocument',
|
||||
{
|
||||
'source': """
|
||||
Object.defineProperty(navigator, 'webdriver', {
|
||||
get: () => undefined
|
||||
});
|
||||
|
||||
Object.defineProperty(navigator, 'languages', {
|
||||
get: () => ['en-US', 'en']
|
||||
});
|
||||
|
||||
Object.defineProperty(navigator, 'plugins', {
|
||||
get: () => [1, 2, 3, 4, 5]
|
||||
});
|
||||
|
||||
window.chrome = {
|
||||
runtime: {}
|
||||
};
|
||||
|
||||
Object.defineProperty(navigator, 'permissions', {
|
||||
get: () => ({
|
||||
query: Promise.resolve.bind(Promise)
|
||||
})
|
||||
});
|
||||
"""
|
||||
},
|
||||
)
|
||||
|
||||
# driver.get('https://www.google.com')
|
||||
return driver
|
||||
|
||||
return driver
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to initialize Chrome driver: {str(e)}')
|
||||
# Clean up any existing driver
|
||||
if hasattr(self, 'driver') and self.driver:
|
||||
try:
|
||||
self.driver.quit()
|
||||
self.driver = None
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
def _get_driver(self) -> webdriver.Chrome:
|
||||
if self.driver is None:
|
||||
self.driver = self.init()
|
||||
self.driver = self._setup_webdriver()
|
||||
return self.driver
|
||||
|
||||
def wait_for_page_load(self):
|
||||
@@ -131,7 +141,7 @@ class BrowserService:
|
||||
elapsed = time.time() - start_time
|
||||
remaining = max(self.MINIMUM_WAIT_TIME - elapsed, 0)
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'--Page loaded in {elapsed:.2f} seconds, waiting for additional {remaining:.2f} seconds'
|
||||
)
|
||||
|
||||
@@ -139,27 +149,36 @@ class BrowserService:
|
||||
if remaining > 0:
|
||||
time.sleep(remaining)
|
||||
|
||||
def get_updated_state(self) -> BrowserState:
|
||||
def _update_state(self, use_vision: bool = False) -> BrowserState:
|
||||
"""
|
||||
Update and return state.
|
||||
"""
|
||||
driver = self._get_driver()
|
||||
dom_service = DomService(driver)
|
||||
content = dom_service.get_clickable_elements()
|
||||
|
||||
screenshot_b64 = None
|
||||
if use_vision:
|
||||
screenshot_b64 = self.take_screenshot(selector_map=content.selector_map)
|
||||
|
||||
self.current_state = BrowserState(
|
||||
items=content.items,
|
||||
selector_map=content.selector_map,
|
||||
url=driver.current_url,
|
||||
title=driver.title,
|
||||
current_tab_handle=self._current_handle or driver.current_window_handle,
|
||||
tabs=self.get_tabs_info(),
|
||||
screenshot=screenshot_b64,
|
||||
)
|
||||
|
||||
return self.current_state
|
||||
|
||||
def close(self):
|
||||
if not self.keep_open:
|
||||
driver = self._get_driver()
|
||||
driver.quit()
|
||||
self.driver = None
|
||||
def close(self, force: bool = False):
|
||||
if not self.keep_open or force:
|
||||
if self.driver:
|
||||
driver = self._get_driver()
|
||||
driver.quit()
|
||||
self.driver = None
|
||||
else:
|
||||
input('Press Enter to close Browser...')
|
||||
self.keep_open = False
|
||||
@@ -174,44 +193,6 @@ class BrowserService:
|
||||
|
||||
# region - Browser Actions
|
||||
|
||||
def search_google(self, query: str):
|
||||
"""
|
||||
@dev TODO: add serp api call here
|
||||
"""
|
||||
driver = self._get_driver()
|
||||
driver.get(f'https://www.google.com/search?q={query}')
|
||||
self.wait_for_page_load()
|
||||
|
||||
def go_to_url(self, url: str):
|
||||
driver = self._get_driver()
|
||||
driver.get(url)
|
||||
self.wait_for_page_load()
|
||||
|
||||
def go_back(self):
|
||||
driver = self._get_driver()
|
||||
driver.back()
|
||||
self.wait_for_page_load()
|
||||
|
||||
def refresh(self):
|
||||
driver = self._get_driver()
|
||||
driver.refresh()
|
||||
self.wait_for_page_load()
|
||||
|
||||
def extract_page_content(self, value: Literal['text', 'markdown'] = 'markdown') -> str:
|
||||
"""
|
||||
TODO: switch to a better parser/extractor
|
||||
"""
|
||||
driver = self._get_driver()
|
||||
content = MainContentExtractor.extract(driver.page_source, output_format=value) # type: ignore TODO
|
||||
return content
|
||||
|
||||
def done(self, text: str):
|
||||
"""
|
||||
Ends the task and waits for further instructions.
|
||||
"""
|
||||
logger.info(f'Done on page {self.current_state.url}\n\n: {text}')
|
||||
return text
|
||||
|
||||
def take_screenshot(self, selector_map: SelectorMap | None, full_page: bool = False) -> str:
|
||||
"""
|
||||
Returns a base64 encoded screenshot of the current page.
|
||||
@@ -332,7 +313,6 @@ class BrowserService:
|
||||
|
||||
# Then send keys
|
||||
element.send_keys(text)
|
||||
logger.info(f'Input text into element with xpath: {xpath}')
|
||||
|
||||
self.wait_for_page_load()
|
||||
|
||||
@@ -341,13 +321,6 @@ class BrowserService:
|
||||
f'Failed to input text into element with xpath: {xpath}. Error: {str(e)}'
|
||||
)
|
||||
|
||||
def input_text_by_index(self, index: int, text: str, state: BrowserState):
|
||||
if index not in state.selector_map:
|
||||
raise Exception(f'Element index {index} not found in selector map')
|
||||
|
||||
xpath = state.selector_map[index]
|
||||
self._input_text_by_xpath(xpath, text)
|
||||
|
||||
def _click_element_by_xpath(self, xpath: str):
|
||||
"""
|
||||
Optimized method to click an element using xpath.
|
||||
@@ -362,7 +335,7 @@ class BrowserService:
|
||||
EC.element_to_be_clickable((By.XPATH, xpath)),
|
||||
message=f'Element not clickable: {xpath}',
|
||||
)
|
||||
driver.execute_script('arguments[0].click();', element)
|
||||
element.click()
|
||||
self.wait_for_page_load()
|
||||
return
|
||||
except Exception:
|
||||
@@ -399,30 +372,7 @@ class BrowserService:
|
||||
except Exception as e:
|
||||
raise Exception(f'Failed to click element with xpath: {xpath}. Error: {str(e)}')
|
||||
|
||||
@time_execution_sync('--click')
|
||||
def click_element_by_index(self, index: int, state: BrowserState):
|
||||
"""
|
||||
Clicks an element using its index from the selector map.
|
||||
"""
|
||||
if index not in state.selector_map:
|
||||
raise Exception(f'Element index {index} not found in selector map')
|
||||
|
||||
# Store current number of tabs before clicking
|
||||
driver = self._get_driver()
|
||||
initial_handles = len(driver.window_handles)
|
||||
|
||||
xpath = state.selector_map[index]
|
||||
self._click_element_by_xpath(xpath)
|
||||
logger.info(f'Clicked on index {index}: with xpath {xpath}')
|
||||
|
||||
# Check if new tab was opened
|
||||
current_handles = len(driver.window_handles)
|
||||
if current_handles > initial_handles:
|
||||
return self.handle_new_tab()
|
||||
|
||||
# endregion
|
||||
|
||||
def handle_new_tab(self) -> dict:
|
||||
def handle_new_tab(self) -> None:
|
||||
"""Handle newly opened tab and switch to it"""
|
||||
driver = self._get_driver()
|
||||
handles = driver.window_handles
|
||||
@@ -430,50 +380,31 @@ class BrowserService:
|
||||
|
||||
# Switch to new tab
|
||||
driver.switch_to.window(new_handle)
|
||||
self._current_handle = new_handle # Update current handle
|
||||
self._current_handle = new_handle
|
||||
|
||||
# Wait for page load
|
||||
self.wait_for_page_load()
|
||||
|
||||
# Get and cache tab info
|
||||
tab_info = {
|
||||
'handle': new_handle,
|
||||
'url': driver.current_url,
|
||||
'title': driver.title,
|
||||
'is_current': True,
|
||||
}
|
||||
# Create and cache tab info
|
||||
tab_info = TabInfo(handle=new_handle, url=driver.current_url, title=driver.title)
|
||||
self._tab_cache[new_handle] = tab_info
|
||||
|
||||
# Update is_current for all tabs
|
||||
for handle in self._tab_cache:
|
||||
self._tab_cache[handle]['is_current'] = handle == new_handle
|
||||
|
||||
return tab_info
|
||||
|
||||
def get_tabs_info(self) -> list[dict]:
|
||||
def get_tabs_info(self) -> list[TabInfo]:
|
||||
"""Get information about all tabs"""
|
||||
driver = self._get_driver()
|
||||
current_handle = driver.current_window_handle
|
||||
self._current_handle = current_handle # Update current handle
|
||||
self._current_handle = current_handle
|
||||
|
||||
tabs_info = []
|
||||
for handle in driver.window_handles:
|
||||
is_current = handle == current_handle
|
||||
|
||||
# Use cached info if available, otherwise get new info
|
||||
if handle in self._tab_cache:
|
||||
tab_info = self._tab_cache[handle]
|
||||
tab_info['is_current'] = is_current
|
||||
else:
|
||||
# Only switch if we need to get info
|
||||
if not is_current:
|
||||
if handle != current_handle:
|
||||
driver.switch_to.window(handle)
|
||||
tab_info = {
|
||||
'handle': handle,
|
||||
'url': driver.current_url,
|
||||
'title': driver.title,
|
||||
'is_current': is_current,
|
||||
}
|
||||
tab_info = TabInfo(handle=handle, url=driver.current_url, title=driver.title)
|
||||
self._tab_cache[handle] = tab_info
|
||||
|
||||
tabs_info.append(tab_info)
|
||||
@@ -484,33 +415,12 @@ class BrowserService:
|
||||
|
||||
return tabs_info
|
||||
|
||||
def switch_tab(self, handle: str) -> dict:
|
||||
"""Switch to specified tab and return its info"""
|
||||
driver = self._get_driver()
|
||||
# endregion
|
||||
|
||||
# Verify handle exists
|
||||
if handle not in driver.window_handles:
|
||||
raise ValueError(f'Tab handle {handle} not found')
|
||||
|
||||
# Only switch if we're not already on that tab
|
||||
current_handle = driver.current_window_handle
|
||||
if handle != current_handle:
|
||||
driver.switch_to.window(handle)
|
||||
# Wait for tab to be ready
|
||||
self.wait_for_page_load()
|
||||
|
||||
# Update and return tab info
|
||||
tab_info = {
|
||||
'handle': handle,
|
||||
'url': driver.current_url,
|
||||
'title': driver.title,
|
||||
'is_current': True,
|
||||
}
|
||||
self._tab_cache[handle] = tab_info
|
||||
return tab_info
|
||||
|
||||
def open_tab(self, url: str):
|
||||
driver = self._get_driver()
|
||||
driver.execute_script(f'window.open("{url}", "_blank");')
|
||||
self.wait_for_page_load()
|
||||
return self.handle_new_tab()
|
||||
@time_execution_sync('--get_state')
|
||||
def get_state(self, use_vision: bool = False) -> BrowserState:
|
||||
"""
|
||||
Get the current state of the browser including page content and tab information.
|
||||
"""
|
||||
self._cached_state = self._update_state(use_vision=use_vision)
|
||||
return self._cached_state
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
|
||||
from browser_use.browser.service import BrowserService
|
||||
from browser_use.dom.service import DomService
|
||||
from browser_use.utils import time_execution_sync
|
||||
|
||||
|
||||
def test_highlight_elements():
|
||||
browser = BrowserService(headless=False)
|
||||
|
||||
driver = browser.init()
|
||||
|
||||
dom_service = DomService(driver)
|
||||
|
||||
browser.go_to_url('https://www.kayak.ch')
|
||||
# browser.go_to_url('https://google.com/flights')
|
||||
|
||||
time.sleep(1)
|
||||
# browser._click_element_by_xpath(
|
||||
# '/html/body/div[5]/div/div[2]/div/div/div[3]/div/div[1]/button[1]'
|
||||
# )
|
||||
browser._click_element_by_xpath("//button[div/div[text()='Alle akzeptieren']]")
|
||||
|
||||
elements = time_execution_sync('get_clickable_elements')(dom_service.get_clickable_elements)()
|
||||
|
||||
# time_execution_sync('highlight_selector_map_elements')(browser.highlight_selector_map_elements)(
|
||||
# elements.selector_map
|
||||
# )
|
||||
|
||||
image = time_execution_sync('take_screenshot')(browser.take_screenshot)(elements.selector_map)
|
||||
|
||||
temp_image_path = os.path.join(os.path.dirname(__file__), 'temp', 'temp.png')
|
||||
with open(temp_image_path, 'wb') as f:
|
||||
f.write(base64.b64decode(image))
|
||||
|
||||
# time_execution_sync('remove_highlights')(browser.remove_highlights)()
|
||||
|
||||
input('Press Enter to continue...')
|
||||
@@ -1,18 +1,18 @@
|
||||
import base64
|
||||
|
||||
import pytest
|
||||
|
||||
from browser_use.browser.service import BrowserService
|
||||
from browser_use.browser.service import Browser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def browser():
|
||||
browser_service = BrowserService(headless=True)
|
||||
browser_service.init()
|
||||
browser_service = Browser(headless=True)
|
||||
yield browser_service
|
||||
browser_service.close()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='takes too long')
|
||||
# @pytest.mark.skip(reason='takes too long')
|
||||
def test_take_full_page_screenshot(browser):
|
||||
# Go to a test page
|
||||
browser.go_to_url('https://example.com')
|
||||
@@ -30,3 +30,7 @@ def test_take_full_page_screenshot(browser):
|
||||
base64.b64decode(screenshot_b64)
|
||||
except Exception as e:
|
||||
pytest.fail(f'Failed to decode base64 screenshot: {str(e)}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_take_full_page_screenshot(Browser(headless=False))
|
||||
|
||||
59
browser_use/browser/tests/test_clicks.py
Normal file
59
browser_use/browser/tests/test_clicks.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import time
|
||||
from browser_use.browser.service import Browser
|
||||
|
||||
from browser_use.utils import time_execution_sync
|
||||
|
||||
|
||||
def test_highlight_elements():
|
||||
browser = Browser()
|
||||
|
||||
browser._get_driver().get('https://kayak.com')
|
||||
# browser.go_to_url('https://google.com/flights')
|
||||
# browser.go_to_url('https://immobilienscout24.de')
|
||||
|
||||
time.sleep(1)
|
||||
# browser._click_element_by_xpath(
|
||||
# '/html/body/div[5]/div/div[2]/div/div/div[3]/div/div[1]/button[1]'
|
||||
# )
|
||||
# browser._click_element_by_xpath("//button[div/div[text()='Alle akzeptieren']]")
|
||||
|
||||
while True:
|
||||
state = browser.get_state()
|
||||
|
||||
time_execution_sync('highlight_selector_map_elements')(
|
||||
browser.highlight_selector_map_elements
|
||||
)(state.selector_map)
|
||||
|
||||
print(state.dom_items_to_string(use_tabs=False))
|
||||
# print(state.selector_map)
|
||||
|
||||
# Find and print duplicate XPaths
|
||||
xpath_counts = {}
|
||||
for selector in state.selector_map.values():
|
||||
if selector in xpath_counts:
|
||||
xpath_counts[selector] += 1
|
||||
else:
|
||||
xpath_counts[selector] = 1
|
||||
|
||||
print('\nDuplicate XPaths found:')
|
||||
for xpath, count in xpath_counts.items():
|
||||
if count > 1:
|
||||
print(f'XPath: {xpath}')
|
||||
print(f'Count: {count}\n')
|
||||
|
||||
print(state.selector_map.keys(), 'Selector map keys')
|
||||
action = input('Select next action: ')
|
||||
|
||||
time_execution_sync('remove_highlight_elements')(browser.remove_highlights)()
|
||||
|
||||
xpath = state.selector_map[int(action)]
|
||||
|
||||
browser._click_element_by_xpath(xpath)
|
||||
|
||||
|
||||
def main():
|
||||
test_highlight_elements()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
50
browser_use/browser/tests/test_selenium.py
Normal file
50
browser_use/browser/tests/test_selenium.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
from selenium.webdriver.chrome.service import Service
|
||||
from webdriver_manager.chrome import ChromeDriverManager
|
||||
|
||||
|
||||
def test_selenium():
|
||||
try:
|
||||
print('1. Setting up Chrome options...')
|
||||
chrome_options = Options()
|
||||
chrome_options.add_argument('--no-sandbox')
|
||||
# Uncomment to test headless mode
|
||||
# chrome_options.add_argument('--headless=new')
|
||||
|
||||
print('2. Installing/finding ChromeDriver...')
|
||||
service = Service(ChromeDriverManager().install())
|
||||
|
||||
print('3. Creating Chrome WebDriver...')
|
||||
driver = webdriver.Chrome(service=service, options=chrome_options)
|
||||
|
||||
print('4. Navigating to Google...')
|
||||
driver.get('https://www.google.com')
|
||||
|
||||
print('5. Getting page title...')
|
||||
title = driver.title
|
||||
print(f'Page title: {title}')
|
||||
|
||||
time.sleep(2) # Wait to see the page if not in headless mode
|
||||
|
||||
print('6. Closing browser...')
|
||||
driver.quit()
|
||||
|
||||
print('✅ Test completed successfully!')
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f'❌ Test failed with error: {str(e)}')
|
||||
print(f'Error type: {type(e).__name__}')
|
||||
return False
|
||||
|
||||
|
||||
# run with: pytest browser_use/browser/tests/test_selenium.py
|
||||
|
||||
#
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
@@ -1,12 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from browser_use.dom.views import ProcessedDomContent
|
||||
|
||||
|
||||
# Exceptions
|
||||
class BrowserException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Pydantic
|
||||
class TabInfo(BaseModel):
|
||||
"""Represents information about a browser tab"""
|
||||
|
||||
handle: str
|
||||
url: str
|
||||
title: str
|
||||
|
||||
|
||||
class BrowserState(ProcessedDomContent):
|
||||
url: str
|
||||
title: str
|
||||
current_tab_handle: str
|
||||
tabs: list[TabInfo]
|
||||
screenshot: Optional[str] = None
|
||||
|
||||
def model_dump(self) -> dict:
|
||||
dump = super().model_dump()
|
||||
# Add a summary of available tabs
|
||||
if self.tabs:
|
||||
dump['available_tabs'] = [
|
||||
f'Tab {i+1}: {tab.title} ({tab.url})' for i, tab in enumerate(self.tabs)
|
||||
]
|
||||
return dump
|
||||
|
||||
104
browser_use/controller/registry/service.py
Normal file
104
browser_use/controller/registry/service.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
from browser_use.browser.service import Browser
|
||||
from browser_use.controller.registry.views import (
|
||||
ActionModel,
|
||||
ActionRegistry,
|
||||
RegisteredAction,
|
||||
)
|
||||
|
||||
|
||||
class Registry:
|
||||
"""Service for registering and managing actions"""
|
||||
|
||||
def __init__(self):
|
||||
self.registry = ActionRegistry()
|
||||
|
||||
def _create_param_model(self, function: Callable) -> Type[BaseModel]:
|
||||
"""Creates a Pydantic model from function signature"""
|
||||
sig = signature(function)
|
||||
params = {
|
||||
name: (param.annotation, ... if param.default == param.empty else param.default)
|
||||
for name, param in sig.parameters.items()
|
||||
if name != 'browser'
|
||||
}
|
||||
# TODO: make the types here work
|
||||
return create_model(
|
||||
f'{function.__name__}Params',
|
||||
__base__=ActionModel,
|
||||
**params, # type: ignore
|
||||
)
|
||||
|
||||
def action(
|
||||
self,
|
||||
description: str,
|
||||
param_model: Optional[Type[BaseModel]] = None,
|
||||
requires_browser: bool = False,
|
||||
):
|
||||
"""Decorator for registering actions"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
# Create param model from function if not provided
|
||||
actual_param_model = param_model or self._create_param_model(func)
|
||||
|
||||
action = RegisteredAction(
|
||||
name=func.__name__,
|
||||
description=description,
|
||||
function=func,
|
||||
param_model=actual_param_model,
|
||||
requires_browser=requires_browser,
|
||||
)
|
||||
self.registry.actions[func.__name__] = action
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def execute_action(
|
||||
self, action_name: str, params: dict, browser: Optional[Browser] = None
|
||||
) -> Any:
|
||||
"""Execute a registered action"""
|
||||
if action_name not in self.registry.actions:
|
||||
raise ValueError(f'Action {action_name} not found')
|
||||
|
||||
action = self.registry.actions[action_name]
|
||||
try:
|
||||
# Create the validated Pydantic model
|
||||
validated_params = action.param_model(**params)
|
||||
|
||||
# Check if the first parameter is a Pydantic model
|
||||
sig = signature(action.function)
|
||||
first_param = next(iter(sig.parameters.values()))
|
||||
is_pydantic = (
|
||||
hasattr(first_param.annotation, '__bases__')
|
||||
and BaseModel in first_param.annotation.__bases__
|
||||
)
|
||||
|
||||
# Execute with or without browser
|
||||
if action.requires_browser:
|
||||
if not browser:
|
||||
raise ValueError(f'Action {action_name} requires browser but none provided')
|
||||
if is_pydantic:
|
||||
return action.function(validated_params, browser=browser)
|
||||
return action.function(**validated_params.model_dump(), browser=browser)
|
||||
|
||||
if is_pydantic:
|
||||
return action.function(validated_params)
|
||||
return action.function(**validated_params.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f'Error executing action {action_name}: {str(e)}')
|
||||
|
||||
def create_action_model(self) -> Type[ActionModel]:
|
||||
"""Creates a Pydantic model from registered actions"""
|
||||
fields = {
|
||||
name: (Optional[action.param_model], None)
|
||||
for name, action in self.registry.actions.items()
|
||||
}
|
||||
return create_model('ActionModel', __base__=ActionModel, **fields) # type:ignore
|
||||
|
||||
def get_prompt_description(self) -> str:
|
||||
"""Get a description of all actions for the prompt"""
|
||||
return self.registry.get_prompt_description()
|
||||
45
browser_use/controller/registry/views.py
Normal file
45
browser_use/controller/registry/views.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Callable, Dict, Type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class RegisteredAction(BaseModel):
|
||||
"""Model for a registered action"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
function: Callable
|
||||
param_model: Type[BaseModel]
|
||||
requires_browser: bool = False
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def prompt_description(self) -> str:
|
||||
"""Get a description of the action for the prompt"""
|
||||
skip_keys = ['title']
|
||||
s = f'{self.description}: \n'
|
||||
s += '{' + str(self.name) + ': '
|
||||
s += str(
|
||||
{
|
||||
k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k not in skip_keys}
|
||||
for k, v in self.param_model.schema()['properties'].items()
|
||||
}
|
||||
)
|
||||
s += '}'
|
||||
return s
|
||||
|
||||
|
||||
class ActionModel(BaseModel):
|
||||
"""Base model for dynamically created action models"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ActionRegistry(BaseModel):
|
||||
"""Model representing the action registry"""
|
||||
|
||||
actions: Dict[str, RegisteredAction] = {}
|
||||
|
||||
def get_prompt_description(self) -> str:
|
||||
"""Get a description of all actions for the prompt"""
|
||||
return '\n'.join([action.prompt_description() for action in self.actions.values()])
|
||||
@@ -1,96 +1,170 @@
|
||||
import logging
|
||||
from browser_use.browser.service import BrowserService
|
||||
from browser_use.browser.views import BrowserState
|
||||
|
||||
from main_content_extractor import MainContentExtractor
|
||||
|
||||
from browser_use.agent.views import ActionModel, ActionResult
|
||||
from browser_use.browser.service import Browser
|
||||
from browser_use.browser.views import TabInfo
|
||||
from browser_use.controller.registry.service import Registry
|
||||
from browser_use.controller.views import (
|
||||
ControllerActionResult,
|
||||
ControllerActions,
|
||||
ControllerPageState,
|
||||
ClickElementAction,
|
||||
DoneAction,
|
||||
ExtractPageContentAction,
|
||||
GoToUrlAction,
|
||||
InputTextAction,
|
||||
OpenTabAction,
|
||||
SearchGoogleAction,
|
||||
SwitchTabAction,
|
||||
)
|
||||
from browser_use.utils import time_execution_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class ControllerService:
|
||||
"""
|
||||
Controller service that interacts with the browser.
|
||||
|
||||
Right now this is just a LLM friendly wrapper around the browser service.
|
||||
In the future we can add the functionality that this is a self-contained agent that can plan and act single steps.
|
||||
|
||||
TODO: easy hanging fruit: pass in a list of actions, compare html that changed and self assess if goal is done -> makes clicking MUCH MUCH faster and cheaper.
|
||||
|
||||
TODO#2: from the state generate functions that can be passed directly into the LLM as function calls. Then it could actually in the same call request for example multiple actions and new state.
|
||||
"""
|
||||
|
||||
class Controller:
|
||||
def __init__(self, keep_open: bool = False):
|
||||
self.browser = BrowserService(keep_open=keep_open)
|
||||
self.cached_browser_state: BrowserState | None = None
|
||||
self.browser = Browser(keep_open=keep_open)
|
||||
self.registry = Registry()
|
||||
self._register_default_actions()
|
||||
|
||||
@time_execution_sync('--get_cached_browser_state')
|
||||
def get_cached_browser_state(self, force_update: bool = False) -> BrowserState:
|
||||
if self.cached_browser_state is None or force_update:
|
||||
self.cached_browser_state = self.browser.get_updated_state()
|
||||
return self.cached_browser_state
|
||||
def _register_default_actions(self):
|
||||
"""Register all default browser actions"""
|
||||
|
||||
return self.cached_browser_state
|
||||
# return self.browser.get_updated_state()
|
||||
|
||||
def get_current_state(self, screenshot: bool = False) -> ControllerPageState:
|
||||
browser_state = self.get_cached_browser_state(force_update=True)
|
||||
|
||||
# Get tab information without switching
|
||||
tabs = self.browser.get_tabs_info()
|
||||
|
||||
screenshot_b64 = None
|
||||
if screenshot:
|
||||
screenshot_b64 = self.browser.take_screenshot(selector_map=browser_state.selector_map)
|
||||
|
||||
return ControllerPageState(
|
||||
items=browser_state.items,
|
||||
url=browser_state.url,
|
||||
title=browser_state.title,
|
||||
selector_map=browser_state.selector_map,
|
||||
screenshot=screenshot_b64,
|
||||
tabs=tabs,
|
||||
# Basic Navigation Actions
|
||||
@self.registry.action(
|
||||
'Search Google', param_model=SearchGoogleAction, requires_browser=True
|
||||
)
|
||||
def search_google(params: SearchGoogleAction, browser: Browser):
|
||||
driver = browser._get_driver()
|
||||
driver.get(f'https://www.google.com/search?q={params.query}')
|
||||
browser.wait_for_page_load()
|
||||
|
||||
@self.registry.action('Navigate to URL', param_model=GoToUrlAction, requires_browser=True)
|
||||
def go_to_url(params: GoToUrlAction, browser: Browser):
|
||||
driver = browser._get_driver()
|
||||
driver.get(params.url)
|
||||
browser.wait_for_page_load()
|
||||
|
||||
@self.registry.action('Go back', requires_browser=True)
|
||||
def go_back(browser: Browser):
|
||||
driver = browser._get_driver()
|
||||
driver.back()
|
||||
browser.wait_for_page_load()
|
||||
|
||||
# Element Interaction Actions
|
||||
@self.registry.action(
|
||||
'Click element', param_model=ClickElementAction, requires_browser=True
|
||||
)
|
||||
def click_element(params: ClickElementAction, browser: Browser):
|
||||
state = browser._cached_state
|
||||
|
||||
if params.index not in state.selector_map:
|
||||
print(state.selector_map)
|
||||
raise Exception(
|
||||
f'Element with index {params.index} does not exist - retry or use alternative actions'
|
||||
)
|
||||
|
||||
xpath = state.selector_map[params.index]
|
||||
driver = browser._get_driver()
|
||||
initial_handles = len(driver.window_handles)
|
||||
|
||||
msg = None
|
||||
|
||||
for _ in range(params.num_clicks):
|
||||
try:
|
||||
browser._click_element_by_xpath(xpath)
|
||||
msg = f'🖱️ Clicked element {params.index}: {xpath}'
|
||||
if params.num_clicks > 1:
|
||||
msg += f' ({_ + 1}/{params.num_clicks} clicks)'
|
||||
except Exception as e:
|
||||
logger.warning(f'Element no longer available after {_ + 1} clicks: {str(e)}')
|
||||
break
|
||||
|
||||
if len(driver.window_handles) > initial_handles:
|
||||
browser.handle_new_tab()
|
||||
|
||||
return ActionResult(extracted_content=f'Clicked element {msg}')
|
||||
|
||||
@self.registry.action('Input text', param_model=InputTextAction, requires_browser=True)
|
||||
def input_text(params: InputTextAction, browser: Browser):
|
||||
state = browser._cached_state
|
||||
if params.index not in state.selector_map:
|
||||
raise Exception(
|
||||
f'Element index {params.index} does not exist - retry or use alternative actions'
|
||||
)
|
||||
|
||||
xpath = state.selector_map[params.index]
|
||||
browser._input_text_by_xpath(xpath, params.text)
|
||||
msg = f'⌨️ Input text "{params.text}" into element {params.index}: {xpath}'
|
||||
return ActionResult(extracted_content=msg)
|
||||
|
||||
# Tab Management Actions
|
||||
@self.registry.action('Switch tab', param_model=SwitchTabAction, requires_browser=True)
|
||||
def switch_tab(params: SwitchTabAction, browser: Browser):
|
||||
driver = browser._get_driver()
|
||||
|
||||
# Verify handle exists
|
||||
if params.handle not in driver.window_handles:
|
||||
raise ValueError(f'Tab handle {params.handle} not found')
|
||||
|
||||
# Only switch if we're not already on that tab
|
||||
if params.handle != driver.current_window_handle:
|
||||
driver.switch_to.window(params.handle)
|
||||
browser._current_handle = params.handle
|
||||
# Wait for tab to be ready
|
||||
browser.wait_for_page_load()
|
||||
|
||||
# Update and return tab info
|
||||
tab_info = TabInfo(handle=params.handle, url=driver.current_url, title=driver.title)
|
||||
browser._tab_cache[params.handle] = tab_info
|
||||
|
||||
@self.registry.action('Open new tab', param_model=OpenTabAction, requires_browser=True)
|
||||
def open_tab(params: OpenTabAction, browser: Browser):
|
||||
driver = browser._get_driver()
|
||||
driver.execute_script(f'window.open("{params.url}", "_blank");')
|
||||
browser.wait_for_page_load()
|
||||
browser.handle_new_tab()
|
||||
|
||||
# Content Actions
|
||||
@self.registry.action(
|
||||
'Extract page content', param_model=ExtractPageContentAction, requires_browser=True
|
||||
)
|
||||
def extract_content(params: ExtractPageContentAction, browser: Browser):
|
||||
driver = browser._get_driver()
|
||||
|
||||
content = MainContentExtractor.extract( # type: ignore
|
||||
html=driver.page_source,
|
||||
output_format=params.value,
|
||||
)
|
||||
return ActionResult(extracted_content=content)
|
||||
|
||||
@self.registry.action('Complete task', param_model=DoneAction, requires_browser=True)
|
||||
def done(params: DoneAction, browser: Browser):
|
||||
logger.info(f'✅ Done on page {browser._cached_state.url}\n\n: {params.text}')
|
||||
return ActionResult(is_done=True, extracted_content=params.text)
|
||||
|
||||
def action(self, description: str, **kwargs):
|
||||
"""Decorator for registering custom actions
|
||||
|
||||
@param description: Describe the LLM what the function does (better description == better function calling)
|
||||
"""
|
||||
return self.registry.action(description, **kwargs)
|
||||
|
||||
@time_execution_sync('--act')
|
||||
def act(self, action: ControllerActions) -> ControllerActionResult:
|
||||
def act(self, action: ActionModel) -> ActionResult:
|
||||
"""Execute an action"""
|
||||
try:
|
||||
current_state = self.get_cached_browser_state(force_update=False)
|
||||
|
||||
if action.search_google:
|
||||
self.browser.search_google(action.search_google.query)
|
||||
elif action.switch_tab:
|
||||
self.browser.switch_tab(action.switch_tab.handle)
|
||||
elif action.open_tab:
|
||||
self.browser.open_tab(action.open_tab.url)
|
||||
elif action.go_to_url:
|
||||
self.browser.go_to_url(action.go_to_url.url)
|
||||
elif action.nothing:
|
||||
# self.browser.nothing()
|
||||
# TODO: implement
|
||||
pass
|
||||
elif action.go_back:
|
||||
self.browser.go_back()
|
||||
elif action.done:
|
||||
self.browser.done(action.done.text)
|
||||
return ControllerActionResult(done=True, extracted_content=action.done.text)
|
||||
elif action.click_element:
|
||||
self.browser.click_element_by_index(action.click_element.id, current_state)
|
||||
elif action.input_text:
|
||||
self.browser.input_text_by_index(
|
||||
action.input_text.id, action.input_text.text, current_state
|
||||
)
|
||||
elif action.extract_page_content:
|
||||
content = self.browser.extract_page_content()
|
||||
return ControllerActionResult(done=False, extracted_content=content)
|
||||
else:
|
||||
raise ValueError(f'Unknown action: {action}')
|
||||
|
||||
return ControllerActionResult(done=False)
|
||||
|
||||
for action_name, params in action.model_dump(exclude_unset=True).items():
|
||||
if params is not None:
|
||||
result = self.registry.execute_action(action_name, params, browser=self.browser)
|
||||
if isinstance(result, str):
|
||||
return ActionResult(extracted_content=result)
|
||||
elif isinstance(result, ActionResult):
|
||||
return result
|
||||
elif result is None:
|
||||
return ActionResult()
|
||||
else:
|
||||
raise ValueError(f'Invalid action result type: {type(result)} of {result}')
|
||||
return ActionResult()
|
||||
except Exception as e:
|
||||
return ControllerActionResult(done=False, error=f'Error executing action: {str(e)}')
|
||||
raise e
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import pytest
|
||||
from browser_use.controller.service import ControllerService
|
||||
|
||||
from browser_use.agent.views import ActionModel
|
||||
from browser_use.controller.service import Controller
|
||||
|
||||
|
||||
def test_get_current_state():
|
||||
# Initialize controller
|
||||
controller = ControllerService()
|
||||
controller = Controller()
|
||||
|
||||
# Go to a test URL
|
||||
controller.browser.go_to_url('https://www.example.com')
|
||||
# controller.act(ActionModel(name='go_to_url', url='https://www.example.com'))
|
||||
|
||||
# Get current state without screenshot
|
||||
state = controller.get_current_state(screenshot=True)
|
||||
state = controller.browser.get_state(use_vision=True)
|
||||
|
||||
input('Press Enter to continue...')
|
||||
|
||||
@@ -1,99 +1,38 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from browser_use.browser.views import BrowserState
|
||||
|
||||
|
||||
class SearchGoogleControllerAction(BaseModel):
|
||||
# Action Input Models
|
||||
class SearchGoogleAction(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
class GoToUrlControllerAction(BaseModel):
|
||||
class GoToUrlAction(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class ClickElementControllerAction(BaseModel):
|
||||
id: int
|
||||
class ClickElementAction(BaseModel):
|
||||
index: int
|
||||
num_clicks: int = 1
|
||||
|
||||
|
||||
class InputTextControllerAction(BaseModel):
|
||||
id: int
|
||||
class InputTextAction(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
|
||||
|
||||
class DoneControllerAction(BaseModel):
|
||||
class DoneAction(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class SwitchTabControllerAction(BaseModel):
|
||||
handle: str # The window handle to switch to
|
||||
class SwitchTabAction(BaseModel):
|
||||
handle: str
|
||||
|
||||
|
||||
class OpenTabControllerAction(BaseModel):
|
||||
class OpenTabAction(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class ControllerActions(BaseModel):
|
||||
"""
|
||||
Controller actions you can use to interact.
|
||||
"""
|
||||
|
||||
search_google: Optional[SearchGoogleControllerAction] = None
|
||||
go_to_url: Optional[GoToUrlControllerAction] = None
|
||||
nothing: Optional[Literal[True]] = None
|
||||
go_back: Optional[Literal[True]] = None
|
||||
done: Optional[DoneControllerAction] = None
|
||||
click_element: Optional[ClickElementControllerAction] = None
|
||||
input_text: Optional[InputTextControllerAction] = None
|
||||
extract_page_content: Optional[Literal[True]] = None
|
||||
switch_tab: Optional[SwitchTabControllerAction] = None
|
||||
open_tab: Optional[OpenTabControllerAction] = None
|
||||
|
||||
@staticmethod
|
||||
def description() -> str:
|
||||
"""
|
||||
Returns a human-readable description of available actions.
|
||||
"""
|
||||
return """
|
||||
- Search Google with a query
|
||||
Example: {"search_google": {"query": "weather today"}}
|
||||
- Navigate directly to a URL where you want to go
|
||||
Example: {"go_to_url": {"url": "https://abc.com"}}
|
||||
- Do nothing/wait
|
||||
Example: {"nothing": true}
|
||||
- Go back to previous page
|
||||
Example: {"go_back": true}
|
||||
- Mark entire task as complete
|
||||
Example: {"done": {"text": "This is the requested result of the task which is send to the human..."}}
|
||||
- Click an interactive element by its given ID
|
||||
Example: {"click_element": {"id": 1}}
|
||||
- Input text into an interactiveelement by its ID
|
||||
Example: {"input_text": {"id": 1, "text": "Hello world"}}
|
||||
- Get the page content in markdown
|
||||
Example: {"extract_page_content": true}
|
||||
- Switch to a different browser tab
|
||||
Example: {"switch_tab": {"handle": "CDwindow-1234..."}}
|
||||
- Open a new tab
|
||||
Example: {"open_tab": {"url": "https://abc.com"}}
|
||||
"""
|
||||
|
||||
|
||||
class ControllerActionResult(BaseModel):
|
||||
done: bool
|
||||
extracted_content: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class ControllerPageState(BrowserState):
|
||||
screenshot: Optional[str] = None
|
||||
tabs: list[dict] = [] # Add tabs info to state
|
||||
|
||||
def model_dump(self) -> dict:
|
||||
dump = super().model_dump()
|
||||
# Add a summary of available tabs
|
||||
if self.tabs:
|
||||
dump['available_tabs'] = [
|
||||
f"Tab {i+1}: {tab['title']} ({tab['url']})" for i, tab in enumerate(self.tabs)
|
||||
]
|
||||
return dump
|
||||
class ExtractPageContentAction(BaseModel):
|
||||
value: Literal['text', 'markdown', 'html'] = 'text'
|
||||
|
||||
@@ -1,14 +1,23 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from bs4 import BeautifulSoup, NavigableString, PageElement, Tag
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.remote.webelement import WebElement
|
||||
|
||||
from browser_use.dom.views import DomContentItem, ProcessedDomContent
|
||||
from browser_use.dom.views import (
|
||||
BatchCheckResults,
|
||||
DomContentItem,
|
||||
ElementCheckResult,
|
||||
ElementState,
|
||||
ProcessedDomContent,
|
||||
TextCheckResult,
|
||||
TextState,
|
||||
)
|
||||
from browser_use.utils import time_execution_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class DomService:
|
||||
@@ -24,111 +33,239 @@ class DomService:
|
||||
|
||||
@time_execution_sync('--_process_content')
|
||||
def _process_content(self, html_content: str) -> ProcessedDomContent:
|
||||
"""
|
||||
Process HTML content to extract and clean relevant elements.
|
||||
Args:
|
||||
html_content: Raw HTML string to process
|
||||
Returns:
|
||||
ProcessedDomContent: Processed DOM content
|
||||
|
||||
@dev TODO: instead of of using enumerated index, use random 4 digit numbers -> a bit more tokens BUT updates on the screen wont click on incorrect items -> tricky because you have to consider that same elements need to have the same index ...
|
||||
"""
|
||||
|
||||
# Parse HTML content using BeautifulSoup with html.parser
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
|
||||
candidate_elements: list[Tag | NavigableString] = []
|
||||
output_items: list[DomContentItem] = []
|
||||
selector_map: dict[int, str] = {}
|
||||
current_index = 0
|
||||
|
||||
def _process_element(element: PageElement):
|
||||
should_add_element = False
|
||||
# Collectors for batch processing with order tracking
|
||||
interactive_elements: dict[str, tuple[Tag, int]] = {} # xpath -> (element, order)
|
||||
text_nodes: dict[str, tuple[NavigableString, int]] = {} # xpath -> (text_node, order)
|
||||
xpath_order_counter = 0 # Track order of appearance
|
||||
|
||||
# if not self._quick_element_filter(element):
|
||||
# if isinstance(element, Tag):
|
||||
# element.decompose()
|
||||
# return
|
||||
dom_queue: list[tuple[PageElement, list, Optional[str]]] = (
|
||||
[(element, [], None) for element in reversed(list(soup.body.children))]
|
||||
if soup.body
|
||||
else []
|
||||
)
|
||||
|
||||
# First pass: collect all elements that need checking
|
||||
while dom_queue:
|
||||
element, path_indices, parent_xpath = dom_queue.pop()
|
||||
|
||||
if isinstance(element, Tag):
|
||||
# Don't add any children of non-interactive elements
|
||||
if not self._is_element_accepted(element):
|
||||
element.decompose()
|
||||
return
|
||||
|
||||
if self._is_interactive_element(element) or self._is_leaf_element(element):
|
||||
if (
|
||||
self._is_active(element)
|
||||
and self._is_top_element(element)
|
||||
and self._is_visible(element)
|
||||
):
|
||||
should_add_element = True
|
||||
|
||||
elif isinstance(element, NavigableString) and element.strip():
|
||||
if self._is_visible(element):
|
||||
should_add_element = True
|
||||
|
||||
if should_add_element:
|
||||
if isinstance(element, (Tag, NavigableString)):
|
||||
candidate_elements.append(element)
|
||||
|
||||
if isinstance(element, Tag):
|
||||
for child in element.children:
|
||||
_process_element(child)
|
||||
|
||||
for element in soup.body.children if soup.body else []:
|
||||
_process_element(element)
|
||||
|
||||
# Process candidates
|
||||
selector_map: dict[int, str] = {}
|
||||
output_items: list[DomContentItem] = []
|
||||
|
||||
for index, element in enumerate(candidate_elements):
|
||||
xpath = self._generate_xpath(element)
|
||||
|
||||
depth = max(len(xpath.split('/')) - 2, 0)
|
||||
|
||||
# Skip text nodes that are direct children of already processed elements
|
||||
if isinstance(element, NavigableString) and element.parent in [
|
||||
e for e in candidate_elements
|
||||
]:
|
||||
continue
|
||||
|
||||
if isinstance(element, NavigableString):
|
||||
text_content = self._cap_text_length(element.strip())
|
||||
if text_content:
|
||||
output_string = f'{text_content}'
|
||||
output_items.append(
|
||||
DomContentItem(
|
||||
index=index, text=output_string, clickable=False, depth=depth
|
||||
)
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# don't add empty text nodes
|
||||
continue
|
||||
|
||||
else:
|
||||
text_content = self._extract_text_from_all_children(element)
|
||||
|
||||
tag_name = element.name
|
||||
attributes = self._get_essential_attributes(element)
|
||||
|
||||
opening_tag = f"<{tag_name}{' ' + attributes if attributes else ''}>"
|
||||
closing_tag = f'</{tag_name}>'
|
||||
|
||||
output_string = f'{opening_tag}{text_content}{closing_tag}'
|
||||
output_items.append(
|
||||
DomContentItem(index=index, text=output_string, clickable=True, depth=depth)
|
||||
siblings = (
|
||||
list(element.parent.find_all(element.name, recursive=False))
|
||||
if element.parent
|
||||
else []
|
||||
)
|
||||
sibling_index = siblings.index(element) + 1 if siblings else 1
|
||||
current_path = path_indices + [(element.name, sibling_index)]
|
||||
element_xpath = '//' + '/'.join(f'{tag}[{idx}]' for tag, idx in current_path)
|
||||
|
||||
selector_map[index] = xpath
|
||||
# Add children to queue with their path information
|
||||
for child in reversed(list(element.children)):
|
||||
dom_queue.append((child, current_path, element_xpath)) # Pass parent's xpath
|
||||
|
||||
# Remove all elements from selector map that are not in output items
|
||||
selector_map = {
|
||||
k: v for k, v in selector_map.items() if k in [i.index for i in output_items]
|
||||
}
|
||||
# Collect interactive elements with their order
|
||||
if (
|
||||
self._is_interactive_element(element) or self._is_leaf_element(element)
|
||||
) and self._is_active(element):
|
||||
interactive_elements[element_xpath] = (element, xpath_order_counter)
|
||||
xpath_order_counter += 1
|
||||
|
||||
elif isinstance(element, NavigableString) and element.strip():
|
||||
if element.parent and element.parent not in [e[0] for e in dom_queue]:
|
||||
if parent_xpath:
|
||||
text_nodes[parent_xpath] = (element, xpath_order_counter)
|
||||
xpath_order_counter += 1
|
||||
|
||||
# Batch check all elements
|
||||
element_results = self._batch_check_elements(interactive_elements)
|
||||
text_results = self._batch_check_texts(text_nodes)
|
||||
|
||||
# Create ordered results
|
||||
ordered_results: list[
|
||||
tuple[int, str, bool, str, int, bool]
|
||||
] = [] # [(order, xpath, is_clickable, content, depth, is_text_only), ...]
|
||||
|
||||
# Process interactive elements
|
||||
for xpath, (element, order) in interactive_elements.items():
|
||||
if xpath in element_results.elements:
|
||||
result = element_results.elements[xpath]
|
||||
if result.isVisible and result.isTopElement:
|
||||
text_content = self._extract_text_from_all_children(element)
|
||||
tag_name = element.name
|
||||
attributes = self._get_essential_attributes(element)
|
||||
output_string = f"<{tag_name}{' ' + attributes if attributes else ''}>{text_content}</{tag_name}>"
|
||||
|
||||
depth = len(xpath.split('/')) - 2
|
||||
ordered_results.append((order, xpath, True, output_string, depth, False))
|
||||
|
||||
# Process text nodes
|
||||
for xpath, (text_node, order) in text_nodes.items():
|
||||
if xpath in text_results.texts:
|
||||
result = text_results.texts[xpath]
|
||||
if result.isVisible:
|
||||
text_content = self._cap_text_length(text_node.strip())
|
||||
if text_content:
|
||||
depth = len(xpath.split('/')) - 2
|
||||
ordered_results.append((order, xpath, False, text_content, depth, True))
|
||||
|
||||
# Sort by original order
|
||||
ordered_results.sort(key=lambda x: x[0])
|
||||
|
||||
# Build final output maintaining order
|
||||
for i, (_, xpath, is_clickable, content, depth, is_text_only) in enumerate(ordered_results):
|
||||
output_items.append(
|
||||
DomContentItem(
|
||||
index=i,
|
||||
text=content,
|
||||
# clickable=is_clickable,
|
||||
depth=depth,
|
||||
is_text_only=is_text_only,
|
||||
)
|
||||
)
|
||||
# if is_clickable: # Only add clickable elements to selector map
|
||||
# TODO: make this right, for now we add all elements (except text) to selector map
|
||||
if not is_text_only:
|
||||
selector_map[i] = xpath
|
||||
|
||||
return ProcessedDomContent(items=output_items, selector_map=selector_map)
|
||||
|
||||
def _cap_text_length(self, text: str, max_length: int = 150) -> str:
|
||||
def _batch_check_elements(self, elements: dict[str, tuple[Tag, int]]) -> BatchCheckResults:
|
||||
"""Batch check all interactive elements at once."""
|
||||
if not elements:
|
||||
return BatchCheckResults(elements={}, texts={})
|
||||
|
||||
check_script = """
|
||||
return (function() {
|
||||
const results = {};
|
||||
const elements = %s;
|
||||
|
||||
for (const [xpath, elementData] of Object.entries(elements)) {
|
||||
const element = document.evaluate(xpath, document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
|
||||
if (!element) continue;
|
||||
|
||||
// Check visibility
|
||||
const isVisible = element.checkVisibility({
|
||||
checkOpacity: true,
|
||||
checkVisibilityCSS: true
|
||||
});
|
||||
|
||||
if (!isVisible) continue;
|
||||
|
||||
// Check if topmost
|
||||
const rect = element.getBoundingClientRect();
|
||||
const points = [
|
||||
{x: rect.left + rect.width * 0.25, y: rect.top + rect.height * 0.25},
|
||||
{x: rect.left + rect.width * 0.75, y: rect.top + rect.height * 0.25},
|
||||
{x: rect.left + rect.width * 0.25, y: rect.top + rect.height * 0.75},
|
||||
{x: rect.left + rect.width * 0.75, y: rect.top + rect.height * 0.75},
|
||||
{x: rect.left + rect.width / 2, y: rect.top + rect.height / 2}
|
||||
];
|
||||
|
||||
const isTopElement = points.some(point => {
|
||||
const topEl = document.elementFromPoint(point.x, point.y);
|
||||
let current = topEl;
|
||||
while (current && current !== document.body) {
|
||||
if (current === element) return true;
|
||||
current = current.parentElement;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
if (isTopElement) {
|
||||
results[xpath] = {
|
||||
xpath: xpath,
|
||||
isVisible: true,
|
||||
isTopElement: true
|
||||
};
|
||||
}
|
||||
}
|
||||
return results;
|
||||
})();
|
||||
""" % json.dumps({xpath: {} for xpath in elements.keys()})
|
||||
|
||||
try:
|
||||
results = self.driver.execute_script(check_script)
|
||||
return BatchCheckResults(
|
||||
elements={xpath: ElementCheckResult(**data) for xpath, data in results.items()},
|
||||
texts={},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error('Error in batch element check: %s', e)
|
||||
return BatchCheckResults(elements={}, texts={})
|
||||
|
||||
def _batch_check_texts(
|
||||
self, texts: dict[str, tuple[NavigableString, int]]
|
||||
) -> BatchCheckResults:
|
||||
"""Batch check all text nodes at once."""
|
||||
if not texts:
|
||||
return BatchCheckResults(elements={}, texts={})
|
||||
|
||||
check_script = """
|
||||
return (function() {
|
||||
const results = {};
|
||||
const texts = %s;
|
||||
|
||||
for (const [xpath, textData] of Object.entries(texts)) {
|
||||
const parent = document.evaluate(xpath, document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
|
||||
if (!parent) continue;
|
||||
|
||||
try {
|
||||
const range = document.createRange();
|
||||
const textNode = parent.childNodes[textData.index];
|
||||
range.selectNodeContents(textNode);
|
||||
const rect = range.getBoundingClientRect();
|
||||
|
||||
const isVisible = (
|
||||
rect.width !== 0 &&
|
||||
rect.height !== 0 &&
|
||||
rect.top >= 0 &&
|
||||
rect.top <= window.innerHeight &&
|
||||
parent.checkVisibility({
|
||||
checkOpacity: true,
|
||||
checkVisibilityCSS: true
|
||||
})
|
||||
);
|
||||
|
||||
if (isVisible) {
|
||||
results[xpath] = {
|
||||
xpath: xpath,
|
||||
isVisible: true
|
||||
};
|
||||
}
|
||||
} catch (e) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return results;
|
||||
})();
|
||||
""" % json.dumps(
|
||||
{
|
||||
xpath: {'index': list(text_node[0].parent.children).index(text_node[0])}
|
||||
for xpath, text_node in texts.items()
|
||||
if text_node[0].parent
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
results = self.driver.execute_script(check_script)
|
||||
return BatchCheckResults(
|
||||
elements={},
|
||||
texts={xpath: TextCheckResult(**data) for xpath, data in results.items()},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error('Error in batch text check: %s', e)
|
||||
return BatchCheckResults(elements={}, texts={})
|
||||
|
||||
def _cap_text_length(self, text: str, max_length: int = 250) -> str:
|
||||
if len(text) > max_length:
|
||||
half_length = max_length // 2
|
||||
return text[:half_length] + '...' + text[-half_length:]
|
||||
@@ -230,75 +367,6 @@ class DomService:
|
||||
|
||||
return element.name not in leaf_element_deny_list
|
||||
|
||||
def _generate_xpath(self, element: Tag | NavigableString) -> str:
|
||||
# Generate cache key based on element properties
|
||||
cache_key = None
|
||||
if isinstance(element, Tag):
|
||||
attributes = [
|
||||
element.get('id', ''),
|
||||
element.get('class', ''),
|
||||
element.name,
|
||||
element.get_text().strip(),
|
||||
]
|
||||
cache_key = '|'.join(str(attr) for attr in attributes)
|
||||
|
||||
# Return cached xpath if exists
|
||||
if cache_key in self.xpath_cache:
|
||||
return self.xpath_cache[cache_key]
|
||||
|
||||
if isinstance(element, NavigableString):
|
||||
if element.parent:
|
||||
return self._generate_xpath(element.parent)
|
||||
return ''
|
||||
|
||||
if not hasattr(element, 'name'):
|
||||
return ''
|
||||
|
||||
element_id = getattr(element, 'get', lambda x: None)('id')
|
||||
if element_id:
|
||||
xpath = f"//*[@id='{element_id}']"
|
||||
if cache_key:
|
||||
self.xpath_cache[cache_key] = xpath
|
||||
return xpath
|
||||
|
||||
parts = []
|
||||
current = element
|
||||
|
||||
while current and getattr(current, 'name', None):
|
||||
if current.name == '[document]':
|
||||
break
|
||||
|
||||
parent = getattr(current, 'parent', None)
|
||||
if parent and hasattr(parent, 'children'):
|
||||
# Get only visible element nodes
|
||||
siblings = [
|
||||
s for s in parent.find_all(current.name, recursive=False) if isinstance(s, Tag)
|
||||
]
|
||||
|
||||
if len(siblings) > 1:
|
||||
try:
|
||||
index = siblings.index(current) + 1
|
||||
parts.insert(0, f'{current.name}[{index}]')
|
||||
except ValueError:
|
||||
parts.insert(0, current.name)
|
||||
else:
|
||||
parts.insert(0, current.name)
|
||||
|
||||
current = parent
|
||||
|
||||
if parts and parts[0] != 'html':
|
||||
parts.insert(0, 'html')
|
||||
if len(parts) > 1 and parts[1] != 'body':
|
||||
parts.insert(1, 'body')
|
||||
|
||||
xpath = '//' + '/'.join(parts)
|
||||
|
||||
# Cache the generated xpath
|
||||
if cache_key:
|
||||
self.xpath_cache[cache_key] = xpath
|
||||
|
||||
return xpath
|
||||
|
||||
def _get_essential_attributes(self, element: Tag) -> str:
|
||||
"""
|
||||
Collects essential attributes from an element.
|
||||
@@ -333,11 +401,11 @@ class DomService:
|
||||
if attr in element.attrs:
|
||||
element_attr = element[attr]
|
||||
if isinstance(element_attr, str):
|
||||
element_attr = element_attr[:50]
|
||||
element_attr = element_attr
|
||||
elif isinstance(element_attr, (list, tuple)):
|
||||
element_attr = ' '.join(str(v)[:50] for v in element_attr)
|
||||
element_attr = ' '.join(str(v) for v in element_attr)
|
||||
|
||||
attrs.append(f'{attr}="{element_attr}"')
|
||||
attrs.append(f'{attr}="{self._cap_text_length(element_attr, 25)}"')
|
||||
|
||||
state_attributes_prefixes = (
|
||||
'aria-',
|
||||
@@ -351,115 +419,6 @@ class DomService:
|
||||
|
||||
return ' '.join(attrs)
|
||||
|
||||
def _is_visible(self, element: Tag | NavigableString) -> bool:
|
||||
if not isinstance(element, Tag):
|
||||
return self._is_text_visible(element)
|
||||
|
||||
element_id = element.get('id', '')
|
||||
if element_id:
|
||||
js_selector = f'document.getElementById("{element_id}")'
|
||||
else:
|
||||
xpath = self._generate_xpath(element)
|
||||
js_selector = f'document.evaluate("{xpath}", document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue'
|
||||
|
||||
visibility_check = f"""
|
||||
return (function() {{
|
||||
const element = {js_selector};
|
||||
|
||||
if (!element) {{
|
||||
return false;
|
||||
}}
|
||||
|
||||
// Force return as boolean
|
||||
return Boolean(element.checkVisibility({{
|
||||
checkOpacity: true,
|
||||
checkVisibilityCSS: true
|
||||
}}));
|
||||
}}());
|
||||
"""
|
||||
|
||||
try:
|
||||
# todo: parse responses with pydantic
|
||||
is_visible = self.driver.execute_script(visibility_check)
|
||||
return bool(is_visible)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _is_text_visible(self, element: NavigableString) -> bool:
|
||||
"""Check if text node is visible using JavaScript."""
|
||||
parent = element.parent
|
||||
if not parent:
|
||||
return False
|
||||
|
||||
xpath = self._generate_xpath(parent)
|
||||
visibility_check = f"""
|
||||
return (function() {{
|
||||
const parent = document.evaluate("{xpath}", document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
|
||||
|
||||
if (!parent) {{
|
||||
return false;
|
||||
}}
|
||||
|
||||
const range = document.createRange();
|
||||
const textNode = parent.childNodes[{list(parent.children).index(element)}];
|
||||
range.selectNodeContents(textNode);
|
||||
const rect = range.getBoundingClientRect();
|
||||
|
||||
if (rect.width === 0 || rect.height === 0 ||
|
||||
rect.top < 0 || rect.top > window.innerHeight) {{
|
||||
return false;
|
||||
}}
|
||||
|
||||
// Force return as boolean
|
||||
return Boolean(parent.checkVisibility({{
|
||||
checkOpacity: true,
|
||||
checkVisibilityCSS: true
|
||||
}}));
|
||||
}}());
|
||||
"""
|
||||
try:
|
||||
is_visible = self.driver.execute_script(visibility_check)
|
||||
return bool(is_visible)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _is_top_element(self, element: Tag | NavigableString, rect=None) -> bool:
|
||||
"""Check if element is the topmost at its position."""
|
||||
xpath = self._generate_xpath(element)
|
||||
check_top = f"""
|
||||
return (function() {{
|
||||
const elem = document.evaluate("{xpath}", document, null, XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
|
||||
if (!elem) {{
|
||||
return false;
|
||||
}}
|
||||
|
||||
const rect = elem.getBoundingClientRect();
|
||||
const points = [
|
||||
{{x: rect.left + rect.width * 0.25, y: rect.top + rect.height * 0.25}},
|
||||
{{x: rect.left + rect.width * 0.75, y: rect.top + rect.height * 0.25}},
|
||||
{{x: rect.left + rect.width * 0.25, y: rect.top + rect.height * 0.75}},
|
||||
{{x: rect.left + rect.width * 0.75, y: rect.top + rect.height * 0.75}},
|
||||
{{x: rect.left + rect.width / 2, y: rect.top + rect.height / 2}}
|
||||
];
|
||||
|
||||
return Boolean(points.some(point => {{
|
||||
const topEl = document.elementFromPoint(point.x, point.y);
|
||||
let current = topEl;
|
||||
while (current && current !== document.body) {{
|
||||
if (current === elem) return true;
|
||||
current = current.parentElement;
|
||||
}}
|
||||
return false;
|
||||
}}));
|
||||
}}());
|
||||
"""
|
||||
try:
|
||||
is_top = self.driver.execute_script(check_top)
|
||||
return bool(is_top)
|
||||
except Exception:
|
||||
logger.error(f'Error checking top element: {element}')
|
||||
return False
|
||||
|
||||
def _is_active(self, element: Tag) -> bool:
|
||||
"""Check if element is active (not disabled)."""
|
||||
return not (
|
||||
@@ -467,40 +426,3 @@ class DomService:
|
||||
or element.get('hidden') is not None
|
||||
or element.get('aria-disabled') == 'true'
|
||||
)
|
||||
|
||||
# def _quick_element_filter(self, element: PageElement) -> bool:
|
||||
# """
|
||||
# Quick pre-filter to eliminate elements before expensive checks.
|
||||
# Returns True if element passes initial filtering.
|
||||
# """
|
||||
# if isinstance(element, NavigableString):
|
||||
# # Quick check for empty or whitespace-only strings
|
||||
# return bool(element.strip())
|
||||
|
||||
# if not isinstance(element, Tag):
|
||||
# return False
|
||||
|
||||
# style = element.get('style')
|
||||
|
||||
# # Quick attribute checks that would make element invisible/non-interactive
|
||||
# if any(
|
||||
# [
|
||||
# element.get('aria-hidden') == 'true',
|
||||
# element.get('hidden') is not None,
|
||||
# element.get('disabled') is not None,
|
||||
# style and ('display: none' in style or 'visibility: hidden' in style),
|
||||
# element.has_attr('class')
|
||||
# and any(cls in element['class'] for cls in ['hidden', 'invisible']),
|
||||
# # Common hidden class patterns
|
||||
# element.get('type') == 'hidden',
|
||||
# ]
|
||||
# ):
|
||||
# return False
|
||||
|
||||
# # Skip elements that definitely won't be interactive or visible
|
||||
# non_interactive_display = ['none', 'hidden']
|
||||
# computed_style = element.get('style', '') or ''
|
||||
# if any(display in computed_style for display in non_interactive_display):
|
||||
# return False
|
||||
|
||||
# return True
|
||||
|
||||
@@ -2,33 +2,32 @@ import time
|
||||
|
||||
from tokencost import count_string_tokens
|
||||
|
||||
from browser_use.browser.service import BrowserService
|
||||
from browser_use.browser.service import Browser
|
||||
from browser_use.dom.service import DomService
|
||||
from browser_use.utils import time_execution_sync
|
||||
|
||||
|
||||
# @pytest.mark.skip("slow af")
|
||||
def test_process_html_file():
|
||||
browser = BrowserService(headless=False)
|
||||
browser = Browser(headless=False)
|
||||
|
||||
driver = browser.init()
|
||||
driver = browser._get_driver()
|
||||
|
||||
dom_service = DomService(driver)
|
||||
|
||||
browser.go_to_url('https://www.kayak.ch')
|
||||
driver.get('https://kayak.com/flights')
|
||||
# browser.go_to_url('https://google.com/flights')
|
||||
# browser.go_to_url('https://immobilienscout24.de')
|
||||
|
||||
time.sleep(1)
|
||||
time.sleep(3)
|
||||
# browser._click_element_by_xpath(
|
||||
# '/html/body/div[5]/div/div[2]/div/div/div[3]/div/div[1]/button[1]'
|
||||
# )
|
||||
browser._click_element_by_xpath("//button[div/div[text()='Alle akzeptieren']]")
|
||||
# browser._click_element_by_xpath("//button[div/div[text()='Alle akzeptieren']]")
|
||||
|
||||
elements = time_execution_sync('get_clickable_elements')(
|
||||
dom_service.get_clickable_elements().dom_items_to_string
|
||||
)()
|
||||
elements = time_execution_sync('get_clickable_elements')(dom_service.get_clickable_elements)()
|
||||
|
||||
print(elements)
|
||||
print('Tokens:', count_string_tokens(elements, model='gpt-4o'))
|
||||
print(elements.dom_items_to_string(use_tabs=False))
|
||||
print('Tokens:', count_string_tokens(elements.dom_items_to_string(), model='gpt-4o'))
|
||||
|
||||
input('Press Enter to continue...')
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Dict, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DomContentItem(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
clickable: bool
|
||||
is_text_only: bool
|
||||
depth: int
|
||||
|
||||
|
||||
@@ -15,13 +16,38 @@ class ProcessedDomContent(BaseModel):
|
||||
items: list[DomContentItem]
|
||||
selector_map: SelectorMap
|
||||
|
||||
def dom_items_to_string(self) -> str:
|
||||
def dom_items_to_string(self, use_tabs: bool = True) -> str:
|
||||
"""Convert the processed DOM content to HTML."""
|
||||
formatted_text = ''
|
||||
for item in self.items:
|
||||
item_depth = '\t' * item.depth * 1
|
||||
if item.clickable:
|
||||
formatted_text += f'{item.index}:{item_depth}{item.text}\n'
|
||||
item_depth = '\t' * item.depth * 1 if use_tabs else ''
|
||||
if item.is_text_only:
|
||||
formatted_text += f'_[:]{item_depth}{item.text}\n'
|
||||
else:
|
||||
formatted_text += f'{item.index}:{item_depth}{item.text}\n'
|
||||
formatted_text += f'{item.index}[:]{item_depth}{item.text}\n'
|
||||
return formatted_text
|
||||
|
||||
|
||||
class ElementState(BaseModel):
|
||||
isVisible: bool
|
||||
isTopElement: bool
|
||||
|
||||
|
||||
class TextState(BaseModel):
|
||||
isVisible: bool
|
||||
|
||||
|
||||
class ElementCheckResult(BaseModel):
|
||||
xpath: str
|
||||
isVisible: bool
|
||||
isTopElement: bool
|
||||
|
||||
|
||||
class TextCheckResult(BaseModel):
|
||||
xpath: str
|
||||
isVisible: bool
|
||||
|
||||
|
||||
class BatchCheckResults(BaseModel):
|
||||
elements: Dict[str, ElementCheckResult]
|
||||
texts: Dict[str, TextCheckResult]
|
||||
|
||||
47
browser_use/logging_config.py
Normal file
47
browser_use/logging_config.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
def setup_logging():
|
||||
# Check if handlers are already set up
|
||||
if logging.getLogger().hasHandlers():
|
||||
return
|
||||
|
||||
# Clear existing handlers
|
||||
root = logging.getLogger()
|
||||
root.handlers = []
|
||||
|
||||
class BrowserUseFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
if record.name.startswith('browser_use.'):
|
||||
record.name = record.name.split('.')[-2]
|
||||
return super().format(record)
|
||||
|
||||
# Setup single handler for all loggers
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setFormatter(BrowserUseFormatter('%(levelname)-8s [%(name)s] %(message)s'))
|
||||
|
||||
# Configure root logger only
|
||||
root.addHandler(console)
|
||||
root.setLevel(logging.INFO)
|
||||
|
||||
# Configure browser_use logger to prevent propagation
|
||||
browser_use_logger = logging.getLogger('browser_use')
|
||||
browser_use_logger.propagate = False
|
||||
browser_use_logger.addHandler(console)
|
||||
|
||||
# Silence third-party loggers
|
||||
for logger in [
|
||||
'WDM',
|
||||
'httpx',
|
||||
'selenium',
|
||||
'urllib3',
|
||||
'asyncio',
|
||||
'langchain',
|
||||
'openai',
|
||||
'httpcore',
|
||||
'charset_normalizer',
|
||||
]:
|
||||
third_party = logging.getLogger(logger)
|
||||
third_party.setLevel(logging.ERROR)
|
||||
third_party.propagate = False
|
||||
@@ -3,13 +3,13 @@ import time
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Coroutine, ParamSpec, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define generic type variables for return type and parameters
|
||||
R = TypeVar('R')
|
||||
P = ParamSpec('P')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
def time_execution_sync(additional_text: str = '') -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@@ -18,7 +18,7 @@ def time_execution_sync(additional_text: str = '') -> Callable[[Callable[P, R]],
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f'{additional_text} Execution time: {execution_time:.2f} seconds')
|
||||
logger.debug(f'{additional_text} Execution time: {execution_time:.2f} seconds')
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -35,7 +35,7 @@ def time_execution_async(
|
||||
start_time = time.time()
|
||||
result = await func(*args, **kwargs)
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f'{additional_text} Execution time: {execution_time:.2f} seconds')
|
||||
logger.debug(f'{additional_text} Execution time: {execution_time:.2f} seconds')
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from browser_use.logging_config import setup_logging
|
||||
|
||||
# Get the absolute path to the project root
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
setup_logging()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -10,9 +9,6 @@ import asyncio
|
||||
|
||||
from browser_use import Agent, Controller
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
# Persist the browser state across agents
|
||||
controller = Controller()
|
||||
|
||||
|
||||
93
examples/extend_actions.py
Normal file
93
examples/extend_actions.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from browser_use.agent.service import Agent
|
||||
from browser_use.browser.service import Browser
|
||||
from browser_use.controller.service import Controller
|
||||
|
||||
# Initialize controller first
|
||||
controller = Controller()
|
||||
|
||||
|
||||
# Action Models
|
||||
class JobDetails(BaseModel):
|
||||
title: str
|
||||
company: str
|
||||
job_link: str
|
||||
salary: Optional[str] = None
|
||||
|
||||
|
||||
@controller.action('Save job details which you found on page', param_model=JobDetails)
|
||||
def save_job(params: JobDetails):
|
||||
with open('jobs.txt', 'a') as f:
|
||||
f.write(f'{params.title} at {params.company}: {params.salary}\n')
|
||||
|
||||
|
||||
class StarredPeople(BaseModel):
|
||||
usernames: List[str]
|
||||
|
||||
|
||||
@controller.action('Save people who starred the repo', param_model=StarredPeople)
|
||||
def save_starred_people(params: StarredPeople):
|
||||
with open('starred_people.txt', 'a') as f:
|
||||
for username in params.usernames:
|
||||
f.write(f'{username}\n')
|
||||
|
||||
|
||||
# Browser-requiring action example
|
||||
class PageSaver(BaseModel):
|
||||
filename: str
|
||||
|
||||
|
||||
@controller.action('Save current page info', param_model=PageSaver, requires_browser=True)
|
||||
def save_page_info(params: PageSaver, browser: Browser):
|
||||
state = browser.get_state()
|
||||
with open(params.filename, 'w') as f:
|
||||
f.write(f'URL: {state.url}\n')
|
||||
f.write(f'Title: {state.title}\n')
|
||||
f.write(f'HTML: {state.items}\n')
|
||||
|
||||
|
||||
class Job(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
company: str
|
||||
salary: Optional[str] = None
|
||||
|
||||
|
||||
class Jobs(BaseModel):
|
||||
jobs: List[Job]
|
||||
|
||||
|
||||
@controller.action('Save jobs', param_model=Jobs, requires_browser=True)
|
||||
def save_jobs(params: Jobs, browser: Browser):
|
||||
with open('jobs.txt', 'a') as f:
|
||||
for job in params.jobs:
|
||||
f.write(f'{job.title} at {job.company}: {job.salary} ({job.link})\n')
|
||||
|
||||
|
||||
# Without Pydantic model - using simple parameters
|
||||
@controller.action('Ask user for information')
|
||||
def ask_human(question: str, display_question: bool) -> str:
|
||||
return input(f'\n{question}\nInput: ')
|
||||
|
||||
|
||||
async def main():
|
||||
task = 'Find 10 software developer jobs in San Francisco at YC startups in google and save the jobs to a file. Then ask human for more information'
|
||||
|
||||
model = ChatOpenAI(model='gpt-4o')
|
||||
agent = Agent(task=task, llm=model, controller=controller)
|
||||
|
||||
await agent.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -1,43 +0,0 @@
|
||||
"""
|
||||
@dev You need to add ANTHROPIC_API_KEY to your environment variables.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import asyncio
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from browser_use.agent.service import AgentService
|
||||
from browser_use.controller.service import ControllerService
|
||||
|
||||
task = 'Go to kayak.com and find a one-way flight from Zürich to San Francisco on 12 January 2025.'
|
||||
controller = ControllerService()
|
||||
# model = ChatAnthropic(
|
||||
# model_name='claude-3-5-sonnet-20240620', timeout=25, stop=None, temperature=0.3
|
||||
# )
|
||||
model = ChatOpenAI(model='gpt-4o')
|
||||
agent = AgentService(task, model, controller, use_vision=True)
|
||||
|
||||
|
||||
async def main():
|
||||
max_steps = 50
|
||||
# Run the agent step by step
|
||||
for i in range(max_steps):
|
||||
print(f'\n📍 Step {i+1}')
|
||||
action, result = await agent.step()
|
||||
|
||||
print('Action:', action)
|
||||
print('Result:', result)
|
||||
|
||||
if result.done:
|
||||
print('\n✅ Task completed successfully!')
|
||||
print('Extracted content:', result.extracted_content)
|
||||
break
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -1,36 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import asyncio
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from browser_use.agent.service import AgentService
|
||||
from browser_use.controller.service import ControllerService
|
||||
|
||||
people = ['Albert Einstein', 'Oprah Winfrey', 'Steve Jobs']
|
||||
task = f'Opening new tabs and searching for images for these people: {", ".join(people)}. Then ask me for further instructions.'
|
||||
controller = ControllerService(keep_open=True)
|
||||
model = ChatOpenAI(model='gpt-4o')
|
||||
agent = AgentService(task, model, controller, use_vision=True)
|
||||
|
||||
|
||||
async def main():
|
||||
max_steps = 50
|
||||
# Run the agent step by step
|
||||
for i in range(max_steps):
|
||||
print(f'\n📍 Step {i+1}')
|
||||
action, result = await agent.step()
|
||||
|
||||
print('Action:', action)
|
||||
print('Result:', result)
|
||||
|
||||
if result.done:
|
||||
print('\n✅ Task completed successfully!')
|
||||
print('Extracted content:', result.extracted_content)
|
||||
break
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -4,12 +4,9 @@ Simple try of the agent.
|
||||
@dev You need to add OPENAI_API_KEY to your environment variables.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from browser_use.controller.service import ControllerService
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import asyncio
|
||||
@@ -18,19 +15,15 @@ from langchain_openai import ChatOpenAI
|
||||
|
||||
from browser_use import Agent
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
llm = ChatOpenAI(model='gpt-4o')
|
||||
agent = Agent(
|
||||
task='Opening new tabs to search for images of Albert Einstein, Oprah Winfrey, and Steve Jobs. Then ask user for further instructions.',
|
||||
llm=llm,
|
||||
task='Find a one-way flight from Bali to Oman on 12 January 2025 on Google Flights. Return me the cheapest option.',
|
||||
llm=ChatOpenAI(model='gpt-4o'),
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
result, history = await agent.run()
|
||||
print(result)
|
||||
print(history)
|
||||
await agent.run()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
"""
|
||||
Simple try of the agent.
|
||||
|
||||
@dev You need to add OPENAI_API_KEY to your environment variables.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import asyncio
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from browser_use import Agent
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
llm = ChatOpenAI(model='gpt-4o')
|
||||
agent = Agent(
|
||||
task='Apply to 2025 batch of Talent Kick. Use dummy data. Here is some information about me: Name: John Doe, Email: john.doe@example.com, Phone: +1234567890, LinkedIn: https://www.linkedin.com/in/john-doe, Github: https://github.com/john-doe, Twitter: https://twitter.com/john-doe, StackOverflow: https://stackoverflow.com/users/123456/john-doe, Youtube: https://www.youtube.com/user/john-doe, Education: BSc Computer Science from MIT (2020-2024), Work Experience: Software Engineer at Google (2024-present), Skills: Python, JavaScript, React, Node.js, AWS, Docker, Kubernetes, Projects: Built an AI-powered chatbot with 10k+ users, Created an open-source library with 1k+ stars on Github, Achievements: Won first place in MIT Hackathon 2023, Published paper on ML at ICML 2024, Languages: English (Native), Spanish (Fluent), Mandarin (Basic), Interests: AI/ML, Open Source, Competitive Programming, Hobbies: Playing guitar, Rock climbing, Chess',
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
await agent.run()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -17,9 +17,7 @@ import argparse
|
||||
import asyncio
|
||||
|
||||
from browser_use import Agent
|
||||
from browser_use.controller.service import ControllerService
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
from browser_use.controller.service import Controller
|
||||
|
||||
|
||||
def get_llm(provider: str):
|
||||
@@ -50,14 +48,13 @@ llm = get_llm(args.provider)
|
||||
agent = Agent(
|
||||
task=args.query,
|
||||
llm=llm,
|
||||
controller=ControllerService(keep_open=True),
|
||||
controller=Controller(keep_open=True),
|
||||
# save_conversation_path='./tmp/try_flight/',
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
result, history = await agent.run()
|
||||
print(result)
|
||||
print(history)
|
||||
await agent.run()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import asyncio
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
from browser_use.agent.service import AgentService
|
||||
from browser_use.controller.service import ControllerService
|
||||
|
||||
task = 'Open 3 wikipedia pages in different tabs and summarize the content of all pages.'
|
||||
controller = ControllerService()
|
||||
model = ChatAnthropic(
|
||||
model_name='claude-3-5-sonnet-20240620', timeout=25, stop=None, temperature=0.3
|
||||
)
|
||||
agent = AgentService(task, model, controller, use_vision=True)
|
||||
|
||||
|
||||
async def main():
|
||||
max_steps = 50
|
||||
# Run the agent step by step
|
||||
for i in range(max_steps):
|
||||
print(f'\n📍 Step {i+1}')
|
||||
action, result = await agent.step()
|
||||
|
||||
print('Action:', action)
|
||||
print('Result:', result)
|
||||
|
||||
if result.done:
|
||||
print('\n✅ Task completed successfully!')
|
||||
print('Extracted content:', result.extracted_content)
|
||||
break
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -4,7 +4,7 @@ description = "Let LLMs interact with websites through a simple interface"
|
||||
authors = [
|
||||
{ name = "Gregor Zunic" }
|
||||
]
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
classifiers = [
|
||||
@@ -12,24 +12,11 @@ classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
dynamic = ["dependencies", "optional-dependencies"]
|
||||
|
||||
dependencies = [
|
||||
"MainContentExtractor>=0.0.4",
|
||||
"Selenium-Screenshot>=2.1.0",
|
||||
"beautifulsoup4>=4.12.3",
|
||||
"langchain>=0.3.7",
|
||||
"langchain-openai>=0.2.5",
|
||||
"langchain-anthropic>=0.2.4",
|
||||
"langchain-fireworks>=0.2.5",
|
||||
"pydantic>=2.9.2",
|
||||
"pytest>=8.3.3",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"python-dotenv>=1.0.1",
|
||||
"requests>=2.32.3",
|
||||
"selenium>=4.26.1",
|
||||
"webdriver-manager>=4.0.2",
|
||||
"lxml_html_clean>=0.3.1"
|
||||
]
|
||||
[tool.setuptools.dynamic]
|
||||
dependencies = {file = ["requirements.txt"]}
|
||||
optional-dependencies = {dev = {file = ["requirements-dev.txt"]}}
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
@@ -42,9 +29,3 @@ docstring-code-format = true
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"tokencost>=0.1.16",
|
||||
"hatch>=1.13.0",
|
||||
]
|
||||
|
||||
29
pytest.ini
29
pytest.ini
@@ -1,2 +1,29 @@
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
markers =
|
||||
slow: marks tests as slow (deselect with '-m "not slow"')
|
||||
integration: marks tests as integration tests
|
||||
unit: marks tests as unit tests
|
||||
|
||||
testpaths =
|
||||
tests
|
||||
|
||||
python_files =
|
||||
test_*.py
|
||||
*_test.py
|
||||
|
||||
addopts =
|
||||
-v
|
||||
--strict-markers
|
||||
--tb=short
|
||||
|
||||
asyncio_mode = auto
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
log_cli = true
|
||||
; log_cli_level = DEBUG
|
||||
log_cli_format = %(levelname)-8s [%(name)s] %(message)s
|
||||
filterwarnings =
|
||||
ignore::pytest.PytestDeprecationWarning
|
||||
ignore::DeprecationWarning
|
||||
|
||||
log_level = INFO
|
||||
|
||||
|
||||
3
requirements-dev.txt
Normal file
3
requirements-dev.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
tokencost>=0.1.16
|
||||
hatch>=1.13.0
|
||||
build>=1.2.2
|
||||
@@ -11,4 +11,4 @@ pytest-asyncio>=0.24.0
|
||||
python-dotenv>=1.0.1
|
||||
requests>=2.32.3
|
||||
selenium>=4.26.1
|
||||
webdriver-manager>=4.0.2
|
||||
webdriver-manager>=4.0.2
|
||||
16858
tests/mind2web_data/processed.json
Normal file
16858
tests/mind2web_data/processed.json
Normal file
File diff suppressed because it is too large
Load Diff
208
tests/test_agent_actions.py
Normal file
208
tests/test_agent_actions.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from browser_use.agent.service import Agent
|
||||
from browser_use.controller.service import Controller
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm():
|
||||
"""Initialize language model for testing"""
|
||||
|
||||
# return ChatAnthropic(model_name='claude-3-5-sonnet-20240620', timeout=25, stop=None)
|
||||
return ChatOpenAI(model='gpt-4o')
|
||||
# return ChatOpenAI(model='gpt-4o-mini')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def agent_with_controller():
|
||||
"""Create agent with controller for testing"""
|
||||
controller = Controller(keep_open=False)
|
||||
print('init controller')
|
||||
try:
|
||||
yield controller
|
||||
finally:
|
||||
if controller.browser:
|
||||
controller.browser.close(force=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ecommerce_interaction(llm, agent_with_controller):
|
||||
"""Test complex ecommerce interaction sequence"""
|
||||
agent = Agent(
|
||||
task="Go to amazon.com, search for 'laptop', filter by 4+ stars, and find the price of the first result",
|
||||
llm=llm,
|
||||
controller=agent_with_controller,
|
||||
save_conversation_path='tmp/test_ecommerce_interaction/conversation',
|
||||
)
|
||||
|
||||
history = await agent.run(max_steps=20)
|
||||
|
||||
# Verify sequence of actions
|
||||
action_sequence = []
|
||||
for h in history:
|
||||
action = getattr(h.model_output, 'action', None)
|
||||
if action and (getattr(action, 'go_to_url', None) or getattr(action, 'open_tab', None)):
|
||||
action_sequence.append('navigate')
|
||||
elif action and getattr(action, 'input_text', None):
|
||||
action_sequence.append('input')
|
||||
# Check that the input is 'laptop'
|
||||
inp = action.input_text.text.lower()
|
||||
if inp == 'laptop':
|
||||
action_sequence.append('input_exact_correct')
|
||||
elif 'laptop' in inp:
|
||||
action_sequence.append('correct_in_input')
|
||||
else:
|
||||
action_sequence.append('incorrect_input')
|
||||
|
||||
elif action and getattr(action, 'click_element', None):
|
||||
action_sequence.append('click')
|
||||
|
||||
if action is None:
|
||||
print(h.result)
|
||||
print(h.model_output)
|
||||
|
||||
# Verify essential steps were performed
|
||||
assert 'navigate' in action_sequence # Navigated to Amazon
|
||||
assert 'input' in action_sequence # Entered search term
|
||||
assert 'click' in action_sequence # Clicked search/filter
|
||||
assert 'input_exact_correct' in action_sequence or 'correct_in_input' in action_sequence
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_recovery(llm, agent_with_controller):
|
||||
"""Test agent's ability to recover from errors"""
|
||||
agent = Agent(
|
||||
task='Navigate to nonexistent-site.com and then recover by going to google.com',
|
||||
llm=llm,
|
||||
controller=agent_with_controller,
|
||||
)
|
||||
|
||||
history = await agent.run(max_steps=10)
|
||||
|
||||
recovery_action = next(
|
||||
(
|
||||
h
|
||||
for h in history
|
||||
if h.model_output
|
||||
and getattr(h.model_output, 'action', None)
|
||||
and getattr(h.model_output.action, 'go_to_url', None)
|
||||
and getattr(h.model_output.action.go_to_url, 'url', '').endswith('google.com') # type: ignore -> pretty weird way to do this
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert recovery_action is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_contact_email(llm, agent_with_controller):
|
||||
"""Test agent's ability to find contact email on a website"""
|
||||
agent = Agent(
|
||||
task='Go to https://browser-use.com/ and find out the contact email',
|
||||
llm=llm,
|
||||
controller=agent_with_controller,
|
||||
)
|
||||
|
||||
history = await agent.run(max_steps=10)
|
||||
|
||||
# Verify the agent found the contact email
|
||||
email_action = next(
|
||||
(
|
||||
h
|
||||
for h in history
|
||||
if h.result.extracted_content and 'info@browser-use.com' in h.result.extracted_content
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert email_action is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_finds_installation_command(llm, agent_with_controller):
|
||||
"""Test agent's ability to find the pip installation command for browser-use on the web"""
|
||||
agent = Agent(
|
||||
task='Find the pip installation command for the browser-use repo',
|
||||
llm=llm,
|
||||
controller=agent_with_controller,
|
||||
)
|
||||
|
||||
history = await agent.run(max_steps=10)
|
||||
|
||||
# Verify the agent found the correct installation command
|
||||
install_command_action = next(
|
||||
(
|
||||
h
|
||||
for h in history
|
||||
if h.result.extracted_content
|
||||
and 'pip install browser-use' in h.result.extracted_content
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert install_command_action is not None
|
||||
|
||||
|
||||
class CaptchaTest(BaseModel):
|
||||
name: str
|
||||
url: str
|
||||
success_text: str
|
||||
additional_text: str | None = None
|
||||
|
||||
|
||||
# pytest tests/test_agent_actions.py -v -k "test_captcha_solver" --capture=no --log-cli-level=INFO
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'captcha',
|
||||
[
|
||||
# good test for num_clicks
|
||||
CaptchaTest(
|
||||
name='Rotate Captcha',
|
||||
url='https://2captcha.com/demo/rotatecaptcha',
|
||||
success_text='Captcha is passed successfully',
|
||||
additional_text='Use num_clicks with number to click multiple times at once in same direction. click done when image is exact correct position.',
|
||||
),
|
||||
CaptchaTest(
|
||||
name='Text Captcha',
|
||||
url='https://2captcha.com/demo/text',
|
||||
success_text='Captcha is passed successfully!',
|
||||
),
|
||||
CaptchaTest(
|
||||
name='Basic Captcha',
|
||||
url='https://captcha.com/demos/features/captcha-demo.aspx',
|
||||
success_text='Correct!',
|
||||
),
|
||||
CaptchaTest(
|
||||
name='MT Captcha',
|
||||
url='https://2captcha.com/demo/mtcaptcha',
|
||||
success_text='Verified Successfully',
|
||||
additional_text='Stop when you solved it successfully.',
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_captcha_solver(llm, agent_with_controller, captcha: CaptchaTest):
|
||||
"""Test agent's ability to solve different types of captchas"""
|
||||
agent = Agent(
|
||||
task=f'Go to {captcha.url} and solve the captcha. {captcha.additional_text}',
|
||||
llm=llm,
|
||||
controller=agent_with_controller,
|
||||
)
|
||||
|
||||
history = await agent.run(max_steps=10)
|
||||
|
||||
# Verify the agent solved the captcha
|
||||
solved = False
|
||||
for h in history:
|
||||
last = h.state.items
|
||||
if any(captcha.success_text in item.text for item in last):
|
||||
solved = True
|
||||
break
|
||||
assert solved, f'Failed to solve {captcha.name}'
|
||||
|
||||
|
||||
# python -m pytest tests/test_agent_actions.py -v --capture=no
|
||||
|
||||
|
||||
# pytest tests/test_agent_actions.py -v -k "test_captcha_solver" --capture=no --log-cli-level=INFO
|
||||
157
tests/test_core_functionality.py
Normal file
157
tests/test_core_functionality.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from browser_use.agent.service import Agent
|
||||
from browser_use.controller.service import Controller
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm():
|
||||
"""Initialize language model for testing"""
|
||||
return ChatOpenAI(model='gpt-4o') # Use appropriate model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def controller():
|
||||
"""Initialize the controller"""
|
||||
controller = Controller()
|
||||
try:
|
||||
yield controller
|
||||
finally:
|
||||
if controller.browser:
|
||||
controller.browser.close(force=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_google(llm, controller):
|
||||
"""Test 'Search Google' action"""
|
||||
agent = Agent(
|
||||
task="Search Google for 'OpenAI'.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=2)
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
assert 'search_google' in action_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_go_to_url(llm, controller):
|
||||
"""Test 'Navigate to URL' action"""
|
||||
agent = Agent(
|
||||
task="Navigate to 'https://www.python.org'.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=2)
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
assert 'go_to_url' in action_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_go_back(llm, controller):
|
||||
"""Test 'Go back' action"""
|
||||
agent = Agent(
|
||||
task="Go to 'https://www.example.com', then go back.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=3)
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
assert 'go_to_url' in action_names
|
||||
assert 'go_back' in action_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_click_element(llm, controller):
|
||||
"""Test 'Click element' action"""
|
||||
agent = Agent(
|
||||
task="Go to 'https://www.python.org' and click on the first link.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=4)
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
assert 'go_to_url' in action_names
|
||||
assert 'click_element' in action_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_text(llm, controller):
|
||||
"""Test 'Input text' action"""
|
||||
agent = Agent(
|
||||
task="Go to 'https://www.google.com' and input 'OpenAI' into the search box.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=4)
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
assert 'go_to_url' in action_names
|
||||
assert 'input_text' in action_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_tab(llm, controller):
|
||||
"""Test 'Switch tab' action"""
|
||||
agent = Agent(
|
||||
task="Open new tabs with 'https://www.google.com' and 'https://www.wikipedia.org', then switch to the first tab.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=6)
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
open_tab_count = action_names.count('open_tab')
|
||||
assert open_tab_count >= 2
|
||||
assert 'switch_tab' in action_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_new_tab(llm, controller):
|
||||
"""Test 'Open new tab' action"""
|
||||
agent = Agent(
|
||||
task="Open a new tab and go to 'https://www.example.com'.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=3)
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
assert 'open_tab' in action_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_page_content(llm, controller):
|
||||
"""Test 'Extract page content' action"""
|
||||
agent = Agent(
|
||||
task="Go to 'https://www.example.com' and extract the page content.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=3)
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
assert 'go_to_url' in action_names
|
||||
assert 'extract_content' in action_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_done_action(llm, controller):
|
||||
"""Test 'Complete task' action"""
|
||||
agent = Agent(
|
||||
task="Navigate to 'https://www.example.com' and signal that the task is done.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=3)
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
assert 'go_to_url' in action_names
|
||||
assert 'done' in action_names
|
||||
116
tests/test_mind2web.py
Normal file
116
tests/test_mind2web.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Test browser automation using Mind2Web dataset tasks with pytest framework.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from browser_use.agent.service import Agent
|
||||
from browser_use.controller.service import Controller
|
||||
from browser_use.utils import logger
|
||||
|
||||
# Constants
|
||||
MAX_STEPS = 50
|
||||
TEST_SUBSET_SIZE = 10
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def test_cases() -> List[Dict[str, Any]]:
|
||||
"""Load test cases from Mind2Web dataset"""
|
||||
file_path = os.path.join(os.path.dirname(__file__), 'mind2web_data/processed.json')
|
||||
logger.info(f'Loading test cases from {file_path}')
|
||||
|
||||
with open(file_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
subset = data[:TEST_SUBSET_SIZE]
|
||||
logger.info(f'Loaded {len(subset)}/{len(data)} test cases')
|
||||
return subset
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def llm():
|
||||
"""Initialize the language model"""
|
||||
return ChatOpenAI(model='gpt-4o')
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
async def controller():
|
||||
"""Initialize the controller"""
|
||||
controller = Controller()
|
||||
try:
|
||||
yield controller
|
||||
finally:
|
||||
if controller.browser:
|
||||
controller.browser.close(force=True)
|
||||
|
||||
|
||||
# run with: pytest -s -v tests/test_mind2web.py:test_random_samples
|
||||
@pytest.mark.asyncio
|
||||
async def test_random_samples(test_cases: List[Dict[str, Any]], llm, controller, validator):
|
||||
"""Test a random sampling of tasks across different websites"""
|
||||
import random
|
||||
|
||||
logger.info('=== Testing Random Samples ===')
|
||||
|
||||
# Take random samples
|
||||
samples = random.sample(test_cases, 1)
|
||||
|
||||
for i, case in enumerate(samples, 1):
|
||||
task = f"Go to {case['website']}.com and {case['confirmed_task']}"
|
||||
logger.info(f'--- Random Sample {i}/{len(samples)} ---')
|
||||
logger.info(f'Task: {task}\n')
|
||||
|
||||
agent = Agent(task, llm, controller)
|
||||
|
||||
await agent.run()
|
||||
|
||||
logger.info('Validating random sample task...')
|
||||
|
||||
# TODO: Validate the task
|
||||
|
||||
|
||||
def test_dataset_integrity(test_cases):
|
||||
"""Test the integrity of the test dataset"""
|
||||
logger.info('\n=== Testing Dataset Integrity ===')
|
||||
|
||||
required_fields = ['website', 'confirmed_task', 'action_reprs']
|
||||
missing_fields = []
|
||||
|
||||
logger.info(f'Checking {len(test_cases)} test cases for required fields')
|
||||
|
||||
for i, case in enumerate(test_cases, 1):
|
||||
logger.debug(f'Checking case {i}/{len(test_cases)}')
|
||||
|
||||
for field in required_fields:
|
||||
if field not in case:
|
||||
missing_fields.append(f'Case {i}: {field}')
|
||||
logger.warning(f"Missing field '{field}' in case {i}")
|
||||
|
||||
# Type checks
|
||||
if not isinstance(case.get('confirmed_task'), str):
|
||||
logger.error(f"Case {i}: 'confirmed_task' must be string")
|
||||
assert False, 'Task must be string'
|
||||
|
||||
if not isinstance(case.get('action_reprs'), list):
|
||||
logger.error(f"Case {i}: 'action_reprs' must be list")
|
||||
assert False, 'Actions must be list'
|
||||
|
||||
if len(case.get('action_reprs', [])) == 0:
|
||||
logger.error(f"Case {i}: 'action_reprs' must not be empty")
|
||||
assert False, 'Must have at least one action'
|
||||
|
||||
if missing_fields:
|
||||
logger.error('Dataset integrity check failed')
|
||||
assert False, f'Missing fields: {missing_fields}'
|
||||
else:
|
||||
logger.info('✅ Dataset integrity check passed')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
176
tests/test_self_registered_actions.py
Normal file
176
tests/test_self_registered_actions.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from browser_use.agent.service import Agent
|
||||
from browser_use.controller.service import Controller
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm():
|
||||
"""Initialize the language model"""
|
||||
return ChatOpenAI(model='gpt-4o') # Use appropriate model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def controller():
|
||||
"""Initialize the controller with self-registered actions"""
|
||||
controller = Controller()
|
||||
|
||||
# Define custom actions without Pydantic models
|
||||
@controller.action('Print a message')
|
||||
def print_message(message: str):
|
||||
print(f'Message: {message}')
|
||||
return f'Printed message: {message}'
|
||||
|
||||
@controller.action('Add two numbers')
|
||||
def add_numbers(a: int, b: int):
|
||||
result = a + b
|
||||
return f'The sum is {result}'
|
||||
|
||||
@controller.action('Concatenate strings')
|
||||
def concatenate_strings(str1: str, str2: str):
|
||||
result = str1 + str2
|
||||
return f'Concatenated string: {result}'
|
||||
|
||||
# Define Pydantic models
|
||||
class SimpleModel(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
|
||||
class NestedModel(BaseModel):
|
||||
user: SimpleModel
|
||||
address: Address
|
||||
|
||||
# Add actions with Pydantic model arguments
|
||||
@controller.action('Process simple model', param_model=SimpleModel)
|
||||
def process_simple_model(model: SimpleModel):
|
||||
return f'Processed {model.name}, age {model.age}'
|
||||
|
||||
@controller.action('Process nested model', param_model=NestedModel)
|
||||
def process_nested_model(model: NestedModel):
|
||||
user_info = f'{model.user.name}, age {model.user.age}'
|
||||
address_info = f'{model.address.street}, {model.address.city}'
|
||||
return f'Processed user {user_info} at address {address_info}'
|
||||
|
||||
@controller.action('Process multiple models')
|
||||
def process_multiple_models(model1: SimpleModel, model2: Address):
|
||||
return f'Processed {model1.name} living at {model2.street}, {model2.city}'
|
||||
|
||||
try:
|
||||
yield controller
|
||||
finally:
|
||||
if controller.browser:
|
||||
controller.browser.close(force=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_self_registered_actions_no_pydantic(llm, controller):
|
||||
"""Test self-registered actions with individual arguments"""
|
||||
agent = Agent(
|
||||
task="First, print the message 'Hello, World!'. Then, add 10 and 20. Next, concatenate 'foo' and 'bar'.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=10)
|
||||
# Check that custom actions were executed
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
|
||||
assert 'print_message' in action_names
|
||||
assert 'add_numbers' in action_names
|
||||
assert 'concatenate_strings' in action_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_arguments_actions(llm, controller):
|
||||
"""Test actions with mixed argument types"""
|
||||
|
||||
# Define another action during the test
|
||||
@controller.action('Calculate the area of a rectangle')
|
||||
def calculate_area(length: float, width: float):
|
||||
area = length * width
|
||||
return f'The area is {area}'
|
||||
|
||||
agent = Agent(
|
||||
task='Calculate the area of a rectangle with length 5.5 and width 3.2.',
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=5)
|
||||
|
||||
# Check that the action was executed
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
|
||||
assert 'calculate_area' in action_names
|
||||
# check result
|
||||
correct = 'The area is 17.6'
|
||||
assert correct in [h.result.extracted_content for h in history if h.model_output]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pydantic_simple_model(llm, controller):
|
||||
"""Test action with a simple Pydantic model argument"""
|
||||
agent = Agent(
|
||||
task="Process a simple model with name 'Alice' and age 30.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=5)
|
||||
|
||||
# Check that the action was executed
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
|
||||
assert 'process_simple_model' in action_names
|
||||
correct = 'Processed Alice, age 30'
|
||||
assert correct in [h.result.extracted_content for h in history if h.model_output]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pydantic_nested_model(llm, controller):
|
||||
"""Test action with a nested Pydantic model argument"""
|
||||
agent = Agent(
|
||||
task="Process a nested model with user name 'Bob', age 25, living at '123 Maple St', 'Springfield'.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=5)
|
||||
|
||||
# Check that the action was executed
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
|
||||
assert 'process_nested_model' in action_names
|
||||
correct = 'Processed user Bob, age 25 at address 123 Maple St, Springfield'
|
||||
assert correct in [h.result.extracted_content for h in history if h.model_output]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pydantic_multiple_models(llm, controller):
|
||||
"""Test action with multiple Pydantic model arguments"""
|
||||
agent = Agent(
|
||||
task="Process models with user name 'Carol', age 28, living at '456 Oak Ave', 'Shelbyville'.",
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
history = await agent.run(max_steps=5)
|
||||
|
||||
# Check that the action was executed
|
||||
actions = [h.model_output.action for h in history if h.model_output and h.model_output.action]
|
||||
action_names = [list(action.model_dump(exclude_unset=True).keys())[0] for action in actions]
|
||||
|
||||
assert 'process_multiple_models' in action_names
|
||||
correct = 'Processed Carol living at 456 Oak Ave, Shelbyville'
|
||||
assert correct in [h.result.extracted_content for h in history if h.model_output]
|
||||
|
||||
|
||||
# run this file with:
|
||||
# pytest tests/test_self_registered_actions.py --capture=no
|
||||
48
tests/test_stress.py
Normal file
48
tests/test_stress.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from browser_use.agent.service import Agent
|
||||
from browser_use.controller.service import Controller
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm():
|
||||
"""Initialize the language model"""
|
||||
return ChatOpenAI(model='gpt-4o') # Use appropriate model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def controller():
|
||||
"""Initialize the controller"""
|
||||
controller = Controller()
|
||||
try:
|
||||
yield controller
|
||||
finally:
|
||||
if controller.browser:
|
||||
controller.browser.close(force=True)
|
||||
|
||||
|
||||
# should get rate limited
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_10_tabs_and_extract_content(llm, controller):
|
||||
"""Stress test: Open 10 tabs and extract content"""
|
||||
agent = Agent(
|
||||
task='Open new tabs with example.com, example.net, example.org, and seven more example sites. Then, extract the content from each.',
|
||||
llm=llm,
|
||||
controller=controller,
|
||||
)
|
||||
start_time = time.time()
|
||||
history = await agent.run(max_steps=50)
|
||||
end_time = time.time()
|
||||
|
||||
total_time = end_time - start_time
|
||||
|
||||
print(f'Total time: {total_time:.2f} seconds')
|
||||
# Check for errors
|
||||
errors = [h.result.error for h in history if h.result and h.result.error]
|
||||
assert len(errors) == 0, 'Errors occurred during the test'
|
||||
# check if 10 tabs were opened
|
||||
assert len(controller.browser.current_state.tabs) >= 10, '10 tabs were not opened'
|
||||
Reference in New Issue
Block a user