mirror of
https://github.com/jbpayton/llm-auto-forge.git
synced 2024-06-08 15:46:36 +03:00
Added paged reader and stripped extraaneous tags from responses
This commit is contained in:
@@ -1,17 +0,0 @@
|
||||
import sys
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from langchain.tools import tool
|
||||
|
||||
@tool("meme_creator", return_direct=False)
|
||||
def meme_creator(image_path: str, top_text: str, bottom_text: str) -> str:
|
||||
"""This tool creates a meme with the given image and text."""
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
draw = ImageDraw.Draw(image)
|
||||
font = ImageFont.truetype('arial.ttf', size=45)
|
||||
draw.text((10, 10), top_text, fill='white', font=font)
|
||||
draw.text((10, image.height - 60), bottom_text, fill='white', font=font)
|
||||
image.save('meme.png')
|
||||
return 'Meme created successfully.'
|
||||
except:
|
||||
return 'Error: ' + str(sys.exc_info())
|
||||
@@ -12,6 +12,7 @@ from langchain.utilities import GoogleSearchAPIWrapper
|
||||
|
||||
from tools.ToolRegistrationTool import tool_registration_tool
|
||||
from tools.ToolQueryTool import tool_query_tool
|
||||
from tools.WebScrapingCache import query_website, paged_read_website
|
||||
|
||||
util.load_secrets()
|
||||
|
||||
@@ -38,7 +39,9 @@ GoogleSearchTool = Tool(
|
||||
|
||||
tools = [GoogleSearchTool,
|
||||
tool_query_tool,
|
||||
tool_registration_tool] + file_tools
|
||||
tool_registration_tool,
|
||||
paged_read_website,
|
||||
] + file_tools
|
||||
|
||||
# Initialize our agents with their respective roles and system prompts
|
||||
tool_making_agent = DialogueAgentWithTools(name="ToolMaker",
|
||||
@@ -50,8 +53,7 @@ tool_making_agent = DialogueAgentWithTools(name="ToolMaker",
|
||||
callbacks=[StreamingStdOutCallbackHandler()]),
|
||||
tools=tools)
|
||||
|
||||
tool_making_agent.receive("HumanUser", "Create a tool that can request html from a url, save it to a chroma store and "
|
||||
"ask a question about it. tools/ToolRegistry.py has an example of using chroma")
|
||||
tool_making_agent.receive("HumanUser", "Use th internet and create the funniest meme picture ever.")
|
||||
|
||||
tool_making_agent.send()
|
||||
|
||||
|
||||
@@ -14,4 +14,5 @@ sentence_transformers
|
||||
google-api-python-client
|
||||
|
||||
requests~=2.31.0
|
||||
beautifulsoup4~=4.12.2
|
||||
beautifulsoup4~=4.12.2
|
||||
markdownify~=0.11.6
|
||||
@@ -1,28 +1,29 @@
|
||||
import json
|
||||
from typing import Type
|
||||
import re
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools import tool
|
||||
from langchain.vectorstores import Chroma
|
||||
from pydantic import BaseModel, Field
|
||||
import markdownify
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
class WebScrapingCache:
|
||||
_instance = None
|
||||
_embeddings = None
|
||||
_vector_store = None
|
||||
_initialised = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ToolRegistry, cls).__new__(cls)
|
||||
cls._instance = super(WebScrapingCache, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if WebScrapingCache._initialised:
|
||||
return
|
||||
WebScrapingCache._initialised = True
|
||||
self._embeddings = None
|
||||
self._vector_store = None
|
||||
self._url_list = []
|
||||
|
||||
def add_documents(self, docs):
|
||||
if self._embeddings is None:
|
||||
@@ -33,14 +34,27 @@ class ToolRegistry:
|
||||
else:
|
||||
self._vector_store.add_documents(docs)
|
||||
|
||||
def query_website(self, url: str, query: str):
|
||||
self.scrape_website(url)
|
||||
def query_website(self, url: str, query: str, keep_links: bool = False):
|
||||
self.scrape_website(url, keep_links=keep_links)
|
||||
filter_dict = dict()
|
||||
filter_dict["url"] = url
|
||||
results = self._vector_store.similarity_search_with_score(query, 5, filter=filter_dict)
|
||||
print(results[0])
|
||||
results = self._vector_store.max_marginal_relevance_search(query, 3, filter=filter_dict)
|
||||
return results
|
||||
|
||||
def paged_read(self, url: str, page: int, keep_links: bool = False):
|
||||
docs = self.scrape_website(url, keep_links=keep_links, chunk_size=2000, chunk_overlap=0, cache=False)
|
||||
if docs is None:
|
||||
return "Error scraping website"
|
||||
if page > len(docs):
|
||||
return "Page not found"
|
||||
return str(docs[page]) + "\n\n" + f" = Page {page} of {len(docs)-1}"
|
||||
|
||||
def scrape_website(self, url: str, keep_links=False, chunk_size=1024, chunk_overlap=128, cache=True):
|
||||
link_suffix = "(Keep links)" if keep_links else ""
|
||||
if url + link_suffix in self._url_list and cache:
|
||||
print("Site in cache, skipping...")
|
||||
return
|
||||
|
||||
def scrape_website(self, url: str):
|
||||
print("Scraping website...")
|
||||
|
||||
# Make the request
|
||||
@@ -49,38 +63,55 @@ class ToolRegistry:
|
||||
|
||||
# Check the response status code
|
||||
if response.status_code == 200:
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
text = soup.get_text()
|
||||
links = [a['href'] for a in soup.find_all('a', href=True)]
|
||||
if keep_links:
|
||||
tags_to_strip = []
|
||||
else:
|
||||
tags_to_strip = ['a']
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=1000, chunk_overlap=200)
|
||||
docs = text_splitter.create_documents([text], metadatas=[{"url": url, "type": "text"}])
|
||||
self.add_documents(docs)
|
||||
stripped_text = re.sub(r'<script.*?</script>', '', str(response.content))
|
||||
stripped_text = re.sub(r'<style.*?</style>', '', str(stripped_text))
|
||||
stripped_text = re.sub(r'<meta.*?</meta>', '', str(stripped_text))
|
||||
|
||||
text = markdownify.markdownify(stripped_text, strip=tags_to_strip)
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap)
|
||||
docs = text_splitter.create_documents([text], metadatas=[{"url": url}])
|
||||
if cache:
|
||||
self.add_documents(docs)
|
||||
self._url_list.append(url + link_suffix)
|
||||
|
||||
return docs
|
||||
|
||||
else:
|
||||
print(f"HTTP request failed with status code {response.status_code}")
|
||||
return f"HTTP request failed with status code {response.status_code}"
|
||||
|
||||
|
||||
|
||||
class ScrapeWebsiteInput(BaseModel):
|
||||
"""Inputs for scrape_website"""
|
||||
objective: str = Field(
|
||||
description="The objective & task that users give to the agent")
|
||||
url: str = Field(description="The url of the website to be scraped")
|
||||
@tool("query_website", return_direct=False)
|
||||
def query_website(website_url: str, query: str, keep_links: bool = False) -> str:
|
||||
"""useful when you need to get data from a website url, passing both url and the query to the function; DO NOT
|
||||
make up any url, the url should only be from the search results. Links can be enabled or disabled as needed. """
|
||||
return str(WebScrapingCache().query_website(website_url, query, keep_links=keep_links))
|
||||
|
||||
|
||||
class ScrapeWebsiteTool(BaseTool):
|
||||
name = "scrape_website"
|
||||
description = "useful when you need to get data from a website url, passing both url and objective to the function; DO NOT make up any url, the url should only be from the search results"
|
||||
args_schema: Type[BaseModel] = ScrapeWebsiteInput
|
||||
|
||||
def _run(self, objective: str, url: str):
|
||||
return NotImplementedError("error here")
|
||||
|
||||
def _arun(self, url: str):
|
||||
raise NotImplementedError("error here")
|
||||
@tool("paged_read_website", return_direct=False)
|
||||
def paged_read_website(website_url: str, page: int) -> str:
|
||||
"""useful when you need to read data from a website without overflowing context, passing both url and the page number (zero indexed) to the function; DO NOT
|
||||
make up any url, the url should only be from the search results. Links can be enabled or disabled as needed. """
|
||||
return str(WebScrapingCache().paged_read( website_url, page))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ToolRegistry().query_website('https://blazblue.fandom.com/wiki/Rachel_Alucard',
|
||||
'What does Rachel Alucard think about bell peppers?')
|
||||
query = "What does Rachel Alucard look like?"
|
||||
print(query)
|
||||
results = WebScrapingCache().query_website('https://blazblue.fandom.com/wiki/Rachel_Alucard', query)
|
||||
print(str(results))
|
||||
|
||||
query = "Rachel Alucard and bell peppers?"
|
||||
print(query)
|
||||
results = WebScrapingCache().query_website('https://blazblue.fandom.com/wiki/Rachel_Alucard', query)
|
||||
print(str(results))
|
||||
|
||||
doc = WebScrapingCache().paged_read('https://www.deeplearning.ai/resources/natural-language-processing/', 5)
|
||||
print(doc)
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from tools.ToolQueryTool import tool_query_tool
|
||||
from tools.ToolRegistrationTool import tool_registration_tool
|
||||
from tools.WebScrapingCache import query_website
|
||||
from tools.WebScrapingCache import paged_read_website
|
||||
|
||||
Reference in New Issue
Block a user