This commit is contained in:
Rotem B. Weiss
2023-11-12 16:30:28 +02:00
parent 6b5b9b6b34
commit 4acd5116e0
25 changed files with 222 additions and 36 deletions

View File

@@ -1,13 +1,13 @@
# main
import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from permchain_example.researcher import Researcher
from permchain_example.editor_actors.editor import EditorActor
from permchain_example.reviser_actors.reviser import ReviserActor
from permchain_example.search_actors.gpt_researcher import GPTResearcherActor
from permchain_example.writer_actors.writer import WriterActor
from permchain_example.research_team import ResearchTeam
from examples.permchain_example.researcher import Researcher
from examples.permchain_example.editor_actors.editor import EditorActor
from examples.permchain_example.reviser_actors.reviser import ReviserActor
from examples.permchain_example.search_actors.gpt_researcher import GPTResearcherActor
from examples.permchain_example.writer_actors.writer import WriterActor
from examples.permchain_example.research_team import ResearchTeam
from gpt_researcher_old.processing.text import md_to_pdf

22
examples/sample_report.py Normal file
View File

@@ -0,0 +1,22 @@
from gpt_researcher import GPTResearcher
import asyncio
async def main():
"""
This is a sample script that shows how to run a research report.
"""
# Query
query = "What happened in the latest burning man floods?"
# Report Type
report_type = "research_report"
# Run Research
researcher = GPTResearcher(query=query, report_type=report_type, config_path=None)
report = await researcher.run()
return report
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,4 @@
from .master import GPTResearcher
from .config import Config
__all__ = ['GPTResearcher', 'Config']

View File

@@ -0,0 +1,3 @@
from .config import Config
__all__ = ['Config']

View File

@@ -1,8 +1,14 @@
# config file
import json
class Config:
def __init__(self):
self.retriver = "tavily"
"""Config class for GPT Researcher."""
def __init__(self, config_file: str = None):
"""Initialize the config class."""
self.config_file = config_file
self.retriever = "tavily"
self.llm_provider = "ChatOpenAI"
self.fast_llm_model = "gpt-3.5-turbo-16k"
self.smart_llm_model = "gpt-4"
@@ -12,8 +18,16 @@ class Config:
self.summary_token_limit = 700
self.temperature = 1.0
self.user_agent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) " \
"Chrome/83.0.4103.97 Safari/537.36 "
"Chrome/83.0.4103.97 Safari/537.36 "
self.memory_backend = "local"
self.load_config_file()
def load_config_file(self) -> None:
"""Load the config file."""
if self.config_file is None:
return None
with open(self.config_file, "r") as f:
config = json.load(f)
for key, value in config.items():
self.__dict__[key] = value

View File

@@ -0,0 +1,3 @@
from .agent import GPTResearcher
__all__ = ['GPTResearcher']

View File

@@ -1,26 +1,44 @@
import time
from gpt_researcher.config import Config
from gpt_researcher.master.functions import *
class GPTResearcher:
def __init__(self, query, report_type, websocket=None):
"""
GPT Researcher
"""
def __init__(self, query, report_type, config_path=None, websocket=None):
"""
Initialize the GPT Researcher class.
Args:
query:
report_type:
config_path:
websocket:
"""
self.query = query
self.agent = None
self.role = None
self.report_type = report_type
self.websocket = websocket
self.retriever = get_retriever()
self.cfg = Config(config_path)
self.retriever = get_retriever(self.cfg.retriever)
self.context = []
self.visited_urls = set()
async def run(self):
"""
Runs the GPT Researcher
Returns:
Report
"""
print(f"🔎 Running research for '{self.query}'...")
# Generate Agent
self.agent, self.role = await choose_agent(self.query)
self.agent, self.role = await choose_agent(self.query, self.cfg)
await self.stream_output("logs", self.agent)
# Generate Sub-Queries
sub_queries = await get_sub_queries(self.query, self.role)
sub_queries = await get_sub_queries(self.query, self.role, self.cfg)
await self.stream_output("logs",
f"🧠 I will conduct my research based on the following queries: {sub_queries}...")
@@ -35,11 +53,19 @@ class GPTResearcher:
await self.stream_output("logs", f"✍️ Writing {self.report_type} for research task: {self.query}...")
report = await generate_report(query=self.query, context=self.context,
agent_role_prompt=self.role, report_type=self.report_type,
websocket=self.websocket)
websocket=self.websocket, cfg=self.cfg)
time.sleep(1)
return report
async def run_sub_query(self, sub_query):
"""
Runs a sub-query
Args:
sub_query:
Returns:
Summary
"""
# Get Urls
retriever = self.retriever(sub_query)
urls = retriever.search()
@@ -54,12 +80,21 @@ class GPTResearcher:
await self.stream_output("logs", f"📝 Summarizing sources...")
raw_data = scrape_urls(urls_to_scrape)
# Summarize Raw Data
summary = await summarize(query=sub_query, text=raw_data, agent_role_prompt=self.role)
summary = await summarize(query=sub_query, text=raw_data, agent_role_prompt=self.role, cfg=self.cfg)
# Run Tasks
return summary
async def stream_output(self, type, output):
"""
Streams output to the websocket
Args:
type:
output:
Returns:
None
"""
if not self.websocket:
return print(output)
await self.websocket.send_json({"type": type, "output": output})

View File

@@ -1,25 +1,41 @@
from gpt_researcher.utils.llm import *
from gpt_researcher.config.config import Config
from gpt_researcher.scraper.scraper import Scraper
from gpt_researcher.scraper import Scraper
from gpt_researcher.master.prompts import *
import json
cfg = Config()
def get_retriever(retriever):
"""
Gets the retriever
Args:
retriever: retriever name
def get_retriever():
if cfg.retriver == "duckduckgo":
from gpt_researcher.retrievers.duckduckgo.duckduckgo import Duckduckgo
Returns:
retriever: Retriever class
"""
if retriever == "duckduckgo":
from gpt_researcher.retrievers import Duckduckgo
retriever = Duckduckgo
elif cfg.retriver == "tavily":
from gpt_researcher.retrievers.tavily_search.tavily_search import TavilySearch
elif retriever == "tavily":
from gpt_researcher.retrievers import TavilySearch
retriever = TavilySearch
else:
raise Exception("Retriever not found.")
return retriever
async def choose_agent(query):
async def choose_agent(query, cfg):
"""
Chooses the agent automatically
Args:
query: original query
cfg: Config
Returns:
agent: Agent name
agent_role_prompt: Agent role prompt
"""
try:
response = await create_chat_completion(
model=cfg.smart_llm_model,
@@ -35,7 +51,18 @@ async def choose_agent(query):
return "Default Agent", "You are an AI critical thinker research assistant. Your sole purpose is to write well written, critically acclaimed, objective and structured reports on given text."
async def get_sub_queries(query, agent_role_prompt):
async def get_sub_queries(query, agent_role_prompt, cfg):
"""
Gets the sub queries
Args:
query: original query
agent_role_prompt: agent role prompt
cfg: Config
Returns:
sub_queries: List of sub queries
"""
response = await create_chat_completion(
model=cfg.smart_llm_model,
messages=[
@@ -49,6 +76,15 @@ async def get_sub_queries(query, agent_role_prompt):
def scrape_urls(urls):
"""
Scrapes the urls
Args:
urls: List of urls
Returns:
text: str
"""
text = ""
try:
text = Scraper(urls).run()
@@ -57,7 +93,19 @@ def scrape_urls(urls):
return text
async def summarize(query, text, agent_role_prompt):
async def summarize(query, text, agent_role_prompt, cfg):
"""
Summarizes the text
Args:
query:
text:
agent_role_prompt:
cfg:
Returns:
summary:
"""
summary = ""
try:
summary = await create_chat_completion(
@@ -73,11 +121,24 @@ async def summarize(query, text, agent_role_prompt):
return summary
async def generate_report(query, context, agent_role_prompt, report_type, websocket):
async def generate_report(query, context, agent_role_prompt, report_type, websocket, cfg):
"""
generates the final report
Args:
query:
context:
agent_role_prompt:
report_type:
websocket:
cfg:
Returns:
report:
"""
generate_prompt = get_report_by_type(report_type)
report = ""
try:
print("Generating report...")
report = await create_chat_completion(
model=cfg.smart_llm_model,
messages=[
@@ -89,7 +150,6 @@ async def generate_report(query, context, agent_role_prompt, report_type, websoc
websocket=websocket,
max_tokens=cfg.smart_token_limit
)
print("Report generated.")
except Exception as e:
print(f"{Fore.RED}Error in generate_report: {e}{Style.RESET_ALL}")

View File

@@ -0,0 +1,4 @@
from .tavily_search.tavily_search import TavilySearch
from .duckduckgo.duckduckgo import Duckduckgo
__all__ = ["TavilySearch", "Duckduckgo"]

View File

@@ -1,9 +1,11 @@
from itertools import islice
from duckduckgo_search import DDGS
from gpt_researcher.scraper.scraper import Scraper
class Duckduckgo:
"""
Duckduckgo API Retriever
"""
def __init__(self):
self.ddg = DDGS()

View File

@@ -6,12 +6,25 @@ from tavily import TavilyClient
class TavilySearch():
"""
Tavily API Retriever
"""
def __init__(self, query):
"""
Initializes the TavilySearch object
Args:
query:
"""
self.query = query
self.api_key = self.get_api_key()
self.client = TavilyClient(self.api_key)
def get_api_key(self):
"""
Gets the Tavily API key
Returns:
"""
# Get the API key
try:
api_key = os.environ["TAVILY_API_KEY"]
@@ -20,6 +33,11 @@ class TavilySearch():
return api_key
def search(self):
"""
Searches the query
Returns:
"""
# Search the query
results = self.client.search(self.query, search_depth="basic", max_results=5)
# Return the results

View File

@@ -0,0 +1,3 @@
from .scraper import Scraper
__all__ = ["Scraper"]

View File

@@ -5,7 +5,15 @@ from bs4 import BeautifulSoup
class Scraper:
"""
Scraper class to extract the content from the links
"""
def __init__(self, urls):
"""
Initialize the Scraper class.
Args:
urls:
"""
self.urls = urls
self.session = requests.Session()
self.session.headers.update({

View File

@@ -68,7 +68,7 @@ async def send_chat_completion_request(
return await stream_response(model, messages, temperature, max_tokens, llm_provider, websocket)
async def stream_response(model, messages, temperature, max_tokens, llm_provider, websocket):
async def stream_response(model, messages, temperature, max_tokens, llm_provider, websocket=None):
paragraph = ""
response = ""
print(f"streaming response...")
@@ -86,7 +86,10 @@ async def stream_response(model, messages, temperature, max_tokens, llm_provider
response += content
paragraph += content
if "\n" in paragraph:
await websocket.send_json({"type": "report", "output": paragraph})
if websocket is not None:
await websocket.send_json({"type": "report", "output": paragraph})
else:
print(paragraph)
paragraph = ""
print(f"streaming response complete")
return response

View File

@@ -7,12 +7,15 @@ from gpt_researcher.master.agent import GPTResearcher
class WebSocketManager:
"""Manage websockets"""
def __init__(self):
"""Initialize the WebSocketManager class."""
self.active_connections: List[WebSocket] = []
self.sender_tasks: Dict[WebSocket, asyncio.Task] = {}
self.message_queues: Dict[WebSocket, asyncio.Queue] = {}
async def start_sender(self, websocket: WebSocket):
"""Start the sender task."""
queue = self.message_queues[websocket]
while True:
message = await queue.get()
@@ -22,23 +25,27 @@ class WebSocketManager:
break
async def connect(self, websocket: WebSocket):
"""Connect a websocket."""
await websocket.accept()
self.active_connections.append(websocket)
self.message_queues[websocket] = asyncio.Queue()
self.sender_tasks[websocket] = asyncio.create_task(self.start_sender(websocket))
async def disconnect(self, websocket: WebSocket):
"""Disconnect a websocket."""
self.active_connections.remove(websocket)
self.sender_tasks[websocket].cancel()
del self.sender_tasks[websocket]
del self.message_queues[websocket]
async def start_streaming(self, task, report_type, websocket):
"""Start streaming the output."""
report = await run_agent(task, report_type, websocket)
return report
async def run_agent(task, report_type, websocket):
"""Run the agent."""
# measure time
start_time = datetime.datetime.now()
# run agent

View File

@@ -87,7 +87,7 @@ class GPTResearcher:
# Example
async def main():
researcher = GPTResearcher.from_json("config.json")
researcher = GPTResearcher.from_json("../config.json")
report, path = await researcher.conduct_research("rank the strongest characters in jujutsu kaisen",
"research_report")