diff --git a/all_rag_techniques/crag.ipynb b/all_rag_techniques/crag.ipynb new file mode 100644 index 0000000..2bf9135 --- /dev/null +++ b/all_rag_techniques/crag.ipynb @@ -0,0 +1,400 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import relevant libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "from dotenv import load_dotenv\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain_openai import ChatOpenAI\n", + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "\n", + "\n", + "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks\n", + "from helper_functions import *\n", + "from evaluation.evalute_rag import *\n", + "\n", + "# Load environment variables from a .env file\n", + "load_dotenv()\n", + "\n", + "# Set the OpenAI API key environment variable\n", + "os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')\n", + "from langchain.tools import DuckDuckGoSearchResults\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define files path" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "path = \"../data/Understanding_Climate_Change.pdf\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create a vector store" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "vectorstore = encode_pdf(path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize OpenAI language model\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(model=\"gpt-4o-mini\", max_tokens=1000, temperature=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize search tool" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "search = DuckDuckGoSearchResults()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define retrieval evaluator, knowledge refinement and query rewriter llm chains" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Retrieval Evaluator\n", + "class RetrievalEvaluatorInput(BaseModel):\n", + " relevance_score: float = Field(..., description=\"The relevance score of the document to the query. the score should be between 0 and 1.\")\n", + "def retrieval_evaluator(query: str, document: str) -> float:\n", + " prompt = PromptTemplate(\n", + " input_variables=[\"query\", \"document\"],\n", + " template=\"On a scale from 0 to 1, how relevant is the following document to the query? Query: {query}\\nDocument: {document}\\nRelevance score:\"\n", + " )\n", + " chain = prompt | llm.with_structured_output(RetrievalEvaluatorInput)\n", + " input_variables = {\"query\": query, \"document\": document}\n", + " result = chain.invoke(input_variables).relevance_score\n", + " return result\n", + "\n", + "# Knowledge Refinement\n", + "class KnowledgeRefinementInput(BaseModel):\n", + " key_points: str = Field(..., description=\"The document to extract key information from.\")\n", + "def knowledge_refinement(document: str) -> List[str]:\n", + " prompt = PromptTemplate(\n", + " input_variables=[\"document\"],\n", + " template=\"Extract the key information from the following document in bullet points:\\n{document}\\nKey points:\"\n", + " )\n", + " chain = prompt | llm.with_structured_output(KnowledgeRefinementInput)\n", + " input_variables = {\"document\": document}\n", + " result = chain.invoke(input_variables).key_points\n", + " return [point.strip() for point in result.split('\\n') if point.strip()]\n", + "\n", + "# Web Search Query Rewriter\n", + "class QueryRewriterInput(BaseModel):\n", + " query: str = Field(..., description=\"The query to rewrite.\")\n", + "def rewrite_query(query: str) -> str:\n", + " prompt = PromptTemplate(\n", + " input_variables=[\"query\"],\n", + " template=\"Rewrite the following query to make it more suitable for a web search:\\n{query}\\nRewritten query:\"\n", + " )\n", + " chain = prompt | llm.with_structured_output(QueryRewriterInput)\n", + " input_variables = {\"query\": query}\n", + " return chain.invoke(input_variables).query.strip()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Helper function to parse search results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "def parse_search_results(results_string: str) -> List[Tuple[str, str]]:\n", + " \"\"\"\n", + " Parse a JSON string of search results into a list of title-link tuples.\n", + "\n", + " Args:\n", + " results_string (str): A JSON-formatted string containing search results.\n", + "\n", + " Returns:\n", + " List[Tuple[str, str]]: A list of tuples, where each tuple contains the title and link of a search result.\n", + " If parsing fails, an empty list is returned.\n", + " \"\"\"\n", + " try:\n", + " # Attempt to parse the JSON string\n", + " results = json.loads(results_string)\n", + " # Extract and return the title and link from each result\n", + " return [(result.get('title', 'Untitled'), result.get('link', '')) for result in results]\n", + " except json.JSONDecodeError:\n", + " # Handle JSON decoding errors by returning an empty list\n", + " print(\"Error parsing search results. Returning empty list.\")\n", + " return []" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define sub functions for the CRAG process" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "def retrieve_documents(query: str, faiss_index: FAISS, k: int = 3) -> List[str]:\n", + " \"\"\"\n", + " Retrieve documents based on a query using a FAISS index.\n", + "\n", + " Args:\n", + " query (str): The query string to search for.\n", + " faiss_index (FAISS): The FAISS index used for similarity search.\n", + " k (int): The number of top documents to retrieve. Defaults to 3.\n", + "\n", + " Returns:\n", + " List[str]: A list of the retrieved document contents.\n", + " \"\"\"\n", + " docs = faiss_index.similarity_search(query, k=k)\n", + " return [doc.page_content for doc in docs]\n", + "\n", + "def evaluate_documents(query: str, documents: List[str]) -> List[float]:\n", + " \"\"\"\n", + " Evaluate the relevance of documents based on a query.\n", + "\n", + " Args:\n", + " query (str): The query string.\n", + " documents (List[str]): A list of document contents to evaluate.\n", + "\n", + " Returns:\n", + " List[float]: A list of relevance scores for each document.\n", + " \"\"\"\n", + " return [retrieval_evaluator(query, doc) for doc in documents]\n", + "\n", + "def perform_web_search(query: str) -> Tuple[List[str], List[Tuple[str, str]]]:\n", + " \"\"\"\n", + " Perform a web search based on a query.\n", + "\n", + " Args:\n", + " query (str): The query string to search for.\n", + "\n", + " Returns:\n", + " Tuple[List[str], List[Tuple[str, str]]]: \n", + " - A list of refined knowledge obtained from the web search.\n", + " - A list of tuples containing titles and links of the sources.\n", + " \"\"\"\n", + " rewritten_query = rewrite_query(query)\n", + " web_results = search.run(rewritten_query)\n", + " web_knowledge = knowledge_refinement(web_results)\n", + " sources = parse_search_results(web_results)\n", + " return web_knowledge, sources\n", + "\n", + "def generate_response(query: str, knowledge: str, sources: List[Tuple[str, str]]) -> str:\n", + " \"\"\"\n", + " Generate a response to a query using knowledge and sources.\n", + "\n", + " Args:\n", + " query (str): The query string.\n", + " knowledge (str): The refined knowledge to use in the response.\n", + " sources (List[Tuple[str, str]]): A list of tuples containing titles and links of the sources.\n", + "\n", + " Returns:\n", + " str: The generated response.\n", + " \"\"\"\n", + " response_prompt = PromptTemplate(\n", + " input_variables=[\"query\", \"knowledge\", \"sources\"],\n", + " template=\"Based on the following knowledge, answer the query. Include the sources with their links (if available) at the end of your answer:\\nQuery: {query}\\nKnowledge: {knowledge}\\nSources: {sources}\\nAnswer:\"\n", + " )\n", + " input_variables = {\n", + " \"query\": query,\n", + " \"knowledge\": knowledge,\n", + " \"sources\": \"\\n\".join([f\"{title}: {link}\" if link else title for title, link in sources])\n", + " }\n", + " response_chain = response_prompt | llm\n", + " return response_chain.invoke(input_variables).content\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CRAG process\n" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "def crag_process(query: str, faiss_index: FAISS) -> str:\n", + " \"\"\"\n", + " Process a query by retrieving, evaluating, and using documents or performing a web search to generate a response.\n", + "\n", + " Args:\n", + " query (str): The query string to process.\n", + " faiss_index (FAISS): The FAISS index used for document retrieval.\n", + "\n", + " Returns:\n", + " str: The generated response based on the query.\n", + " \"\"\"\n", + " print(f\"\\nProcessing query: {query}\")\n", + "\n", + " # Retrieve and evaluate documents\n", + " retrieved_docs = retrieve_documents(query, faiss_index)\n", + " eval_scores = evaluate_documents(query, retrieved_docs)\n", + " \n", + " print(f\"\\nRetrieved {len(retrieved_docs)} documents\")\n", + " print(f\"Evaluation scores: {eval_scores}\")\n", + "\n", + " # Determine action based on evaluation scores\n", + " max_score = max(eval_scores)\n", + " sources = []\n", + " \n", + " if max_score > 0.7:\n", + " print(\"\\nAction: Correct - Using retrieved document\")\n", + " best_doc = retrieved_docs[eval_scores.index(max_score)]\n", + " final_knowledge = best_doc\n", + " sources.append((\"Retrieved document\", \"\"))\n", + " elif max_score < 0.3:\n", + " print(\"\\nAction: Incorrect - Performing web search\")\n", + " final_knowledge, sources = perform_web_search(query)\n", + " else:\n", + " print(\"\\nAction: Ambiguous - Combining retrieved document and web search\")\n", + " best_doc = retrieved_docs[eval_scores.index(max_score)]\n", + " # Refine the retrieved knowledge\n", + " retrieved_knowledge = knowledge_refinement(best_doc)\n", + " web_knowledge, web_sources = perform_web_search(query)\n", + " final_knowledge = \"\\n\".join(retrieved_knowledge + web_knowledge)\n", + " sources = [(\"Retrieved document\", \"\")] + web_sources\n", + "\n", + " print(\"\\nFinal knowledge:\")\n", + " print(final_knowledge)\n", + " \n", + " print(\"\\nSources:\")\n", + " for title, link in sources:\n", + " print(f\"{title}: {link}\" if link else title)\n", + "\n", + " # Generate response\n", + " print(\"\\nGenerating response...\")\n", + " response = generate_response(query, final_knowledge, sources)\n", + "\n", + " print(\"\\nResponse generated\")\n", + " return response" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example query with high relevance to the document\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What are the main causes of climate change?\"\n", + "result = crag_process(query, vectorstore)\n", + "print(f\"Query: {query}\")\n", + "print(f\"Answer: {result}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example query with low relevance to the document\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"how did harry beat quirrell?\"\n", + "result = crag_process(query, vectorstore)\n", + "print(f\"Query: {query}\")\n", + "print(f\"Answer: {result}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/all_rag_techniques/self_rag.ipynb b/all_rag_techniques/self_rag.ipynb index 98f869a..badbcfa 100644 --- a/all_rag_techniques/self_rag.ipynb +++ b/all_rag_techniques/self_rag.ipynb @@ -62,7 +62,7 @@ "source": [ "
\n", "\n", - "\"self\n", + "\"Self\n", "
" ] }, diff --git a/images/crag.svg b/images/crag.svg new file mode 100644 index 0000000..585e839 --- /dev/null +++ b/images/crag.svg @@ -0,0 +1,3 @@ + + +
Response Generation
Knowledge Refinement
Web Search Process
Yes
No
Yes
No
Combine Knowledge and Sources
Generate Response
Generate Final Response
Extract Key Information
Refine Knowledge
Rewrite Query
Perform Web Search
Perform Search
Refine Web Knowledge
Parse Search Results
Start: Query Input
Retrieve Documents from FAISS Index
Evaluate Document Relevance
Max Score > 0.7?
Use Retrieved Document
Max Score < 0.3?
Combine Retrieved Document and Web Search
End: Return Response
\ No newline at end of file