mirror of
https://github.com/NirDiamant/RAG_Techniques.git
synced 2025-04-07 00:48:52 +03:00
174 lines
6.8 KiB
Python
174 lines
6.8 KiB
Python
import os
|
|
import sys
|
|
from dotenv import load_dotenv
|
|
from langchain.docstore.document import Document
|
|
from typing import List, Any
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain.chains import RetrievalQA
|
|
from langchain_core.retrievers import BaseRetriever
|
|
from sentence_transformers import CrossEncoder
|
|
from pydantic import BaseModel, Field
|
|
import argparse
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
|
|
from helper_functions import *
|
|
from evaluation.evalute_rag import *
|
|
|
|
# Load environment variables from a .env file
|
|
load_dotenv()
|
|
|
|
# Set the OpenAI API key environment variable
|
|
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
|
|
|
|
|
# Helper Classes and Functions
|
|
|
|
class RatingScore(BaseModel):
|
|
relevance_score: float = Field(..., description="The relevance score of a document to a query.")
|
|
|
|
|
|
def rerank_documents(query: str, docs: List[Document], top_n: int = 3) -> List[Document]:
|
|
prompt_template = PromptTemplate(
|
|
input_variables=["query", "doc"],
|
|
template="""On a scale of 1-10, rate the relevance of the following document to the query. Consider the specific context and intent of the query, not just keyword matches.
|
|
Query: {query}
|
|
Document: {doc}
|
|
Relevance Score:"""
|
|
)
|
|
|
|
llm = ChatOpenAI(temperature=0, model_name="gpt-4o", max_tokens=4000)
|
|
llm_chain = prompt_template | llm.with_structured_output(RatingScore)
|
|
|
|
scored_docs = []
|
|
for doc in docs:
|
|
input_data = {"query": query, "doc": doc.page_content}
|
|
score = llm_chain.invoke(input_data).relevance_score
|
|
try:
|
|
score = float(score)
|
|
except ValueError:
|
|
score = 0 # Default score if parsing fails
|
|
scored_docs.append((doc, score))
|
|
|
|
reranked_docs = sorted(scored_docs, key=lambda x: x[1], reverse=True)
|
|
return [doc for doc, _ in reranked_docs[:top_n]]
|
|
|
|
|
|
class CustomRetriever(BaseRetriever, BaseModel):
|
|
vectorstore: Any = Field(description="Vector store for initial retrieval")
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def get_relevant_documents(self, query: str, num_docs=2) -> List[Document]:
|
|
initial_docs = self.vectorstore.similarity_search(query, k=30)
|
|
return rerank_documents(query, initial_docs, top_n=num_docs)
|
|
|
|
|
|
class CrossEncoderRetriever(BaseRetriever, BaseModel):
|
|
vectorstore: Any = Field(description="Vector store for initial retrieval")
|
|
cross_encoder: Any = Field(description="Cross-encoder model for reranking")
|
|
k: int = Field(default=5, description="Number of documents to retrieve initially")
|
|
rerank_top_k: int = Field(default=3, description="Number of documents to return after reranking")
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
|
initial_docs = self.vectorstore.similarity_search(query, k=self.k)
|
|
pairs = [[query, doc.page_content] for doc in initial_docs]
|
|
scores = self.cross_encoder.predict(pairs)
|
|
scored_docs = sorted(zip(initial_docs, scores), key=lambda x: x[1], reverse=True)
|
|
return [doc for doc, _ in scored_docs[:self.rerank_top_k]]
|
|
|
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
|
raise NotImplementedError("Async retrieval not implemented")
|
|
|
|
|
|
def compare_rag_techniques(query: str, docs: List[Document]) -> None:
|
|
embeddings = OpenAIEmbeddings()
|
|
vectorstore = FAISS.from_documents(docs, embeddings)
|
|
|
|
print("Comparison of Retrieval Techniques")
|
|
print("==================================")
|
|
print(f"Query: {query}\n")
|
|
|
|
print("Baseline Retrieval Result:")
|
|
baseline_docs = vectorstore.similarity_search(query, k=2)
|
|
for i, doc in enumerate(baseline_docs):
|
|
print(f"\nDocument {i + 1}:")
|
|
print(doc.page_content)
|
|
|
|
print("\nAdvanced Retrieval Result:")
|
|
custom_retriever = CustomRetriever(vectorstore=vectorstore)
|
|
advanced_docs = custom_retriever.get_relevant_documents(query)
|
|
for i, doc in enumerate(advanced_docs):
|
|
print(f"\nDocument {i + 1}:")
|
|
print(doc.page_content)
|
|
|
|
|
|
# Main class
|
|
class RAGPipeline:
|
|
def __init__(self, path: str):
|
|
self.vectorstore = encode_pdf(path)
|
|
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o")
|
|
|
|
def run(self, query: str, retriever_type: str = "reranker"):
|
|
if retriever_type == "reranker":
|
|
retriever = CustomRetriever(vectorstore=self.vectorstore)
|
|
elif retriever_type == "cross_encoder":
|
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
|
retriever = CrossEncoderRetriever(
|
|
vectorstore=self.vectorstore,
|
|
cross_encoder=cross_encoder,
|
|
k=10,
|
|
rerank_top_k=5
|
|
)
|
|
else:
|
|
raise ValueError("Unknown retriever type. Use 'reranker' or 'cross_encoder'.")
|
|
|
|
qa_chain = RetrievalQA.from_chain_type(
|
|
llm=self.llm,
|
|
chain_type="stuff",
|
|
retriever=retriever,
|
|
return_source_documents=True
|
|
)
|
|
|
|
result = qa_chain({"query": query})
|
|
|
|
print(f"\nQuestion: {query}")
|
|
print(f"Answer: {result['result']}")
|
|
print("\nRelevant source documents:")
|
|
for i, doc in enumerate(result["source_documents"]):
|
|
print(f"\nDocument {i + 1}:")
|
|
print(doc.page_content[:200] + "...")
|
|
|
|
|
|
# Argument Parsing
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="RAG Pipeline")
|
|
parser.add_argument("--path", type=str, default="../data/Understanding_Climate_Change.pdf", help="Path to the document")
|
|
parser.add_argument("--query", type=str, default='What are the impacts of climate change?', help="Query to ask")
|
|
parser.add_argument("--retriever_type", type=str, default="reranker", choices=["reranker", "cross_encoder"],
|
|
help="Type of retriever to use")
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
pipeline = RAGPipeline(path=args.path)
|
|
pipeline.run(query=args.query, retriever_type=args.retriever_type)
|
|
|
|
# Demonstrate the reranking comparison
|
|
# Example that demonstrates why we should use reranking
|
|
chunks = [
|
|
"The capital of France is great.",
|
|
"The capital of France is huge.",
|
|
"The capital of France is beautiful.",
|
|
"""Have you ever visited Paris? It is a beautiful city where you can eat delicious food and see the Eiffel Tower.
|
|
I really enjoyed all the cities in France, but its capital with the Eiffel Tower is my favorite city.""",
|
|
"I really enjoyed my trip to Paris, France. The city is beautiful and the food is delicious. I would love to visit again. Such a great capital city."
|
|
]
|
|
docs = [Document(page_content=sentence) for sentence in chunks]
|
|
|
|
compare_rag_techniques(query="what is the capital of france?", docs=docs)
|