mirror of
https://github.com/NirDiamant/RAG_Techniques.git
synced 2025-04-07 00:48:52 +03:00
817 lines
33 KiB
Python
817 lines
33 KiB
Python
import networkx as nx
|
|
from langchain.vectorstores import FAISS
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain.retrievers import ContextualCompressionRetriever
|
|
from langchain.retrievers.document_compressors import LLMChainExtractor
|
|
from langchain.callbacks import get_openai_callback
|
|
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as patches
|
|
import os
|
|
import sys
|
|
from dotenv import load_dotenv
|
|
from langchain_openai import ChatOpenAI
|
|
from typing import List, Tuple, Dict
|
|
from nltk.stem import WordNetLemmatizer
|
|
from nltk.tokenize import word_tokenize
|
|
import nltk
|
|
import spacy
|
|
import heapq
|
|
import argparse
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
|
|
from spacy.cli import download
|
|
from spacy.lang.en import English
|
|
|
|
sys.path.append(os.path.abspath(
|
|
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path sicnce we work with notebooks
|
|
from helper_functions import *
|
|
from evaluation.evalute_rag import *
|
|
|
|
# Load environment variables from a .env file
|
|
load_dotenv()
|
|
|
|
# Set the OpenAI API key environment variable
|
|
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
|
|
|
nltk.download('punkt', quiet=True)
|
|
nltk.download('wordnet', quiet=True)
|
|
|
|
|
|
# Define the document processor class
|
|
# Define the DocumentProcessor class
|
|
class DocumentProcessor:
|
|
def __init__(self):
|
|
"""
|
|
Initializes the DocumentProcessor with a text splitter and OpenAI embeddings.
|
|
|
|
Attributes:
|
|
- text_splitter: An instance of RecursiveCharacterTextSplitter with specified chunk size and overlap.
|
|
- embeddings: An instance of OpenAIEmbeddings used for embedding documents.
|
|
"""
|
|
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
|
self.embeddings = OpenAIEmbeddings()
|
|
|
|
def process_documents(self, documents):
|
|
"""
|
|
Processes a list of documents by splitting them into smaller chunks and creating a vector store.
|
|
|
|
Args:
|
|
- documents (list of str): A list of documents to be processed.
|
|
|
|
Returns:
|
|
- tuple: A tuple containing:
|
|
- splits (list of str): The list of split document chunks.
|
|
- vector_store (FAISS): A FAISS vector store created from the split document chunks and their embeddings.
|
|
"""
|
|
splits = self.text_splitter.split_documents(documents)
|
|
vector_store = FAISS.from_documents(splits, self.embeddings)
|
|
return splits, vector_store
|
|
|
|
def create_embeddings_batch(self, texts, batch_size=32):
|
|
"""
|
|
Creates embeddings for a list of texts in batches.
|
|
|
|
Args:
|
|
- texts (list of str): A list of texts to be embedded.
|
|
- batch_size (int, optional): The number of texts to process in each batch. Default is 32.
|
|
|
|
Returns:
|
|
- numpy.ndarray: An array of embeddings for the input texts.
|
|
"""
|
|
embeddings = []
|
|
for i in range(0, len(texts), batch_size):
|
|
batch = texts[i:i + batch_size]
|
|
batch_embeddings = self.embeddings.embed_documents(batch)
|
|
embeddings.extend(batch_embeddings)
|
|
return np.array(embeddings)
|
|
|
|
def compute_similarity_matrix(self, embeddings):
|
|
"""
|
|
Computes a cosine similarity matrix for a given set of embeddings.
|
|
|
|
Args:
|
|
- embeddings (numpy.ndarray): An array of embeddings.
|
|
|
|
Returns:
|
|
- numpy.ndarray: A cosine similarity matrix for the input embeddings.
|
|
"""
|
|
return cosine_similarity(embeddings)
|
|
|
|
|
|
# Define the knowledge graph class
|
|
# Define the Concepts class
|
|
class Concepts(BaseModel):
|
|
concepts_list: List[str] = Field(description="List of concepts")
|
|
|
|
|
|
# Define the KnowledgeGraph class
|
|
class KnowledgeGraph:
|
|
def __init__(self):
|
|
"""
|
|
Initializes the KnowledgeGraph with a graph, lemmatizer, and NLP model.
|
|
|
|
Attributes:
|
|
- graph: An instance of a networkx Graph.
|
|
- lemmatizer: An instance of WordNetLemmatizer.
|
|
- concept_cache: A dictionary to cache extracted concepts.
|
|
- nlp: An instance of a spaCy NLP model.
|
|
- edges_threshold: A float value that sets the threshold for adding edges based on similarity.
|
|
"""
|
|
self.graph = nx.Graph()
|
|
self.lemmatizer = WordNetLemmatizer()
|
|
self.concept_cache = {}
|
|
self.nlp = self._load_spacy_model()
|
|
self.edges_threshold = 0.8
|
|
|
|
def build_graph(self, splits, llm, embedding_model):
|
|
"""
|
|
Builds the knowledge graph by adding nodes, creating embeddings, extracting concepts, and adding edges.
|
|
|
|
Args:
|
|
- splits (list): A list of document splits.
|
|
- llm: An instance of a large language model.
|
|
- embedding_model: An instance of an embedding model.
|
|
|
|
Returns:
|
|
- None
|
|
"""
|
|
self._add_nodes(splits)
|
|
embeddings = self._create_embeddings(splits, embedding_model)
|
|
self._extract_concepts(splits, llm)
|
|
self._add_edges(embeddings)
|
|
|
|
def _add_nodes(self, splits):
|
|
"""
|
|
Adds nodes to the graph from the document splits.
|
|
|
|
Args:
|
|
- splits (list): A list of document splits.
|
|
|
|
Returns:
|
|
- None
|
|
"""
|
|
for i, split in enumerate(splits):
|
|
self.graph.add_node(i, content=split.page_content)
|
|
|
|
def _create_embeddings(self, splits, embedding_model):
|
|
"""
|
|
Creates embeddings for the document splits using the embedding model.
|
|
|
|
Args:
|
|
- splits (list): A list of document splits.
|
|
- embedding_model: An instance of an embedding model.
|
|
|
|
Returns:
|
|
- numpy.ndarray: An array of embeddings for the document splits.
|
|
"""
|
|
texts = [split.page_content for split in splits]
|
|
return embedding_model.embed_documents(texts)
|
|
|
|
def _compute_similarities(self, embeddings):
|
|
"""
|
|
Computes the cosine similarity matrix for the embeddings.
|
|
|
|
Args:
|
|
- embeddings (numpy.ndarray): An array of embeddings.
|
|
|
|
Returns:
|
|
- numpy.ndarray: A cosine similarity matrix for the embeddings.
|
|
"""
|
|
return cosine_similarity(embeddings)
|
|
|
|
def _load_spacy_model(self):
|
|
"""
|
|
Loads the spaCy NLP model, downloading it if necessary.
|
|
|
|
Args:
|
|
- None
|
|
|
|
Returns:
|
|
- spacy.Language: An instance of a spaCy NLP model.
|
|
"""
|
|
try:
|
|
return spacy.load("en_core_web_sm")
|
|
except OSError:
|
|
print("Downloading spaCy model...")
|
|
download("en_core_web_sm")
|
|
return spacy.load("en_core_web_sm")
|
|
|
|
def _extract_concepts_and_entities(self, content, llm):
|
|
"""
|
|
Extracts concepts and named entities from the content using spaCy and a large language model.
|
|
|
|
Args:
|
|
- content (str): The content from which to extract concepts and entities.
|
|
- llm: An instance of a large language model.
|
|
|
|
Returns:
|
|
- list: A list of extracted concepts and entities.
|
|
"""
|
|
if content in self.concept_cache:
|
|
return self.concept_cache[content]
|
|
|
|
# Extract named entities using spaCy
|
|
doc = self.nlp(content)
|
|
named_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "WORK_OF_ART"]]
|
|
|
|
# Extract general concepts using LLM
|
|
concept_extraction_prompt = PromptTemplate(
|
|
input_variables=["text"],
|
|
template="Extract key concepts (excluding named entities) from the following text:\n\n{text}\n\nKey concepts:"
|
|
)
|
|
concept_chain = concept_extraction_prompt | llm.with_structured_output(Concepts)
|
|
general_concepts = concept_chain.invoke({"text": content}).concepts_list
|
|
|
|
# Combine named entities and general concepts
|
|
all_concepts = list(set(named_entities + general_concepts))
|
|
|
|
self.concept_cache[content] = all_concepts
|
|
return all_concepts
|
|
|
|
def _extract_concepts(self, splits, llm):
|
|
"""
|
|
Extracts concepts for all document splits using multi-threading.
|
|
|
|
Args:
|
|
- splits (list): A list of document splits.
|
|
- llm: An instance of a large language model.
|
|
|
|
Returns:
|
|
- None
|
|
"""
|
|
with ThreadPoolExecutor() as executor:
|
|
future_to_node = {executor.submit(self._extract_concepts_and_entities, split.page_content, llm): i
|
|
for i, split in enumerate(splits)}
|
|
|
|
for future in tqdm(as_completed(future_to_node), total=len(splits),
|
|
desc="Extracting concepts and entities"):
|
|
node = future_to_node[future]
|
|
concepts = future.result()
|
|
self.graph.nodes[node]['concepts'] = concepts
|
|
|
|
def _add_edges(self, embeddings):
|
|
"""
|
|
Adds edges to the graph based on the similarity of embeddings and shared concepts.
|
|
|
|
Args:
|
|
- embeddings (numpy.ndarray): An array of embeddings for the document splits.
|
|
|
|
Returns:
|
|
- None
|
|
"""
|
|
similarity_matrix = self._compute_similarities(embeddings)
|
|
num_nodes = len(self.graph.nodes)
|
|
|
|
for node1 in tqdm(range(num_nodes), desc="Adding edges"):
|
|
for node2 in range(node1 + 1, num_nodes):
|
|
similarity_score = similarity_matrix[node1][node2]
|
|
if similarity_score > self.edges_threshold:
|
|
shared_concepts = set(self.graph.nodes[node1]['concepts']) & set(
|
|
self.graph.nodes[node2]['concepts'])
|
|
edge_weight = self._calculate_edge_weight(node1, node2, similarity_score, shared_concepts)
|
|
self.graph.add_edge(node1, node2, weight=edge_weight,
|
|
similarity=similarity_score,
|
|
shared_concepts=list(shared_concepts))
|
|
|
|
def _calculate_edge_weight(self, node1, node2, similarity_score, shared_concepts, alpha=0.7, beta=0.3):
|
|
"""
|
|
Calculates the weight of an edge based on similarity score and shared concepts.
|
|
|
|
Args:
|
|
- node1 (int): The first node.
|
|
- node2 (int): The second node.
|
|
- similarity_score (float): The similarity score between the nodes.
|
|
- shared_concepts (set): The set of shared concepts between the nodes.
|
|
- alpha (float, optional): The weight of the similarity score. Default is 0.7.
|
|
- beta (float, optional): The weight of the shared concepts. Default is 0.3.
|
|
|
|
Returns:
|
|
- float: The calculated weight of the edge.
|
|
"""
|
|
max_possible_shared = min(len(self.graph.nodes[node1]['concepts']), len(self.graph.nodes[node2]['concepts']))
|
|
normalized_shared_concepts = len(shared_concepts) / max_possible_shared if max_possible_shared > 0 else 0
|
|
return alpha * similarity_score + beta * normalized_shared_concepts
|
|
|
|
def _lemmatize_concept(self, concept):
|
|
"""
|
|
Lemmatizes a given concept.
|
|
|
|
Args:
|
|
- concept (str): The concept to be lemmatized.
|
|
|
|
Returns:
|
|
- str: The lemmatized concept.
|
|
"""
|
|
return ' '.join([self.lemmatizer.lemmatize(word) for word in concept.lower().split()])
|
|
|
|
|
|
# Define the Query Engine class
|
|
# Define the AnswerCheck class
|
|
class AnswerCheck(BaseModel):
|
|
is_complete: bool = Field(description="Whether the current context provides a complete answer to the query")
|
|
answer: str = Field(description="The current answer based on the context, if any")
|
|
|
|
|
|
# Define the QueryEngine class
|
|
class QueryEngine:
|
|
def __init__(self, vector_store, knowledge_graph, llm):
|
|
self.vector_store = vector_store
|
|
self.knowledge_graph = knowledge_graph
|
|
self.llm = llm
|
|
self.max_context_length = 4000
|
|
self.answer_check_chain = self._create_answer_check_chain()
|
|
|
|
def _create_answer_check_chain(self):
|
|
"""
|
|
Creates a chain to check if the context provides a complete answer to the query.
|
|
|
|
Args:
|
|
- None
|
|
|
|
Returns:
|
|
- Chain: A chain to check if the context provides a complete answer.
|
|
"""
|
|
answer_check_prompt = PromptTemplate(
|
|
input_variables=["query", "context"],
|
|
template="Given the query: '{query}'\n\nAnd the current context:\n{context}\n\nDoes this context provide a complete answer to the query? If yes, provide the answer. If no, state that the answer is incomplete.\n\nIs complete answer (Yes/No):\nAnswer (if complete):"
|
|
)
|
|
return answer_check_prompt | self.llm.with_structured_output(AnswerCheck)
|
|
|
|
def _check_answer(self, query: str, context: str) -> Tuple[bool, str]:
|
|
"""
|
|
Checks if the current context provides a complete answer to the query.
|
|
|
|
Args:
|
|
- query (str): The query to be answered.
|
|
- context (str): The current context.
|
|
|
|
Returns:
|
|
- tuple: A tuple containing:
|
|
- is_complete (bool): Whether the context provides a complete answer.
|
|
- answer (str): The answer based on the context, if complete.
|
|
"""
|
|
response = self.answer_check_chain.invoke({"query": query, "context": context})
|
|
return response.is_complete, response.answer
|
|
|
|
def _expand_context(self, query: str, relevant_docs) -> Tuple[str, List[int], Dict[int, str], str]:
|
|
"""
|
|
Expands the context by traversing the knowledge graph using a Dijkstra-like approach.
|
|
|
|
This method implements a modified version of Dijkstra's algorithm to explore the knowledge graph,
|
|
prioritizing the most relevant and strongly connected information. The algorithm works as follows:
|
|
|
|
1. Initialize:
|
|
- Start with nodes corresponding to the most relevant documents.
|
|
- Use a priority queue to manage the traversal order, where priority is based on connection strength.
|
|
- Maintain a dictionary of best known "distances" (inverse of connection strengths) to each node.
|
|
|
|
2. Traverse:
|
|
- Always explore the node with the highest priority (strongest connection) next.
|
|
- For each node, check if we've found a complete answer.
|
|
- Explore the node's neighbors, updating their priorities if a stronger connection is found.
|
|
|
|
3. Concept Handling:
|
|
- Track visited concepts to guide the exploration towards new, relevant information.
|
|
- Expand to neighbors only if they introduce new concepts.
|
|
|
|
4. Termination:
|
|
- Stop if a complete answer is found.
|
|
- Continue until the priority queue is empty (all reachable nodes explored).
|
|
|
|
This approach ensures that:
|
|
- We prioritize the most relevant and strongly connected information.
|
|
- We explore new concepts systematically.
|
|
- We find the most relevant answer by following the strongest connections in the knowledge graph.
|
|
|
|
Args:
|
|
- query (str): The query to be answered.
|
|
- relevant_docs (List[Document]): A list of relevant documents to start the traversal.
|
|
|
|
Returns:
|
|
- tuple: A tuple containing:
|
|
- expanded_context (str): The accumulated context from traversed nodes.
|
|
- traversal_path (List[int]): The sequence of node indices visited.
|
|
- filtered_content (Dict[int, str]): A mapping of node indices to their content.
|
|
- final_answer (str): The final answer found, if any.
|
|
"""
|
|
# Initialize variables
|
|
expanded_context = ""
|
|
traversal_path = []
|
|
visited_concepts = set()
|
|
filtered_content = {}
|
|
final_answer = ""
|
|
|
|
priority_queue = []
|
|
distances = {} # Stores the best known "distance" (inverse of connection strength) to each node
|
|
|
|
print("\nTraversing the knowledge graph:")
|
|
|
|
# Initialize priority queue with closest nodes from relevant docs
|
|
for doc in relevant_docs:
|
|
# Find the most similar node in the knowledge graph for each relevant document
|
|
closest_nodes = self.vector_store.similarity_search_with_score(doc.page_content, k=1)
|
|
closest_node_content, similarity_score = closest_nodes[0]
|
|
|
|
# Get the corresponding node in our knowledge graph
|
|
closest_node = next(n for n in self.knowledge_graph.graph.nodes if
|
|
self.knowledge_graph.graph.nodes[n]['content'] == closest_node_content.page_content)
|
|
|
|
# Initialize priority (inverse of similarity score for min-heap behavior)
|
|
priority = 1 / similarity_score
|
|
heapq.heappush(priority_queue, (priority, closest_node))
|
|
distances[closest_node] = priority
|
|
|
|
step = 0
|
|
while priority_queue:
|
|
# Get the node with the highest priority (lowest distance value)
|
|
current_priority, current_node = heapq.heappop(priority_queue)
|
|
|
|
# Skip if we've already found a better path to this node
|
|
if current_priority > distances.get(current_node, float('inf')):
|
|
continue
|
|
|
|
if current_node not in traversal_path:
|
|
step += 1
|
|
traversal_path.append(current_node)
|
|
node_content = self.knowledge_graph.graph.nodes[current_node]['content']
|
|
node_concepts = self.knowledge_graph.graph.nodes[current_node]['concepts']
|
|
|
|
# Add node content to our accumulated context
|
|
filtered_content[current_node] = node_content
|
|
expanded_context += "\n" + node_content if expanded_context else node_content
|
|
|
|
# Log the current step for debugging and visualization
|
|
print(f"\nStep {step} - Node {current_node}:")
|
|
print(f"Content: {node_content[:100]}...")
|
|
print(f"Concepts: {', '.join(node_concepts)}")
|
|
print("-" * 50)
|
|
|
|
# Check if we have a complete answer with the current context
|
|
is_complete, answer = self._check_answer(query, expanded_context)
|
|
if is_complete:
|
|
final_answer = answer
|
|
break
|
|
|
|
# Process the concepts of the current node
|
|
node_concepts_set = set(self.knowledge_graph._lemmatize_concept(c) for c in node_concepts)
|
|
if not node_concepts_set.issubset(visited_concepts):
|
|
visited_concepts.update(node_concepts_set)
|
|
|
|
# Explore neighbors
|
|
for neighbor in self.knowledge_graph.graph.neighbors(current_node):
|
|
edge_data = self.knowledge_graph.graph[current_node][neighbor]
|
|
edge_weight = edge_data['weight']
|
|
|
|
# Calculate new distance (priority) to the neighbor
|
|
# Note: We use 1 / edge_weight because higher weights mean stronger connections
|
|
distance = current_priority + (1 / edge_weight)
|
|
|
|
# If we've found a stronger connection to the neighbor, update its distance
|
|
if distance < distances.get(neighbor, float('inf')):
|
|
distances[neighbor] = distance
|
|
heapq.heappush(priority_queue, (distance, neighbor))
|
|
|
|
# Process the neighbor node if it's not already in our traversal path
|
|
if neighbor not in traversal_path:
|
|
step += 1
|
|
traversal_path.append(neighbor)
|
|
neighbor_content = self.knowledge_graph.graph.nodes[neighbor]['content']
|
|
neighbor_concepts = self.knowledge_graph.graph.nodes[neighbor]['concepts']
|
|
|
|
filtered_content[neighbor] = neighbor_content
|
|
expanded_context += "\n" + neighbor_content if expanded_context else neighbor_content
|
|
|
|
# Log the neighbor node information
|
|
print(f"\nStep {step} - Node {neighbor} (neighbor of {current_node}):")
|
|
print(f"Content: {neighbor_content[:100]}...")
|
|
print(f"Concepts: {', '.join(neighbor_concepts)}")
|
|
print("-" * 50)
|
|
|
|
# Check if we have a complete answer after adding the neighbor's content
|
|
is_complete, answer = self._check_answer(query, expanded_context)
|
|
if is_complete:
|
|
final_answer = answer
|
|
break
|
|
|
|
# Process the neighbor's concepts
|
|
neighbor_concepts_set = set(
|
|
self.knowledge_graph._lemmatize_concept(c) for c in neighbor_concepts)
|
|
if not neighbor_concepts_set.issubset(visited_concepts):
|
|
visited_concepts.update(neighbor_concepts_set)
|
|
|
|
# If we found a final answer, break out of the main loop
|
|
if final_answer:
|
|
break
|
|
|
|
# If we haven't found a complete answer, generate one using the LLM
|
|
if not final_answer:
|
|
print("\nGenerating final answer...")
|
|
response_prompt = PromptTemplate(
|
|
input_variables=["query", "context"],
|
|
template="Based on the following context, please answer the query.\n\nContext: {context}\n\nQuery: {query}\n\nAnswer:"
|
|
)
|
|
response_chain = response_prompt | self.llm
|
|
input_data = {"query": query, "context": expanded_context}
|
|
final_answer = response_chain.invoke(input_data)
|
|
|
|
return expanded_context, traversal_path, filtered_content, final_answer
|
|
|
|
def query(self, query: str) -> Tuple[str, List[int], Dict[int, str]]:
|
|
"""
|
|
Processes a query by retrieving relevant documents, expanding the context, and generating the final answer.
|
|
|
|
Args:
|
|
- query (str): The query to be answered.
|
|
|
|
Returns:
|
|
- tuple: A tuple containing:
|
|
- final_answer (str): The final answer to the query.
|
|
- traversal_path (list): The traversal path of nodes in the knowledge graph.
|
|
- filtered_content (dict): The filtered content of nodes.
|
|
"""
|
|
with get_openai_callback() as cb:
|
|
print(f"\nProcessing query: {query}")
|
|
relevant_docs = self._retrieve_relevant_documents(query)
|
|
expanded_context, traversal_path, filtered_content, final_answer = self._expand_context(query,
|
|
relevant_docs)
|
|
|
|
if not final_answer:
|
|
print("\nGenerating final answer...")
|
|
response_prompt = PromptTemplate(
|
|
input_variables=["query", "context"],
|
|
template="Based on the following context, please answer the query.\n\nContext: {context}\n\nQuery: {query}\n\nAnswer:"
|
|
)
|
|
|
|
response_chain = response_prompt | self.llm
|
|
input_data = {"query": query, "context": expanded_context}
|
|
response = response_chain.invoke(input_data)
|
|
final_answer = response
|
|
else:
|
|
print("\nComplete answer found during traversal.")
|
|
|
|
print(f"\nFinal Answer: {final_answer}")
|
|
print(f"\nTotal Tokens: {cb.total_tokens}")
|
|
print(f"Prompt Tokens: {cb.prompt_tokens}")
|
|
print(f"Completion Tokens: {cb.completion_tokens}")
|
|
print(f"Total Cost (USD): ${cb.total_cost}")
|
|
|
|
return final_answer, traversal_path, filtered_content
|
|
|
|
def _retrieve_relevant_documents(self, query: str):
|
|
"""
|
|
Retrieves relevant documents based on the query using the vector store.
|
|
|
|
Args:
|
|
- query (str): The query to be answered.
|
|
|
|
Returns:
|
|
- list: A list of relevant documents.
|
|
"""
|
|
print("\nRetrieving relevant documents...")
|
|
retriever = self.vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
|
compressor = LLMChainExtractor.from_llm(self.llm)
|
|
compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
|
|
return compression_retriever.invoke(query)
|
|
|
|
|
|
# Import necessary libraries
|
|
import networkx as nx
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as patches
|
|
|
|
|
|
# Define the Visualizer class
|
|
class Visualizer:
|
|
@staticmethod
|
|
def visualize_traversal(graph, traversal_path):
|
|
"""
|
|
Visualizes the traversal path on the knowledge graph with nodes, edges, and traversal path highlighted.
|
|
|
|
Args:
|
|
- graph (networkx.Graph): The knowledge graph containing nodes and edges.
|
|
- traversal_path (list of int): The list of node indices representing the traversal path.
|
|
|
|
Returns:
|
|
- None
|
|
"""
|
|
traversal_graph = nx.DiGraph()
|
|
|
|
# Add nodes and edges from the original graph
|
|
for node in graph.nodes():
|
|
traversal_graph.add_node(node)
|
|
for u, v, data in graph.edges(data=True):
|
|
traversal_graph.add_edge(u, v, **data)
|
|
|
|
fig, ax = plt.subplots(figsize=(16, 12))
|
|
|
|
# Generate positions for all nodes
|
|
pos = nx.spring_layout(traversal_graph, k=1, iterations=50)
|
|
|
|
# Draw regular edges with color based on weight
|
|
edges = traversal_graph.edges()
|
|
edge_weights = [traversal_graph[u][v].get('weight', 0.5) for u, v in edges]
|
|
nx.draw_networkx_edges(traversal_graph, pos,
|
|
edgelist=edges,
|
|
edge_color=edge_weights,
|
|
edge_cmap=plt.cm.Blues,
|
|
width=2,
|
|
ax=ax)
|
|
|
|
# Draw nodes
|
|
nx.draw_networkx_nodes(traversal_graph, pos,
|
|
node_color='lightblue',
|
|
node_size=3000,
|
|
ax=ax)
|
|
|
|
# Draw traversal path with curved arrows
|
|
edge_offset = 0.1
|
|
for i in range(len(traversal_path) - 1):
|
|
start = traversal_path[i]
|
|
end = traversal_path[i + 1]
|
|
start_pos = pos[start]
|
|
end_pos = pos[end]
|
|
|
|
# Calculate control point for curve
|
|
mid_point = ((start_pos[0] + end_pos[0]) / 2, (start_pos[1] + end_pos[1]) / 2)
|
|
control_point = (mid_point[0] + edge_offset, mid_point[1] + edge_offset)
|
|
|
|
# Draw curved arrow
|
|
arrow = patches.FancyArrowPatch(start_pos, end_pos,
|
|
connectionstyle=f"arc3,rad={0.3}",
|
|
color='red',
|
|
arrowstyle="->",
|
|
mutation_scale=20,
|
|
linestyle='--',
|
|
linewidth=2,
|
|
zorder=4)
|
|
ax.add_patch(arrow)
|
|
|
|
# Prepare labels for the nodes
|
|
labels = {}
|
|
for i, node in enumerate(traversal_path):
|
|
concepts = graph.nodes[node].get('concepts', [])
|
|
label = f"{i + 1}. {concepts[0] if concepts else ''}"
|
|
labels[node] = label
|
|
|
|
for node in traversal_graph.nodes():
|
|
if node not in labels:
|
|
concepts = graph.nodes[node].get('concepts', [])
|
|
labels[node] = concepts[0] if concepts else ''
|
|
|
|
# Draw labels
|
|
nx.draw_networkx_labels(traversal_graph, pos, labels, font_size=8, font_weight="bold", ax=ax)
|
|
|
|
# Highlight start and end nodes
|
|
start_node = traversal_path[0]
|
|
end_node = traversal_path[-1]
|
|
|
|
nx.draw_networkx_nodes(traversal_graph, pos,
|
|
nodelist=[start_node],
|
|
node_color='lightgreen',
|
|
node_size=3000,
|
|
ax=ax)
|
|
|
|
nx.draw_networkx_nodes(traversal_graph, pos,
|
|
nodelist=[end_node],
|
|
node_color='lightcoral',
|
|
node_size=3000,
|
|
ax=ax)
|
|
|
|
ax.set_title("Graph Traversal Flow")
|
|
ax.axis('off')
|
|
|
|
# Add colorbar for edge weights
|
|
sm = plt.cm.ScalarMappable(cmap=plt.cm.Blues,
|
|
norm=plt.Normalize(vmin=min(edge_weights), vmax=max(edge_weights)))
|
|
sm.set_array([])
|
|
cbar = fig.colorbar(sm, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
|
|
cbar.set_label('Edge Weight', rotation=270, labelpad=15)
|
|
|
|
# Add legend
|
|
regular_line = plt.Line2D([0], [0], color='blue', linewidth=2, label='Regular Edge')
|
|
traversal_line = plt.Line2D([0], [0], color='red', linewidth=2, linestyle='--', label='Traversal Path')
|
|
start_point = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen', markersize=15,
|
|
label='Start Node')
|
|
end_point = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightcoral', markersize=15,
|
|
label='End Node')
|
|
legend = plt.legend(handles=[regular_line, traversal_line, start_point, end_point], loc='upper left',
|
|
bbox_to_anchor=(0, 1), ncol=2)
|
|
legend.get_frame().set_alpha(0.8)
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
@staticmethod
|
|
def print_filtered_content(traversal_path, filtered_content):
|
|
"""
|
|
Prints the filtered content of visited nodes in the order of traversal.
|
|
|
|
Args:
|
|
- traversal_path (list of int): The list of node indices representing the traversal path.
|
|
- filtered_content (dict of int: str): A dictionary mapping node indices to their filtered content.
|
|
|
|
Returns:
|
|
- None
|
|
"""
|
|
print("\nFiltered content of visited nodes in order of traversal:")
|
|
for i, node in enumerate(traversal_path):
|
|
print(f"\nStep {i + 1} - Node {node}:")
|
|
print(
|
|
f"Filtered Content: {filtered_content.get(node, 'No filtered content available')[:200]}...") # Print first 200 characters
|
|
print("-" * 50)
|
|
|
|
|
|
# Define the graph RAG class
|
|
class GraphRAG:
|
|
def __init__(self, documents):
|
|
"""
|
|
Initializes the GraphRAG system with components for document processing, knowledge graph construction,
|
|
querying, and visualization.
|
|
|
|
Args:
|
|
- documents (list of str): A list of documents to be processed.
|
|
|
|
Attributes:
|
|
- llm: An instance of a large language model (LLM) for generating responses.
|
|
- embedding_model: An instance of an embedding model for document embeddings.
|
|
- document_processor: An instance of the DocumentProcessor class for processing documents.
|
|
- knowledge_graph: An instance of the KnowledgeGraph class for building and managing the knowledge graph.
|
|
- query_engine: An instance of the QueryEngine class for handling queries (initialized as None).
|
|
- visualizer: An instance of the Visualizer class for visualizing the knowledge graph traversal.
|
|
"""
|
|
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
|
|
self.embedding_model = OpenAIEmbeddings()
|
|
self.document_processor = DocumentProcessor()
|
|
self.knowledge_graph = KnowledgeGraph()
|
|
self.query_engine = None
|
|
self.visualizer = Visualizer()
|
|
self.process_documents(documents)
|
|
|
|
def process_documents(self, documents):
|
|
"""
|
|
Processes a list of documents by splitting them into chunks, embedding them, and building a knowledge graph.
|
|
|
|
Args:
|
|
- documents (list of str): A list of documents to be processed.
|
|
|
|
Returns:
|
|
- None
|
|
"""
|
|
splits, vector_store = self.document_processor.process_documents(documents)
|
|
self.knowledge_graph.build_graph(splits, self.llm, self.embedding_model)
|
|
self.query_engine = QueryEngine(vector_store, self.knowledge_graph, self.llm)
|
|
|
|
def query(self, query: str):
|
|
"""
|
|
Handles a query by retrieving relevant information from the knowledge graph and visualizing the traversal path.
|
|
|
|
Args:
|
|
- query (str): The query to be answered.
|
|
|
|
Returns:
|
|
- str: The response to the query.
|
|
"""
|
|
response, traversal_path, filtered_content = self.query_engine.query(query)
|
|
|
|
if traversal_path:
|
|
self.visualizer.visualize_traversal(self.knowledge_graph.graph, traversal_path)
|
|
else:
|
|
print("No traversal path to visualize.")
|
|
|
|
return response
|
|
|
|
|
|
# Argument parsing
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="GraphRAG system")
|
|
parser.add_argument('--path', type=str, default="../data/Understanding_Climate_Change.pdf",
|
|
help='Path to the PDF file.')
|
|
parser.add_argument('--query', type=str, default='what is the main cause of climate change?',
|
|
help='Query to retrieve documents.')
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
# Load the documents
|
|
loader = PyPDFLoader(args.path)
|
|
documents = loader.load()
|
|
documents = documents[:10]
|
|
|
|
# Create a graph RAG instance
|
|
graph_rag = GraphRAG(documents)
|
|
|
|
# Process the documents and create the graph
|
|
graph_rag.process_documents(documents)
|
|
|
|
# Input a query and get the retrieved information from the graph RAG
|
|
response = graph_rag.query(args.query)
|