mirror of
https://github.com/raghujhts13/Advanced-LangChain-RAG.git
synced 2024-05-26 19:18:39 +03:00
Initial commit
This commit is contained in:
135
ingest.py
Normal file
135
ingest.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Used to load the PDF files from the source_files directory into langchain documents
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
# Used to recursively chunk the langchain document's page_content into appropriate size based on context-performance and LLM based on a
|
||||
# list of characters
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
# Using ollama for serving llama3 embeddings
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
# Using chroma vectorstore
|
||||
from langchain_community.vectorstores import Chroma
|
||||
# For setting up the chat template
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
# for using the groq llama3 api
|
||||
from langchain_groq import ChatGroq
|
||||
import os
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
os.environ['GROQ_API_KEY'] = ''
|
||||
|
||||
class MyCustomError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
class local_pdf_gpt_ingester:
|
||||
def __init__(self, path, embedding_model, vectorstore, images=True):
|
||||
self.path = path
|
||||
# since we are using ollama, the model should have already been served, if not then we can throw an error or pull the model oursellves
|
||||
self.embeddings = embedding_model
|
||||
# here we are only using the conditions for Chroma, but can be extended
|
||||
self.vectorstore = vectorstore
|
||||
self.images = images
|
||||
def extract_file_metadata(self, context):
|
||||
try:
|
||||
llm = ChatGroq(model_name = 'llama3-8b-8192')
|
||||
system = "You are a helpful assistant.From the given text extract the client name, service provider name and the date of the contract as a JSON response with keys as client, service_provider and contract_date(in the format dd-mm-YYYY). Do not return any other additonal messages."
|
||||
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{text}")])
|
||||
chain = prompt | llm
|
||||
response = chain.invoke({"text":context})
|
||||
return response.content
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return False
|
||||
def parse_meta(self, info):
|
||||
text = info.strip()
|
||||
if not text.startswith('{') and '{' not in text:
|
||||
text = '{' + text
|
||||
# Check and add '}' at the end if missing
|
||||
if not text.endswith('}') and '}' not in text:
|
||||
text = text + '}'
|
||||
dictionary = json.loads(text)
|
||||
return dictionary
|
||||
def pdf_to_documents(self, **kwargs):
|
||||
# using the langchain pdf loader for loading the documents and converting images to flat text( only OCR happens, not an image
|
||||
# embeddding) - We cannot use this loader for multimoddal content (we can use the pypdf library to extract images and text separately
|
||||
# or use the unstructured.io library)
|
||||
files = os.listdir(self.path)
|
||||
files = [f"{self.path}/{f}" for f in files if f.endswith(".pdf")]
|
||||
return_chunks = []
|
||||
for f in files:
|
||||
docs = PyPDFLoader(f,extract_images = self.images)
|
||||
# setting up the splitter (here the splitter is configured to work with llama3, need to change appropriately)
|
||||
splitter = RecursiveCharacterTextSplitter(chunk_size = kwargs.get('chunk_size', 3800),
|
||||
chunk_overlap = kwargs.get('chunk_overlap',50),
|
||||
separators = kwargs.get('separators',["\n\n","\n"," ","."]),
|
||||
is_separator_regex=kwargs.get('is_separator_regex', False)
|
||||
)
|
||||
# splitting the documents into appropriate chunks and returning the chunked content
|
||||
doc_chunks = docs.load_and_split(text_splitter = splitter)
|
||||
my_chunk = doc_chunks[0].page_content
|
||||
meta_info = self.extract_file_metadata(my_chunk)
|
||||
meta_info = self.parse_meta(meta_info)
|
||||
return_chunks.append({'doc':[i.page_content for i in doc_chunks], 'metadata': meta_info})
|
||||
return return_chunks
|
||||
def pull_model_to_ollama(model_name):
|
||||
# Construct the command
|
||||
command = ["ollama", "pull", model_name]
|
||||
try:
|
||||
# Execute the command
|
||||
result = subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
|
||||
# Print the standard output and error
|
||||
print("Output:\n", result.stdout)
|
||||
print("Error (if any):\n", result.stderr)
|
||||
if result.stderr:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
# Handle errors in the command execution
|
||||
print(f"An error occurred while pulling the model '{model_name}':")
|
||||
print("Return code:", e.returncode)
|
||||
print("Output:\n", e.output)
|
||||
print("Error:\n", e.stderr)
|
||||
def load_embeddings(self):
|
||||
# need to check whether the embeddings are already present, if not we have to validate and pull the model into our ollama
|
||||
try:
|
||||
# find all the llama3 models here - https://ollama.com/library/llama3:8b
|
||||
# we are using the default connection and parameters for ollama emeddings
|
||||
self.embeddings = OllamaEmbeddings(model=self.embeddings)
|
||||
except:
|
||||
model_check = self.pull_model_to_ollama(self.embeddings)
|
||||
if model_check:
|
||||
self.embeddings = OllamaEmbeddings(model=self.embeddings)
|
||||
else:
|
||||
raise MyCustomError(f"The provided model name - {self.embeddings} is invalid. Find the list of supported llama3 models here - https://ollama.com/library/llama3:8b")
|
||||
def embed_and_store(self, documents):
|
||||
# we are going to use the ollama embeddings with Chroma store here
|
||||
if self.vectorstore=='chroma':
|
||||
try:
|
||||
# default methodis get_or_create collection, so if the name already exists this will append
|
||||
# and not overwrite that content
|
||||
self.vectorstore = Chroma(collection_name="MSA_4k_chunks",
|
||||
embedding_function=self.embeddings,
|
||||
persist_directory="./db")
|
||||
except Exception as e:
|
||||
raise MyCustomError(f"Chroma DB error - {e}")
|
||||
|
||||
for doc in documents:
|
||||
metas = [doc['metadata'] for i in doc['doc']]
|
||||
# optionally you can also IDs. passing metadata can be donw like this, or you can directly
|
||||
# add the metadata to the document object itself. No specific reason for me doing this.
|
||||
# adding metadata also allows easy updation and deletion, we can always check if a document
|
||||
# is present or absent based on the metdata, right.
|
||||
self.vectorstore.add_texts(texts = doc['doc'],
|
||||
metadatas = metas)
|
||||
|
||||
print("The documents have been embedded and stored in the vector database")
|
||||
if __name__=='__main__':
|
||||
msa_bot = local_pdf_gpt_ingester("source_files/", "llama3", "chroma")
|
||||
msa_docs = msa_bot.pdf_to_documents()
|
||||
msa_bot.load_embeddings()
|
||||
msa_bot.embed_and_store(msa_docs)
|
||||
|
||||
|
||||
130
model.py
Normal file
130
model.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
# Using chroma vectorstore
|
||||
from langchain_community.vectorstores import Chroma
|
||||
# For setting up the chat template
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
# for using the groq llama3 api
|
||||
from langchain_groq import ChatGroq
|
||||
from langchain.chains import RetrievalQA, create_retrieval_chain
|
||||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||
import os
|
||||
import json
|
||||
from fuzzywuzzy import process
|
||||
|
||||
os.environ['GROQ_API_KEY'] = ''
|
||||
|
||||
class MyCustomError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
class local_llama3_chatbot:
|
||||
def __init__(self, model, vector_dir, embedding):
|
||||
self.llm = model
|
||||
self.db = vector_dir
|
||||
self.embedding = embedding
|
||||
self.clients = []
|
||||
def load_base_assets(self):
|
||||
self.llm = ChatGroq(model_name = self.llm, temperature=0)
|
||||
self.embedding = OllamaEmbeddings(model=self.embedding)
|
||||
self.db = Chroma(collection_name="MSA_4k_chunks",
|
||||
embedding_function=self.embedding,
|
||||
persist_directory=self.db)
|
||||
def parse_meta(self, info):
|
||||
text = info.strip()
|
||||
if '{' in text:
|
||||
text = text[text.find('{'):]
|
||||
print("Before conversion - "+info)
|
||||
if not text.startswith('{') and '{' not in text:
|
||||
text = '{' + text
|
||||
# Check and add '}' at the end if missing
|
||||
if not text.endswith('}') and '}' not in text:
|
||||
text = text + '}'
|
||||
dictionary = json.loads(text)
|
||||
print("after conversion - "+str(dictionary))
|
||||
return dictionary
|
||||
def fetch_chat_metadata(self, user_query):
|
||||
# function to get the required metadata from the query and also understand the nature of the
|
||||
# query( to check for comparative documents )
|
||||
system = "You are a helpful assistant who answers exactly in the format specified by the user."
|
||||
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{text}")])
|
||||
chain = prompt | self.llm
|
||||
user_query = 'From the given text extract the organization name and the year as a python dictionary response with keys as client and contract_date(in the format YYYY).Do not return any other additonal messages. For unfound values return the value as "" for the ccorresponding key. Some sample client names will look like ["abc corp", "abc org", "abc inc", "abc ai"] TEXT: '+user_query
|
||||
meta_response = chain.invoke({"text":user_query})
|
||||
metadata = meta_response.content
|
||||
self.llm='llama3-8b-8192'
|
||||
self.llm = ChatGroq(model_name = self.llm)
|
||||
chain = prompt | self.llm
|
||||
user_query = 'From the given text extract the organization name and the year of the contract as a python dictionary response with keys as client and contract_date(in the format YYYY).Do not return any other additonal messages. For unfound values return the value as "" for the ccorresponding key. Some sample client names will look like ["abc corp", "abc org", "abc inc", "abc ai"] TEXT: '+user_query
|
||||
check_query = f'From the given text, check whether it is a cross-document question and return the response as "cross-document" if True and if False return response as "solo-document". Do not reurn any other additional message in your response. Cross-document question involves more than 2 organization names and asking for aggregations across organizations. TEXT: {user_query}"'
|
||||
response = chain.invoke({"text":check_query})
|
||||
q_type = response.content
|
||||
try:
|
||||
metadata = eval(metadata)
|
||||
except:
|
||||
metadata = self.parse_meta(metadata)
|
||||
return {'question_type': q_type, 'metadata': metadata}
|
||||
def model_call(self, user_query, classification):
|
||||
client = classification['metadata']['client']
|
||||
year = classification['metadata']['contract_date']
|
||||
searchtype = classification['question_type']
|
||||
metadatas = self.db.get()['metadatas']
|
||||
self.clients = [i['client'] for i in metadatas]
|
||||
correct_client = process.extractOne(client, self.clients)[0]
|
||||
print(f'client is {correct_client}')
|
||||
retriever = self.db.as_retriever(search_kwargs={"k":3,"filter":{
|
||||
'client': correct_client
|
||||
}})
|
||||
# You can simplify the filter dict dynamically instead of these many if statements.
|
||||
# I have done this for simplicity
|
||||
if searchtype=='solo-document' and isinstance(year, int)==False:
|
||||
print("entered - "+searchtype)
|
||||
retriever = self.db.as_retriever(search_kwargs={"k":3,
|
||||
"filter":{
|
||||
'client': correct_client
|
||||
}})
|
||||
elif searchtype=='cross-document' and isinstance(year,int):
|
||||
print("entered - "+searchtype)
|
||||
retriever = self.db.as_retriever(search_kwargs={"k":3,
|
||||
"filter":{
|
||||
'contract_date': {'$gt': f'01-01-{year}'}
|
||||
}})
|
||||
elif searchtype=='solo-document' and isinstance(year,int):
|
||||
print("entered - "+searchtype)
|
||||
retriever = self.db.as_retriever(search_kwargs={"k":3,
|
||||
"filter":{
|
||||
'client': correct_client,
|
||||
'contract_date': {'$gt': f'01-01-{year}'}
|
||||
}})
|
||||
qa_chain_response = self.final_llm_call(user_query,retriever)
|
||||
# qa_chain_response = self.vector_searcher(user_query, correct_client)
|
||||
return qa_chain_response
|
||||
def vector_searcher(self, user_query, correct_client):
|
||||
return self.db.similarity_search(user_query, k=3,filter={'client': correct_client})
|
||||
def final_llm_call(self, user_query, retriever):
|
||||
system_prompt = (
|
||||
"Use the given context to answer the question. "
|
||||
"If you don't know the answer, say you don't know. "
|
||||
"Use three sentence maximum and keep the answer concise. "
|
||||
"Context: {context}"
|
||||
)
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_prompt),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
question_answer_chain = create_stuff_documents_chain(self.llm, prompt)
|
||||
chain = create_retrieval_chain(retriever, question_answer_chain)
|
||||
response = chain.invoke({"input":user_query})
|
||||
return response
|
||||
if __name__ == '__main__':
|
||||
chatbot = local_llama3_chatbot('llama3-8b-8192', './db','llama3')
|
||||
query = input("query : ") # 'what is the address of shiro corp'
|
||||
chatbot.load_base_assets()
|
||||
while query!='exit':
|
||||
classification = chatbot.fetch_chat_metadata(query)
|
||||
output = chatbot.model_call(query, classification)
|
||||
# print(output)
|
||||
print("Output is : ",output['answer'])
|
||||
query = input("query : ")
|
||||
print("========================End Of Conversation=================================")
|
||||
BIN
source_files/file-sample_100kB.doc
Normal file
BIN
source_files/file-sample_100kB.doc
Normal file
Binary file not shown.
BIN
source_files/template-master-service-agreement-kuro-2.pdf
Normal file
BIN
source_files/template-master-service-agreement-kuro-2.pdf
Normal file
Binary file not shown.
BIN
source_files/template-master-service-agreement-shiro-1.pdf
Normal file
BIN
source_files/template-master-service-agreement-shiro-1.pdf
Normal file
Binary file not shown.
Reference in New Issue
Block a user