mirror of
				https://github.com/NirDiamant/RAG_Techniques.git
				synced 2025-04-07 00:48:52 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			344 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			344 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
{
 | 
						||
 "cells": [
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "### Overview: \n",
 | 
						||
    "This code implements one of the multiple ways of multi-model RAG. It extracts and processes text and images from PDFs, utilizing a multi-modal Retrieval-Augmented Generation (RAG) system for summarizing and retrieving content for question answering.\n",
 | 
						||
    "\n",
 | 
						||
    "### Key Components:\n",
 | 
						||
    "   - **PyMuPDF**: For extracting text and images from PDFs.\n",
 | 
						||
    "   - **Gemini 1.5-flash model**: To summarize images and tables.\n",
 | 
						||
    "   - **Cohere Embeddings**: For embedding document splits.\n",
 | 
						||
    "   - **Chroma Vectorstore**: To store and retrieve document embeddings.\n",
 | 
						||
    "   - **LangChain**: To orchestrate the retrieval and generation pipeline."
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "### Diagram:\n",
 | 
						||
    "   <img src=\"../images/multi_model_rag_with_captioning.svg\" alt=\"Reliable-RAG\" width=\"300\">\n",
 | 
						||
    "\n",
 | 
						||
    "### Motivation: \n",
 | 
						||
    "Efficiently summarize complex documents to facilitate easy retrieval and concise responses for multi-modal data.\n",
 | 
						||
    "\n",
 | 
						||
    "### Method Details:\n",
 | 
						||
    "   - Text and images are extracted from the PDF using PyMuPDF.\n",
 | 
						||
    "   - Summarization is performed on extracted images and tables using Gemini.\n",
 | 
						||
    "   - Embeddings are generated via Cohere for storage in Chroma.\n",
 | 
						||
    "   - A similarity-based retriever fetches relevant sections based on the query.\n",
 | 
						||
    "\n",
 | 
						||
    "### Benefits:\n",
 | 
						||
    "   - Simplified retrieval from complex, multi-modal documents.\n",
 | 
						||
    "   - Streamlined Q&A process for both text and images.\n",
 | 
						||
    "   - Flexible architecture for expanding to more document types.\n",
 | 
						||
    "\n",
 | 
						||
    "### Implementation:\n",
 | 
						||
    "   - Documents are split into chunks with overlap using a text splitter.\n",
 | 
						||
    "   - Summarized text and image content are stored as vectors.\n",
 | 
						||
    "   - Queries are handled by retrieving relevant document segments and generating concise answers.\n",
 | 
						||
    "\n",
 | 
						||
    "### Summary: \n",
 | 
						||
    "The project enables multi-modal document processing and retrieval, providing concise, relevant responses by combining state-of-the-art LLMs and vector-based retrieval systems."
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "### Imports"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 1,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "data": {
 | 
						||
      "text/plain": [
 | 
						||
       "True"
 | 
						||
      ]
 | 
						||
     },
 | 
						||
     "execution_count": 1,
 | 
						||
     "metadata": {},
 | 
						||
     "output_type": "execute_result"
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "import fitz  # PyMuPDF\n",
 | 
						||
    "from PIL import Image\n",
 | 
						||
    "import io\n",
 | 
						||
    "import os\n",
 | 
						||
    "from dotenv import load_dotenv\n",
 | 
						||
    "\n",
 | 
						||
    "import google.generativeai as genai\n",
 | 
						||
    "from langchain_core.prompts import ChatPromptTemplate\n",
 | 
						||
    "from langchain_core.documents import Document\n",
 | 
						||
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
 | 
						||
    "from langchain_community.vectorstores import Chroma\n",
 | 
						||
    "from langchain_cohere import ChatCohere, CohereEmbeddings\n",
 | 
						||
    "\n",
 | 
						||
    "load_dotenv()"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "### Download the \"Attention is all you need\" paper"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 2,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "--2024-09-20 19:19:26--  https://arxiv.org/pdf/1706.03762\n",
 | 
						||
      "Resolving arxiv.org (arxiv.org)... 151.101.195.42, 151.101.3.42, 151.101.67.42, ...\n",
 | 
						||
      "Connecting to arxiv.org (arxiv.org)|151.101.195.42|:443... connected.\n",
 | 
						||
      "HTTP request sent, awaiting response... 200 OK\n",
 | 
						||
      "Length: 2215244 (2.1M) [application/pdf]\n",
 | 
						||
      "Saving to: ‘1706.03762’\n",
 | 
						||
      "\n",
 | 
						||
      "1706.03762          100%[===================>]   2.11M  13.3MB/s    in 0.2s    \n",
 | 
						||
      "\n",
 | 
						||
      "2024-09-20 19:19:26 (13.3 MB/s) - ‘1706.03762’ saved [2215244/2215244]\n",
 | 
						||
      "\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "!wget https://arxiv.org/pdf/1706.03762\n",
 | 
						||
    "!mv 1706.03762 attention_is_all_you_need.pdf"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "### Data Extraction"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 3,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "text_data = []\n",
 | 
						||
    "img_data = []"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 4,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "with fitz.open('attention_is_all_you_need.pdf') as pdf_file:\n",
 | 
						||
    "    # Create a directory to store the images\n",
 | 
						||
    "    if not os.path.exists(\"extracted_images\"):\n",
 | 
						||
    "        os.makedirs(\"extracted_images\")\n",
 | 
						||
    "\n",
 | 
						||
    "    # Loop through every page in the PDF\n",
 | 
						||
    "    for page_number in range(len(pdf_file)):\n",
 | 
						||
    "        page = pdf_file[page_number]\n",
 | 
						||
    "        \n",
 | 
						||
    "        # Get the text on page\n",
 | 
						||
    "        text = page.get_text().strip()\n",
 | 
						||
    "        text_data.append({\"response\": text, \"name\": page_number+1})\n",
 | 
						||
    "        # Get the list of images on the page\n",
 | 
						||
    "        images = page.get_images(full=True)\n",
 | 
						||
    "\n",
 | 
						||
    "        # Loop through all images found on the page\n",
 | 
						||
    "        for image_index, img in enumerate(images, start=0):\n",
 | 
						||
    "            xref = img[0]  # Get the XREF of the image\n",
 | 
						||
    "            base_image = pdf_file.extract_image(xref)  # Extract the image\n",
 | 
						||
    "            image_bytes = base_image[\"image\"]  # Get the image bytes\n",
 | 
						||
    "            image_ext = base_image[\"ext\"]  # Get the image extension\n",
 | 
						||
    "            \n",
 | 
						||
    "            # Load the image using PIL and save it\n",
 | 
						||
    "            image = Image.open(io.BytesIO(image_bytes))\n",
 | 
						||
    "            image.save(f\"extracted_images/image_{page_number+1}_{image_index+1}.{image_ext}\")"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 5,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))\n",
 | 
						||
    "model = genai.GenerativeModel(model_name=\"gemini-1.5-flash\")"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "### Image Captioning"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 6,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "for img in os.listdir(\"extracted_images\"):\n",
 | 
						||
    "    image = Image.open(f\"extracted_images/{img}\")\n",
 | 
						||
    "    response = model.generate_content([image, \"You are an assistant tasked with summarizing tables, images and text for retrieval. \\\n",
 | 
						||
    "    These summaries will be embedded and used to retrieve the raw text or table elements \\\n",
 | 
						||
    "    Give a concise summary of the table or text that is well optimized for retrieval. Table or text or image:\"])\n",
 | 
						||
    "    img_data.append({\"response\": response.text, \"name\": img})"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "### Vectostore"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 7,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# Set embeddings\n",
 | 
						||
    "embedding_model = CohereEmbeddings(model=\"embed-english-v3.0\")\n",
 | 
						||
    "\n",
 | 
						||
    "# Load the document\n",
 | 
						||
    "docs_list = [Document(page_content=text['response'], metadata={\"name\": text['name']}) for text in text_data]\n",
 | 
						||
    "img_list = [Document(page_content=img['response'], metadata={\"name\": img['name']}) for img in img_data]\n",
 | 
						||
    "\n",
 | 
						||
    "# Split\n",
 | 
						||
    "text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
 | 
						||
    "    chunk_size=400, chunk_overlap=50\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "doc_splits = text_splitter.split_documents(docs_list)\n",
 | 
						||
    "img_splits = text_splitter.split_documents(img_list)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 8,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "# Add to vectorstore\n",
 | 
						||
    "vectorstore = Chroma.from_documents(\n",
 | 
						||
    "    documents=doc_splits + img_splits, # adding the both text and image splits\n",
 | 
						||
    "    collection_name=\"multi_model_rag\",\n",
 | 
						||
    "    embedding=embedding_model,\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "retriever = vectorstore.as_retriever(\n",
 | 
						||
    "                search_type=\"similarity\",\n",
 | 
						||
    "                search_kwargs={'k': 1}, # number of documents to retrieve\n",
 | 
						||
    "            )"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "### Query"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 9,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "query = \"What is the BLEU score of the Transformer (base model)?\""
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 10,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [],
 | 
						||
   "source": [
 | 
						||
    "docs = retriever.invoke(query)"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "markdown",
 | 
						||
   "metadata": {},
 | 
						||
   "source": [
 | 
						||
    "### Output"
 | 
						||
   ]
 | 
						||
  },
 | 
						||
  {
 | 
						||
   "cell_type": "code",
 | 
						||
   "execution_count": 11,
 | 
						||
   "metadata": {},
 | 
						||
   "outputs": [
 | 
						||
    {
 | 
						||
     "name": "stdout",
 | 
						||
     "output_type": "stream",
 | 
						||
     "text": [
 | 
						||
      "The Transformer (base model) achieves a BLEU score of 27.3.\n"
 | 
						||
     ]
 | 
						||
    }
 | 
						||
   ],
 | 
						||
   "source": [
 | 
						||
    "from langchain_core.output_parsers import StrOutputParser\n",
 | 
						||
    "\n",
 | 
						||
    "# Prompt\n",
 | 
						||
    "system = \"\"\"You are an assistant for question-answering tasks. Answer the question based upon your knowledge. \n",
 | 
						||
    "Use three-to-five sentences maximum and keep the answer concise.\"\"\"\n",
 | 
						||
    "prompt = ChatPromptTemplate.from_messages(\n",
 | 
						||
    "    [\n",
 | 
						||
    "        (\"system\", system),\n",
 | 
						||
    "        (\"human\", \"Retrieved documents: \\n\\n <docs>{documents}</docs> \\n\\n User question: <question>{question}</question>\"),\n",
 | 
						||
    "    ]\n",
 | 
						||
    ")\n",
 | 
						||
    "\n",
 | 
						||
    "# LLM\n",
 | 
						||
    "llm = ChatCohere(model=\"command-r-plus\", temperature=0)\n",
 | 
						||
    "\n",
 | 
						||
    "# Chain\n",
 | 
						||
    "rag_chain = prompt | llm | StrOutputParser()\n",
 | 
						||
    "\n",
 | 
						||
    "# Run\n",
 | 
						||
    "generation = rag_chain.invoke({\"documents\":docs[0].page_content, \"question\": query})\n",
 | 
						||
    "print(generation)"
 | 
						||
   ]
 | 
						||
  }
 | 
						||
 ],
 | 
						||
 "metadata": {
 | 
						||
  "kernelspec": {
 | 
						||
   "display_name": "test",
 | 
						||
   "language": "python",
 | 
						||
   "name": "python3"
 | 
						||
  },
 | 
						||
  "language_info": {
 | 
						||
   "codemirror_mode": {
 | 
						||
    "name": "ipython",
 | 
						||
    "version": 3
 | 
						||
   },
 | 
						||
   "file_extension": ".py",
 | 
						||
   "mimetype": "text/x-python",
 | 
						||
   "name": "python",
 | 
						||
   "nbconvert_exporter": "python",
 | 
						||
   "pygments_lexer": "ipython3",
 | 
						||
   "version": "3.11.0"
 | 
						||
  }
 | 
						||
 },
 | 
						||
 "nbformat": 4,
 | 
						||
 "nbformat_minor": 2
 | 
						||
}
 |