mirror of
https://github.com/NirDiamant/RAG_Techniques.git
synced 2025-04-07 00:48:52 +03:00
adding runnable scripts
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from dotenv import load_dotenv
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
# Add the parent directory to the path since we work with notebooks
|
||||
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
|
||||
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
@@ -13,15 +15,10 @@ load_dotenv()
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
# Define document(s) path
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
|
||||
|
||||
# Define the HyDe retriever class - creating vector store, generating hypothetical document, and retrieving
|
||||
class HyDERetriever:
|
||||
def __init__(self, files_path, chunk_size=500, chunk_overlap=100):
|
||||
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
|
||||
|
||||
self.embeddings = OpenAIEmbeddings()
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
@@ -30,7 +27,7 @@ class HyDERetriever:
|
||||
self.hyde_prompt = PromptTemplate(
|
||||
input_variables=["query", "chunk_size"],
|
||||
template="""Given the question '{query}', generate a hypothetical document that directly answers this question. The document should be detailed and in-depth.
|
||||
the document size has be exactly {chunk_size} characters.""",
|
||||
The document size has to be exactly {chunk_size} characters.""",
|
||||
)
|
||||
self.hyde_chain = self.hyde_prompt | self.llm
|
||||
|
||||
@@ -44,16 +41,38 @@ class HyDERetriever:
|
||||
return similar_docs, hypothetical_doc
|
||||
|
||||
|
||||
# Create a HyDe retriever instance
|
||||
retriever = HyDERetriever(path)
|
||||
# Main class for running the retrieval process
|
||||
class ClimateChangeRAG:
|
||||
def __init__(self, path, query):
|
||||
self.retriever = HyDERetriever(path)
|
||||
self.query = query
|
||||
|
||||
# Demonstrate on a use case
|
||||
test_query = "What is the main cause of climate change?"
|
||||
results, hypothetical_doc = retriever.retrieve(test_query)
|
||||
def run(self):
|
||||
# Retrieve results and hypothetical document
|
||||
results, hypothetical_doc = self.retriever.retrieve(self.query)
|
||||
|
||||
# Plot the hypothetical document and the retrieved documents
|
||||
docs_content = [doc.page_content for doc in results]
|
||||
# Plot the hypothetical document and the retrieved documents
|
||||
docs_content = [doc.page_content for doc in results]
|
||||
|
||||
print("hypothetical_doc:\n")
|
||||
print(text_wrap(hypothetical_doc) + "\n")
|
||||
show_context(docs_content)
|
||||
print("Hypothetical document:\n")
|
||||
print(text_wrap(hypothetical_doc) + "\n")
|
||||
show_context(docs_content)
|
||||
|
||||
|
||||
# Argument parsing function
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run the Climate Change RAG method.")
|
||||
parser.add_argument("--path", type=str, default="../data/Understanding_Climate_Change.pdf",
|
||||
help="Path to the PDF file to process.")
|
||||
parser.add_argument("--query", type=str, default="What is the main cause of climate change?",
|
||||
help="Query to test the retriever (default: 'What is the main topic of the document?').")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse command-line arguments
|
||||
args = parse_args()
|
||||
|
||||
# Create and run the RAG method instance
|
||||
rag_runner = ClimateChangeRAG(args.path, args.query)
|
||||
rag_runner.run()
|
||||
|
||||
@@ -8,13 +8,13 @@ from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from typing import Dict, Any
|
||||
from typing import List, Dict, Any
|
||||
from langchain.docstore.document import Document
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path since we work with notebooks
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
@@ -25,11 +25,25 @@ load_dotenv()
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
|
||||
# Define the query classifer class
|
||||
class categories_options(BaseModel):
|
||||
# Define all the required classes and strategies
|
||||
class CategoriesOptions(BaseModel):
|
||||
category: str = Field(
|
||||
description="The category of the query, the options are: Factual, Analytical, Opinion, or Contextual",
|
||||
example="Factual")
|
||||
example="Factual"
|
||||
)
|
||||
|
||||
|
||||
class RelevantScore(BaseModel):
|
||||
score: float = Field(description="The relevance score of the document to the query", example=8.0)
|
||||
|
||||
|
||||
class SelectedIndices(BaseModel):
|
||||
indices: List[int] = Field(description="Indices of selected documents", example=[0, 1, 2, 3])
|
||||
|
||||
|
||||
class SubQueries(BaseModel):
|
||||
sub_queries: List[str] = Field(description="List of sub-queries for comprehensive analysis",
|
||||
example=["What is the population of New York?", "What is the GDP of New York?"])
|
||||
|
||||
|
||||
class QueryClassifier:
|
||||
@@ -39,14 +53,13 @@ class QueryClassifier:
|
||||
input_variables=["query"],
|
||||
template="Classify the following query into one of these categories: Factual, Analytical, Opinion, or Contextual.\nQuery: {query}\nCategory:"
|
||||
)
|
||||
self.chain = self.prompt | self.llm.with_structured_output(categories_options)
|
||||
self.chain = self.prompt | self.llm.with_structured_output(CategoriesOptions)
|
||||
|
||||
def classify(self, query):
|
||||
print("clasiffying query")
|
||||
print("Classifying query...")
|
||||
return self.chain.invoke(query).category
|
||||
|
||||
|
||||
# Define the Base Retriever class, such that the complex ones will inherit from it
|
||||
class BaseRetrievalStrategy:
|
||||
def __init__(self, texts):
|
||||
self.embeddings = OpenAIEmbeddings()
|
||||
@@ -59,95 +72,67 @@ class BaseRetrievalStrategy:
|
||||
return self.db.similarity_search(query, k=k)
|
||||
|
||||
|
||||
# Define Factual retriever strategy
|
||||
class relevant_score(BaseModel):
|
||||
score: float = Field(description="The relevance score of the document to the query", example=8.0)
|
||||
|
||||
|
||||
class FactualRetrievalStrategy(BaseRetrievalStrategy):
|
||||
def retrieve(self, query, k=4):
|
||||
print("retrieving factual")
|
||||
# Use LLM to enhance the query
|
||||
print("Retrieving factual information...")
|
||||
enhanced_query_prompt = PromptTemplate(
|
||||
input_variables=["query"],
|
||||
template="Enhance this factual query for better information retrieval: {query}"
|
||||
)
|
||||
query_chain = enhanced_query_prompt | self.llm
|
||||
enhanced_query = query_chain.invoke(query).content
|
||||
print(f'enhande query: {enhanced_query}')
|
||||
print(f'Enhanced query: {enhanced_query}')
|
||||
|
||||
# Retrieve documents using the enhanced query
|
||||
docs = self.db.similarity_search(enhanced_query, k=k * 2)
|
||||
|
||||
# Use LLM to rank the relevance of retrieved documents
|
||||
ranking_prompt = PromptTemplate(
|
||||
input_variables=["query", "doc"],
|
||||
template="On a scale of 1-10, how relevant is this document to the query: '{query}'?\nDocument: {doc}\nRelevance score:"
|
||||
)
|
||||
ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)
|
||||
ranking_chain = ranking_prompt | self.llm.with_structured_output(RelevantScore)
|
||||
|
||||
ranked_docs = []
|
||||
print("ranking docs")
|
||||
print("Ranking documents...")
|
||||
for doc in docs:
|
||||
input_data = {"query": enhanced_query, "doc": doc.page_content}
|
||||
score = float(ranking_chain.invoke(input_data).score)
|
||||
ranked_docs.append((doc, score))
|
||||
|
||||
# Sort by relevance score and return top k
|
||||
ranked_docs.sort(key=lambda x: x[1], reverse=True)
|
||||
return [doc for doc, _ in ranked_docs[:k]]
|
||||
|
||||
|
||||
# Define Analytical reriever strategy
|
||||
class SelectedIndices(BaseModel):
|
||||
indices: List[int] = Field(description="Indices of selected documents", example=[0, 1, 2, 3])
|
||||
|
||||
|
||||
class SubQueries(BaseModel):
|
||||
sub_queries: List[str] = Field(description="List of sub-queries for comprehensive analysis",
|
||||
example=["What is the population of New York?", "What is the GDP of New York?"])
|
||||
|
||||
|
||||
class AnalyticalRetrievalStrategy(BaseRetrievalStrategy):
|
||||
def retrieve(self, query, k=4):
|
||||
print("retrieving analytical")
|
||||
# Use LLM to generate sub-queries for comprehensive analysis
|
||||
print("Retrieving analytical information...")
|
||||
sub_queries_prompt = PromptTemplate(
|
||||
input_variables=["query", "k"],
|
||||
template="Generate {k} sub-questions for: {query}"
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
||||
sub_queries_chain = sub_queries_prompt | llm.with_structured_output(SubQueries)
|
||||
|
||||
sub_queries_chain = sub_queries_prompt | self.llm.with_structured_output(SubQueries)
|
||||
input_data = {"query": query, "k": k}
|
||||
sub_queries = sub_queries_chain.invoke(input_data).sub_queries
|
||||
print(f'sub queries for comprehensive analysis: {sub_queries}')
|
||||
print(f'Sub-queries: {sub_queries}')
|
||||
|
||||
all_docs = []
|
||||
for sub_query in sub_queries:
|
||||
all_docs.extend(self.db.similarity_search(sub_query, k=2))
|
||||
|
||||
# Use LLM to ensure diversity and relevance
|
||||
diversity_prompt = PromptTemplate(
|
||||
input_variables=["query", "docs", "k"],
|
||||
template="""Select the most diverse and relevant set of {k} documents for the query: '{query}'\nDocuments: {docs}\n
|
||||
Return only the indices of selected documents as a list of integers."""
|
||||
template="Select the most diverse and relevant set of {k} documents for the query: '{query}'\nDocuments: {docs}\n"
|
||||
)
|
||||
diversity_chain = diversity_prompt | self.llm.with_structured_output(SelectedIndices)
|
||||
docs_text = "\n".join([f"{i}: {doc.page_content[:50]}..." for i, doc in enumerate(all_docs)])
|
||||
input_data = {"query": query, "docs": docs_text, "k": k}
|
||||
selected_indices_result = diversity_chain.invoke(input_data).indices
|
||||
print(f'selected diverse and relevant documents')
|
||||
selected_indices = diversity_chain.invoke(input_data).indices
|
||||
|
||||
return [all_docs[i] for i in selected_indices_result if i < len(all_docs)]
|
||||
return [all_docs[i] for i in selected_indices if i < len(all_docs)]
|
||||
|
||||
|
||||
# Define Opinion retriever strategy
|
||||
class OpinionRetrievalStrategy(BaseRetrievalStrategy):
|
||||
def retrieve(self, query, k=3):
|
||||
print("retrieving opinion")
|
||||
# Use LLM to identify potential viewpoints
|
||||
print("Retrieving opinions...")
|
||||
viewpoints_prompt = PromptTemplate(
|
||||
input_variables=["query", "k"],
|
||||
template="Identify {k} distinct viewpoints or perspectives on the topic: {query}"
|
||||
@@ -155,13 +140,12 @@ class OpinionRetrievalStrategy(BaseRetrievalStrategy):
|
||||
viewpoints_chain = viewpoints_prompt | self.llm
|
||||
input_data = {"query": query, "k": k}
|
||||
viewpoints = viewpoints_chain.invoke(input_data).content.split('\n')
|
||||
print(f'viewpoints: {viewpoints}')
|
||||
print(f'Viewpoints: {viewpoints}')
|
||||
|
||||
all_docs = []
|
||||
for viewpoint in viewpoints:
|
||||
all_docs.extend(self.db.similarity_search(f"{query} {viewpoint}", k=2))
|
||||
|
||||
# Use LLM to classify and select diverse opinions
|
||||
opinion_prompt = PromptTemplate(
|
||||
input_variables=["query", "docs", "k"],
|
||||
template="Classify these documents into distinct opinions on '{query}' and select the {k} most representative and diverse viewpoints:\nDocuments: {docs}\nSelected indices:"
|
||||
@@ -171,16 +155,13 @@ class OpinionRetrievalStrategy(BaseRetrievalStrategy):
|
||||
docs_text = "\n".join([f"{i}: {doc.page_content[:100]}..." for i, doc in enumerate(all_docs)])
|
||||
input_data = {"query": query, "docs": docs_text, "k": k}
|
||||
selected_indices = opinion_chain.invoke(input_data).indices
|
||||
print(f'selected diverse and relevant documents')
|
||||
|
||||
return [all_docs[int(i)] for i in selected_indices.split() if i.isdigit() and int(i) < len(all_docs)]
|
||||
return [all_docs[int(i)] for i in selected_indices if i.isdigit() and int(i) < len(all_docs)]
|
||||
|
||||
|
||||
# Define Contextual retriever strategy
|
||||
class ContextualRetrievalStrategy(BaseRetrievalStrategy):
|
||||
def retrieve(self, query, k=4, user_context=None):
|
||||
print("retrieving contextual")
|
||||
# Use LLM to incorporate user context into the query
|
||||
print("Retrieving contextual information...")
|
||||
context_prompt = PromptTemplate(
|
||||
input_variables=["query", "context"],
|
||||
template="Given the user context: {context}\nReformulate the query to best address the user's needs: {query}"
|
||||
@@ -188,18 +169,15 @@ class ContextualRetrievalStrategy(BaseRetrievalStrategy):
|
||||
context_chain = context_prompt | self.llm
|
||||
input_data = {"query": query, "context": user_context or "No specific context provided"}
|
||||
contextualized_query = context_chain.invoke(input_data).content
|
||||
print(f'contextualized query: {contextualized_query}')
|
||||
print(f'Contextualized query: {contextualized_query}')
|
||||
|
||||
# Retrieve documents using the contextualized query
|
||||
docs = self.db.similarity_search(contextualized_query, k=k * 2)
|
||||
|
||||
# Use LLM to rank the relevance of retrieved documents considering the user context
|
||||
ranking_prompt = PromptTemplate(
|
||||
input_variables=["query", "context", "doc"],
|
||||
template="Given the query: '{query}' and user context: '{context}', rate the relevance of this document on a scale of 1-10:\nDocument: {doc}\nRelevance score:"
|
||||
)
|
||||
ranking_chain = ranking_prompt | self.llm.with_structured_output(relevant_score)
|
||||
print("ranking docs")
|
||||
ranking_chain = ranking_prompt | self.llm.with_structured_output(RelevantScore)
|
||||
|
||||
ranked_docs = []
|
||||
for doc in docs:
|
||||
@@ -208,14 +186,13 @@ class ContextualRetrievalStrategy(BaseRetrievalStrategy):
|
||||
score = float(ranking_chain.invoke(input_data).score)
|
||||
ranked_docs.append((doc, score))
|
||||
|
||||
# Sort by relevance score and return top k
|
||||
ranked_docs.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return [doc for doc, _ in ranked_docs[:k]]
|
||||
|
||||
|
||||
# Define the Adapive retriever class
|
||||
class AdaptiveRetriever:
|
||||
# Define the main Adaptive RAG class
|
||||
class AdaptiveRAG:
|
||||
def __init__(self, texts: List[str]):
|
||||
self.classifier = QueryClassifier()
|
||||
self.strategies = {
|
||||
@@ -224,35 +201,7 @@ class AdaptiveRetriever:
|
||||
"Opinion": OpinionRetrievalStrategy(texts),
|
||||
"Contextual": ContextualRetrievalStrategy(texts)
|
||||
}
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
category = self.classifier.classify(query)
|
||||
strategy = self.strategies[category]
|
||||
return strategy.retrieve(query)
|
||||
|
||||
|
||||
# Define aditional retriever that inherits from langchain BaseRetriever
|
||||
class PydanticAdaptiveRetriever(BaseRetriever):
|
||||
adaptive_retriever: AdaptiveRetriever = Field(exclude=True)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
return self.adaptive_retriever.get_relevant_documents(query)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
return self.get_relevant_documents(query)
|
||||
|
||||
|
||||
# Define the Adaptive RAG class
|
||||
class AdaptiveRAG:
|
||||
def __init__(self, texts: List[str]):
|
||||
adaptive_retriever = AdaptiveRetriever(texts)
|
||||
self.retriever = PydanticAdaptiveRetriever(adaptive_retriever=adaptive_retriever)
|
||||
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
||||
|
||||
# Create a custom prompt
|
||||
prompt_template = """Use the following pieces of context to answer the question at the end.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
@@ -260,34 +209,39 @@ class AdaptiveRAG:
|
||||
|
||||
Question: {question}
|
||||
Answer:"""
|
||||
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
|
||||
|
||||
# Create the LLM chain
|
||||
self.llm_chain = prompt | self.llm
|
||||
self.prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
|
||||
self.llm_chain = self.prompt | self.llm
|
||||
|
||||
def answer(self, query: str) -> str:
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
category = self.classifier.classify(query)
|
||||
strategy = self.strategies[category]
|
||||
docs = strategy.retrieve(query)
|
||||
input_data = {"context": "\n".join([doc.page_content for doc in docs]), "question": query}
|
||||
return self.llm_chain.invoke(input_data)
|
||||
return self.llm_chain.invoke(input_data).content
|
||||
|
||||
|
||||
# Demonstrate use of this model
|
||||
# Usage
|
||||
texts = [
|
||||
"The Earth is the third planet from the Sun and the only astronomical object known to harbor life."
|
||||
]
|
||||
rag_system = AdaptiveRAG(texts)
|
||||
# Argument parsing functions
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Run AdaptiveRAG system.")
|
||||
parser.add_argument('--texts', nargs='+', help="Input texts for retrieval")
|
||||
return parser.parse_args()
|
||||
|
||||
# Showcase the four different types of queries
|
||||
factual_result = rag_system.answer("What is the distance between the Earth and the Sun?").content
|
||||
print(f"Answer: {factual_result}")
|
||||
|
||||
analytical_result = rag_system.answer("How does the Earth's distance from the Sun affect its climate?").content
|
||||
print(f"Answer: {analytical_result}")
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
texts = args.texts or [
|
||||
"The Earth is the third planet from the Sun and the only astronomical object known to harbor life."]
|
||||
rag_system = AdaptiveRAG(texts)
|
||||
|
||||
opinion_result = rag_system.answer("What are the different theories about the origin of life on Earth?").content
|
||||
print(f"Answer: {opinion_result}")
|
||||
queries = [
|
||||
"What is the distance between the Earth and the Sun?",
|
||||
"How does the Earth's distance from the Sun affect its climate?",
|
||||
"What are the different theories about the origin of life on Earth?",
|
||||
"How does the Earth's position in the Solar System influence its habitability?"
|
||||
]
|
||||
|
||||
contextual_result = rag_system.answer(
|
||||
"How does the Earth's position in the Solar System influence its habitability?").content
|
||||
print(f"Answer: {contextual_result}")
|
||||
for query in queries:
|
||||
print(f"Query: {query}")
|
||||
result = rag_system.answer(query)
|
||||
print(f"Answer: {result}")
|
||||
|
||||
@@ -1,88 +1,35 @@
|
||||
import nest_asyncio
|
||||
import random
|
||||
|
||||
nest_asyncio.apply()
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
|
||||
from llama_index.core.evaluation import (
|
||||
DatasetGenerator,
|
||||
FaithfulnessEvaluator,
|
||||
RelevancyEvaluator
|
||||
)
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
import openai
|
||||
import time
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.evaluation import DatasetGenerator, FaithfulnessEvaluator, RelevancyEvaluator
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
# Apply asyncio fix for Jupyter notebooks
|
||||
nest_asyncio.apply()
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Read Docs
|
||||
data_dir = "../data"
|
||||
documents = SimpleDirectoryReader(data_dir).load_data()
|
||||
|
||||
# Create evaluation questions and pick k out of them
|
||||
num_eval_questions = 25
|
||||
|
||||
eval_documents = documents[0:20]
|
||||
data_generator = DatasetGenerator.from_documents(eval_documents)
|
||||
eval_questions = data_generator.generate_questions_from_nodes()
|
||||
k_eval_questions = random.sample(eval_questions, num_eval_questions)
|
||||
|
||||
# Define metrics evaluators and modify llama_index faithfullness evaluator prompt to rely on the context
|
||||
# We will use GPT-4 for evaluating the responses
|
||||
gpt4 = OpenAI(temperature=0, model="gpt-4o")
|
||||
|
||||
# Define service context for GPT-4 for evaluation
|
||||
service_context_gpt4 = ServiceContext.from_defaults(llm=gpt4)
|
||||
|
||||
# Define Faithfulness and Relevancy Evaluators which are based on GPT-4
|
||||
faithfulness_gpt4 = FaithfulnessEvaluator(service_context=service_context_gpt4)
|
||||
|
||||
faithfulness_new_prompt_template = PromptTemplate(""" Please tell if a given piece of information is directly supported by the context.
|
||||
You need to answer with either YES or NO.
|
||||
Answer YES if any part of the context explicitly supports the information, even if most of the context is unrelated. If the context does not explicitly support the information, answer NO. Some examples are provided below.
|
||||
|
||||
Information: Apple pie is generally double-crusted.
|
||||
Context: An apple pie is a fruit pie in which the principal filling ingredient is apples.
|
||||
Apple pie is often served with whipped cream, ice cream ('apple pie à la mode'), custard, or cheddar cheese.
|
||||
It is generally double-crusted, with pastry both above and below the filling; the upper crust may be solid or latticed (woven of crosswise strips).
|
||||
Answer: YES
|
||||
|
||||
Information: Apple pies taste bad.
|
||||
Context: An apple pie is a fruit pie in which the principal filling ingredient is apples.
|
||||
Apple pie is often served with whipped cream, ice cream ('apple pie à la mode'), custard, or cheddar cheese.
|
||||
It is generally double-crusted, with pastry both above and below the filling; the upper crust may be solid or latticed (woven of crosswise strips).
|
||||
Answer: NO
|
||||
|
||||
Information: Paris is the capital of France.
|
||||
Context: This document describes a day trip in Paris. You will visit famous landmarks like the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral.
|
||||
Answer: NO
|
||||
|
||||
Information: {query_str}
|
||||
Context: {context_str}
|
||||
Answer:
|
||||
|
||||
""")
|
||||
|
||||
faithfulness_gpt4.update_prompts(
|
||||
{"your_prompt_key": faithfulness_new_prompt_template}) # Update the prompts dictionary with the new prompt template
|
||||
relevancy_gpt4 = RelevancyEvaluator(service_context=service_context_gpt4)
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
|
||||
# Function to evaluate metrics for each chunk size
|
||||
# Define function to calculate average response time, average faithfulness and average relevancy metrics for given chunk size
|
||||
# We use GPT-3.5-Turbo to generate response and GPT-4 to evaluate it.
|
||||
def evaluate_response_time_and_accuracy(chunk_size, eval_questions):
|
||||
# Utility functions
|
||||
def evaluate_response_time_and_accuracy(chunk_size, eval_questions, eval_documents, faithfulness_evaluator,
|
||||
relevancy_evaluator):
|
||||
"""
|
||||
Evaluate the average response time, faithfulness, and relevancy of responses generated by GPT-3.5-turbo for a given chunk size.
|
||||
|
||||
Parameters:
|
||||
chunk_size (int): The size of data chunks being processed.
|
||||
eval_questions (list): List of evaluation questions.
|
||||
eval_documents (list): Documents used for evaluation.
|
||||
faithfulness_evaluator (FaithfulnessEvaluator): Evaluator for faithfulness.
|
||||
relevancy_evaluator (RelevancyEvaluator): Evaluator for relevancy.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the average response time, faithfulness, and relevancy metrics.
|
||||
@@ -92,32 +39,23 @@ def evaluate_response_time_and_accuracy(chunk_size, eval_questions):
|
||||
total_faithfulness = 0
|
||||
total_relevancy = 0
|
||||
|
||||
# create vector index
|
||||
# Create vector index
|
||||
llm = OpenAI(model="gpt-3.5-turbo")
|
||||
|
||||
service_context = ServiceContext.from_defaults(llm=llm, chunk_size=chunk_size, chunk_overlap=chunk_size // 5)
|
||||
vector_index = VectorStoreIndex.from_documents(
|
||||
eval_documents, service_context=service_context
|
||||
)
|
||||
# build query engine
|
||||
vector_index = VectorStoreIndex.from_documents(eval_documents, service_context=service_context)
|
||||
|
||||
# Build query engine
|
||||
query_engine = vector_index.as_query_engine(similarity_top_k=5)
|
||||
num_questions = len(eval_questions)
|
||||
|
||||
# Iterate over each question in eval_questions to compute metrics.
|
||||
# While BatchEvalRunner can be used for faster evaluations (see: https://docs.llamaindex.ai/en/latest/examples/evaluation/batch_eval.html),
|
||||
# we're using a loop here to specifically measure response time for different chunk sizes.
|
||||
# Iterate over each question in eval_questions to compute metrics
|
||||
for question in eval_questions:
|
||||
start_time = time.time()
|
||||
response_vector = query_engine.query(question)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
faithfulness_result = faithfulness_gpt4.evaluate_response(
|
||||
response=response_vector
|
||||
).passing
|
||||
|
||||
relevancy_result = relevancy_gpt4.evaluate_response(
|
||||
query=question, response=response_vector
|
||||
).passing
|
||||
faithfulness_result = faithfulness_evaluator.evaluate_response(response=response_vector).passing
|
||||
relevancy_result = relevancy_evaluator.evaluate_response(query=question, response=response_vector).passing
|
||||
|
||||
total_response_time += elapsed_time
|
||||
total_faithfulness += faithfulness_result
|
||||
@@ -130,11 +68,72 @@ def evaluate_response_time_and_accuracy(chunk_size, eval_questions):
|
||||
return average_response_time, average_faithfulness, average_relevancy
|
||||
|
||||
|
||||
# Test different chunk sizes
|
||||
chunk_sizes = [128, 256]
|
||||
# Define the main class for the RAG method
|
||||
|
||||
for chunk_size in chunk_sizes:
|
||||
avg_response_time, avg_faithfulness, avg_relevancy = evaluate_response_time_and_accuracy(chunk_size,
|
||||
k_eval_questions)
|
||||
print(
|
||||
f"Chunk size {chunk_size} - Average Response time: {avg_response_time:.2f}s, Average Faithfulness: {avg_faithfulness:.2f}, Average Relevancy: {avg_relevancy:.2f}")
|
||||
class RAGEvaluator:
|
||||
def __init__(self, data_dir, num_eval_questions, chunk_sizes):
|
||||
self.data_dir = data_dir
|
||||
self.num_eval_questions = num_eval_questions
|
||||
self.chunk_sizes = chunk_sizes
|
||||
self.documents = self.load_documents()
|
||||
self.eval_questions = self.generate_eval_questions()
|
||||
self.service_context_gpt4 = self.create_service_context()
|
||||
self.faithfulness_evaluator = self.create_faithfulness_evaluator()
|
||||
self.relevancy_evaluator = self.create_relevancy_evaluator()
|
||||
|
||||
def load_documents(self):
|
||||
return SimpleDirectoryReader(self.data_dir).load_data()
|
||||
|
||||
def generate_eval_questions(self):
|
||||
eval_documents = self.documents[0:20]
|
||||
data_generator = DatasetGenerator.from_documents(eval_documents)
|
||||
eval_questions = data_generator.generate_questions_from_nodes()
|
||||
return random.sample(eval_questions, self.num_eval_questions)
|
||||
|
||||
def create_service_context(self):
|
||||
gpt4 = OpenAI(temperature=0, model="gpt-4o")
|
||||
return ServiceContext.from_defaults(llm=gpt4)
|
||||
|
||||
def create_faithfulness_evaluator(self):
|
||||
faithfulness_evaluator = FaithfulnessEvaluator(service_context=self.service_context_gpt4)
|
||||
faithfulness_new_prompt_template = PromptTemplate("""
|
||||
Please tell if a given piece of information is directly supported by the context.
|
||||
You need to answer with either YES or NO.
|
||||
Answer YES if any part of the context explicitly supports the information, even if most of the context is unrelated. If the context does not explicitly support the information, answer NO. Some examples are provided below.
|
||||
...
|
||||
""")
|
||||
faithfulness_evaluator.update_prompts({"your_prompt_key": faithfulness_new_prompt_template})
|
||||
return faithfulness_evaluator
|
||||
|
||||
def create_relevancy_evaluator(self):
|
||||
return RelevancyEvaluator(service_context=self.service_context_gpt4)
|
||||
|
||||
def run(self):
|
||||
for chunk_size in self.chunk_sizes:
|
||||
avg_response_time, avg_faithfulness, avg_relevancy = evaluate_response_time_and_accuracy(
|
||||
chunk_size,
|
||||
self.eval_questions,
|
||||
self.documents[0:20],
|
||||
self.faithfulness_evaluator,
|
||||
self.relevancy_evaluator
|
||||
)
|
||||
print(f"Chunk size {chunk_size} - Average Response time: {avg_response_time:.2f}s, "
|
||||
f"Average Faithfulness: {avg_faithfulness:.2f}, Average Relevancy: {avg_relevancy:.2f}")
|
||||
|
||||
|
||||
# Argument Parsing
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='RAG Method Evaluation')
|
||||
parser.add_argument('--data_dir', type=str, default='../data', help='Directory of the documents')
|
||||
parser.add_argument('--num_eval_questions', type=int, default=25, help='Number of evaluation questions')
|
||||
parser.add_argument('--chunk_sizes', nargs='+', type=int, default=[128, 256], help='List of chunk sizes')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
evaluator = RAGEvaluator(data_dir=args.data_dir, num_eval_questions=args.num_eval_questions,
|
||||
chunk_sizes=args.chunk_sizes)
|
||||
evaluator.run()
|
||||
|
||||
@@ -2,11 +2,9 @@ import os
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
from typing import List
|
||||
|
||||
# Load environment variables from a .env file
|
||||
load_dotenv()
|
||||
@@ -14,12 +12,6 @@ load_dotenv()
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
# Define path to PDF
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
|
||||
# Read PDF to string
|
||||
content = read_pdf_to_string(path)
|
||||
|
||||
|
||||
# Function to split text into chunks with metadata of the chunk chronological index
|
||||
def split_text_to_chunks_with_indices(text: str, chunk_size: int, chunk_overlap: int) -> List[Document]:
|
||||
@@ -33,32 +25,8 @@ def split_text_to_chunks_with_indices(text: str, chunk_size: int, chunk_overlap:
|
||||
return chunks
|
||||
|
||||
|
||||
# Split our document accordingly
|
||||
chunks_size = 400
|
||||
chunk_overlap = 200
|
||||
docs = split_text_to_chunks_with_indices(content, chunks_size, chunk_overlap)
|
||||
|
||||
# Create vector store and retriever
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = FAISS.from_documents(docs, embeddings)
|
||||
chunks_query_retriever = vectorstore.as_retriever(search_kwargs={"k": 1})
|
||||
|
||||
|
||||
# Function to draw the k<sup>th</sup> chunk (in the original order) from the vector store
|
||||
|
||||
# Function to retrieve a chunk from the vectorstore based on its index in the metadata
|
||||
def get_chunk_by_index(vectorstore, target_index: int) -> Document:
|
||||
"""
|
||||
Retrieve a chunk from the vectorstore based on its index in the metadata.
|
||||
|
||||
Args:
|
||||
vectorstore (VectorStore): The vectorstore containing the chunks.
|
||||
target_index (int): The index of the chunk to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[Document]: The retrieved chunk as a Document object, or None if not found.
|
||||
"""
|
||||
# This is a simplified version. In practice, you might need a more efficient method
|
||||
# to retrieve chunks by index, depending on your vectorstore implementation.
|
||||
all_docs = vectorstore.similarity_search("", k=vectorstore.index.ntotal)
|
||||
for doc in all_docs:
|
||||
if doc.metadata.get('index') == target_index:
|
||||
@@ -66,29 +34,9 @@ def get_chunk_by_index(vectorstore, target_index: int) -> Document:
|
||||
return None
|
||||
|
||||
|
||||
# Check the function
|
||||
chunk = get_chunk_by_index(vectorstore, 0)
|
||||
print(chunk.page_content)
|
||||
|
||||
|
||||
# Function that retrieves from the vector stroe based on semantic similarity and then pads each retrieved chunk with its num_neighbors before and after, taking into account the chunk overlap to construct a meaningful wide window arround it
|
||||
# Function that retrieves from the vectorstore based on semantic similarity and pads each retrieved chunk with its neighboring chunks
|
||||
def retrieve_with_context_overlap(vectorstore, retriever, query: str, num_neighbors: int = 1, chunk_size: int = 200,
|
||||
chunk_overlap: int = 20) -> List[str]:
|
||||
"""
|
||||
Retrieve chunks based on a query, then fetch neighboring chunks and concatenate them,
|
||||
accounting for overlap and correct indexing.
|
||||
|
||||
Args:
|
||||
vectorstore (VectorStore): The vectorstore containing the chunks.
|
||||
retriever: The retriever object to get relevant documents.
|
||||
query (str): The query to search for relevant chunks.
|
||||
num_neighbors (int): The number of chunks to retrieve before and after each relevant chunk.
|
||||
chunk_size (int): The size of each chunk when originally split.
|
||||
chunk_overlap (int): The overlap between chunks when originally split.
|
||||
|
||||
Returns:
|
||||
List[str]: List of concatenated chunk sequences, each centered on a relevant chunk.
|
||||
"""
|
||||
relevant_chunks = retriever.get_relevant_documents(query)
|
||||
result_sequences = []
|
||||
|
||||
@@ -99,7 +47,7 @@ def retrieve_with_context_overlap(vectorstore, retriever, query: str, num_neighb
|
||||
|
||||
# Determine the range of chunks to retrieve
|
||||
start_index = max(0, current_index - num_neighbors)
|
||||
end_index = current_index + num_neighbors + 1 # +1 because range is exclusive at the end
|
||||
end_index = current_index + num_neighbors + 1
|
||||
|
||||
# Retrieve all chunks in the range
|
||||
neighbor_chunks = []
|
||||
@@ -123,68 +71,77 @@ def retrieve_with_context_overlap(vectorstore, retriever, query: str, num_neighb
|
||||
return result_sequences
|
||||
|
||||
|
||||
# Comparing regular retrival and retrival with context window
|
||||
# Baseline approach
|
||||
query = "Explain the role of deforestation and fossil fuels in climate change."
|
||||
baseline_chunk = chunks_query_retriever.get_relevant_documents(query
|
||||
,
|
||||
k=1
|
||||
)
|
||||
# Focused context enrichment approach
|
||||
enriched_chunks = retrieve_with_context_overlap(
|
||||
vectorstore,
|
||||
chunks_query_retriever,
|
||||
query,
|
||||
num_neighbors=1,
|
||||
chunk_size=400,
|
||||
chunk_overlap=200
|
||||
)
|
||||
# Main class that encapsulates the RAG method
|
||||
class RAGMethod:
|
||||
def __init__(self, chunk_size: int = 400, chunk_overlap: int = 200):
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.docs = self._prepare_docs()
|
||||
self.vectorstore, self.retriever = self._prepare_retriever()
|
||||
|
||||
print("Baseline Chunk:")
|
||||
print(baseline_chunk[0].page_content)
|
||||
print("\nEnriched Chunks:")
|
||||
print(enriched_chunks[0])
|
||||
def _prepare_docs(self) -> List[Document]:
|
||||
content = """
|
||||
Artificial Intelligence (AI) has a rich history dating back to the mid-20th century. The term "Artificial Intelligence" was coined in 1956 at the Dartmouth Conference, marking the field's official beginning.
|
||||
|
||||
# An example that showcases the superiority of additional context window
|
||||
document_content = """
|
||||
Artificial Intelligence (AI) has a rich history dating back to the mid-20th century. The term "Artificial Intelligence" was coined in 1956 at the Dartmouth Conference, marking the field's official beginning.
|
||||
In the 1950s and 1960s, AI research focused on symbolic methods and problem-solving. The Logic Theorist, created in 1955 by Allen Newell and Herbert A. Simon, is often considered the first AI program.
|
||||
|
||||
In the 1950s and 1960s, AI research focused on symbolic methods and problem-solving. The Logic Theorist, created in 1955 by Allen Newell and Herbert A. Simon, is often considered the first AI program.
|
||||
The 1960s saw the development of expert systems, which used predefined rules to solve complex problems. DENDRAL, created in 1965, was one of the first expert systems, designed to analyze chemical compounds.
|
||||
|
||||
The 1960s saw the development of expert systems, which used predefined rules to solve complex problems. DENDRAL, created in 1965, was one of the first expert systems, designed to analyze chemical compounds.
|
||||
However, the 1970s brought the first "AI Winter," a period of reduced funding and interest in AI research, largely due to overpromised capabilities and underdelivered results.
|
||||
|
||||
However, the 1970s brought the first "AI Winter," a period of reduced funding and interest in AI research, largely due to overpromised capabilities and underdelivered results.
|
||||
The 1980s saw a resurgence with the popularization of expert systems in corporations. The Japanese government's Fifth Generation Computer Project also spurred increased investment in AI research globally.
|
||||
|
||||
The 1980s saw a resurgence with the popularization of expert systems in corporations. The Japanese government's Fifth Generation Computer Project also spurred increased investment in AI research globally.
|
||||
Neural networks gained prominence in the 1980s and 1990s. The backpropagation algorithm, although discovered earlier, became widely used for training multi-layer networks during this time.
|
||||
|
||||
Neural networks gained prominence in the 1980s and 1990s. The backpropagation algorithm, although discovered earlier, became widely used for training multi-layer networks during this time.
|
||||
The late 1990s and 2000s marked the rise of machine learning approaches. Support Vector Machines (SVMs) and Random Forests became popular for various classification and regression tasks.
|
||||
|
||||
The late 1990s and 2000s marked the rise of machine learning approaches. Support Vector Machines (SVMs) and Random Forests became popular for various classification and regression tasks.
|
||||
Deep Learning, a subset of machine learning using neural networks with many layers, began to show promising results in the early 2010s. The breakthrough came in 2012 when a deep neural network significantly outperformed other machine learning methods in the ImageNet competition.
|
||||
|
||||
Deep Learning, a subset of machine learning using neural networks with many layers, began to show promising results in the early 2010s. The breakthrough came in 2012 when a deep neural network significantly outperformed other machine learning methods in the ImageNet competition.
|
||||
Since then, deep learning has revolutionized many AI applications, including image and speech recognition, natural language processing, and game playing. In 2016, Google's AlphaGo defeated a world champion Go player, a landmark achievement in AI.
|
||||
|
||||
Since then, deep learning has revolutionized many AI applications, including image and speech recognition, natural language processing, and game playing. In 2016, Google's AlphaGo defeated a world champion Go player, a landmark achievement in AI.
|
||||
The current era of AI is characterized by the integration of deep learning with other AI techniques, the development of more efficient and powerful hardware, and the ethical considerations surrounding AI deployment.
|
||||
|
||||
The current era of AI is characterized by the integration of deep learning with other AI techniques, the development of more efficient and powerful hardware, and the ethical considerations surrounding AI deployment.
|
||||
Transformers, introduced in 2017, have become a dominant architecture in natural language processing, enabling models like GPT (Generative Pre-trained Transformer) to generate human-like text.
|
||||
|
||||
Transformers, introduced in 2017, have become a dominant architecture in natural language processing, enabling models like GPT (Generative Pre-trained Transformer) to generate human-like text.
|
||||
As AI continues to evolve, new challenges and opportunities arise. Explainable AI, robust and fair machine learning, and artificial general intelligence (AGI) are among the key areas of current and future research in the field.
|
||||
"""
|
||||
return split_text_to_chunks_with_indices(content, self.chunk_size, self.chunk_overlap)
|
||||
|
||||
As AI continues to evolve, new challenges and opportunities arise. Explainable AI, robust and fair machine learning, and artificial general intelligence (AGI) are among the key areas of current and future research in the field.
|
||||
"""
|
||||
def _prepare_retriever(self):
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = FAISS.from_documents(self.docs, embeddings)
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": 1})
|
||||
return vectorstore, retriever
|
||||
|
||||
chunks_size = 250
|
||||
chunk_overlap = 20
|
||||
document_chunks = split_text_to_chunks_with_indices(document_content, chunks_size, chunk_overlap)
|
||||
document_vectorstore = FAISS.from_documents(document_chunks, embeddings)
|
||||
document_retriever = document_vectorstore.as_retriever(search_kwargs={"k": 1})
|
||||
def run(self, query: str, num_neighbors: int = 1):
|
||||
baseline_chunk = self.retriever.get_relevant_documents(query)
|
||||
enriched_chunks = retrieve_with_context_overlap(self.vectorstore, self.retriever, query, num_neighbors,
|
||||
self.chunk_size, self.chunk_overlap)
|
||||
return baseline_chunk[0].page_content, enriched_chunks[0]
|
||||
|
||||
query = "When did deep learning become prominent in AI?"
|
||||
context = document_retriever.get_relevant_documents(query)
|
||||
context_pages_content = [doc.page_content for doc in context]
|
||||
|
||||
print("Regular retrieval:\n")
|
||||
show_context(context_pages_content)
|
||||
# Argument parsing function
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Run RAG method on a given PDF and query.")
|
||||
parser.add_argument("--query", type=str, default="When did deep learning become prominent in AI?",
|
||||
help="Query to test the retriever (default: 'What is the main topic of the document?').")
|
||||
parser.add_argument('--chunk_size', type=int, default=400, help="Size of text chunks.")
|
||||
parser.add_argument('--chunk_overlap', type=int, default=200, help="Overlap between chunks.")
|
||||
parser.add_argument('--num_neighbors', type=int, default=1, help="Number of neighboring chunks for context.")
|
||||
return parser.parse_args()
|
||||
|
||||
sequences = retrieve_with_context_overlap(document_vectorstore, document_retriever, query, num_neighbors=1)
|
||||
print("\nRetrieval with context enrichment:\n")
|
||||
show_context(sequences)
|
||||
|
||||
# Main execution
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# Initialize and run the RAG method
|
||||
rag_method = RAGMethod(chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap)
|
||||
baseline, enriched = rag_method.run(args.query, num_neighbors=args.num_neighbors)
|
||||
|
||||
print("Baseline Chunk:")
|
||||
print(baseline)
|
||||
|
||||
print("\nEnriched Chunks:")
|
||||
print(enriched)
|
||||
|
||||
@@ -1,50 +1,125 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
from dotenv import load_dotenv
|
||||
from langchain.retrievers.document_compressors import LLMChainExtractor
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.chains import RetrievalQA
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
# Add the parent directory to the path since we work with notebooks
|
||||
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
|
||||
|
||||
# Load environment variables from a .env file
|
||||
load_dotenv()
|
||||
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
# Define document's path
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
|
||||
# Create a vector store
|
||||
vector_store = encode_pdf(path)
|
||||
class ContextualCompressionRAG:
|
||||
"""
|
||||
A class to handle the process of creating a retrieval-based Question Answering system
|
||||
with a contextual compression retriever.
|
||||
"""
|
||||
|
||||
# Create a retriever + contexual compressor + combine them
|
||||
# Create a retriever
|
||||
retriever = vector_store.as_retriever()
|
||||
def __init__(self, path, model_name="gpt-4o-mini", temperature=0, max_tokens=4000):
|
||||
"""
|
||||
Initializes the ContextualCompressionRAG by setting up the document store and retriever.
|
||||
|
||||
# Create a contextual compressor
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
|
||||
compressor = LLMChainExtractor.from_llm(llm)
|
||||
Args:
|
||||
path (str): Path to the PDF file to process.
|
||||
model_name (str): The name of the language model to use (default: gpt-4o-mini).
|
||||
temperature (float): The temperature for the language model.
|
||||
max_tokens (int): The maximum tokens for the language model (default: 4000).
|
||||
"""
|
||||
print("\n--- Initializing Contextual Compression RAG ---")
|
||||
self.path = path
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# Combine the retriever with the compressor
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=compressor,
|
||||
base_retriever=retriever
|
||||
)
|
||||
# Step 1: Create a vector store
|
||||
self.vector_store = self._encode_document()
|
||||
|
||||
# Create a QA chain with the compressed retriever
|
||||
qa_chain = RetrievalQA.from_chain_type(
|
||||
llm=llm,
|
||||
retriever=compression_retriever,
|
||||
# Step 2: Create a retriever
|
||||
self.retriever = self.vector_store.as_retriever()
|
||||
|
||||
# Step 3: Initialize language model and create a contextual compressor
|
||||
self.llm = self._initialize_llm()
|
||||
self.compressor = LLMChainExtractor.from_llm(self.llm)
|
||||
|
||||
# Step 4: Combine the retriever with the compressor
|
||||
self.compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=self.compressor,
|
||||
base_retriever=self.retriever
|
||||
)
|
||||
|
||||
# Step 5: Create a QA chain with the compressed retriever
|
||||
self.qa_chain = RetrievalQA.from_chain_type(
|
||||
llm=self.llm,
|
||||
retriever=self.compression_retriever,
|
||||
return_source_documents=True
|
||||
)
|
||||
)
|
||||
|
||||
# Example usage
|
||||
query = "What is the main topic of the document?"
|
||||
result = qa_chain.invoke({"query": query})
|
||||
print(result["result"])
|
||||
print("Source documents:", result["source_documents"])
|
||||
def _encode_document(self):
|
||||
"""Helper function to encode the document into a vector store."""
|
||||
return encode_pdf(self.path)
|
||||
|
||||
def _initialize_llm(self):
|
||||
"""Helper function to initialize the language model."""
|
||||
return ChatOpenAI(temperature=self.temperature, model_name=self.model_name, max_tokens=self.max_tokens)
|
||||
|
||||
def run(self, query):
|
||||
"""
|
||||
Executes a query using the QA chain and prints the result.
|
||||
|
||||
Args:
|
||||
query (str): The query to run against the document.
|
||||
"""
|
||||
print("\n--- Running Query ---")
|
||||
start_time = time.time()
|
||||
result = self.qa_chain.invoke({"query": query})
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Display the result and the source documents
|
||||
print(f"Result: {result['result']}")
|
||||
print(f"Source Documents: {result['source_documents']}")
|
||||
print(f"Query Execution Time: {elapsed_time:.2f} seconds")
|
||||
return result, elapsed_time
|
||||
|
||||
|
||||
# Function to parse command line arguments
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Process a PDF document with contextual compression RAG.")
|
||||
parser.add_argument("--model_name", type=str, default="gpt-4o-mini",
|
||||
help="Name of the language model to use (default: gpt-4o-mini).")
|
||||
parser.add_argument("--path", type=str, default="../data/Understanding_Climate_Change.pdf",
|
||||
help="Path to the PDF file to process.")
|
||||
parser.add_argument("--query", type=str, default="What is the main topic of the document?",
|
||||
help="Query to test the retriever (default: 'What is the main topic of the document?').")
|
||||
parser.add_argument("--temperature", type=float, default=0,
|
||||
help="Temperature setting for the language model (default: 0).")
|
||||
parser.add_argument("--max_tokens", type=int, default=4000,
|
||||
help="Max tokens for the language model (default: 4000).")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# Main function to run the RAG pipeline
|
||||
def main(args):
|
||||
# Initialize ContextualCompressionRAG
|
||||
contextual_compression_rag = ContextualCompressionRAG(
|
||||
path=args.path,
|
||||
model_name=args.model_name,
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens
|
||||
)
|
||||
|
||||
# Run a query
|
||||
contextual_compression_rag.run(args.query)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Call the main function with parsed arguments
|
||||
main(parse_args())
|
||||
|
||||
@@ -1,202 +1,152 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from dotenv import load_dotenv
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain.tools import DuckDuckGoSearchResults
|
||||
from helper_functions import encode_pdf
|
||||
import json
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path since we work with notebooks
|
||||
|
||||
# Load environment variables from a .env file
|
||||
load_dotenv()
|
||||
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
from langchain.tools import DuckDuckGoSearchResults
|
||||
|
||||
# Define files path
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
|
||||
# Create a vector store
|
||||
vectorstore = encode_pdf(path)
|
||||
|
||||
# Initialize OpenAI language model
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", max_tokens=1000, temperature=0)
|
||||
|
||||
# Initialize search tool
|
||||
search = DuckDuckGoSearchResults()
|
||||
|
||||
|
||||
# Define retrieval evaluator, knowledge refinement and query rewriter llm chains
|
||||
# Retrieval Evaluator
|
||||
class RetrievalEvaluatorInput(BaseModel):
|
||||
relevance_score: float = Field(...,
|
||||
description="The relevance score of the document to the query. the score should be between 0 and 1.")
|
||||
"""
|
||||
Model for capturing the relevance score of a document to a query.
|
||||
"""
|
||||
relevance_score: float = Field(..., description="Relevance score between 0 and 1, "
|
||||
"indicating the document's relevance to the query.")
|
||||
|
||||
|
||||
def retrieval_evaluator(query: str, document: str) -> float:
|
||||
class QueryRewriterInput(BaseModel):
|
||||
"""
|
||||
Model for capturing a rewritten query suitable for web search.
|
||||
"""
|
||||
query: str = Field(..., description="The query rewritten for better web search results.")
|
||||
|
||||
|
||||
class KnowledgeRefinementInput(BaseModel):
|
||||
"""
|
||||
Model for extracting key points from a document.
|
||||
"""
|
||||
key_points: str = Field(..., description="Key information extracted from the document in bullet-point form.")
|
||||
|
||||
|
||||
class CRAG:
|
||||
"""
|
||||
A class to handle the CRAG process for document retrieval, evaluation, and knowledge refinement.
|
||||
"""
|
||||
|
||||
def __init__(self, path, model="gpt-4o-mini", max_tokens=1000, temperature=0, lower_threshold=0.3,
|
||||
upper_threshold=0.7):
|
||||
"""
|
||||
Initializes the CRAG Retriever by encoding the PDF document and creating the necessary models and search tools.
|
||||
|
||||
Args:
|
||||
path (str): Path to the PDF file to encode.
|
||||
model (str): The language model to use for the CRAG process.
|
||||
max_tokens (int): Maximum tokens to use in LLM responses (default: 1000).
|
||||
temperature (float): The temperature to use for LLM responses (default: 0).
|
||||
lower_threshold (float): Lower threshold for document evaluation scores (default: 0.3).
|
||||
upper_threshold (float): Upper threshold for document evaluation scores (default: 0.7).
|
||||
"""
|
||||
print("\n--- Initializing CRAG Process ---")
|
||||
|
||||
self.lower_threshold = lower_threshold
|
||||
self.upper_threshold = upper_threshold
|
||||
|
||||
# Encode the PDF document into a vector store
|
||||
self.vectorstore = encode_pdf(path)
|
||||
|
||||
# Initialize OpenAI language model
|
||||
self.llm = ChatOpenAI(model=model, max_tokens=max_tokens, temperature=temperature)
|
||||
|
||||
# Initialize search tool
|
||||
self.search = DuckDuckGoSearchResults()
|
||||
|
||||
@staticmethod
|
||||
def retrieve_documents(query, faiss_index, k=3):
|
||||
docs = faiss_index.similarity_search(query, k=k)
|
||||
return [doc.page_content for doc in docs]
|
||||
|
||||
def evaluate_documents(self, query, documents):
|
||||
return [self.retrieval_evaluator(query, doc) for doc in documents]
|
||||
|
||||
def retrieval_evaluator(self, query, document):
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["query", "document"],
|
||||
template="On a scale from 0 to 1, how relevant is the following document to the query? Query: {query}\nDocument: {document}\nRelevance score:"
|
||||
template="On a scale from 0 to 1, how relevant is the following document to the query? "
|
||||
"Query: {query}\nDocument: {document}\nRelevance score:"
|
||||
)
|
||||
chain = prompt | llm.with_structured_output(RetrievalEvaluatorInput)
|
||||
chain = prompt | self.llm.with_structured_output(RetrievalEvaluatorInput)
|
||||
input_variables = {"query": query, "document": document}
|
||||
result = chain.invoke(input_variables).relevance_score
|
||||
return result
|
||||
|
||||
|
||||
# Knowledge Refinement
|
||||
class KnowledgeRefinementInput(BaseModel):
|
||||
key_points: str = Field(..., description="The document to extract key information from.")
|
||||
|
||||
|
||||
def knowledge_refinement(document: str) -> List[str]:
|
||||
def knowledge_refinement(self, document):
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["document"],
|
||||
template="Extract the key information from the following document in bullet points:\n{document}\nKey points:"
|
||||
template="Extract the key information from the following document in bullet points:"
|
||||
"\n{document}\nKey points:"
|
||||
)
|
||||
chain = prompt | llm.with_structured_output(KnowledgeRefinementInput)
|
||||
chain = prompt | self.llm.with_structured_output(KnowledgeRefinementInput)
|
||||
input_variables = {"document": document}
|
||||
result = chain.invoke(input_variables).key_points
|
||||
return [point.strip() for point in result.split('\n') if point.strip()]
|
||||
|
||||
|
||||
# Web Search Query Rewriter
|
||||
class QueryRewriterInput(BaseModel):
|
||||
query: str = Field(..., description="The query to rewrite.")
|
||||
|
||||
|
||||
def rewrite_query(query: str) -> str:
|
||||
def rewrite_query(self, query):
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["query"],
|
||||
template="Rewrite the following query to make it more suitable for a web search:\n{query}\nRewritten query:"
|
||||
)
|
||||
chain = prompt | llm.with_structured_output(QueryRewriterInput)
|
||||
chain = prompt | self.llm.with_structured_output(QueryRewriterInput)
|
||||
input_variables = {"query": query}
|
||||
return chain.invoke(input_variables).query.strip()
|
||||
|
||||
|
||||
# Helper function to parse search results
|
||||
|
||||
def parse_search_results(results_string: str) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Parse a JSON string of search results into a list of title-link tuples.
|
||||
|
||||
Args:
|
||||
results_string (str): A JSON-formatted string containing search results.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: A list of tuples, where each tuple contains the title and link of a search result.
|
||||
If parsing fails, an empty list is returned.
|
||||
"""
|
||||
@staticmethod
|
||||
def parse_search_results(results_string):
|
||||
try:
|
||||
# Attempt to parse the JSON string
|
||||
results = json.loads(results_string)
|
||||
# Extract and return the title and link from each result
|
||||
return [(result.get('title', 'Untitled'), result.get('link', '')) for result in results]
|
||||
except json.JSONDecodeError:
|
||||
# Handle JSON decoding errors by returning an empty list
|
||||
print("Error parsing search results. Returning empty list.")
|
||||
return []
|
||||
|
||||
|
||||
# Define sub functions for the CRAG process
|
||||
def retrieve_documents(query: str, faiss_index: FAISS, k: int = 3) -> List[str]:
|
||||
"""
|
||||
Retrieve documents based on a query using a FAISS index.
|
||||
|
||||
Args:
|
||||
query (str): The query string to search for.
|
||||
faiss_index (FAISS): The FAISS index used for similarity search.
|
||||
k (int): The number of top documents to retrieve. Defaults to 3.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of the retrieved document contents.
|
||||
"""
|
||||
docs = faiss_index.similarity_search(query, k=k)
|
||||
return [doc.page_content for doc in docs]
|
||||
|
||||
|
||||
def evaluate_documents(query: str, documents: List[str]) -> List[float]:
|
||||
"""
|
||||
Evaluate the relevance of documents based on a query.
|
||||
|
||||
Args:
|
||||
query (str): The query string.
|
||||
documents (List[str]): A list of document contents to evaluate.
|
||||
|
||||
Returns:
|
||||
List[float]: A list of relevance scores for each document.
|
||||
"""
|
||||
return [retrieval_evaluator(query, doc) for doc in documents]
|
||||
|
||||
|
||||
def perform_web_search(query: str) -> Tuple[List[str], List[Tuple[str, str]]]:
|
||||
"""
|
||||
Perform a web search based on a query.
|
||||
|
||||
Args:
|
||||
query (str): The query string to search for.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], List[Tuple[str, str]]]:
|
||||
- A list of refined knowledge obtained from the web search.
|
||||
- A list of tuples containing titles and links of the sources.
|
||||
"""
|
||||
rewritten_query = rewrite_query(query)
|
||||
web_results = search.run(rewritten_query)
|
||||
web_knowledge = knowledge_refinement(web_results)
|
||||
sources = parse_search_results(web_results)
|
||||
def perform_web_search(self, query):
|
||||
rewritten_query = self.rewrite_query(query)
|
||||
web_results = self.search.run(rewritten_query)
|
||||
web_knowledge = self.knowledge_refinement(web_results)
|
||||
sources = self.parse_search_results(web_results)
|
||||
return web_knowledge, sources
|
||||
|
||||
|
||||
def generate_response(query: str, knowledge: str, sources: List[Tuple[str, str]]) -> str:
|
||||
"""
|
||||
Generate a response to a query using knowledge and sources.
|
||||
|
||||
Args:
|
||||
query (str): The query string.
|
||||
knowledge (str): The refined knowledge to use in the response.
|
||||
sources (List[Tuple[str, str]]): A list of tuples containing titles and links of the sources.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
def generate_response(self, query, knowledge, sources):
|
||||
response_prompt = PromptTemplate(
|
||||
input_variables=["query", "knowledge", "sources"],
|
||||
template="Based on the following knowledge, answer the query. Include the sources with their links (if available) at the end of your answer:\nQuery: {query}\nKnowledge: {knowledge}\nSources: {sources}\nAnswer:"
|
||||
template="Based on the following knowledge, answer the query. "
|
||||
"Include the sources with their links (if available) at the end of your answer:"
|
||||
"\nQuery: {query}\nKnowledge: {knowledge}\nSources: {sources}\nAnswer:"
|
||||
)
|
||||
input_variables = {
|
||||
"query": query,
|
||||
"knowledge": knowledge,
|
||||
"sources": "\n".join([f"{title}: {link}" if link else title for title, link in sources])
|
||||
}
|
||||
response_chain = response_prompt | llm
|
||||
response_chain = response_prompt | self.llm
|
||||
return response_chain.invoke(input_variables).content
|
||||
|
||||
|
||||
# CRAG process
|
||||
|
||||
def crag_process(query: str, faiss_index: FAISS) -> str:
|
||||
"""
|
||||
Process a query by retrieving, evaluating, and using documents or performing a web search to generate a response.
|
||||
|
||||
Args:
|
||||
query (str): The query string to process.
|
||||
faiss_index (FAISS): The FAISS index used for document retrieval.
|
||||
|
||||
Returns:
|
||||
str: The generated response based on the query.
|
||||
"""
|
||||
def run(self, query):
|
||||
print(f"\nProcessing query: {query}")
|
||||
|
||||
# Retrieve and evaluate documents
|
||||
retrieved_docs = retrieve_documents(query, faiss_index)
|
||||
eval_scores = evaluate_documents(query, retrieved_docs)
|
||||
retrieved_docs = self.retrieve_documents(query, self.vectorstore)
|
||||
eval_scores = self.evaluate_documents(query, retrieved_docs)
|
||||
|
||||
print(f"\nRetrieved {len(retrieved_docs)} documents")
|
||||
print(f"Evaluation scores: {eval_scores}")
|
||||
@@ -212,13 +162,12 @@ def crag_process(query: str, faiss_index: FAISS) -> str:
|
||||
sources.append(("Retrieved document", ""))
|
||||
elif max_score < 0.3:
|
||||
print("\nAction: Incorrect - Performing web search")
|
||||
final_knowledge, sources = perform_web_search(query)
|
||||
final_knowledge, sources = self.perform_web_search(query)
|
||||
else:
|
||||
print("\nAction: Ambiguous - Combining retrieved document and web search")
|
||||
best_doc = retrieved_docs[eval_scores.index(max_score)]
|
||||
# Refine the retrieved knowledge
|
||||
retrieved_knowledge = knowledge_refinement(best_doc)
|
||||
web_knowledge, web_sources = perform_web_search(query)
|
||||
retrieved_knowledge = self.knowledge_refinement(best_doc)
|
||||
web_knowledge, web_sources = self.perform_web_search(query)
|
||||
final_knowledge = "\n".join(retrieved_knowledge + web_knowledge)
|
||||
sources = [("Retrieved document", "")] + web_sources
|
||||
|
||||
@@ -229,24 +178,59 @@ def crag_process(query: str, faiss_index: FAISS) -> str:
|
||||
for title, link in sources:
|
||||
print(f"{title}: {link}" if link else title)
|
||||
|
||||
# Generate response
|
||||
print("\nGenerating response...")
|
||||
response = generate_response(query, final_knowledge, sources)
|
||||
|
||||
response = self.generate_response(query, final_knowledge, sources)
|
||||
print("\nResponse generated")
|
||||
return response
|
||||
|
||||
|
||||
# Example query with high relevance to the document
|
||||
# Function to validate command line inputs
|
||||
def validate_args(args):
|
||||
if args.max_tokens <= 0:
|
||||
raise ValueError("max_tokens must be a positive integer.")
|
||||
if args.temperature < 0 or args.temperature > 1:
|
||||
raise ValueError("temperature must be between 0 and 1.")
|
||||
return args
|
||||
|
||||
query = "What are the main causes of climate change?"
|
||||
result = crag_process(query, vectorstore)
|
||||
print(f"Query: {query}")
|
||||
print(f"Answer: {result}")
|
||||
|
||||
# Example query with low relevance to the document
|
||||
# Function to parse command line arguments
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="CRAG Process for Document Retrieval and Query Answering.")
|
||||
parser.add_argument("--path", type=str, default="../data/Understanding_Climate_Change.pdf",
|
||||
help="Path to the PDF file to encode.")
|
||||
parser.add_argument("--model", type=str, default="gpt-4o-mini",
|
||||
help="Language model to use (default: gpt-4o-mini).")
|
||||
parser.add_argument("--max_tokens", type=int, default=1000,
|
||||
help="Maximum tokens to use in LLM responses (default: 1000).")
|
||||
parser.add_argument("--temperature", type=float, default=0,
|
||||
help="Temperature to use for LLM responses (default: 0).")
|
||||
parser.add_argument("--query", type=str, default="What are the main causes of climate change?",
|
||||
help="Query to test the CRAG process.")
|
||||
parser.add_argument("--lower_threshold", type=float, default=0.3,
|
||||
help="Lower threshold for score evaluation (default: 0.3).")
|
||||
parser.add_argument("--upper_threshold", type=float, default=0.7,
|
||||
help="Upper threshold for score evaluation (default: 0.7).")
|
||||
|
||||
query = "how did harry beat quirrell?"
|
||||
result = crag_process(query, vectorstore)
|
||||
print(f"Query: {query}")
|
||||
print(f"Answer: {result}")
|
||||
return validate_args(parser.parse_args())
|
||||
|
||||
|
||||
# Main function to handle argument parsing and call the CRAG class
|
||||
def main(args):
|
||||
# Initialize the CRAG process
|
||||
crag = CRAG(
|
||||
path=args.path,
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
temperature=args.temperature,
|
||||
lower_threshold=args.lower_threshold,
|
||||
upper_threshold=args.upper_threshold
|
||||
)
|
||||
|
||||
# Process the query
|
||||
response = crag.run(args.query)
|
||||
print(f"Query: {args.query}")
|
||||
print(f"Answer: {response}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(parse_args())
|
||||
|
||||
@@ -7,6 +7,8 @@ from enum import Enum
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain_openai import ChatOpenAI
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
import argparse
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -14,8 +16,7 @@ load_dotenv()
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path
|
||||
|
||||
from helper_functions import *
|
||||
|
||||
@@ -23,30 +24,19 @@ from helper_functions import *
|
||||
class QuestionGeneration(Enum):
|
||||
"""
|
||||
Enum class to specify the level of question generation for document processing.
|
||||
|
||||
Attributes:
|
||||
DOCUMENT_LEVEL (int): Represents question generation at the entire document level.
|
||||
FRAGMENT_LEVEL (int): Represents question generation at the individual text fragment level.
|
||||
"""
|
||||
DOCUMENT_LEVEL = 1
|
||||
FRAGMENT_LEVEL = 2
|
||||
|
||||
|
||||
# Depending on the model, for Mitral 7B it can be max 8000, for Llama 3.1 8B 128k
|
||||
DOCUMENT_MAX_TOKENS = 4000
|
||||
DOCUMENT_OVERLAP_TOKENS = 100
|
||||
|
||||
# Embeddings and text similarity calculated on shorter texts
|
||||
FRAGMENT_MAX_TOKENS = 128
|
||||
FRAGMENT_OVERLAP_TOKENS = 16
|
||||
|
||||
# Questions generated on document or fragment level
|
||||
QUESTION_GENERATION = QuestionGeneration.DOCUMENT_LEVEL
|
||||
# how many questions will be generated for specific document or fragment
|
||||
QUESTIONS_PER_DOCUMENT = 40
|
||||
|
||||
|
||||
# Define classes and functions used by this pipeline
|
||||
class QuestionList(BaseModel):
|
||||
question_list: List[str] = Field(..., title="List of questions generated for the document or fragment")
|
||||
|
||||
@@ -55,30 +45,11 @@ class OpenAIEmbeddingsWrapper(OpenAIEmbeddings):
|
||||
"""
|
||||
A wrapper class for OpenAI embeddings, providing a similar interface to the original OllamaEmbeddings.
|
||||
"""
|
||||
|
||||
def __call__(self, query: str) -> List[float]:
|
||||
"""
|
||||
Allows the instance to be used as a callable to generate an embedding for a query.
|
||||
|
||||
Args:
|
||||
query (str): The query string to be embedded.
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding for the query as a list of floats.
|
||||
"""
|
||||
return self.embed_query(query)
|
||||
|
||||
|
||||
def clean_and_filter_questions(questions: List[str]) -> List[str]:
|
||||
"""
|
||||
Cleans and filters a list of questions.
|
||||
|
||||
Args:
|
||||
questions (List[str]): A list of questions to be cleaned and filtered.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of cleaned and filtered questions that end with a question mark.
|
||||
"""
|
||||
cleaned_questions = []
|
||||
for question in questions:
|
||||
cleaned_question = re.sub(r'^\d+\.\s*', '', question.strip())
|
||||
@@ -88,45 +59,20 @@ def clean_and_filter_questions(questions: List[str]) -> List[str]:
|
||||
|
||||
|
||||
def generate_questions(text: str) -> List[str]:
|
||||
"""
|
||||
Generates a list of questions based on the provided text using OpenAI.
|
||||
|
||||
Args:
|
||||
text (str): The context data from which questions are generated.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of unique, filtered questions.
|
||||
"""
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["context", "num_questions"],
|
||||
template="Using the context data: {context}\n\nGenerate a list of at least {num_questions} "
|
||||
"possible questions that can be asked about this context. Ensure the questions are "
|
||||
"directly answerable within the context and do not include any answers or headers. "
|
||||
"Separate the questions with a new line character."
|
||||
"possible questions that can be asked about this context."
|
||||
)
|
||||
chain = prompt | llm.with_structured_output(QuestionList)
|
||||
input_data = {"context": text, "num_questions": QUESTIONS_PER_DOCUMENT}
|
||||
result = chain.invoke(input_data)
|
||||
|
||||
# Extract the list of questions from the QuestionList object
|
||||
questions = result.question_list
|
||||
|
||||
filtered_questions = clean_and_filter_questions(questions)
|
||||
return list(set(filtered_questions))
|
||||
return list(set(clean_and_filter_questions(questions)))
|
||||
|
||||
|
||||
def generate_answer(content: str, question: str) -> str:
|
||||
"""
|
||||
Generates an answer to a given question based on the provided context using OpenAI.
|
||||
|
||||
Args:
|
||||
content (str): The context data used to generate the answer.
|
||||
question (str): The question for which the answer is generated.
|
||||
|
||||
Returns:
|
||||
str: The precise answer to the question based on the provided context.
|
||||
"""
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["context", "question"],
|
||||
@@ -138,17 +84,6 @@ def generate_answer(content: str, question: str) -> str:
|
||||
|
||||
|
||||
def split_document(document: str, chunk_size: int, chunk_overlap: int) -> List[str]:
|
||||
"""
|
||||
Splits a document into smaller chunks of text.
|
||||
|
||||
Args:
|
||||
document (str): The text of the document to be split.
|
||||
chunk_size (int): The size of each chunk in terms of the number of tokens.
|
||||
chunk_overlap (int): The number of overlapping tokens between consecutive chunks.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of text chunks, where each chunk is a string of the document content.
|
||||
"""
|
||||
tokens = re.findall(r'\b\w+\b', document)
|
||||
chunks = []
|
||||
for i in range(0, len(tokens), chunk_size - chunk_overlap):
|
||||
@@ -160,68 +95,16 @@ def split_document(document: str, chunk_size: int, chunk_overlap: int) -> List[s
|
||||
|
||||
|
||||
def print_document(comment: str, document: Any) -> None:
|
||||
"""
|
||||
Prints a comment followed by the content of a document.
|
||||
|
||||
Args:
|
||||
comment (str): The comment or description to print before the document details.
|
||||
document (Any): The document whose content is to be printed.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
print(
|
||||
f'{comment} (type: {document.metadata["type"]}, index: {document.metadata["index"]}): {document.page_content}')
|
||||
print(f'{comment} (type: {document.metadata["type"]}, index: {document.metadata["index"]}): {document.page_content}')
|
||||
|
||||
|
||||
# Example usage
|
||||
class DocumentProcessor:
|
||||
def __init__(self, content: str, embedding_model: OpenAIEmbeddings):
|
||||
self.content = content
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
# Initialize OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddingsWrapper()
|
||||
|
||||
# Example document
|
||||
example_text = "This is an example document. It contains information about various topics."
|
||||
|
||||
# Generate questions
|
||||
questions = generate_questions(example_text)
|
||||
print("Generated Questions:")
|
||||
for q in questions:
|
||||
print(f"- {q}")
|
||||
|
||||
# Generate an answer
|
||||
sample_question = questions[0] if questions else "What is this document about?"
|
||||
answer = generate_answer(example_text, sample_question)
|
||||
print(f"\nQuestion: {sample_question}")
|
||||
print(f"Answer: {answer}")
|
||||
|
||||
# Split document
|
||||
chunks = split_document(example_text, chunk_size=10, chunk_overlap=2)
|
||||
print("\nDocument Chunks:")
|
||||
for i, chunk in enumerate(chunks):
|
||||
print(f"Chunk {i + 1}: {chunk}")
|
||||
|
||||
# Example of using OpenAIEmbeddings
|
||||
doc_embedding = embeddings.embed_documents([example_text])
|
||||
query_embedding = embeddings.embed_query("What is the main topic?")
|
||||
print("\nDocument Embedding (first 5 elements):", doc_embedding[0][:5])
|
||||
print("Query Embedding (first 5 elements):", query_embedding[:5])
|
||||
|
||||
|
||||
# Main pipeline
|
||||
def process_documents(content: str, embedding_model: OpenAIEmbeddings):
|
||||
"""
|
||||
Process the document content, split it into fragments, generate questions,
|
||||
create a FAISS vector store, and return a retriever.
|
||||
|
||||
Args:
|
||||
content (str): The content of the document to process.
|
||||
embedding_model (OpenAIEmbeddings): The embedding model to use for vectorization.
|
||||
|
||||
Returns:
|
||||
VectorStoreRetriever: A retriever for the most relevant FAISS document.
|
||||
"""
|
||||
# Split the whole text content into text documents
|
||||
text_documents = split_document(content, DOCUMENT_MAX_TOKENS, DOCUMENT_OVERLAP_TOKENS)
|
||||
def run(self):
|
||||
text_documents = split_document(self.content, DOCUMENT_MAX_TOKENS, DOCUMENT_OVERLAP_TOKENS)
|
||||
print(f'Text content split into: {len(text_documents)} documents')
|
||||
|
||||
documents = []
|
||||
@@ -261,39 +144,44 @@ def process_documents(content: str, embedding_model: OpenAIEmbeddings):
|
||||
print_document("Dataset", document)
|
||||
|
||||
print(f'Creating store, calculating embeddings for {len(documents)} FAISS documents')
|
||||
vectorstore = FAISS.from_documents(documents, embedding_model)
|
||||
vectorstore = FAISS.from_documents(documents, self.embedding_model)
|
||||
|
||||
print("Creating retriever returning the most relevant FAISS document")
|
||||
return vectorstore.as_retriever(search_kwargs={"k": 1})
|
||||
|
||||
|
||||
# Example
|
||||
# Load sample PDF document to string variable
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
content = read_pdf_to_string(path)
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Process a document and create a retriever.")
|
||||
parser.add_argument('--path', type=str, default='../data/Understanding_Climate_Change.pdf',
|
||||
help="Path to the PDF document to process")
|
||||
return parser.parse_args()
|
||||
|
||||
# Instantiate OpenAI Embeddings class that will be used by FAISS
|
||||
embedding_model = OpenAIEmbeddings()
|
||||
|
||||
# Process documents and create retriever
|
||||
document_query_retriever = process_documents(content, embedding_model)
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# Example usage of the retriever
|
||||
query = "What is climate change?"
|
||||
retrieved_docs = document_query_retriever.get_relevant_documents(query)
|
||||
print(f"\nQuery: {query}")
|
||||
print(f"Retrieved document: {retrieved_docs[0].page_content}")
|
||||
# Load sample PDF document to string variable
|
||||
content = read_pdf_to_string(args.path)
|
||||
|
||||
# Find the most relevant FAISS document in the store. In most cases, this will be an augmented question rather than the original text document.
|
||||
query = "How do freshwater ecosystems change due to alterations in climatic factors?"
|
||||
print(f'Question:{os.linesep}{query}{os.linesep}')
|
||||
retrieved_documents = document_query_retriever.invoke(query)
|
||||
# Instantiate OpenAI Embeddings class that will be used by FAISS
|
||||
embedding_model = OpenAIEmbeddings()
|
||||
|
||||
for doc in retrieved_documents:
|
||||
# Process documents and create retriever
|
||||
processor = DocumentProcessor(content, embedding_model)
|
||||
document_query_retriever = processor.run()
|
||||
|
||||
# Example usage of the retriever
|
||||
query = "What is climate change?"
|
||||
retrieved_docs = document_query_retriever.get_relevant_documents(query)
|
||||
print(f"\nQuery: {query}")
|
||||
print(f"Retrieved document: {retrieved_docs[0].page_content}")
|
||||
|
||||
# Further query example
|
||||
query = "How do freshwater ecosystems change due to alterations in climatic factors?"
|
||||
retrieved_documents = document_query_retriever.get_relevant_documents(query)
|
||||
for doc in retrieved_documents:
|
||||
print_document("Relevant fragment retrieved", doc)
|
||||
|
||||
# Find the parent text document and use it as context for the generative model to generate an answer to the question.
|
||||
context = doc.metadata['text']
|
||||
print(f'{os.linesep}Context:{os.linesep}{context}')
|
||||
answer = generate_answer(context, query)
|
||||
print(f'{os.linesep}Answer:{os.linesep}{answer}')
|
||||
context = doc.metadata['text']
|
||||
answer = generate_answer(context, query)
|
||||
print(f'{os.linesep}Answer:{os.linesep}{answer}')
|
||||
|
||||
@@ -2,8 +2,7 @@ import os
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
@@ -14,18 +13,14 @@ load_dotenv()
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
|
||||
# Define the explainable retriever class
|
||||
# Define utility classes/functions
|
||||
class ExplainableRetriever:
|
||||
def __init__(self, texts):
|
||||
self.embeddings = OpenAIEmbeddings()
|
||||
|
||||
self.vectorstore = FAISS.from_texts(texts, self.embeddings)
|
||||
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
|
||||
|
||||
# Create a base retriever
|
||||
self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
|
||||
|
||||
# Create an explanation chain
|
||||
explain_prompt = PromptTemplate(
|
||||
input_variables=["query", "context"],
|
||||
template="""
|
||||
@@ -42,39 +37,49 @@ class ExplainableRetriever:
|
||||
self.explain_chain = explain_prompt | self.llm
|
||||
|
||||
def retrieve_and_explain(self, query):
|
||||
# Retrieve relevant documents
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
|
||||
explained_results = []
|
||||
|
||||
for doc in docs:
|
||||
# Generate explanation
|
||||
input_data = {"query": query, "context": doc.page_content}
|
||||
explanation = self.explain_chain.invoke(input_data).content
|
||||
|
||||
explained_results.append({
|
||||
"content": doc.page_content,
|
||||
"explanation": explanation
|
||||
})
|
||||
|
||||
return explained_results
|
||||
|
||||
|
||||
# Create a mock example and explainable retriever instance
|
||||
# Usage
|
||||
texts = [
|
||||
class ExplainableRAGMethod:
|
||||
def __init__(self, texts):
|
||||
self.explainable_retriever = ExplainableRetriever(texts)
|
||||
|
||||
def run(self, query):
|
||||
return self.explainable_retriever.retrieve_and_explain(query)
|
||||
|
||||
|
||||
# Argument Parsing
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Explainable RAG Method")
|
||||
parser.add_argument('--query', type=str, default='Why is the sky blue?', help="Query for the retriever")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# Sample texts (these can be replaced by actual data)
|
||||
texts = [
|
||||
"The sky is blue because of the way sunlight interacts with the atmosphere.",
|
||||
"Photosynthesis is the process by which plants use sunlight to produce energy.",
|
||||
"Global warming is caused by the increase of greenhouse gases in Earth's atmosphere."
|
||||
]
|
||||
]
|
||||
|
||||
explainable_retriever = ExplainableRetriever(texts)
|
||||
explainable_rag = ExplainableRAGMethod(texts)
|
||||
results = explainable_rag.run(args.query)
|
||||
|
||||
# Show the results
|
||||
query = "Why is the sky blue?"
|
||||
results = explainable_retriever.retrieve_and_explain(query)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"Result {i}:")
|
||||
print(f"Content: {result['content']}")
|
||||
print(f"Explanation: {result['explanation']}")
|
||||
|
||||
@@ -2,27 +2,21 @@ import os
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from typing import List
|
||||
from rank_bm25 import BM25Okapi
|
||||
import numpy as np
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
# Add the parent directory to the path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
# Load environment variables from a .env file
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
# Define document path
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
|
||||
|
||||
# Encode the pdf to vector store and return split document from the step before to create BM25 instance
|
||||
# Function to encode the PDF to a vector store and return split documents
|
||||
def encode_pdf_and_get_split_documents(path, chunk_size=1000, chunk_overlap=200):
|
||||
"""
|
||||
Encodes a PDF book into a vector store using OpenAI embeddings.
|
||||
@@ -35,53 +29,35 @@ def encode_pdf_and_get_split_documents(path, chunk_size=1000, chunk_overlap=200)
|
||||
Returns:
|
||||
A FAISS vector store containing the encoded book content.
|
||||
"""
|
||||
|
||||
# Load PDF documents
|
||||
loader = PyPDFLoader(path)
|
||||
documents = loader.load()
|
||||
|
||||
# Split documents into chunks
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
|
||||
)
|
||||
texts = text_splitter.split_documents(documents)
|
||||
cleaned_texts = replace_t_with_space(texts)
|
||||
|
||||
# Create embeddings and vector store
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = FAISS.from_documents(cleaned_texts, embeddings)
|
||||
|
||||
return vectorstore, cleaned_texts
|
||||
|
||||
|
||||
# Create vectorstore and get the chunked documents
|
||||
vectorstore, cleaned_texts = encode_pdf_and_get_split_documents(path)
|
||||
|
||||
|
||||
# Create a bm25 index for retrieving documents by keywords
|
||||
# Function to create BM25 index for keyword retrieval
|
||||
def create_bm25_index(documents: List[Document]) -> BM25Okapi:
|
||||
"""
|
||||
Create a BM25 index from the given documents.
|
||||
|
||||
BM25 (Best Matching 25) is a ranking function used in information retrieval.
|
||||
It's based on the probabilistic retrieval framework and is an improvement over TF-IDF.
|
||||
|
||||
Args:
|
||||
documents (List[Document]): List of documents to index.
|
||||
|
||||
Returns:
|
||||
BM25Okapi: An index that can be used for BM25 scoring.
|
||||
"""
|
||||
# Tokenize each document by splitting on whitespace
|
||||
# This is a simple approach and could be improved with more sophisticated tokenization
|
||||
tokenized_docs = [doc.page_content.split() for doc in documents]
|
||||
return BM25Okapi(tokenized_docs)
|
||||
|
||||
|
||||
bm25 = create_bm25_index(cleaned_texts) # Create BM25 index from the cleaned texts (chunks)
|
||||
|
||||
|
||||
# Define a function that retrieves both semantically and by keyword, normalizes the scores and gets the top k documents
|
||||
# Function for fusion retrieval combining keyword-based (BM25) and vector-based search
|
||||
def fusion_retrieval(vectorstore, bm25, query: str, k: int = 5, alpha: float = 0.5) -> List[Document]:
|
||||
"""
|
||||
Perform fusion retrieval combining keyword-based (BM25) and vector-based search.
|
||||
@@ -96,36 +72,72 @@ def fusion_retrieval(vectorstore, bm25, query: str, k: int = 5, alpha: float = 0
|
||||
Returns:
|
||||
List[Document]: The top k documents based on the combined scores.
|
||||
"""
|
||||
# Step 1: Get all documents from the vectorstore
|
||||
all_docs = vectorstore.similarity_search("", k=vectorstore.index.ntotal)
|
||||
|
||||
# Step 2: Perform BM25 search
|
||||
bm25_scores = bm25.get_scores(query.split())
|
||||
|
||||
# Step 3: Perform vector search
|
||||
vector_results = vectorstore.similarity_search_with_score(query, k=len(all_docs))
|
||||
|
||||
# Step 4: Normalize scores
|
||||
vector_scores = np.array([score for _, score in vector_results])
|
||||
vector_scores = 1 - (vector_scores - np.min(vector_scores)) / (np.max(vector_scores) - np.min(vector_scores))
|
||||
|
||||
bm25_scores = (bm25_scores - np.min(bm25_scores)) / (np.max(bm25_scores) - np.min(bm25_scores))
|
||||
|
||||
# Step 5: Combine scores
|
||||
combined_scores = alpha * vector_scores + (1 - alpha) * bm25_scores
|
||||
|
||||
# Step 6: Rank documents
|
||||
sorted_indices = np.argsort(combined_scores)[::-1]
|
||||
|
||||
# Step 7: Return top k documents
|
||||
return [all_docs[i] for i in sorted_indices[:k]]
|
||||
|
||||
|
||||
# Use Case example
|
||||
# Query
|
||||
query = "What are the impacts of climate change on the environment?"
|
||||
class FusionRetrievalRAG:
|
||||
def __init__(self, path: str, chunk_size: int = 1000, chunk_overlap: int = 200):
|
||||
"""
|
||||
Initializes the FusionRetrievalRAG class by setting up the vector store and BM25 index.
|
||||
|
||||
# Perform fusion retrieval
|
||||
top_docs = fusion_retrieval(vectorstore, bm25, query, k=5, alpha=0.5)
|
||||
docs_content = [doc.page_content for doc in top_docs]
|
||||
show_context(docs_content)
|
||||
Args:
|
||||
path (str): Path to the PDF file.
|
||||
chunk_size (int): The size of each text chunk.
|
||||
chunk_overlap (int): The overlap between consecutive chunks.
|
||||
"""
|
||||
self.vectorstore, self.cleaned_texts = encode_pdf_and_get_split_documents(path, chunk_size, chunk_overlap)
|
||||
self.bm25 = create_bm25_index(self.cleaned_texts)
|
||||
|
||||
def run(self, query: str, k: int = 5, alpha: float = 0.5):
|
||||
"""
|
||||
Executes the fusion retrieval for the given query.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
k (int): The number of documents to retrieve.
|
||||
alpha (float): The weight of vector search vs. BM25 search.
|
||||
|
||||
Returns:
|
||||
List[Document]: The top k retrieved documents.
|
||||
"""
|
||||
top_docs = fusion_retrieval(self.vectorstore, self.bm25, query, k, alpha)
|
||||
docs_content = [doc.page_content for doc in top_docs]
|
||||
show_context(docs_content)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Parses command-line arguments.
|
||||
|
||||
Returns:
|
||||
args: The parsed arguments.
|
||||
"""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Fusion Retrieval RAG Script")
|
||||
parser.add_argument('--path', type=str, default="../data/Understanding_Climate_Change.pdf",
|
||||
help='Path to the PDF file.')
|
||||
parser.add_argument('--chunk_size', type=int, default=1000, help='Size of each chunk.')
|
||||
parser.add_argument('--chunk_overlap', type=int, default=200, help='Overlap between consecutive chunks.')
|
||||
parser.add_argument('--query', type=str, default='What are the impacts of climate change on the environment?',
|
||||
help='Query to retrieve documents.')
|
||||
parser.add_argument('--k', type=int, default=5, help='Number of documents to retrieve.')
|
||||
parser.add_argument('--alpha', type=float, default=0.5, help='Weight for vector search vs. BM25.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
retriever = FusionRetrievalRAG(path=args.path, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap)
|
||||
retriever.run(query=args.query, k=args.k, alpha=args.alpha)
|
||||
|
||||
@@ -19,6 +19,7 @@ from nltk.tokenize import word_tokenize
|
||||
import nltk
|
||||
import spacy
|
||||
import heapq
|
||||
import argparse
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from tqdm import tqdm
|
||||
@@ -729,11 +730,14 @@ class Visualizer:
|
||||
|
||||
# Define the graph RAG class
|
||||
class GraphRAG:
|
||||
def __init__(self):
|
||||
def __init__(self, documents):
|
||||
"""
|
||||
Initializes the GraphRAG system with components for document processing, knowledge graph construction,
|
||||
querying, and visualization.
|
||||
|
||||
Args:
|
||||
- documents (list of str): A list of documents to be processed.
|
||||
|
||||
Attributes:
|
||||
- llm: An instance of a large language model (LLM) for generating responses.
|
||||
- embedding_model: An instance of an embedding model for document embeddings.
|
||||
@@ -748,6 +752,7 @@ class GraphRAG:
|
||||
self.knowledge_graph = KnowledgeGraph()
|
||||
self.query_engine = None
|
||||
self.visualizer = Visualizer()
|
||||
self.process_documents(documents)
|
||||
|
||||
def process_documents(self, documents):
|
||||
"""
|
||||
@@ -783,20 +788,29 @@ class GraphRAG:
|
||||
return response
|
||||
|
||||
|
||||
# Define documents path
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
# Argument parsing
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="GraphRAG system")
|
||||
parser.add_argument('--path', type=str, default="../data/Understanding_Climate_Change.pdf",
|
||||
help='Path to the PDF file.')
|
||||
parser.add_argument('--query', type=str, default='what is the main cause of climate change?',
|
||||
help='Query to retrieve documents.')
|
||||
return parser.parse_args()
|
||||
|
||||
# Load the documents
|
||||
loader = PyPDFLoader(path)
|
||||
documents = loader.load()
|
||||
documents = documents[:10]
|
||||
|
||||
# Create a graph RAG instance
|
||||
graph_rag = GraphRAG()
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
# Process the documents and create the graph
|
||||
graph_rag.process_documents(documents)
|
||||
# Load the documents
|
||||
loader = PyPDFLoader(args.path)
|
||||
documents = loader.load()
|
||||
documents = documents[:10]
|
||||
|
||||
# Input a query and get the retrieved information from the graph RAG
|
||||
query = "what is the main cause of climate change?"
|
||||
response = graph_rag.query(query)
|
||||
# Create a graph RAG instance
|
||||
graph_rag = GraphRAG(documents)
|
||||
|
||||
# Process the documents and create the graph
|
||||
graph_rag.process_documents(documents)
|
||||
|
||||
# Input a query and get the retrieved information from the graph RAG
|
||||
response = graph_rag.query(args.query)
|
||||
|
||||
@@ -7,8 +7,7 @@ from langchain.chains.summarize.chain import load_summarize_chain
|
||||
from langchain.docstore.document import Document
|
||||
from helper_functions import encode_pdf, encode_from_string
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
@@ -18,105 +17,49 @@ load_dotenv()
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
# Define document path
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
|
||||
|
||||
# Function to encode to both summary and chunk levels, sharing the page metadata
|
||||
async def encode_pdf_hierarchical(path, chunk_size=1000, chunk_overlap=200, is_string=False):
|
||||
"""
|
||||
Asynchronously encodes a PDF book into a hierarchical vector store using OpenAI embeddings.
|
||||
Includes rate limit handling with exponential backoff.
|
||||
|
||||
Args:
|
||||
path: The path to the PDF file.
|
||||
chunk_size: The desired size of each text chunk.
|
||||
chunk_overlap: The amount of overlap between consecutive chunks.
|
||||
|
||||
Returns:
|
||||
A tuple containing two FAISS vector stores:
|
||||
1. Document-level summaries
|
||||
2. Detailed chunks
|
||||
"""
|
||||
|
||||
# Load PDF documents
|
||||
if not is_string:
|
||||
loader = PyPDFLoader(path)
|
||||
documents = await asyncio.to_thread(loader.load)
|
||||
else:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
# Set a really small chunk size, just to show.
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
is_separator_regex=False,
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, is_separator_regex=False
|
||||
)
|
||||
documents = text_splitter.create_documents([path])
|
||||
|
||||
# Create document-level summaries
|
||||
summary_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
|
||||
summary_chain = load_summarize_chain(summary_llm, chain_type="map_reduce")
|
||||
|
||||
async def summarize_doc(doc):
|
||||
"""
|
||||
Summarizes a single document with rate limit handling.
|
||||
|
||||
Args:
|
||||
doc: The document to be summarized.
|
||||
|
||||
Returns:
|
||||
A summarized Document object.
|
||||
"""
|
||||
# Retry the summarization with exponential backoff
|
||||
summary_output = await retry_with_exponential_backoff(summary_chain.ainvoke([doc]))
|
||||
summary = summary_output['output_text']
|
||||
return Document(
|
||||
page_content=summary,
|
||||
metadata={"source": path, "page": doc.metadata["page"], "summary": True}
|
||||
)
|
||||
return Document(page_content=summary, metadata={"source": path, "page": doc.metadata["page"], "summary": True})
|
||||
|
||||
# Process documents in smaller batches to avoid rate limits
|
||||
batch_size = 5 # Adjust this based on your rate limits
|
||||
summaries = []
|
||||
batch_size = 5
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch = documents[i:i + batch_size]
|
||||
batch_summaries = await asyncio.gather(*[summarize_doc(doc) for doc in batch])
|
||||
summaries.extend(batch_summaries)
|
||||
await asyncio.sleep(1) # Short pause between batches
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Split documents into detailed chunks
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
|
||||
)
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len)
|
||||
detailed_chunks = await asyncio.to_thread(text_splitter.split_documents, documents)
|
||||
|
||||
# Update metadata for detailed chunks
|
||||
for i, chunk in enumerate(detailed_chunks):
|
||||
chunk.metadata.update({
|
||||
"chunk_id": i,
|
||||
"summary": False,
|
||||
"page": int(chunk.metadata.get("page", 0))
|
||||
})
|
||||
chunk.metadata.update({"chunk_id": i, "summary": False, "page": int(chunk.metadata.get("page", 0))})
|
||||
|
||||
# Create embeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
|
||||
# Create vector stores asynchronously with rate limit handling
|
||||
async def create_vectorstore(docs):
|
||||
"""
|
||||
Creates a vector store from a list of documents with rate limit handling.
|
||||
return await retry_with_exponential_backoff(asyncio.to_thread(FAISS.from_documents, docs, embeddings))
|
||||
|
||||
Args:
|
||||
docs: The list of documents to be embedded.
|
||||
|
||||
Returns:
|
||||
A FAISS vector store containing the embedded documents.
|
||||
"""
|
||||
return await retry_with_exponential_backoff(
|
||||
asyncio.to_thread(FAISS.from_documents, docs, embeddings)
|
||||
)
|
||||
|
||||
# Generate vector stores for summaries and detailed chunks concurrently
|
||||
summary_vectorstore, detailed_vectorstore = await asyncio.gather(
|
||||
create_vectorstore(summaries),
|
||||
create_vectorstore(detailed_chunks)
|
||||
@@ -125,64 +68,57 @@ async def encode_pdf_hierarchical(path, chunk_size=1000, chunk_overlap=200, is_s
|
||||
return summary_vectorstore, detailed_vectorstore
|
||||
|
||||
|
||||
# Retrieve information according to summary level, and then retrieve information from the chunk level vector store and filter according to the summary level pages
|
||||
def retrieve_hierarchical(query, summary_vectorstore, detailed_vectorstore, k_summaries=3, k_chunks=5):
|
||||
"""
|
||||
Performs a hierarchical retrieval using the query.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
summary_vectorstore: The vector store containing document summaries.
|
||||
detailed_vectorstore: The vector store containing detailed chunks.
|
||||
k_summaries: The number of top summaries to retrieve.
|
||||
k_chunks: The number of detailed chunks to retrieve per summary.
|
||||
|
||||
Returns:
|
||||
A list of relevant detailed chunks.
|
||||
"""
|
||||
|
||||
# Retrieve top summaries
|
||||
top_summaries = summary_vectorstore.similarity_search(query, k=k_summaries)
|
||||
|
||||
relevant_chunks = []
|
||||
for summary in top_summaries:
|
||||
# For each summary, retrieve relevant detailed chunks
|
||||
page_number = summary.metadata["page"]
|
||||
page_filter = lambda metadata: metadata["page"] == page_number
|
||||
page_chunks = detailed_vectorstore.similarity_search(
|
||||
query,
|
||||
k=k_chunks,
|
||||
filter=page_filter
|
||||
)
|
||||
page_chunks = detailed_vectorstore.similarity_search(query, k=k_chunks, filter=page_filter)
|
||||
relevant_chunks.extend(page_chunks)
|
||||
|
||||
return relevant_chunks
|
||||
|
||||
|
||||
async def main():
|
||||
# Encode the PDF book to both document-level summaries and detailed chunks if the vector stores do not exist
|
||||
class HierarchicalRAG:
|
||||
def __init__(self, pdf_path, chunk_size=1000, chunk_overlap=200):
|
||||
self.pdf_path = pdf_path
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.summary_store = None
|
||||
self.detailed_store = None
|
||||
|
||||
async def run(self, query):
|
||||
if os.path.exists("../vector_stores/summary_store") and os.path.exists("../vector_stores/detailed_store"):
|
||||
embeddings = OpenAIEmbeddings()
|
||||
summary_store = FAISS.load_local("../vector_stores/summary_store", embeddings, allow_dangerous_deserialization=True)
|
||||
detailed_store = FAISS.load_local("../vector_stores/detailed_store", embeddings,
|
||||
allow_dangerous_deserialization=True)
|
||||
|
||||
self.summary_store = FAISS.load_local("../vector_stores/summary_store", embeddings, allow_dangerous_deserialization=True)
|
||||
self.detailed_store = FAISS.load_local("../vector_stores/detailed_store", embeddings, allow_dangerous_deserialization=True)
|
||||
else:
|
||||
summary_store, detailed_store = await encode_pdf_hierarchical(path)
|
||||
summary_store.save_local("../vector_stores/summary_store")
|
||||
detailed_store.save_local("../vector_stores/detailed_store")
|
||||
self.summary_store, self.detailed_store = await encode_pdf_hierarchical(self.pdf_path, self.chunk_size, self.chunk_overlap)
|
||||
self.summary_store.save_local("../vector_stores/summary_store")
|
||||
self.detailed_store.save_local("../vector_stores/detailed_store")
|
||||
|
||||
# Demonstrate on a use case
|
||||
query = "What is the greenhouse effect?"
|
||||
results = retrieve_hierarchical(query, summary_store, detailed_store)
|
||||
|
||||
# Print results
|
||||
results = retrieve_hierarchical(query, self.summary_store, self.detailed_store)
|
||||
for chunk in results:
|
||||
print(f"Page: {chunk.metadata['page']}")
|
||||
print(f"Content: {chunk.page_content}...") # Print first 100 characters
|
||||
print(f"Content: {chunk.page_content}...")
|
||||
print("---")
|
||||
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Run Hierarchical RAG on a given PDF.")
|
||||
parser.add_argument("--pdf_path", type=str, default="../data/Understanding_Climate_Change.pdf", help="Path to the PDF document.")
|
||||
parser.add_argument("--chunk_size", type=int, default=1000, help="Size of each text chunk.")
|
||||
parser.add_argument("--chunk_overlap", type=int, default=200, help="Overlap between consecutive chunks.")
|
||||
parser.add_argument("--query", type=str, default='What is the greenhouse effect',
|
||||
help="Query to search in the document.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
args = parse_args()
|
||||
rag = HierarchicalRAG(args.pdf_path, args.chunk_size, args.chunk_overlap)
|
||||
asyncio.run(rag.run(args.query))
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
# Load environment variables from a .env file
|
||||
load_dotenv()
|
||||
@@ -10,135 +9,132 @@ load_dotenv()
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
# 1 - Query Rewriting: Reformulating queries to improve retrieval.
|
||||
re_write_llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
||||
|
||||
# Create a prompt template for query rewriting
|
||||
query_rewrite_template = """You are an AI assistant tasked with reformulating user queries to improve retrieval in a RAG system.
|
||||
Given the original query, rewrite it to be more specific, detailed, and likely to retrieve relevant information.
|
||||
|
||||
Original query: {original_query}
|
||||
|
||||
Rewritten query:"""
|
||||
|
||||
query_rewrite_prompt = PromptTemplate(
|
||||
input_variables=["original_query"],
|
||||
template=query_rewrite_template
|
||||
)
|
||||
|
||||
# Create an LLMChain for query rewriting
|
||||
query_rewriter = query_rewrite_prompt | re_write_llm
|
||||
|
||||
|
||||
def rewrite_query(original_query):
|
||||
# Function for rewriting a query to improve retrieval
|
||||
def rewrite_query(original_query, llm_chain):
|
||||
"""
|
||||
Rewrite the original query to improve retrieval.
|
||||
|
||||
Args:
|
||||
original_query (str): The original user query
|
||||
llm_chain: The chain used to generate the rewritten query
|
||||
|
||||
Returns:
|
||||
str: The rewritten query
|
||||
"""
|
||||
response = query_rewriter.invoke(original_query)
|
||||
response = llm_chain.invoke(original_query)
|
||||
return response.content
|
||||
|
||||
|
||||
# Demonstrate on a use case
|
||||
# example query over the understanding climate change dataset
|
||||
original_query = "What are the impacts of climate change on the environment?"
|
||||
rewritten_query = rewrite_query(original_query)
|
||||
print("Original query:", original_query)
|
||||
print("\nRewritten query:", rewritten_query)
|
||||
|
||||
# 2 - Step-back Prompting: Generating broader queries for better context retrieval.
|
||||
|
||||
|
||||
step_back_llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
||||
|
||||
# Create a prompt template for step-back prompting
|
||||
step_back_template = """You are an AI assistant tasked with generating broader, more general queries to improve context retrieval in a RAG system.
|
||||
Given the original query, generate a step-back query that is more general and can help retrieve relevant background information.
|
||||
|
||||
Original query: {original_query}
|
||||
|
||||
Step-back query:"""
|
||||
|
||||
step_back_prompt = PromptTemplate(
|
||||
input_variables=["original_query"],
|
||||
template=step_back_template
|
||||
)
|
||||
|
||||
# Create an LLMChain for step-back prompting
|
||||
step_back_chain = step_back_prompt | step_back_llm
|
||||
|
||||
|
||||
def generate_step_back_query(original_query):
|
||||
# Function for generating a step-back query to retrieve broader context
|
||||
def generate_step_back_query(original_query, llm_chain):
|
||||
"""
|
||||
Generate a step-back query to retrieve broader context.
|
||||
|
||||
Args:
|
||||
original_query (str): The original user query
|
||||
llm_chain: The chain used to generate the step-back query
|
||||
|
||||
Returns:
|
||||
str: The step-back query
|
||||
"""
|
||||
response = step_back_chain.invoke(original_query)
|
||||
response = llm_chain.invoke(original_query)
|
||||
return response.content
|
||||
|
||||
|
||||
# Demonstrate on a use case
|
||||
# example query over the understanding climate change dataset
|
||||
original_query = "What are the impacts of climate change on the environment?"
|
||||
step_back_query = generate_step_back_query(original_query)
|
||||
print("Original query:", original_query)
|
||||
print("\nStep-back query:", step_back_query)
|
||||
|
||||
# 3- Sub-query Decomposition: Breaking complex queries into simpler sub-queries.
|
||||
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
||||
|
||||
# Create a prompt template for sub-query decomposition
|
||||
subquery_decomposition_template = """You are an AI assistant tasked with breaking down complex queries into simpler sub-queries for a RAG system.
|
||||
Given the original query, decompose it into 2-4 simpler sub-queries that, when answered together, would provide a comprehensive response to the original query.
|
||||
|
||||
Original query: {original_query}
|
||||
|
||||
example: What are the impacts of climate change on the environment?
|
||||
|
||||
Sub-queries:
|
||||
1. What are the impacts of climate change on biodiversity?
|
||||
2. How does climate change affect the oceans?
|
||||
3. What are the effects of climate change on agriculture?
|
||||
4. What are the impacts of climate change on human health?"""
|
||||
|
||||
subquery_decomposition_prompt = PromptTemplate(
|
||||
input_variables=["original_query"],
|
||||
template=subquery_decomposition_template
|
||||
)
|
||||
|
||||
# Create an LLMChain for sub-query decomposition
|
||||
subquery_decomposer_chain = subquery_decomposition_prompt | sub_query_llm
|
||||
|
||||
|
||||
def decompose_query(original_query: str):
|
||||
# Function for decomposing a query into simpler sub-queries
|
||||
def decompose_query(original_query, llm_chain):
|
||||
"""
|
||||
Decompose the original query into simpler sub-queries.
|
||||
|
||||
Args:
|
||||
original_query (str): The original complex query
|
||||
llm_chain: The chain used to generate sub-queries
|
||||
|
||||
Returns:
|
||||
List[str]: A list of simpler sub-queries
|
||||
"""
|
||||
response = subquery_decomposer_chain.invoke(original_query).content
|
||||
response = llm_chain.invoke(original_query).content
|
||||
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
|
||||
return sub_queries
|
||||
|
||||
|
||||
# Demonstrate on a use case
|
||||
# example query over the understanding climate change dataset
|
||||
original_query = "What are the impacts of climate change on the environment?"
|
||||
sub_queries = decompose_query(original_query)
|
||||
print("\nSub-queries:")
|
||||
for i, sub_query in enumerate(sub_queries, 1):
|
||||
print(sub_query)
|
||||
# Main class for the RAG method
|
||||
class RAGQueryProcessor:
|
||||
def __init__(self):
|
||||
# Initialize LLM models
|
||||
self.re_write_llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
||||
self.step_back_llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
||||
self.sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
||||
|
||||
# Initialize prompt templates
|
||||
query_rewrite_template = """You are an AI assistant tasked with reformulating user queries to improve retrieval in a RAG system.
|
||||
Given the original query, rewrite it to be more specific, detailed, and likely to retrieve relevant information.
|
||||
|
||||
Original query: {original_query}
|
||||
|
||||
Rewritten query:"""
|
||||
step_back_template = """You are an AI assistant tasked with generating broader, more general queries to improve context retrieval in a RAG system.
|
||||
Given the original query, generate a step-back query that is more general and can help retrieve relevant background information.
|
||||
|
||||
Original query: {original_query}
|
||||
|
||||
Step-back query:"""
|
||||
subquery_decomposition_template = """You are an AI assistant tasked with breaking down complex queries into simpler sub-queries for a RAG system.
|
||||
Given the original query, decompose it into 2-4 simpler sub-queries that, when answered together, would provide a comprehensive response to the original query.
|
||||
|
||||
Original query: {original_query}
|
||||
|
||||
example: What are the impacts of climate change on the environment?
|
||||
|
||||
Sub-queries:
|
||||
1. What are the impacts of climate change on biodiversity?
|
||||
2. How does climate change affect the oceans?
|
||||
3. What are the effects of climate change on agriculture?
|
||||
4. What are the impacts of climate change on human health?"""
|
||||
|
||||
# Create LLMChains
|
||||
self.query_rewriter = PromptTemplate(input_variables=["original_query"],
|
||||
template=query_rewrite_template) | self.re_write_llm
|
||||
self.step_back_chain = PromptTemplate(input_variables=["original_query"],
|
||||
template=step_back_template) | self.step_back_llm
|
||||
self.subquery_decomposer_chain = PromptTemplate(input_variables=["original_query"],
|
||||
template=subquery_decomposition_template) | self.sub_query_llm
|
||||
|
||||
def run(self, original_query):
|
||||
"""
|
||||
Run the full RAG query processing pipeline.
|
||||
|
||||
Args:
|
||||
original_query (str): The original query to be processed
|
||||
"""
|
||||
# Rewrite the query
|
||||
rewritten_query = rewrite_query(original_query, self.query_rewriter)
|
||||
print("Original query:", original_query)
|
||||
print("\nRewritten query:", rewritten_query)
|
||||
|
||||
# Generate step-back query
|
||||
step_back_query = generate_step_back_query(original_query, self.step_back_chain)
|
||||
print("\nStep-back query:", step_back_query)
|
||||
|
||||
# Decompose the query into sub-queries
|
||||
sub_queries = decompose_query(original_query, self.subquery_decomposer_chain)
|
||||
print("\nSub-queries:")
|
||||
for i, sub_query in enumerate(sub_queries, 1):
|
||||
print(f"{i}. {sub_query}")
|
||||
|
||||
|
||||
# Argument parsing
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Process a query using the RAG method.")
|
||||
parser.add_argument("--query", type=str, default='What are the impacts of climate change on the environment?',
|
||||
help="The original query to be processed")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# Main execution
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
processor = RAGQueryProcessor()
|
||||
processor.run(args.query)
|
||||
|
||||
@@ -18,8 +18,7 @@ import os
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
@@ -29,15 +28,8 @@ load_dotenv()
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
# Define logging, llm and embeddings
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
llm = ChatOpenAI(model_name="gpt-4o-mini")
|
||||
|
||||
|
||||
# Helper Functions
|
||||
# Helper functions
|
||||
|
||||
def extract_text(item):
|
||||
"""Extract text content from either a string or an AIMessage object."""
|
||||
@@ -48,6 +40,7 @@ def extract_text(item):
|
||||
|
||||
def embed_texts(texts: List[str]) -> List[List[float]]:
|
||||
"""Embed texts using OpenAIEmbeddings."""
|
||||
embeddings = OpenAIEmbeddings()
|
||||
logging.info(f"Embedding {len(texts)} texts")
|
||||
return embeddings.embed_documents([extract_text(text) for text in texts])
|
||||
|
||||
@@ -59,7 +52,7 @@ def perform_clustering(embeddings: np.ndarray, n_clusters: int = 10) -> np.ndarr
|
||||
return gm.fit_predict(embeddings)
|
||||
|
||||
|
||||
def summarize_texts(texts: List[str]) -> str:
|
||||
def summarize_texts(texts: List[str], llm: ChatOpenAI) -> str:
|
||||
"""Summarize a list of texts using OpenAI."""
|
||||
logging.info(f"Summarizing {len(texts)} texts")
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
@@ -85,15 +78,59 @@ def visualize_clusters(embeddings: np.ndarray, labels: np.ndarray, level: int):
|
||||
plt.show()
|
||||
|
||||
|
||||
# RAPTOR Core Function
|
||||
def build_vectorstore(tree_results: Dict[int, pd.DataFrame], embeddings) -> FAISS:
|
||||
"""Build a FAISS vectorstore from all texts in the RAPTOR tree."""
|
||||
all_texts = []
|
||||
all_embeddings = []
|
||||
all_metadatas = []
|
||||
|
||||
def build_raptor_tree(texts: List[str], max_levels: int = 3) -> Dict[int, pd.DataFrame]:
|
||||
for level, df in tree_results.items():
|
||||
all_texts.extend([str(text) for text in df['text'].tolist()])
|
||||
all_embeddings.extend([embedding.tolist() if isinstance(embedding, np.ndarray) else embedding for embedding in
|
||||
df['embedding'].tolist()])
|
||||
all_metadatas.extend(df['metadata'].tolist())
|
||||
|
||||
logging.info(f"Building vectorstore with {len(all_texts)} texts")
|
||||
documents = [Document(page_content=str(text), metadata=metadata)
|
||||
for text, metadata in zip(all_texts, all_metadatas)]
|
||||
return FAISS.from_documents(documents, embeddings)
|
||||
|
||||
|
||||
def create_retriever(vectorstore: FAISS, llm: ChatOpenAI) -> ContextualCompressionRetriever:
|
||||
"""Create a retriever with contextual compression."""
|
||||
logging.info("Creating contextual compression retriever")
|
||||
base_retriever = vectorstore.as_retriever()
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
"Given the following context and question, extract only the relevant information for answering the question:\n\n"
|
||||
"Context: {context}\n"
|
||||
"Question: {question}\n\n"
|
||||
"Relevant Information:"
|
||||
)
|
||||
|
||||
extractor = LLMChainExtractor.from_llm(llm, prompt=prompt)
|
||||
return ContextualCompressionRetriever(
|
||||
base_compressor=extractor,
|
||||
base_retriever=base_retriever
|
||||
)
|
||||
|
||||
|
||||
# Main class RAPTORMethod
|
||||
class RAPTORMethod:
|
||||
def __init__(self, texts: List[str], max_levels: int = 3):
|
||||
self.texts = texts
|
||||
self.max_levels = max_levels
|
||||
self.embeddings = OpenAIEmbeddings()
|
||||
self.llm = ChatOpenAI(model_name="gpt-4o-mini")
|
||||
self.tree_results = self.build_raptor_tree()
|
||||
|
||||
def build_raptor_tree(self) -> Dict[int, pd.DataFrame]:
|
||||
"""Build the RAPTOR tree structure with level metadata and parent-child relationships."""
|
||||
results = {}
|
||||
current_texts = [extract_text(text) for text in texts]
|
||||
current_metadata = [{"level": 0, "origin": "original", "parent_id": None} for _ in texts]
|
||||
current_texts = [extract_text(text) for text in self.texts]
|
||||
current_metadata = [{"level": 0, "origin": "original", "parent_id": None} for _ in self.texts]
|
||||
|
||||
for level in range(1, max_levels + 1):
|
||||
for level in range(1, self.max_levels + 1):
|
||||
logging.info(f"Processing level {level}")
|
||||
|
||||
embeddings = embed_texts(current_texts)
|
||||
@@ -115,7 +152,7 @@ def build_raptor_tree(texts: List[str], max_levels: int = 3) -> Dict[int, pd.Dat
|
||||
cluster_docs = df[df['cluster'] == cluster]
|
||||
cluster_texts = cluster_docs['text'].tolist()
|
||||
cluster_metadata = cluster_docs['metadata'].tolist()
|
||||
summary = summarize_texts(cluster_texts)
|
||||
summary = summarize_texts(cluster_texts, self.llm)
|
||||
summaries.append(summary)
|
||||
new_metadata.append({
|
||||
"level": level,
|
||||
@@ -139,202 +176,58 @@ def build_raptor_tree(texts: List[str], max_levels: int = 3) -> Dict[int, pd.Dat
|
||||
|
||||
return results
|
||||
|
||||
def run(self, query: str, k: int = 3) -> Dict[str, Any]:
|
||||
"""Run the RAPTOR query pipeline."""
|
||||
vectorstore = build_vectorstore(self.tree_results, self.embeddings)
|
||||
retriever = create_retriever(vectorstore, self.llm)
|
||||
|
||||
# Vectorstore Function
|
||||
|
||||
def build_vectorstore(tree_results: Dict[int, pd.DataFrame]) -> FAISS:
|
||||
"""Build a FAISS vectorstore from all texts in the RAPTOR tree."""
|
||||
all_texts = []
|
||||
all_embeddings = []
|
||||
all_metadatas = []
|
||||
|
||||
for level, df in tree_results.items():
|
||||
all_texts.extend([str(text) for text in df['text'].tolist()])
|
||||
all_embeddings.extend([embedding.tolist() if isinstance(embedding, np.ndarray) else embedding for embedding in
|
||||
df['embedding'].tolist()])
|
||||
all_metadatas.extend(df['metadata'].tolist())
|
||||
|
||||
logging.info(f"Building vectorstore with {len(all_texts)} texts")
|
||||
|
||||
# Create Document objects manually to ensure correct types
|
||||
documents = [Document(page_content=str(text), metadata=metadata)
|
||||
for text, metadata in zip(all_texts, all_metadatas)]
|
||||
|
||||
return FAISS.from_documents(documents, embeddings)
|
||||
|
||||
|
||||
# Define tree traversal retrieval
|
||||
def tree_traversal_retrieval(query: str, vectorstore: FAISS, k: int = 3) -> List[Document]:
|
||||
"""Perform tree traversal retrieval."""
|
||||
query_embedding = embeddings.embed_query(query)
|
||||
|
||||
def retrieve_level(level: int, parent_ids: List[str] = None) -> List[Document]:
|
||||
if parent_ids:
|
||||
docs = vectorstore.similarity_search_by_vector_with_relevance_scores(
|
||||
query_embedding,
|
||||
k=k,
|
||||
filter=lambda meta: meta['level'] == level and meta['id'] in parent_ids
|
||||
)
|
||||
else:
|
||||
docs = vectorstore.similarity_search_by_vector_with_relevance_scores(
|
||||
query_embedding,
|
||||
k=k,
|
||||
filter=lambda meta: meta['level'] == level
|
||||
)
|
||||
|
||||
if not docs or level == 0:
|
||||
return docs
|
||||
|
||||
child_ids = [doc.metadata.get('child_ids', []) for doc, _ in docs]
|
||||
child_ids = [item for sublist in child_ids for item in sublist] # Flatten the list
|
||||
|
||||
child_docs = retrieve_level(level - 1, child_ids)
|
||||
return docs + child_docs
|
||||
|
||||
max_level = max(doc.metadata['level'] for doc in vectorstore.docstore.values())
|
||||
return retrieve_level(max_level)
|
||||
|
||||
|
||||
# Create Retriever
|
||||
|
||||
def create_retriever(vectorstore: FAISS) -> ContextualCompressionRetriever:
|
||||
"""Create a retriever with contextual compression."""
|
||||
logging.info("Creating contextual compression retriever")
|
||||
base_retriever = vectorstore.as_retriever()
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
"Given the following context and question, extract only the relevant information for answering the question:\n\n"
|
||||
"Context: {context}\n"
|
||||
"Question: {question}\n\n"
|
||||
"Relevant Information:"
|
||||
)
|
||||
|
||||
extractor = LLMChainExtractor.from_llm(llm, prompt=prompt)
|
||||
|
||||
return ContextualCompressionRetriever(
|
||||
base_compressor=extractor,
|
||||
base_retriever=base_retriever
|
||||
)
|
||||
|
||||
|
||||
# Define hierarchical retrieval
|
||||
def hierarchical_retrieval(query: str, retriever: ContextualCompressionRetriever, max_level: int) -> List[Document]:
|
||||
"""Perform hierarchical retrieval starting from the highest level, handling potential None values."""
|
||||
all_retrieved_docs = []
|
||||
|
||||
for level in range(max_level, -1, -1):
|
||||
# Retrieve documents from the current level
|
||||
level_docs = retriever.get_relevant_documents(
|
||||
query,
|
||||
filter=lambda meta: meta['level'] == level
|
||||
)
|
||||
all_retrieved_docs.extend(level_docs)
|
||||
|
||||
# If we've found documents, retrieve their children from the next level down
|
||||
if level_docs and level > 0:
|
||||
child_ids = [doc.metadata.get('child_ids', []) for doc in level_docs]
|
||||
child_ids = [item for sublist in child_ids for item in sublist if
|
||||
item is not None] # Flatten and filter None
|
||||
|
||||
if child_ids: # Only modify query if there are valid child IDs
|
||||
child_query = f" AND id:({' OR '.join(str(id) for id in child_ids)})"
|
||||
query += child_query
|
||||
|
||||
return all_retrieved_docs
|
||||
|
||||
|
||||
# RAPTOR Query Process (Online Process)
|
||||
def raptor_query(query: str, retriever: ContextualCompressionRetriever, max_level: int) -> Dict[str, Any]:
|
||||
"""Process a query using the RAPTOR system with hierarchical retrieval."""
|
||||
logging.info(f"Processing query: {query}")
|
||||
relevant_docs = retriever.get_relevant_documents(query)
|
||||
|
||||
relevant_docs = hierarchical_retrieval(query, retriever, max_level)
|
||||
|
||||
doc_details = []
|
||||
for i, doc in enumerate(relevant_docs, 1):
|
||||
doc_details.append({
|
||||
"index": i,
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
"level": doc.metadata.get('level', 'Unknown'),
|
||||
"similarity_score": doc.metadata.get('score', 'N/A')
|
||||
})
|
||||
doc_details = [{"content": doc.page_content, "metadata": doc.metadata} for doc in relevant_docs]
|
||||
|
||||
context = "\n\n".join([doc.page_content for doc in relevant_docs])
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
"Given the following context, please answer the question:\n\n"
|
||||
"Context: {context}\n\n"
|
||||
"Question: {question}\n\n"
|
||||
"Answer:"
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain = LLMChain(llm=self.llm, prompt=prompt)
|
||||
answer = chain.run(context=context, question=query)
|
||||
|
||||
logging.info("Query processing completed")
|
||||
|
||||
result = {
|
||||
return {
|
||||
"query": query,
|
||||
"retrieved_documents": doc_details,
|
||||
"num_docs_retrieved": len(relevant_docs),
|
||||
"context_used": context,
|
||||
"answer": answer,
|
||||
"model_used": llm.model_name,
|
||||
"model_used": self.llm.model_name,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# Argument Parsing and Validation
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Run RAPTORMethod")
|
||||
parser.add_argument("--path", type=str, default="../data/Understanding_Climate_Change.pdf",
|
||||
help="Path to the PDF file to process.")
|
||||
parser.add_argument("--query", type=str, default="What is the greenhouse effect?",
|
||||
help="Query to test the retriever (default: 'What is the main topic of the document?').")
|
||||
parser.add_argument('--max_levels', type=int, default=3, help="Max levels for RAPTOR tree")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def print_query_details(result: Dict[str, Any]):
|
||||
"""Print detailed information about the query process, including tree level metadata."""
|
||||
# Main Execution
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
loader = PyPDFLoader(args.path)
|
||||
documents = loader.load()
|
||||
texts = [doc.page_content for doc in documents]
|
||||
|
||||
raptor_method = RAPTORMethod(texts, max_levels=args.max_levels)
|
||||
result = raptor_method.run(args.query)
|
||||
|
||||
print(f"Query: {result['query']}")
|
||||
print(f"\nNumber of documents retrieved: {result['num_docs_retrieved']}")
|
||||
print(f"\nRetrieved Documents:")
|
||||
for doc in result['retrieved_documents']:
|
||||
print(f" Document {doc['index']}:")
|
||||
print(f" Content: {doc['content'][:100]}...") # Show first 100 characters
|
||||
print(f" Similarity Score: {doc['similarity_score']}")
|
||||
print(f" Tree Level: {doc['metadata'].get('level', 'Unknown')}")
|
||||
print(f" Origin: {doc['metadata'].get('origin', 'Unknown')}")
|
||||
if 'child_docs' in doc['metadata']:
|
||||
print(f" Number of Child Documents: {len(doc['metadata']['child_docs'])}")
|
||||
print()
|
||||
|
||||
print(f"\nContext used for answer generation:")
|
||||
print(result['context_used'])
|
||||
|
||||
print(f"\nGenerated Answer:")
|
||||
print(result['answer'])
|
||||
|
||||
print(f"\nModel Used: {result['model_used']}")
|
||||
|
||||
|
||||
# ## Example Usage and Visualization
|
||||
#
|
||||
|
||||
# ## Define data folder
|
||||
|
||||
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
|
||||
# Process texts
|
||||
loader = PyPDFLoader(path)
|
||||
documents = loader.load()
|
||||
texts = [doc.page_content for doc in documents]
|
||||
|
||||
# Create RAPTOR components instances
|
||||
# Build the RAPTOR tree
|
||||
tree_results = build_raptor_tree(texts)
|
||||
|
||||
# Build vectorstore
|
||||
vectorstore = build_vectorstore(tree_results)
|
||||
|
||||
# Create retriever
|
||||
retriever = create_retriever(vectorstore)
|
||||
|
||||
# Run a query and observe where it got the data from + results
|
||||
# Run the pipeline
|
||||
max_level = 3 # Adjust based on your tree depth
|
||||
query = "What is the greenhouse effect?"
|
||||
result = raptor_query(query, retriever, max_level)
|
||||
print_query_details(result)
|
||||
print(f"Context Used: {result['context_used']}")
|
||||
print(f"Answer: {result['answer']}")
|
||||
print(f"Model Used: {result['model_used']}")
|
||||
|
||||
@@ -2,14 +2,15 @@ import os
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from typing import List, Any
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from sentence_transformers import CrossEncoder
|
||||
from pydantic import BaseModel, Field
|
||||
import argparse
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
@@ -19,16 +20,8 @@ load_dotenv()
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
# Define the document's path
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
|
||||
# Create a vector store
|
||||
vectorstore = encode_pdf(path)
|
||||
|
||||
|
||||
# ## Method 1: LLM based function to rerank the retrieved documents
|
||||
|
||||
# Create a custom reranking function
|
||||
# Helper Classes and Functions
|
||||
|
||||
class RatingScore(BaseModel):
|
||||
relevance_score: float = Field(..., description="The relevance score of a document to a query.")
|
||||
@@ -60,28 +53,6 @@ def rerank_documents(query: str, docs: List[Document], top_n: int = 3) -> List[D
|
||||
return [doc for doc, _ in reranked_docs[:top_n]]
|
||||
|
||||
|
||||
# Example usage of the reranking function with a sample query relevant to the document
|
||||
|
||||
query = "What are the impacts of climate change on biodiversity?"
|
||||
initial_docs = vectorstore.similarity_search(query, k=15)
|
||||
reranked_docs = rerank_documents(query, initial_docs)
|
||||
|
||||
# print first 3 initial documents
|
||||
print("Top initial documents:")
|
||||
for i, doc in enumerate(initial_docs[:3]):
|
||||
print(f"\nDocument {i + 1}:")
|
||||
print(doc.page_content[:200] + "...") # Print first 200 characters of each document
|
||||
|
||||
# Print results
|
||||
print(f"Query: {query}\n")
|
||||
print("Top reranked documents:")
|
||||
for i, doc in enumerate(reranked_docs):
|
||||
print(f"\nDocument {i + 1}:")
|
||||
print(doc.page_content[:200] + "...") # Print first 200 characters of each document
|
||||
|
||||
|
||||
# Create a custom retriever based on our reranker
|
||||
# Create a custom retriever class
|
||||
class CustomRetriever(BaseRetriever, BaseModel):
|
||||
vectorstore: Any = Field(description="Vector store for initial retrieval")
|
||||
|
||||
@@ -93,44 +64,27 @@ class CustomRetriever(BaseRetriever, BaseModel):
|
||||
return rerank_documents(query, initial_docs, top_n=num_docs)
|
||||
|
||||
|
||||
# Create the custom retriever
|
||||
custom_retriever = CustomRetriever(vectorstore=vectorstore)
|
||||
class CrossEncoderRetriever(BaseRetriever, BaseModel):
|
||||
vectorstore: Any = Field(description="Vector store for initial retrieval")
|
||||
cross_encoder: Any = Field(description="Cross-encoder model for reranking")
|
||||
k: int = Field(default=5, description="Number of documents to retrieve initially")
|
||||
rerank_top_k: int = Field(default=3, description="Number of documents to return after reranking")
|
||||
|
||||
# Create an LLM for answering questions
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# Create the RetrievalQA chain with the custom retriever
|
||||
qa_chain = RetrievalQA.from_chain_type(
|
||||
llm=llm,
|
||||
chain_type="stuff",
|
||||
retriever=custom_retriever,
|
||||
return_source_documents=True
|
||||
)
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
initial_docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
pairs = [[query, doc.page_content] for doc in initial_docs]
|
||||
scores = self.cross_encoder.predict(pairs)
|
||||
scored_docs = sorted(zip(initial_docs, scores), key=lambda x: x[1], reverse=True)
|
||||
return [doc for doc, _ in scored_docs[:self.rerank_top_k]]
|
||||
|
||||
# Example query
|
||||
|
||||
result = qa_chain({"query": query})
|
||||
|
||||
print(f"\nQuestion: {query}")
|
||||
print(f"Answer: {result['result']}")
|
||||
print("\nRelevant source documents:")
|
||||
for i, doc in enumerate(result["source_documents"]):
|
||||
print(f"\nDocument {i + 1}:")
|
||||
print(doc.page_content[:200] + "...") # Print first 200 characters of each document
|
||||
|
||||
# Example that demonstrates why we should use reranking
|
||||
chunks = [
|
||||
"The capital of France is great.",
|
||||
"The capital of France is huge.",
|
||||
"The capital of France is beautiful.",
|
||||
"""Have you ever visited Paris? It is a beautiful city where you can eat delicious food and see the Eiffel Tower.
|
||||
I really enjoyed all the cities in france, but its capital with the Eiffel Tower is my favorite city.""",
|
||||
"I really enjoyed my trip to Paris, France. The city is beautiful and the food is delicious. I would love to visit again. Such a great capital city."
|
||||
]
|
||||
docs = [Document(page_content=sentence) for sentence in chunks]
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError("Async retrieval not implemented")
|
||||
|
||||
|
||||
def compare_rag_techniques(query: str, docs: List[Document] = docs) -> None:
|
||||
def compare_rag_techniques(query: str, docs: List[Document]) -> None:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = FAISS.from_documents(docs, embeddings)
|
||||
|
||||
@@ -152,76 +106,68 @@ def compare_rag_techniques(query: str, docs: List[Document] = docs) -> None:
|
||||
print(doc.page_content)
|
||||
|
||||
|
||||
query = "what is the capital of france?"
|
||||
compare_rag_techniques(query, docs)
|
||||
# Main class
|
||||
class RAGPipeline:
|
||||
def __init__(self, path: str):
|
||||
self.vectorstore = encode_pdf(path)
|
||||
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
|
||||
|
||||
# ## Method 2: Cross Encoder models
|
||||
|
||||
# <div style="text-align: center;">
|
||||
#
|
||||
# <img src="../images/rerank_cross_encoder.svg" alt="rerank cross encoder" style="width:40%; height:auto;">
|
||||
# </div>
|
||||
|
||||
# Define the cross encoder class
|
||||
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
||||
|
||||
|
||||
class CrossEncoderRetriever(BaseRetriever, BaseModel):
|
||||
vectorstore: Any = Field(description="Vector store for initial retrieval")
|
||||
cross_encoder: Any = Field(description="Cross-encoder model for reranking")
|
||||
k: int = Field(default=5, description="Number of documents to retrieve initially")
|
||||
rerank_top_k: int = Field(default=3, description="Number of documents to return after reranking")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
# Initial retrieval
|
||||
initial_docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
|
||||
# Prepare pairs for cross-encoder
|
||||
pairs = [[query, doc.page_content] for doc in initial_docs]
|
||||
|
||||
# Get cross-encoder scores
|
||||
scores = self.cross_encoder.predict(pairs)
|
||||
|
||||
# Sort documents by score
|
||||
scored_docs = sorted(zip(initial_docs, scores), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Return top reranked documents
|
||||
return [doc for doc, _ in scored_docs[:self.rerank_top_k]]
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError("Async retrieval not implemented")
|
||||
|
||||
|
||||
# Create an instance and showcase over an example
|
||||
# Create the cross-encoder retriever
|
||||
cross_encoder_retriever = CrossEncoderRetriever(
|
||||
vectorstore=vectorstore,
|
||||
def run(self, query: str, retriever_type: str = "reranker"):
|
||||
if retriever_type == "reranker":
|
||||
retriever = CustomRetriever(vectorstore=self.vectorstore)
|
||||
elif retriever_type == "cross_encoder":
|
||||
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
||||
retriever = CrossEncoderRetriever(
|
||||
vectorstore=self.vectorstore,
|
||||
cross_encoder=cross_encoder,
|
||||
k=10, # Retrieve 10 documents initially
|
||||
rerank_top_k=5 # Return top 5 after reranking
|
||||
)
|
||||
k=10,
|
||||
rerank_top_k=5
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown retriever type. Use 'reranker' or 'cross_encoder'.")
|
||||
|
||||
# Set up the LLM
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
|
||||
|
||||
# Create the RetrievalQA chain with the cross-encoder retriever
|
||||
qa_chain = RetrievalQA.from_chain_type(
|
||||
llm=llm,
|
||||
qa_chain = RetrievalQA.from_chain_type(
|
||||
llm=self.llm,
|
||||
chain_type="stuff",
|
||||
retriever=cross_encoder_retriever,
|
||||
retriever=retriever,
|
||||
return_source_documents=True
|
||||
)
|
||||
)
|
||||
|
||||
# Example query
|
||||
query = "What are the impacts of climate change on biodiversity?"
|
||||
result = qa_chain({"query": query})
|
||||
result = qa_chain({"query": query})
|
||||
|
||||
print(f"\nQuestion: {query}")
|
||||
print(f"Answer: {result['result']}")
|
||||
print("\nRelevant source documents:")
|
||||
for i, doc in enumerate(result["source_documents"]):
|
||||
print(f"\nQuestion: {query}")
|
||||
print(f"Answer: {result['result']}")
|
||||
print("\nRelevant source documents:")
|
||||
for i, doc in enumerate(result["source_documents"]):
|
||||
print(f"\nDocument {i + 1}:")
|
||||
print(doc.page_content[:200] + "...") # Print first 200 characters of each document
|
||||
print(doc.page_content[:200] + "...")
|
||||
|
||||
|
||||
# Argument Parsing
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="RAG Pipeline")
|
||||
parser.add_argument("--path", type=str, default="../data/Understanding_Climate_Change.pdf", help="Path to the document")
|
||||
parser.add_argument("--query", type=str, default='What are the impacts of climate change?', help="Query to ask")
|
||||
parser.add_argument("--retriever_type", type=str, default="reranker", choices=["reranker", "cross_encoder"],
|
||||
help="Type of retriever to use")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
pipeline = RAGPipeline(path=args.path)
|
||||
pipeline.run(query=args.query, retriever_type=args.retriever_type)
|
||||
|
||||
# Demonstrate the reranking comparison
|
||||
# Example that demonstrates why we should use reranking
|
||||
chunks = [
|
||||
"The capital of France is great.",
|
||||
"The capital of France is huge.",
|
||||
"The capital of France is beautiful.",
|
||||
"""Have you ever visited Paris? It is a beautiful city where you can eat delicious food and see the Eiffel Tower.
|
||||
I really enjoyed all the cities in France, but its capital with the Eiffel Tower is my favorite city.""",
|
||||
"I really enjoyed my trip to Paris, France. The city is beautiful and the food is delicious. I would love to visit again. Such a great capital city."
|
||||
]
|
||||
docs = [Document(page_content=sentence) for sentence in chunks]
|
||||
|
||||
compare_rag_techniques(query="what is the capital of france?", docs=docs)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.chains import RetrievalQA
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
@@ -19,19 +21,13 @@ load_dotenv()
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
# Define documents path
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
|
||||
# Create vector store and retrieval QA chain
|
||||
content = read_pdf_to_string(path)
|
||||
vectorstore = encode_from_string(content)
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
||||
qa_chain = RetrievalQA.from_chain_type(llm, retriever=retriever)
|
||||
# Define the Response class
|
||||
class Response(BaseModel):
|
||||
answer: str = Field(..., title="The answer to the question. The options can be only 'Yes' or 'No'")
|
||||
|
||||
|
||||
# Function to format user feedback in a dictionary
|
||||
# Define utility functions
|
||||
def get_user_feedback(query, response, relevance, quality, comments=""):
|
||||
return {
|
||||
"query": query,
|
||||
@@ -42,14 +38,12 @@ def get_user_feedback(query, response, relevance, quality, comments=""):
|
||||
}
|
||||
|
||||
|
||||
# Function to store the feedback in a json file
|
||||
def store_feedback(feedback):
|
||||
with open("../data/feedback_data.json", "a") as f:
|
||||
json.dump(feedback, f)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
# Function to read the feedback file
|
||||
def load_feedback_data():
|
||||
feedback_data = []
|
||||
try:
|
||||
@@ -61,13 +55,7 @@ def load_feedback_data():
|
||||
return feedback_data
|
||||
|
||||
|
||||
# Function to adjust files relevancy based on the feedbacks file
|
||||
class Response(BaseModel):
|
||||
answer: str = Field(..., title="The answer to the question. The options can be only 'Yes' or 'No'")
|
||||
|
||||
|
||||
def adjust_relevance_scores(query: str, docs: List[Any], feedback_data: List[Dict[str, Any]]) -> List[Any]:
|
||||
# Create a prompt template for relevance checking
|
||||
relevance_prompt = PromptTemplate(
|
||||
input_variables=["query", "feedback_query", "doc_content", "feedback_response"],
|
||||
template="""
|
||||
@@ -82,15 +70,11 @@ def adjust_relevance_scores(query: str, docs: List[Any], feedback_data: List[Dic
|
||||
"""
|
||||
)
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
||||
|
||||
# Create an LLMChain for relevance checking
|
||||
relevance_chain = relevance_prompt | llm.with_structured_output(Response)
|
||||
|
||||
for doc in docs:
|
||||
relevant_feedback = []
|
||||
|
||||
for feedback in feedback_data:
|
||||
# Use LLM to check relevance
|
||||
input_data = {
|
||||
"query": query,
|
||||
"feedback_query": feedback['query'],
|
||||
@@ -102,60 +86,63 @@ def adjust_relevance_scores(query: str, docs: List[Any], feedback_data: List[Dic
|
||||
if result == 'yes':
|
||||
relevant_feedback.append(feedback)
|
||||
|
||||
# Adjust the relevance score based on feedback
|
||||
if relevant_feedback:
|
||||
avg_relevance = sum(f['relevance'] for f in relevant_feedback) / len(relevant_feedback)
|
||||
doc.metadata['relevance_score'] *= (avg_relevance / 3) # Assuming a 1-5 scale, 3 is neutral
|
||||
doc.metadata['relevance_score'] *= (avg_relevance / 3)
|
||||
|
||||
# Re-rank documents based on adjusted scores
|
||||
return sorted(docs, key=lambda x: x.metadata['relevance_score'], reverse=True)
|
||||
|
||||
|
||||
# Function to fine tune the vector index to include also queries + answers that received good feedbacks
|
||||
def fine_tune_index(feedback_data: List[Dict[str, Any]], texts: List[str]) -> Any:
|
||||
# Filter high-quality responses
|
||||
good_responses = [f for f in feedback_data if f['relevance'] >= 4 and f['quality'] >= 4]
|
||||
|
||||
# Extract queries and responses, and create new documents
|
||||
additional_texts = []
|
||||
for f in good_responses:
|
||||
combined_text = f['query'] + " " + f['response']
|
||||
additional_texts.append(combined_text)
|
||||
|
||||
# make the list a string
|
||||
additional_texts = " ".join(additional_texts)
|
||||
|
||||
# Create a new index with original and high-quality texts
|
||||
additional_texts = " ".join([f['query'] + " " + f['response'] for f in good_responses])
|
||||
all_texts = texts + additional_texts
|
||||
new_vectorstore = encode_from_string(all_texts)
|
||||
|
||||
return new_vectorstore
|
||||
|
||||
|
||||
# Demonstration of how to retrieve answers with respect to user feedbacks
|
||||
query = "What is the greenhouse effect?"
|
||||
# Define the main RAG class
|
||||
class RetrievalAugmentedGeneration:
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
self.content = read_pdf_to_string(self.path)
|
||||
self.vectorstore = encode_from_string(self.content)
|
||||
self.retriever = self.vectorstore.as_retriever()
|
||||
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
||||
self.qa_chain = RetrievalQA.from_chain_type(self.llm, retriever=self.retriever)
|
||||
|
||||
# Get response from RAG system
|
||||
response = qa_chain(query)["result"]
|
||||
def run(self, query: str, relevance: int, quality: int):
|
||||
response = self.qa_chain(query)["result"]
|
||||
feedback = get_user_feedback(query, response, relevance, quality)
|
||||
store_feedback(feedback)
|
||||
|
||||
relevance = 5
|
||||
quality = 5
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
adjusted_docs = adjust_relevance_scores(query, docs, load_feedback_data())
|
||||
self.retriever.search_kwargs['k'] = len(adjusted_docs)
|
||||
self.retriever.search_kwargs['docs'] = adjusted_docs
|
||||
|
||||
# Collect feedback
|
||||
feedback = get_user_feedback(query, response, relevance, quality)
|
||||
return response
|
||||
|
||||
# Store feedback
|
||||
store_feedback(feedback)
|
||||
|
||||
# Adjust relevance scores for future retrievals
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
adjusted_docs = adjust_relevance_scores(query, docs, load_feedback_data())
|
||||
# Argument parsing
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Run the RAG system with feedback integration.")
|
||||
parser.add_argument('--path', type=str, default="../data/Understanding_Climate_Change.pdf",
|
||||
help="Path to the document.")
|
||||
parser.add_argument('--query', type=str, default='What is the greenhouse effect?',
|
||||
help="Query to ask the RAG system.")
|
||||
parser.add_argument('--relevance', type=int, default=5, help="Relevance score for the feedback.")
|
||||
parser.add_argument('--quality', type=int, default=5, help="Quality score for the feedback.")
|
||||
return parser.parse_args()
|
||||
|
||||
# Update the retriever with adjusted docs
|
||||
retriever.search_kwargs['k'] = len(adjusted_docs)
|
||||
retriever.search_kwargs['docs'] = adjusted_docs
|
||||
|
||||
# Finetune the vectorstore periodicly
|
||||
# Periodically (e.g., daily or weekly), fine-tune the index
|
||||
new_vectorstore = fine_tune_index(load_feedback_data(), content)
|
||||
retriever = new_vectorstore.as_retriever()
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
rag = RetrievalAugmentedGeneration(args.path)
|
||||
result = rag.run(args.query, args.relevance, args.quality)
|
||||
print(f"Response: {result}")
|
||||
|
||||
# Fine-tune the vectorstore periodically
|
||||
new_vectorstore = fine_tune_index(load_feedback_data(), rag.content)
|
||||
rag.retriever = new_vectorstore.as_retriever()
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain_openai import ChatOpenAI
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
||||
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path since we work with notebooks
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
@@ -16,91 +16,85 @@ load_dotenv()
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
# Define files path
|
||||
path = "../data/Understanding_Climate_Change.pdf"
|
||||
|
||||
# Create a vector store
|
||||
vectorstore = encode_pdf(path)
|
||||
|
||||
# Initialize the language model
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini", max_tokens=1000, temperature=0)
|
||||
|
||||
|
||||
# Defining prompt templates
|
||||
# Define all relevant classes/functions
|
||||
class RetrievalResponse(BaseModel):
|
||||
response: str = Field(..., title="Determines if retrieval is necessary", description="Output only 'Yes' or 'No'.")
|
||||
|
||||
|
||||
retrieval_prompt = PromptTemplate(
|
||||
input_variables=["query"],
|
||||
template="Given the query '{query}', determine if retrieval is necessary. Output only 'Yes' or 'No'."
|
||||
)
|
||||
|
||||
|
||||
class RelevanceResponse(BaseModel):
|
||||
response: str = Field(..., title="Determines if context is relevant",
|
||||
description="Output only 'Relevant' or 'Irrelevant'.")
|
||||
|
||||
|
||||
relevance_prompt = PromptTemplate(
|
||||
input_variables=["query", "context"],
|
||||
template="Given the query '{query}' and the context '{context}', determine if the context is relevant. Output only 'Relevant' or 'Irrelevant'."
|
||||
)
|
||||
|
||||
|
||||
class GenerationResponse(BaseModel):
|
||||
response: str = Field(..., title="Generated response", description="The generated response.")
|
||||
|
||||
|
||||
generation_prompt = PromptTemplate(
|
||||
input_variables=["query", "context"],
|
||||
template="Given the query '{query}' and the context '{context}', generate a response."
|
||||
)
|
||||
|
||||
|
||||
class SupportResponse(BaseModel):
|
||||
response: str = Field(..., title="Determines if response is supported",
|
||||
description="Output 'Fully supported', 'Partially supported', or 'No support'.")
|
||||
|
||||
|
||||
class UtilityResponse(BaseModel):
|
||||
response: int = Field(..., title="Utility rating", description="Rate the utility of the response from 1 to 5.")
|
||||
|
||||
|
||||
# Define prompt templates
|
||||
retrieval_prompt = PromptTemplate(
|
||||
input_variables=["query"],
|
||||
template="Given the query '{query}', determine if retrieval is necessary. Output only 'Yes' or 'No'."
|
||||
)
|
||||
|
||||
relevance_prompt = PromptTemplate(
|
||||
input_variables=["query", "context"],
|
||||
template="Given the query '{query}' and the context '{context}', determine if the context is relevant. Output only 'Relevant' or 'Irrelevant'."
|
||||
)
|
||||
|
||||
generation_prompt = PromptTemplate(
|
||||
input_variables=["query", "context"],
|
||||
template="Given the query '{query}' and the context '{context}', generate a response."
|
||||
)
|
||||
|
||||
support_prompt = PromptTemplate(
|
||||
input_variables=["response", "context"],
|
||||
template="Given the response '{response}' and the context '{context}', determine if the response is supported by the context. Output 'Fully supported', 'Partially supported', or 'No support'."
|
||||
)
|
||||
|
||||
|
||||
class UtilityResponse(BaseModel):
|
||||
response: int = Field(..., title="Utility rating", description="Rate the utility of the response from 1 to 5.")
|
||||
|
||||
|
||||
utility_prompt = PromptTemplate(
|
||||
input_variables=["query", "response"],
|
||||
template="Given the query '{query}' and the response '{response}', rate the utility of the response from 1 to 5."
|
||||
)
|
||||
|
||||
# Create LLMChains for each step
|
||||
retrieval_chain = retrieval_prompt | llm.with_structured_output(RetrievalResponse)
|
||||
relevance_chain = relevance_prompt | llm.with_structured_output(RelevanceResponse)
|
||||
generation_chain = generation_prompt | llm.with_structured_output(GenerationResponse)
|
||||
support_chain = support_prompt | llm.with_structured_output(SupportResponse)
|
||||
utility_chain = utility_prompt | llm.with_structured_output(UtilityResponse)
|
||||
|
||||
# Define main class
|
||||
|
||||
# Defining the self RAG logic flow
|
||||
def self_rag(query, vectorstore, top_k=3):
|
||||
class SelfRAG:
|
||||
def __init__(self, path, top_k=3):
|
||||
self.vectorstore = encode_pdf(path)
|
||||
self.top_k = top_k
|
||||
self.llm = ChatOpenAI(model="gpt-4o-mini", max_tokens=1000, temperature=0)
|
||||
|
||||
# Create LLMChains for each step
|
||||
self.retrieval_chain = retrieval_prompt | self.llm.with_structured_output(RetrievalResponse)
|
||||
self.relevance_chain = relevance_prompt | self.llm.with_structured_output(RelevanceResponse)
|
||||
self.generation_chain = generation_prompt | self.llm.with_structured_output(GenerationResponse)
|
||||
self.support_chain = support_prompt | self.llm.with_structured_output(SupportResponse)
|
||||
self.utility_chain = utility_prompt | self.llm.with_structured_output(UtilityResponse)
|
||||
|
||||
def run(self, query):
|
||||
print(f"\nProcessing query: {query}")
|
||||
|
||||
# Step 1: Determine if retrieval is necessary
|
||||
print("Step 1: Determining if retrieval is necessary...")
|
||||
input_data = {"query": query}
|
||||
retrieval_decision = retrieval_chain.invoke(input_data).response.strip().lower()
|
||||
retrieval_decision = self.retrieval_chain.invoke(input_data).response.strip().lower()
|
||||
print(f"Retrieval decision: {retrieval_decision}")
|
||||
|
||||
if retrieval_decision == 'yes':
|
||||
# Step 2: Retrieve relevant documents
|
||||
print("Step 2: Retrieving relevant documents...")
|
||||
docs = vectorstore.similarity_search(query, k=top_k)
|
||||
docs = self.vectorstore.similarity_search(query, k=self.top_k)
|
||||
contexts = [doc.page_content for doc in docs]
|
||||
print(f"Retrieved {len(contexts)} documents")
|
||||
|
||||
@@ -109,7 +103,7 @@ def self_rag(query, vectorstore, top_k=3):
|
||||
relevant_contexts = []
|
||||
for i, context in enumerate(contexts):
|
||||
input_data = {"query": query, "context": context}
|
||||
relevance = relevance_chain.invoke(input_data).response.strip().lower()
|
||||
relevance = self.relevance_chain.invoke(input_data).response.strip().lower()
|
||||
print(f"Document {i + 1} relevance: {relevance}")
|
||||
if relevance == 'relevant':
|
||||
relevant_contexts.append(context)
|
||||
@@ -120,7 +114,7 @@ def self_rag(query, vectorstore, top_k=3):
|
||||
if not relevant_contexts:
|
||||
print("No relevant contexts found. Generating without retrieval...")
|
||||
input_data = {"query": query, "context": "No relevant context found."}
|
||||
return generation_chain.invoke(input_data).response
|
||||
return self.generation_chain.invoke(input_data).response
|
||||
|
||||
# Step 4: Generate response using relevant contexts
|
||||
print("Step 4: Generating responses using relevant contexts...")
|
||||
@@ -128,18 +122,18 @@ def self_rag(query, vectorstore, top_k=3):
|
||||
for i, context in enumerate(relevant_contexts):
|
||||
print(f"Generating response for context {i + 1}...")
|
||||
input_data = {"query": query, "context": context}
|
||||
response = generation_chain.invoke(input_data).response
|
||||
response = self.generation_chain.invoke(input_data).response
|
||||
|
||||
# Step 5: Assess support
|
||||
print(f"Step 5: Assessing support for response {i + 1}...")
|
||||
input_data = {"response": response, "context": context}
|
||||
support = support_chain.invoke(input_data).response.strip().lower()
|
||||
support = self.support_chain.invoke(input_data).response.strip().lower()
|
||||
print(f"Support assessment: {support}")
|
||||
|
||||
# Step 6: Evaluate utility
|
||||
print(f"Step 6: Evaluating utility for response {i + 1}...")
|
||||
input_data = {"query": query, "response": response}
|
||||
utility = int(utility_chain.invoke(input_data).response)
|
||||
utility = int(self.utility_chain.invoke(input_data).response)
|
||||
print(f"Utility score: {utility}")
|
||||
|
||||
responses.append((response, support, utility))
|
||||
@@ -153,21 +147,24 @@ def self_rag(query, vectorstore, top_k=3):
|
||||
# Generate without retrieval
|
||||
print("Generating without retrieval...")
|
||||
input_data = {"query": query, "context": "No retrieval necessary."}
|
||||
return generation_chain.invoke(input_data).response
|
||||
return self.generation_chain.invoke(input_data).response
|
||||
|
||||
|
||||
# Test the self-RAG function easy query with high relevance
|
||||
# Argument parsing functions
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Self-RAG method")
|
||||
parser.add_argument('--path', type=str, default='../data/Understanding_Climate_Change.pdf',
|
||||
help='Path to the PDF file for vector store')
|
||||
parser.add_argument('--query', type=str, default='What is the impact of climate change on the environment?',
|
||||
help='Query to be processed')
|
||||
return parser.parse_args()
|
||||
|
||||
query = "What is the impact of climate change on the environment?"
|
||||
response = self_rag(query, vectorstore)
|
||||
|
||||
print("\nFinal response:")
|
||||
print(response)
|
||||
|
||||
# Test the self-RAG function with a more challenging query with low relevance
|
||||
|
||||
query = "how did harry beat quirrell?"
|
||||
response = self_rag(query, vectorstore)
|
||||
|
||||
print("\nFinal response:")
|
||||
print(response)
|
||||
# Main entry point
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
rag = SelfRAG(path=args.path)
|
||||
response = rag.run(args.query)
|
||||
print("\nFinal response:")
|
||||
print(response)
|
||||
|
||||
Reference in New Issue
Block a user