mirror of
https://github.com/NirDiamant/RAG_Techniques.git
synced 2025-04-07 00:48:52 +03:00
87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
import os
|
|
import sys
|
|
from dotenv import load_dotenv
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path
|
|
from helper_functions import *
|
|
from evaluation.evalute_rag import *
|
|
|
|
# Load environment variables from a .env file
|
|
load_dotenv()
|
|
|
|
# Set the OpenAI API key environment variable
|
|
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
|
|
|
|
|
|
# Define utility classes/functions
|
|
class ExplainableRetriever:
|
|
def __init__(self, texts):
|
|
self.embeddings = OpenAIEmbeddings()
|
|
self.vectorstore = FAISS.from_texts(texts, self.embeddings)
|
|
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
|
|
self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
|
|
|
|
explain_prompt = PromptTemplate(
|
|
input_variables=["query", "context"],
|
|
template="""
|
|
Analyze the relationship between the following query and the retrieved context.
|
|
Explain why this context is relevant to the query and how it might help answer the query.
|
|
|
|
Query: {query}
|
|
|
|
Context: {context}
|
|
|
|
Explanation:
|
|
"""
|
|
)
|
|
self.explain_chain = explain_prompt | self.llm
|
|
|
|
def retrieve_and_explain(self, query):
|
|
docs = self.retriever.get_relevant_documents(query)
|
|
explained_results = []
|
|
|
|
for doc in docs:
|
|
input_data = {"query": query, "context": doc.page_content}
|
|
explanation = self.explain_chain.invoke(input_data).content
|
|
explained_results.append({
|
|
"content": doc.page_content,
|
|
"explanation": explanation
|
|
})
|
|
return explained_results
|
|
|
|
|
|
class ExplainableRAGMethod:
|
|
def __init__(self, texts):
|
|
self.explainable_retriever = ExplainableRetriever(texts)
|
|
|
|
def run(self, query):
|
|
return self.explainable_retriever.retrieve_and_explain(query)
|
|
|
|
|
|
# Argument Parsing
|
|
def parse_args():
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="Explainable RAG Method")
|
|
parser.add_argument('--query', type=str, default='Why is the sky blue?', help="Query for the retriever")
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
|
|
# Sample texts (these can be replaced by actual data)
|
|
texts = [
|
|
"The sky is blue because of the way sunlight interacts with the atmosphere.",
|
|
"Photosynthesis is the process by which plants use sunlight to produce energy.",
|
|
"Global warming is caused by the increase of greenhouse gases in Earth's atmosphere."
|
|
]
|
|
|
|
explainable_rag = ExplainableRAGMethod(texts)
|
|
results = explainable_rag.run(args.query)
|
|
|
|
for i, result in enumerate(results, 1):
|
|
print(f"Result {i}:")
|
|
print(f"Content: {result['content']}")
|
|
print(f"Explanation: {result['explanation']}")
|
|
print()
|