merge new version

This commit is contained in:
katiue
2025-01-13 14:10:33 +07:00
4 changed files with 305 additions and 104 deletions

View File

@@ -20,9 +20,11 @@ from browser_use.agent.views import (
ActionResult,
AgentHistoryList,
AgentOutput,
AgentHistory,
)
from browser_use.browser.browser import Browser
from browser_use.browser.context import BrowserContext
from browser_use.browser.views import BrowserStateHistory
from browser_use.controller.service import Controller
from browser_use.telemetry.views import (
AgentEndTelemetryEvent,
@@ -34,6 +36,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
BaseMessage,
)
from src.utils.agent_state import AgentState
from .custom_massage_manager import CustomMassageManager
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
@@ -72,6 +75,7 @@ class CustomAgent(Agent):
max_error_length: int = 400,
max_actions_per_step: int = 10,
tool_call_in_content: bool = True,
agent_state: AgentState = None,
):
super().__init__(
task=task,
@@ -92,6 +96,7 @@ class CustomAgent(Agent):
tool_call_in_content=tool_call_in_content,
)
self.add_infos = add_infos
self.agent_state = agent_state
self.message_manager = CustomMassageManager(
llm=self.llm,
task=self.task,
@@ -367,9 +372,21 @@ class CustomAgent(Agent):
)
for step in range(max_steps):
# 1) Check if stop requested
if self.agent_state and self.agent_state.is_stop_requested():
logger.info("🛑 Stop requested by user")
self._create_stop_history_item()
break
# 2) Store last valid state before step
if self.browser_context and self.agent_state:
state = await self.browser_context.get_state(use_vision=self.use_vision)
self.agent_state.set_last_valid_state(state)
if self._too_many_failures():
break
# 3) Do the step
await self.step(step_info)
if self.history.is_done():
@@ -403,3 +420,61 @@ class CustomAgent(Agent):
if self.generate_gif:
self.create_history_gif()
def _create_stop_history_item(self):
"""Create a history item for when the agent is stopped."""
try:
# Attempt to retrieve the last valid state from agent_state
state = None
if self.agent_state:
last_state = self.agent_state.get_last_valid_state()
if last_state:
# Convert to BrowserStateHistory
state = BrowserStateHistory(
url=getattr(last_state, 'url', ""),
title=getattr(last_state, 'title', ""),
tabs=getattr(last_state, 'tabs', []),
interacted_element=[None],
screenshot=getattr(last_state, 'screenshot', None)
)
else:
state = self._create_empty_state()
else:
state = self._create_empty_state()
# Create a final item in the agent history indicating done
stop_history = AgentHistory(
model_output=None,
state=state,
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
)
self.history.history.append(stop_history)
except Exception as e:
logger.error(f"Error creating stop history item: {e}")
# Create empty state as fallback
state = self._create_empty_state()
stop_history = AgentHistory(
model_output=None,
state=state,
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
)
self.history.history.append(stop_history)
def _convert_to_browser_state_history(self, browser_state):
return BrowserStateHistory(
url=getattr(browser_state, 'url', ""),
title=getattr(browser_state, 'title', ""),
tabs=getattr(browser_state, 'tabs', []),
interacted_element=[None],
screenshot=getattr(browser_state, 'screenshot', None)
)
def _create_empty_state(self):
return BrowserStateHistory(
url="",
title="",
tabs=[],
interacted_element=[None],
screenshot=None
)

30
src/utils/agent_state.py Normal file
View File

@@ -0,0 +1,30 @@
import asyncio
class AgentState:
_instance = None
def __init__(self):
if not hasattr(self, '_stop_requested'):
self._stop_requested = asyncio.Event()
self.last_valid_state = None # store the last valid browser state
def __new__(cls):
if cls._instance is None:
cls._instance = super(AgentState, cls).__new__(cls)
return cls._instance
def request_stop(self):
self._stop_requested.set()
def clear_stop(self):
self._stop_requested.clear()
self.last_valid_state = None
def is_stop_requested(self):
return self._stop_requested.is_set()
def set_last_valid_state(self, state):
self.last_valid_state = state
def get_last_valid_state(self):
return self.last_valid_state

View File

@@ -196,7 +196,7 @@ async def capture_screenshot(browser_context) -> str:
encoded = base64.b64encode(screenshot).decode('utf-8')
return f'<img src="data:image/jpeg;base64,{encoded}" style="width:80vw; height:90vh ; border:1px solid #ccc;">'
except Exception as e:
return f"<div class='error' style='width:80vw; height:90vh'>Waiting for browser session...</div>"
return f"<h1 class='error' style='width:80vw; height:90vh'>Waiting for browser session...</h1>"
except Exception as e:
return f"<div class='error' style='width:80vw; height:90vh'>Waiting for browser session...</div>"
return f"<h1 class='error' style='width:80vw; height:90vh'>Waiting for browser session...</h1>"

300
webui.py
View File

@@ -5,10 +5,20 @@
# @Project : browser-use-webui
# @FileName: webui.py
import pdb
import logging
from dotenv import load_dotenv
load_dotenv()
import os
import glob
import asyncio
import argparse
import os
logger = logging.getLogger(__name__)
import gradio as gr
from browser_use.agent.service import Agent
@@ -18,6 +28,8 @@ from browser_use.browser.context import (
BrowserContextConfig,
BrowserContextWindowSize,
)
from playwright.async_api import async_playwright
from src.utils.agent_state import AgentState
from src.utils import utils
from src.agent.custom_agent import CustomAgent
@@ -36,6 +48,36 @@ load_dotenv()
_global_browser = None
_global_browser_context = None
# Create the global agent state instance
_global_agent_state = AgentState()
async def stop_agent():
"""Request the agent to stop and update UI with enhanced feedback"""
global _global_agent_state, _global_browser_context, _global_browser
try:
# Request stop
_global_agent_state.request_stop()
# Update UI immediately
message = "Stop requested - the agent will halt at the next safe point"
logger.info(f"🛑 {message}")
# Return UI updates
return (
message, # errors_output
gr.update(value="Stopping...", interactive=False), # stop_button
gr.update(interactive=False), # run_button
)
except Exception as e:
error_msg = f"Error during stop: {str(e)}"
logger.error(error_msg)
return (
error_msg,
gr.update(value="Stop", interactive=True),
gr.update(interactive=True)
)
async def run_browser_agent(
agent_type,
llm_provider,
@@ -59,79 +101,108 @@ async def run_browser_agent(
max_actions_per_step,
tool_call_in_content
):
# Disable recording if the checkbox is unchecked
if not enable_recording:
save_recording_path = None
global _global_agent_state
_global_agent_state.clear_stop() # Clear any previous stop requests
# Ensure the recording directory exists if recording is enabled
if save_recording_path:
os.makedirs(save_recording_path, exist_ok=True)
try:
# Disable recording if the checkbox is unchecked
if not enable_recording:
save_recording_path = None
# Get the list of existing videos before the agent runs
existing_videos = set()
if save_recording_path:
existing_videos = set(
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
# Ensure the recording directory exists if recording is enabled
if save_recording_path:
os.makedirs(save_recording_path, exist_ok=True)
# Get the list of existing videos before the agent runs
existing_videos = set()
if save_recording_path:
existing_videos = set(
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
)
# Run the agent
llm = utils.get_llm_model(
provider=llm_provider,
model_name=llm_model_name,
temperature=llm_temperature,
base_url=llm_base_url,
api_key=llm_api_key,
)
if agent_type == "org":
final_result, errors, model_actions, model_thoughts, trace_file = await run_org_agent(
llm=llm,
use_own_browser=use_own_browser,
keep_browser_open=keep_browser_open,
headless=headless,
disable_security=disable_security,
window_w=window_w,
window_h=window_h,
save_recording_path=save_recording_path,
save_trace_path=save_trace_path,
task=task,
max_steps=max_steps,
use_vision=use_vision,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
)
elif agent_type == "custom":
final_result, errors, model_actions, model_thoughts, trace_file = await run_custom_agent(
llm=llm,
use_own_browser=use_own_browser,
keep_browser_open=keep_browser_open,
headless=headless,
disable_security=disable_security,
window_w=window_w,
window_h=window_h,
save_recording_path=save_recording_path,
save_trace_path=save_trace_path,
task=task,
add_infos=add_infos,
max_steps=max_steps,
use_vision=use_vision,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
)
else:
raise ValueError(f"Invalid agent type: {agent_type}")
# Get the list of videos after the agent runs (if recording is enabled)
latest_video = None
if save_recording_path:
new_videos = set(
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
)
if new_videos - existing_videos:
latest_video = list(new_videos - existing_videos)[0] # Get the first new video
return (
final_result,
errors,
model_actions,
model_thoughts,
latest_video,
trace_file,
gr.update(value="Stop", interactive=True), # Re-enable stop button
gr.update(value="Run", interactive=True) # Re-enable run button
)
# Run the agent
llm = utils.get_llm_model(
provider=llm_provider,
model_name=llm_model_name,
temperature=llm_temperature,
base_url=llm_base_url,
api_key=llm_api_key,
)
if agent_type == "org":
final_result, errors, model_actions, model_thoughts, recorded_files, trace_file = await run_org_agent(
llm=llm,
use_own_browser=use_own_browser,
keep_browser_open=keep_browser_open,
headless=headless,
disable_security=disable_security,
window_w=window_w,
window_h=window_h,
save_recording_path=save_recording_path,
save_trace_path=save_trace_path,
task=task,
max_steps=max_steps,
use_vision=use_vision,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
except Exception as e:
import traceback
traceback.print_exc()
errors = str(e) + "\n" + traceback.format_exc()
return (
'', # final_result
errors, # errors
'', # model_actions
'', # model_thoughts
None, # latest_video
None, # trace_file
gr.update(value="Stop", interactive=True), # Re-enable stop button
gr.update(value="Run", interactive=True) # Re-enable run button
)
elif agent_type == "custom":
final_result, errors, model_actions, model_thoughts, recorded_files, trace_file = await run_custom_agent(
llm=llm,
use_own_browser=use_own_browser,
keep_browser_open=keep_browser_open,
headless=headless,
disable_security=disable_security,
window_w=window_w,
window_h=window_h,
save_recording_path=save_recording_path,
save_trace_path=save_trace_path,
task=task,
add_infos=add_infos,
max_steps=max_steps,
use_vision=use_vision,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
)
else:
raise ValueError(f"Invalid agent type: {agent_type}")
# Get the list of videos after the agent runs (if recording is enabled)
latest_video = None
if save_recording_path:
new_videos = set(
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
)
if new_videos - existing_videos:
latest_video = list(new_videos - existing_videos)[0] # Get the first new video
return final_result, errors, model_actions, model_thoughts, latest_video, trace_file
async def run_org_agent(
llm,
@@ -150,7 +221,11 @@ async def run_org_agent(
tool_call_in_content
):
try:
global _global_browser, _global_browser_context
global _global_browser, _global_browser_context, _global_agent_state
# Clear any previous stop request
_global_agent_state.clear_stop()
if use_own_browser:
chrome_path = os.getenv("CHROME_PATH", None)
if chrome_path == "":
@@ -196,15 +271,14 @@ async def run_org_agent(
model_actions = history.model_actions()
model_thoughts = history.model_thoughts()
recorded_files = get_latest_files(save_recording_path)
trace_file = get_latest_files(save_trace_path)
return final_result, errors, model_actions, model_thoughts, recorded_files.get('.webm'), trace_file.get('.zip')
return final_result, errors, model_actions, model_thoughts, trace_file.get('.zip')
except Exception as e:
import traceback
traceback.print_exc()
errors = str(e) + "\n" + traceback.format_exc()
return '', errors, '', '', None, None
return '', errors, '', '', None
finally:
# Handle cleanup based on persistence configuration
if not keep_browser_open:
@@ -234,7 +308,10 @@ async def run_custom_agent(
tool_call_in_content
):
try:
global _global_browser, _global_browser_context
global _global_browser, _global_browser_context, _global_agent_state
# Clear any previous stop request
_global_agent_state.clear_stop()
if use_own_browser:
chrome_path = os.getenv("CHROME_PATH", None)
@@ -279,7 +356,8 @@ async def run_custom_agent(
controller=controller,
system_prompt_class=CustomSystemPrompt,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
tool_call_in_content=tool_call_in_content,
agent_state=_global_agent_state
)
history = await agent.run(max_steps=max_steps)
@@ -288,15 +366,14 @@ async def run_custom_agent(
model_actions = history.model_actions()
model_thoughts = history.model_thoughts()
recorded_files = get_latest_files(save_recording_path)
trace_file = get_latest_files(save_trace_path)
return final_result, errors, model_actions, model_thoughts, recorded_files.get('.webm'), trace_file.get('.zip')
return final_result, errors, model_actions, model_thoughts, trace_file.get('.zip')
except Exception as e:
import traceback
traceback.print_exc()
errors = str(e) + "\n" + traceback.format_exc()
return '', errors, '', '', None, None
return '', errors, '', '', None
finally:
# Handle cleanup based on persistence configuration
if not keep_browser_open:
@@ -355,7 +432,9 @@ async def run_with_stream(
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
)
yield result
# Add HTML content at the start of the result array
html_content = "<h1 style='width:80vw; height:90vh'>Using browser...</h1>"
yield [html_content] + list(result)
else:
try:
# Run the browser agent in the background
@@ -386,16 +465,17 @@ async def run_with_stream(
)
# Initialize values for streaming
html_content = "<div style='width:80vw; height:90vh'>Using browser...</div>"
html_content = "<h1 style='width:80vw; height:90vh'>Using browser...</h1>"
final_result = errors = model_actions = model_thoughts = ""
latest_videos = trace = None
# Periodically update the stream while the agent task is running
while not agent_task.done():
try:
html_content = await capture_screenshot(_global_browser_context)
except Exception as e:
html_content = f"<div style='width:80vw; height:90vh'>Waiting for browser session...</div>"
html_content = f"<h1 style='width:80vw; height:90vh'>Waiting for browser session...</h1>"
yield [
html_content,
@@ -405,14 +485,16 @@ async def run_with_stream(
model_thoughts,
latest_videos,
trace,
gr.update(value="Stop", interactive=True), # Re-enable stop button
gr.update(value="Run", interactive=True) # Re-enable run button
]
await asyncio.sleep(0.01)
# Once the agent task completes, get the results
try:
result = await agent_task
if isinstance(result, tuple) and len(result) == 6:
final_result, errors, model_actions, model_thoughts, latest_videos, trace = result
if isinstance(result, tuple) and len(result) == 8:
final_result, errors, model_actions, model_thoughts, latest_videos, trace, stop_button, run_button = result
else:
errors = "Unexpected result format from agent"
except Exception as e:
@@ -426,18 +508,22 @@ async def run_with_stream(
model_thoughts,
latest_videos,
trace,
stop_button,
run_button
]
except Exception as e:
import traceback
yield [
f"<div style='width:80vw; height:90vh'>Waiting for browser session...</div>",
f"<h1 style='width:80vw; height:90vh'>Waiting for browser session...</h1>",
"",
f"Error: {str(e)}\n{traceback.format_exc()}",
"",
"",
None,
None,
gr.update(value="Stop", interactive=True), # Re-enable stop button
gr.update(value="Run", interactive=True) # Re-enable run button
]
# Define the theme map globally
@@ -654,7 +740,7 @@ def create_ui(theme_name="Ocean"):
with gr.Row():
browser_view = gr.HTML(
value="<div style='width:80vw; height:90vh'>Waiting for browser session...</div>",
value="<h1 style='width:80vw; height:90vh'>Waiting for browser session...</h1>",
label="Live Browser View",
)
@@ -684,6 +770,35 @@ def create_ui(theme_name="Ocean"):
)
trace_file = gr.File(label="Trace File")
# Bind the stop button click event after errors_output is defined
stop_button.click(
fn=stop_agent,
inputs=[],
outputs=[errors_output, stop_button, run_button],
)
# Run button click handler
run_button.click(
fn=run_with_stream,
inputs=[
agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h, save_recording_path, save_trace_path,
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_call_in_content
],
outputs=[
browser_view, # Browser view
final_result_output, # Final result
errors_output, # Errors
model_actions_output, # Model actions
model_thoughts_output, # Model thoughts
recording_display, # Latest recording
trace_file, # Trace file
stop_button, # Stop button
run_button # Run button
],
)
with gr.TabItem("🎥 Recordings", id=6):
def list_recordings(save_recording_path):
if not os.path.exists(save_recording_path):
@@ -734,25 +849,6 @@ def create_ui(theme_name="Ocean"):
use_own_browser.change(fn=close_global_browser)
keep_browser_open.change(fn=close_global_browser)
# Run button click handler
run_button.click(
fn=run_with_stream,
inputs=[
agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h, save_recording_path, save_trace_path,
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_call_in_content
],
outputs=[
browser_view, # HTML view
final_result_output, # Final result
errors_output, # Errors
model_actions_output, # Model actions
model_thoughts_output, # Model thoughts
recording_display, # Video file (.webm)
trace_file # Trace file (.zip)
],
queue=True,
)
return demo