Files
RAG_Techniques/all_rag_techniques/document_augmentation.ipynb
2024-08-22 19:21:06 +03:00

478 lines
20 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Document Augmentation through Question Generation for Enhanced Retrieval\n",
"\n",
"## Overview\n",
"\n",
"This implementation demonstrates a text augmentation technique that leverages additional question generation to improve document retrieval within a vector database. By generating and incorporating various questions related to each text fragment, the system enhances the standard retrieval process, thus increasing the likelihood of finding relevant documents that can be utilized as context for generative question answering.\n",
"\n",
"## Motivation\n",
"\n",
"By enriching text fragments with related questions, we aim to significantly enhance the accuracy of identifying the most relevant sections of a document that contain answers to user queries.\n",
"\n",
"## Prerequisites\n",
"\n",
"This approach utilizes OpenAI's language models and embeddings. You'll need an OpenAI API key to use this implementation. Make sure you have the required Python packages installed:\n",
"\n",
"```\n",
"pip install langchain openai faiss-cpu PyPDF2 pydantic\n",
"```\n",
"\n",
"## Key Components\n",
"\n",
"1. **PDF Processing and Text Chunking**: Handling PDF documents and dividing them into manageable text fragments.\n",
"2. **Question Augmentation**: Generating relevant questions at both the document and fragment levels using OpenAI's language models.\n",
"3. **Vector Store Creation**: Calculating embeddings for documents using OpenAI's embedding model and creating a FAISS vector store.\n",
"4. **Retrieval and Answer Generation**: Finding the most relevant document using FAISS and generating answers based on the context provided.\n",
"\n",
"## Method Details\n",
"\n",
"### Document Preprocessing\n",
"\n",
"1. Convert the PDF to a string using PyPDFLoader from LangChain.\n",
"2. Split the text into overlapping text documents (text_document) for building context purpose and then each document to overlapping text fragments (text_fragment) for retrieval and semantic search purpose.\n",
"\n",
"### Document Augmentation\n",
"\n",
"1. Generate questions at the document or text fragment level using OpenAI's language models.\n",
"2. Configure the number of questions to generate using the QUESTIONS_PER_DOCUMENT constant.\n",
"\n",
"### Vector Store Creation\n",
"\n",
"1. Use the OpenAIEmbeddings class to compute document embeddings.\n",
"2. Create a FAISS vector store from these embeddings.\n",
"\n",
"### Retrieval and Generation\n",
"\n",
"1. Retrieve the most relevant document from the FAISS store based on the given query.\n",
"2. Use the retrieved document as context for generating answers with OpenAI's language models.\n",
"\n",
"## Benefits of This Approach\n",
"\n",
"1. **Enhanced Retrieval Process**: Increases the probability of finding the most relevant FAISS document for a given query.\n",
"2. **Flexible Context Adjustment**: Allows for easy adjustment of the context window size for both text documents and fragments.\n",
"3. **High-Quality Language Understanding**: Leverages OpenAI's powerful language models for question generation and answer production.\n",
"\n",
"## Implementation Details\n",
"\n",
"- The `OpenAIEmbeddingsWrapper` class provides a consistent interface for embedding generation.\n",
"- The `generate_questions` function uses OpenAI's chat models to create relevant questions from the text.\n",
"- The `process_documents` function handles the core logic of document splitting, question generation, and vector store creation.\n",
"- The main execution demonstrates loading a PDF, processing its content, and performing a sample query.\n",
"\n",
"## Conclusion\n",
"\n",
"This technique provides a method to improve the quality of information retrieval in vector-based document search systems. By generating additional questions similar to user queries and utilizing OpenAI's advanced language models, it potentially leads to better comprehension and more accurate responses in subsequent tasks, such as question answering.\n",
"\n",
"## Note on API Usage\n",
"\n",
"Be aware that this implementation uses OpenAI's API, which may incur costs based on usage. Make sure to monitor your API usage and set appropriate limits in your OpenAI account settings."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import libraries and set constants"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import re\n",
"from langchain.docstore.document import Document\n",
"from langchain.vectorstores import FAISS\n",
"from enum import Enum\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain_openai import ChatOpenAI\n",
"from typing import Any, Dict, List, Tuple\n",
"\n",
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = os.getenv('OPENAI_API_KEY')\n",
"\n",
"\n",
"sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks\n",
"\n",
"from helper_functions import *\n",
"\n",
"\n",
"class QuestionGeneration(Enum):\n",
" \"\"\"\n",
" Enum class to specify the level of question generation for document processing.\n",
"\n",
" Attributes:\n",
" DOCUMENT_LEVEL (int): Represents question generation at the entire document level.\n",
" FRAGMENT_LEVEL (int): Represents question generation at the individual text fragment level.\n",
" \"\"\"\n",
" DOCUMENT_LEVEL = 1\n",
" FRAGMENT_LEVEL = 2\n",
"\n",
"#Depending on the model, for Mitral 7B it can be max 8000, for Llama 3.1 8B 128k\n",
"DOCUMENT_MAX_TOKENS = 4000\n",
"DOCUMENT_OVERLAP_TOKENS = 100\n",
"\n",
"#Embeddings and text similarity calculated on shorter texts\n",
"FRAGMENT_MAX_TOKENS = 128\n",
"FRAGMENT_OVERLAP_TOKENS = 16\n",
"\n",
"#Questions generated on document or fragment level\n",
"QUESTION_GENERATION = QuestionGeneration.DOCUMENT_LEVEL\n",
"#how many questions will be generated for specific document or fragment\n",
"QUESTIONS_PER_DOCUMENT = 40"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define classes and functions used by this pipeline"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class QuestionList(BaseModel):\n",
" question_list: List[str] = Field(..., title=\"List of questions generated for the document or fragment\")\n",
"\n",
"\n",
"class OpenAIEmbeddingsWrapper(OpenAIEmbeddings):\n",
" \"\"\"\n",
" A wrapper class for OpenAI embeddings, providing a similar interface to the original OllamaEmbeddings.\n",
" \"\"\"\n",
" \n",
" def __call__(self, query: str) -> List[float]:\n",
" \"\"\"\n",
" Allows the instance to be used as a callable to generate an embedding for a query.\n",
"\n",
" Args:\n",
" query (str): The query string to be embedded.\n",
"\n",
" Returns:\n",
" List[float]: The embedding for the query as a list of floats.\n",
" \"\"\"\n",
" return self.embed_query(query)\n",
"\n",
"def clean_and_filter_questions(questions: List[str]) -> List[str]:\n",
" \"\"\"\n",
" Cleans and filters a list of questions.\n",
"\n",
" Args:\n",
" questions (List[str]): A list of questions to be cleaned and filtered.\n",
"\n",
" Returns:\n",
" List[str]: A list of cleaned and filtered questions that end with a question mark.\n",
" \"\"\"\n",
" cleaned_questions = []\n",
" for question in questions:\n",
" cleaned_question = re.sub(r'^\\d+\\.\\s*', '', question.strip())\n",
" if cleaned_question.endswith('?'):\n",
" cleaned_questions.append(cleaned_question)\n",
" return cleaned_questions\n",
"\n",
"def generate_questions(text: str) -> List[str]:\n",
" \"\"\"\n",
" Generates a list of questions based on the provided text using OpenAI.\n",
"\n",
" Args:\n",
" text (str): The context data from which questions are generated.\n",
"\n",
" Returns:\n",
" List[str]: A list of unique, filtered questions.\n",
" \"\"\"\n",
" llm = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)\n",
" prompt = PromptTemplate(\n",
" input_variables=[\"context\", \"num_questions\"],\n",
" template=\"Using the context data: {context}\\n\\nGenerate a list of at least {num_questions} \"\n",
" \"possible questions that can be asked about this context. Ensure the questions are \"\n",
" \"directly answerable within the context and do not include any answers or headers. \"\n",
" \"Separate the questions with a new line character.\"\n",
" )\n",
" chain = prompt | llm.with_structured_output(QuestionList)\n",
" input_data = {\"context\": text, \"num_questions\": QUESTIONS_PER_DOCUMENT}\n",
" result = chain.invoke(input_data)\n",
" \n",
" # Extract the list of questions from the QuestionList object\n",
" questions = result.question_list\n",
" \n",
" filtered_questions = clean_and_filter_questions(questions)\n",
" return list(set(filtered_questions))\n",
"\n",
"def generate_answer(content: str, question: str) -> str:\n",
" \"\"\"\n",
" Generates an answer to a given question based on the provided context using OpenAI.\n",
"\n",
" Args:\n",
" content (str): The context data used to generate the answer.\n",
" question (str): The question for which the answer is generated.\n",
"\n",
" Returns:\n",
" str: The precise answer to the question based on the provided context.\n",
" \"\"\"\n",
" llm = ChatOpenAI(model=\"gpt-4o-mini\",temperature=0)\n",
" prompt = PromptTemplate(\n",
" input_variables=[\"context\", \"question\"],\n",
" template=\"Using the context data: {context}\\n\\nProvide a brief and precise answer to the question: {question}\"\n",
" )\n",
" chain = prompt | llm\n",
" input_data = {\"context\": content, \"question\": question}\n",
" return chain.invoke(input_data)\n",
"\n",
"def split_document(document: str, chunk_size: int, chunk_overlap: int) -> List[str]:\n",
" \"\"\"\n",
" Splits a document into smaller chunks of text.\n",
"\n",
" Args:\n",
" document (str): The text of the document to be split.\n",
" chunk_size (int): The size of each chunk in terms of the number of tokens.\n",
" chunk_overlap (int): The number of overlapping tokens between consecutive chunks.\n",
"\n",
" Returns:\n",
" List[str]: A list of text chunks, where each chunk is a string of the document content.\n",
" \"\"\"\n",
" tokens = re.findall(r'\\b\\w+\\b', document)\n",
" chunks = []\n",
" for i in range(0, len(tokens), chunk_size - chunk_overlap):\n",
" chunk_tokens = tokens[i:i + chunk_size]\n",
" chunks.append(chunk_tokens)\n",
" if i + chunk_size >= len(tokens):\n",
" break\n",
" return [\" \".join(chunk) for chunk in chunks]\n",
"\n",
"def print_document(comment: str, document: Any) -> None:\n",
" \"\"\"\n",
" Prints a comment followed by the content of a document.\n",
"\n",
" Args:\n",
" comment (str): The comment or description to print before the document details.\n",
" document (Any): The document whose content is to be printed.\n",
"\n",
" Returns:\n",
" None\n",
" \"\"\"\n",
" print(f'{comment} (type: {document.metadata[\"type\"]}, index: {document.metadata[\"index\"]}): {document.page_content}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example usage\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize OpenAIEmbeddings\n",
"embeddings = OpenAIEmbeddingsWrapper()\n",
"\n",
"# Example document\n",
"example_text = \"This is an example document. It contains information about various topics.\"\n",
"\n",
"# Generate questions\n",
"questions = generate_questions(example_text)\n",
"print(\"Generated Questions:\")\n",
"for q in questions:\n",
" print(f\"- {q}\")\n",
"\n",
"# Generate an answer\n",
"sample_question = questions[0] if questions else \"What is this document about?\"\n",
"answer = generate_answer(example_text, sample_question)\n",
"print(f\"\\nQuestion: {sample_question}\")\n",
"print(f\"Answer: {answer}\")\n",
"\n",
"# Split document\n",
"chunks = split_document(example_text, chunk_size=10, chunk_overlap=2)\n",
"print(\"\\nDocument Chunks:\")\n",
"for i, chunk in enumerate(chunks):\n",
" print(f\"Chunk {i + 1}: {chunk}\")\n",
"\n",
"# Example of using OpenAIEmbeddings\n",
"doc_embedding = embeddings.embed_documents([example_text])\n",
"query_embedding = embeddings.embed_query(\"What is the main topic?\")\n",
"print(\"\\nDocument Embedding (first 5 elements):\", doc_embedding[0][:5])\n",
"print(\"Query Embedding (first 5 elements):\", query_embedding[:5])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Main pipeline"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def process_documents(content: str, embedding_model: OpenAIEmbeddings):\n",
" \"\"\"\n",
" Process the document content, split it into fragments, generate questions,\n",
" create a FAISS vector store, and return a retriever.\n",
"\n",
" Args:\n",
" content (str): The content of the document to process.\n",
" embedding_model (OpenAIEmbeddings): The embedding model to use for vectorization.\n",
"\n",
" Returns:\n",
" VectorStoreRetriever: A retriever for the most relevant FAISS document.\n",
" \"\"\"\n",
" # Split the whole text content into text documents\n",
" text_documents = split_document(content, DOCUMENT_MAX_TOKENS, DOCUMENT_OVERLAP_TOKENS)\n",
" print(f'Text content split into: {len(text_documents)} documents')\n",
"\n",
" documents = []\n",
" counter = 0\n",
" for i, text_document in enumerate(text_documents):\n",
" text_fragments = split_document(text_document, FRAGMENT_MAX_TOKENS, FRAGMENT_OVERLAP_TOKENS)\n",
" print(f'Text document {i} - split into: {len(text_fragments)} fragments')\n",
" \n",
" for j, text_fragment in enumerate(text_fragments):\n",
" documents.append(Document(\n",
" page_content=text_fragment,\n",
" metadata={\"type\": \"ORIGINAL\", \"index\": counter, \"text\": text_document}\n",
" ))\n",
" counter += 1\n",
" \n",
" if QUESTION_GENERATION == QuestionGeneration.FRAGMENT_LEVEL:\n",
" questions = generate_questions(text_fragment)\n",
" documents.extend([\n",
" Document(page_content=question, metadata={\"type\": \"AUGMENTED\", \"index\": counter + idx, \"text\": text_document})\n",
" for idx, question in enumerate(questions)\n",
" ])\n",
" counter += len(questions)\n",
" print(f'Text document {i} Text fragment {j} - generated: {len(questions)} questions')\n",
" \n",
" if QUESTION_GENERATION == QuestionGeneration.DOCUMENT_LEVEL:\n",
" questions = generate_questions(text_document)\n",
" documents.extend([\n",
" Document(page_content=question, metadata={\"type\": \"AUGMENTED\", \"index\": counter + idx, \"text\": text_document})\n",
" for idx, question in enumerate(questions)\n",
" ])\n",
" counter += len(questions)\n",
" print(f'Text document {i} - generated: {len(questions)} questions')\n",
"\n",
" for document in documents:\n",
" print_document(\"Dataset\", document)\n",
"\n",
" print(f'Creating store, calculating embeddings for {len(documents)} FAISS documents')\n",
" vectorstore = FAISS.from_documents(documents, embedding_model)\n",
"\n",
" print(\"Creating retriever returning the most relevant FAISS document\")\n",
" return vectorstore.as_retriever(search_kwargs={\"k\": 1})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Load sample PDF document to string variable\n",
"path = \"../data/Understanding_Climate_Change.pdf\"\n",
"content = read_pdf_to_string(path)\n",
"\n",
"# Instantiate OpenAI Embeddings class that will be used by FAISS\n",
"embedding_model = OpenAIEmbeddings()\n",
"\n",
"# Process documents and create retriever\n",
"document_query_retriever = process_documents(content, embedding_model)\n",
"\n",
"# Example usage of the retriever\n",
"query = \"What is climate change?\"\n",
"retrieved_docs = document_query_retriever.get_relevant_documents(query)\n",
"print(f\"\\nQuery: {query}\")\n",
"print(f\"Retrieved document: {retrieved_docs[0].page_content}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Find the most relevant FAISS document in the store. In most cases, this will be an augmented question rather than the original text document."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"How do freshwater ecosystems change due to alterations in climatic factors?\"\n",
"print (f'Question:{os.linesep}{query}{os.linesep}')\n",
"retrieved_documents = document_query_retriever.invoke(query)\n",
"\n",
"for doc in retrieved_documents:\n",
" print_document(\"Relevant fragment retrieved\", doc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Find the parent text document and use it as context for the generative model to generate an answer to the question."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"context = doc.metadata['text']\n",
"print (f'{os.linesep}Context:{os.linesep}{context}')\n",
"answer = generate_answer(context, query)\n",
"print(f'{os.linesep}Answer:{os.linesep}{answer}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}