mirror of
https://github.com/assafelovic/gpt-researcher.git
synced 2024-04-09 14:09:35 +03:00
update
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
# main
|
||||
import os, sys
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
||||
|
||||
from permchain_example.researcher import Researcher
|
||||
from permchain_example.editor_actors.editor import EditorActor
|
||||
from permchain_example.reviser_actors.reviser import ReviserActor
|
||||
from permchain_example.search_actors.gpt_researcher import GPTResearcherActor
|
||||
from permchain_example.writer_actors.writer import WriterActor
|
||||
from permchain_example.research_team import ResearchTeam
|
||||
from examples.permchain_example.researcher import Researcher
|
||||
from examples.permchain_example.editor_actors.editor import EditorActor
|
||||
from examples.permchain_example.reviser_actors.reviser import ReviserActor
|
||||
from examples.permchain_example.search_actors.gpt_researcher import GPTResearcherActor
|
||||
from examples.permchain_example.writer_actors.writer import WriterActor
|
||||
from examples.permchain_example.research_team import ResearchTeam
|
||||
from gpt_researcher_old.processing.text import md_to_pdf
|
||||
|
||||
|
||||
22
examples/sample_report.py
Normal file
22
examples/sample_report.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from gpt_researcher import GPTResearcher
|
||||
import asyncio
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
This is a sample script that shows how to run a research report.
|
||||
"""
|
||||
# Query
|
||||
query = "What happened in the latest burning man floods?"
|
||||
|
||||
# Report Type
|
||||
report_type = "research_report"
|
||||
|
||||
# Run Research
|
||||
researcher = GPTResearcher(query=query, report_type=report_type, config_path=None)
|
||||
report = await researcher.run()
|
||||
return report
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
4
gpt_researcher/__init__.py
Normal file
4
gpt_researcher/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .master import GPTResearcher
|
||||
from .config import Config
|
||||
|
||||
__all__ = ['GPTResearcher', 'Config']
|
||||
3
gpt_researcher/config/__init__.py
Normal file
3
gpt_researcher/config/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .config import Config
|
||||
|
||||
__all__ = ['Config']
|
||||
@@ -1,8 +1,14 @@
|
||||
# config file
|
||||
import json
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
self.retriver = "tavily"
|
||||
"""Config class for GPT Researcher."""
|
||||
|
||||
def __init__(self, config_file: str = None):
|
||||
"""Initialize the config class."""
|
||||
self.config_file = config_file
|
||||
self.retriever = "tavily"
|
||||
self.llm_provider = "ChatOpenAI"
|
||||
self.fast_llm_model = "gpt-3.5-turbo-16k"
|
||||
self.smart_llm_model = "gpt-4"
|
||||
@@ -12,8 +18,16 @@ class Config:
|
||||
self.summary_token_limit = 700
|
||||
self.temperature = 1.0
|
||||
self.user_agent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) " \
|
||||
"Chrome/83.0.4103.97 Safari/537.36 "
|
||||
"Chrome/83.0.4103.97 Safari/537.36 "
|
||||
self.memory_backend = "local"
|
||||
self.load_config_file()
|
||||
|
||||
|
||||
def load_config_file(self) -> None:
|
||||
"""Load the config file."""
|
||||
if self.config_file is None:
|
||||
return None
|
||||
with open(self.config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
for key, value in config.items():
|
||||
self.__dict__[key] = value
|
||||
|
||||
|
||||
3
gpt_researcher/master/__init__.py
Normal file
3
gpt_researcher/master/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .agent import GPTResearcher
|
||||
|
||||
__all__ = ['GPTResearcher']
|
||||
@@ -1,26 +1,44 @@
|
||||
import time
|
||||
|
||||
from gpt_researcher.config import Config
|
||||
from gpt_researcher.master.functions import *
|
||||
|
||||
|
||||
class GPTResearcher:
|
||||
def __init__(self, query, report_type, websocket=None):
|
||||
"""
|
||||
GPT Researcher
|
||||
"""
|
||||
def __init__(self, query, report_type, config_path=None, websocket=None):
|
||||
"""
|
||||
Initialize the GPT Researcher class.
|
||||
Args:
|
||||
query:
|
||||
report_type:
|
||||
config_path:
|
||||
websocket:
|
||||
"""
|
||||
self.query = query
|
||||
self.agent = None
|
||||
self.role = None
|
||||
self.report_type = report_type
|
||||
self.websocket = websocket
|
||||
self.retriever = get_retriever()
|
||||
self.cfg = Config(config_path)
|
||||
self.retriever = get_retriever(self.cfg.retriever)
|
||||
self.context = []
|
||||
self.visited_urls = set()
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Runs the GPT Researcher
|
||||
Returns:
|
||||
Report
|
||||
"""
|
||||
print(f"🔎 Running research for '{self.query}'...")
|
||||
# Generate Agent
|
||||
self.agent, self.role = await choose_agent(self.query)
|
||||
self.agent, self.role = await choose_agent(self.query, self.cfg)
|
||||
await self.stream_output("logs", self.agent)
|
||||
|
||||
# Generate Sub-Queries
|
||||
sub_queries = await get_sub_queries(self.query, self.role)
|
||||
sub_queries = await get_sub_queries(self.query, self.role, self.cfg)
|
||||
await self.stream_output("logs",
|
||||
f"🧠 I will conduct my research based on the following queries: {sub_queries}...")
|
||||
|
||||
@@ -35,11 +53,19 @@ class GPTResearcher:
|
||||
await self.stream_output("logs", f"✍️ Writing {self.report_type} for research task: {self.query}...")
|
||||
report = await generate_report(query=self.query, context=self.context,
|
||||
agent_role_prompt=self.role, report_type=self.report_type,
|
||||
websocket=self.websocket)
|
||||
websocket=self.websocket, cfg=self.cfg)
|
||||
time.sleep(1)
|
||||
return report
|
||||
|
||||
async def run_sub_query(self, sub_query):
|
||||
"""
|
||||
Runs a sub-query
|
||||
Args:
|
||||
sub_query:
|
||||
|
||||
Returns:
|
||||
Summary
|
||||
"""
|
||||
# Get Urls
|
||||
retriever = self.retriever(sub_query)
|
||||
urls = retriever.search()
|
||||
@@ -54,12 +80,21 @@ class GPTResearcher:
|
||||
await self.stream_output("logs", f"📝 Summarizing sources...")
|
||||
raw_data = scrape_urls(urls_to_scrape)
|
||||
# Summarize Raw Data
|
||||
summary = await summarize(query=sub_query, text=raw_data, agent_role_prompt=self.role)
|
||||
summary = await summarize(query=sub_query, text=raw_data, agent_role_prompt=self.role, cfg=self.cfg)
|
||||
|
||||
# Run Tasks
|
||||
return summary
|
||||
|
||||
async def stream_output(self, type, output):
|
||||
"""
|
||||
Streams output to the websocket
|
||||
Args:
|
||||
type:
|
||||
output:
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not self.websocket:
|
||||
return print(output)
|
||||
await self.websocket.send_json({"type": type, "output": output})
|
||||
|
||||
@@ -1,25 +1,41 @@
|
||||
from gpt_researcher.utils.llm import *
|
||||
from gpt_researcher.config.config import Config
|
||||
from gpt_researcher.scraper.scraper import Scraper
|
||||
from gpt_researcher.scraper import Scraper
|
||||
from gpt_researcher.master.prompts import *
|
||||
import json
|
||||
|
||||
cfg = Config()
|
||||
|
||||
def get_retriever(retriever):
|
||||
"""
|
||||
Gets the retriever
|
||||
Args:
|
||||
retriever: retriever name
|
||||
|
||||
def get_retriever():
|
||||
if cfg.retriver == "duckduckgo":
|
||||
from gpt_researcher.retrievers.duckduckgo.duckduckgo import Duckduckgo
|
||||
Returns:
|
||||
retriever: Retriever class
|
||||
|
||||
"""
|
||||
if retriever == "duckduckgo":
|
||||
from gpt_researcher.retrievers import Duckduckgo
|
||||
retriever = Duckduckgo
|
||||
elif cfg.retriver == "tavily":
|
||||
from gpt_researcher.retrievers.tavily_search.tavily_search import TavilySearch
|
||||
elif retriever == "tavily":
|
||||
from gpt_researcher.retrievers import TavilySearch
|
||||
retriever = TavilySearch
|
||||
else:
|
||||
raise Exception("Retriever not found.")
|
||||
return retriever
|
||||
|
||||
|
||||
async def choose_agent(query):
|
||||
async def choose_agent(query, cfg):
|
||||
"""
|
||||
Chooses the agent automatically
|
||||
Args:
|
||||
query: original query
|
||||
cfg: Config
|
||||
|
||||
Returns:
|
||||
agent: Agent name
|
||||
agent_role_prompt: Agent role prompt
|
||||
"""
|
||||
try:
|
||||
response = await create_chat_completion(
|
||||
model=cfg.smart_llm_model,
|
||||
@@ -35,7 +51,18 @@ async def choose_agent(query):
|
||||
return "Default Agent", "You are an AI critical thinker research assistant. Your sole purpose is to write well written, critically acclaimed, objective and structured reports on given text."
|
||||
|
||||
|
||||
async def get_sub_queries(query, agent_role_prompt):
|
||||
async def get_sub_queries(query, agent_role_prompt, cfg):
|
||||
"""
|
||||
Gets the sub queries
|
||||
Args:
|
||||
query: original query
|
||||
agent_role_prompt: agent role prompt
|
||||
cfg: Config
|
||||
|
||||
Returns:
|
||||
sub_queries: List of sub queries
|
||||
|
||||
"""
|
||||
response = await create_chat_completion(
|
||||
model=cfg.smart_llm_model,
|
||||
messages=[
|
||||
@@ -49,6 +76,15 @@ async def get_sub_queries(query, agent_role_prompt):
|
||||
|
||||
|
||||
def scrape_urls(urls):
|
||||
"""
|
||||
Scrapes the urls
|
||||
Args:
|
||||
urls: List of urls
|
||||
|
||||
Returns:
|
||||
text: str
|
||||
|
||||
"""
|
||||
text = ""
|
||||
try:
|
||||
text = Scraper(urls).run()
|
||||
@@ -57,7 +93,19 @@ def scrape_urls(urls):
|
||||
return text
|
||||
|
||||
|
||||
async def summarize(query, text, agent_role_prompt):
|
||||
async def summarize(query, text, agent_role_prompt, cfg):
|
||||
"""
|
||||
Summarizes the text
|
||||
Args:
|
||||
query:
|
||||
text:
|
||||
agent_role_prompt:
|
||||
cfg:
|
||||
|
||||
Returns:
|
||||
summary:
|
||||
|
||||
"""
|
||||
summary = ""
|
||||
try:
|
||||
summary = await create_chat_completion(
|
||||
@@ -73,11 +121,24 @@ async def summarize(query, text, agent_role_prompt):
|
||||
return summary
|
||||
|
||||
|
||||
async def generate_report(query, context, agent_role_prompt, report_type, websocket):
|
||||
async def generate_report(query, context, agent_role_prompt, report_type, websocket, cfg):
|
||||
"""
|
||||
generates the final report
|
||||
Args:
|
||||
query:
|
||||
context:
|
||||
agent_role_prompt:
|
||||
report_type:
|
||||
websocket:
|
||||
cfg:
|
||||
|
||||
Returns:
|
||||
report:
|
||||
|
||||
"""
|
||||
generate_prompt = get_report_by_type(report_type)
|
||||
report = ""
|
||||
try:
|
||||
print("Generating report...")
|
||||
report = await create_chat_completion(
|
||||
model=cfg.smart_llm_model,
|
||||
messages=[
|
||||
@@ -89,7 +150,6 @@ async def generate_report(query, context, agent_role_prompt, report_type, websoc
|
||||
websocket=websocket,
|
||||
max_tokens=cfg.smart_token_limit
|
||||
)
|
||||
print("Report generated.")
|
||||
except Exception as e:
|
||||
print(f"{Fore.RED}Error in generate_report: {e}{Style.RESET_ALL}")
|
||||
|
||||
|
||||
4
gpt_researcher/retrievers/__init__.py
Normal file
4
gpt_researcher/retrievers/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .tavily_search.tavily_search import TavilySearch
|
||||
from .duckduckgo.duckduckgo import Duckduckgo
|
||||
|
||||
__all__ = ["TavilySearch", "Duckduckgo"]
|
||||
@@ -1,9 +1,11 @@
|
||||
from itertools import islice
|
||||
from duckduckgo_search import DDGS
|
||||
from gpt_researcher.scraper.scraper import Scraper
|
||||
|
||||
|
||||
class Duckduckgo:
|
||||
"""
|
||||
Duckduckgo API Retriever
|
||||
"""
|
||||
def __init__(self):
|
||||
self.ddg = DDGS()
|
||||
|
||||
|
||||
@@ -6,12 +6,25 @@ from tavily import TavilyClient
|
||||
|
||||
|
||||
class TavilySearch():
|
||||
"""
|
||||
Tavily API Retriever
|
||||
"""
|
||||
def __init__(self, query):
|
||||
"""
|
||||
Initializes the TavilySearch object
|
||||
Args:
|
||||
query:
|
||||
"""
|
||||
self.query = query
|
||||
self.api_key = self.get_api_key()
|
||||
self.client = TavilyClient(self.api_key)
|
||||
|
||||
def get_api_key(self):
|
||||
"""
|
||||
Gets the Tavily API key
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# Get the API key
|
||||
try:
|
||||
api_key = os.environ["TAVILY_API_KEY"]
|
||||
@@ -20,6 +33,11 @@ class TavilySearch():
|
||||
return api_key
|
||||
|
||||
def search(self):
|
||||
"""
|
||||
Searches the query
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# Search the query
|
||||
results = self.client.search(self.query, search_depth="basic", max_results=5)
|
||||
# Return the results
|
||||
|
||||
3
gpt_researcher/scraper/__init__.py
Normal file
3
gpt_researcher/scraper/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .scraper import Scraper
|
||||
|
||||
__all__ = ["Scraper"]
|
||||
@@ -5,7 +5,15 @@ from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
class Scraper:
|
||||
"""
|
||||
Scraper class to extract the content from the links
|
||||
"""
|
||||
def __init__(self, urls):
|
||||
"""
|
||||
Initialize the Scraper class.
|
||||
Args:
|
||||
urls:
|
||||
"""
|
||||
self.urls = urls
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
|
||||
@@ -68,7 +68,7 @@ async def send_chat_completion_request(
|
||||
return await stream_response(model, messages, temperature, max_tokens, llm_provider, websocket)
|
||||
|
||||
|
||||
async def stream_response(model, messages, temperature, max_tokens, llm_provider, websocket):
|
||||
async def stream_response(model, messages, temperature, max_tokens, llm_provider, websocket=None):
|
||||
paragraph = ""
|
||||
response = ""
|
||||
print(f"streaming response...")
|
||||
@@ -86,7 +86,10 @@ async def stream_response(model, messages, temperature, max_tokens, llm_provider
|
||||
response += content
|
||||
paragraph += content
|
||||
if "\n" in paragraph:
|
||||
await websocket.send_json({"type": "report", "output": paragraph})
|
||||
if websocket is not None:
|
||||
await websocket.send_json({"type": "report", "output": paragraph})
|
||||
else:
|
||||
print(paragraph)
|
||||
paragraph = ""
|
||||
print(f"streaming response complete")
|
||||
return response
|
||||
|
||||
@@ -7,12 +7,15 @@ from gpt_researcher.master.agent import GPTResearcher
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""Manage websockets"""
|
||||
def __init__(self):
|
||||
"""Initialize the WebSocketManager class."""
|
||||
self.active_connections: List[WebSocket] = []
|
||||
self.sender_tasks: Dict[WebSocket, asyncio.Task] = {}
|
||||
self.message_queues: Dict[WebSocket, asyncio.Queue] = {}
|
||||
|
||||
async def start_sender(self, websocket: WebSocket):
|
||||
"""Start the sender task."""
|
||||
queue = self.message_queues[websocket]
|
||||
while True:
|
||||
message = await queue.get()
|
||||
@@ -22,23 +25,27 @@ class WebSocketManager:
|
||||
break
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
"""Connect a websocket."""
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
self.message_queues[websocket] = asyncio.Queue()
|
||||
self.sender_tasks[websocket] = asyncio.create_task(self.start_sender(websocket))
|
||||
|
||||
async def disconnect(self, websocket: WebSocket):
|
||||
"""Disconnect a websocket."""
|
||||
self.active_connections.remove(websocket)
|
||||
self.sender_tasks[websocket].cancel()
|
||||
del self.sender_tasks[websocket]
|
||||
del self.message_queues[websocket]
|
||||
|
||||
async def start_streaming(self, task, report_type, websocket):
|
||||
"""Start streaming the output."""
|
||||
report = await run_agent(task, report_type, websocket)
|
||||
return report
|
||||
|
||||
|
||||
async def run_agent(task, report_type, websocket):
|
||||
"""Run the agent."""
|
||||
# measure time
|
||||
start_time = datetime.datetime.now()
|
||||
# run agent
|
||||
|
||||
@@ -87,7 +87,7 @@ class GPTResearcher:
|
||||
|
||||
# Example
|
||||
async def main():
|
||||
researcher = GPTResearcher.from_json("config.json")
|
||||
researcher = GPTResearcher.from_json("../config.json")
|
||||
|
||||
report, path = await researcher.conduct_research("rank the strongest characters in jujutsu kaisen",
|
||||
"research_report")
|
||||
|
||||
Reference in New Issue
Block a user