Merge pull request #2 from InCoB/activation-of-stop-button-and-cache
feat(custom-agent): Implement stop button functionality for custom mo…
This commit is contained in:
@@ -20,9 +20,11 @@ from browser_use.agent.views import (
|
||||
ActionResult,
|
||||
AgentHistoryList,
|
||||
AgentOutput,
|
||||
AgentHistory,
|
||||
)
|
||||
from browser_use.browser.browser import Browser
|
||||
from browser_use.browser.context import BrowserContext
|
||||
from browser_use.browser.views import BrowserStateHistory
|
||||
from browser_use.controller.service import Controller
|
||||
from browser_use.telemetry.views import (
|
||||
AgentEndTelemetryEvent,
|
||||
@@ -34,6 +36,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
)
|
||||
from src.utils.agent_state import AgentState
|
||||
|
||||
from .custom_massage_manager import CustomMassageManager
|
||||
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
|
||||
@@ -72,6 +75,7 @@ class CustomAgent(Agent):
|
||||
max_error_length: int = 400,
|
||||
max_actions_per_step: int = 10,
|
||||
tool_call_in_content: bool = True,
|
||||
agent_state: AgentState = None,
|
||||
):
|
||||
super().__init__(
|
||||
task=task,
|
||||
@@ -92,6 +96,7 @@ class CustomAgent(Agent):
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
)
|
||||
self.add_infos = add_infos
|
||||
self.agent_state = agent_state
|
||||
self.message_manager = CustomMassageManager(
|
||||
llm=self.llm,
|
||||
task=self.task,
|
||||
@@ -367,9 +372,21 @@ class CustomAgent(Agent):
|
||||
)
|
||||
|
||||
for step in range(max_steps):
|
||||
# 1) Check if stop requested
|
||||
if self.agent_state and self.agent_state.is_stop_requested():
|
||||
logger.info("🛑 Stop requested by user")
|
||||
self._create_stop_history_item()
|
||||
break
|
||||
|
||||
# 2) Store last valid state before step
|
||||
if self.browser_context and self.agent_state:
|
||||
state = await self.browser_context.get_state(use_vision=self.use_vision)
|
||||
self.agent_state.set_last_valid_state(state)
|
||||
|
||||
if self._too_many_failures():
|
||||
break
|
||||
|
||||
# 3) Do the step
|
||||
await self.step(step_info)
|
||||
|
||||
if self.history.is_done():
|
||||
@@ -403,3 +420,61 @@ class CustomAgent(Agent):
|
||||
|
||||
if self.generate_gif:
|
||||
self.create_history_gif()
|
||||
|
||||
def _create_stop_history_item(self):
|
||||
"""Create a history item for when the agent is stopped."""
|
||||
try:
|
||||
# Attempt to retrieve the last valid state from agent_state
|
||||
state = None
|
||||
if self.agent_state:
|
||||
last_state = self.agent_state.get_last_valid_state()
|
||||
if last_state:
|
||||
# Convert to BrowserStateHistory
|
||||
state = BrowserStateHistory(
|
||||
url=getattr(last_state, 'url', ""),
|
||||
title=getattr(last_state, 'title', ""),
|
||||
tabs=getattr(last_state, 'tabs', []),
|
||||
interacted_element=[None],
|
||||
screenshot=getattr(last_state, 'screenshot', None)
|
||||
)
|
||||
else:
|
||||
state = self._create_empty_state()
|
||||
else:
|
||||
state = self._create_empty_state()
|
||||
|
||||
# Create a final item in the agent history indicating done
|
||||
stop_history = AgentHistory(
|
||||
model_output=None,
|
||||
state=state,
|
||||
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
|
||||
)
|
||||
self.history.history.append(stop_history)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating stop history item: {e}")
|
||||
# Create empty state as fallback
|
||||
state = self._create_empty_state()
|
||||
stop_history = AgentHistory(
|
||||
model_output=None,
|
||||
state=state,
|
||||
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
|
||||
)
|
||||
self.history.history.append(stop_history)
|
||||
|
||||
def _convert_to_browser_state_history(self, browser_state):
|
||||
return BrowserStateHistory(
|
||||
url=getattr(browser_state, 'url', ""),
|
||||
title=getattr(browser_state, 'title', ""),
|
||||
tabs=getattr(browser_state, 'tabs', []),
|
||||
interacted_element=[None],
|
||||
screenshot=getattr(browser_state, 'screenshot', None)
|
||||
)
|
||||
|
||||
def _create_empty_state(self):
|
||||
return BrowserStateHistory(
|
||||
url="",
|
||||
title="",
|
||||
tabs=[],
|
||||
interacted_element=[None],
|
||||
screenshot=None
|
||||
)
|
||||
|
||||
30
src/utils/agent_state.py
Normal file
30
src/utils/agent_state.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import asyncio
|
||||
|
||||
class AgentState:
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, '_stop_requested'):
|
||||
self._stop_requested = asyncio.Event()
|
||||
self.last_valid_state = None # store the last valid browser state
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(AgentState, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def request_stop(self):
|
||||
self._stop_requested.set()
|
||||
|
||||
def clear_stop(self):
|
||||
self._stop_requested.clear()
|
||||
self.last_valid_state = None
|
||||
|
||||
def is_stop_requested(self):
|
||||
return self._stop_requested.is_set()
|
||||
|
||||
def set_last_valid_state(self, state):
|
||||
self.last_valid_state = state
|
||||
|
||||
def get_last_valid_state(self):
|
||||
return self.last_valid_state
|
||||
239
webui.py
239
webui.py
@@ -6,6 +6,7 @@
|
||||
# @FileName: webui.py
|
||||
|
||||
import pdb
|
||||
import logging
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -13,6 +14,8 @@ load_dotenv()
|
||||
import argparse
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import gradio as gr
|
||||
import argparse
|
||||
|
||||
@@ -27,6 +30,7 @@ from browser_use.browser.context import (
|
||||
BrowserContextWindowSize,
|
||||
)
|
||||
from playwright.async_api import async_playwright
|
||||
from src.utils.agent_state import AgentState
|
||||
|
||||
from src.agent.custom_agent import CustomAgent
|
||||
from src.agent.custom_prompts import CustomSystemPrompt
|
||||
@@ -45,6 +49,36 @@ from browser_use.browser.context import BrowserContextConfig, BrowserContextWind
|
||||
_global_browser = None
|
||||
_global_browser_context = None
|
||||
|
||||
# Create the global agent state instance
|
||||
_global_agent_state = AgentState()
|
||||
|
||||
async def stop_agent():
|
||||
"""Request the agent to stop and update UI with enhanced feedback"""
|
||||
global _global_agent_state, _global_browser_context, _global_browser
|
||||
|
||||
try:
|
||||
# Request stop
|
||||
_global_agent_state.request_stop()
|
||||
|
||||
# Update UI immediately
|
||||
message = "Stop requested - the agent will halt at the next safe point"
|
||||
logger.info(f"🛑 {message}")
|
||||
|
||||
# Return UI updates
|
||||
return (
|
||||
message, # errors_output
|
||||
gr.update(value="Stopping...", interactive=False), # stop_button
|
||||
gr.update(interactive=False), # run_button
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error during stop: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return (
|
||||
error_msg,
|
||||
gr.update(value="Stop", interactive=True),
|
||||
gr.update(interactive=True)
|
||||
)
|
||||
|
||||
async def run_browser_agent(
|
||||
agent_type,
|
||||
llm_provider,
|
||||
@@ -68,79 +102,105 @@ async def run_browser_agent(
|
||||
max_actions_per_step,
|
||||
tool_call_in_content
|
||||
):
|
||||
# Disable recording if the checkbox is unchecked
|
||||
if not enable_recording:
|
||||
save_recording_path = None
|
||||
global _global_agent_state
|
||||
_global_agent_state.clear_stop() # Clear any previous stop requests
|
||||
|
||||
# Ensure the recording directory exists if recording is enabled
|
||||
if save_recording_path:
|
||||
os.makedirs(save_recording_path, exist_ok=True)
|
||||
try:
|
||||
# Disable recording if the checkbox is unchecked
|
||||
if not enable_recording:
|
||||
save_recording_path = None
|
||||
|
||||
# Get the list of existing videos before the agent runs
|
||||
existing_videos = set()
|
||||
if save_recording_path:
|
||||
existing_videos = set(
|
||||
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
|
||||
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
|
||||
# Ensure the recording directory exists if recording is enabled
|
||||
if save_recording_path:
|
||||
os.makedirs(save_recording_path, exist_ok=True)
|
||||
|
||||
# Get the list of existing videos before the agent runs
|
||||
existing_videos = set()
|
||||
if save_recording_path:
|
||||
existing_videos = set(
|
||||
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
|
||||
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
|
||||
)
|
||||
|
||||
# Run the agent
|
||||
llm = utils.get_llm_model(
|
||||
provider=llm_provider,
|
||||
model_name=llm_model_name,
|
||||
temperature=llm_temperature,
|
||||
base_url=llm_base_url,
|
||||
api_key=llm_api_key,
|
||||
)
|
||||
if agent_type == "org":
|
||||
final_result, errors, model_actions, model_thoughts = await run_org_agent(
|
||||
llm=llm,
|
||||
use_own_browser=use_own_browser,
|
||||
keep_browser_open=keep_browser_open,
|
||||
headless=headless,
|
||||
disable_security=disable_security,
|
||||
window_w=window_w,
|
||||
window_h=window_h,
|
||||
save_recording_path=save_recording_path,
|
||||
save_trace_path=save_trace_path,
|
||||
task=task,
|
||||
max_steps=max_steps,
|
||||
use_vision=use_vision,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
)
|
||||
elif agent_type == "custom":
|
||||
final_result, errors, model_actions, model_thoughts = await run_custom_agent(
|
||||
llm=llm,
|
||||
use_own_browser=use_own_browser,
|
||||
keep_browser_open=keep_browser_open,
|
||||
headless=headless,
|
||||
disable_security=disable_security,
|
||||
window_w=window_w,
|
||||
window_h=window_h,
|
||||
save_recording_path=save_recording_path,
|
||||
save_trace_path=save_trace_path,
|
||||
task=task,
|
||||
add_infos=add_infos,
|
||||
max_steps=max_steps,
|
||||
use_vision=use_vision,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type: {agent_type}")
|
||||
|
||||
# Get the list of videos after the agent runs (if recording is enabled)
|
||||
latest_video = None
|
||||
if save_recording_path:
|
||||
new_videos = set(
|
||||
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
|
||||
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
|
||||
)
|
||||
if new_videos - existing_videos:
|
||||
latest_video = list(new_videos - existing_videos)[0] # Get the first new video
|
||||
|
||||
return (
|
||||
final_result,
|
||||
errors,
|
||||
model_actions,
|
||||
model_thoughts,
|
||||
latest_video,
|
||||
gr.update(value="Stop", interactive=True), # Re-enable stop button
|
||||
gr.update(value="Run", interactive=True) # Re-enable run button
|
||||
)
|
||||
|
||||
# Run the agent
|
||||
llm = utils.get_llm_model(
|
||||
provider=llm_provider,
|
||||
model_name=llm_model_name,
|
||||
temperature=llm_temperature,
|
||||
base_url=llm_base_url,
|
||||
api_key=llm_api_key,
|
||||
)
|
||||
if agent_type == "org":
|
||||
final_result, errors, model_actions, model_thoughts = await run_org_agent(
|
||||
llm=llm,
|
||||
use_own_browser=use_own_browser,
|
||||
keep_browser_open=keep_browser_open,
|
||||
headless=headless,
|
||||
disable_security=disable_security,
|
||||
window_w=window_w,
|
||||
window_h=window_h,
|
||||
save_recording_path=save_recording_path,
|
||||
save_trace_path=save_trace_path,
|
||||
task=task,
|
||||
max_steps=max_steps,
|
||||
use_vision=use_vision,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
errors = str(e) + "\n" + traceback.format_exc()
|
||||
return (
|
||||
'', # final_result
|
||||
errors, # errors
|
||||
'', # model_actions
|
||||
'', # model_thoughts
|
||||
None, # latest_video
|
||||
gr.update(value="Stop", interactive=True), # Re-enable stop button
|
||||
gr.update(value="Run", interactive=True) # Re-enable run button
|
||||
)
|
||||
elif agent_type == "custom":
|
||||
final_result, errors, model_actions, model_thoughts = await run_custom_agent(
|
||||
llm=llm,
|
||||
use_own_browser=use_own_browser,
|
||||
keep_browser_open=keep_browser_open,
|
||||
headless=headless,
|
||||
disable_security=disable_security,
|
||||
window_w=window_w,
|
||||
window_h=window_h,
|
||||
save_recording_path=save_recording_path,
|
||||
save_trace_path=save_trace_path,
|
||||
task=task,
|
||||
add_infos=add_infos,
|
||||
max_steps=max_steps,
|
||||
use_vision=use_vision,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type: {agent_type}")
|
||||
|
||||
# Get the list of videos after the agent runs (if recording is enabled)
|
||||
latest_video = None
|
||||
if save_recording_path:
|
||||
new_videos = set(
|
||||
glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4"))
|
||||
+ glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]"))
|
||||
)
|
||||
if new_videos - existing_videos:
|
||||
latest_video = list(new_videos - existing_videos)[0] # Get the first new video
|
||||
|
||||
return final_result, errors, model_actions, model_thoughts, latest_video
|
||||
|
||||
|
||||
async def run_org_agent(
|
||||
@@ -161,7 +221,11 @@ async def run_org_agent(
|
||||
|
||||
):
|
||||
try:
|
||||
global _global_browser, _global_browser_context
|
||||
global _global_browser, _global_browser_context, _global_agent_state
|
||||
|
||||
# Clear any previous stop request
|
||||
_global_agent_state.clear_stop()
|
||||
|
||||
if use_own_browser:
|
||||
chrome_path = os.getenv("CHROME_PATH", None)
|
||||
if chrome_path == "":
|
||||
@@ -242,7 +306,10 @@ async def run_custom_agent(
|
||||
tool_call_in_content
|
||||
):
|
||||
try:
|
||||
global _global_browser, _global_browser_context
|
||||
global _global_browser, _global_browser_context, _global_agent_state
|
||||
|
||||
# Clear any previous stop request
|
||||
_global_agent_state.clear_stop()
|
||||
|
||||
if use_own_browser:
|
||||
chrome_path = os.getenv("CHROME_PATH", None)
|
||||
@@ -287,7 +354,8 @@ async def run_custom_agent(
|
||||
controller=controller,
|
||||
system_prompt_class=CustomSystemPrompt,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
agent_state=_global_agent_state
|
||||
)
|
||||
history = await agent.run(max_steps=max_steps)
|
||||
|
||||
@@ -550,6 +618,24 @@ def create_ui(theme_name="Ocean"):
|
||||
label="Model Thoughts", lines=3, show_label=True
|
||||
)
|
||||
|
||||
# Bind the stop button click event after errors_output is defined
|
||||
stop_button.click(
|
||||
fn=stop_agent,
|
||||
inputs=[],
|
||||
outputs=[errors_output, stop_button, run_button],
|
||||
)
|
||||
|
||||
# Run button click handler
|
||||
run_button.click(
|
||||
fn=run_browser_agent,
|
||||
inputs=[
|
||||
agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
|
||||
use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h, save_recording_path, save_trace_path,
|
||||
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_call_in_content
|
||||
],
|
||||
outputs=[final_result_output, errors_output, model_actions_output, model_thoughts_output, recording_display, stop_button, run_button],
|
||||
)
|
||||
|
||||
with gr.TabItem("🎥 Recordings", id=6):
|
||||
def list_recordings(save_recording_path):
|
||||
if not os.path.exists(save_recording_path):
|
||||
@@ -601,17 +687,6 @@ def create_ui(theme_name="Ocean"):
|
||||
use_own_browser.change(fn=close_global_browser)
|
||||
keep_browser_open.change(fn=close_global_browser)
|
||||
|
||||
# Run button click handler
|
||||
run_button.click(
|
||||
fn=run_browser_agent,
|
||||
inputs=[
|
||||
agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
|
||||
use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h, save_recording_path, save_trace_path,
|
||||
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_call_in_content
|
||||
],
|
||||
outputs=[final_result_output, errors_output, model_actions_output, model_thoughts_output, recording_display],
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user