mirror of
				https://github.com/NirDiamant/RAG_Techniques.git
				synced 2025-04-07 00:48:52 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			144 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			144 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import sys
 | |
| from dotenv import load_dotenv
 | |
| from langchain.docstore.document import Document
 | |
| from typing import List
 | |
| from rank_bm25 import BM25Okapi
 | |
| import numpy as np
 | |
| 
 | |
| # Add the parent directory to the path
 | |
| sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
 | |
| from helper_functions import *
 | |
| from evaluation.evalute_rag import *
 | |
| 
 | |
| # Load environment variables
 | |
| load_dotenv()
 | |
| os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
 | |
| 
 | |
| 
 | |
| # Function to encode the PDF to a vector store and return split documents
 | |
| def encode_pdf_and_get_split_documents(path, chunk_size=1000, chunk_overlap=200):
 | |
|     """
 | |
|     Encodes a PDF book into a vector store using OpenAI embeddings.
 | |
| 
 | |
|     Args:
 | |
|         path: The path to the PDF file.
 | |
|         chunk_size: The desired size of each text chunk.
 | |
|         chunk_overlap: The amount of overlap between consecutive chunks.
 | |
| 
 | |
|     Returns:
 | |
|         A FAISS vector store containing the encoded book content.
 | |
|     """
 | |
|     loader = PyPDFLoader(path)
 | |
|     documents = loader.load()
 | |
|     text_splitter = RecursiveCharacterTextSplitter(
 | |
|         chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
 | |
|     )
 | |
|     texts = text_splitter.split_documents(documents)
 | |
|     cleaned_texts = replace_t_with_space(texts)
 | |
|     embeddings = OpenAIEmbeddings()
 | |
|     vectorstore = FAISS.from_documents(cleaned_texts, embeddings)
 | |
| 
 | |
|     return vectorstore, cleaned_texts
 | |
| 
 | |
| 
 | |
| # Function to create BM25 index for keyword retrieval
 | |
| def create_bm25_index(documents: List[Document]) -> BM25Okapi:
 | |
|     """
 | |
|     Create a BM25 index from the given documents.
 | |
| 
 | |
|     Args:
 | |
|         documents (List[Document]): List of documents to index.
 | |
| 
 | |
|     Returns:
 | |
|         BM25Okapi: An index that can be used for BM25 scoring.
 | |
|     """
 | |
|     tokenized_docs = [doc.page_content.split() for doc in documents]
 | |
|     return BM25Okapi(tokenized_docs)
 | |
| 
 | |
| 
 | |
| # Function for fusion retrieval combining keyword-based (BM25) and vector-based search
 | |
| def fusion_retrieval(vectorstore, bm25, query: str, k: int = 5, alpha: float = 0.5) -> List[Document]:
 | |
|     """
 | |
|     Perform fusion retrieval combining keyword-based (BM25) and vector-based search.
 | |
| 
 | |
|     Args:
 | |
|     vectorstore (VectorStore): The vectorstore containing the documents.
 | |
|     bm25 (BM25Okapi): Pre-computed BM25 index.
 | |
|     query (str): The query string.
 | |
|     k (int): The number of documents to retrieve.
 | |
|     alpha (float): The weight for vector search scores (1-alpha will be the weight for BM25 scores).
 | |
| 
 | |
|     Returns:
 | |
|     List[Document]: The top k documents based on the combined scores.
 | |
|     """
 | |
|     all_docs = vectorstore.similarity_search("", k=vectorstore.index.ntotal)
 | |
|     bm25_scores = bm25.get_scores(query.split())
 | |
|     vector_results = vectorstore.similarity_search_with_score(query, k=len(all_docs))
 | |
| 
 | |
|     vector_scores = np.array([score for _, score in vector_results])
 | |
|     vector_scores = 1 - (vector_scores - np.min(vector_scores)) / (np.max(vector_scores) - np.min(vector_scores))
 | |
|     bm25_scores = (bm25_scores - np.min(bm25_scores)) / (np.max(bm25_scores) - np.min(bm25_scores))
 | |
| 
 | |
|     combined_scores = alpha * vector_scores + (1 - alpha) * bm25_scores
 | |
|     sorted_indices = np.argsort(combined_scores)[::-1]
 | |
| 
 | |
|     return [all_docs[i] for i in sorted_indices[:k]]
 | |
| 
 | |
| 
 | |
| class FusionRetrievalRAG:
 | |
|     def __init__(self, path: str, chunk_size: int = 1000, chunk_overlap: int = 200):
 | |
|         """
 | |
|         Initializes the FusionRetrievalRAG class by setting up the vector store and BM25 index.
 | |
| 
 | |
|         Args:
 | |
|         path (str): Path to the PDF file.
 | |
|         chunk_size (int): The size of each text chunk.
 | |
|         chunk_overlap (int): The overlap between consecutive chunks.
 | |
|         """
 | |
|         self.vectorstore, self.cleaned_texts = encode_pdf_and_get_split_documents(path, chunk_size, chunk_overlap)
 | |
|         self.bm25 = create_bm25_index(self.cleaned_texts)
 | |
| 
 | |
|     def run(self, query: str, k: int = 5, alpha: float = 0.5):
 | |
|         """
 | |
|         Executes the fusion retrieval for the given query.
 | |
| 
 | |
|         Args:
 | |
|         query (str): The search query.
 | |
|         k (int): The number of documents to retrieve.
 | |
|         alpha (float): The weight of vector search vs. BM25 search.
 | |
| 
 | |
|         Returns:
 | |
|         List[Document]: The top k retrieved documents.
 | |
|         """
 | |
|         top_docs = fusion_retrieval(self.vectorstore, self.bm25, query, k, alpha)
 | |
|         docs_content = [doc.page_content for doc in top_docs]
 | |
|         show_context(docs_content)
 | |
| 
 | |
| 
 | |
| def parse_args():
 | |
|     """
 | |
|     Parses command-line arguments.
 | |
| 
 | |
|     Returns:
 | |
|     args: The parsed arguments.
 | |
|     """
 | |
|     import argparse
 | |
|     parser = argparse.ArgumentParser(description="Fusion Retrieval RAG Script")
 | |
|     parser.add_argument('--path', type=str, default="../data/Understanding_Climate_Change.pdf",
 | |
|                         help='Path to the PDF file.')
 | |
|     parser.add_argument('--chunk_size', type=int, default=1000, help='Size of each chunk.')
 | |
|     parser.add_argument('--chunk_overlap', type=int, default=200, help='Overlap between consecutive chunks.')
 | |
|     parser.add_argument('--query', type=str, default='What are the impacts of climate change on the environment?',
 | |
|                         help='Query to retrieve documents.')
 | |
|     parser.add_argument('--k', type=int, default=5, help='Number of documents to retrieve.')
 | |
|     parser.add_argument('--alpha', type=float, default=0.5, help='Weight for vector search vs. BM25.')
 | |
| 
 | |
|     return parser.parse_args()
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     args = parse_args()
 | |
|     retriever = FusionRetrievalRAG(path=args.path, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap)
 | |
|     retriever.run(query=args.query, k=args.k, alpha=args.alpha)
 | 
