Added paged reader and stripped extraaneous tags from responses

This commit is contained in:
joeyp
2023-08-12 02:03:51 -04:00
parent c118b4b423
commit 587a09cd86
5 changed files with 78 additions and 59 deletions

View File

@@ -1,17 +0,0 @@
import sys
from PIL import Image, ImageDraw, ImageFont
from langchain.tools import tool
@tool("meme_creator", return_direct=False)
def meme_creator(image_path: str, top_text: str, bottom_text: str) -> str:
"""This tool creates a meme with the given image and text."""
try:
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
font = ImageFont.truetype('arial.ttf', size=45)
draw.text((10, 10), top_text, fill='white', font=font)
draw.text((10, image.height - 60), bottom_text, fill='white', font=font)
image.save('meme.png')
return 'Meme created successfully.'
except:
return 'Error: ' + str(sys.exc_info())

View File

@@ -12,6 +12,7 @@ from langchain.utilities import GoogleSearchAPIWrapper
from tools.ToolRegistrationTool import tool_registration_tool
from tools.ToolQueryTool import tool_query_tool
from tools.WebScrapingCache import query_website, paged_read_website
util.load_secrets()
@@ -38,7 +39,9 @@ GoogleSearchTool = Tool(
tools = [GoogleSearchTool,
tool_query_tool,
tool_registration_tool] + file_tools
tool_registration_tool,
paged_read_website,
] + file_tools
# Initialize our agents with their respective roles and system prompts
tool_making_agent = DialogueAgentWithTools(name="ToolMaker",
@@ -50,8 +53,7 @@ tool_making_agent = DialogueAgentWithTools(name="ToolMaker",
callbacks=[StreamingStdOutCallbackHandler()]),
tools=tools)
tool_making_agent.receive("HumanUser", "Create a tool that can request html from a url, save it to a chroma store and "
"ask a question about it. tools/ToolRegistry.py has an example of using chroma")
tool_making_agent.receive("HumanUser", "Use th internet and create the funniest meme picture ever.")
tool_making_agent.send()

View File

@@ -14,4 +14,5 @@ sentence_transformers
google-api-python-client
requests~=2.31.0
beautifulsoup4~=4.12.2
beautifulsoup4~=4.12.2
markdownify~=0.11.6

View File

@@ -1,28 +1,29 @@
import json
from typing import Type
import re
import requests
from bs4 import BeautifulSoup
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools import BaseTool
from langchain.tools import tool
from langchain.vectorstores import Chroma
from pydantic import BaseModel, Field
import markdownify
class ToolRegistry:
class WebScrapingCache:
_instance = None
_embeddings = None
_vector_store = None
_initialised = False
def __new__(cls):
if cls._instance is None:
cls._instance = super(ToolRegistry, cls).__new__(cls)
cls._instance = super(WebScrapingCache, cls).__new__(cls)
return cls._instance
def __init__(self):
if WebScrapingCache._initialised:
return
WebScrapingCache._initialised = True
self._embeddings = None
self._vector_store = None
self._url_list = []
def add_documents(self, docs):
if self._embeddings is None:
@@ -33,14 +34,27 @@ class ToolRegistry:
else:
self._vector_store.add_documents(docs)
def query_website(self, url: str, query: str):
self.scrape_website(url)
def query_website(self, url: str, query: str, keep_links: bool = False):
self.scrape_website(url, keep_links=keep_links)
filter_dict = dict()
filter_dict["url"] = url
results = self._vector_store.similarity_search_with_score(query, 5, filter=filter_dict)
print(results[0])
results = self._vector_store.max_marginal_relevance_search(query, 3, filter=filter_dict)
return results
def paged_read(self, url: str, page: int, keep_links: bool = False):
docs = self.scrape_website(url, keep_links=keep_links, chunk_size=2000, chunk_overlap=0, cache=False)
if docs is None:
return "Error scraping website"
if page > len(docs):
return "Page not found"
return str(docs[page]) + "\n\n" + f" = Page {page} of {len(docs)-1}"
def scrape_website(self, url: str, keep_links=False, chunk_size=1024, chunk_overlap=128, cache=True):
link_suffix = "(Keep links)" if keep_links else ""
if url + link_suffix in self._url_list and cache:
print("Site in cache, skipping...")
return
def scrape_website(self, url: str):
print("Scraping website...")
# Make the request
@@ -49,38 +63,55 @@ class ToolRegistry:
# Check the response status code
if response.status_code == 200:
soup = BeautifulSoup(response.content, "html.parser")
text = soup.get_text()
links = [a['href'] for a in soup.find_all('a', href=True)]
if keep_links:
tags_to_strip = []
else:
tags_to_strip = ['a']
text_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1000, chunk_overlap=200)
docs = text_splitter.create_documents([text], metadatas=[{"url": url, "type": "text"}])
self.add_documents(docs)
stripped_text = re.sub(r'<script.*?</script>', '', str(response.content))
stripped_text = re.sub(r'<style.*?</style>', '', str(stripped_text))
stripped_text = re.sub(r'<meta.*?</meta>', '', str(stripped_text))
text = markdownify.markdownify(stripped_text, strip=tags_to_strip)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
docs = text_splitter.create_documents([text], metadatas=[{"url": url}])
if cache:
self.add_documents(docs)
self._url_list.append(url + link_suffix)
return docs
else:
print(f"HTTP request failed with status code {response.status_code}")
return f"HTTP request failed with status code {response.status_code}"
class ScrapeWebsiteInput(BaseModel):
"""Inputs for scrape_website"""
objective: str = Field(
description="The objective & task that users give to the agent")
url: str = Field(description="The url of the website to be scraped")
@tool("query_website", return_direct=False)
def query_website(website_url: str, query: str, keep_links: bool = False) -> str:
"""useful when you need to get data from a website url, passing both url and the query to the function; DO NOT
make up any url, the url should only be from the search results. Links can be enabled or disabled as needed. """
return str(WebScrapingCache().query_website(website_url, query, keep_links=keep_links))
class ScrapeWebsiteTool(BaseTool):
name = "scrape_website"
description = "useful when you need to get data from a website url, passing both url and objective to the function; DO NOT make up any url, the url should only be from the search results"
args_schema: Type[BaseModel] = ScrapeWebsiteInput
def _run(self, objective: str, url: str):
return NotImplementedError("error here")
def _arun(self, url: str):
raise NotImplementedError("error here")
@tool("paged_read_website", return_direct=False)
def paged_read_website(website_url: str, page: int) -> str:
"""useful when you need to read data from a website without overflowing context, passing both url and the page number (zero indexed) to the function; DO NOT
make up any url, the url should only be from the search results. Links can be enabled or disabled as needed. """
return str(WebScrapingCache().paged_read( website_url, page))
if __name__ == "__main__":
ToolRegistry().query_website('https://blazblue.fandom.com/wiki/Rachel_Alucard',
'What does Rachel Alucard think about bell peppers?')
query = "What does Rachel Alucard look like?"
print(query)
results = WebScrapingCache().query_website('https://blazblue.fandom.com/wiki/Rachel_Alucard', query)
print(str(results))
query = "Rachel Alucard and bell peppers?"
print(query)
results = WebScrapingCache().query_website('https://blazblue.fandom.com/wiki/Rachel_Alucard', query)
print(str(results))
doc = WebScrapingCache().paged_read('https://www.deeplearning.ai/resources/natural-language-processing/', 5)
print(doc)

View File

@@ -1,2 +1,4 @@
from tools.ToolQueryTool import tool_query_tool
from tools.ToolRegistrationTool import tool_registration_tool
from tools.WebScrapingCache import query_website
from tools.WebScrapingCache import paged_read_website