added corrective RAG

This commit is contained in:
nird
2024-08-04 17:37:25 +03:00
parent 696f9095e8
commit 3d1c786933
3 changed files with 404 additions and 1 deletions

View File

@@ -0,0 +1,400 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import relevant libraries"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"from dotenv import load_dotenv\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain_openai import ChatOpenAI\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks\n",
"from helper_functions import *\n",
"from evaluation.evalute_rag import *\n",
"\n",
"# Load environment variables from a .env file\n",
"load_dotenv()\n",
"\n",
"# Set the OpenAI API key environment variable\n",
"os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')\n",
"from langchain.tools import DuckDuckGoSearchResults\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define files path"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"path = \"../data/Understanding_Climate_Change.pdf\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a vector store"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"vectorstore = encode_pdf(path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize OpenAI language model\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(model=\"gpt-4o-mini\", max_tokens=1000, temperature=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize search tool"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"search = DuckDuckGoSearchResults()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define retrieval evaluator, knowledge refinement and query rewriter llm chains"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Retrieval Evaluator\n",
"class RetrievalEvaluatorInput(BaseModel):\n",
" relevance_score: float = Field(..., description=\"The relevance score of the document to the query. the score should be between 0 and 1.\")\n",
"def retrieval_evaluator(query: str, document: str) -> float:\n",
" prompt = PromptTemplate(\n",
" input_variables=[\"query\", \"document\"],\n",
" template=\"On a scale from 0 to 1, how relevant is the following document to the query? Query: {query}\\nDocument: {document}\\nRelevance score:\"\n",
" )\n",
" chain = prompt | llm.with_structured_output(RetrievalEvaluatorInput)\n",
" input_variables = {\"query\": query, \"document\": document}\n",
" result = chain.invoke(input_variables).relevance_score\n",
" return result\n",
"\n",
"# Knowledge Refinement\n",
"class KnowledgeRefinementInput(BaseModel):\n",
" key_points: str = Field(..., description=\"The document to extract key information from.\")\n",
"def knowledge_refinement(document: str) -> List[str]:\n",
" prompt = PromptTemplate(\n",
" input_variables=[\"document\"],\n",
" template=\"Extract the key information from the following document in bullet points:\\n{document}\\nKey points:\"\n",
" )\n",
" chain = prompt | llm.with_structured_output(KnowledgeRefinementInput)\n",
" input_variables = {\"document\": document}\n",
" result = chain.invoke(input_variables).key_points\n",
" return [point.strip() for point in result.split('\\n') if point.strip()]\n",
"\n",
"# Web Search Query Rewriter\n",
"class QueryRewriterInput(BaseModel):\n",
" query: str = Field(..., description=\"The query to rewrite.\")\n",
"def rewrite_query(query: str) -> str:\n",
" prompt = PromptTemplate(\n",
" input_variables=[\"query\"],\n",
" template=\"Rewrite the following query to make it more suitable for a web search:\\n{query}\\nRewritten query:\"\n",
" )\n",
" chain = prompt | llm.with_structured_output(QueryRewriterInput)\n",
" input_variables = {\"query\": query}\n",
" return chain.invoke(input_variables).query.strip()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Helper function to parse search results\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def parse_search_results(results_string: str) -> List[Tuple[str, str]]:\n",
" \"\"\"\n",
" Parse a JSON string of search results into a list of title-link tuples.\n",
"\n",
" Args:\n",
" results_string (str): A JSON-formatted string containing search results.\n",
"\n",
" Returns:\n",
" List[Tuple[str, str]]: A list of tuples, where each tuple contains the title and link of a search result.\n",
" If parsing fails, an empty list is returned.\n",
" \"\"\"\n",
" try:\n",
" # Attempt to parse the JSON string\n",
" results = json.loads(results_string)\n",
" # Extract and return the title and link from each result\n",
" return [(result.get('title', 'Untitled'), result.get('link', '')) for result in results]\n",
" except json.JSONDecodeError:\n",
" # Handle JSON decoding errors by returning an empty list\n",
" print(\"Error parsing search results. Returning empty list.\")\n",
" return []"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define sub functions for the CRAG process"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def retrieve_documents(query: str, faiss_index: FAISS, k: int = 3) -> List[str]:\n",
" \"\"\"\n",
" Retrieve documents based on a query using a FAISS index.\n",
"\n",
" Args:\n",
" query (str): The query string to search for.\n",
" faiss_index (FAISS): The FAISS index used for similarity search.\n",
" k (int): The number of top documents to retrieve. Defaults to 3.\n",
"\n",
" Returns:\n",
" List[str]: A list of the retrieved document contents.\n",
" \"\"\"\n",
" docs = faiss_index.similarity_search(query, k=k)\n",
" return [doc.page_content for doc in docs]\n",
"\n",
"def evaluate_documents(query: str, documents: List[str]) -> List[float]:\n",
" \"\"\"\n",
" Evaluate the relevance of documents based on a query.\n",
"\n",
" Args:\n",
" query (str): The query string.\n",
" documents (List[str]): A list of document contents to evaluate.\n",
"\n",
" Returns:\n",
" List[float]: A list of relevance scores for each document.\n",
" \"\"\"\n",
" return [retrieval_evaluator(query, doc) for doc in documents]\n",
"\n",
"def perform_web_search(query: str) -> Tuple[List[str], List[Tuple[str, str]]]:\n",
" \"\"\"\n",
" Perform a web search based on a query.\n",
"\n",
" Args:\n",
" query (str): The query string to search for.\n",
"\n",
" Returns:\n",
" Tuple[List[str], List[Tuple[str, str]]]: \n",
" - A list of refined knowledge obtained from the web search.\n",
" - A list of tuples containing titles and links of the sources.\n",
" \"\"\"\n",
" rewritten_query = rewrite_query(query)\n",
" web_results = search.run(rewritten_query)\n",
" web_knowledge = knowledge_refinement(web_results)\n",
" sources = parse_search_results(web_results)\n",
" return web_knowledge, sources\n",
"\n",
"def generate_response(query: str, knowledge: str, sources: List[Tuple[str, str]]) -> str:\n",
" \"\"\"\n",
" Generate a response to a query using knowledge and sources.\n",
"\n",
" Args:\n",
" query (str): The query string.\n",
" knowledge (str): The refined knowledge to use in the response.\n",
" sources (List[Tuple[str, str]]): A list of tuples containing titles and links of the sources.\n",
"\n",
" Returns:\n",
" str: The generated response.\n",
" \"\"\"\n",
" response_prompt = PromptTemplate(\n",
" input_variables=[\"query\", \"knowledge\", \"sources\"],\n",
" template=\"Based on the following knowledge, answer the query. Include the sources with their links (if available) at the end of your answer:\\nQuery: {query}\\nKnowledge: {knowledge}\\nSources: {sources}\\nAnswer:\"\n",
" )\n",
" input_variables = {\n",
" \"query\": query,\n",
" \"knowledge\": knowledge,\n",
" \"sources\": \"\\n\".join([f\"{title}: {link}\" if link else title for title, link in sources])\n",
" }\n",
" response_chain = response_prompt | llm\n",
" return response_chain.invoke(input_variables).content\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CRAG process\n"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"def crag_process(query: str, faiss_index: FAISS) -> str:\n",
" \"\"\"\n",
" Process a query by retrieving, evaluating, and using documents or performing a web search to generate a response.\n",
"\n",
" Args:\n",
" query (str): The query string to process.\n",
" faiss_index (FAISS): The FAISS index used for document retrieval.\n",
"\n",
" Returns:\n",
" str: The generated response based on the query.\n",
" \"\"\"\n",
" print(f\"\\nProcessing query: {query}\")\n",
"\n",
" # Retrieve and evaluate documents\n",
" retrieved_docs = retrieve_documents(query, faiss_index)\n",
" eval_scores = evaluate_documents(query, retrieved_docs)\n",
" \n",
" print(f\"\\nRetrieved {len(retrieved_docs)} documents\")\n",
" print(f\"Evaluation scores: {eval_scores}\")\n",
"\n",
" # Determine action based on evaluation scores\n",
" max_score = max(eval_scores)\n",
" sources = []\n",
" \n",
" if max_score > 0.7:\n",
" print(\"\\nAction: Correct - Using retrieved document\")\n",
" best_doc = retrieved_docs[eval_scores.index(max_score)]\n",
" final_knowledge = best_doc\n",
" sources.append((\"Retrieved document\", \"\"))\n",
" elif max_score < 0.3:\n",
" print(\"\\nAction: Incorrect - Performing web search\")\n",
" final_knowledge, sources = perform_web_search(query)\n",
" else:\n",
" print(\"\\nAction: Ambiguous - Combining retrieved document and web search\")\n",
" best_doc = retrieved_docs[eval_scores.index(max_score)]\n",
" # Refine the retrieved knowledge\n",
" retrieved_knowledge = knowledge_refinement(best_doc)\n",
" web_knowledge, web_sources = perform_web_search(query)\n",
" final_knowledge = \"\\n\".join(retrieved_knowledge + web_knowledge)\n",
" sources = [(\"Retrieved document\", \"\")] + web_sources\n",
"\n",
" print(\"\\nFinal knowledge:\")\n",
" print(final_knowledge)\n",
" \n",
" print(\"\\nSources:\")\n",
" for title, link in sources:\n",
" print(f\"{title}: {link}\" if link else title)\n",
"\n",
" # Generate response\n",
" print(\"\\nGenerating response...\")\n",
" response = generate_response(query, final_knowledge, sources)\n",
"\n",
" print(\"\\nResponse generated\")\n",
" return response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example query with high relevance to the document\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"What are the main causes of climate change?\"\n",
"result = crag_process(query, vectorstore)\n",
"print(f\"Query: {query}\")\n",
"print(f\"Answer: {result}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example query with low relevance to the document\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"how did harry beat quirrell?\"\n",
"result = crag_process(query, vectorstore)\n",
"print(f\"Query: {query}\")\n",
"print(f\"Answer: {result}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -62,7 +62,7 @@
"source": [
"<div style=\"text-align: center;\">\n",
"\n",
"<img src=\"../images/self_rag.svg\" alt=\"self RAG\" style=\"width:80%; height:auto;\">\n",
"<img src=\"../images/self_rag.svg\" alt=\"Self RAG\" style=\"width:80%; height:auto;\">\n",
"</div>"
]
},

3
images/crag.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 27 KiB