Files
graphrag-ancient-history/gui/rag_service.py
2025-05-11 21:09:46 +03:00

136 lines
4.5 KiB
Python

import os
import numpy as np
from typing import Optional, List
from loguru import logger
from openai import AzureOpenAI
from lightrag import LightRAG, QueryParam
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.utils import EmbeddingFunc
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
class RAGService:
"""Service class for RAG operations."""
def __init__(self):
self.rag: Optional[LightRAG] = None
# Azure OpenAI for LLM
async def llm_model_func(
self, prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
client = AzureOpenAI(
api_key=os.environ["AZURE_OPENAI_API_KEY"],
api_version=os.environ["AZURE_OPENAI_API_VERSION"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
chat_completion = client.chat.completions.create(
model=os.environ["AZURE_OPENAI_DEPLOYMENT"],
messages=messages,
temperature=kwargs.get("temperature", 0),
top_p=kwargs.get("top_p", 1),
n=kwargs.get("n", 1),
)
return chat_completion.choices[0].message.content
# vLLM for embeddings
async def embedding_func(self, texts: List[str]) -> np.ndarray:
try:
return await openai_embed(
texts,
model=os.environ["EMBEDDING_MODEL"],
api_key="anything",
base_url=os.environ["VLLM_EMBED_HOST"],
)
except Exception as e:
logger.error(f"Error in embedding call: {e}")
raise
async def get_embedding_dim(self) -> int:
"""Get embedding dimension by testing with a sample text."""
test_text = ["This is a test sentence."]
embedding = await self.embedding_func(test_text)
return embedding.shape[1]
async def initialize(self):
"""Initialize the RAG system."""
try:
knowledge_graph_path = os.environ["KNOWLEDGE_GRAPH_PATH"]
# Get embedding dimension dynamically
embedding_dimension = await self.get_embedding_dim()
logger.info(f"Detected embedding dimension: {embedding_dimension}")
self.rag = LightRAG(
working_dir=knowledge_graph_path,
graph_storage="NetworkXStorage",
kv_storage="JsonKVStorage",
vector_storage="FaissVectorDBStorage",
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": 0.2
},
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=8192,
func=self.embedding_func
),
llm_model_func=self.llm_model_func,
enable_llm_cache=True,
enable_llm_cache_for_entity_extract=False,
embedding_cache_config={
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False
},
)
# Initialize storages
await self.rag.initialize_storages()
await initialize_pipeline_status()
logger.success("RAG system initialized successfully")
except Exception as e:
logger.error(f"Error initializing RAG: {e}")
raise
async def query(
self,
question: str,
mode: str = "mix",
response_type: str = "Multiple Paragraphs"
) -> str:
"""Query the RAG system."""
if not self.rag:
raise RuntimeError("RAG system not initialized")
try:
response = await self.rag.aquery(
question,
param=QueryParam(
mode=mode,
response_type=response_type,
only_need_context=False,
)
)
return response
except Exception as e:
logger.error(f"Error processing query: {e}")
raise
def is_initialized(self) -> bool:
"""Check if RAG is initialized."""
return self.rag is not None
rag_service = RAGService()