implement basic webapp for anne
This commit is contained in:
64
README.md
Normal file
64
README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Pre-historic Knowledge Assistant Web App
|
||||
|
||||
A web application for the RAG-based knowledge assistant.
|
||||
|
||||
## Features
|
||||
|
||||
- Multiple LLM provider support (Azure OpenAI, OpenAI, Ollama, vLLM, custom endpoints)
|
||||
- Flexible embedding configuration
|
||||
- Web interface with real-time responses
|
||||
- REST API endpoints
|
||||
- Health check endpoint
|
||||
|
||||
## Setup
|
||||
|
||||
1. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Create `.env` file from `.env.example`:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env with your configuration
|
||||
```
|
||||
|
||||
3. Run the application:
|
||||
```bash
|
||||
python app.py
|
||||
# or
|
||||
./run.sh
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### LLM Providers
|
||||
|
||||
The application supports multiple LLM providers:
|
||||
|
||||
- **Azure OpenAI**: Set `LLM_PROVIDER=azure_openai` and configure Azure credentials
|
||||
- **OpenAI**: Set `LLM_PROVIDER=openai` and provide API key
|
||||
- **Ollama**: Set `LLM_PROVIDER=ollama` and configure host URL
|
||||
- **vLLM**: Set `LLM_PROVIDER=vllm` and configure vLLM host
|
||||
- **Custom**: Set `LLM_PROVIDER=custom` for any OpenAI-compatible endpoint
|
||||
|
||||
### Environment Variables
|
||||
|
||||
See `.env.example` for all available configuration options.
|
||||
|
||||
## API Endpoints
|
||||
|
||||
- `GET /`: Web interface
|
||||
- `POST /ask`: Process question via web form
|
||||
- `POST /api/ask`: REST API endpoint for questions
|
||||
- `GET /api/health`: Health check endpoint
|
||||
|
||||
## Development
|
||||
|
||||
To run in development mode with auto-reload:
|
||||
|
||||
```bash
|
||||
python app.py
|
||||
```
|
||||
|
||||
The application will be available at `http://localhost:8000`.
|
||||
220
create.py
220
create.py
@@ -1,143 +1,141 @@
|
||||
import glob
|
||||
import os
|
||||
import asyncio
|
||||
import statistics
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, List, Any
|
||||
import asyncio
|
||||
from typing import Dict, List, Any
|
||||
import numpy as np
|
||||
|
||||
import aiofiles
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.azure_openai import azure_openai_embed, azure_openai_complete
|
||||
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
|
||||
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
|
||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
from lightrag.utils import setup_logger, EmbeddingFunc
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
setup_logger("lightrag", level="INFO")
|
||||
# Setup environment and logging
|
||||
setup_logger("lightrag", level="DEBUG")
|
||||
|
||||
|
||||
def get_required_env(name):
|
||||
value = os.environ.get(name)
|
||||
if not value:
|
||||
raise ValueError(f"Missing required environment variable: {name}")
|
||||
return value
|
||||
|
||||
|
||||
def read_text_file(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
text = file.read()
|
||||
return text
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
return file.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file {file_path}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def get_text_statistics(text):
|
||||
"""Calculate statistics for the given text."""
|
||||
char_count = len(text)
|
||||
word_count = len(text.split())
|
||||
line_count = len(text.splitlines())
|
||||
return {
|
||||
'char_count': char_count,
|
||||
'word_count': word_count,
|
||||
'line_count': line_count
|
||||
}
|
||||
async def llm_model_func(prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs):
|
||||
try:
|
||||
return await openai_complete_if_cache(
|
||||
model=os.environ["LLM_MODEL"],
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key="anything",
|
||||
base_url=os.environ["VLLM_LLM_HOST"],
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in LLM call: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def with_statistics(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(file_path: str, *args, **kwargs) -> Dict[str, Any]:
|
||||
# Read the file
|
||||
text = read_text_file(file_path)
|
||||
|
||||
# Get text statistics
|
||||
stats = get_text_statistics(text)
|
||||
file_size = os.path.getsize(file_path)
|
||||
stats['file_size'] = file_size
|
||||
stats['file_name'] = os.path.basename(file_path)
|
||||
|
||||
# Log individual file statistics
|
||||
logger.debug(f"File: {stats['file_name']}")
|
||||
logger.debug(f" - Size: {file_size} bytes")
|
||||
logger.debug(f" - Characters: {stats['char_count']}")
|
||||
logger.debug(f" - Words: {stats['word_count']}")
|
||||
logger.debug(f" - Lines: {stats['line_count']}")
|
||||
|
||||
# Call the original function
|
||||
result = func(text, *args, **kwargs)
|
||||
|
||||
return {
|
||||
'result': result,
|
||||
'stats': stats
|
||||
}
|
||||
|
||||
return wrapper
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
try:
|
||||
return await openai_embed(
|
||||
texts,
|
||||
model=os.environ["EMBEDDING_MODEL"],
|
||||
api_key="anything",
|
||||
base_url=os.environ["VLLM_EMBED_HOST"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in embedding call: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def initilize_rag_ollama():
|
||||
rag = LightRAG(
|
||||
working_dir=os.environ["KNOWLEDGE_GRAPH_PATH"],
|
||||
graph_storage="NetworkXStorage", # "Neo4JStorage",
|
||||
kv_storage="JsonKVStorage",
|
||||
vector_storage="FaissVectorDBStorage",
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": 0.25
|
||||
},
|
||||
llm_model_func=ollama_model_complete,
|
||||
llm_model_name=os.environ["OLLAMA_LLM_MODEL"],
|
||||
llm_model_kwargs={
|
||||
"host": os.environ["OLLAMA_LLM_HOST"],
|
||||
"options": {"num_ctx": 40000},
|
||||
},
|
||||
enable_llm_cache=False,
|
||||
enable_llm_cache_for_entity_extract=False,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=1024,
|
||||
max_token_size=8192,
|
||||
func=lambda texts: ollama_embed(
|
||||
texts,
|
||||
embed_model=os.environ["OLLAMA_EMBED_MODEL"],
|
||||
host=os.environ["OLLAMA_EMBED_HOST"]
|
||||
async def get_embedding_dim():
|
||||
test_text = ["This is a test sentence."]
|
||||
embedding = await embedding_func(test_text)
|
||||
embedding_dim = embedding.shape[1]
|
||||
return embedding_dim
|
||||
|
||||
|
||||
async def initialize_rag():
|
||||
try:
|
||||
knowledge_graph_path = get_required_env("KNOWLEDGE_GRAPH_PATH")
|
||||
|
||||
# Get embedding dimension dynamically
|
||||
embedding_dimension = await get_embedding_dim()
|
||||
logger.info(f"Detected embedding dimension: {embedding_dimension}")
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=knowledge_graph_path,
|
||||
graph_storage="NetworkXStorage",
|
||||
kv_storage="JsonKVStorage",
|
||||
vector_storage="FaissVectorDBStorage",
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": 0.2
|
||||
},
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=embedding_dimension,
|
||||
max_token_size=8192,
|
||||
func=embedding_func
|
||||
),
|
||||
),
|
||||
embedding_cache_config={
|
||||
"enabled": False,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False
|
||||
},
|
||||
)
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
return rag
|
||||
llm_model_func=llm_model_func,
|
||||
enable_llm_cache=False,
|
||||
enable_llm_cache_for_entity_extract=False,
|
||||
embedding_cache_config={
|
||||
"enabled": False,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False
|
||||
},
|
||||
)
|
||||
|
||||
# Initialize storages properly
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
return rag
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing RAG: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
logger.info("Initializing lightRAG instance")
|
||||
rag = asyncio.run(initilize_rag_ollama())
|
||||
try:
|
||||
# Initialize RAG
|
||||
logger.info("Initializing LightRAG instance")
|
||||
rag = asyncio.run(initialize_rag())
|
||||
|
||||
input_dir_path = "/Users/tcudikel/Dev/ancient-history/data/input/transcripts"
|
||||
txt_files = glob.glob(f"{input_dir_path}/*.txt")
|
||||
logger.info(f"Found {len(txt_files)} text files in {input_dir_path}")
|
||||
# Find text files
|
||||
input_dir_path = get_required_env("KNOWLEDGE_GRAPH_INPUT_DIR_PATH")
|
||||
texts = []
|
||||
txt_files = glob.glob(f"{input_dir_path}/*.txt")
|
||||
for txt_file in txt_files:
|
||||
txt = read_text_file(txt_file)
|
||||
texts.append(txt)
|
||||
|
||||
# Collect statistics
|
||||
all_stats = []
|
||||
logger.info(f"Found {len(txt_files)} text files in {input_dir_path}")
|
||||
|
||||
@with_statistics
|
||||
def process_file(text, rag):
|
||||
return rag.insert(text)
|
||||
if not txt_files:
|
||||
logger.warning(f"No text files found in {input_dir_path}")
|
||||
return
|
||||
|
||||
for file_path in tqdm(txt_files, desc="Processing files", unit="file", miniters=1, ncols=100, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'):
|
||||
result_with_stats = process_file(file_path, rag)
|
||||
all_stats.append(result_with_stats['stats'])
|
||||
# results = await process_files(rag, txt_files)
|
||||
|
||||
# Calculate and log summary statistics
|
||||
if all_stats:
|
||||
char_counts = [stat['char_count'] for stat in all_stats]
|
||||
word_counts = [stat['word_count'] for stat in all_stats]
|
||||
line_counts = [stat['line_count'] for stat in all_stats]
|
||||
results = rag.insert(texts)
|
||||
|
||||
logger.info("Text statistics summary:")
|
||||
logger.info(f" - Total characters: {sum(char_counts)}")
|
||||
logger.info(f" - Total words: {sum(word_counts)}")
|
||||
logger.info(f" - Total lines: {sum(line_counts)}")
|
||||
logger.info(f" - Average characters per file: {statistics.mean(char_counts):.2f}")
|
||||
logger.info(f" - Average words per file: {statistics.mean(word_counts):.2f}")
|
||||
logger.info(f" - Average lines per file: {statistics.mean(line_counts):.2f}")
|
||||
|
||||
logger.success(f"{len(txt_files)} files inserted into the knowledge graph.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main process: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
0
gui/__init__.py
Normal file
0
gui/__init__.py
Normal file
125
gui/app.py
Normal file
125
gui/app.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
from fastapi import FastAPI, Request, Form, HTTPException
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.responses import HTMLResponse
|
||||
from contextlib import asynccontextmanager
|
||||
from loguru import logger
|
||||
import uvicorn
|
||||
|
||||
from rag_service import rag_service
|
||||
from models import QuestionRequest, QuestionResponse, HealthResponse
|
||||
|
||||
|
||||
# Template configuration
|
||||
templates = Jinja2Templates(directory=os.environ["TEMPLATES_DIR"])
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan handler."""
|
||||
try:
|
||||
logger.info("Starting application...")
|
||||
logger.info("Initializing RAG system...")
|
||||
await rag_service.initialize()
|
||||
logger.success("Application started successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize RAG system: {e}")
|
||||
# Don't exit - allow the app to run for diagnostics
|
||||
|
||||
yield # Application runs here
|
||||
|
||||
logger.info("Shutting down application...")
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Pre-historic Knowledge Assistant",
|
||||
description="A RAG-based knowledge assistant",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
|
||||
# Web interface routes
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def home(request: Request):
|
||||
"""Render the main page."""
|
||||
return templates.TemplateResponse(
|
||||
"index.html", {"request": request}
|
||||
)
|
||||
|
||||
|
||||
@app.post("/ask", response_class=HTMLResponse)
|
||||
async def ask_question(request: Request, question: str = Form(...)):
|
||||
"""Process a question and return the response."""
|
||||
if not rag_service.is_initialized():
|
||||
return templates.TemplateResponse(
|
||||
"index.html",
|
||||
{
|
||||
"request": request,
|
||||
"error": "RAG system not available. Please restart the server."
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
response = await rag_service.query(question)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"index.html",
|
||||
{
|
||||
"request": request,
|
||||
"question": question,
|
||||
"answer": response
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing question: {e}")
|
||||
return templates.TemplateResponse(
|
||||
"index.html",
|
||||
{
|
||||
"request": request,
|
||||
"question": question,
|
||||
"error": f"Error: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# API routes
|
||||
@app.post("/api/ask", response_model=QuestionResponse)
|
||||
async def ask_question_api(request: QuestionRequest):
|
||||
"""API endpoint for question processing."""
|
||||
if not rag_service.is_initialized():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="RAG system not available"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await rag_service.query(
|
||||
request.question,
|
||||
mode=request.mode,
|
||||
response_type=request.response_type
|
||||
)
|
||||
return QuestionResponse(question=request.question, answer=response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing question: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/api/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return HealthResponse(
|
||||
status="healthy" if rag_service.is_initialized() else "unhealthy",
|
||||
rag_initialized=rag_service.is_initialized()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"gui.app:app",
|
||||
host="0.0.0.0",
|
||||
port=int(os.environ["WEBUI_PORT"]),
|
||||
reload=True
|
||||
)
|
||||
188
gui/index.html
Normal file
188
gui/index.html
Normal file
@@ -0,0 +1,188 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en" class="h-full">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Pre-historic Knowledge Assistant</title>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<style>
|
||||
/* Custom animations */
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(10px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
.fade-in {
|
||||
animation: fadeIn 0.5s ease-out;
|
||||
}
|
||||
/* Custom scrollbar */
|
||||
::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
::-webkit-scrollbar-track {
|
||||
background: #f1f5f9;
|
||||
}
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: #cbd5e1;
|
||||
border-radius: 3px;
|
||||
}
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: #94a3b8;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body class="h-full bg-gradient-to-br from-slate-50 to-slate-100">
|
||||
<div class="min-h-full flex flex-col">
|
||||
<!-- Header -->
|
||||
<header class="bg-white border-b border-slate-200 sticky top-0 z-10 backdrop-blur-sm bg-white/90">
|
||||
<div class="max-w-3xl mx-auto px-4 py-4">
|
||||
<h1 class="text-2xl font-bold text-slate-900 flex items-center gap-2">
|
||||
<svg class="w-8 h-8 text-blue-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9.663 17h4.673M12 3v1m6.364 1.636l-.707.707M21 12h-1M4 12H3m3.343-5.657l-.707-.707m2.828 9.9a5 5 0 117.072 0l-.548.547A3.374 3.374 0 0014 18.469V19a2 2 0 11-4 0v-.531c0-.895-.356-1.754-.988-2.386l-.548-.547z"></path>
|
||||
</svg>
|
||||
Pre-historic Knowledge Assistant
|
||||
</h1>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- Main Content -->
|
||||
<main class="flex-1 max-w-3xl w-full mx-auto px-4 py-8">
|
||||
{% if error %}
|
||||
<div class="mb-6 p-4 bg-red-50 border border-red-200 rounded-lg fade-in">
|
||||
<p class="text-red-700 flex items-center gap-2">
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8v4m0 4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"></path>
|
||||
</svg>
|
||||
{{ error }}
|
||||
</p>
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
<!-- Question Form -->
|
||||
<div class="bg-white rounded-xl shadow-sm border border-slate-200 p-6 mb-6">
|
||||
<form action="/ask" method="POST" class="space-y-4">
|
||||
<div>
|
||||
<label for="question" class="block text-sm font-medium text-slate-700 mb-2">
|
||||
Ask your question
|
||||
</label>
|
||||
<textarea
|
||||
id="question"
|
||||
name="question"
|
||||
rows="3"
|
||||
required
|
||||
placeholder="What would you like to know?"
|
||||
class="w-full px-4 py-3 border border-slate-300 rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-blue-500 resize-none transition duration-200 placeholder-slate-400"
|
||||
>{{ question if question else '' }}</textarea>
|
||||
</div>
|
||||
<button
|
||||
type="submit"
|
||||
class="w-full sm:w-auto px-6 py-3 bg-blue-600 text-white font-medium rounded-lg hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2 transition duration-200 flex items-center justify-center gap-2"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z"></path>
|
||||
</svg>
|
||||
Ask Question
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<!-- Loading State -->
|
||||
<div id="loading-state" class="bg-white rounded-xl shadow-sm border border-slate-200 p-6 hidden">
|
||||
<div class="flex items-center justify-center gap-3">
|
||||
<svg class="animate-spin h-5 w-5 text-blue-600" fill="none" viewBox="0 0 24 24">
|
||||
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
||||
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
|
||||
</svg>
|
||||
<p class="text-slate-600">Processing your question...</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Answer Display -->
|
||||
{% if answer %}
|
||||
<div id="answer-container" class="bg-white rounded-xl shadow-sm border border-slate-200 p-6 fade-in">
|
||||
<div class="flex items-start gap-3 mb-3">
|
||||
<div class="flex-shrink-0">
|
||||
<svg class="w-6 h-6 text-green-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"></path>
|
||||
</svg>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<h3 class="text-lg font-semibold text-slate-900 mb-2">Answer</h3>
|
||||
<div class="prose prose-slate max-w-none">
|
||||
<p class="text-slate-700 leading-relaxed whitespace-pre-wrap">{{ answer }}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
</main>
|
||||
|
||||
<!-- Footer -->
|
||||
<footer class="bg-white border-t border-slate-200 mt-8">
|
||||
<div class="max-w-3xl mx-auto px-4 py-4">
|
||||
<p class="text-center text-sm text-slate-500">
|
||||
Powered by GraphRAG based LLM Agents 🤖
|
||||
</p>
|
||||
</div>
|
||||
</footer>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Show loading state on form submission
|
||||
document.querySelector('form').addEventListener('submit', function(e) {
|
||||
// Find or create loading state
|
||||
let loadingState = document.getElementById('loading-state');
|
||||
|
||||
// If loading state doesn't exist, create it
|
||||
if (!loadingState) {
|
||||
const loadingHTML = `
|
||||
<div id="loading-state" class="bg-white rounded-xl shadow-sm border border-slate-200 p-6 fade-in">
|
||||
<div class="flex items-center justify-center gap-3">
|
||||
<svg class="animate-spin h-5 w-5 text-blue-600" fill="none" viewBox="0 0 24 24">
|
||||
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
|
||||
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
|
||||
</svg>
|
||||
<p class="text-slate-600">Processing your question...</p>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// Find the main element and append loading state
|
||||
const main = document.querySelector('main');
|
||||
const formContainer = document.querySelector('.bg-white.rounded-xl.shadow-sm.border.border-slate-200.p-6.mb-6');
|
||||
formContainer.insertAdjacentHTML('afterend', loadingHTML);
|
||||
loadingState = document.getElementById('loading-state');
|
||||
} else {
|
||||
// Show existing loading state
|
||||
loadingState.style.display = 'block';
|
||||
loadingState.classList.remove('hidden');
|
||||
}
|
||||
|
||||
// Hide any existing answer container
|
||||
const answerContainer = document.getElementById('answer-container');
|
||||
if (answerContainer) {
|
||||
answerContainer.style.display = 'none';
|
||||
}
|
||||
|
||||
// Smooth scroll to the loading state
|
||||
setTimeout(() => {
|
||||
loadingState.scrollIntoView({ behavior: 'smooth', block: 'center' });
|
||||
}, 100);
|
||||
});
|
||||
|
||||
// Auto-resize textarea
|
||||
const textarea = document.querySelector('textarea');
|
||||
textarea.addEventListener('input', function() {
|
||||
this.style.height = 'auto';
|
||||
this.style.height = this.scrollHeight + 'px';
|
||||
});
|
||||
|
||||
// Handle Enter key for form submission and Shift+Enter for new line
|
||||
textarea.addEventListener('keydown', function(e) {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
this.closest('form').submit();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
26
gui/models.py
Normal file
26
gui/models.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class QuestionRequest(BaseModel):
|
||||
"""Request model for questions."""
|
||||
question: str
|
||||
mode: str = "mix"
|
||||
response_type: str = "Multiple Paragraphs"
|
||||
|
||||
|
||||
class QuestionResponse(BaseModel):
|
||||
"""Response model for questions."""
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Response model for health check."""
|
||||
status: str
|
||||
rag_initialized: bool
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Response model for errors."""
|
||||
detail: str
|
||||
136
gui/rag_service.py
Normal file
136
gui/rag_service.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import Optional, List
|
||||
from loguru import logger
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||
|
||||
|
||||
class RAGService:
|
||||
"""Service class for RAG operations."""
|
||||
|
||||
def __init__(self):
|
||||
self.rag: Optional[LightRAG] = None
|
||||
|
||||
|
||||
|
||||
# Azure OpenAI for LLM
|
||||
async def llm_model_func(
|
||||
self, prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
client = AzureOpenAI(
|
||||
api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
||||
api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
||||
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||
)
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
if history_messages:
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model=os.environ["AZURE_OPENAI_DEPLOYMENT"],
|
||||
messages=messages,
|
||||
temperature=kwargs.get("temperature", 0),
|
||||
top_p=kwargs.get("top_p", 1),
|
||||
n=kwargs.get("n", 1),
|
||||
)
|
||||
return chat_completion.choices[0].message.content
|
||||
|
||||
# vLLM for embeddings
|
||||
async def embedding_func(self, texts: List[str]) -> np.ndarray:
|
||||
try:
|
||||
return await openai_embed(
|
||||
texts,
|
||||
model=os.environ["EMBEDDING_MODEL"],
|
||||
api_key="anything",
|
||||
base_url=os.environ["VLLM_EMBED_HOST"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in embedding call: {e}")
|
||||
raise
|
||||
|
||||
async def get_embedding_dim(self) -> int:
|
||||
"""Get embedding dimension by testing with a sample text."""
|
||||
test_text = ["This is a test sentence."]
|
||||
embedding = await self.embedding_func(test_text)
|
||||
return embedding.shape[1]
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the RAG system."""
|
||||
try:
|
||||
knowledge_graph_path = os.environ["KNOWLEDGE_GRAPH_PATH"]
|
||||
|
||||
# Get embedding dimension dynamically
|
||||
embedding_dimension = await self.get_embedding_dim()
|
||||
logger.info(f"Detected embedding dimension: {embedding_dimension}")
|
||||
|
||||
self.rag = LightRAG(
|
||||
working_dir=knowledge_graph_path,
|
||||
graph_storage="NetworkXStorage",
|
||||
kv_storage="JsonKVStorage",
|
||||
vector_storage="FaissVectorDBStorage",
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": 0.2
|
||||
},
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=embedding_dimension,
|
||||
max_token_size=8192,
|
||||
func=self.embedding_func
|
||||
),
|
||||
llm_model_func=self.llm_model_func,
|
||||
enable_llm_cache=True,
|
||||
enable_llm_cache_for_entity_extract=False,
|
||||
embedding_cache_config={
|
||||
"enabled": False,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False
|
||||
},
|
||||
)
|
||||
|
||||
# Initialize storages
|
||||
await self.rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
|
||||
logger.success("RAG system initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing RAG: {e}")
|
||||
raise
|
||||
|
||||
async def query(
|
||||
self,
|
||||
question: str,
|
||||
mode: str = "mix",
|
||||
response_type: str = "Multiple Paragraphs"
|
||||
) -> str:
|
||||
"""Query the RAG system."""
|
||||
if not self.rag:
|
||||
raise RuntimeError("RAG system not initialized")
|
||||
|
||||
try:
|
||||
response = await self.rag.aquery(
|
||||
question,
|
||||
param=QueryParam(
|
||||
mode=mode,
|
||||
response_type=response_type,
|
||||
only_need_context=False,
|
||||
)
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing query: {e}")
|
||||
raise
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if RAG is initialized."""
|
||||
return self.rag is not None
|
||||
|
||||
|
||||
rag_service = RAGService()
|
||||
138
inference.py
138
inference.py
@@ -1,51 +1,139 @@
|
||||
import glob
|
||||
import os
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
import aiofiles
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm.azure_openai import azure_openai_embed, azure_openai_complete
|
||||
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
|
||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
||||
from lightrag.utils import setup_logger, EmbeddingFunc
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from openai import AzureOpenAI
|
||||
|
||||
# Setup environment and logging
|
||||
setup_logger("lightrag", level="INFO")
|
||||
|
||||
|
||||
def get_required_env(name):
|
||||
value = os.environ.get(name)
|
||||
if not value:
|
||||
raise ValueError(f"Missing required environment variable: {name}")
|
||||
return value
|
||||
|
||||
async def initialize_rag():
|
||||
rag = LightRAG(
|
||||
working_dir="/Users/tcudikel/Dev/ancient-history/data/storage/base_gpt4o",
|
||||
graph_storage="NetworkXStorage",
|
||||
vector_storage="ChromaVectorDBStorage",
|
||||
vector_db_storage_cls_kwargs={
|
||||
"local_path": "/Users/tcudikel/Dev/ancient-history/data/storage/base_gpt4o/vdb",
|
||||
"cosine_better_than_threshold": 0.5,
|
||||
},
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=3072,
|
||||
max_token_size=8192,
|
||||
func=lambda texts: azure_openai_embed(texts)
|
||||
),
|
||||
llm_model_func=azure_openai_complete
|
||||
|
||||
""" LLM vLLM
|
||||
async def llm_model_func(prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs):
|
||||
try:
|
||||
return await openai_complete_if_cache(
|
||||
model=os.environ["LLM_MODEL"],
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key="anything",
|
||||
base_url=os.environ["VLLM_LLM_HOST"],
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in LLM call: {e}")
|
||||
raise
|
||||
"""
|
||||
|
||||
""" LLM Azure OpenAI"""
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
client = AzureOpenAI(
|
||||
api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
||||
api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
||||
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||
)
|
||||
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
if history_messages:
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
return rag
|
||||
chat_completion = client.chat.completions.create(
|
||||
model=os.environ["AZURE_OPENAI_DEPLOYMENT"],
|
||||
messages=messages,
|
||||
temperature=kwargs.get("temperature", 0),
|
||||
top_p=kwargs.get("top_p", 1),
|
||||
n=kwargs.get("n", 1),
|
||||
)
|
||||
return chat_completion.choices[0].message.content
|
||||
|
||||
async def embedding_func(texts: List[str]) -> np.ndarray:
|
||||
try:
|
||||
return await openai_embed(
|
||||
texts,
|
||||
model=os.environ["EMBEDDING_MODEL"],
|
||||
api_key="anything",
|
||||
base_url=os.environ["VLLM_EMBED_HOST"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in embedding call: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_embedding_dim():
|
||||
test_text = ["This is a test sentence."]
|
||||
embedding = await embedding_func(test_text)
|
||||
embedding_dim = embedding.shape[1]
|
||||
return embedding_dim
|
||||
|
||||
|
||||
async def initialize_rag():
|
||||
try:
|
||||
knowledge_graph_path = get_required_env("KNOWLEDGE_GRAPH_PATH")
|
||||
|
||||
# Get embedding dimension dynamically
|
||||
embedding_dimension = await get_embedding_dim()
|
||||
logger.info(f"Detected embedding dimension: {embedding_dimension}")
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=knowledge_graph_path,
|
||||
graph_storage="NetworkXStorage",
|
||||
kv_storage="JsonKVStorage",
|
||||
vector_storage="FaissVectorDBStorage",
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": 0.2
|
||||
},
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=embedding_dimension,
|
||||
max_token_size=8192,
|
||||
func=embedding_func
|
||||
),
|
||||
llm_model_func=llm_model_func,
|
||||
enable_llm_cache=True,
|
||||
enable_llm_cache_for_entity_extract=False,
|
||||
embedding_cache_config={
|
||||
"enabled": False,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False
|
||||
},
|
||||
)
|
||||
|
||||
# Initialize storages properly
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
return rag
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing RAG: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
rag = asyncio.run(initialize_rag())
|
||||
|
||||
|
||||
mode = "mix"
|
||||
response = rag.query(
|
||||
"Which prophets exist before Noah?",
|
||||
"Giants in Holy texts? In terms of monotheistic, polytesitic, ateistic, agnostic and deistic approaches",
|
||||
param=QueryParam(
|
||||
mode=mode,
|
||||
response_type="Single Paragraphs",
|
||||
response_type="in bullet points and description for each bullet point",
|
||||
only_need_context=False,
|
||||
# conversation_history=,
|
||||
# history_turns=5,
|
||||
|
||||
@@ -1,22 +1,87 @@
|
||||
aiofiles==24.1.0
|
||||
annotated-types==0.7.0
|
||||
anyio==4.9.0
|
||||
anytree==2.13.0
|
||||
ascii_colors==0.8.0
|
||||
autograd==1.7.0
|
||||
beartype==0.18.5
|
||||
blobfile==3.0.0
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
click==8.2.0
|
||||
contourpy==1.3.2
|
||||
cycler==0.12.1
|
||||
distro==1.9.0
|
||||
faiss-cpu==1.10.0
|
||||
fastapi==0.115.12
|
||||
filelock==3.18.0
|
||||
fonttools==4.57.0
|
||||
gensim==4.3.3
|
||||
glcontext==3.0.0
|
||||
glfw==2.9.0
|
||||
graspologic==3.4.1
|
||||
graspologic-native==1.2.5
|
||||
h11==0.16.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
hyppo==0.4.0
|
||||
idna==3.10
|
||||
imgui-bundle==1.6.2
|
||||
Jinja2==3.1.6
|
||||
jiter==0.9.0
|
||||
joblib==1.4.2
|
||||
kiwisolver==1.4.8
|
||||
lightrag-hku==1.3.2
|
||||
llvmlite==0.44.0
|
||||
loguru==0.7.3
|
||||
lxml==5.4.0
|
||||
MarkupSafe==3.0.2
|
||||
matplotlib==3.10.1
|
||||
moderngl==5.12.0
|
||||
munch==4.0.0
|
||||
networkx==3.4.2
|
||||
numpy==2.2.4
|
||||
numba==0.61.2
|
||||
numpy==1.26.4
|
||||
openai==1.76.0
|
||||
packaging==25.0
|
||||
pandas==2.2.3
|
||||
patsy==1.0.1
|
||||
pillow==11.2.1
|
||||
pipmaster==0.5.4
|
||||
POT==0.9.5
|
||||
pycryptodomex==3.22.0
|
||||
pydantic==2.11.3
|
||||
pydantic-settings==2.9.1
|
||||
pydantic_core==2.33.1
|
||||
pyglm==2.8.2
|
||||
pynndescent==0.5.13
|
||||
PyOpenGL==3.1.9
|
||||
pyparsing==3.2.3
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.1.0
|
||||
python-louvain==0.16
|
||||
scipy==1.15.2
|
||||
python-multipart==0.0.20
|
||||
pytz==2025.2
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
scikit-learn==1.6.1
|
||||
scipy==1.12.0
|
||||
seaborn==0.13.2
|
||||
setuptools==79.0.1
|
||||
six==1.17.0
|
||||
smart-open==7.1.0
|
||||
sniffio==1.3.1
|
||||
starlette==0.46.2
|
||||
statsmodels==0.14.4
|
||||
tenacity==9.1.2
|
||||
threadpoolctl==3.6.0
|
||||
tiktoken==0.9.0
|
||||
tk==0.1.0
|
||||
tqdm==4.67.1
|
||||
typing-inspection==0.4.0
|
||||
typing_extensions==4.13.2
|
||||
tzdata==2025.2
|
||||
umap-learn==0.5.7
|
||||
urllib3==2.4.0
|
||||
uvicorn==0.34.2
|
||||
wrapt==1.17.2
|
||||
|
||||
Reference in New Issue
Block a user