resolve to merge with new version

This commit is contained in:
katiue
2025-01-09 01:26:43 +07:00
parent a4c2deffa0
commit 7c66ac14f4
4 changed files with 28 additions and 35 deletions

View File

@@ -85,7 +85,6 @@ class CustomAgent(Agent):
include_attributes=include_attributes,
max_error_length=max_error_length,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content,
)
self.add_infos = add_infos
self.message_manager = CustomMassageManager(
@@ -156,7 +155,7 @@ class CustomAgent(Agent):
parsed: AgentOutput = response['parsed']
# cut the number of actions to max_actions_per_step
parsed.action = parsed.action[: self.max_actions_per_step]
self._log_response(parsed)
self._log_response(parsed) # type: ignore
self.n_steps += 1
return parsed
@@ -165,7 +164,7 @@ class CustomAgent(Agent):
# and Manually parse the response. Temporarily solution for DeepSeek
ret = self.llm.invoke(input_messages)
if isinstance(ret.content, list):
parsed_json = json.loads(ret.content[0].replace("```json", "").replace("```", ""))
parsed_json = json.loads(str(ret.content[0]).replace("```json", "").replace("```", ""))
else:
parsed_json = json.loads(ret.content.replace("```json", "").replace("```", ""))
parsed: AgentOutput = self.AgentOutput(**parsed_json)
@@ -193,7 +192,7 @@ class CustomAgent(Agent):
input_messages = self.message_manager.get_messages()
model_output = await self.get_next_action(input_messages)
if step_info is not None:
self.update_step_info(model_output, step_info)
self.update_step_info(model_output=CustomAgentOutput(**model_output.dict()), step_info=step_info)
logger.info(f'🧠 All Memory: {step_info.memory}')
self._save_conversation(input_messages, model_output)
self.message_manager._remove_last_state_message() # we dont want the whole state in the chat history

View File

@@ -41,6 +41,7 @@ class CustomMassageManager(MessageManager):
max_actions_per_step: int = 10,
tool_call_in_content: bool = False,
):
self.tool_call_in_content = tool_call_in_content
super().__init__(
llm=llm,
task=task,
@@ -52,13 +53,17 @@ class CustomMassageManager(MessageManager):
include_attributes=include_attributes,
max_error_length=max_error_length,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content,
)
# Custom: Move Task info to state_message
self.history = MessageHistory()
self._add_message_with_tokens(self.system_prompt)
tool_calls = [
tool_calls = self._create_tool_calls()
example_tool_call = self._create_example_tool_call(tool_calls)
self._add_message_with_tokens(example_tool_call)
def _create_tool_calls(self):
return [
{
'name': 'AgentOutput',
'args': {
@@ -73,20 +78,20 @@ class CustomMassageManager(MessageManager):
'type': 'tool_call',
}
]
def _create_example_tool_call(self, tool_calls):
if self.tool_call_in_content:
# openai throws error if tool_calls are not responded -> move to content
example_tool_call = AIMessage(
return AIMessage(
content=f'{tool_calls}',
tool_calls=[],
)
else:
example_tool_call = AIMessage(
return AIMessage(
content=f'',
tool_calls=tool_calls,
)
self._add_message_with_tokens(example_tool_call)
def add_state_message(
self,
state: BrowserState,

View File

@@ -22,7 +22,6 @@ def get_llm_model(provider: str, **kwargs):
:param kwargs:
:return:
"""
if provider == "anthropic":
if provider == "anthropic":
if not kwargs.get("base_url", ""):
base_url = "https://api.anthropic.com"
@@ -35,7 +34,6 @@ def get_llm_model(provider: str, **kwargs):
api_key = kwargs.get("api_key")
return ChatAnthropic(
model_name=kwargs.get("model_name", "claude-3-5-sonnet-20240620"),
model_name=kwargs.get("model_name", "claude-3-5-sonnet-20240620"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
@@ -43,7 +41,6 @@ def get_llm_model(provider: str, **kwargs):
timeout=kwargs.get("timeout", 60),
stop=kwargs.get("stop", None),
)
elif provider == "openai":
elif provider == "openai":
if not kwargs.get("base_url", ""):
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
@@ -56,13 +53,11 @@ def get_llm_model(provider: str, **kwargs):
api_key = kwargs.get("api_key")
return ChatOpenAI(
model=kwargs.get("model_name", "gpt-4o"),
model=kwargs.get("model_name", "gpt-4o"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=SecretStr(api_key or ""),
)
elif provider == "deepseek":
elif provider == "deepseek":
if not kwargs.get("base_url", ""):
base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
@@ -75,28 +70,23 @@ def get_llm_model(provider: str, **kwargs):
api_key = kwargs.get("api_key")
return ChatOpenAI(
model=kwargs.get("model_name", "deepseek-chat"),
model=kwargs.get("model_name", "deepseek-chat"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=SecretStr(api_key or ""),
)
elif provider == "gemini":
elif provider == "gemini":
if not kwargs.get("api_key", ""):
api_key = os.getenv("GOOGLE_API_KEY", "")
else:
api_key = kwargs.get("api_key")
return ChatGoogleGenerativeAI(
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
temperature=kwargs.get("temperature", 0.0),
api_key=SecretStr(api_key or ""),
)
elif provider == "ollama":
elif provider == "ollama":
return ChatOllama(
model=kwargs.get("model_name", "qwen2.5:7b"),
model=kwargs.get("model_name", "qwen2.5:7b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=128000,
@@ -111,7 +101,6 @@ def get_llm_model(provider: str, **kwargs):
else:
api_key = kwargs.get("api_key")
return AzureChatOpenAI(
model=kwargs.get("model_name", "gpt-4o"),
model=kwargs.get("model_name", "gpt-4o"),
temperature=kwargs.get("temperature", 0.0),
api_version="2024-05-01-preview",

View File

@@ -6,6 +6,7 @@
# @FileName: webui.py
import pdb
import glob
from dotenv import load_dotenv
load_dotenv()
@@ -54,13 +55,6 @@ async def run_browser_agent(
tool_call_in_content,
browser_context=None # Added optional argument
):
# Ensure the recording directory exists
os.makedirs(save_recording_path, exist_ok=True)
# Get the list of existing videos before the agent runs
existing_videos = set(glob.glob(os.path.join(save_recording_path, '*.[mM][pP]4')) +
glob.glob(os.path.join(save_recording_path, '*.[wW][eE][bB][mM]')))
# Run the agent
llm = utils.get_llm_model(
provider=llm_provider,
@@ -162,7 +156,6 @@ async def run_org_agent(
llm=llm,
use_vision=use_vision,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content,
browser_context=browser_context,
)
history = await agent.run(max_steps=max_steps)
@@ -208,7 +201,7 @@ async def run_custom_agent(
chrome_use_data = None
browser_context_ = await playwright.chromium.launch_persistent_context(
user_data_dir=chrome_use_data,
user_data_dir=chrome_use_data if chrome_use_data else "",
executable_path=chrome_exe,
no_viewport=False,
headless=headless, # 保持浏览器窗口可见
@@ -234,7 +227,9 @@ async def run_custom_agent(
llm=llm,
browser_context=browser_context,
controller=controller,
system_prompt_class=CustomSystemPrompt
system_prompt_class=CustomSystemPrompt,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
)
history = await agent.run(max_steps=max_steps)
final_result = history.final_result()
@@ -268,7 +263,9 @@ async def run_custom_agent(
llm=llm,
browser_context=browser_context_in,
controller=controller,
system_prompt_class=CustomSystemPrompt
system_prompt_class=CustomSystemPrompt,
max_actions_per_step=max_actions_per_step,
tool_call_in_content=tool_call_in_content
)
history = await agent.run(max_steps=max_steps)
@@ -317,6 +314,8 @@ async def run_with_stream(
add_infos,
max_steps,
use_vision,
max_actions_per_step,
tool_call_in_content,
):
"""Wrapper to run the agent and handle streaming."""
browser = None
@@ -360,6 +359,8 @@ async def run_with_stream(
add_infos,
max_steps,
use_vision,
max_actions_per_step,
tool_call_in_content,
browser_context=browser_context # Explicit keyword argument
)
)
@@ -430,7 +431,7 @@ async def run_with_stream(
if browser:
await browser.close()
from gradio.themes import Citrus, Default, Glass, Monochrome, Ocean, Origin, Soft
from gradio.themes import Citrus, Default, Glass, Monochrome, Ocean, Origin, Soft, Base
# Define the theme map globally
theme_map = {
@@ -472,7 +473,6 @@ def create_ui(theme_name="Ocean"):
### Control your browser with AI assistance
""",
elem_classes=["header-text"],
elem_classes=["header-text"],
)
with gr.Tabs() as tabs:
@@ -636,7 +636,7 @@ def create_ui(theme_name="Ocean"):
model_actions_output,
model_thoughts_output,
recording_file,
trace_file, max_actions_per_step, tool_call_in_content
trace_file
],
queue=True,
)