Files
RAG_Techniques/all_rag_techniques/multi_model_rag_with_captioning.ipynb
2024-09-20 19:56:53 +05:30

344 lines
10 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Overview: \n",
"This code implements one of the multiple ways of multi-model RAG. It extracts and processes text and images from PDFs, utilizing a multi-modal Retrieval-Augmented Generation (RAG) system for summarizing and retrieving content for question answering.\n",
"\n",
"### Key Components:\n",
" - **PyMuPDF**: For extracting text and images from PDFs.\n",
" - **Gemini 1.5-flash model**: To summarize images and tables.\n",
" - **Cohere Embeddings**: For embedding document splits.\n",
" - **Chroma Vectorstore**: To store and retrieve document embeddings.\n",
" - **LangChain**: To orchestrate the retrieval and generation pipeline."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Diagram:\n",
" <img src=\"../images/multi_model_rag_with_captioning.svg\" alt=\"Reliable-RAG\" width=\"300\">\n",
"\n",
"### Motivation: \n",
"Efficiently summarize complex documents to facilitate easy retrieval and concise responses for multi-modal data.\n",
"\n",
"### Method Details:\n",
" - Text and images are extracted from the PDF using PyMuPDF.\n",
" - Summarization is performed on extracted images and tables using Gemini.\n",
" - Embeddings are generated via Cohere for storage in Chroma.\n",
" - A similarity-based retriever fetches relevant sections based on the query.\n",
"\n",
"### Benefits:\n",
" - Simplified retrieval from complex, multi-modal documents.\n",
" - Streamlined Q&A process for both text and images.\n",
" - Flexible architecture for expanding to more document types.\n",
"\n",
"### Implementation:\n",
" - Documents are split into chunks with overlap using a text splitter.\n",
" - Summarized text and image content are stored as vectors.\n",
" - Queries are handled by retrieving relevant document segments and generating concise answers.\n",
"\n",
"### Summary: \n",
"The project enables multi-modal document processing and retrieval, providing concise, relevant responses by combining state-of-the-art LLMs and vector-based retrieval systems."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import fitz # PyMuPDF\n",
"from PIL import Image\n",
"import io\n",
"import os\n",
"from dotenv import load_dotenv\n",
"\n",
"import google.generativeai as genai\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.documents import Document\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain_community.vectorstores import Chroma\n",
"from langchain_cohere import ChatCohere, CohereEmbeddings\n",
"\n",
"load_dotenv()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download the \"Attention is all you need\" paper"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2024-09-20 19:19:26-- https://arxiv.org/pdf/1706.03762\n",
"Resolving arxiv.org (arxiv.org)... 151.101.195.42, 151.101.3.42, 151.101.67.42, ...\n",
"Connecting to arxiv.org (arxiv.org)|151.101.195.42|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 2215244 (2.1M) [application/pdf]\n",
"Saving to: 1706.03762\n",
"\n",
"1706.03762 100%[===================>] 2.11M 13.3MB/s in 0.2s \n",
"\n",
"2024-09-20 19:19:26 (13.3 MB/s) - 1706.03762 saved [2215244/2215244]\n",
"\n"
]
}
],
"source": [
"!wget https://arxiv.org/pdf/1706.03762\n",
"!mv 1706.03762 attention_is_all_you_need.pdf"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data Extraction"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"text_data = []\n",
"img_data = []"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"with fitz.open('attention_is_all_you_need.pdf') as pdf_file:\n",
" # Create a directory to store the images\n",
" if not os.path.exists(\"extracted_images\"):\n",
" os.makedirs(\"extracted_images\")\n",
"\n",
" # Loop through every page in the PDF\n",
" for page_number in range(len(pdf_file)):\n",
" page = pdf_file[page_number]\n",
" \n",
" # Get the text on page\n",
" text = page.get_text().strip()\n",
" text_data.append({\"response\": text, \"name\": page_number+1})\n",
" # Get the list of images on the page\n",
" images = page.get_images(full=True)\n",
"\n",
" # Loop through all images found on the page\n",
" for image_index, img in enumerate(images, start=0):\n",
" xref = img[0] # Get the XREF of the image\n",
" base_image = pdf_file.extract_image(xref) # Extract the image\n",
" image_bytes = base_image[\"image\"] # Get the image bytes\n",
" image_ext = base_image[\"ext\"] # Get the image extension\n",
" \n",
" # Load the image using PIL and save it\n",
" image = Image.open(io.BytesIO(image_bytes))\n",
" image.save(f\"extracted_images/image_{page_number+1}_{image_index+1}.{image_ext}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))\n",
"model = genai.GenerativeModel(model_name=\"gemini-1.5-flash\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Image Captioning"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"for img in os.listdir(\"extracted_images\"):\n",
" image = Image.open(f\"extracted_images/{img}\")\n",
" response = model.generate_content([image, \"You are an assistant tasked with summarizing tables, images and text for retrieval. \\\n",
" These summaries will be embedded and used to retrieve the raw text or table elements \\\n",
" Give a concise summary of the table or text that is well optimized for retrieval. Table or text or image:\"])\n",
" img_data.append({\"response\": response.text, \"name\": img})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Vectostore"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Set embeddings\n",
"embedding_model = CohereEmbeddings(model=\"embed-english-v3.0\")\n",
"\n",
"# Load the document\n",
"docs_list = [Document(page_content=text['response'], metadata={\"name\": text['name']}) for text in text_data]\n",
"img_list = [Document(page_content=img['response'], metadata={\"name\": img['name']}) for img in img_data]\n",
"\n",
"# Split\n",
"text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
" chunk_size=400, chunk_overlap=50\n",
")\n",
"\n",
"doc_splits = text_splitter.split_documents(docs_list)\n",
"img_splits = text_splitter.split_documents(img_list)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Add to vectorstore\n",
"vectorstore = Chroma.from_documents(\n",
" documents=doc_splits + img_splits, # adding the both text and image splits\n",
" collection_name=\"multi_model_rag\",\n",
" embedding=embedding_model,\n",
")\n",
"\n",
"retriever = vectorstore.as_retriever(\n",
" search_type=\"similarity\",\n",
" search_kwargs={'k': 1}, # number of documents to retrieve\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Query"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"query = \"What is the BLEU score of the Transformer (base model)?\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"docs = retriever.invoke(query)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Output"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The Transformer (base model) achieves a BLEU score of 27.3.\n"
]
}
],
"source": [
"from langchain_core.output_parsers import StrOutputParser\n",
"\n",
"# Prompt\n",
"system = \"\"\"You are an assistant for question-answering tasks. Answer the question based upon your knowledge. \n",
"Use three-to-five sentences maximum and keep the answer concise.\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", system),\n",
" (\"human\", \"Retrieved documents: \\n\\n <docs>{documents}</docs> \\n\\n User question: <question>{question}</question>\"),\n",
" ]\n",
")\n",
"\n",
"# LLM\n",
"llm = ChatCohere(model=\"command-r-plus\", temperature=0)\n",
"\n",
"# Chain\n",
"rag_chain = prompt | llm | StrOutputParser()\n",
"\n",
"# Run\n",
"generation = rag_chain.invoke({\"documents\":docs[0].page_content, \"question\": query})\n",
"print(generation)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "test",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}