mirror of
https://github.com/NirDiamant/RAG_Techniques.git
synced 2025-04-07 00:48:52 +03:00
added corrective RAG
This commit is contained in:
400
all_rag_techniques/crag.ipynb
Normal file
400
all_rag_techniques/crag.ipynb
Normal file
@@ -0,0 +1,400 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Import relevant libraries"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks\n",
|
||||
"from helper_functions import *\n",
|
||||
"from evaluation.evalute_rag import *\n",
|
||||
"\n",
|
||||
"# Load environment variables from a .env file\n",
|
||||
"load_dotenv()\n",
|
||||
"\n",
|
||||
"# Set the OpenAI API key environment variable\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')\n",
|
||||
"from langchain.tools import DuckDuckGoSearchResults\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define files path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"path = \"../data/Understanding_Climate_Change.pdf\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create a vector store"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vectorstore = encode_pdf(path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initialize OpenAI language model\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = ChatOpenAI(model=\"gpt-4o-mini\", max_tokens=1000, temperature=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initialize search tool"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search = DuckDuckGoSearchResults()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define retrieval evaluator, knowledge refinement and query rewriter llm chains"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Retrieval Evaluator\n",
|
||||
"class RetrievalEvaluatorInput(BaseModel):\n",
|
||||
" relevance_score: float = Field(..., description=\"The relevance score of the document to the query. the score should be between 0 and 1.\")\n",
|
||||
"def retrieval_evaluator(query: str, document: str) -> float:\n",
|
||||
" prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"query\", \"document\"],\n",
|
||||
" template=\"On a scale from 0 to 1, how relevant is the following document to the query? Query: {query}\\nDocument: {document}\\nRelevance score:\"\n",
|
||||
" )\n",
|
||||
" chain = prompt | llm.with_structured_output(RetrievalEvaluatorInput)\n",
|
||||
" input_variables = {\"query\": query, \"document\": document}\n",
|
||||
" result = chain.invoke(input_variables).relevance_score\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
"# Knowledge Refinement\n",
|
||||
"class KnowledgeRefinementInput(BaseModel):\n",
|
||||
" key_points: str = Field(..., description=\"The document to extract key information from.\")\n",
|
||||
"def knowledge_refinement(document: str) -> List[str]:\n",
|
||||
" prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"document\"],\n",
|
||||
" template=\"Extract the key information from the following document in bullet points:\\n{document}\\nKey points:\"\n",
|
||||
" )\n",
|
||||
" chain = prompt | llm.with_structured_output(KnowledgeRefinementInput)\n",
|
||||
" input_variables = {\"document\": document}\n",
|
||||
" result = chain.invoke(input_variables).key_points\n",
|
||||
" return [point.strip() for point in result.split('\\n') if point.strip()]\n",
|
||||
"\n",
|
||||
"# Web Search Query Rewriter\n",
|
||||
"class QueryRewriterInput(BaseModel):\n",
|
||||
" query: str = Field(..., description=\"The query to rewrite.\")\n",
|
||||
"def rewrite_query(query: str) -> str:\n",
|
||||
" prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"query\"],\n",
|
||||
" template=\"Rewrite the following query to make it more suitable for a web search:\\n{query}\\nRewritten query:\"\n",
|
||||
" )\n",
|
||||
" chain = prompt | llm.with_structured_output(QueryRewriterInput)\n",
|
||||
" input_variables = {\"query\": query}\n",
|
||||
" return chain.invoke(input_variables).query.strip()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Helper function to parse search results\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def parse_search_results(results_string: str) -> List[Tuple[str, str]]:\n",
|
||||
" \"\"\"\n",
|
||||
" Parse a JSON string of search results into a list of title-link tuples.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" results_string (str): A JSON-formatted string containing search results.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" List[Tuple[str, str]]: A list of tuples, where each tuple contains the title and link of a search result.\n",
|
||||
" If parsing fails, an empty list is returned.\n",
|
||||
" \"\"\"\n",
|
||||
" try:\n",
|
||||
" # Attempt to parse the JSON string\n",
|
||||
" results = json.loads(results_string)\n",
|
||||
" # Extract and return the title and link from each result\n",
|
||||
" return [(result.get('title', 'Untitled'), result.get('link', '')) for result in results]\n",
|
||||
" except json.JSONDecodeError:\n",
|
||||
" # Handle JSON decoding errors by returning an empty list\n",
|
||||
" print(\"Error parsing search results. Returning empty list.\")\n",
|
||||
" return []"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define sub functions for the CRAG process"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def retrieve_documents(query: str, faiss_index: FAISS, k: int = 3) -> List[str]:\n",
|
||||
" \"\"\"\n",
|
||||
" Retrieve documents based on a query using a FAISS index.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" query (str): The query string to search for.\n",
|
||||
" faiss_index (FAISS): The FAISS index used for similarity search.\n",
|
||||
" k (int): The number of top documents to retrieve. Defaults to 3.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" List[str]: A list of the retrieved document contents.\n",
|
||||
" \"\"\"\n",
|
||||
" docs = faiss_index.similarity_search(query, k=k)\n",
|
||||
" return [doc.page_content for doc in docs]\n",
|
||||
"\n",
|
||||
"def evaluate_documents(query: str, documents: List[str]) -> List[float]:\n",
|
||||
" \"\"\"\n",
|
||||
" Evaluate the relevance of documents based on a query.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" query (str): The query string.\n",
|
||||
" documents (List[str]): A list of document contents to evaluate.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" List[float]: A list of relevance scores for each document.\n",
|
||||
" \"\"\"\n",
|
||||
" return [retrieval_evaluator(query, doc) for doc in documents]\n",
|
||||
"\n",
|
||||
"def perform_web_search(query: str) -> Tuple[List[str], List[Tuple[str, str]]]:\n",
|
||||
" \"\"\"\n",
|
||||
" Perform a web search based on a query.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" query (str): The query string to search for.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" Tuple[List[str], List[Tuple[str, str]]]: \n",
|
||||
" - A list of refined knowledge obtained from the web search.\n",
|
||||
" - A list of tuples containing titles and links of the sources.\n",
|
||||
" \"\"\"\n",
|
||||
" rewritten_query = rewrite_query(query)\n",
|
||||
" web_results = search.run(rewritten_query)\n",
|
||||
" web_knowledge = knowledge_refinement(web_results)\n",
|
||||
" sources = parse_search_results(web_results)\n",
|
||||
" return web_knowledge, sources\n",
|
||||
"\n",
|
||||
"def generate_response(query: str, knowledge: str, sources: List[Tuple[str, str]]) -> str:\n",
|
||||
" \"\"\"\n",
|
||||
" Generate a response to a query using knowledge and sources.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" query (str): The query string.\n",
|
||||
" knowledge (str): The refined knowledge to use in the response.\n",
|
||||
" sources (List[Tuple[str, str]]): A list of tuples containing titles and links of the sources.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" str: The generated response.\n",
|
||||
" \"\"\"\n",
|
||||
" response_prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"query\", \"knowledge\", \"sources\"],\n",
|
||||
" template=\"Based on the following knowledge, answer the query. Include the sources with their links (if available) at the end of your answer:\\nQuery: {query}\\nKnowledge: {knowledge}\\nSources: {sources}\\nAnswer:\"\n",
|
||||
" )\n",
|
||||
" input_variables = {\n",
|
||||
" \"query\": query,\n",
|
||||
" \"knowledge\": knowledge,\n",
|
||||
" \"sources\": \"\\n\".join([f\"{title}: {link}\" if link else title for title, link in sources])\n",
|
||||
" }\n",
|
||||
" response_chain = response_prompt | llm\n",
|
||||
" return response_chain.invoke(input_variables).content\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### CRAG process\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def crag_process(query: str, faiss_index: FAISS) -> str:\n",
|
||||
" \"\"\"\n",
|
||||
" Process a query by retrieving, evaluating, and using documents or performing a web search to generate a response.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" query (str): The query string to process.\n",
|
||||
" faiss_index (FAISS): The FAISS index used for document retrieval.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" str: The generated response based on the query.\n",
|
||||
" \"\"\"\n",
|
||||
" print(f\"\\nProcessing query: {query}\")\n",
|
||||
"\n",
|
||||
" # Retrieve and evaluate documents\n",
|
||||
" retrieved_docs = retrieve_documents(query, faiss_index)\n",
|
||||
" eval_scores = evaluate_documents(query, retrieved_docs)\n",
|
||||
" \n",
|
||||
" print(f\"\\nRetrieved {len(retrieved_docs)} documents\")\n",
|
||||
" print(f\"Evaluation scores: {eval_scores}\")\n",
|
||||
"\n",
|
||||
" # Determine action based on evaluation scores\n",
|
||||
" max_score = max(eval_scores)\n",
|
||||
" sources = []\n",
|
||||
" \n",
|
||||
" if max_score > 0.7:\n",
|
||||
" print(\"\\nAction: Correct - Using retrieved document\")\n",
|
||||
" best_doc = retrieved_docs[eval_scores.index(max_score)]\n",
|
||||
" final_knowledge = best_doc\n",
|
||||
" sources.append((\"Retrieved document\", \"\"))\n",
|
||||
" elif max_score < 0.3:\n",
|
||||
" print(\"\\nAction: Incorrect - Performing web search\")\n",
|
||||
" final_knowledge, sources = perform_web_search(query)\n",
|
||||
" else:\n",
|
||||
" print(\"\\nAction: Ambiguous - Combining retrieved document and web search\")\n",
|
||||
" best_doc = retrieved_docs[eval_scores.index(max_score)]\n",
|
||||
" # Refine the retrieved knowledge\n",
|
||||
" retrieved_knowledge = knowledge_refinement(best_doc)\n",
|
||||
" web_knowledge, web_sources = perform_web_search(query)\n",
|
||||
" final_knowledge = \"\\n\".join(retrieved_knowledge + web_knowledge)\n",
|
||||
" sources = [(\"Retrieved document\", \"\")] + web_sources\n",
|
||||
"\n",
|
||||
" print(\"\\nFinal knowledge:\")\n",
|
||||
" print(final_knowledge)\n",
|
||||
" \n",
|
||||
" print(\"\\nSources:\")\n",
|
||||
" for title, link in sources:\n",
|
||||
" print(f\"{title}: {link}\" if link else title)\n",
|
||||
"\n",
|
||||
" # Generate response\n",
|
||||
" print(\"\\nGenerating response...\")\n",
|
||||
" response = generate_response(query, final_knowledge, sources)\n",
|
||||
"\n",
|
||||
" print(\"\\nResponse generated\")\n",
|
||||
" return response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example query with high relevance to the document\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What are the main causes of climate change?\"\n",
|
||||
"result = crag_process(query, vectorstore)\n",
|
||||
"print(f\"Query: {query}\")\n",
|
||||
"print(f\"Answer: {result}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example query with low relevance to the document\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"how did harry beat quirrell?\"\n",
|
||||
"result = crag_process(query, vectorstore)\n",
|
||||
"print(f\"Query: {query}\")\n",
|
||||
"print(f\"Answer: {result}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -62,7 +62,7 @@
|
||||
"source": [
|
||||
"<div style=\"text-align: center;\">\n",
|
||||
"\n",
|
||||
"<img src=\"../images/self_rag.svg\" alt=\"self RAG\" style=\"width:80%; height:auto;\">\n",
|
||||
"<img src=\"../images/self_rag.svg\" alt=\"Self RAG\" style=\"width:80%; height:auto;\">\n",
|
||||
"</div>"
|
||||
]
|
||||
},
|
||||
|
||||
3
images/crag.svg
Normal file
3
images/crag.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 27 KiB |
Reference in New Issue
Block a user