Files
RAG_Techniques/all_rag_techniques/self_rag.ipynb
2024-08-04 17:37:25 +03:00

349 lines
14 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Self-RAG: A Dynamic Approach to Retrieval-Augmented Generation\n",
"\n",
"## Overview\n",
"\n",
"Self-RAG is an advanced algorithm that combines the power of retrieval-based and generation-based approaches in natural language processing. It dynamically decides whether to use retrieved information and how to best utilize it in generating responses, aiming to produce more accurate, relevant, and useful outputs.\n",
"\n",
"## Motivation\n",
"\n",
"Traditional question-answering systems often struggle with balancing the use of retrieved information and the generation of new content. Some systems might rely too heavily on retrieved data, leading to responses that lack flexibility, while others might generate responses without sufficient grounding in factual information. Self-RAG addresses these issues by implementing a multi-step process that carefully evaluates the necessity and relevance of retrieved information, and assesses the quality of generated responses.\n",
"\n",
"## Key Components\n",
"\n",
"1. **Retrieval Decision**: Determines if retrieval is necessary for a given query.\n",
"2. **Document Retrieval**: Fetches potentially relevant documents from a vector store.\n",
"3. **Relevance Evaluation**: Assesses the relevance of retrieved documents to the query.\n",
"4. **Response Generation**: Generates responses based on relevant contexts.\n",
"5. **Support Assessment**: Evaluates how well the generated response is supported by the context.\n",
"6. **Utility Evaluation**: Rates the usefulness of the generated response.\n",
"\n",
"## Method Details\n",
"\n",
"1. **Retrieval Decision**: The algorithm first decides if retrieval is necessary for the given query. This step prevents unnecessary retrieval for queries that can be answered directly.\n",
"\n",
"2. **Document Retrieval**: If retrieval is deemed necessary, the algorithm fetches the top-k most similar documents from a vector store.\n",
"\n",
"3. **Relevance Evaluation**: Each retrieved document is evaluated for its relevance to the query. This step filters out irrelevant information, ensuring that only pertinent context is used for generation.\n",
"\n",
"4. **Response Generation**: The algorithm generates responses using the relevant contexts. If no relevant contexts are found, it generates a response without retrieval.\n",
"\n",
"5. **Support Assessment**: Each generated response is evaluated to determine how well it is supported by the context. This step helps in identifying responses that are grounded in the provided information.\n",
"\n",
"6. **Utility Evaluation**: The utility of each response is rated, considering how well it addresses the original query.\n",
"\n",
"7. **Response Selection**: The final step involves selecting the best response based on the support assessment and utility evaluation.\n",
"\n",
"## Benefits of the Approach\n",
"\n",
"1. **Dynamic Retrieval**: By deciding whether retrieval is necessary, the system can adapt to different types of queries efficiently.\n",
"\n",
"2. **Relevance Filtering**: The relevance evaluation step ensures that only pertinent information is used, reducing noise in the generation process.\n",
"\n",
"3. **Quality Assurance**: The support assessment and utility evaluation provide a way to gauge the quality of generated responses.\n",
"\n",
"4. **Flexibility**: The system can generate responses with or without retrieval, adapting to the available information.\n",
"\n",
"5. **Improved Accuracy**: By grounding responses in relevant retrieved information and assessing their support, the system can produce more accurate outputs.\n",
"\n",
"## Conclusion\n",
"\n",
"Self-RAG represents a sophisticated approach to question-answering and information retrieval tasks. By incorporating multiple evaluation steps and dynamically deciding on the use of retrieved information, it aims to produce responses that are not only relevant and accurate but also useful to the end-user. This method showcases the potential of combining retrieval and generation techniques in a thoughtful, evaluated manner to enhance the quality of AI-generated responses."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div style=\"text-align: center;\">\n",
"\n",
"<img src=\"../images/self_rag.svg\" alt=\"Self RAG\" style=\"width:80%; height:auto;\">\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import relevant libraries"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"from dotenv import load_dotenv\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain_openai import ChatOpenAI\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks\n",
"from helper_functions import *\n",
"from evaluation.evalute_rag import *\n",
"\n",
"# Load environment variables from a .env file\n",
"load_dotenv()\n",
"\n",
"# Set the OpenAI API key environment variable\n",
"os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define files path"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"path = \"../data/Understanding_Climate_Change.pdf\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a vector store "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"vectorstore = encode_pdf(path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize the language model\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(model=\"gpt-4o-mini\", max_tokens=1000, temperature=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Defining prompt templates"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class RetrievalResponse(BaseModel):\n",
" response: str = Field(..., title=\"Determines if retrieval is necessary\", description=\"Output only 'Yes' or 'No'.\")\n",
"retrieval_prompt = PromptTemplate(\n",
" input_variables=[\"query\"],\n",
" template=\"Given the query '{query}', determine if retrieval is necessary. Output only 'Yes' or 'No'.\"\n",
")\n",
"\n",
"class RelevanceResponse(BaseModel):\n",
" response: str = Field(..., title=\"Determines if context is relevant\", description=\"Output only 'Relevant' or 'Irrelevant'.\")\n",
"relevance_prompt = PromptTemplate(\n",
" input_variables=[\"query\", \"context\"],\n",
" template=\"Given the query '{query}' and the context '{context}', determine if the context is relevant. Output only 'Relevant' or 'Irrelevant'.\"\n",
")\n",
"\n",
"class GenerationResponse(BaseModel):\n",
" response: str = Field(..., title=\"Generated response\", description=\"The generated response.\")\n",
"generation_prompt = PromptTemplate(\n",
" input_variables=[\"query\", \"context\"],\n",
" template=\"Given the query '{query}' and the context '{context}', generate a response.\"\n",
")\n",
"\n",
"class SupportResponse(BaseModel):\n",
" response: str = Field(..., title=\"Determines if response is supported\", description=\"Output 'Fully supported', 'Partially supported', or 'No support'.\")\n",
"support_prompt = PromptTemplate(\n",
" input_variables=[\"response\", \"context\"],\n",
" template=\"Given the response '{response}' and the context '{context}', determine if the response is supported by the context. Output 'Fully supported', 'Partially supported', or 'No support'.\"\n",
")\n",
"\n",
"class UtilityResponse(BaseModel):\n",
" response: int = Field(..., title=\"Utility rating\", description=\"Rate the utility of the response from 1 to 5.\")\n",
"utility_prompt = PromptTemplate(\n",
" input_variables=[\"query\", \"response\"],\n",
" template=\"Given the query '{query}' and the response '{response}', rate the utility of the response from 1 to 5.\"\n",
")\n",
"\n",
"# Create LLMChains for each step\n",
"retrieval_chain = retrieval_prompt | llm.with_structured_output(RetrievalResponse)\n",
"relevance_chain = relevance_prompt | llm.with_structured_output(RelevanceResponse)\n",
"generation_chain = generation_prompt | llm.with_structured_output(GenerationResponse)\n",
"support_chain = support_prompt | llm.with_structured_output(SupportResponse)\n",
"utility_chain = utility_prompt | llm.with_structured_output(UtilityResponse)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Defining the self RAG logic flow"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def self_rag(query, vectorstore, top_k=3):\n",
" print(f\"\\nProcessing query: {query}\")\n",
" \n",
" # Step 1: Determine if retrieval is necessary\n",
" print(\"Step 1: Determining if retrieval is necessary...\")\n",
" input_data = {\"query\": query}\n",
" retrieval_decision = retrieval_chain.invoke(input_data).response.strip().lower()\n",
" print(f\"Retrieval decision: {retrieval_decision}\")\n",
" \n",
" if retrieval_decision == 'yes':\n",
" # Step 2: Retrieve relevant documents\n",
" print(\"Step 2: Retrieving relevant documents...\")\n",
" docs = vectorstore.similarity_search(query, k=top_k)\n",
" contexts = [doc.page_content for doc in docs]\n",
" print(f\"Retrieved {len(contexts)} documents\")\n",
" \n",
" # Step 3: Evaluate relevance of retrieved documents\n",
" print(\"Step 3: Evaluating relevance of retrieved documents...\")\n",
" relevant_contexts = []\n",
" for i, context in enumerate(contexts):\n",
" input_data = {\"query\": query, \"context\": context}\n",
" relevance = relevance_chain.invoke(input_data).response.strip().lower()\n",
" print(f\"Document {i+1} relevance: {relevance}\")\n",
" if relevance == 'relevant':\n",
" relevant_contexts.append(context)\n",
" \n",
" print(f\"Number of relevant contexts: {len(relevant_contexts)}\")\n",
" \n",
" # If no relevant contexts found, generate without retrieval\n",
" if not relevant_contexts:\n",
" print(\"No relevant contexts found. Generating without retrieval...\")\n",
" input_data = {\"query\": query, \"context\": \"No relevant context found.\"}\n",
" return generation_chain.invoke(input_data).response\n",
" \n",
" # Step 4: Generate response using relevant contexts\n",
" print(\"Step 4: Generating responses using relevant contexts...\")\n",
" responses = []\n",
" for i, context in enumerate(relevant_contexts):\n",
" print(f\"Generating response for context {i+1}...\")\n",
" input_data = {\"query\": query, \"context\": context}\n",
" response = generation_chain.invoke(input_data).response\n",
" \n",
" # Step 5: Assess support\n",
" print(f\"Step 5: Assessing support for response {i+1}...\")\n",
" input_data = {\"response\": response, \"context\": context}\n",
" support = support_chain.invoke(input_data).response.strip().lower()\n",
" print(f\"Support assessment: {support}\")\n",
" \n",
" # Step 6: Evaluate utility\n",
" print(f\"Step 6: Evaluating utility for response {i+1}...\")\n",
" input_data = {\"query\": query, \"response\": response}\n",
" utility = int(utility_chain.invoke(input_data).response)\n",
" print(f\"Utility score: {utility}\")\n",
" \n",
" responses.append((response, support, utility))\n",
" \n",
" # Select the best response based on support and utility\n",
" print(\"Selecting the best response...\")\n",
" best_response = max(responses, key=lambda x: (x[1] == 'fully supported', x[2]))\n",
" print(f\"Best response support: {best_response[1]}, utility: {best_response[2]}\")\n",
" return best_response[0]\n",
" else:\n",
" # Generate without retrieval\n",
" print(\"Generating without retrieval...\")\n",
" input_data = {\"query\": query, \"context\": \"No retrieval necessary.\"}\n",
" return generation_chain.invoke(input_data).response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test the self-RAG function easy query with high relevance\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"What is the impact of climate change on the environment?\"\n",
"response = self_rag(query, vectorstore)\n",
"\n",
"print(\"\\nFinal response:\")\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test the self-RAG function with a more challenging query with low relevance\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"how did harry beat quirrell?\"\n",
"response = self_rag(query, vectorstore)\n",
"\n",
"print(\"\\nFinal response:\")\n",
"print(response)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}