resolve to merge with new version
This commit is contained in:
@@ -85,7 +85,6 @@ class CustomAgent(Agent):
|
||||
include_attributes=include_attributes,
|
||||
max_error_length=max_error_length,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
)
|
||||
self.add_infos = add_infos
|
||||
self.message_manager = CustomMassageManager(
|
||||
@@ -156,7 +155,7 @@ class CustomAgent(Agent):
|
||||
parsed: AgentOutput = response['parsed']
|
||||
# cut the number of actions to max_actions_per_step
|
||||
parsed.action = parsed.action[: self.max_actions_per_step]
|
||||
self._log_response(parsed)
|
||||
self._log_response(parsed) # type: ignore
|
||||
self.n_steps += 1
|
||||
|
||||
return parsed
|
||||
@@ -165,7 +164,7 @@ class CustomAgent(Agent):
|
||||
# and Manually parse the response. Temporarily solution for DeepSeek
|
||||
ret = self.llm.invoke(input_messages)
|
||||
if isinstance(ret.content, list):
|
||||
parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", ""))
|
||||
parsed_json = json.loads(str(ret.content[0]).replace("```json", "").replace("```", ""))
|
||||
else:
|
||||
parsed_json = json.loads(ret.content.replace("```json", "").replace("```", ""))
|
||||
parsed: AgentOutput = self.AgentOutput(**parsed_json)
|
||||
@@ -193,7 +192,7 @@ class CustomAgent(Agent):
|
||||
input_messages = self.message_manager.get_messages()
|
||||
model_output = await self.get_next_action(input_messages)
|
||||
if step_info is not None:
|
||||
self.update_step_info(model_output, step_info)
|
||||
self.update_step_info(model_output=CustomAgentOutput(**model_output.dict()), step_info=step_info)
|
||||
logger.info(f'🧠 All Memory: {step_info.memory}')
|
||||
self._save_conversation(input_messages, model_output)
|
||||
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history
|
||||
|
||||
@@ -41,6 +41,7 @@ class CustomMassageManager(MessageManager):
|
||||
max_actions_per_step: int = 10,
|
||||
tool_call_in_content: bool = False,
|
||||
):
|
||||
self.tool_call_in_content = tool_call_in_content
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
task=task,
|
||||
@@ -52,13 +53,17 @@ class CustomMassageManager(MessageManager):
|
||||
include_attributes=include_attributes,
|
||||
max_error_length=max_error_length,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
)
|
||||
|
||||
# Custom: Move Task info to state_message
|
||||
self.history = MessageHistory()
|
||||
self._add_message_with_tokens(self.system_prompt)
|
||||
tool_calls = [
|
||||
tool_calls = self._create_tool_calls()
|
||||
example_tool_call = self._create_example_tool_call(tool_calls)
|
||||
self._add_message_with_tokens(example_tool_call)
|
||||
|
||||
def _create_tool_calls(self):
|
||||
return [
|
||||
{
|
||||
'name': 'AgentOutput',
|
||||
'args': {
|
||||
@@ -73,20 +78,20 @@ class CustomMassageManager(MessageManager):
|
||||
'type': 'tool_call',
|
||||
}
|
||||
]
|
||||
|
||||
def _create_example_tool_call(self, tool_calls):
|
||||
if self.tool_call_in_content:
|
||||
# openai throws error if tool_calls are not responded -> move to content
|
||||
example_tool_call = AIMessage(
|
||||
return AIMessage(
|
||||
content=f'{tool_calls}',
|
||||
tool_calls=[],
|
||||
)
|
||||
else:
|
||||
example_tool_call = AIMessage(
|
||||
return AIMessage(
|
||||
content=f'',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
self._add_message_with_tokens(example_tool_call)
|
||||
|
||||
def add_state_message(
|
||||
self,
|
||||
state: BrowserState,
|
||||
|
||||
@@ -22,7 +22,6 @@ def get_llm_model(provider: str, **kwargs):
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if provider == "anthropic":
|
||||
if provider == "anthropic":
|
||||
if not kwargs.get("base_url", ""):
|
||||
base_url = "https://api.anthropic.com"
|
||||
@@ -35,7 +34,6 @@ def get_llm_model(provider: str, **kwargs):
|
||||
api_key = kwargs.get("api_key")
|
||||
|
||||
return ChatAnthropic(
|
||||
model_name=kwargs.get("model_name", "claude-3-5-sonnet-20240620"),
|
||||
model_name=kwargs.get("model_name", "claude-3-5-sonnet-20240620"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
base_url=base_url,
|
||||
@@ -43,7 +41,6 @@ def get_llm_model(provider: str, **kwargs):
|
||||
timeout=kwargs.get("timeout", 60),
|
||||
stop=kwargs.get("stop", None),
|
||||
)
|
||||
elif provider == "openai":
|
||||
elif provider == "openai":
|
||||
if not kwargs.get("base_url", ""):
|
||||
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
|
||||
@@ -56,13 +53,11 @@ def get_llm_model(provider: str, **kwargs):
|
||||
api_key = kwargs.get("api_key")
|
||||
|
||||
return ChatOpenAI(
|
||||
model=kwargs.get("model_name", "gpt-4o"),
|
||||
model=kwargs.get("model_name", "gpt-4o"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
base_url=base_url,
|
||||
api_key=SecretStr(api_key or ""),
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
elif provider == "deepseek":
|
||||
if not kwargs.get("base_url", ""):
|
||||
base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
|
||||
@@ -75,28 +70,23 @@ def get_llm_model(provider: str, **kwargs):
|
||||
api_key = kwargs.get("api_key")
|
||||
|
||||
return ChatOpenAI(
|
||||
model=kwargs.get("model_name", "deepseek-chat"),
|
||||
model=kwargs.get("model_name", "deepseek-chat"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
base_url=base_url,
|
||||
api_key=SecretStr(api_key or ""),
|
||||
)
|
||||
elif provider == "gemini":
|
||||
elif provider == "gemini":
|
||||
if not kwargs.get("api_key", ""):
|
||||
api_key = os.getenv("GOOGLE_API_KEY", "")
|
||||
else:
|
||||
api_key = kwargs.get("api_key")
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
|
||||
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
api_key=SecretStr(api_key or ""),
|
||||
)
|
||||
elif provider == "ollama":
|
||||
elif provider == "ollama":
|
||||
return ChatOllama(
|
||||
model=kwargs.get("model_name", "qwen2.5:7b"),
|
||||
model=kwargs.get("model_name", "qwen2.5:7b"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
num_ctx=128000,
|
||||
@@ -111,7 +101,6 @@ def get_llm_model(provider: str, **kwargs):
|
||||
else:
|
||||
api_key = kwargs.get("api_key")
|
||||
return AzureChatOpenAI(
|
||||
model=kwargs.get("model_name", "gpt-4o"),
|
||||
model=kwargs.get("model_name", "gpt-4o"),
|
||||
temperature=kwargs.get("temperature", 0.0),
|
||||
api_version="2024-05-01-preview",
|
||||
|
||||
28
webui.py
28
webui.py
@@ -6,6 +6,7 @@
|
||||
# @FileName: webui.py
|
||||
|
||||
import pdb
|
||||
import glob
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
@@ -54,13 +55,6 @@ async def run_browser_agent(
|
||||
tool_call_in_content,
|
||||
browser_context=None # Added optional argument
|
||||
):
|
||||
# Ensure the recording directory exists
|
||||
os.makedirs(save_recording_path, exist_ok=True)
|
||||
|
||||
# Get the list of existing videos before the agent runs
|
||||
existing_videos = set(glob.glob(os.path.join(save_recording_path, '*.[mM][pP]4')) +
|
||||
glob.glob(os.path.join(save_recording_path, '*.[wW][eE][bB][mM]')))
|
||||
|
||||
# Run the agent
|
||||
llm = utils.get_llm_model(
|
||||
provider=llm_provider,
|
||||
@@ -162,7 +156,6 @@ async def run_org_agent(
|
||||
llm=llm,
|
||||
use_vision=use_vision,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
browser_context=browser_context,
|
||||
)
|
||||
history = await agent.run(max_steps=max_steps)
|
||||
@@ -208,7 +201,7 @@ async def run_custom_agent(
|
||||
chrome_use_data = None
|
||||
|
||||
browser_context_ = await playwright.chromium.launch_persistent_context(
|
||||
user_data_dir=chrome_use_data,
|
||||
user_data_dir=chrome_use_data if chrome_use_data else "",
|
||||
executable_path=chrome_exe,
|
||||
no_viewport=False,
|
||||
headless=headless, # 保持浏览器窗口可见
|
||||
@@ -234,7 +227,9 @@ async def run_custom_agent(
|
||||
llm=llm,
|
||||
browser_context=browser_context,
|
||||
controller=controller,
|
||||
system_prompt_class=CustomSystemPrompt
|
||||
system_prompt_class=CustomSystemPrompt,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
)
|
||||
history = await agent.run(max_steps=max_steps)
|
||||
final_result = history.final_result()
|
||||
@@ -268,7 +263,9 @@ async def run_custom_agent(
|
||||
llm=llm,
|
||||
browser_context=browser_context_in,
|
||||
controller=controller,
|
||||
system_prompt_class=CustomSystemPrompt
|
||||
system_prompt_class=CustomSystemPrompt,
|
||||
max_actions_per_step=max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content
|
||||
)
|
||||
history = await agent.run(max_steps=max_steps)
|
||||
|
||||
@@ -317,6 +314,8 @@ async def run_with_stream(
|
||||
add_infos,
|
||||
max_steps,
|
||||
use_vision,
|
||||
max_actions_per_step,
|
||||
tool_call_in_content,
|
||||
):
|
||||
"""Wrapper to run the agent and handle streaming."""
|
||||
browser = None
|
||||
@@ -360,6 +359,8 @@ async def run_with_stream(
|
||||
add_infos,
|
||||
max_steps,
|
||||
use_vision,
|
||||
max_actions_per_step,
|
||||
tool_call_in_content,
|
||||
browser_context=browser_context # Explicit keyword argument
|
||||
)
|
||||
)
|
||||
@@ -430,7 +431,7 @@ async def run_with_stream(
|
||||
if browser:
|
||||
await browser.close()
|
||||
|
||||
from gradio.themes import Citrus, Default, Glass, Monochrome, Ocean, Origin, Soft
|
||||
from gradio.themes import Citrus, Default, Glass, Monochrome, Ocean, Origin, Soft, Base
|
||||
|
||||
# Define the theme map globally
|
||||
theme_map = {
|
||||
@@ -472,7 +473,6 @@ def create_ui(theme_name="Ocean"):
|
||||
### Control your browser with AI assistance
|
||||
""",
|
||||
elem_classes=["header-text"],
|
||||
elem_classes=["header-text"],
|
||||
)
|
||||
|
||||
with gr.Tabs() as tabs:
|
||||
@@ -636,7 +636,7 @@ def create_ui(theme_name="Ocean"):
|
||||
model_actions_output,
|
||||
model_thoughts_output,
|
||||
recording_file,
|
||||
trace_file, max_actions_per_step, tool_call_in_content
|
||||
trace_file
|
||||
],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user