Initial commit

This commit is contained in:
shiroyasha13
2024-05-19 19:53:10 +05:30
committed by GitHub
parent abfbfc3422
commit ce498257fd
5 changed files with 265 additions and 0 deletions

135
ingest.py Normal file
View File

@@ -0,0 +1,135 @@
# Used to load the PDF files from the source_files directory into langchain documents
from langchain_community.document_loaders import PyPDFLoader
# Used to recursively chunk the langchain document's page_content into appropriate size based on context-performance and LLM based on a
# list of characters
from langchain_text_splitters import RecursiveCharacterTextSplitter
# Using ollama for serving llama3 embeddings
from langchain_community.embeddings import OllamaEmbeddings
# Using chroma vectorstore
from langchain_community.vectorstores import Chroma
# For setting up the chat template
from langchain_core.prompts import ChatPromptTemplate
# for using the groq llama3 api
from langchain_groq import ChatGroq
import os
import subprocess
import json
os.environ['GROQ_API_KEY'] = ''
class MyCustomError(Exception):
def __init__(self, message):
super().__init__(message)
class local_pdf_gpt_ingester:
def __init__(self, path, embedding_model, vectorstore, images=True):
self.path = path
# since we are using ollama, the model should have already been served, if not then we can throw an error or pull the model oursellves
self.embeddings = embedding_model
# here we are only using the conditions for Chroma, but can be extended
self.vectorstore = vectorstore
self.images = images
def extract_file_metadata(self, context):
try:
llm = ChatGroq(model_name = 'llama3-8b-8192')
system = "You are a helpful assistant.From the given text extract the client name, service provider name and the date of the contract as a JSON response with keys as client, service_provider and contract_date(in the format dd-mm-YYYY). Do not return any other additonal messages."
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{text}")])
chain = prompt | llm
response = chain.invoke({"text":context})
return response.content
except Exception as e:
print(e)
return False
def parse_meta(self, info):
text = info.strip()
if not text.startswith('{') and '{' not in text:
text = '{' + text
# Check and add '}' at the end if missing
if not text.endswith('}') and '}' not in text:
text = text + '}'
dictionary = json.loads(text)
return dictionary
def pdf_to_documents(self, **kwargs):
# using the langchain pdf loader for loading the documents and converting images to flat text( only OCR happens, not an image
# embeddding) - We cannot use this loader for multimoddal content (we can use the pypdf library to extract images and text separately
# or use the unstructured.io library)
files = os.listdir(self.path)
files = [f"{self.path}/{f}" for f in files if f.endswith(".pdf")]
return_chunks = []
for f in files:
docs = PyPDFLoader(f,extract_images = self.images)
# setting up the splitter (here the splitter is configured to work with llama3, need to change appropriately)
splitter = RecursiveCharacterTextSplitter(chunk_size = kwargs.get('chunk_size', 3800),
chunk_overlap = kwargs.get('chunk_overlap',50),
separators = kwargs.get('separators',["\n\n","\n"," ","."]),
is_separator_regex=kwargs.get('is_separator_regex', False)
)
# splitting the documents into appropriate chunks and returning the chunked content
doc_chunks = docs.load_and_split(text_splitter = splitter)
my_chunk = doc_chunks[0].page_content
meta_info = self.extract_file_metadata(my_chunk)
meta_info = self.parse_meta(meta_info)
return_chunks.append({'doc':[i.page_content for i in doc_chunks], 'metadata': meta_info})
return return_chunks
def pull_model_to_ollama(model_name):
# Construct the command
command = ["ollama", "pull", model_name]
try:
# Execute the command
result = subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# Print the standard output and error
print("Output:\n", result.stdout)
print("Error (if any):\n", result.stderr)
if result.stderr:
return False
else:
return True
except subprocess.CalledProcessError as e:
# Handle errors in the command execution
print(f"An error occurred while pulling the model '{model_name}':")
print("Return code:", e.returncode)
print("Output:\n", e.output)
print("Error:\n", e.stderr)
def load_embeddings(self):
# need to check whether the embeddings are already present, if not we have to validate and pull the model into our ollama
try:
# find all the llama3 models here - https://ollama.com/library/llama3:8b
# we are using the default connection and parameters for ollama emeddings
self.embeddings = OllamaEmbeddings(model=self.embeddings)
except:
model_check = self.pull_model_to_ollama(self.embeddings)
if model_check:
self.embeddings = OllamaEmbeddings(model=self.embeddings)
else:
raise MyCustomError(f"The provided model name - {self.embeddings} is invalid. Find the list of supported llama3 models here - https://ollama.com/library/llama3:8b")
def embed_and_store(self, documents):
# we are going to use the ollama embeddings with Chroma store here
if self.vectorstore=='chroma':
try:
# default methodis get_or_create collection, so if the name already exists this will append
# and not overwrite that content
self.vectorstore = Chroma(collection_name="MSA_4k_chunks",
embedding_function=self.embeddings,
persist_directory="./db")
except Exception as e:
raise MyCustomError(f"Chroma DB error - {e}")
for doc in documents:
metas = [doc['metadata'] for i in doc['doc']]
# optionally you can also IDs. passing metadata can be donw like this, or you can directly
# add the metadata to the document object itself. No specific reason for me doing this.
# adding metadata also allows easy updation and deletion, we can always check if a document
# is present or absent based on the metdata, right.
self.vectorstore.add_texts(texts = doc['doc'],
metadatas = metas)
print("The documents have been embedded and stored in the vector database")
if __name__=='__main__':
msa_bot = local_pdf_gpt_ingester("source_files/", "llama3", "chroma")
msa_docs = msa_bot.pdf_to_documents()
msa_bot.load_embeddings()
msa_bot.embed_and_store(msa_docs)

130
model.py Normal file
View File

@@ -0,0 +1,130 @@
from langchain_community.embeddings import OllamaEmbeddings
# Using chroma vectorstore
from langchain_community.vectorstores import Chroma
# For setting up the chat template
from langchain_core.prompts import ChatPromptTemplate
# for using the groq llama3 api
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
import os
import json
from fuzzywuzzy import process
os.environ['GROQ_API_KEY'] = ''
class MyCustomError(Exception):
def __init__(self, message):
super().__init__(message)
class local_llama3_chatbot:
def __init__(self, model, vector_dir, embedding):
self.llm = model
self.db = vector_dir
self.embedding = embedding
self.clients = []
def load_base_assets(self):
self.llm = ChatGroq(model_name = self.llm, temperature=0)
self.embedding = OllamaEmbeddings(model=self.embedding)
self.db = Chroma(collection_name="MSA_4k_chunks",
embedding_function=self.embedding,
persist_directory=self.db)
def parse_meta(self, info):
text = info.strip()
if '{' in text:
text = text[text.find('{'):]
print("Before conversion - "+info)
if not text.startswith('{') and '{' not in text:
text = '{' + text
# Check and add '}' at the end if missing
if not text.endswith('}') and '}' not in text:
text = text + '}'
dictionary = json.loads(text)
print("after conversion - "+str(dictionary))
return dictionary
def fetch_chat_metadata(self, user_query):
# function to get the required metadata from the query and also understand the nature of the
# query( to check for comparative documents )
system = "You are a helpful assistant who answers exactly in the format specified by the user."
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{text}")])
chain = prompt | self.llm
user_query = 'From the given text extract the organization name and the year as a python dictionary response with keys as client and contract_date(in the format YYYY).Do not return any other additonal messages. For unfound values return the value as "" for the ccorresponding key. Some sample client names will look like ["abc corp", "abc org", "abc inc", "abc ai"] TEXT: '+user_query
meta_response = chain.invoke({"text":user_query})
metadata = meta_response.content
self.llm='llama3-8b-8192'
self.llm = ChatGroq(model_name = self.llm)
chain = prompt | self.llm
user_query = 'From the given text extract the organization name and the year of the contract as a python dictionary response with keys as client and contract_date(in the format YYYY).Do not return any other additonal messages. For unfound values return the value as "" for the ccorresponding key. Some sample client names will look like ["abc corp", "abc org", "abc inc", "abc ai"] TEXT: '+user_query
check_query = f'From the given text, check whether it is a cross-document question and return the response as "cross-document" if True and if False return response as "solo-document". Do not reurn any other additional message in your response. Cross-document question involves more than 2 organization names and asking for aggregations across organizations. TEXT: {user_query}"'
response = chain.invoke({"text":check_query})
q_type = response.content
try:
metadata = eval(metadata)
except:
metadata = self.parse_meta(metadata)
return {'question_type': q_type, 'metadata': metadata}
def model_call(self, user_query, classification):
client = classification['metadata']['client']
year = classification['metadata']['contract_date']
searchtype = classification['question_type']
metadatas = self.db.get()['metadatas']
self.clients = [i['client'] for i in metadatas]
correct_client = process.extractOne(client, self.clients)[0]
print(f'client is {correct_client}')
retriever = self.db.as_retriever(search_kwargs={"k":3,"filter":{
'client': correct_client
}})
# You can simplify the filter dict dynamically instead of these many if statements.
# I have done this for simplicity
if searchtype=='solo-document' and isinstance(year, int)==False:
print("entered - "+searchtype)
retriever = self.db.as_retriever(search_kwargs={"k":3,
"filter":{
'client': correct_client
}})
elif searchtype=='cross-document' and isinstance(year,int):
print("entered - "+searchtype)
retriever = self.db.as_retriever(search_kwargs={"k":3,
"filter":{
'contract_date': {'$gt': f'01-01-{year}'}
}})
elif searchtype=='solo-document' and isinstance(year,int):
print("entered - "+searchtype)
retriever = self.db.as_retriever(search_kwargs={"k":3,
"filter":{
'client': correct_client,
'contract_date': {'$gt': f'01-01-{year}'}
}})
qa_chain_response = self.final_llm_call(user_query,retriever)
# qa_chain_response = self.vector_searcher(user_query, correct_client)
return qa_chain_response
def vector_searcher(self, user_query, correct_client):
return self.db.similarity_search(user_query, k=3,filter={'client': correct_client})
def final_llm_call(self, user_query, retriever):
system_prompt = (
"Use the given context to answer the question. "
"If you don't know the answer, say you don't know. "
"Use three sentence maximum and keep the answer concise. "
"Context: {context}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(self.llm, prompt)
chain = create_retrieval_chain(retriever, question_answer_chain)
response = chain.invoke({"input":user_query})
return response
if __name__ == '__main__':
chatbot = local_llama3_chatbot('llama3-8b-8192', './db','llama3')
query = input("query : ") # 'what is the address of shiro corp'
chatbot.load_base_assets()
while query!='exit':
classification = chatbot.fetch_chat_metadata(query)
output = chatbot.model_call(query, classification)
# print(output)
print("Output is : ",output['answer'])
query = input("query : ")
print("========================End Of Conversation=================================")

Binary file not shown.