Merge pull request #2 from InCoB/activation-of-stop-button-and-cache

feat(custom-agent): Implement stop button functionality for custom mo…
This commit is contained in:
InCoB
2025-01-12 15:27:49 +00:00
committed by GitHub
3 changed files with 262 additions and 82 deletions

View File

@@ -20,9 +20,11 @@ from browser_use.agent.views import (
ActionResult,
AgentHistoryList,
AgentOutput,
AgentHistory,
)
from browser_use.browser.browser import Browser
from browser_use.browser.context import BrowserContext
from browser_use.browser.views import BrowserStateHistory
from browser_use.controller.service import Controller
from browser_use.telemetry.views import (
AgentEndTelemetryEvent,
@@ -34,6 +36,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
BaseMessage,
)
from src.utils.agent_state import AgentState
from .custom_massage_manager import CustomMassageManager
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
@@ -72,6 +75,7 @@ class CustomAgent(Agent):
max_error_length: int = 400,
max_actions_per_step: int = 10,
tool_call_in_content: bool = True,
agent_state: AgentState = None,
):
super().__init__(
task=task,
@@ -92,6 +96,7 @@ class CustomAgent(Agent):
tool_call_in_content=tool_call_in_content,
)
self.add_infos = add_infos
self.agent_state = agent_state
self.message_manager = CustomMassageManager(
llm=self.llm,
task=self.task,
@@ -367,9 +372,21 @@ class CustomAgent(Agent):
)
for step in range(max_steps):
# 1) Check if stop requested
if self.agent_state and self.agent_state.is_stop_requested():
logger.info("🛑 Stop requested by user")
self._create_stop_history_item()
break
# 2) Store last valid state before step
if self.browser_context and self.agent_state:
state = await self.browser_context.get_state(use_vision=self.use_vision)
self.agent_state.set_last_valid_state(state)
if self._too_many_failures():
break
# 3) Do the step
await self.step(step_info)
if self.history.is_done():
@@ -403,3 +420,61 @@ class CustomAgent(Agent):
if self.generate_gif:
self.create_history_gif()
def _create_stop_history_item(self):
"""Create a history item for when the agent is stopped."""
try:
# Attempt to retrieve the last valid state from agent_state
state = None
if self.agent_state:
last_state = self.agent_state.get_last_valid_state()
if last_state:
# Convert to BrowserStateHistory
state = BrowserStateHistory(
url=getattr(last_state, 'url', ""),
title=getattr(last_state, 'title', ""),
tabs=getattr(last_state, 'tabs', []),
interacted_element=[None],
screenshot=getattr(last_state, 'screenshot', None)
)
else:
state = self._create_empty_state()
else:
state = self._create_empty_state()
# Create a final item in the agent history indicating done
stop_history = AgentHistory(
model_output=None,
state=state,
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
)
self.history.history.append(stop_history)
except Exception as e:
logger.error(f"Error creating stop history item: {e}")
# Create empty state as fallback
state = self._create_empty_state()
stop_history = AgentHistory(
model_output=None,
state=state,
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
)
self.history.history.append(stop_history)
def _convert_to_browser_state_history(self, browser_state):
return BrowserStateHistory(
url=getattr(browser_state, 'url', ""),
title=getattr(browser_state, 'title', ""),
tabs=getattr(browser_state, 'tabs', []),
interacted_element=[None],
screenshot=getattr(browser_state, 'screenshot', None)
)
def _create_empty_state(self):
return BrowserStateHistory(
url="",
title="",
tabs=[],
interacted_element=[None],
screenshot=None
)

30
src/utils/agent_state.py Normal file
View File

@@ -0,0 +1,30 @@
import asyncio
class AgentState:
_instance = None
def __init__(self):
if not hasattr(self, '_stop_requested'):
self._stop_requested = asyncio.Event()
self.last_valid_state = None # store the last valid browser state
def __new__(cls):
if cls._instance is None:
cls._instance = super(AgentState, cls).__new__(cls)
return cls._instance
def request_stop(self):
self._stop_requested.set()
def clear_stop(self):
self._stop_requested.clear()
self.last_valid_state = None
def is_stop_requested(self):
return self._stop_requested.is_set()
def set_last_valid_state(self, state):
self.last_valid_state = state
def get_last_valid_state(self):
return self.last_valid_state

239
webui.py
View File

@@ -6,6 +6,7 @@
# @FileName: webui.py
import pdb
import logging
from dotenv import load_dotenv
@@ -13,6 +14,8 @@ load_dotenv()
import argparse
import os
logger = logging.getLogger(__name__)
import gradio as gr
import argparse
@@ -27,6 +30,7 @@ from browser_use.browser.context import (
BrowserContextWindowSize,
)
from playwright.async_api import async_playwright
from src.utils.agent_state import AgentState
from src.agent.custom_agent import CustomAgent
from src.agent.custom_prompts import CustomSystemPrompt
@@ -45,6 +49,36 @@ from browser_use.browser.context import BrowserContextConfig, BrowserContextWind
_global_browser = None
_global_browser_context = None
# Create the global agent state instance
_global_agent_state = AgentState()
async def stop_agent():
"""Request the agent to stop and update UI with enhanced feedback"""
global _global_agent_state, _global_browser_context, _global_browser
try:
# Request stop
_global_agent_state.request_stop()
# Update UI immediately
message = "Stop requested - the agent will halt at the next safe point"
logger.info(f"🛑 {message}")
# Return UI updates
return (
message, # errors_output
gr.update(value="Stopping...", interactive=False), # stop_button
gr.update(interactive=False), # run_button
)
except Exception as e:
error_msg = f"Error during stop: {str(e)}"
logger.error(error_msg)
return (
error_msg,
gr.update(value="Stop", interactive=True),
gr.update(interactive=True)
)
async def run_browser_agent(
agent_type,
llm_provider,
@@ -68,79 +102,105 @@ async def run_browser_agent(
max_actions_per_step,
tool_call_in_content
):
# Disable recording if the checkbox is unchecked
if not enable_recording:
save_recording_path = None
global _global_agent_state
_global_agent_state.clear_stop() # Clear any previous stop requests
# Ensure the recording directory exists if recording is enabled
if save_recording_path:
os.makedirs(save_recording_path, exist_ok=True)
try:
# Disable recording if the checkbox is unchecked
if not enable_recording:
save_recording_path = None
# Get the list of existing videos before the agent runs
existing_videos = set()
if save_recording_path:
existing_videos = set(
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
# Ensure the recording directory exists if recording is enabled
if save_recording_path:
os.makedirs(save_recording_path, exist_ok=True)
# Get the list of existing videos before the agent runs
existing_videos = set()
if save_recording_path:
existing_videos = set(
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
)
# Run the agent
llm = utils.get_llm_model(
provider=llm_provider,
model_name=llm_model_name,
temperature=llm_temperature,
base_url=llm_base_url,
api_key=llm_api_key,
)
if agent_type == "org":
final_result, errors, model_actions, model_thoughts = await run_org_agent(
llm=llm,
use_own_browser=use_own_browser,
keep_browser_open=keep_browser_open,
headless=headless,
disable_security=disable_security,
window_w=window_w,
window_h=window_h,
save_recording_path=save_recording_path,
save_trace_path=save_trace_path,
task=task,
max_steps=max_steps,
use_vision=use_vision,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
)
elif agent_type == "custom":
final_result, errors, model_actions, model_thoughts = await run_custom_agent(
llm=llm,
use_own_browser=use_own_browser,
keep_browser_open=keep_browser_open,
headless=headless,
disable_security=disable_security,
window_w=window_w,
window_h=window_h,
save_recording_path=save_recording_path,
save_trace_path=save_trace_path,
task=task,
add_infos=add_infos,
max_steps=max_steps,
use_vision=use_vision,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
)
else:
raise ValueError(f"Invalid agent type: {agent_type}")
# Get the list of videos after the agent runs (if recording is enabled)
latest_video = None
if save_recording_path:
new_videos = set(
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
)
if new_videos - existing_videos:
latest_video = list(new_videos - existing_videos)[0] # Get the first new video
return (
final_result,
errors,
model_actions,
model_thoughts,
latest_video,
gr.update(value="Stop", interactive=True), # Re-enable stop button
gr.update(value="Run", interactive=True) # Re-enable run button
)
# Run the agent
llm = utils.get_llm_model(
provider=llm_provider,
model_name=llm_model_name,
temperature=llm_temperature,
base_url=llm_base_url,
api_key=llm_api_key,
)
if agent_type == "org":
final_result, errors, model_actions, model_thoughts = await run_org_agent(
llm=llm,
use_own_browser=use_own_browser,
keep_browser_open=keep_browser_open,
headless=headless,
disable_security=disable_security,
window_w=window_w,
window_h=window_h,
save_recording_path=save_recording_path,
save_trace_path=save_trace_path,
task=task,
max_steps=max_steps,
use_vision=use_vision,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
except Exception as e:
import traceback
traceback.print_exc()
errors = str(e) + "\n" + traceback.format_exc()
return (
'', # final_result
errors, # errors
'', # model_actions
'', # model_thoughts
None, # latest_video
gr.update(value="Stop", interactive=True), # Re-enable stop button
gr.update(value="Run", interactive=True) # Re-enable run button
)
elif agent_type == "custom":
final_result, errors, model_actions, model_thoughts = await run_custom_agent(
llm=llm,
use_own_browser=use_own_browser,
keep_browser_open=keep_browser_open,
headless=headless,
disable_security=disable_security,
window_w=window_w,
window_h=window_h,
save_recording_path=save_recording_path,
save_trace_path=save_trace_path,
task=task,
add_infos=add_infos,
max_steps=max_steps,
use_vision=use_vision,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
)
else:
raise ValueError(f"Invalid agent type: {agent_type}")
# Get the list of videos after the agent runs (if recording is enabled)
latest_video = None
if save_recording_path:
new_videos = set(
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
)
if new_videos - existing_videos:
latest_video = list(new_videos - existing_videos)[0] # Get the first new video
return final_result, errors, model_actions, model_thoughts, latest_video
async def run_org_agent(
@@ -161,7 +221,11 @@ async def run_org_agent(
):
try:
global _global_browser, _global_browser_context
global _global_browser, _global_browser_context, _global_agent_state
# Clear any previous stop request
_global_agent_state.clear_stop()
if use_own_browser:
chrome_path = os.getenv("CHROME_PATH", None)
if chrome_path == "":
@@ -242,7 +306,10 @@ async def run_custom_agent(
tool_call_in_content
):
try:
global _global_browser, _global_browser_context
global _global_browser, _global_browser_context, _global_agent_state
# Clear any previous stop request
_global_agent_state.clear_stop()
if use_own_browser:
chrome_path = os.getenv("CHROME_PATH", None)
@@ -287,7 +354,8 @@ async def run_custom_agent(
controller=controller,
system_prompt_class=CustomSystemPrompt,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
tool_call_in_content=tool_call_in_content,
agent_state=_global_agent_state
)
history = await agent.run(max_steps=max_steps)
@@ -550,6 +618,24 @@ def create_ui(theme_name="Ocean"):
label="Model Thoughts", lines=3, show_label=True
)
# Bind the stop button click event after errors_output is defined
stop_button.click(
fn=stop_agent,
inputs=[],
outputs=[errors_output, stop_button, run_button],
)
# Run button click handler
run_button.click(
fn=run_browser_agent,
inputs=[
agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h, save_recording_path, save_trace_path,
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_call_in_content
],
outputs=[final_result_output, errors_output, model_actions_output, model_thoughts_output, recording_display, stop_button, run_button],
)
with gr.TabItem("🎥 Recordings", id=6):
def list_recordings(save_recording_path):
if not os.path.exists(save_recording_path):
@@ -601,17 +687,6 @@ def create_ui(theme_name="Ocean"):
use_own_browser.change(fn=close_global_browser)
keep_browser_open.change(fn=close_global_browser)
# Run button click handler
run_button.click(
fn=run_browser_agent,
inputs=[
agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h, save_recording_path, save_trace_path,
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_call_in_content
],
outputs=[final_result_output, errors_output, model_actions_output, model_thoughts_output, recording_display],
)
return demo
def main():