mirror of
https://github.com/NirDiamant/RAG_Techniques.git
synced 2025-04-07 00:48:52 +03:00
344 lines
10 KiB
Plaintext
344 lines
10 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Overview: \n",
|
||
"This code implements one of the multiple ways of multi-model RAG. It extracts and processes text and images from PDFs, utilizing a multi-modal Retrieval-Augmented Generation (RAG) system for summarizing and retrieving content for question answering.\n",
|
||
"\n",
|
||
"### Key Components:\n",
|
||
" - **PyMuPDF**: For extracting text and images from PDFs.\n",
|
||
" - **Gemini 1.5-flash model**: To summarize images and tables.\n",
|
||
" - **Cohere Embeddings**: For embedding document splits.\n",
|
||
" - **Chroma Vectorstore**: To store and retrieve document embeddings.\n",
|
||
" - **LangChain**: To orchestrate the retrieval and generation pipeline."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Diagram:\n",
|
||
" <img src=\"../images/multi_model_rag_with_captioning.svg\" alt=\"Reliable-RAG\" width=\"300\">\n",
|
||
"\n",
|
||
"### Motivation: \n",
|
||
"Efficiently summarize complex documents to facilitate easy retrieval and concise responses for multi-modal data.\n",
|
||
"\n",
|
||
"### Method Details:\n",
|
||
" - Text and images are extracted from the PDF using PyMuPDF.\n",
|
||
" - Summarization is performed on extracted images and tables using Gemini.\n",
|
||
" - Embeddings are generated via Cohere for storage in Chroma.\n",
|
||
" - A similarity-based retriever fetches relevant sections based on the query.\n",
|
||
"\n",
|
||
"### Benefits:\n",
|
||
" - Simplified retrieval from complex, multi-modal documents.\n",
|
||
" - Streamlined Q&A process for both text and images.\n",
|
||
" - Flexible architecture for expanding to more document types.\n",
|
||
"\n",
|
||
"### Implementation:\n",
|
||
" - Documents are split into chunks with overlap using a text splitter.\n",
|
||
" - Summarized text and image content are stored as vectors.\n",
|
||
" - Queries are handled by retrieving relevant document segments and generating concise answers.\n",
|
||
"\n",
|
||
"### Summary: \n",
|
||
"The project enables multi-modal document processing and retrieval, providing concise, relevant responses by combining state-of-the-art LLMs and vector-based retrieval systems."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Imports"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"True"
|
||
]
|
||
},
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"import fitz # PyMuPDF\n",
|
||
"from PIL import Image\n",
|
||
"import io\n",
|
||
"import os\n",
|
||
"from dotenv import load_dotenv\n",
|
||
"\n",
|
||
"import google.generativeai as genai\n",
|
||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||
"from langchain_core.documents import Document\n",
|
||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||
"from langchain_community.vectorstores import Chroma\n",
|
||
"from langchain_cohere import ChatCohere, CohereEmbeddings\n",
|
||
"\n",
|
||
"load_dotenv()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Download the \"Attention is all you need\" paper"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"--2024-09-20 19:19:26-- https://arxiv.org/pdf/1706.03762\n",
|
||
"Resolving arxiv.org (arxiv.org)... 151.101.195.42, 151.101.3.42, 151.101.67.42, ...\n",
|
||
"Connecting to arxiv.org (arxiv.org)|151.101.195.42|:443... connected.\n",
|
||
"HTTP request sent, awaiting response... 200 OK\n",
|
||
"Length: 2215244 (2.1M) [application/pdf]\n",
|
||
"Saving to: ‘1706.03762’\n",
|
||
"\n",
|
||
"1706.03762 100%[===================>] 2.11M 13.3MB/s in 0.2s \n",
|
||
"\n",
|
||
"2024-09-20 19:19:26 (13.3 MB/s) - ‘1706.03762’ saved [2215244/2215244]\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"!wget https://arxiv.org/pdf/1706.03762\n",
|
||
"!mv 1706.03762 attention_is_all_you_need.pdf"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Data Extraction"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"text_data = []\n",
|
||
"img_data = []"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"with fitz.open('attention_is_all_you_need.pdf') as pdf_file:\n",
|
||
" # Create a directory to store the images\n",
|
||
" if not os.path.exists(\"extracted_images\"):\n",
|
||
" os.makedirs(\"extracted_images\")\n",
|
||
"\n",
|
||
" # Loop through every page in the PDF\n",
|
||
" for page_number in range(len(pdf_file)):\n",
|
||
" page = pdf_file[page_number]\n",
|
||
" \n",
|
||
" # Get the text on page\n",
|
||
" text = page.get_text().strip()\n",
|
||
" text_data.append({\"response\": text, \"name\": page_number+1})\n",
|
||
" # Get the list of images on the page\n",
|
||
" images = page.get_images(full=True)\n",
|
||
"\n",
|
||
" # Loop through all images found on the page\n",
|
||
" for image_index, img in enumerate(images, start=0):\n",
|
||
" xref = img[0] # Get the XREF of the image\n",
|
||
" base_image = pdf_file.extract_image(xref) # Extract the image\n",
|
||
" image_bytes = base_image[\"image\"] # Get the image bytes\n",
|
||
" image_ext = base_image[\"ext\"] # Get the image extension\n",
|
||
" \n",
|
||
" # Load the image using PIL and save it\n",
|
||
" image = Image.open(io.BytesIO(image_bytes))\n",
|
||
" image.save(f\"extracted_images/image_{page_number+1}_{image_index+1}.{image_ext}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))\n",
|
||
"model = genai.GenerativeModel(model_name=\"gemini-1.5-flash\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Image Captioning"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for img in os.listdir(\"extracted_images\"):\n",
|
||
" image = Image.open(f\"extracted_images/{img}\")\n",
|
||
" response = model.generate_content([image, \"You are an assistant tasked with summarizing tables, images and text for retrieval. \\\n",
|
||
" These summaries will be embedded and used to retrieve the raw text or table elements \\\n",
|
||
" Give a concise summary of the table or text that is well optimized for retrieval. Table or text or image:\"])\n",
|
||
" img_data.append({\"response\": response.text, \"name\": img})"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Vectostore"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Set embeddings\n",
|
||
"embedding_model = CohereEmbeddings(model=\"embed-english-v3.0\")\n",
|
||
"\n",
|
||
"# Load the document\n",
|
||
"docs_list = [Document(page_content=text['response'], metadata={\"name\": text['name']}) for text in text_data]\n",
|
||
"img_list = [Document(page_content=img['response'], metadata={\"name\": img['name']}) for img in img_data]\n",
|
||
"\n",
|
||
"# Split\n",
|
||
"text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
|
||
" chunk_size=400, chunk_overlap=50\n",
|
||
")\n",
|
||
"\n",
|
||
"doc_splits = text_splitter.split_documents(docs_list)\n",
|
||
"img_splits = text_splitter.split_documents(img_list)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Add to vectorstore\n",
|
||
"vectorstore = Chroma.from_documents(\n",
|
||
" documents=doc_splits + img_splits, # adding the both text and image splits\n",
|
||
" collection_name=\"multi_model_rag\",\n",
|
||
" embedding=embedding_model,\n",
|
||
")\n",
|
||
"\n",
|
||
"retriever = vectorstore.as_retriever(\n",
|
||
" search_type=\"similarity\",\n",
|
||
" search_kwargs={'k': 1}, # number of documents to retrieve\n",
|
||
" )"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Query"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"query = \"What is the BLEU score of the Transformer (base model)?\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"docs = retriever.invoke(query)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Output"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"The Transformer (base model) achieves a BLEU score of 27.3.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from langchain_core.output_parsers import StrOutputParser\n",
|
||
"\n",
|
||
"# Prompt\n",
|
||
"system = \"\"\"You are an assistant for question-answering tasks. Answer the question based upon your knowledge. \n",
|
||
"Use three-to-five sentences maximum and keep the answer concise.\"\"\"\n",
|
||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||
" [\n",
|
||
" (\"system\", system),\n",
|
||
" (\"human\", \"Retrieved documents: \\n\\n <docs>{documents}</docs> \\n\\n User question: <question>{question}</question>\"),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"# LLM\n",
|
||
"llm = ChatCohere(model=\"command-r-plus\", temperature=0)\n",
|
||
"\n",
|
||
"# Chain\n",
|
||
"rag_chain = prompt | llm | StrOutputParser()\n",
|
||
"\n",
|
||
"# Run\n",
|
||
"generation = rag_chain.invoke({\"documents\":docs[0].page_content, \"question\": query})\n",
|
||
"print(generation)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "test",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.0"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|