Files
RAG_Techniques/all_rag_techniques/fusion_retrieval.ipynb

323 lines
11 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fusion Retrieval in Document Search\n",
"\n",
"## Overview\n",
"\n",
"This code implements a Fusion Retrieval system that combines vector-based similarity search with keyword-based BM25 retrieval. The approach aims to leverage the strengths of both methods to improve the overall quality and relevance of document retrieval.\n",
"\n",
"## Motivation\n",
"\n",
"Traditional retrieval methods often rely on either semantic understanding (vector-based) or keyword matching (BM25). Each approach has its strengths and weaknesses. Fusion retrieval aims to combine these methods to create a more robust and accurate retrieval system that can handle a wider range of queries effectively.\n",
"\n",
"## Key Components\n",
"\n",
"1. PDF processing and text chunking\n",
"2. Vector store creation using FAISS and OpenAI embeddings\n",
"3. BM25 index creation for keyword-based retrieval\n",
"4. Custom fusion retrieval function that combines both methods\n",
"\n",
"## Method Details\n",
"\n",
"### Document Preprocessing\n",
"\n",
"1. The PDF is loaded and split into chunks using RecursiveCharacterTextSplitter.\n",
"2. Chunks are cleaned by replacing 't' with spaces (likely addressing a specific formatting issue).\n",
"\n",
"### Vector Store Creation\n",
"\n",
"1. OpenAI embeddings are used to create vector representations of the text chunks.\n",
"2. A FAISS vector store is created from these embeddings for efficient similarity search.\n",
"\n",
"### BM25 Index Creation\n",
"\n",
"1. A BM25 index is created from the same text chunks used for the vector store.\n",
"2. This allows for keyword-based retrieval alongside the vector-based method.\n",
"\n",
"### Fusion Retrieval Function\n",
"\n",
"The `fusion_retrieval` function is the core of this implementation:\n",
"\n",
"1. It takes a query and performs both vector-based and BM25-based retrieval.\n",
"2. Scores from both methods are normalized to a common scale.\n",
"3. A weighted combination of these scores is computed (controlled by the `alpha` parameter).\n",
"4. Documents are ranked based on the combined scores, and the top-k results are returned.\n",
"\n",
"## Benefits of this Approach\n",
"\n",
"1. Improved Retrieval Quality: By combining semantic and keyword-based search, the system can capture both conceptual similarity and exact keyword matches.\n",
"2. Flexibility: The `alpha` parameter allows for adjusting the balance between vector and keyword search based on specific use cases or query types.\n",
"3. Robustness: The combined approach can handle a wider range of queries effectively, mitigating weaknesses of individual methods.\n",
"4. Customizability: The system can be easily adapted to use different vector stores or keyword-based retrieval methods.\n",
"\n",
"## Conclusion\n",
"\n",
"Fusion retrieval represents a powerful approach to document search that combines the strengths of semantic understanding and keyword matching. By leveraging both vector-based and BM25 retrieval methods, it offers a more comprehensive and flexible solution for information retrieval tasks. This approach has potential applications in various fields where both conceptual similarity and keyword relevance are important, such as academic research, legal document search, or general-purpose search engines."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div style=\"text-align: center;\">\n",
"\n",
"<img src=\"../images/fusion_retrieval.svg\" alt=\"Fusion Retrieval\" style=\"width:100%; height:auto;\">\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import libraries "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"from dotenv import load_dotenv\n",
"from langchain.docstore.document import Document\n",
"\n",
"from typing import List\n",
"from rank_bm25 import BM25Okapi\n",
"import numpy as np\n",
"\n",
"\n",
"sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks\n",
"from helper_functions import *\n",
"from evaluation.evalute_rag import *\n",
"\n",
"# Load environment variables from a .env file\n",
"load_dotenv()\n",
"\n",
"# Set the OpenAI API key environment variable\n",
"os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define document path"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"path = \"../data/Understanding_Climate_Change.pdf\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Encode the pdf to vector store and return split document from the step before to create BM25 instance"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def encode_pdf_and_get_split_documents(path, chunk_size=1000, chunk_overlap=200):\n",
" \"\"\"\n",
" Encodes a PDF book into a vector store using OpenAI embeddings.\n",
"\n",
" Args:\n",
" path: The path to the PDF file.\n",
" chunk_size: The desired size of each text chunk.\n",
" chunk_overlap: The amount of overlap between consecutive chunks.\n",
"\n",
" Returns:\n",
" A FAISS vector store containing the encoded book content.\n",
" \"\"\"\n",
"\n",
" # Load PDF documents\n",
" loader = PyPDFLoader(path)\n",
" documents = loader.load()\n",
"\n",
" # Split documents into chunks\n",
" text_splitter = RecursiveCharacterTextSplitter(\n",
" chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len\n",
" )\n",
" texts = text_splitter.split_documents(documents)\n",
" cleaned_texts = replace_t_with_space(texts)\n",
"\n",
" # Create embeddings and vector store\n",
" embeddings = OpenAIEmbeddings()\n",
" vectorstore = FAISS.from_documents(cleaned_texts, embeddings)\n",
"\n",
" return vectorstore, cleaned_texts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create vectorstore and get the chunked documents"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"vectorstore, cleaned_texts = encode_pdf_and_get_split_documents(path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a bm25 index for retrieving documents by keywords"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def create_bm25_index(documents: List[Document]) -> BM25Okapi:\n",
" \"\"\"\n",
" Create a BM25 index from the given documents.\n",
"\n",
" BM25 (Best Matching 25) is a ranking function used in information retrieval.\n",
" It's based on the probabilistic retrieval framework and is an improvement over TF-IDF.\n",
"\n",
" Args:\n",
" documents (List[Document]): List of documents to index.\n",
"\n",
" Returns:\n",
" BM25Okapi: An index that can be used for BM25 scoring.\n",
" \"\"\"\n",
" # Tokenize each document by splitting on whitespace\n",
" # This is a simple approach and could be improved with more sophisticated tokenization\n",
" tokenized_docs = [doc.page_content.split() for doc in documents]\n",
" return BM25Okapi(tokenized_docs)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"bm25 = create_bm25_index(cleaned_texts) # Create BM25 index from the cleaned texts (chunks)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define a function that retrieves both semantically and by keyword, normalizes the scores and gets the top k documents"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def fusion_retrieval(vectorstore, bm25, query: str, k: int = 5, alpha: float = 0.5) -> List[Document]:\n",
" \"\"\"\n",
" Perform fusion retrieval combining keyword-based (BM25) and vector-based search.\n",
"\n",
" Args:\n",
" vectorstore (VectorStore): The vectorstore containing the documents.\n",
" bm25 (BM25Okapi): Pre-computed BM25 index.\n",
" query (str): The query string.\n",
" k (int): The number of documents to retrieve.\n",
" alpha (float): The weight for vector search scores (1-alpha will be the weight for BM25 scores).\n",
"\n",
" Returns:\n",
" List[Document]: The top k documents based on the combined scores.\n",
" \"\"\"\n",
" \n",
" epsilon = 1e-8\n",
"\n",
" # Step 1: Get all documents from the vectorstore\n",
" all_docs = vectorstore.similarity_search(\"\", k=vectorstore.index.ntotal)\n",
"\n",
" # Step 2: Perform BM25 search\n",
" bm25_scores = bm25.get_scores(query.split())\n",
"\n",
" # Step 3: Perform vector search\n",
" vector_results = vectorstore.similarity_search_with_score(query, k=len(all_docs))\n",
" \n",
" # Step 4: Normalize scores\n",
" vector_scores = np.array([score for _, score in vector_results])\n",
" vector_scores = 1 - (vector_scores - np.min(vector_scores)) / (np.max(vector_scores) - np.min(vector_scores) + epsilon)\n",
"\n",
" bm25_scores = (bm25_scores - np.min(bm25_scores)) / (np.max(bm25_scores) - np.min(bm25_scores) + epsilon)\n",
"\n",
" # Step 5: Combine scores\n",
" combined_scores = alpha * vector_scores + (1 - alpha) * bm25_scores \n",
"\n",
" # Step 6: Rank documents\n",
" sorted_indices = np.argsort(combined_scores)[::-1]\n",
" \n",
" # Step 7: Return top k documents\n",
" return [all_docs[i] for i in sorted_indices[:k]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use Case example"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Query\n",
"query = \"What are the impacts of climate change on the environment?\"\n",
"\n",
"# Perform fusion retrieval\n",
"top_docs = fusion_retrieval(vectorstore, bm25, query, k=5, alpha=0.5)\n",
"docs_content = [doc.page_content for doc in top_docs]\n",
"show_context(docs_content)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}