Code updates:

1. added semantic_chunking.py
2. changed the structure of simple_rag.py
3. added missing imports to helper_functions.py and raptor.ipynb
This commit is contained in:
eliavs
2024-09-05 10:37:09 +03:00
parent 9456c27419
commit bc8e374510
4 changed files with 201 additions and 34 deletions

View File

@@ -85,6 +85,7 @@
"import pandas as pd\n",
"from typing import List, Dict, Any\n",
"from sklearn.mixture import GaussianMixture\n",
"from langchain.chains.llm import LLMChain\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.vectorstores import FAISS\n",
"from langchain_openai import ChatOpenAI\n",
@@ -92,6 +93,7 @@
"from langchain.retrievers import ContextualCompressionRetriever\n",
"from langchain.retrievers.document_compressors import LLMChainExtractor\n",
"from langchain.schema import AIMessage\n",
"from langchain.docstore.document import Document\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import logging\n",

View File

@@ -0,0 +1,123 @@
import time
import os
import sys
import argparse
from dotenv import load_dotenv
from helper_functions import *
from langchain_experimental.text_splitter import SemanticChunker, BreakpointThresholdType
from langchain_openai.embeddings import OpenAIEmbeddings
# Add the parent directory to the path since we work with notebooks
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
# Load environment variables from a .env file (e.g., OpenAI API key)
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
# Function to run semantic chunking and return chunking and retrieval times
class SemanticChunkingRAG:
"""
A class to handle the Semantic Chunking RAG process for document chunking and query retrieval.
"""
def __init__(self, path, n_retrieved=2, embeddings=None, breakpoint_type: BreakpointThresholdType = "percentile",
breakpoint_amount=90):
"""
Initializes the SemanticChunkingRAG by encoding the content using a semantic chunker.
Args:
path (str): Path to the PDF file to encode.
n_retrieved (int): Number of chunks to retrieve for each query (default: 2).
embeddings: Embedding model to use.
breakpoint_type (str): Type of semantic breakpoint threshold.
breakpoint_amount (float): Amount for the semantic breakpoint threshold.
"""
print("\n--- Initializing Semantic Chunking RAG ---")
# Read PDF to string
content = read_pdf_to_string(path)
# Use provided embeddings or initialize OpenAI embeddings
self.embeddings = embeddings if embeddings else OpenAIEmbeddings()
# Initialize the semantic chunker
self.semantic_chunker = SemanticChunker(
self.embeddings,
breakpoint_threshold_type=breakpoint_type,
breakpoint_threshold_amount=breakpoint_amount
)
# Measure time for semantic chunking
start_time = time.time()
self.semantic_docs = self.semantic_chunker.create_documents([content])
self.time_records = {'Chunking': time.time() - start_time}
print(f"Semantic Chunking Time: {self.time_records['Chunking']:.2f} seconds")
# Create a vector store and retriever from the semantic chunks
self.semantic_vectorstore = FAISS.from_documents(self.semantic_docs, self.embeddings)
self.semantic_retriever = self.semantic_vectorstore.as_retriever(search_kwargs={"k": n_retrieved})
def run(self, query):
"""
Retrieves and displays the context for the given query.
Args:
query (str): The query to retrieve context for.
Returns:
tuple: The retrieval time.
"""
# Measure time for semantic retrieval
start_time = time.time()
semantic_context = retrieve_context_per_question(query, self.semantic_retriever)
self.time_records['Retrieval'] = time.time() - start_time
print(f"Semantic Retrieval Time: {self.time_records['Retrieval']:.2f} seconds")
# Display the retrieved context
show_context(semantic_context)
return self.time_records
# Function to parse command line arguments
def parse_args():
parser = argparse.ArgumentParser(
description="Process a PDF document with semantic chunking RAG.")
parser.add_argument("--path", type=str, default="../data/Understanding_Climate_Change.pdf",
help="Path to the PDF file to encode.")
parser.add_argument("--n_retrieved", type=int, default=2,
help="Number of chunks to retrieve for each query (default: 2).")
parser.add_argument("--breakpoint_threshold_type", type=str,
choices=["percentile", "standard_deviation", "interquartile", "gradient"],
default="percentile",
help="Type of breakpoint threshold to use for chunking (default: percentile).")
parser.add_argument("--breakpoint_threshold_amount", type=float, default=90,
help="Amount of the breakpoint threshold to use (default: 90).")
parser.add_argument("--chunk_size", type=int, default=1000,
help="Size of each text chunk in simple chunking (default: 1000).")
parser.add_argument("--chunk_overlap", type=int, default=200,
help="Overlap between consecutive chunks in simple chunking (default: 200).")
parser.add_argument("--query", type=str, default="What is the main cause of climate change?",
help="Query to test the retriever (default: 'What is the main cause of climate change?').")
parser.add_argument("--experiment", action="store_true",
help="Run the experiment to compare performance between semantic chunking and simple chunking.")
return parser.parse_args()
# Main function to process PDF, chunk text, and test retriever
def main(args):
# Initialize SemanticChunkingRAG
semantic_rag = SemanticChunkingRAG(
path=args.path,
n_retrieved=args.n_retrieved,
breakpoint_type=args.breakpoint_threshold_type,
breakpoint_amount=args.breakpoint_threshold_amount
)
# Run a query
semantic_rag.run(args.query)
if __name__ == '__main__':
# Call the main function with parsed arguments
main(parse_args())

View File

@@ -1,21 +1,65 @@
from helper_functions import *
from evaluation.evalute_rag import *
import os
import sys
import argparse
import time
from dotenv import load_dotenv
from helper_functions import *
from evaluation.evalute_rag import *
# Add the parent directory to the path since we work with notebooks
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
# Load environment variables from a .env file (e.g., OpenAI API key)
load_dotenv()
# Set the OpenAI API key environment variable
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
class SimpleRAG:
"""
A class to handle the Simple RAG process for document chunking and query retrieval.
"""
def __init__(self, path, chunk_size=1000, chunk_overlap=200, n_retrieved=2):
"""
Initializes the SimpleRAGRetriever by encoding the PDF document and creating the retriever.
Args:
path (str): Path to the PDF file to encode.
chunk_size (int): Size of each text chunk (default: 1000).
chunk_overlap (int): Overlap between consecutive chunks (default: 200).
n_retrieved (int): Number of chunks to retrieve for each query (default: 2).
"""
print("\n--- Initializing Simple RAG Retriever ---")
# Encode the PDF document into a vector store using OpenAI embeddings
start_time = time.time()
self.vector_store = encode_pdf(path, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self.time_records = {'Chunking': time.time() - start_time}
print(f"Chunking Time: {self.time_records['Chunking']:.2f} seconds")
# Create a retriever from the vector store
self.chunks_query_retriever = self.vector_store.as_retriever(search_kwargs={"k": n_retrieved})
def run(self, query):
"""
Retrieves and displays the context for the given query.
Args:
query (str): The query to retrieve context for.
Returns:
tuple: The retrieval time.
"""
# Measure time for retrieval
start_time = time.time()
context = retrieve_context_per_question(query, self.chunks_query_retriever)
self.time_records['Retrieval'] = time.time() - start_time
print(f"Retrieval Time: {self.time_records['Retrieval']:.2f} seconds")
# Display the retrieved context
show_context(context)
# Function to validate command line inputs
def validate_args(args):
if args.chunk_size <= 0:
@@ -29,7 +73,7 @@ def validate_args(args):
# Function to parse command line arguments
def parse_args():
parser = argparse.ArgumentParser(description="Encode a PDF document and test a retriever.")
parser = argparse.ArgumentParser(description="Encode a PDF document and test a simple RAG.")
parser.add_argument("--path", type=str, default="../data/Understanding_Climate_Change.pdf",
help="Path to the PDF file to encode.")
parser.add_argument("--chunk_size", type=int, default=1000,
@@ -47,21 +91,22 @@ def parse_args():
return validate_args(parser.parse_args())
# Main function to encode PDF, retrieve context, and optionally evaluate retriever
# Main function to handle argument parsing and call the SimpleRAGRetriever class
def main(args):
# Encode the PDF document into a vector store using OpenAI embeddings
chunks_vector_store = encode_pdf(args.path, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap)
# Initialize the SimpleRAGRetriever
simple_rag = SimpleRAG(
path=args.path,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
n_retrieved=args.n_retrieved
)
# Create a retriever from the vector store, specifying how many chunks to retrieve
chunks_query_retriever = chunks_vector_store.as_retriever(search_kwargs={"k": args.n_retrieved})
# Test the retriever with the user's query
context = retrieve_context_per_question(args.query, chunks_query_retriever)
show_context(context) # Display the context retrieved for the query
# Retrieve context based on the query
simple_rag.run(args.query)
# Evaluate the retriever's performance on the query (if requested)
if args.evaluate:
evaluate_rag(chunks_query_retriever)
evaluate_rag(simple_rag.chunks_query_retriever)
if __name__ == '__main__':

View File

@@ -1,21 +1,20 @@
from langchain.document_loaders import PyPDFLoader
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain import PromptTemplate
import fitz
from openai import RateLimitError
from typing import List
from rank_bm25 import BM25Okapi
import fitz
import asyncio
import random
import textwrap
import numpy as np
def replace_t_with_space(list_of_documents):
"""
Replaces all tab characters ('\t') with spaces in the page content of each document.
@@ -46,8 +45,6 @@ def text_wrap(text, width=120):
return textwrap.fill(text, width=width)
def encode_pdf(path, chunk_size=1000, chunk_overlap=200):
"""
Encodes a PDF book into a vector store using OpenAI embeddings.
@@ -78,6 +75,7 @@ def encode_pdf(path, chunk_size=1000, chunk_overlap=200):
return vectorstore
def encode_from_string(content, chunk_size=1000, chunk_overlap=200):
"""
Encodes a string into a vector store using OpenAI embeddings.
@@ -94,7 +92,7 @@ def encode_from_string(content, chunk_size=1000, chunk_overlap=200):
ValueError: If the input content is not valid.
RuntimeError: If there is an error during the encoding process.
"""
if not isinstance(content, str) or not content.strip():
raise ValueError("Content must be a non-empty string.")
@@ -148,9 +146,9 @@ def retrieve_context_per_question(question, chunks_query_retriever):
# context = " ".join(doc.page_content for doc in docs)
context = [doc.page_content for doc in docs]
return context
class QuestionAnswerFromContext(BaseModel):
"""
Model to generate an answer to a query based on a given context.
@@ -160,8 +158,8 @@ class QuestionAnswerFromContext(BaseModel):
"""
answer_based_on_content: str = Field(description="Generates an answer to a query based on a given context.")
def create_question_answer_from_context_chain(llm):
def create_question_answer_from_context_chain(llm):
# Initialize the ChatOpenAI model with specific parameters
question_answer_from_context_llm = llm
@@ -180,11 +178,11 @@ def create_question_answer_from_context_chain(llm):
)
# Create a chain by combining the prompt template and the language model
question_answer_from_context_cot_chain = question_answer_from_context_prompt | question_answer_from_context_llm.with_structured_output(QuestionAnswerFromContext)
question_answer_from_context_cot_chain = question_answer_from_context_prompt | question_answer_from_context_llm.with_structured_output(
QuestionAnswerFromContext)
return question_answer_from_context_cot_chain
def answer_question_from_context(question, context, question_answer_from_context_chain):
"""
Answer a question using the given context by invoking a chain of reasoning.
@@ -217,7 +215,7 @@ def show_context(context):
Prints each context item in the list with a heading indicating its position.
"""
for i, c in enumerate(context):
print(f"Context {i+1}:")
print(f"Context {i + 1}:")
print(c)
print("\n")
@@ -247,7 +245,6 @@ def read_pdf_to_string(path):
return content
def bm25_retrieval(bm25: BM25Okapi, cleaned_texts: List[str], query: str, k: int = 5) -> List[str]:
"""
Perform BM25 retrieval and return the top k cleaned text chunks.
@@ -276,7 +273,6 @@ def bm25_retrieval(bm25: BM25Okapi, cleaned_texts: List[str], query: str, k: int
return top_k_texts
async def exponential_backoff(attempt):
"""
Implements exponential backoff with a jitter.
@@ -290,10 +286,11 @@ async def exponential_backoff(attempt):
# Calculate the wait time with exponential backoff and jitter
wait_time = (2 ** attempt) + random.uniform(0, 1)
print(f"Rate limit hit. Retrying in {wait_time:.2f} seconds...")
# Asynchronously sleep for the calculated wait time
await asyncio.sleep(wait_time)
async def retry_with_exponential_backoff(coroutine, max_retries=5):
"""
Retries a coroutine using exponential backoff upon encountering a RateLimitError.
@@ -316,9 +313,9 @@ async def retry_with_exponential_backoff(coroutine, max_retries=5):
# If the last attempt also fails, raise the exception
if attempt == max_retries - 1:
raise e
# Wait for an exponential backoff period before retrying
await exponential_backoff(attempt)
# If max retries are reached without success, raise an exception
raise Exception("Max retries reached")