Files
graphrag-ancient-history/inference.py
2025-05-11 21:09:46 +03:00

146 lines
4.4 KiB
Python
Executable File

import os
import asyncio
import numpy as np
from typing import List
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.utils import setup_logger, EmbeddingFunc
from loguru import logger
from openai import AzureOpenAI
# Setup environment and logging
setup_logger("lightrag", level="INFO")
def get_required_env(name):
value = os.environ.get(name)
if not value:
raise ValueError(f"Missing required environment variable: {name}")
return value
""" LLM vLLM
async def llm_model_func(prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs):
try:
return await openai_complete_if_cache(
model=os.environ["LLM_MODEL"],
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key="anything",
base_url=os.environ["VLLM_LLM_HOST"],
**kwargs,
)
except Exception as e:
logger.error(f"Error in LLM call: {e}")
raise
"""
""" LLM Azure OpenAI"""
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
client = AzureOpenAI(
api_key=os.environ["AZURE_OPENAI_API_KEY"],
api_version=os.environ["AZURE_OPENAI_API_VERSION"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
chat_completion = client.chat.completions.create(
model=os.environ["AZURE_OPENAI_DEPLOYMENT"],
messages=messages,
temperature=kwargs.get("temperature", 0),
top_p=kwargs.get("top_p", 1),
n=kwargs.get("n", 1),
)
return chat_completion.choices[0].message.content
async def embedding_func(texts: List[str]) -> np.ndarray:
try:
return await openai_embed(
texts,
model=os.environ["EMBEDDING_MODEL"],
api_key="anything",
base_url=os.environ["VLLM_EMBED_HOST"],
)
except Exception as e:
logger.error(f"Error in embedding call: {e}")
raise
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
return embedding_dim
async def initialize_rag():
try:
knowledge_graph_path = get_required_env("KNOWLEDGE_GRAPH_PATH")
# Get embedding dimension dynamically
embedding_dimension = await get_embedding_dim()
logger.info(f"Detected embedding dimension: {embedding_dimension}")
rag = LightRAG(
working_dir=knowledge_graph_path,
graph_storage="NetworkXStorage",
kv_storage="JsonKVStorage",
vector_storage="FaissVectorDBStorage",
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": 0.2
},
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=8192,
func=embedding_func
),
llm_model_func=llm_model_func,
enable_llm_cache=True,
enable_llm_cache_for_entity_extract=False,
embedding_cache_config={
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False
},
)
# Initialize storages properly
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
except Exception as e:
logger.error(f"Error initializing RAG: {e}")
raise
def main():
rag = asyncio.run(initialize_rag())
mode = "mix"
response = rag.query(
"Giants in Holy texts? In terms of monotheistic, polytesitic, ateistic, agnostic and deistic approaches",
param=QueryParam(
mode=mode,
response_type="in bullet points and description for each bullet point",
only_need_context=False,
# conversation_history=,
# history_turns=5,
)
)
print(response)
if __name__ == "__main__":
main()