diff --git a/all_rag_techniques/crag.ipynb b/all_rag_techniques/crag.ipynb new file mode 100644 index 0000000..2bf9135 --- /dev/null +++ b/all_rag_techniques/crag.ipynb @@ -0,0 +1,400 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import relevant libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "from dotenv import load_dotenv\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain_openai import ChatOpenAI\n", + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "\n", + "\n", + "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks\n", + "from helper_functions import *\n", + "from evaluation.evalute_rag import *\n", + "\n", + "# Load environment variables from a .env file\n", + "load_dotenv()\n", + "\n", + "# Set the OpenAI API key environment variable\n", + "os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')\n", + "from langchain.tools import DuckDuckGoSearchResults\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define files path" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "path = \"../data/Understanding_Climate_Change.pdf\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create a vector store" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "vectorstore = encode_pdf(path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize OpenAI language model\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(model=\"gpt-4o-mini\", max_tokens=1000, temperature=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize search tool" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "search = DuckDuckGoSearchResults()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define retrieval evaluator, knowledge refinement and query rewriter llm chains" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Retrieval Evaluator\n", + "class RetrievalEvaluatorInput(BaseModel):\n", + " relevance_score: float = Field(..., description=\"The relevance score of the document to the query. the score should be between 0 and 1.\")\n", + "def retrieval_evaluator(query: str, document: str) -> float:\n", + " prompt = PromptTemplate(\n", + " input_variables=[\"query\", \"document\"],\n", + " template=\"On a scale from 0 to 1, how relevant is the following document to the query? Query: {query}\\nDocument: {document}\\nRelevance score:\"\n", + " )\n", + " chain = prompt | llm.with_structured_output(RetrievalEvaluatorInput)\n", + " input_variables = {\"query\": query, \"document\": document}\n", + " result = chain.invoke(input_variables).relevance_score\n", + " return result\n", + "\n", + "# Knowledge Refinement\n", + "class KnowledgeRefinementInput(BaseModel):\n", + " key_points: str = Field(..., description=\"The document to extract key information from.\")\n", + "def knowledge_refinement(document: str) -> List[str]:\n", + " prompt = PromptTemplate(\n", + " input_variables=[\"document\"],\n", + " template=\"Extract the key information from the following document in bullet points:\\n{document}\\nKey points:\"\n", + " )\n", + " chain = prompt | llm.with_structured_output(KnowledgeRefinementInput)\n", + " input_variables = {\"document\": document}\n", + " result = chain.invoke(input_variables).key_points\n", + " return [point.strip() for point in result.split('\\n') if point.strip()]\n", + "\n", + "# Web Search Query Rewriter\n", + "class QueryRewriterInput(BaseModel):\n", + " query: str = Field(..., description=\"The query to rewrite.\")\n", + "def rewrite_query(query: str) -> str:\n", + " prompt = PromptTemplate(\n", + " input_variables=[\"query\"],\n", + " template=\"Rewrite the following query to make it more suitable for a web search:\\n{query}\\nRewritten query:\"\n", + " )\n", + " chain = prompt | llm.with_structured_output(QueryRewriterInput)\n", + " input_variables = {\"query\": query}\n", + " return chain.invoke(input_variables).query.strip()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Helper function to parse search results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "def parse_search_results(results_string: str) -> List[Tuple[str, str]]:\n", + " \"\"\"\n", + " Parse a JSON string of search results into a list of title-link tuples.\n", + "\n", + " Args:\n", + " results_string (str): A JSON-formatted string containing search results.\n", + "\n", + " Returns:\n", + " List[Tuple[str, str]]: A list of tuples, where each tuple contains the title and link of a search result.\n", + " If parsing fails, an empty list is returned.\n", + " \"\"\"\n", + " try:\n", + " # Attempt to parse the JSON string\n", + " results = json.loads(results_string)\n", + " # Extract and return the title and link from each result\n", + " return [(result.get('title', 'Untitled'), result.get('link', '')) for result in results]\n", + " except json.JSONDecodeError:\n", + " # Handle JSON decoding errors by returning an empty list\n", + " print(\"Error parsing search results. Returning empty list.\")\n", + " return []" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define sub functions for the CRAG process" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "def retrieve_documents(query: str, faiss_index: FAISS, k: int = 3) -> List[str]:\n", + " \"\"\"\n", + " Retrieve documents based on a query using a FAISS index.\n", + "\n", + " Args:\n", + " query (str): The query string to search for.\n", + " faiss_index (FAISS): The FAISS index used for similarity search.\n", + " k (int): The number of top documents to retrieve. Defaults to 3.\n", + "\n", + " Returns:\n", + " List[str]: A list of the retrieved document contents.\n", + " \"\"\"\n", + " docs = faiss_index.similarity_search(query, k=k)\n", + " return [doc.page_content for doc in docs]\n", + "\n", + "def evaluate_documents(query: str, documents: List[str]) -> List[float]:\n", + " \"\"\"\n", + " Evaluate the relevance of documents based on a query.\n", + "\n", + " Args:\n", + " query (str): The query string.\n", + " documents (List[str]): A list of document contents to evaluate.\n", + "\n", + " Returns:\n", + " List[float]: A list of relevance scores for each document.\n", + " \"\"\"\n", + " return [retrieval_evaluator(query, doc) for doc in documents]\n", + "\n", + "def perform_web_search(query: str) -> Tuple[List[str], List[Tuple[str, str]]]:\n", + " \"\"\"\n", + " Perform a web search based on a query.\n", + "\n", + " Args:\n", + " query (str): The query string to search for.\n", + "\n", + " Returns:\n", + " Tuple[List[str], List[Tuple[str, str]]]: \n", + " - A list of refined knowledge obtained from the web search.\n", + " - A list of tuples containing titles and links of the sources.\n", + " \"\"\"\n", + " rewritten_query = rewrite_query(query)\n", + " web_results = search.run(rewritten_query)\n", + " web_knowledge = knowledge_refinement(web_results)\n", + " sources = parse_search_results(web_results)\n", + " return web_knowledge, sources\n", + "\n", + "def generate_response(query: str, knowledge: str, sources: List[Tuple[str, str]]) -> str:\n", + " \"\"\"\n", + " Generate a response to a query using knowledge and sources.\n", + "\n", + " Args:\n", + " query (str): The query string.\n", + " knowledge (str): The refined knowledge to use in the response.\n", + " sources (List[Tuple[str, str]]): A list of tuples containing titles and links of the sources.\n", + "\n", + " Returns:\n", + " str: The generated response.\n", + " \"\"\"\n", + " response_prompt = PromptTemplate(\n", + " input_variables=[\"query\", \"knowledge\", \"sources\"],\n", + " template=\"Based on the following knowledge, answer the query. Include the sources with their links (if available) at the end of your answer:\\nQuery: {query}\\nKnowledge: {knowledge}\\nSources: {sources}\\nAnswer:\"\n", + " )\n", + " input_variables = {\n", + " \"query\": query,\n", + " \"knowledge\": knowledge,\n", + " \"sources\": \"\\n\".join([f\"{title}: {link}\" if link else title for title, link in sources])\n", + " }\n", + " response_chain = response_prompt | llm\n", + " return response_chain.invoke(input_variables).content\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CRAG process\n" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "def crag_process(query: str, faiss_index: FAISS) -> str:\n", + " \"\"\"\n", + " Process a query by retrieving, evaluating, and using documents or performing a web search to generate a response.\n", + "\n", + " Args:\n", + " query (str): The query string to process.\n", + " faiss_index (FAISS): The FAISS index used for document retrieval.\n", + "\n", + " Returns:\n", + " str: The generated response based on the query.\n", + " \"\"\"\n", + " print(f\"\\nProcessing query: {query}\")\n", + "\n", + " # Retrieve and evaluate documents\n", + " retrieved_docs = retrieve_documents(query, faiss_index)\n", + " eval_scores = evaluate_documents(query, retrieved_docs)\n", + " \n", + " print(f\"\\nRetrieved {len(retrieved_docs)} documents\")\n", + " print(f\"Evaluation scores: {eval_scores}\")\n", + "\n", + " # Determine action based on evaluation scores\n", + " max_score = max(eval_scores)\n", + " sources = []\n", + " \n", + " if max_score > 0.7:\n", + " print(\"\\nAction: Correct - Using retrieved document\")\n", + " best_doc = retrieved_docs[eval_scores.index(max_score)]\n", + " final_knowledge = best_doc\n", + " sources.append((\"Retrieved document\", \"\"))\n", + " elif max_score < 0.3:\n", + " print(\"\\nAction: Incorrect - Performing web search\")\n", + " final_knowledge, sources = perform_web_search(query)\n", + " else:\n", + " print(\"\\nAction: Ambiguous - Combining retrieved document and web search\")\n", + " best_doc = retrieved_docs[eval_scores.index(max_score)]\n", + " # Refine the retrieved knowledge\n", + " retrieved_knowledge = knowledge_refinement(best_doc)\n", + " web_knowledge, web_sources = perform_web_search(query)\n", + " final_knowledge = \"\\n\".join(retrieved_knowledge + web_knowledge)\n", + " sources = [(\"Retrieved document\", \"\")] + web_sources\n", + "\n", + " print(\"\\nFinal knowledge:\")\n", + " print(final_knowledge)\n", + " \n", + " print(\"\\nSources:\")\n", + " for title, link in sources:\n", + " print(f\"{title}: {link}\" if link else title)\n", + "\n", + " # Generate response\n", + " print(\"\\nGenerating response...\")\n", + " response = generate_response(query, final_knowledge, sources)\n", + "\n", + " print(\"\\nResponse generated\")\n", + " return response" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example query with high relevance to the document\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What are the main causes of climate change?\"\n", + "result = crag_process(query, vectorstore)\n", + "print(f\"Query: {query}\")\n", + "print(f\"Answer: {result}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example query with low relevance to the document\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"how did harry beat quirrell?\"\n", + "result = crag_process(query, vectorstore)\n", + "print(f\"Query: {query}\")\n", + "print(f\"Answer: {result}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/all_rag_techniques/self_rag.ipynb b/all_rag_techniques/self_rag.ipynb index 98f869a..badbcfa 100644 --- a/all_rag_techniques/self_rag.ipynb +++ b/all_rag_techniques/self_rag.ipynb @@ -62,7 +62,7 @@ "source": [ "