Merge pull request #249 from vvincent1234/fix/dr-bugs

Fix/dr bugs
This commit is contained in:
warmshao
2025-02-08 18:31:22 +08:00
committed by GitHub
5 changed files with 130 additions and 70 deletions

View File

@@ -58,11 +58,6 @@ Activate the virtual environment:
```bash
source .venv/bin/activate
```
alternative activation for Windows:
```bash
.\.venv\Scripts\Activate
```
#### Step 3: Install Dependencies
Install Python packages:

View File

@@ -66,6 +66,6 @@ class CustomController(Controller):
)
# go back to org url
await page.go_back()
msg = f'📄 Extracted page content as {output_format}\n: {content}\n'
msg = f'Extracted page content:\n {content}\n'
logger.info(msg)
return ActionResult(extracted_content=msg)

View File

@@ -1,4 +1,3 @@
import pdb
from dotenv import load_dotenv
@@ -21,17 +20,51 @@ from json_repair import repair_json
from src.agent.custom_prompts import CustomSystemPrompt, CustomAgentMessagePrompt
from src.controller.custom_controller import CustomController
from src.browser.custom_browser import CustomBrowser
from src.browser.custom_context import BrowserContextConfig
from browser_use.browser.context import (
BrowserContextConfig,
BrowserContextWindowSize,
)
logger = logging.getLogger(__name__)
async def deep_research(task, llm, agent_state, **kwargs):
async def deep_research(task, llm, agent_state=None, **kwargs):
task_id = str(uuid4())
save_dir = kwargs.get("save_dir", os.path.join(f"./tmp/deep_research/{task_id}"))
logger.info(f"Save Deep Research at: {save_dir}")
os.makedirs(save_dir, exist_ok=True)
# max qyery num per iteration
max_query_num = kwargs.get("max_query_num", 3)
use_own_browser = kwargs.get("use_own_browser", False)
extra_chromium_args = []
if use_own_browser:
# TODO: if use own browser, max query num must be 1 per iter, how to solve it?
max_query_num = 1
chrome_path = os.getenv("CHROME_PATH", None)
if chrome_path == "":
chrome_path = None
chrome_user_data = os.getenv("CHROME_USER_DATA", None)
if chrome_user_data:
extra_chromium_args += [f"--user-data-dir={chrome_user_data}"]
browser = CustomBrowser(
config=BrowserConfig(
headless=kwargs.get("headless", False),
disable_security=kwargs.get("disable_security", True),
chrome_instance_path=chrome_path,
extra_chromium_args=extra_chromium_args,
)
)
browser_context = await browser.new_context()
else:
browser = None
browser_context = None
controller = CustomController()
search_system_prompt = f"""
You are a **Deep Researcher**, an AI agent specializing in in-depth information gathering and research using a web browser with **automated execution capabilities**. Your expertise lies in formulating comprehensive research plans and executing them meticulously to fulfill complex user requests. You will analyze user instructions, devise a detailed research plan, and determine the necessary search queries to gather the required information.
@@ -111,26 +144,12 @@ Provide your output as a JSON formatted list. Each item in the list must adhere
1. **User Instruction:** The original instruction given by the user. This helps you determine what kind of information will be useful and how to structure your thinking.
2. **Previous Recorded Information:** Textual data gathered and recorded from previous searches and processing, represented as a single text string.
3. **Current Search Results:** Textual data gathered from the most recent search query.
3. **Current Search Plan:** Research plan for current search.
4. **Current Search Query:** The current search query.
5. **Current Search Results:** Textual data gathered from the most recent search query.
"""
record_messages = [SystemMessage(content=record_system_prompt)]
use_own_browser = kwargs.get("use_own_browser", False)
extra_chromium_args = []
if use_own_browser:
# if use own browser, max query num should be 1 per iter
max_query_num = 1
chrome_path = os.getenv("CHROME_PATH", None)
if chrome_path == "":
chrome_path = None
chrome_user_data = os.getenv("CHROME_USER_DATA", None)
if chrome_user_data:
extra_chromium_args += [f"--user-data-dir={chrome_user_data}"]
else:
chrome_path = None
browser = None
controller = CustomController()
search_iteration = 0
max_search_iterations = kwargs.get("max_search_iterations", 10) # Limit search iterations to prevent infinite loop
use_vision = kwargs.get("use_vision", False)
@@ -167,35 +186,42 @@ Provide your output as a JSON formatted list. Each item in the list must adhere
logger.info(query_tasks)
# 2. Perform Web Search and Auto exec
# Paralle BU agents
# Parallel BU agents
add_infos = "1. Please click on the most relevant link to get information and go deeper, instead of just staying on the search page. \n" \
"2. When opening a PDF file, please remember to extract the content using extract_content instead of simply opening it for the user to view."
"2. When opening a PDF file, please remember to extract the content using extract_content instead of simply opening it for the user to view.\n"
if use_own_browser:
browser = CustomBrowser(
config=BrowserConfig(
headless=kwargs.get("headless", False),
disable_security=kwargs.get("disable_security", True),
chrome_instance_path=chrome_path,
extra_chromium_args=extra_chromium_args,
)
agent = CustomAgent(
task=query_tasks[0],
llm=llm,
add_infos=add_infos,
browser=browser,
browser_context=browser_context,
use_vision=use_vision,
system_prompt_class=CustomSystemPrompt,
agent_prompt_class=CustomAgentMessagePrompt,
max_actions_per_step=5,
controller=controller,
agent_state=agent_state
)
agents = [CustomAgent(
task=task,
llm=llm,
add_infos=add_infos,
browser=browser,
use_vision=use_vision,
system_prompt_class=CustomSystemPrompt,
agent_prompt_class=CustomAgentMessagePrompt,
max_actions_per_step=5,
controller=controller,
agent_state=agent_state
) for task in query_tasks]
query_results = await asyncio.gather(*[agent.run(max_steps=kwargs.get("max_steps", 10)) for agent in agents])
if browser:
await browser.close()
browser = None
logger.info("Browser closed.")
agent_result = await agent.run(max_steps=kwargs.get("max_steps", 10))
query_results = [agent_result]
else:
agents = [CustomAgent(
task=query_tasks[0],
llm=llm,
add_infos=add_infos,
browser=browser,
browser_context=browser_context,
use_vision=use_vision,
system_prompt_class=CustomSystemPrompt,
agent_prompt_class=CustomAgentMessagePrompt,
max_actions_per_step=5,
controller=controller,
agent_state=agent_state
) for task in query_tasks]
query_results = await asyncio.gather(
*[agent.run(max_steps=kwargs.get("max_steps", 10)) for agent in agents])
if agent_state and agent_state.is_stop_requested():
# Stop
break
@@ -211,19 +237,27 @@ Provide your output as a JSON formatted list. Each item in the list must adhere
with open(querr_save_path, "w", encoding="utf-8") as fw:
fw.write(f"Query: {query_tasks[i]}\n")
fw.write(query_result)
history_infos_ = json.dumps(history_infos, indent=4)
record_prompt = f"User Instruction:{task}. \nPrevious Recorded Information:\n {json.dumps(history_infos_)} \n Current Search Results: {query_result}\n "
record_messages.append(HumanMessage(content=record_prompt))
ai_record_msg = llm.invoke(record_messages[:1] + record_messages[-1:])
record_messages.append(ai_record_msg)
if hasattr(ai_record_msg, "reasoning_content"):
logger.info("🤯 Start Record Deep Thinking: ")
logger.info(ai_record_msg.reasoning_content)
logger.info("🤯 End Record Deep Thinking")
record_content = ai_record_msg.content
record_content = repair_json(record_content)
new_record_infos = json.loads(record_content)
history_infos.extend(new_record_infos)
# split query result in case the content is too long
query_results_split = query_result.split("Extracted page content:")
for qi, query_result_ in enumerate(query_results_split):
if not query_result_:
continue
else:
# TODO: limit content lenght: 128k tokens, ~3 chars per token
query_result_ = query_result_[:128000*3]
history_infos_ = json.dumps(history_infos, indent=4)
record_prompt = f"User Instruction:{task}. \nPrevious Recorded Information:\n {history_infos_}\n Current Search Iteration: {search_iteration}\n Current Search Plan:\n{query_plan}\n Current Search Query:\n {query_tasks[i]}\n Current Search Results: {query_result_}\n "
record_messages.append(HumanMessage(content=record_prompt))
ai_record_msg = llm.invoke(record_messages[:1] + record_messages[-1:])
record_messages.append(ai_record_msg)
if hasattr(ai_record_msg, "reasoning_content"):
logger.info("🤯 Start Record Deep Thinking: ")
logger.info(ai_record_msg.reasoning_content)
logger.info("🤯 End Record Deep Thinking")
record_content = ai_record_msg.content
record_content = repair_json(record_content)
new_record_infos = json.loads(record_content)
history_infos.extend(new_record_infos)
logger.info("\nFinish Searching, Start Generating Report...")
@@ -258,7 +292,7 @@ Provide your output as a JSON formatted list. Each item in the list must adhere
1. **User Instruction:** The original instruction given by the user. This helps you determine what kind of information will be useful and how to structure your thinking.
2. **Search Information:** Information gathered from the search queries.
"""
history_infos_ = json.dumps(history_infos, indent=4)
record_json_path = os.path.join(save_dir, "record_infos.json")
logger.info(f"save All recorded information at {record_json_path}")
@@ -288,5 +322,6 @@ Provide your output as a JSON formatted list. Each item in the list must adhere
finally:
if browser:
await browser.close()
browser = None
logger.info("Browser closed.")
if browser_context:
await browser_context.close()
logger.info("Browser closed.")

View File

@@ -357,5 +357,5 @@ async def test_browser_use_parallel():
if __name__ == "__main__":
# asyncio.run(test_browser_use_org())
asyncio.run(test_browser_use_parallel())
# asyncio.run(test_browser_use_custom())
# asyncio.run(test_browser_use_parallel())
asyncio.run(test_browser_use_custom())

View File

@@ -0,0 +1,30 @@
import asyncio
import os
from dotenv import load_dotenv
load_dotenv()
import sys
sys.path.append(".")
async def test_deep_research():
from src.utils.deep_research import deep_research
from src.utils import utils
task = "write a report about DeepSeek-R1, get its pdf"
llm = utils.get_llm_model(
provider="gemini",
model_name="gemini-2.0-flash-thinking-exp-01-21",
temperature=1.0,
api_key=os.getenv("GOOGLE_API_KEY", "")
)
report_content, report_file_path = await deep_research(task=task, llm=llm, agent_state=None,
max_search_iterations=1,
max_query_num=3,
use_own_browser=False)
if __name__ == "__main__":
asyncio.run(test_deep_research())