Files
RAG_Techniques/all_rag_techniques/reranking.ipynb
2024-08-30 13:57:03 +03:00

499 lines
18 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reranking Methods in RAG Systems\n",
"\n",
"## Overview\n",
"Reranking is a crucial step in Retrieval-Augmented Generation (RAG) systems that aims to improve the relevance and quality of retrieved documents. It involves reassessing and reordering initially retrieved documents to ensure that the most pertinent information is prioritized for subsequent processing or presentation.\n",
"\n",
"## Motivation\n",
"The primary motivation for reranking in RAG systems is to overcome limitations of initial retrieval methods, which often rely on simpler similarity metrics. Reranking allows for more sophisticated relevance assessment, taking into account nuanced relationships between queries and documents that might be missed by traditional retrieval techniques. This process aims to enhance the overall performance of RAG systems by ensuring that the most relevant information is used in the generation phase.\n",
"\n",
"## Key Components\n",
"Reranking systems typically include the following components:\n",
"\n",
"1. Initial Retriever: Often a vector store using embedding-based similarity search.\n",
"2. Reranking Model: This can be either:\n",
" - A Large Language Model (LLM) for scoring relevance\n",
" - A Cross-Encoder model specifically trained for relevance assessment\n",
"3. Scoring Mechanism: A method to assign relevance scores to documents\n",
"4. Sorting and Selection Logic: To reorder documents based on new scores\n",
"\n",
"## Method Details\n",
"The reranking process generally follows these steps:\n",
"\n",
"1. Initial Retrieval: Fetch an initial set of potentially relevant documents.\n",
"2. Pair Creation: Form query-document pairs for each retrieved document.\n",
"3. Scoring: \n",
" - LLM Method: Use prompts to ask the LLM to rate document relevance.\n",
" - Cross-Encoder Method: Feed query-document pairs directly into the model.\n",
"4. Score Interpretation: Parse and normalize the relevance scores.\n",
"5. Reordering: Sort documents based on their new relevance scores.\n",
"6. Selection: Choose the top K documents from the reordered list.\n",
"\n",
"## Benefits of this Approach\n",
"Reranking offers several advantages:\n",
"\n",
"1. Improved Relevance: By using more sophisticated models, reranking can capture subtle relevance factors.\n",
"2. Flexibility: Different reranking methods can be applied based on specific needs and resources.\n",
"3. Enhanced Context Quality: Providing more relevant documents to the RAG system improves the quality of generated responses.\n",
"4. Reduced Noise: Reranking helps filter out less relevant information, focusing on the most pertinent content.\n",
"\n",
"## Conclusion\n",
"Reranking is a powerful technique in RAG systems that significantly enhances the quality of retrieved information. Whether using LLM-based scoring or specialized Cross-Encoder models, reranking allows for more nuanced and accurate assessment of document relevance. This improved relevance translates directly to better performance in downstream tasks, making reranking an essential component in advanced RAG implementations.\n",
"\n",
"The choice between LLM-based and Cross-Encoder reranking methods depends on factors such as required accuracy, available computational resources, and specific application needs. Both approaches offer substantial improvements over basic retrieval methods and contribute to the overall effectiveness of RAG systems."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div style=\"text-align: center;\">\n",
"\n",
"<img src=\"../images/reranking-visualization.svg\" alt=\"rerank llm\" style=\"width:100%; height:auto;\">\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div style=\"text-align: center;\">\n",
"\n",
"<img src=\"../images/reranking_comparison.svg\" alt=\"rerank llm\" style=\"width:100%; height:auto;\">\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import relevant libraries"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"from dotenv import load_dotenv\n",
"from langchain.docstore.document import Document\n",
"from typing import List, Dict, Any, Tuple\n",
"from langchain_openai import ChatOpenAI\n",
"from langchain.chains import RetrievalQA\n",
"from langchain_core.retrievers import BaseRetriever\n",
"from sentence_transformers import CrossEncoder\n",
"\n",
"\n",
"sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks\n",
"from helper_functions import *\n",
"from evaluation.evalute_rag import *\n",
"\n",
"# Load environment variables from a .env file\n",
"load_dotenv()\n",
"\n",
"# Set the OpenAI API key environment variable\n",
"os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define the document's path"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
"outputs": [],
"source": [
"path = \"../data/Understanding_Climate_Change.pdf\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a vector store"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {},
"outputs": [],
"source": [
"vectorstore = encode_pdf(path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Method 1: LLM based function to rerank the retrieved documents"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div style=\"text-align: center;\">\n",
"\n",
"<img src=\"../images/rerank_llm.svg\" alt=\"rerank llm\" style=\"width:40%; height:auto;\">\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a custom reranking function\n"
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
"outputs": [],
"source": [
"class RatingScore(BaseModel):\n",
" relevance_score: float = Field(..., description=\"The relevance score of a document to a query.\")\n",
"\n",
"def rerank_documents(query: str, docs: List[Document], top_n: int = 3) -> List[Document]:\n",
" prompt_template = PromptTemplate(\n",
" input_variables=[\"query\", \"doc\"],\n",
" template=\"\"\"On a scale of 1-10, rate the relevance of the following document to the query. Consider the specific context and intent of the query, not just keyword matches.\n",
" Query: {query}\n",
" Document: {doc}\n",
" Relevance Score:\"\"\"\n",
" )\n",
" \n",
" llm = ChatOpenAI(temperature=0, model_name=\"gpt-4o\", max_tokens=4000)\n",
" llm_chain = prompt_template | llm.with_structured_output(RatingScore)\n",
" \n",
" scored_docs = []\n",
" for doc in docs:\n",
" input_data = {\"query\": query, \"doc\": doc.page_content}\n",
" score = llm_chain.invoke(input_data).relevance_score\n",
" try:\n",
" score = float(score)\n",
" except ValueError:\n",
" score = 0 # Default score if parsing fails\n",
" scored_docs.append((doc, score))\n",
" \n",
" reranked_docs = sorted(scored_docs, key=lambda x: x[1], reverse=True)\n",
" return [doc for doc, _ in reranked_docs[:top_n]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example usage of the reranking function with a sample query relevant to the document\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"What are the impacts of climate change on biodiversity?\"\n",
"initial_docs = vectorstore.similarity_search(query, k=15)\n",
"reranked_docs = rerank_documents(query, initial_docs)\n",
"\n",
"# print first 3 initial documents\n",
"print(\"Top initial documents:\")\n",
"for i, doc in enumerate(initial_docs[:3]):\n",
" print(f\"\\nDocument {i+1}:\")\n",
" print(doc.page_content[:200] + \"...\") # Print first 200 characters of each document\n",
"\n",
"\n",
"# Print results\n",
"print(f\"Query: {query}\\n\")\n",
"print(\"Top reranked documents:\")\n",
"for i, doc in enumerate(reranked_docs):\n",
" print(f\"\\nDocument {i+1}:\")\n",
" print(doc.page_content[:200] + \"...\") # Print first 200 characters of each document"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a custom retriever based on our reranker"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
"outputs": [],
"source": [
"# Create a custom retriever class\n",
"class CustomRetriever(BaseRetriever, BaseModel):\n",
" \n",
" vectorstore: Any = Field(description=\"Vector store for initial retrieval\")\n",
"\n",
" class Config:\n",
" arbitrary_types_allowed = True\n",
"\n",
" def get_relevant_documents(self, query: str, num_docs=2) -> List[Document]:\n",
" initial_docs = self.vectorstore.similarity_search(query, k=30)\n",
" return rerank_documents(query, initial_docs, top_n=num_docs)\n",
"\n",
"\n",
"# Create the custom retriever\n",
"custom_retriever = CustomRetriever(vectorstore=vectorstore)\n",
"\n",
"# Create an LLM for answering questions\n",
"llm = ChatOpenAI(temperature=0, model_name=\"gpt-4o\")\n",
"\n",
"# Create the RetrievalQA chain with the custom retriever\n",
"qa_chain = RetrievalQA.from_chain_type(\n",
" llm=llm,\n",
" chain_type=\"stuff\",\n",
" retriever=custom_retriever,\n",
" return_source_documents=True\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example query\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"result = qa_chain({\"query\": query})\n",
"\n",
"print(f\"\\nQuestion: {query}\")\n",
"print(f\"Answer: {result['result']}\")\n",
"print(\"\\nRelevant source documents:\")\n",
"for i, doc in enumerate(result[\"source_documents\"]):\n",
" print(f\"\\nDocument {i+1}:\")\n",
" print(doc.page_content[:200] + \"...\") # Print first 200 characters of each document"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example that demonstrates why we should use reranking "
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Comparison of Retrieval Techniques\n",
"==================================\n",
"Query: what is the capital of france?\n",
"\n",
"Baseline Retrieval Result:\n",
"\n",
"Document 1:\n",
"The capital of France is great.\n",
"\n",
"Document 2:\n",
"The capital of France is beautiful.\n",
"\n",
"Advanced Retrieval Result:\n",
"\n",
"Document 1:\n",
"I really enjoyed my trip to Paris, France. The city is beautiful and the food is delicious. I would love to visit again. Such a great capital city.\n",
"\n",
"Document 2:\n",
"Have you ever visited Paris? It is a beautiful city where you can eat delicious food and see the Eiffel Tower. \n",
" I really enjoyed all the cities in france, but its capital with the Eiffel Tower is my favorite city.\n"
]
}
],
"source": [
"chunks = [\n",
" \"The capital of France is great.\",\n",
" \"The capital of France is huge.\",\n",
" \"The capital of France is beautiful.\",\n",
" \"\"\"Have you ever visited Paris? It is a beautiful city where you can eat delicious food and see the Eiffel Tower. \n",
" I really enjoyed all the cities in france, but its capital with the Eiffel Tower is my favorite city.\"\"\", \n",
" \"I really enjoyed my trip to Paris, France. The city is beautiful and the food is delicious. I would love to visit again. Such a great capital city.\"\n",
"]\n",
"docs = [Document(page_content=sentence) for sentence in chunks]\n",
"\n",
"\n",
"def compare_rag_techniques(query: str, docs: List[Document] = docs) -> None:\n",
" embeddings = OpenAIEmbeddings()\n",
" vectorstore = FAISS.from_documents(docs, embeddings)\n",
"\n",
" print(\"Comparison of Retrieval Techniques\")\n",
" print(\"==================================\")\n",
" print(f\"Query: {query}\\n\")\n",
" \n",
" print(\"Baseline Retrieval Result:\")\n",
" baseline_docs = vectorstore.similarity_search(query, k=2)\n",
" for i, doc in enumerate(baseline_docs):\n",
" print(f\"\\nDocument {i+1}:\")\n",
" print(doc.page_content)\n",
"\n",
" print(\"\\nAdvanced Retrieval Result:\")\n",
" custom_retriever = CustomRetriever(vectorstore=vectorstore)\n",
" advanced_docs = custom_retriever.get_relevant_documents(query)\n",
" for i, doc in enumerate(advanced_docs):\n",
" print(f\"\\nDocument {i+1}:\")\n",
" print(doc.page_content)\n",
"\n",
"\n",
"query = \"what is the capital of france?\"\n",
"compare_rag_techniques(query, docs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Method 2: Cross Encoder models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div style=\"text-align: center;\">\n",
"\n",
"<img src=\"../images/rerank_cross_encoder.svg\" alt=\"rerank cross encoder\" style=\"width:40%; height:auto;\">\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define the cross encoder class"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')\n",
"\n",
"class CrossEncoderRetriever(BaseRetriever, BaseModel):\n",
" vectorstore: Any = Field(description=\"Vector store for initial retrieval\")\n",
" cross_encoder: Any = Field(description=\"Cross-encoder model for reranking\")\n",
" k: int = Field(default=5, description=\"Number of documents to retrieve initially\")\n",
" rerank_top_k: int = Field(default=3, description=\"Number of documents to return after reranking\")\n",
"\n",
" class Config:\n",
" arbitrary_types_allowed = True\n",
"\n",
" def get_relevant_documents(self, query: str) -> List[Document]:\n",
" # Initial retrieval\n",
" initial_docs = self.vectorstore.similarity_search(query, k=self.k)\n",
" \n",
" # Prepare pairs for cross-encoder\n",
" pairs = [[query, doc.page_content] for doc in initial_docs]\n",
" \n",
" # Get cross-encoder scores\n",
" scores = self.cross_encoder.predict(pairs)\n",
" \n",
" # Sort documents by score\n",
" scored_docs = sorted(zip(initial_docs, scores), key=lambda x: x[1], reverse=True)\n",
" \n",
" # Return top reranked documents\n",
" return [doc for doc, _ in scored_docs[:self.rerank_top_k]]\n",
"\n",
" async def aget_relevant_documents(self, query: str) -> List[Document]:\n",
" raise NotImplementedError(\"Async retrieval not implemented\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create an instance and showcase over an example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create the cross-encoder retriever\n",
"cross_encoder_retriever = CrossEncoderRetriever(\n",
" vectorstore=vectorstore,\n",
" cross_encoder=cross_encoder,\n",
" k=10, # Retrieve 10 documents initially\n",
" rerank_top_k=5 # Return top 5 after reranking\n",
")\n",
"\n",
"# Set up the LLM\n",
"llm = ChatOpenAI(temperature=0, model_name=\"gpt-4o\")\n",
"\n",
"# Create the RetrievalQA chain with the cross-encoder retriever\n",
"qa_chain = RetrievalQA.from_chain_type(\n",
" llm=llm,\n",
" chain_type=\"stuff\",\n",
" retriever=cross_encoder_retriever,\n",
" return_source_documents=True\n",
")\n",
"\n",
"# Example query\n",
"query = \"What are the impacts of climate change on biodiversity?\"\n",
"result = qa_chain({\"query\": query})\n",
"\n",
"print(f\"\\nQuestion: {query}\")\n",
"print(f\"Answer: {result['result']}\")\n",
"print(\"\\nRelevant source documents:\")\n",
"for i, doc in enumerate(result[\"source_documents\"]):\n",
" print(f\"\\nDocument {i+1}:\")\n",
" print(doc.page_content[:200] + \"...\") # Print first 200 characters of each document"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}