mirror of
https://github.com/NirDiamant/RAG_Techniques.git
synced 2025-04-07 00:48:52 +03:00
Code updates:
1. added semantic_chunking.py 2. changed the structure of simple_rag.py 3. added missing imports to helper_functions.py and raptor.ipynb
This commit is contained in:
@@ -85,6 +85,7 @@
|
||||
"import pandas as pd\n",
|
||||
"from typing import List, Dict, Any\n",
|
||||
"from sklearn.mixture import GaussianMixture\n",
|
||||
"from langchain.chains.llm import LLMChain\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.vectorstores import FAISS\n",
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
@@ -92,6 +93,7 @@
|
||||
"from langchain.retrievers import ContextualCompressionRetriever\n",
|
||||
"from langchain.retrievers.document_compressors import LLMChainExtractor\n",
|
||||
"from langchain.schema import AIMessage\n",
|
||||
"from langchain.docstore.document import Document\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import logging\n",
|
||||
|
||||
123
all_rag_techniques_runnable_scripts/semantic_chunking.py
Normal file
123
all_rag_techniques_runnable_scripts/semantic_chunking.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from dotenv import load_dotenv
|
||||
from helper_functions import *
|
||||
from langchain_experimental.text_splitter import SemanticChunker, BreakpointThresholdType
|
||||
from langchain_openai.embeddings import OpenAIEmbeddings
|
||||
|
||||
# Add the parent directory to the path since we work with notebooks
|
||||
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
|
||||
|
||||
# Load environment variables from a .env file (e.g., OpenAI API key)
|
||||
load_dotenv()
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
|
||||
# Function to run semantic chunking and return chunking and retrieval times
|
||||
class SemanticChunkingRAG:
|
||||
"""
|
||||
A class to handle the Semantic Chunking RAG process for document chunking and query retrieval.
|
||||
"""
|
||||
|
||||
def __init__(self, path, n_retrieved=2, embeddings=None, breakpoint_type: BreakpointThresholdType = "percentile",
|
||||
breakpoint_amount=90):
|
||||
"""
|
||||
Initializes the SemanticChunkingRAG by encoding the content using a semantic chunker.
|
||||
|
||||
Args:
|
||||
path (str): Path to the PDF file to encode.
|
||||
n_retrieved (int): Number of chunks to retrieve for each query (default: 2).
|
||||
embeddings: Embedding model to use.
|
||||
breakpoint_type (str): Type of semantic breakpoint threshold.
|
||||
breakpoint_amount (float): Amount for the semantic breakpoint threshold.
|
||||
"""
|
||||
print("\n--- Initializing Semantic Chunking RAG ---")
|
||||
# Read PDF to string
|
||||
content = read_pdf_to_string(path)
|
||||
|
||||
# Use provided embeddings or initialize OpenAI embeddings
|
||||
self.embeddings = embeddings if embeddings else OpenAIEmbeddings()
|
||||
|
||||
# Initialize the semantic chunker
|
||||
self.semantic_chunker = SemanticChunker(
|
||||
self.embeddings,
|
||||
breakpoint_threshold_type=breakpoint_type,
|
||||
breakpoint_threshold_amount=breakpoint_amount
|
||||
)
|
||||
|
||||
# Measure time for semantic chunking
|
||||
start_time = time.time()
|
||||
self.semantic_docs = self.semantic_chunker.create_documents([content])
|
||||
self.time_records = {'Chunking': time.time() - start_time}
|
||||
print(f"Semantic Chunking Time: {self.time_records['Chunking']:.2f} seconds")
|
||||
|
||||
# Create a vector store and retriever from the semantic chunks
|
||||
self.semantic_vectorstore = FAISS.from_documents(self.semantic_docs, self.embeddings)
|
||||
self.semantic_retriever = self.semantic_vectorstore.as_retriever(search_kwargs={"k": n_retrieved})
|
||||
|
||||
def run(self, query):
|
||||
"""
|
||||
Retrieves and displays the context for the given query.
|
||||
|
||||
Args:
|
||||
query (str): The query to retrieve context for.
|
||||
|
||||
Returns:
|
||||
tuple: The retrieval time.
|
||||
"""
|
||||
# Measure time for semantic retrieval
|
||||
start_time = time.time()
|
||||
semantic_context = retrieve_context_per_question(query, self.semantic_retriever)
|
||||
self.time_records['Retrieval'] = time.time() - start_time
|
||||
print(f"Semantic Retrieval Time: {self.time_records['Retrieval']:.2f} seconds")
|
||||
|
||||
# Display the retrieved context
|
||||
show_context(semantic_context)
|
||||
return self.time_records
|
||||
|
||||
|
||||
# Function to parse command line arguments
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process a PDF document with semantic chunking RAG.")
|
||||
parser.add_argument("--path", type=str, default="../data/Understanding_Climate_Change.pdf",
|
||||
help="Path to the PDF file to encode.")
|
||||
parser.add_argument("--n_retrieved", type=int, default=2,
|
||||
help="Number of chunks to retrieve for each query (default: 2).")
|
||||
parser.add_argument("--breakpoint_threshold_type", type=str,
|
||||
choices=["percentile", "standard_deviation", "interquartile", "gradient"],
|
||||
default="percentile",
|
||||
help="Type of breakpoint threshold to use for chunking (default: percentile).")
|
||||
parser.add_argument("--breakpoint_threshold_amount", type=float, default=90,
|
||||
help="Amount of the breakpoint threshold to use (default: 90).")
|
||||
parser.add_argument("--chunk_size", type=int, default=1000,
|
||||
help="Size of each text chunk in simple chunking (default: 1000).")
|
||||
parser.add_argument("--chunk_overlap", type=int, default=200,
|
||||
help="Overlap between consecutive chunks in simple chunking (default: 200).")
|
||||
parser.add_argument("--query", type=str, default="What is the main cause of climate change?",
|
||||
help="Query to test the retriever (default: 'What is the main cause of climate change?').")
|
||||
parser.add_argument("--experiment", action="store_true",
|
||||
help="Run the experiment to compare performance between semantic chunking and simple chunking.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# Main function to process PDF, chunk text, and test retriever
|
||||
def main(args):
|
||||
# Initialize SemanticChunkingRAG
|
||||
semantic_rag = SemanticChunkingRAG(
|
||||
path=args.path,
|
||||
n_retrieved=args.n_retrieved,
|
||||
breakpoint_type=args.breakpoint_threshold_type,
|
||||
breakpoint_amount=args.breakpoint_threshold_amount
|
||||
)
|
||||
|
||||
# Run a query
|
||||
semantic_rag.run(args.query)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Call the main function with parsed arguments
|
||||
main(parse_args())
|
||||
@@ -1,21 +1,65 @@
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
from dotenv import load_dotenv
|
||||
from helper_functions import *
|
||||
from evaluation.evalute_rag import *
|
||||
|
||||
# Add the parent directory to the path since we work with notebooks
|
||||
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
|
||||
|
||||
# Load environment variables from a .env file (e.g., OpenAI API key)
|
||||
load_dotenv()
|
||||
|
||||
# Set the OpenAI API key environment variable
|
||||
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
|
||||
class SimpleRAG:
|
||||
"""
|
||||
A class to handle the Simple RAG process for document chunking and query retrieval.
|
||||
"""
|
||||
|
||||
def __init__(self, path, chunk_size=1000, chunk_overlap=200, n_retrieved=2):
|
||||
"""
|
||||
Initializes the SimpleRAGRetriever by encoding the PDF document and creating the retriever.
|
||||
|
||||
Args:
|
||||
path (str): Path to the PDF file to encode.
|
||||
chunk_size (int): Size of each text chunk (default: 1000).
|
||||
chunk_overlap (int): Overlap between consecutive chunks (default: 200).
|
||||
n_retrieved (int): Number of chunks to retrieve for each query (default: 2).
|
||||
"""
|
||||
print("\n--- Initializing Simple RAG Retriever ---")
|
||||
|
||||
# Encode the PDF document into a vector store using OpenAI embeddings
|
||||
start_time = time.time()
|
||||
self.vector_store = encode_pdf(path, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
self.time_records = {'Chunking': time.time() - start_time}
|
||||
print(f"Chunking Time: {self.time_records['Chunking']:.2f} seconds")
|
||||
|
||||
# Create a retriever from the vector store
|
||||
self.chunks_query_retriever = self.vector_store.as_retriever(search_kwargs={"k": n_retrieved})
|
||||
|
||||
def run(self, query):
|
||||
"""
|
||||
Retrieves and displays the context for the given query.
|
||||
|
||||
Args:
|
||||
query (str): The query to retrieve context for.
|
||||
|
||||
Returns:
|
||||
tuple: The retrieval time.
|
||||
"""
|
||||
# Measure time for retrieval
|
||||
start_time = time.time()
|
||||
context = retrieve_context_per_question(query, self.chunks_query_retriever)
|
||||
self.time_records['Retrieval'] = time.time() - start_time
|
||||
print(f"Retrieval Time: {self.time_records['Retrieval']:.2f} seconds")
|
||||
|
||||
# Display the retrieved context
|
||||
show_context(context)
|
||||
|
||||
|
||||
# Function to validate command line inputs
|
||||
def validate_args(args):
|
||||
if args.chunk_size <= 0:
|
||||
@@ -29,7 +73,7 @@ def validate_args(args):
|
||||
|
||||
# Function to parse command line arguments
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Encode a PDF document and test a retriever.")
|
||||
parser = argparse.ArgumentParser(description="Encode a PDF document and test a simple RAG.")
|
||||
parser.add_argument("--path", type=str, default="../data/Understanding_Climate_Change.pdf",
|
||||
help="Path to the PDF file to encode.")
|
||||
parser.add_argument("--chunk_size", type=int, default=1000,
|
||||
@@ -47,21 +91,22 @@ def parse_args():
|
||||
return validate_args(parser.parse_args())
|
||||
|
||||
|
||||
# Main function to encode PDF, retrieve context, and optionally evaluate retriever
|
||||
# Main function to handle argument parsing and call the SimpleRAGRetriever class
|
||||
def main(args):
|
||||
# Encode the PDF document into a vector store using OpenAI embeddings
|
||||
chunks_vector_store = encode_pdf(args.path, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap)
|
||||
# Initialize the SimpleRAGRetriever
|
||||
simple_rag = SimpleRAG(
|
||||
path=args.path,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
n_retrieved=args.n_retrieved
|
||||
)
|
||||
|
||||
# Create a retriever from the vector store, specifying how many chunks to retrieve
|
||||
chunks_query_retriever = chunks_vector_store.as_retriever(search_kwargs={"k": args.n_retrieved})
|
||||
|
||||
# Test the retriever with the user's query
|
||||
context = retrieve_context_per_question(args.query, chunks_query_retriever)
|
||||
show_context(context) # Display the context retrieved for the query
|
||||
# Retrieve context based on the query
|
||||
simple_rag.run(args.query)
|
||||
|
||||
# Evaluate the retriever's performance on the query (if requested)
|
||||
if args.evaluate:
|
||||
evaluate_rag(chunks_query_retriever)
|
||||
evaluate_rag(simple_rag.chunks_query_retriever)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain import PromptTemplate
|
||||
import fitz
|
||||
from openai import RateLimitError
|
||||
from typing import List
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
import fitz
|
||||
import asyncio
|
||||
import random
|
||||
import textwrap
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def replace_t_with_space(list_of_documents):
|
||||
"""
|
||||
Replaces all tab characters ('\t') with spaces in the page content of each document.
|
||||
@@ -46,8 +45,6 @@ def text_wrap(text, width=120):
|
||||
return textwrap.fill(text, width=width)
|
||||
|
||||
|
||||
|
||||
|
||||
def encode_pdf(path, chunk_size=1000, chunk_overlap=200):
|
||||
"""
|
||||
Encodes a PDF book into a vector store using OpenAI embeddings.
|
||||
@@ -78,6 +75,7 @@ def encode_pdf(path, chunk_size=1000, chunk_overlap=200):
|
||||
|
||||
return vectorstore
|
||||
|
||||
|
||||
def encode_from_string(content, chunk_size=1000, chunk_overlap=200):
|
||||
"""
|
||||
Encodes a string into a vector store using OpenAI embeddings.
|
||||
@@ -94,7 +92,7 @@ def encode_from_string(content, chunk_size=1000, chunk_overlap=200):
|
||||
ValueError: If the input content is not valid.
|
||||
RuntimeError: If there is an error during the encoding process.
|
||||
"""
|
||||
|
||||
|
||||
if not isinstance(content, str) or not content.strip():
|
||||
raise ValueError("Content must be a non-empty string.")
|
||||
|
||||
@@ -148,9 +146,9 @@ def retrieve_context_per_question(question, chunks_query_retriever):
|
||||
# context = " ".join(doc.page_content for doc in docs)
|
||||
context = [doc.page_content for doc in docs]
|
||||
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class QuestionAnswerFromContext(BaseModel):
|
||||
"""
|
||||
Model to generate an answer to a query based on a given context.
|
||||
@@ -160,8 +158,8 @@ class QuestionAnswerFromContext(BaseModel):
|
||||
"""
|
||||
answer_based_on_content: str = Field(description="Generates an answer to a query based on a given context.")
|
||||
|
||||
def create_question_answer_from_context_chain(llm):
|
||||
|
||||
def create_question_answer_from_context_chain(llm):
|
||||
# Initialize the ChatOpenAI model with specific parameters
|
||||
question_answer_from_context_llm = llm
|
||||
|
||||
@@ -180,11 +178,11 @@ def create_question_answer_from_context_chain(llm):
|
||||
)
|
||||
|
||||
# Create a chain by combining the prompt template and the language model
|
||||
question_answer_from_context_cot_chain = question_answer_from_context_prompt | question_answer_from_context_llm.with_structured_output(QuestionAnswerFromContext)
|
||||
question_answer_from_context_cot_chain = question_answer_from_context_prompt | question_answer_from_context_llm.with_structured_output(
|
||||
QuestionAnswerFromContext)
|
||||
return question_answer_from_context_cot_chain
|
||||
|
||||
|
||||
|
||||
def answer_question_from_context(question, context, question_answer_from_context_chain):
|
||||
"""
|
||||
Answer a question using the given context by invoking a chain of reasoning.
|
||||
@@ -217,7 +215,7 @@ def show_context(context):
|
||||
Prints each context item in the list with a heading indicating its position.
|
||||
"""
|
||||
for i, c in enumerate(context):
|
||||
print(f"Context {i+1}:")
|
||||
print(f"Context {i + 1}:")
|
||||
print(c)
|
||||
print("\n")
|
||||
|
||||
@@ -247,7 +245,6 @@ def read_pdf_to_string(path):
|
||||
return content
|
||||
|
||||
|
||||
|
||||
def bm25_retrieval(bm25: BM25Okapi, cleaned_texts: List[str], query: str, k: int = 5) -> List[str]:
|
||||
"""
|
||||
Perform BM25 retrieval and return the top k cleaned text chunks.
|
||||
@@ -276,7 +273,6 @@ def bm25_retrieval(bm25: BM25Okapi, cleaned_texts: List[str], query: str, k: int
|
||||
return top_k_texts
|
||||
|
||||
|
||||
|
||||
async def exponential_backoff(attempt):
|
||||
"""
|
||||
Implements exponential backoff with a jitter.
|
||||
@@ -290,10 +286,11 @@ async def exponential_backoff(attempt):
|
||||
# Calculate the wait time with exponential backoff and jitter
|
||||
wait_time = (2 ** attempt) + random.uniform(0, 1)
|
||||
print(f"Rate limit hit. Retrying in {wait_time:.2f} seconds...")
|
||||
|
||||
|
||||
# Asynchronously sleep for the calculated wait time
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
|
||||
async def retry_with_exponential_backoff(coroutine, max_retries=5):
|
||||
"""
|
||||
Retries a coroutine using exponential backoff upon encountering a RateLimitError.
|
||||
@@ -316,9 +313,9 @@ async def retry_with_exponential_backoff(coroutine, max_retries=5):
|
||||
# If the last attempt also fails, raise the exception
|
||||
if attempt == max_retries - 1:
|
||||
raise e
|
||||
|
||||
|
||||
# Wait for an exponential backoff period before retrying
|
||||
await exponential_backoff(attempt)
|
||||
|
||||
|
||||
# If max retries are reached without success, raise an exception
|
||||
raise Exception("Max retries reached")
|
||||
|
||||
Reference in New Issue
Block a user