Files
2025-02-24 13:49:53 +08:00

477 lines
18 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Corrective RAG Process: Retrieval-Augmented Generation with Dynamic Correction\n",
"\n",
"## Overview\n",
"\n",
"The Corrective RAG (Retrieval-Augmented Generation) process is an advanced information retrieval and response generation system. It extends the standard RAG approach by dynamically evaluating and correcting the retrieval process, combining the power of vector databases, web search, and language models to provide accurate and context-aware responses to user queries.\n",
"\n",
"## Motivation\n",
"\n",
"While traditional RAG systems have improved information retrieval and response generation, they can still fall short when the retrieved information is irrelevant or outdated. The Corrective RAG process addresses these limitations by:\n",
"\n",
"1. Leveraging pre-existing knowledge bases\n",
"2. Evaluating the relevance of retrieved information\n",
"3. Dynamically searching the web when necessary\n",
"4. Refining and combining knowledge from multiple sources\n",
"5. Generating human-like responses based on the most appropriate knowledge\n",
"\n",
"## Key Components\n",
"\n",
"1. **FAISS Index**: A vector database for efficient similarity search of pre-existing knowledge.\n",
"2. **Retrieval Evaluator**: Assesses the relevance of retrieved documents to the query.\n",
"3. **Knowledge Refinement**: Extracts key information from documents when necessary.\n",
"4. **Web Search Query Rewriter**: Optimizes queries for web searches when local knowledge is insufficient.\n",
"5. **Response Generator**: Creates human-like responses based on the accumulated knowledge.\n",
"\n",
"## Method Details\n",
"\n",
"1. **Document Retrieval**: \n",
" - Performs similarity search in the FAISS index to find relevant documents.\n",
" - Retrieves top-k documents (default k=3).\n",
"\n",
"2. **Document Evaluation**:\n",
" - Calculates relevance scores for each retrieved document.\n",
" - Determines the best course of action based on the highest relevance score.\n",
"\n",
"3. **Corrective Knowledge Acquisition**:\n",
" - If high relevance (score > 0.7): Uses the most relevant document as-is.\n",
" - If low relevance (score < 0.3): Corrects by performing a web search with a rewritten query.\n",
" - If ambiguous (0.3 ≤ score ≤ 0.7): Corrects by combining the most relevant document with web search results.\n",
"\n",
"4. **Adaptive Knowledge Processing**:\n",
" - For web search results: Refines the knowledge to extract key points.\n",
" - For ambiguous cases: Combines raw document content with refined web search results.\n",
"\n",
"5. **Response Generation**:\n",
" - Uses a language model to generate a human-like response based on the query and acquired knowledge.\n",
" - Includes source information in the response for transparency.\n",
"\n",
"## Benefits of the Corrective RAG Approach\n",
"\n",
"1. **Dynamic Correction**: Adapts to the quality of retrieved information, ensuring relevance and accuracy.\n",
"2. **Flexibility**: Leverages both pre-existing knowledge and web search as needed.\n",
"3. **Accuracy**: Evaluates the relevance of information before using it, ensuring high-quality responses.\n",
"4. **Transparency**: Provides source information, allowing users to verify the origin of the information.\n",
"5. **Efficiency**: Uses vector search for quick retrieval from large knowledge bases.\n",
"6. **Contextual Understanding**: Combines multiple sources of information when necessary to provide comprehensive responses.\n",
"7. **Up-to-date Information**: Can supplement or replace outdated local knowledge with current web information.\n",
"\n",
"## Conclusion\n",
"\n",
"The Corrective RAG process represents a sophisticated evolution of the standard RAG approach. By intelligently evaluating and correcting the retrieval process, it overcomes common limitations of traditional RAG systems. This dynamic approach ensures that responses are based on the most relevant and up-to-date information available, whether from local knowledge bases or the web. The system's ability to adapt its information sourcing strategy based on relevance scores makes it particularly suited for applications requiring high accuracy and current information, such as research assistance, dynamic knowledge bases, and advanced question-answering systems."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div style=\"text-align: center;\">\n",
"\n",
"<img src=\"../images/crag.svg\" alt=\"Corrective RAG\" style=\"width:80%; height:auto;\">\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import relevant libraries"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"from dotenv import load_dotenv\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain_openai import ChatOpenAI\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path since we work with notebooks\n",
"from helper_functions import *\n",
"from evaluation.evalute_rag import *\n",
"\n",
"# Load environment variables from a .env file\n",
"load_dotenv()\n",
"\n",
"# Set the OpenAI API key environment variable\n",
"os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')\n",
"from langchain.tools import DuckDuckGoSearchResults\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define files path"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"path = \"../data/Understanding_Climate_Change.pdf\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a vector store"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"vectorstore = encode_pdf(path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize OpenAI language model\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(model=\"gpt-4o-mini\", max_tokens=1000, temperature=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize search tool"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"search = DuckDuckGoSearchResults()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define retrieval evaluator, knowledge refinement and query rewriter llm chains"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Retrieval Evaluator\n",
"class RetrievalEvaluatorInput(BaseModel):\n",
" relevance_score: float = Field(..., description=\"The relevance score of the document to the query. the score should be between 0 and 1.\")\n",
"def retrieval_evaluator(query: str, document: str) -> float:\n",
" prompt = PromptTemplate(\n",
" input_variables=[\"query\", \"document\"],\n",
" template=\"On a scale from 0 to 1, how relevant is the following document to the query? Query: {query}\\nDocument: {document}\\nRelevance score:\"\n",
" )\n",
" chain = prompt | llm.with_structured_output(RetrievalEvaluatorInput)\n",
" input_variables = {\"query\": query, \"document\": document}\n",
" result = chain.invoke(input_variables).relevance_score\n",
" return result\n",
"\n",
"# Knowledge Refinement\n",
"class KnowledgeRefinementInput(BaseModel):\n",
" key_points: str = Field(..., description=\"The document to extract key information from.\")\n",
"def knowledge_refinement(document: str) -> List[str]:\n",
" prompt = PromptTemplate(\n",
" input_variables=[\"document\"],\n",
" template=\"Extract the key information from the following document in bullet points:\\n{document}\\nKey points:\"\n",
" )\n",
" chain = prompt | llm.with_structured_output(KnowledgeRefinementInput)\n",
" input_variables = {\"document\": document}\n",
" result = chain.invoke(input_variables).key_points\n",
" return [point.strip() for point in result.split('\\n') if point.strip()]\n",
"\n",
"# Web Search Query Rewriter\n",
"class QueryRewriterInput(BaseModel):\n",
" query: str = Field(..., description=\"The query to rewrite.\")\n",
"def rewrite_query(query: str) -> str:\n",
" prompt = PromptTemplate(\n",
" input_variables=[\"query\"],\n",
" template=\"Rewrite the following query to make it more suitable for a web search:\\n{query}\\nRewritten query:\"\n",
" )\n",
" chain = prompt | llm.with_structured_output(QueryRewriterInput)\n",
" input_variables = {\"query\": query}\n",
" return chain.invoke(input_variables).query.strip()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Helper function to parse search results\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def parse_search_results(results_string: str) -> List[Tuple[str, str]]:\n",
" \"\"\"\n",
" Parse a JSON string of search results into a list of title-link tuples.\n",
"\n",
" Args:\n",
" results_string (str): A JSON-formatted string containing search results.\n",
"\n",
" Returns:\n",
" List[Tuple[str, str]]: A list of tuples, where each tuple contains the title and link of a search result.\n",
" If parsing fails, an empty list is returned.\n",
" \"\"\"\n",
" try:\n",
" # Attempt to parse the JSON string\n",
" results = json.loads(results_string)\n",
" # Extract and return the title and link from each result\n",
" return [(result.get('title', 'Untitled'), result.get('link', '')) for result in results]\n",
" except json.JSONDecodeError:\n",
" # Handle JSON decoding errors by returning an empty list\n",
" print(\"Error parsing search results. Returning empty list.\")\n",
" return []"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define sub functions for the CRAG process"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def retrieve_documents(query: str, faiss_index: FAISS, k: int = 3) -> List[str]:\n",
" \"\"\"\n",
" Retrieve documents based on a query using a FAISS index.\n",
"\n",
" Args:\n",
" query (str): The query string to search for.\n",
" faiss_index (FAISS): The FAISS index used for similarity search.\n",
" k (int): The number of top documents to retrieve. Defaults to 3.\n",
"\n",
" Returns:\n",
" List[str]: A list of the retrieved document contents.\n",
" \"\"\"\n",
" docs = faiss_index.similarity_search(query, k=k)\n",
" return [doc.page_content for doc in docs]\n",
"\n",
"def evaluate_documents(query: str, documents: List[str]) -> List[float]:\n",
" \"\"\"\n",
" Evaluate the relevance of documents based on a query.\n",
"\n",
" Args:\n",
" query (str): The query string.\n",
" documents (List[str]): A list of document contents to evaluate.\n",
"\n",
" Returns:\n",
" List[float]: A list of relevance scores for each document.\n",
" \"\"\"\n",
" return [retrieval_evaluator(query, doc) for doc in documents]\n",
"\n",
"def perform_web_search(query: str) -> Tuple[List[str], List[Tuple[str, str]]]:\n",
" \"\"\"\n",
" Perform a web search based on a query.\n",
"\n",
" Args:\n",
" query (str): The query string to search for.\n",
"\n",
" Returns:\n",
" Tuple[List[str], List[Tuple[str, str]]]: \n",
" - A list of refined knowledge obtained from the web search.\n",
" - A list of tuples containing titles and links of the sources.\n",
" \"\"\"\n",
" rewritten_query = rewrite_query(query)\n",
" web_results = search.run(rewritten_query)\n",
" web_knowledge = knowledge_refinement(web_results)\n",
" sources = parse_search_results(web_results)\n",
" return web_knowledge, sources\n",
"\n",
"def generate_response(query: str, knowledge: str, sources: List[Tuple[str, str]]) -> str:\n",
" \"\"\"\n",
" Generate a response to a query using knowledge and sources.\n",
"\n",
" Args:\n",
" query (str): The query string.\n",
" knowledge (str): The refined knowledge to use in the response.\n",
" sources (List[Tuple[str, str]]): A list of tuples containing titles and links of the sources.\n",
"\n",
" Returns:\n",
" str: The generated response.\n",
" \"\"\"\n",
" response_prompt = PromptTemplate(\n",
" input_variables=[\"query\", \"knowledge\", \"sources\"],\n",
" template=\"Based on the following knowledge, answer the query. Include the sources with their links (if available) at the end of your answer:\\nQuery: {query}\\nKnowledge: {knowledge}\\nSources: {sources}\\nAnswer:\"\n",
" )\n",
" input_variables = {\n",
" \"query\": query,\n",
" \"knowledge\": knowledge,\n",
" \"sources\": \"\\n\".join([f\"{title}: {link}\" if link else title for title, link in sources])\n",
" }\n",
" response_chain = response_prompt | llm\n",
" return response_chain.invoke(input_variables).content\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CRAG process\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"def crag_process(query: str, faiss_index: FAISS) -> str:\n",
" \"\"\"\n",
" Process a query by retrieving, evaluating, and using documents or performing a web search to generate a response.\n",
"\n",
" Args:\n",
" query (str): The query string to process.\n",
" faiss_index (FAISS): The FAISS index used for document retrieval.\n",
"\n",
" Returns:\n",
" str: The generated response based on the query.\n",
" \"\"\"\n",
" print(f\"\\nProcessing query: {query}\")\n",
"\n",
" # Retrieve and evaluate documents\n",
" retrieved_docs = retrieve_documents(query, faiss_index)\n",
" eval_scores = evaluate_documents(query, retrieved_docs)\n",
" \n",
" print(f\"\\nRetrieved {len(retrieved_docs)} documents\")\n",
" print(f\"Evaluation scores: {eval_scores}\")\n",
"\n",
" # Determine action based on evaluation scores\n",
" max_score = max(eval_scores)\n",
" sources = []\n",
" \n",
" if max_score > 0.7:\n",
" print(\"\\nAction: Correct - Using retrieved document\")\n",
" best_doc = retrieved_docs[eval_scores.index(max_score)]\n",
" final_knowledge = best_doc\n",
" sources.append((\"Retrieved document\", \"\"))\n",
" elif max_score < 0.3:\n",
" print(\"\\nAction: Incorrect - Performing web search\")\n",
" final_knowledge, sources = perform_web_search(query)\n",
" else:\n",
" print(\"\\nAction: Ambiguous - Combining retrieved document and web search\")\n",
" best_doc = retrieved_docs[eval_scores.index(max_score)]\n",
" # Refine the retrieved knowledge\n",
" retrieved_knowledge = knowledge_refinement(best_doc)\n",
" web_knowledge, web_sources = perform_web_search(query)\n",
" final_knowledge = \"\\n\".join(retrieved_knowledge + web_knowledge)\n",
" sources = [(\"Retrieved document\", \"\")] + web_sources\n",
"\n",
" print(\"\\nFinal knowledge:\")\n",
" print(final_knowledge)\n",
" \n",
" print(\"\\nSources:\")\n",
" for title, link in sources:\n",
" print(f\"{title}: {link}\" if link else title)\n",
"\n",
" # Generate response\n",
" print(\"\\nGenerating response...\")\n",
" response = generate_response(query, final_knowledge, sources)\n",
"\n",
" print(\"\\nResponse generated\")\n",
" return response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example query with high relevance to the document\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"What are the main causes of climate change?\"\n",
"result = crag_process(query, vectorstore)\n",
"print(f\"Query: {query}\")\n",
"print(f\"Answer: {result}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example query with low relevance to the document\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"how did harry beat quirrell?\"\n",
"result = crag_process(query, vectorstore)\n",
"print(f\"Query: {query}\")\n",
"print(f\"Answer: {result}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}