Files
RAG_Techniques/all_rag_techniques/retrieval_with_feedback_loop.ipynb
2024-07-28 13:57:31 +03:00

396 lines
14 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# RAG System with Feedback Loop: Enhancing Retrieval and Response Quality\n",
"\n",
"## Overview\n",
"\n",
"This system implements a Retrieval-Augmented Generation (RAG) approach with an integrated feedback loop. It aims to improve the quality and relevance of responses over time by incorporating user feedback and dynamically adjusting the retrieval process.\n",
"\n",
"## Motivation\n",
"\n",
"Traditional RAG systems can sometimes produce inconsistent or irrelevant responses due to limitations in the retrieval process or the underlying knowledge base. By implementing a feedback loop, we can:\n",
"\n",
"1. Continuously improve the quality of retrieved documents\n",
"2. Enhance the relevance of generated responses\n",
"3. Adapt the system to user preferences and needs over time\n",
"\n",
"## Key Components\n",
"\n",
"1. **PDF Content Extraction**: Extracts text from PDF documents\n",
"2. **Vectorstore**: Stores and indexes document embeddings for efficient retrieval\n",
"3. **Retriever**: Fetches relevant documents based on user queries\n",
"4. **Language Model**: Generates responses using retrieved documents\n",
"5. **Feedback Collection**: Gathers user feedback on response quality and relevance\n",
"6. **Feedback Storage**: Persists user feedback for future use\n",
"7. **Relevance Score Adjustment**: Modifies document relevance based on feedback\n",
"8. **Index Fine-tuning**: Periodically updates the vectorstore using accumulated feedback\n",
"\n",
"## Method Details\n",
"\n",
"### 1. Initial Setup\n",
"- The system reads PDF content and creates a vectorstore\n",
"- A retriever is initialized using the vectorstore\n",
"- A language model (LLM) is set up for response generation\n",
"\n",
"### 2. Query Processing\n",
"- When a user submits a query, the retriever fetches relevant documents\n",
"- The LLM generates a response based on the retrieved documents\n",
"\n",
"### 3. Feedback Collection\n",
"- The system collects user feedback on the response's relevance and quality\n",
"- Feedback is stored in a JSON file for persistence\n",
"\n",
"### 4. Relevance Score Adjustment\n",
"- For subsequent queries, the system loads previous feedback\n",
"- An LLM evaluates the relevance of past feedback to the current query\n",
"- Document relevance scores are adjusted based on this evaluation\n",
"\n",
"### 5. Retriever Update\n",
"- The retriever is updated with the adjusted document scores\n",
"- This ensures that future retrievals benefit from past feedback\n",
"\n",
"### 6. Periodic Index Fine-tuning\n",
"- At regular intervals, the system fine-tunes the index\n",
"- High-quality feedback is used to create additional documents\n",
"- The vectorstore is updated with these new documents, improving overall retrieval quality\n",
"\n",
"## Benefits of this Approach\n",
"\n",
"1. **Continuous Improvement**: The system learns from each interaction, gradually enhancing its performance.\n",
"2. **Personalization**: By incorporating user feedback, the system can adapt to individual or group preferences over time.\n",
"3. **Increased Relevance**: The feedback loop helps prioritize more relevant documents in future retrievals.\n",
"4. **Quality Control**: Low-quality or irrelevant responses are less likely to be repeated as the system evolves.\n",
"5. **Adaptability**: The system can adjust to changes in user needs or document contents over time.\n",
"\n",
"## Conclusion\n",
"\n",
"This RAG system with a feedback loop represents a significant advancement over traditional RAG implementations. By continuously learning from user interactions, it offers a more dynamic, adaptive, and user-centric approach to information retrieval and response generation. This system is particularly valuable in domains where information accuracy and relevance are critical, and where user needs may evolve over time.\n",
"\n",
"While the implementation adds complexity compared to a basic RAG system, the benefits in terms of response quality and user satisfaction make it a worthwhile investment for applications requiring high-quality, context-aware information retrieval and generation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div style=\"text-align: center;\">\n",
"\n",
"<img src=\"../images/retrieval_with_feedback_loop.svg\" alt=\"retrieval with feedback loop\" style=\"width:40%; height:auto;\">\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import relevant libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"from dotenv import load_dotenv\n",
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
"from langchain_openai import ChatOpenAI\n",
"from langchain.chains import RetrievalQA\n",
"import json\n",
"from typing import List, Dict, Any\n",
"\n",
"\n",
"sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks\n",
"from helper_functions import *\n",
"from evaluation.evalute_rag import *\n",
"\n",
"# Load environment variables from a .env file\n",
"load_dotenv()\n",
"\n",
"# Set the OpenAI API key environment variable\n",
"os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')\n",
"os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define documents path"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"path = \"../data/Understanding_Climate_Change.pdf\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create vector store and retrieval QA chain"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"content = read_pdf_to_string(path)\n",
"vectorstore = encode_from_string(content)\n",
"retriever = vectorstore.as_retriever()\n",
"\n",
"llm = ChatOpenAI(temperature=0, model_name=\"gpt-4o\", max_tokens=4000)\n",
"qa_chain = RetrievalQA.from_chain_type(llm, retriever=retriever)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Function to format user feedback in a dictionary"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def get_user_feedback(query, response, relevance, quality, comments=\"\"):\n",
" return {\n",
" \"query\": query,\n",
" \"response\": response,\n",
" \"relevance\": int(relevance),\n",
" \"quality\": int(quality),\n",
" \"comments\": comments\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Function to store the feedback in a json file"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def store_feedback(feedback):\n",
" with open(\"../data/feedback_data.json\", \"a\") as f:\n",
" json.dump(feedback, f)\n",
" f.write(\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Function to read the feedback file"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def load_feedback_data():\n",
" feedback_data = []\n",
" try:\n",
" with open(\"../data/feedback_data.json\", \"r\") as f:\n",
" for line in f:\n",
" feedback_data.append(json.loads(line.strip()))\n",
" except FileNotFoundError:\n",
" print(\"No feedback data file found. Starting with empty feedback.\")\n",
" return feedback_data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Function to adjust files relevancy based on the feedbacks file"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class Response(BaseModel):\n",
" answer: str = Field(..., title=\"The answer to the question. The options can be only 'Yes' or 'No'\")\n",
"\n",
"def adjust_relevance_scores(query: str, docs: List[Any], feedback_data: List[Dict[str, Any]]) -> List[Any]:\n",
" # Create a prompt template for relevance checking\n",
" relevance_prompt = PromptTemplate(\n",
" input_variables=[\"query\", \"feedback_query\", \"doc_content\", \"feedback_response\"],\n",
" template=\"\"\"\n",
" Determine if the following feedback response is relevant to the current query and document content.\n",
" You are also provided with the Feedback original query that was used to generate the feedback response.\n",
" Current query: {query}\n",
" Feedback query: {feedback_query}\n",
" Document content: {doc_content}\n",
" Feedback response: {feedback_response}\n",
" \n",
" Is this feedback relevant? Respond with only 'Yes' or 'No'.\n",
" \"\"\"\n",
" )\n",
" llm = ChatOpenAI(temperature=0, model_name=\"gpt-4o\", max_tokens=4000)\n",
"\n",
" # Create an LLMChain for relevance checking\n",
" relevance_chain = relevance_prompt | llm.with_structured_output(Response)\n",
"\n",
" for doc in docs:\n",
" relevant_feedback = []\n",
" \n",
" for feedback in feedback_data:\n",
" # Use LLM to check relevance\n",
" input_data = {\n",
" \"query\": query,\n",
" \"feedback_query\": feedback['query'],\n",
" \"doc_content\": doc.page_content[:1000],\n",
" \"feedback_response\": feedback['response']\n",
" }\n",
" result = relevance_chain.invoke(input_data).answer\n",
" \n",
" if result == 'yes':\n",
" relevant_feedback.append(feedback)\n",
" \n",
" # Adjust the relevance score based on feedback\n",
" if relevant_feedback:\n",
" avg_relevance = sum(f['relevance'] for f in relevant_feedback) / len(relevant_feedback)\n",
" doc.metadata['relevance_score'] *= (avg_relevance / 3) # Assuming a 1-5 scale, 3 is neutral\n",
" \n",
" # Re-rank documents based on adjusted scores\n",
" return sorted(docs, key=lambda x: x.metadata['relevance_score'], reverse=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Function to fine tune the vector index to include also queries + answers that received good feedbacks"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def fine_tune_index(feedback_data: List[Dict[str, Any]], texts: List[str]) -> Any:\n",
" # Filter high-quality responses\n",
" good_responses = [f for f in feedback_data if f['relevance'] >= 4 and f['quality'] >= 4]\n",
" \n",
" # Extract queries and responses, and create new documents\n",
" additional_texts = []\n",
" for f in good_responses:\n",
" combined_text = f['query'] + \" \" + f['response']\n",
" additional_texts.append(combined_text)\n",
"\n",
" # make the list a string\n",
" additional_texts = \" \".join(additional_texts)\n",
" \n",
" # Create a new index with original and high-quality texts\n",
" all_texts = texts + additional_texts\n",
" new_vectorstore = encode_from_string(all_texts)\n",
" \n",
" return new_vectorstore"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Demonstration of how to retrieve answers with respect to user feedbacks"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"\n",
"query = \"What is the greenhouse effect?\"\n",
"\n",
"# Get response from RAG system\n",
"response = qa_chain(query)[\"result\"]\n",
"\n",
"relevance = 5\n",
"quality = 5\n",
"\n",
"# Collect feedback\n",
"feedback = get_user_feedback(query, response, relevance, quality)\n",
"\n",
"# Store feedback\n",
"store_feedback(feedback)\n",
"\n",
"# Adjust relevance scores for future retrievals\n",
"docs = retriever.get_relevant_documents(query)\n",
"adjusted_docs = adjust_relevance_scores(query, docs, load_feedback_data())\n",
"\n",
"# Update the retriever with adjusted docs\n",
"retriever.search_kwargs['k'] = len(adjusted_docs)\n",
"retriever.search_kwargs['docs'] = adjusted_docs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Finetune the vectorstore periodicly"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# Periodically (e.g., daily or weekly), fine-tune the index\n",
"new_vectorstore = fine_tune_index(load_feedback_data(), content)\n",
"retriever = new_vectorstore.as_retriever()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}