Brought all lightrag structure to the minirag app

This commit is contained in:
Saifeddine ALOUI
2025-01-30 23:05:43 +01:00
parent d07d3a9fea
commit c88e682324
47 changed files with 10403 additions and 1339 deletions

26
.gitignore vendored Normal file
View File

@@ -0,0 +1,26 @@
__pycache__
*.egg-info
dickens/
book.txt
lightrag-dev/
.idea/
dist/
env/
local_neo4jWorkDir/
neo4jWorkDir/
ignore_this.txt
.venv/
*.ignore.*
.ruff_cache/
gui/
*.log
.vscode
inputs
rag_storage
.env
venv/
examples/input/
examples/output/
.DS_Store
#Remove config.ini from repo
*.ini

22
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,22 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: requirements-txt-fixer
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
hooks:
- id: ruff-format
- id: ruff
args: [--fix, --ignore=E402]
- repo: https://github.com/mgedmin/check-manifest
rev: "0.49"
hooks:
- id: check-manifest
stages: [manual]

BIN
assets/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB

View File

@@ -0,0 +1,7 @@
AZURE_OPENAI_API_VERSION=2024-08-01-preview
AZURE_OPENAI_DEPLOYMENT=gpt-4o
AZURE_OPENAI_API_KEY=myapikey
AZURE_OPENAI_ENDPOINT=https://myendpoint.openai.azure.com
AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
AZURE_EMBEDDING_API_VERSION=2023-05-15

2
minirag/api/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
inputs
rag_storage

475
minirag/api/README.md Normal file
View File

@@ -0,0 +1,475 @@
## Install with API Support
LightRAG provides optional API support through FastAPI servers that add RAG capabilities to existing LLM services. You can install LightRAG with API support in two ways:
### 1. Installation from PyPI
```bash
pip install "lightrag-hku[api]"
```
### 2. Installation from Source (Development)
```bash
# Clone the repository
git clone https://github.com/HKUDS/lightrag.git
# Change to the repository directory
cd lightrag
# create a Python virtual enviroment if neccesary
# Install in editable mode with API support
pip install -e ".[api]"
```
### Prerequisites
Before running any of the servers, ensure you have the corresponding backend service running for both llm and embedding.
The new api allows you to mix different bindings for llm/embeddings.
For example, you have the possibility to use ollama for the embedding and openai for the llm.
#### For LoLLMs Server
- LoLLMs must be running and accessible
- Default connection: http://localhost:9600
- Configure using --llm-binding-host and/or --embedding-binding-host if running on a different host/port
#### For Ollama Server
- Ollama must be running and accessible
- Requires environment variables setup or command line argument provided
- Environment variables: LLM_BINDING=ollama, LLM_BINDING_HOST, LLM_MODEL
- Command line arguments: --llm-binding=ollama, --llm-binding-host, --llm-model
- Default connection is http://localhost:11434 if not priveded
> The default MAX_TOKENS(num_ctx) for Ollama is 32768. If your Ollama server is lacking or GPU memory, set it to a lower value.
#### For OpenAI Alike Server
- Requires environment variables setup or command line argument provided
- Environment variables: LLM_BINDING=ollama, LLM_BINDING_HOST, LLM_MODEL, LLM_BINDING_API_KEY
- Command line arguments: --llm-binding=ollama, --llm-binding-host, --llm-model, --llm-binding-api-key
- Default connection is https://api.openai.com/v1 if not priveded
#### For Azure OpenAI Server
Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)):
```bash
# Change the resource group name, location and OpenAI resource name as needed
RESOURCE_GROUP_NAME=LightRAG
LOCATION=swedencentral
RESOURCE_NAME=LightRAG-OpenAI
az login
az group create --name $RESOURCE_GROUP_NAME --location $LOCATION
az cognitiveservices account create --name $RESOURCE_NAME --resource-group $RESOURCE_GROUP_NAME --kind OpenAI --sku S0 --location swedencentral
az cognitiveservices account deployment create --resource-group $RESOURCE_GROUP_NAME --model-format OpenAI --name $RESOURCE_NAME --deployment-name gpt-4o --model-name gpt-4o --model-version "2024-08-06" --sku-capacity 100 --sku-name "Standard"
az cognitiveservices account deployment create --resource-group $RESOURCE_GROUP_NAME --model-format OpenAI --name $RESOURCE_NAME --deployment-name text-embedding-3-large --model-name text-embedding-3-large --model-version "1" --sku-capacity 80 --sku-name "Standard"
az cognitiveservices account show --name $RESOURCE_NAME --resource-group $RESOURCE_GROUP_NAME --query "properties.endpoint"
az cognitiveservices account keys list --name $RESOURCE_NAME -g $RESOURCE_GROUP_NAME
```
The output of the last command will give you the endpoint and the key for the OpenAI API. You can use these values to set the environment variables in the `.env` file.
```
LLM_BINDING=azure_openai
LLM_BINDING_HOST=endpoint_of_azure_ai
LLM_MODEL=model_name_of_azure_ai
LLM_BINDING_API_KEY=api_key_of_azure_ai
```
### About Ollama API
We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily.
#### Choose Query mode in chat
A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include:
/local
/global
/hybrid
/naive
/mix
For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。
#### Connect Open WebUI to LightRAG
After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
## Configuration
LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables.
For better performance, the API server's default values for TOP_K and COSINE_THRESHOLD are set to 50 and 0.4 respectively. If COSINE_THRESHOLD remains at its default value of 0.2 in LightRAG, many irrelevant entities and relations would be retrieved and sent to the LLM.
### Environment Variables
You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. Here's a complete example of available environment variables:
```env
# Server Configuration
HOST=0.0.0.0
PORT=9721
# Directory Configuration
WORKING_DIR=/app/data/rag_storage
INPUT_DIR=/app/data/inputs
# RAG Configuration
MAX_ASYNC=4
MAX_TOKENS=32768
EMBEDDING_DIM=1024
MAX_EMBED_TOKENS=8192
#HISTORY_TURNS=3
#CHUNK_SIZE=1200
#CHUNK_OVERLAP_SIZE=100
#COSINE_THRESHOLD=0.4
#TOP_K=50
# LLM Configuration
LLM_BINDING=ollama
LLM_BINDING_HOST=http://localhost:11434
LLM_MODEL=mistral-nemo:latest
# must be set if using OpenAI LLM (LLM_MODEL must be set or set by command line parms)
OPENAI_API_KEY=you_api_key
# Embedding Configuration
EMBEDDING_BINDING=ollama
EMBEDDING_BINDING_HOST=http://localhost:11434
EMBEDDING_MODEL=bge-m3:latest
# Security
#LIGHTRAG_API_KEY=you-api-key-for-accessing-LightRAG
# Logging
LOG_LEVEL=INFO
# Optional SSL Configuration
#SSL=true
#SSL_CERTFILE=/path/to/cert.pem
#SSL_KEYFILE=/path/to/key.pem
# Optional Timeout
#TIMEOUT=30
```
### Configuration Priority
The configuration values are loaded in the following order (highest priority first):
1. Command-line arguments
2. Environment variables
3. Default values
For example:
```bash
# This command-line argument will override both the environment variable and default value
python lightrag.py --port 8080
# The environment variable will override the default value but not the command-line argument
PORT=7000 python lightrag.py
```
#### LightRag Server Options
| Parameter | Default | Description |
|-----------|---------|-------------|
| --host | 0.0.0.0 | Server host |
| --port | 9721 | Server port |
| --llm-binding | ollama | LLM binding to be used. Supported: lollms, ollama, openai |
| --llm-binding-host | (dynamic) | LLM server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) |
| --llm-model | mistral-nemo:latest | LLM model name |
| --llm-binding-api-key | None | API Key for OpenAI Alike LLM |
| --embedding-binding | ollama | Embedding binding to be used. Supported: lollms, ollama, openai |
| --embedding-binding-host | (dynamic) | Embedding server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) |
| --embedding-model | bge-m3:latest | Embedding model name |
| --working-dir | ./rag_storage | Working directory for RAG storage |
| --input-dir | ./inputs | Directory containing input documents |
| --max-async | 4 | Maximum async operations |
| --max-tokens | 32768 | Maximum token size |
| --embedding-dim | 1024 | Embedding dimensions |
| --max-embed-tokens | 8192 | Maximum embedding token size |
| --timeout | None | Timeout in seconds (useful when using slow AI). Use None for infinite timeout |
| --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
| --key | None | API key for authentication. Protects lightrag server against unauthorized access |
| --ssl | False | Enable HTTPS |
| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) |
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
### Example Usage
#### Running a Lightrag server with ollama default local server as llm and embedding backends
Ollama is the default backend for both llm and embedding, so by default you can run lightrag-server with no parameters and the default ones will be used. Make sure ollama is installed and is running and default models are already installed on ollama.
```bash
# Run lightrag with ollama, mistral-nemo:latest for llm, and bge-m3:latest for embedding
lightrag-server
# Using specific models (ensure they are installed in your ollama instance)
lightrag-server --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-model nomic-embed-text --embedding-dim 1024
# Using an authentication key
lightrag-server --key my-key
# Using lollms for llm and ollama for embedding
lightrag-server --llm-binding lollms
```
#### Running a Lightrag server with lollms default local server as llm and embedding backends
```bash
# Run lightrag with lollms, mistral-nemo:latest for llm, and bge-m3:latest for embedding, use lollms for both llm and embedding
lightrag-server --llm-binding lollms --embedding-binding lollms
# Using specific models (ensure they are installed in your ollama instance)
lightrag-server --llm-binding lollms --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-binding lollms --embedding-model nomic-embed-text --embedding-dim 1024
# Using an authentication key
lightrag-server --key my-key
# Using lollms for llm and openai for embedding
lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small
```
#### Running a Lightrag server with openai server as llm and embedding backends
```bash
# Run lightrag with lollms, GPT-4o-mini for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding
lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small
# Using an authentication key
lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small --key my-key
# Using lollms for llm and openai for embedding
lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small
```
#### Running a Lightrag server with azure openai server as llm and embedding backends
```bash
# Run lightrag with lollms, GPT-4o-mini for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding
lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small
# Using an authentication key
lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding azure_openai --embedding-model text-embedding-3-small --key my-key
# Using lollms for llm and azure_openai for embedding
lightrag-server --llm-binding lollms --embedding-binding azure_openai --embedding-model text-embedding-3-small
```
**Important Notes:**
- For LoLLMs: Make sure the specified models are installed in your LoLLMs instance
- For Ollama: Make sure the specified models are installed in your Ollama instance
- For OpenAI: Ensure you have set up your OPENAI_API_KEY environment variable
- For Azure OpenAI: Build and configure your server as stated in the Prequisites section
For help on any server, use the --help flag:
```bash
lightrag-server --help
```
Note: If you don't need the API functionality, you can install the base package without API support using:
```bash
pip install lightrag-hku
```
## API Endpoints
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality.
### Query Endpoints
#### POST /query
Query the RAG system with options for different search modes.
```bash
curl -X POST "http://localhost:9721/query" \
-H "Content-Type: application/json" \
-d '{"query": "Your question here", "mode": "hybrid", ""}'
```
#### POST /query/stream
Stream responses from the RAG system.
```bash
curl -X POST "http://localhost:9721/query/stream" \
-H "Content-Type: application/json" \
-d '{"query": "Your question here", "mode": "hybrid"}'
```
### Document Management Endpoints
#### POST /documents/text
Insert text directly into the RAG system.
```bash
curl -X POST "http://localhost:9721/documents/text" \
-H "Content-Type: application/json" \
-d '{"text": "Your text content here", "description": "Optional description"}'
```
#### POST /documents/file
Upload a single file to the RAG system.
```bash
curl -X POST "http://localhost:9721/documents/file" \
-F "file=@/path/to/your/document.txt" \
-F "description=Optional description"
```
#### POST /documents/batch
Upload multiple files at once.
```bash
curl -X POST "http://localhost:9721/documents/batch" \
-F "files=@/path/to/doc1.txt" \
-F "files=@/path/to/doc2.txt"
```
#### POST /documents/scan
Trigger document scan for new files in the Input directory.
```bash
curl -X POST "http://localhost:9721/documents/scan" --max-time 1800
```
> Ajust max-time according to the estimated index time for all new files.
### Ollama Emulation Endpoints
#### GET /api/version
Get Ollama version information
```bash
curl http://localhost:9721/api/version
```
#### GET /api/tags
Get Ollama available models
```bash
curl http://localhost:9721/api/tags
```
#### POST /api/chat
Handle chat completion requests
```shell
curl -N -X POST http://localhost:9721/api/chat -H "Content-Type: application/json" -d \
'{"model":"lightrag:latest","messages":[{"role":"user","content":"猪八戒是谁"}],"stream":true}'
```
> For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md)
#### DELETE /documents
Clear all documents from the RAG system.
```bash
curl -X DELETE "http://localhost:9721/documents"
```
### Utility Endpoints
#### GET /health
Check server health and configuration.
```bash
curl "http://localhost:9721/health"
```
## Development
Contribute to the project: [Guide](contributor-readme.MD)
### Running in Development Mode
For LoLLMs:
```bash
uvicorn lollms_lightrag_server:app --reload --port 9721
```
For Ollama:
```bash
uvicorn ollama_lightrag_server:app --reload --port 9721
```
For OpenAI:
```bash
uvicorn openai_lightrag_server:app --reload --port 9721
```
For Azure OpenAI:
```bash
uvicorn azure_openai_lightrag_server:app --reload --port 9721
```
### API Documentation
When any server is running, visit:
- Swagger UI: http://localhost:9721/docs
- ReDoc: http://localhost:9721/redoc
### Testing API Endpoints
You can test the API endpoints using the provided curl commands or through the Swagger UI interface. Make sure to:
1. Start the appropriate backend service (LoLLMs, Ollama, or OpenAI)
2. Start the RAG server
3. Upload some documents using the document management endpoints
4. Query the system using the query endpoints
5. Trigger document scan if new files is put into inputs directory
### Important Features
#### Automatic Document Vectorization
When starting any of the servers with the `--input-dir` parameter, the system will automatically:
1. Check for existing vectorized content in the database
2. Only vectorize new documents that aren't already in the database
3. Make all content immediately available for RAG queries
This intelligent caching mechanism:
- Prevents unnecessary re-vectorization of existing documents
- Reduces startup time for subsequent runs
- Preserves system resources
- Maintains consistency across restarts
**Important Notes:**
- The `--input-dir` parameter enables automatic document processing at startup
- Documents already in the database are not re-vectorized
- Only new documents in the input directory will be processed
- This optimization significantly reduces startup time for subsequent runs
- The working directory (`--working-dir`) stores the vectorized documents database
## Install Lightrag as a Linux Service
Create your service file: `lightrag-server.sevice`. Modified the following lines from `lightrag-server.sevice.example`
```text
Description=LightRAG Ollama Service
WorkingDirectory=<lightrag installed directory>
ExecStart=<lightrag installed directory>/lightrag/api/start_lightrag.sh
```
Create your service startup script: `start_lightrag.sh`. Change you python virtual environment activation method as need:
```shell
#!/bin/bash
# python virtual environment activation
source /home/netman/lightrag-xyj/venv/bin/activate
# start lightrag api server
lightrag-server
```
Install lightrag.service in Linux. Sample commands in Ubuntu server look like:
```shell
sudo cp lightrag-server.service /etc/systemd/system/
sudo systemctl daemon-reload
sudo systemctl start lightrag-server.service
sudo systemctl status lightrag-server.service
sudo systemctl enable lightrag-server.service
```

1
minirag/api/__init__.py Normal file
View File

@@ -0,0 +1 @@
__api_version__ = "1.0.3"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,12 @@
ascii_colors
fastapi
nest_asyncio
numpy
pipmaster
python-dotenv
python-multipart
tenacity
tiktoken
torch
tqdm
uvicorn

View File

@@ -0,0 +1,2 @@
# LightRag Webui
A simple webui to interact with the lightrag datalake

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

View File

@@ -0,0 +1,104 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MiniRag Interface</title>
<script src="https://cdn.tailwindcss.com"></script>
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<style>
.fade-in {
animation: fadeIn 0.3s ease-in;
}
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
}
.spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from { transform: rotate(0deg); }
to { transform: rotate(360deg); }
}
.slide-in {
animation: slideIn 0.3s ease-out;
}
@keyframes slideIn {
from { transform: translateX(-100%); }
to { transform: translateX(0); }
}
</style>
</head>
<body class="bg-gray-50">
<div class="flex h-screen">
<!-- Sidebar -->
<div class="w-64 bg-white shadow-lg">
<div class="p-4">
<h1 class="text-xl font-bold text-gray-800 mb-6">MiniRag</h1>
<nav class="space-y-2">
<a href="#" class="nav-item" data-page="file-manager">
<div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors">
<svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M3 7v10a2 2 0 002 2h14a2 2 0 002-2V9a2 2 0 00-2-2h-6l-2-2H5a2 2 0 00-2 2z"/>
</svg>
File Manager
</div>
</a>
<a href="#" class="nav-item" data-page="query">
<div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors">
<svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z"/>
</svg>
Query Database
</div>
</a>
<a href="#" class="nav-item" data-page="knowledge-graph">
<div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors">
<svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z"/>
</svg>
Knowledge Graph
</div>
</a>
<a href="#" class="nav-item" data-page="status">
<div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors">
<svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"/>
</svg>
Status
</div>
</a>
<a href="#" class="nav-item" data-page="settings">
<div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors">
<svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"/>
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"/>
</svg>
Settings
</div>
</a>
</nav>
</div>
</div>
<!-- Main Content -->
<div class="flex-1 overflow-auto p-6">
<div id="content" class="fade-in"></div>
</div>
<!-- Toast Notification -->
<div id="toast" class="fixed bottom-4 right-4 hidden">
<div class="bg-gray-800 text-white px-6 py-3 rounded-lg shadow-lg"></div>
</div>
</div>
<script src="/js/api.js"></script>
</body>
</html>

View File

@@ -0,0 +1,404 @@
// State management
const state = {
apiKey: localStorage.getItem('apiKey') || '',
files: [],
indexedFiles: [],
currentPage: 'file-manager'
};
// Utility functions
const showToast = (message, duration = 3000) => {
const toast = document.getElementById('toast');
toast.querySelector('div').textContent = message;
toast.classList.remove('hidden');
setTimeout(() => toast.classList.add('hidden'), duration);
};
const fetchWithAuth = async (url, options = {}) => {
const headers = {
...(options.headers || {}),
...(state.apiKey ? { 'Authorization': `Bearer ${state.apiKey}` } : {})
};
return fetch(url, { ...options, headers });
};
// Page renderers
const pages = {
'file-manager': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">File Manager</h2>
<div class="border-2 border-dashed border-gray-300 rounded-lg p-8 text-center hover:border-gray-400 transition-colors">
<input type="file" id="fileInput" multiple accept=".txt,.md,.doc,.docx,.pdf,.pptx" class="hidden">
<label for="fileInput" class="cursor-pointer">
<svg class="mx-auto h-12 w-12 text-gray-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 16a4 4 0 01-.88-7.903A5 5 0 1115.9 6L16 6a5 5 0 011 9.9M15 13l-3-3m0 0l-3 3m3-3v12"/>
</svg>
<p class="mt-2 text-gray-600">Drag files here or click to select</p>
<p class="text-sm text-gray-500">Supported formats: TXT, MD, DOC, PDF, PPTX</p>
</label>
</div>
<div id="fileList" class="space-y-2">
<h3 class="text-lg font-semibold text-gray-700">Selected Files</h3>
<div class="space-y-2"></div>
</div>
<div id="uploadProgress" class="hidden mt-4">
<div class="w-full bg-gray-200 rounded-full h-2.5">
<div class="bg-blue-600 h-2.5 rounded-full" style="width: 0%"></div>
</div>
<p class="text-sm text-gray-600 mt-2"><span id="uploadStatus">0</span> files processed</p>
</div>
<button id="rescanBtn" class="flex items-center bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="20" height="20" fill="currentColor" class="mr-2">
<path d="M12 4a8 8 0 1 1-8 8H2.5a9.5 9.5 0 1 0 2.8-6.7L2 3v6h6L5.7 6.7A7.96 7.96 0 0 1 12 4z"/>
</svg>
Rescan Files
</button>
<button id="uploadBtn" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
Upload & Index Files
</button>
<div id="indexedFiles" class="space-y-2">
<h3 class="text-lg font-semibold text-gray-700">Indexed Files</h3>
<div class="space-y-2"></div>
</div>
</div>
`,
'query': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">Query Database</h2>
<div class="space-y-4">
<div>
<label class="block text-sm font-medium text-gray-700">Query Mode</label>
<select id="queryMode" class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500">
<option value="light">Light</option>
<option value="naive">Naive</option>
<option value="mini">Mini</option>
</select>
</div>
<div>
<label class="block text-sm font-medium text-gray-700">Query</label>
<textarea id="queryInput" rows="4" class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500"></textarea>
</div>
<button id="queryBtn" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
Send Query
</button>
<div id="queryResult" class="mt-4 p-4 bg-white rounded-lg shadow"></div>
</div>
</div>
`,
'knowledge-graph': () => `
<div class="flex items-center justify-center h-full">
<div class="text-center">
<svg class="mx-auto h-12 w-12 text-gray-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10"/>
</svg>
<h3 class="mt-2 text-sm font-medium text-gray-900">Under Construction</h3>
<p class="mt-1 text-sm text-gray-500">Knowledge graph visualization will be available in a future update.</p>
</div>
</div>
`,
'status': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">System Status</h2>
<div id="statusContent" class="grid grid-cols-1 md:grid-cols-2 gap-6">
<div class="p-6 bg-white rounded-lg shadow-sm">
<h3 class="text-lg font-semibold mb-4">System Health</h3>
<div id="healthStatus"></div>
</div>
<div class="p-6 bg-white rounded-lg shadow-sm">
<h3 class="text-lg font-semibold mb-4">Configuration</h3>
<div id="configStatus"></div>
</div>
</div>
</div>
`,
'settings': () => `
<div class="space-y-6">
<h2 class="text-2xl font-bold text-gray-800">Settings</h2>
<div class="max-w-xl">
<div class="space-y-4">
<div>
<label class="block text-sm font-medium text-gray-700">API Key</label>
<input type="password" id="apiKeyInput" value="${state.apiKey}"
class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500">
</div>
<button id="saveSettings" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
Save Settings
</button>
</div>
</div>
</div>
`
};
// Page handlers
const handlers = {
'file-manager': () => {
const fileInput = document.getElementById('fileInput');
const dropZone = fileInput.parentElement.parentElement;
const fileList = document.querySelector('#fileList div');
const indexedFiles = document.querySelector('#indexedFiles div');
const uploadBtn = document.getElementById('uploadBtn');
const updateFileList = () => {
fileList.innerHTML = state.files.map(file => `
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
<span>${file.name}</span>
<button class="text-red-600 hover:text-red-700" onclick="removeFile('${file.name}')">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"/>
</svg>
</button>
</div>
`).join('');
};
const updateIndexedFiles = async () => {
const response = await fetchWithAuth('/health');
const data = await response.json();
indexedFiles.innerHTML = data.indexed_files.map(file => `
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
<span>${file}</span>
</div>
`).join('');
};
dropZone.addEventListener('dragover', (e) => {
e.preventDefault();
dropZone.classList.add('border-blue-500');
});
dropZone.addEventListener('dragleave', () => {
dropZone.classList.remove('border-blue-500');
});
dropZone.addEventListener('drop', (e) => {
e.preventDefault();
dropZone.classList.remove('border-blue-500');
const files = Array.from(e.dataTransfer.files);
state.files.push(...files);
updateFileList();
});
fileInput.addEventListener('change', () => {
state.files.push(...Array.from(fileInput.files));
updateFileList();
});
uploadBtn.addEventListener('click', async () => {
if (state.files.length === 0) {
showToast('Please select files to upload');
return;
}
let apiKey = localStorage.getItem('apiKey') || '';
const progress = document.getElementById('uploadProgress');
const progressBar = progress.querySelector('div');
const statusText = document.getElementById('uploadStatus');
progress.classList.remove('hidden');
for (let i = 0; i < state.files.length; i++) {
const formData = new FormData();
formData.append('file', state.files[i]);
try {
await fetch('/documents/upload', {
method: 'POST',
headers: apiKey ? { 'Authorization': `Bearer ${apiKey}` } : {},
body: formData
});
const percentage = ((i + 1) / state.files.length) * 100;
progressBar.style.width = `${percentage}%`;
statusText.textContent = `${i + 1}/${state.files.length}`;
} catch (error) {
console.error('Upload error:', error);
}
}
progress.classList.add('hidden');
});
rescanBtn.addEventListener('click', async () => {
const progress = document.getElementById('uploadProgress');
const progressBar = progress.querySelector('div');
const statusText = document.getElementById('uploadStatus');
progress.classList.remove('hidden');
try {
// Start the scanning process
const scanResponse = await fetch('/documents/scan', {
method: 'POST',
});
if (!scanResponse.ok) {
throw new Error('Scan failed to start');
}
// Start polling for progress
const pollInterval = setInterval(async () => {
const progressResponse = await fetch('/documents/scan-progress');
const progressData = await progressResponse.json();
// Update progress bar
progressBar.style.width = `${progressData.progress}%`;
// Update status text
if (progressData.total_files > 0) {
statusText.textContent = `Processing ${progressData.current_file} (${progressData.indexed_count}/${progressData.total_files})`;
}
// Check if scanning is complete
if (!progressData.is_scanning) {
clearInterval(pollInterval);
progress.classList.add('hidden');
statusText.textContent = 'Scan complete!';
}
}, 1000); // Poll every second
} catch (error) {
console.error('Upload error:', error);
progress.classList.add('hidden');
statusText.textContent = 'Error during scanning process';
}
});
updateIndexedFiles();
},
'query': () => {
const queryBtn = document.getElementById('queryBtn');
const queryInput = document.getElementById('queryInput');
const queryMode = document.getElementById('queryMode');
const queryResult = document.getElementById('queryResult');
let apiKey = localStorage.getItem('apiKey') || '';
queryBtn.addEventListener('click', async () => {
const query = queryInput.value.trim();
if (!query) {
showToast('Please enter a query');
return;
}
queryBtn.disabled = true;
queryBtn.innerHTML = `
<svg class="animate-spin h-5 w-5 mr-3" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4" fill="none"/>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"/>
</svg>
Processing...
`;
try {
const response = await fetchWithAuth('/query', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
query,
mode: queryMode.value,
stream: false,
only_need_context: false
})
});
const data = await response.json();
queryResult.innerHTML = marked.parse(data.response);
} catch (error) {
showToast('Error processing query');
} finally {
queryBtn.disabled = false;
queryBtn.textContent = 'Send Query';
}
});
},
'status': async () => {
const healthStatus = document.getElementById('healthStatus');
const configStatus = document.getElementById('configStatus');
try {
const response = await fetchWithAuth('/health');
const data = await response.json();
healthStatus.innerHTML = `
<div class="space-y-2">
<div class="flex items-center">
<div class="w-3 h-3 rounded-full ${data.status === 'healthy' ? 'bg-green-500' : 'bg-red-500'} mr-2"></div>
<span class="font-medium">${data.status}</span>
</div>
<div>
<p class="text-sm text-gray-600">Working Directory: ${data.working_directory}</p>
<p class="text-sm text-gray-600">Input Directory: ${data.input_directory}</p>
<p class="text-sm text-gray-600">Indexed Files: ${data.indexed_files_count}</p>
</div>
</div>
`;
configStatus.innerHTML = Object.entries(data.configuration)
.map(([key, value]) => `
<div class="mb-2">
<span class="text-sm font-medium text-gray-700">${key}:</span>
<span class="text-sm text-gray-600 ml-2">${value}</span>
</div>
`).join('');
} catch (error) {
showToast('Error fetching status');
}
},
'settings': () => {
const saveBtn = document.getElementById('saveSettings');
const apiKeyInput = document.getElementById('apiKeyInput');
saveBtn.addEventListener('click', () => {
state.apiKey = apiKeyInput.value;
localStorage.setItem('apiKey', state.apiKey);
showToast('Settings saved successfully');
});
}
};
// Navigation handling
document.querySelectorAll('.nav-item').forEach(item => {
item.addEventListener('click', (e) => {
e.preventDefault();
const page = item.dataset.page;
document.getElementById('content').innerHTML = pages[page]();
if (handlers[page]) handlers[page]();
state.currentPage = page;
});
});
// Initialize with file manager
document.getElementById('content').innerHTML = pages['file-manager']();
handlers['file-manager']();
// Global functions
window.removeFile = (fileName) => {
state.files = state.files.filter(file => file.name !== fileName);
document.querySelector('#fileList div').innerHTML = state.files.map(file => `
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
<span>${file.name}</span>
<button class="text-red-600 hover:text-red-700" onclick="removeFile('${file.name}')">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"/>
</svg>
</button>
</div>
`).join('');
};

View File

@@ -0,0 +1,211 @@
// js/graph.js
function openGraphModal(label) {
const modal = document.getElementById("graph-modal");
const graphTitle = document.getElementById("graph-title");
if (!modal || !graphTitle) {
console.error("Key element not found");
return;
}
graphTitle.textContent = `Knowledge Graph - ${label}`;
modal.style.display = "flex";
renderGraph(label);
}
function closeGraphModal() {
const modal = document.getElementById("graph-modal");
modal.style.display = "none";
clearGraph();
}
function clearGraph() {
const svg = document.getElementById("graph-svg");
svg.innerHTML = "";
}
async function getGraph(label) {
try {
const response = await fetch(`/graphs?label=${label}`);
const rawData = await response.json();
console.log({data: JSON.parse(JSON.stringify(rawData))});
const nodes = rawData.nodes
nodes.forEach(node => {
node.id = Date.now().toString(36) + Math.random().toString(36).substring(2); // 使用 crypto.randomUUID() 生成唯一 UUID
});
// Strictly verify edge data
const edges = (rawData.edges || []).map(edge => {
const sourceNode = nodes.find(n => n.labels.includes(edge.source));
const targetNode = nodes.find(n => n.labels.includes(edge.target)
)
;
if (!sourceNode || !targetNode) {
console.warn("NOT VALID EDGE:", edge);
return null;
}
return {
source: sourceNode,
target: targetNode,
type: edge.type || ""
};
}).filter(edge => edge !== null);
return {nodes, edges};
} catch (error) {
console.error("Loading graph failed:", error);
return {nodes: [], edges: []};
}
}
async function renderGraph(label) {
const data = await getGraph(label);
if (!data.nodes || data.nodes.length === 0) {
d3.select("#graph-svg")
.html(`<text x="50%" y="50%" text-anchor="middle">No valid nodes</text>`);
return;
}
const svg = d3.select("#graph-svg");
const width = svg.node().clientWidth;
const height = svg.node().clientHeight;
svg.selectAll("*").remove();
// Create a force oriented diagram layout
const simulation = d3.forceSimulation(data.nodes)
.force("charge", d3.forceManyBody().strength(-300))
.force("center", d3.forceCenter(width / 2, height / 2));
// Add a connection (if there are valid edges)
if (data.edges.length > 0) {
simulation.force("link",
d3.forceLink(data.edges)
.id(d => d.id)
.distance(100)
);
}
// Draw nodes
const nodes = svg.selectAll(".node")
.data(data.nodes)
.enter()
.append("circle")
.attr("class", "node")
.attr("r", 10)
.call(d3.drag()
.on("start", dragStarted)
.on("drag", dragged)
.on("end", dragEnded)
);
svg.append("defs")
.append("marker")
.attr("id", "arrow-out")
.attr("viewBox", "0 0 10 10")
.attr("refX", 8)
.attr("refY", 5)
.attr("markerWidth", 6)
.attr("markerHeight", 6)
.attr("orient", "auto")
.append("path")
.attr("d", "M0,0 L10,5 L0,10 Z")
.attr("fill", "#999");
// Draw edges (with arrows)
const links = svg.selectAll(".link")
.data(data.edges)
.enter()
.append("line")
.attr("class", "link")
.attr("marker-end", "url(#arrow-out)"); // Always draw arrows on the target side
// Edge style configuration
links
.attr("stroke", "#999")
.attr("stroke-width", 2)
.attr("stroke-opacity", 0.8);
// Draw label (with background box)
const labels = svg.selectAll(".label")
.data(data.nodes)
.enter()
.append("text")
.attr("class", "label")
.text(d => d.labels[0] || "")
.attr("text-anchor", "start")
.attr("dy", "0.3em")
.attr("fill", "#333");
// Update Location
simulation.on("tick", () => {
links
.attr("x1", d => {
// Calculate the direction vector from the source node to the target node
const dx = d.target.x - d.source.x;
const dy = d.target.y - d.source.y;
const distance = Math.sqrt(dx * dx + dy * dy);
if (distance === 0) return d.source.x; // 避免除以零 Avoid dividing by zero
// Adjust the starting point coordinates (source node edge) based on radius 10
return d.source.x + (dx / distance) * 10;
})
.attr("y1", d => {
const dx = d.target.x - d.source.x;
const dy = d.target.y - d.source.y;
const distance = Math.sqrt(dx * dx + dy * dy);
if (distance === 0) return d.source.y;
return d.source.y + (dy / distance) * 10;
})
.attr("x2", d => {
// Adjust the endpoint coordinates (target node edge) based on a radius of 10
const dx = d.target.x - d.source.x;
const dy = d.target.y - d.source.y;
const distance = Math.sqrt(dx * dx + dy * dy);
if (distance === 0) return d.target.x;
return d.target.x - (dx / distance) * 10;
})
.attr("y2", d => {
const dx = d.target.x - d.source.x;
const dy = d.target.y - d.source.y;
const distance = Math.sqrt(dx * dx + dy * dy);
if (distance === 0) return d.target.y;
return d.target.y - (dy / distance) * 10;
});
// Update the position of nodes and labels (keep unchanged)
nodes
.attr("cx", d => d.x)
.attr("cy", d => d.y);
labels
.attr("x", d => d.x + 12)
.attr("y", d => d.y + 4);
});
// Drag and drop logic
function dragStarted(event, d) {
if (!event.active) simulation.alphaTarget(0.3).restart();
d.fx = d.x;
d.fy = d.y;
}
function dragged(event, d) {
d.fx = event.x;
d.fy = event.y;
simulation.alpha(0.3).restart();
}
function dragEnded(event, d) {
if (!event.active) simulation.alphaTarget(0);
d.fx = null;
d.fy = null;
}
}

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from typing import TypedDict, Union, Literal, Generic, TypeVar
import os
import numpy as np
from .utils import EmbeddingFunc
@@ -17,16 +17,28 @@ T = TypeVar("T")
class QueryParam:
mode: Literal["light", "naive","mini"] = "mini"
only_need_context: bool = False
only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs"
stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 5
top_k: int = int(os.getenv("TOP_K", "60"))
# Number of document chunks to retrieve.
# top_n: int = 10
# Number of tokens for the original chunks.
max_token_for_text_unit: int = 2000
max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions
max_token_for_global_context: int = 2000
max_token_for_global_context: int = 4000
# Number of tokens for the entity descriptions
max_token_for_local_context: int = 2000#For Light/Graph
max_token_for_node_context: int = 500#For Mini, if too long, SLM may be fail to generate any response
max_token_for_local_context: int = 4000
hl_keywords: list[str] = field(default_factory=list)
ll_keywords: list[str] = field(default_factory=list)
# Conversation history support
conversation_history: list[dict] = field(
default_factory=list
) # Format: [{"role": "user/assistant", "content": "message"}]
history_turns: int = (
3 # Number of complete conversation turns (user-assistant pairs) to consider
)
@dataclass
class StorageNameSpace:

621
minirag/kg/age_impl.py Normal file
View File

@@ -0,0 +1,621 @@
import asyncio
import inspect
import json
import os
import sys
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import pipmaster as pm
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import psycopg
from psycopg.rows import namedtuple_row
from psycopg_pool import AsyncConnectionPool, PoolTimeout
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from lightrag.utils import logger
from ..base import BaseGraphStorage
if sys.platform.startswith("win"):
import asyncio.windows_events
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
class AGEQueryException(Exception):
"""Exception for the AGE queries."""
def __init__(self, exception: Union[str, Dict]) -> None:
if isinstance(exception, dict):
self.message = exception["message"] if "message" in exception else "unknown"
self.details = exception["details"] if "details" in exception else "unknown"
else:
self.message = exception
self.details = "unknown"
def get_message(self) -> str:
return self.message
def get_details(self) -> Any:
return self.details
@dataclass
class AGEStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self._driver = None
self._driver_lock = asyncio.Lock()
DB = os.environ["AGE_POSTGRES_DB"].replace("\\", "\\\\").replace("'", "\\'")
USER = os.environ["AGE_POSTGRES_USER"].replace("\\", "\\\\").replace("'", "\\'")
PASSWORD = (
os.environ["AGE_POSTGRES_PASSWORD"]
.replace("\\", "\\\\")
.replace("'", "\\'")
)
HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'")
PORT = int(os.environ["AGE_POSTGRES_PORT"])
self.graph_name = os.environ["AGE_GRAPH_NAME"]
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
self._driver = AsyncConnectionPool(connection_string, open=False)
return None
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def close(self):
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
if self._driver:
await self._driver.close()
async def index_done_callback(self):
print("KG successfully indexed.")
@staticmethod
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
"""
Convert a record returned from an age query to a dictionary
Args:
record (): a record from an age query result
Returns:
Dict[str, Any]: a dictionary representation of the record where
the dictionary key is the field name and the value is the
value converted to a python type
"""
# result holder
d = {}
# prebuild a mapping of vertex_id to vertex mappings to be used
# later to build edges
vertices = {}
for k in record._fields:
v = getattr(record, k)
# agtype comes back '{key: value}::type' which must be parsed
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
vertices[vertex["id"]] = vertex.get("properties")
# iterate returned fields and parse appropriately
for k in record._fields:
v = getattr(record, k)
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
else:
dtype = ""
if dtype == "vertex":
vertex = json.loads(v)
field = json.loads(v).get("properties")
if not field:
field = {}
field["label"] = AGEStorage._decode_graph_label(vertex["label"])
d[k] = field
# convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query
# this is an attempt to be consistent with neo4j implementation
elif dtype == "edge":
edge = json.loads(v)
d[k] = (
vertices.get(edge["start_id"], {}),
edge[
"label"
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
vertices.get(edge["end_id"], {}),
)
else:
d[k] = json.loads(v) if isinstance(v, str) else v
return d
@staticmethod
def _format_properties(
properties: Dict[str, Any], _id: Union[str, None] = None
) -> str:
"""
Convert a dictionary of properties to a string representation that
can be used in a cypher query insert/merge statement.
Args:
properties (Dict[str,str]): a dictionary containing node/edge properties
id (Union[str, None]): the id of the node or None if none exists
Returns:
str: the properties dictionary as a properly formatted string
"""
props = []
# wrap property key in backticks to escape
for k, v in properties.items():
prop = f"`{k}`: {json.dumps(v)}"
props.append(prop)
if _id is not None and "id" not in properties:
props.append(
f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}"
)
return "{" + ", ".join(props) + "}"
@staticmethod
def _encode_graph_label(label: str) -> str:
"""
Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
Args:
label (str): the original label
Returns:
str: the encoded label
"""
return "x" + label.encode().hex()
@staticmethod
def _decode_graph_label(encoded_label: str) -> str:
"""
Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
Args:
encoded_label (str): the encoded label
Returns:
str: the decoded label
"""
return bytes.fromhex(encoded_label.removeprefix("x")).decode()
@staticmethod
def _get_col_name(field: str, idx: int) -> str:
"""
Convert a cypher return field to a pgsql select field
If possible keep the cypher column name, but create a generic name if necessary
Args:
field (str): a return field from a cypher query to be formatted for pgsql
idx (int): the position of the field in the return statement
Returns:
str: the field to be used in the pgsql select statement
"""
# remove white space
field = field.strip()
# if an alias is provided for the field, use it
if " as " in field:
return field.split(" as ")[-1].strip()
# if the return value is an unnamed primitive, give it a generic name
if field.isnumeric() or field in ("true", "false", "null"):
return f"column_{idx}"
# otherwise return the value stripping out some common special chars
return field.replace("(", "_").replace(")", "")
@staticmethod
def _wrap_query(query: str, graph_name: str, **params: str) -> str:
"""
Convert a cypher query to an Apache Age compatible
sql query by wrapping the cypher query in ag_catalog.cypher,
casting results to agtype and building a select statement
Args:
query (str): a valid cypher query
graph_name (str): the name of the graph to query
params (dict): parameters for the query
Returns:
str: an equivalent pgsql query
"""
# pgsql template
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
{query}
$$) AS ({fields});"""
# if there are any returned fields they must be added to the pgsql query
if "return" in query.lower():
# parse return statement to identify returned fields
fields = (
query.lower()
.split("return")[-1]
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)
# raise exception if RETURN * is found as we can't resolve the fields
if "*" in [x.strip() for x in fields]:
raise ValueError(
"AGE graph does not support 'RETURN *'"
+ " statements in Cypher queries"
)
# get pgsql formatted field names
fields = [
AGEStorage._get_col_name(field, idx) for idx, field in enumerate(fields)
]
# build resulting pgsql relation
fields_str = ", ".join(
[field.split(".")[-1] + " agtype" for field in fields]
)
# if no return statement we still need to return a single field of type agtype
else:
fields_str = "a agtype"
select_str = "*"
return template.format(
graph_name=graph_name,
query=query.format(**params),
fields=fields_str,
projection=select_str,
)
async def _query(self, query: str, **params: str) -> List[Dict[str, Any]]:
"""
Query the graph by taking a cypher query, converting it to an
age compatible query, executing it and converting the result
Args:
query (str): a cypher query to be executed
params (dict): parameters for the query
Returns:
List[Dict[str, Any]]: a list of dictionaries containing the result set
"""
# convert cypher query to pgsql/age query
wrapped_query = self._wrap_query(query, self.graph_name, **params)
await self._driver.open()
# create graph if it doesn't exist
async with self._get_pool_connection() as conn:
async with conn.cursor() as curs:
try:
await curs.execute('SET search_path = ag_catalog, "$user", public')
await curs.execute(f"SELECT create_graph('{self.graph_name}')")
await conn.commit()
except (
psycopg.errors.InvalidSchemaName,
psycopg.errors.UniqueViolation,
):
await conn.rollback()
# execute the query, rolling back on an error
async with self._get_pool_connection() as conn:
async with conn.cursor(row_factory=namedtuple_row) as curs:
try:
await curs.execute('SET search_path = ag_catalog, "$user", public')
await curs.execute(wrapped_query)
await conn.commit()
except psycopg.Error as e:
await conn.rollback()
raise AGEQueryException(
{
"message": f"Error executing graph query: {query.format(**params)}",
"detail": str(e),
}
) from e
data = await curs.fetchall()
if data is None:
result = []
# decode records
else:
result = [AGEStorage._record_to_dict(d) for d in data]
return result
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists
"""
params = {"label": AGEStorage._encode_graph_label(entity_name_label)}
single_result = (await self._query(query, **params))[0]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
single_result["node_exists"],
)
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
query = """
MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
RETURN COUNT(r) > 0 AS edge_exists
"""
params = {
"src_label": AGEStorage._encode_graph_label(entity_name_label_source),
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target),
}
single_result = (await self._query(query, **params))[0]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
single_result["edge_exists"],
)
return single_result["edge_exists"]
async def get_node(self, node_id: str) -> Union[dict, None]:
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`) RETURN n
"""
params = {"label": AGEStorage._encode_graph_label(entity_name_label)}
record = await self._query(query, **params)
if record:
node = record[0]
node_dict = node["n"]
logger.debug(
"{%s}: query: {%s}, result: {%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
node_dict,
)
return node_dict
return None
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`)-[]->(x)
RETURN count(x) AS total_edge_count
"""
params = {"label": AGEStorage._encode_graph_label(entity_name_label)}
record = (await self._query(query, **params))[0]
if record:
edge_count = int(record["total_edge_count"])
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
edge_count,
)
return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('"')
entity_name_label_target = tgt_id.strip('"')
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
logger.debug(
"{%s}:query:src_Degree+trg_degree:result:{%s}",
inspect.currentframe().f_code.co_name,
degrees,
)
return degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""
Find all edges between nodes of two given labels
Args:
source_node_label (str): Label of the source nodes
target_node_label (str): Label of the target nodes
Returns:
list: List of all relationships/edges found
"""
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
query = """
MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
RETURN properties(r) as edge_properties
LIMIT 1
"""
params = {
"src_label": AGEStorage._encode_graph_label(entity_name_label_source),
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target),
}
record = await self._query(query, **params)
if record and record[0] and record[0]["edge_properties"]:
result = record[0]["edge_properties"]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
result,
)
return result
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
"""
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
"""
node_label = source_node_id.strip('"')
query = """
MATCH (n:`{label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected
"""
params = {"label": AGEStorage._encode_graph_label(node_label)}
results = await self._query(query, **params)
edges = []
for record in results:
source_node = record["n"] if record["n"] else None
connected_node = record["connected"] if record["connected"] else None
source_label = (
source_node["label"] if source_node and source_node["label"] else None
)
target_label = (
connected_node["label"]
if connected_node and connected_node["label"]
else None
)
if source_label and target_label:
edges.append((source_label, target_label))
return edges
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((AGEQueryException,)),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
"""
Upsert a node in the AGE database.
Args:
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = node_id.strip('"')
properties = node_data
query = """
MERGE (n:`{label}`)
SET n += {properties}
"""
params = {
"label": AGEStorage._encode_graph_label(label),
"properties": AGEStorage._format_properties(properties),
}
try:
await self._query(query, **params)
logger.debug(
"Upserted node with label '{%s}' and properties: {%s}",
label,
properties,
)
except Exception as e:
logger.error("Error during upsert: {%s}", e)
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((AGEQueryException,)),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
"""
Upsert an edge and its properties between two nodes identified by their labels.
Args:
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
edge_properties = edge_data
query = """
MATCH (source:`{src_label}`)
WITH source
MATCH (target:`{tgt_label}`)
MERGE (source)-[r:DIRECTED]->(target)
SET r += {properties}
RETURN r
"""
params = {
"src_label": AGEStorage._encode_graph_label(source_node_label),
"tgt_label": AGEStorage._encode_graph_label(target_node_label),
"properties": AGEStorage._format_properties(edge_properties),
}
try:
await self._query(query, **params)
logger.debug(
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
source_node_label,
target_node_label,
edge_properties,
)
except Exception as e:
logger.error("Error during edge upsert: {%s}", e)
raise
async def _node2vec_embed(self):
print("Implemented but never called.")
@asynccontextmanager
async def _get_pool_connection(self, timeout: Optional[float] = None):
"""Workaround for a psycopg_pool bug"""
try:
connection = await self._driver.getconn(timeout=timeout)
except PoolTimeout:
await self._driver._add_connection(None) # workaround...
connection = await self._driver.getconn(timeout=timeout)
try:
async with connection:
yield connection
finally:
await self._driver.putconn(connection)

173
minirag/kg/chroma_impl.py Normal file
View File

@@ -0,0 +1,173 @@
import os
import asyncio
from dataclasses import dataclass
from typing import Union
import numpy as np
from chromadb import HttpClient
from chromadb.config import Settings
from lightrag.base import BaseVectorStorage
from lightrag.utils import logger
@dataclass
class ChromaVectorDBStorage(BaseVectorStorage):
"""ChromaDB vector storage implementation."""
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self):
try:
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
user_collection_settings = config.get("collection_settings", {})
# Default HNSW index settings for ChromaDB
default_collection_settings = {
# Distance metric used for similarity search (cosine similarity)
"hnsw:space": "cosine",
# Number of nearest neighbors to explore during index construction
# Higher values = better recall but slower indexing
"hnsw:construction_ef": 128,
# Number of nearest neighbors to explore during search
# Higher values = better recall but slower search
"hnsw:search_ef": 128,
# Number of connections per node in the HNSW graph
# Higher values = better recall but more memory usage
"hnsw:M": 16,
# Number of vectors to process in one batch during indexing
"hnsw:batch_size": 100,
# Number of updates before forcing index synchronization
# Lower values = more frequent syncs but slower indexing
"hnsw:sync_threshold": 1000,
}
collection_settings = {
**default_collection_settings,
**user_collection_settings,
}
auth_provider = config.get(
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
auth_credentials = config.get("auth_token", "secret-token")
headers = {}
if "token_authn" in auth_provider:
headers = {
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
}
elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin")
self._client = HttpClient(
host=config.get("host", "localhost"),
port=config.get("port", 8000),
headers=headers,
settings=Settings(
chroma_api_impl="rest",
chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=auth_credentials,
allow_reset=True,
anonymized_telemetry=False,
),
)
self._collection = self._client.get_or_create_collection(
name=self.namespace,
metadata={
**collection_settings,
"dimension": self.embedding_func.embedding_dim,
},
)
# Use batch size from collection settings if specified
self._max_batch_size = self.global_config.get(
"embedding_batch_num", collection_settings.get("hnsw:batch_size", 32)
)
except Exception as e:
logger.error(f"ChromaDB initialization failed: {str(e)}")
raise
async def upsert(self, data: dict[str, dict]):
if not data:
logger.warning("Empty data provided to vector DB")
return []
try:
ids = list(data.keys())
documents = [v["content"] for v in data.values()]
metadatas = [
{k: v for k, v in item.items() if k in self.meta_fields}
or {"_default": "true"}
for item in data.values()
]
# Process in batches
batches = [
documents[i : i + self._max_batch_size]
for i in range(0, len(documents), self._max_batch_size)
]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = []
# Pre-allocate embeddings_list with known size
embeddings_list = [None] * len(embedding_tasks)
# Use asyncio.gather instead of as_completed if order doesn't matter
embeddings_results = await asyncio.gather(*embedding_tasks)
embeddings_list = list(embeddings_results)
embeddings = np.concatenate(embeddings_list)
# Upsert in batches
for i in range(0, len(ids), self._max_batch_size):
batch_slice = slice(i, i + self._max_batch_size)
self._collection.upsert(
ids=ids[batch_slice],
embeddings=embeddings[batch_slice].tolist(),
documents=documents[batch_slice],
metadatas=metadatas[batch_slice],
)
return ids
except Exception as e:
logger.error(f"Error during ChromaDB upsert: {str(e)}")
raise
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
try:
embedding = await self.embedding_func([query])
results = self._collection.query(
query_embeddings=embedding.tolist(),
n_results=top_k * 2, # Request more results to allow for filtering
include=["metadatas", "distances", "documents"],
)
# Filter results by cosine similarity threshold and take top k
# We request 2x results initially to have enough after filtering
# ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal)
# We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity)
# Only keep results with distance below threshold, then take top k
return [
{
"id": results["ids"][0][i],
"distance": 1 - results["distances"][0][i],
"content": results["documents"][0][i],
**results["metadatas"][0][i],
}
for i in range(len(results["ids"][0]))
if (1 - results["distances"][0][i]) >= self.cosine_better_than_threshold
][:top_k]
except Exception as e:
logger.error(f"Error during ChromaDB query: {str(e)}")
raise
async def index_done_callback(self):
# ChromaDB handles persistence automatically
pass

397
minirag/kg/gremlin_impl.py Normal file
View File

@@ -0,0 +1,397 @@
import asyncio
import inspect
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from gremlin_python.driver import client, serializer
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
from gremlin_python.driver.protocol import GremlinServerError
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from lightrag.utils import logger
from ..base import BaseGraphStorage
@dataclass
class GremlinStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with Gremlin in production")
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self._driver = None
self._driver_lock = asyncio.Lock()
USER = os.environ.get("GREMLIN_USER", "")
PASSWORD = os.environ.get("GREMLIN_PASSWORD", "")
HOST = os.environ["GREMLIN_HOST"]
PORT = int(os.environ["GREMLIN_PORT"])
# TraversalSource, a custom one has to be created manually,
# default it "g"
SOURCE = os.environ.get("GREMLIN_TRAVERSE_SOURCE", "g")
# All vertices will have graph={GRAPH} property, so that we can
# have several logical graphs for one source
GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"])
self.graph_name = GRAPH
self._driver = client.Client(
f"ws://{HOST}:{PORT}/gremlin",
SOURCE,
username=USER,
password=PASSWORD,
message_serializer=serializer.GraphSONSerializersV3d0(),
transport_factory=lambda: AiohttpTransport(call_from_event_loop=True),
)
def __post_init__(self):
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def close(self):
if self._driver:
self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
if self._driver:
self._driver.close()
async def index_done_callback(self):
print("KG successfully indexed.")
@staticmethod
def _to_value_map(value: Any) -> str:
"""Dump supported Python object as Gremlin valueMap"""
json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
parsed_str = json_str.replace("'", r"\'")
# walk over the string and replace curly brackets with square brackets
# outside of strings, as well as replace double quotes with single quotes
# and "deescape" double quotes inside of strings
outside_str = True
escaped = False
remove_indices = []
for i, c in enumerate(parsed_str):
if escaped:
# previous character was an "odd" backslash
escaped = False
if c == '"':
# we want to "deescape" double quotes: store indices to delete
remove_indices.insert(0, i - 1)
elif c == "\\":
escaped = True
elif c == '"':
outside_str = not outside_str
parsed_str = parsed_str[:i] + "'" + parsed_str[i + 1 :]
elif c == "{" and outside_str:
parsed_str = parsed_str[:i] + "[" + parsed_str[i + 1 :]
elif c == "}" and outside_str:
parsed_str = parsed_str[:i] + "]" + parsed_str[i + 1 :]
for idx in remove_indices:
parsed_str = parsed_str[:idx] + parsed_str[idx + 1 :]
return parsed_str
@staticmethod
def _convert_properties(properties: Dict[str, Any]) -> str:
"""Create chained .property() commands from properties dict"""
props = []
for k, v in properties.items():
prop_name = GremlinStorage._to_value_map(k)
props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})")
return "".join(props)
@staticmethod
def _fix_name(name: str) -> str:
"""Strip double quotes and format as a proper field name"""
name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'"))
return name
async def _query(self, query: str) -> List[Dict[str, Any]]:
"""
Query the Gremlin graph
Args:
query (str): a query to be executed
Returns:
List[Dict[str, Any]]: a list of dictionaries containing the result set
"""
result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
if result:
result = result[0]
return result
async def has_node(self, node_id: str) -> bool:
entity_name = GremlinStorage._fix_name(node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.limit(1)
.count()
.project('has_node')
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
"""
result = await self._query(query)
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
result[0]["has_node"],
)
return result[0]["has_node"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_source = GremlinStorage._fix_name(source_node_id)
entity_name_target = GremlinStorage._fix_name(target_node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name_source})
.outE()
.inV().has('graph', {self.graph_name})
.has('entity_name', {entity_name_target})
.limit(1)
.count()
.project('has_edge')
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
"""
result = await self._query(query)
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
result[0]["has_edge"],
)
return result[0]["has_edge"]
async def get_node(self, node_id: str) -> Union[dict, None]:
entity_name = GremlinStorage._fix_name(node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.limit(1)
.project('properties')
.by(elementMap())
"""
result = await self._query(query)
if result:
node = result[0]
node_dict = node["properties"]
logger.debug(
"{%s}: query: {%s}, result: {%s}",
inspect.currentframe().f_code.co_name,
query.format,
node_dict,
)
return node_dict
async def node_degree(self, node_id: str) -> int:
entity_name = GremlinStorage._fix_name(node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name})
.outE()
.inV().has('graph', {self.graph_name})
.count()
.project('total_edge_count')
.by()
"""
result = await self._query(query)
edge_count = result[0]["total_edge_count"]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
edge_count,
)
return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
logger.debug(
"{%s}:query:src_Degree+trg_degree:result:{%s}",
inspect.currentframe().f_code.co_name,
degrees,
)
return degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""
Find all edges between nodes of two given names
Args:
source_node_id (str): Name of the source nodes
target_node_id (str): Name of the target nodes
Returns:
dict|None: Dict of found edge properties, or None if not found
"""
entity_name_source = GremlinStorage._fix_name(source_node_id)
entity_name_target = GremlinStorage._fix_name(target_node_id)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {entity_name_source})
.outE()
.inV().has('graph', {self.graph_name})
.has('entity_name', {entity_name_target})
.limit(1)
.project('edge_properties')
.by(__.bothE().elementMap())
"""
result = await self._query(query)
if result:
edge_properties = result[0]["edge_properties"]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query,
edge_properties,
)
return edge_properties
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
"""
Retrieves all edges (relationships) for a particular node identified by its name.
:return: List of tuples containing edge sources and targets
"""
node_name = GremlinStorage._fix_name(source_node_id)
query = f"""g
.E()
.filter(
__.or(
__.outV().has('graph', {self.graph_name})
.has('entity_name', {node_name}),
__.inV().has('graph', {self.graph_name})
.has('entity_name', {node_name})
)
)
.project('source_name', 'target_name')
.by(__.outV().values('entity_name'))
.by(__.inV().values('entity_name'))
"""
result = await self._query(query)
edges = [(res["source_name"], res["target_name"]) for res in result]
return edges
@retry(
stop=stop_after_attempt(10),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((GremlinServerError,)),
)
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
"""
Upsert a node in the Gremlin graph.
Args:
node_id: The unique identifier for the node (used as name)
node_data: Dictionary of node properties
"""
name = GremlinStorage._fix_name(node_id)
properties = GremlinStorage._convert_properties(node_data)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {name})
.fold()
.coalesce(
__.unfold(),
__.addV('ENTITY')
.property('graph', {self.graph_name})
.property('entity_name', {name})
)
{properties}
"""
try:
await self._query(query)
logger.debug(
"Upserted node with name {%s} and properties: {%s}",
name,
properties,
)
except Exception as e:
logger.error("Error during upsert: {%s}", e)
raise
@retry(
stop=stop_after_attempt(10),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((GremlinServerError,)),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
):
"""
Upsert an edge and its properties between two nodes identified by their names.
Args:
source_node_id (str): Name of the source node (used as identifier)
target_node_id (str): Name of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_name = GremlinStorage._fix_name(source_node_id)
target_node_name = GremlinStorage._fix_name(target_node_id)
edge_properties = GremlinStorage._convert_properties(edge_data)
query = f"""g
.V().has('graph', {self.graph_name})
.has('entity_name', {source_node_name}).as('source')
.V().has('graph', {self.graph_name})
.has('entity_name', {target_node_name}).as('target')
.coalesce(
__.select('source').outE('DIRECTED').where(__.inV().as('target')),
__.select('source').addE('DIRECTED').to(__.select('target'))
)
.property('graph', {self.graph_name})
{edge_properties}
"""
try:
await self._query(query)
logger.debug(
"Upserted edge from {%s} to {%s} with properties: {%s}",
source_node_name,
target_node_name,
edge_properties,
)
except Exception as e:
logger.error("Error during edge upsert: {%s}", e)
raise
async def _node2vec_embed(self):
print("Implemented but never called.")

134
minirag/kg/json_kv_impl.py Normal file
View File

@@ -0,0 +1,134 @@
"""
JsonDocStatus Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import asyncio
import os
from dataclasses import dataclass
from lightrag.utils import (
logger,
load_json,
write_json,
)
from lightrag.base import (
BaseKVStorage,
)
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
self._lock = asyncio.Lock()
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def all_keys(self) -> list[str]:
return list(self._data.keys())
async def index_done_callback(self):
write_json(self._data, self._file_name)
async def get_by_id(self, id):
return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None):
if fields is None:
return [self._data.get(id, None) for id in ids]
return [
(
{k: v for k, v in self._data[id].items() if k in fields}
if self._data.get(id, None)
else None
)
for id in ids
]
async def filter_keys(self, data: list[str]) -> set[str]:
return set([s for s in data if s not in self._data])
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
return left_data
async def drop(self):
self._data = {}
async def filter(self, filter_func):
"""Filter key-value pairs based on a filter function
Args:
filter_func: The filter function, which takes a value as an argument and returns a boolean value
Returns:
Dict: Key-value pairs that meet the condition
"""
result = {}
async with self._lock:
for key, value in self._data.items():
if filter_func(value):
result[key] = value
return result
async def delete(self, ids: list[str]):
"""Delete data with specified IDs
Args:
ids: List of IDs to delete
"""
async with self._lock:
for id in ids:
if id in self._data:
del self._data[id]
await self.index_done_callback()
logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}")

View File

@@ -0,0 +1,128 @@
"""
JsonDocStatus Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import os
from dataclasses import dataclass
from typing import Union, Dict
from lightrag.utils import (
logger,
load_json,
write_json,
)
from lightrag.base import (
DocStatus,
DocProcessingStatus,
DocStatusStorage,
)
@dataclass
class JsonDocStatusStorage(DocStatusStorage):
"""JSON implementation of document status storage"""
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info(f"Loaded document status storage with {len(self._data)} records")
async def filter_keys(self, data: list[str]) -> set[str]:
"""Return keys that should be processed (not in storage or not successfully processed)"""
return set(
[
k
for k in data
if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED
]
)
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
counts = {status: 0 for status in DocStatus}
for doc in self._data.values():
counts[doc["status"]] += 1
return counts
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
async def index_done_callback(self):
"""Save data to file after indexing"""
write_json(self._data, self._file_name)
async def upsert(self, data: dict[str, dict]):
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
self._data.update(data)
await self.index_done_callback()
return data
async def get_by_id(self, id: str):
return self._data.get(id)
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
"""Get document status by ID"""
return self._data.get(doc_id)
async def delete(self, doc_ids: list[str]):
"""Delete document status by IDs"""
for doc_id in doc_ids:
self._data.pop(doc_id, None)
await self.index_done_callback()

94
minirag/kg/milvus_impl.py Normal file
View File

@@ -0,0 +1,94 @@
import asyncio
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import numpy as np
from lightrag.utils import logger
from ..base import BaseVectorStorage
import pipmaster as pm
if not pm.is_installed("pymilvus"):
pm.install("pymilvus")
from pymilvus import MilvusClient
@dataclass
class MilvusVectorDBStorge(BaseVectorStorage):
@staticmethod
def create_collection_if_not_exist(
client: MilvusClient, collection_name: str, **kwargs
):
if client.has_collection(collection_name):
return
client.create_collection(
collection_name, max_length=64, id_type="string", **kwargs
)
def __post_init__(self):
self._client = MilvusClient(
uri=os.environ.get(
"MILVUS_URI",
os.path.join(self.global_config["working_dir"], "milvus_lite.db"),
),
user=os.environ.get("MILVUS_USER", ""),
password=os.environ.get("MILVUS_PASSWORD", ""),
token=os.environ.get("MILVUS_TOKEN", ""),
db_name=os.environ.get("MILVUS_DB_NAME", ""),
)
self._max_batch_size = self.global_config["embedding_batch_num"]
MilvusVectorDBStorge.create_collection_if_not_exist(
self._client,
self.namespace,
dimension=self.embedding_func.embedding_dim,
)
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
list_data = [
{
"id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
async def wrapped_task(batch):
result = await self.embedding_func(batch)
pbar.update(1)
return result
embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
)
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["vector"] = embeddings[i]
results = self._client.upsert(collection_name=self.namespace, data=list_data)
return results
async def query(self, query, top_k=5):
embedding = await self.embedding_func([query])
results = self._client.search(
collection_name=self.namespace,
data=embedding,
limit=top_k,
output_fields=list(self.meta_fields),
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
)
print(results)
return [
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
for dp in results[0]
]

440
minirag/kg/mongo_impl.py Normal file
View File

@@ -0,0 +1,440 @@
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import pipmaster as pm
import np
if not pm.is_installed("pymongo"):
pm.install("pymongo")
from pymongo import MongoClient
from motor.motor_asyncio import AsyncIOMotorClient
from typing import Union, List, Tuple
from lightrag.utils import logger
from lightrag.base import BaseKVStorage
from lightrag.base import BaseGraphStorage
@dataclass
class MongoKVStorage(BaseKVStorage):
def __post_init__(self):
client = MongoClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
)
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as KV {self.namespace}")
async def all_keys(self) -> list[str]:
return [x["_id"] for x in self._data.find({}, {"_id": 1})]
async def get_by_id(self, id):
return self._data.find_one({"_id": id})
async def get_by_ids(self, ids, fields=None):
if fields is None:
return list(self._data.find({"_id": {"$in": ids}}))
return list(
self._data.find(
{"_id": {"$in": ids}},
{field: 1 for field in fields},
)
)
async def filter_keys(self, data: list[str]) -> set[str]:
existing_ids = [
str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1})
]
return set([s for s in data if s not in existing_ids])
async def upsert(self, data: dict[str, dict]):
if self.namespace == "llm_response_cache":
for mode, items in data.items():
for k, v in tqdm_async(items.items(), desc="Upserting"):
key = f"{mode}_{k}"
result = self._data.update_one(
{"_id": key}, {"$setOnInsert": v}, upsert=True
)
if result.upserted_id:
logger.debug(f"\nInserted new document with key: {key}")
data[mode][k]["_id"] = key
else:
for k, v in tqdm_async(data.items(), desc="Upserting"):
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
data[k]["_id"] = k
return data
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
if "llm_response_cache" == self.namespace:
res = {}
v = self._data.find_one({"_id": mode + "_" + id})
if v:
res[id] = v
logger.debug(f"llm_response_cache find one by:{id}")
return res
else:
return None
else:
return None
async def drop(self):
""" """
pass
@dataclass
class MongoGraphStorage(BaseGraphStorage):
"""
A concrete implementation using MongoDBs $graphLookup to demonstrate multi-hop queries.
"""
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self.client = AsyncIOMotorClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
)
self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")]
self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")]
#
# -------------------------------------------------------------------------
# HELPER: $graphLookup pipeline
# -------------------------------------------------------------------------
#
async def _graph_lookup(
self, start_node_id: str, max_depth: int = None
) -> List[dict]:
"""
Performs a $graphLookup starting from 'start_node_id' and returns
all reachable documents (including the start node itself).
Pipeline Explanation:
- 1) $match: We match the start node document by _id = start_node_id.
- 2) $graphLookup:
"from": same collection,
"startWith": "$edges.target" (the immediate neighbors in 'edges'),
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "reachableNodes",
"maxDepth": max_depth (if provided),
"depthField": "depth" (used for debugging or filtering).
- 3) We add an $project or $unwind as needed to extract data.
"""
pipeline = [
{"$match": {"_id": start_node_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "reachableNodes",
"depthField": "depth",
}
},
]
# If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth
if max_depth is not None:
pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
# Return the matching doc plus a field "reachableNodes"
cursor = self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
# If there's no matching node, results = [].
# Otherwise, results[0] is the start node doc,
# plus results[0]["reachableNodes"] is the array of connected docs.
return results
#
# -------------------------------------------------------------------------
# BASIC QUERIES
# -------------------------------------------------------------------------
#
async def has_node(self, node_id: str) -> bool:
"""
Check if node_id is present in the collection by looking up its doc.
No real need for $graphLookup here, but let's keep it direct.
"""
doc = await self.collection.find_one({"_id": node_id}, {"_id": 1})
return doc is not None
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if there's a direct single-hop edge from source_node_id to target_node_id.
We'll do a $graphLookup with maxDepth=0 from the source node—meaning
“Look up zero expansions.” Actually, for a direct edge check, we can do maxDepth=1
and then see if the target node is in the "reachableNodes" at depth=0.
But typically for a direct edge, we might just do a find_one.
Below is a demonstration approach.
"""
# We can do a single-hop graphLookup (maxDepth=0 or 1).
# Then check if the target_node appears among the edges array.
pipeline = [
{"$match": {"_id": source_node_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "reachableNodes",
"depthField": "depth",
"maxDepth": 0, # means: do not follow beyond immediate edges
}
},
{
"$project": {
"_id": 0,
"reachableNodes._id": 1, # only keep the _id from the subdocs
}
},
]
cursor = self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
if not results:
return False
# results[0]["reachableNodes"] are the immediate neighbors
reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])]
return target_node_id in reachable_ids
#
# -------------------------------------------------------------------------
# DEGREES
# -------------------------------------------------------------------------
#
async def node_degree(self, node_id: str) -> int:
"""
Returns the total number of edges connected to node_id (both inbound and outbound).
The easiest approach is typically two queries:
- count of edges array in node_id's doc
- count of how many other docs have node_id in their edges.target.
But we'll do a $graphLookup demonstration for inbound edges:
1) Outbound edges: direct from node's edges array
2) Inbound edges: we can do a special $graphLookup from all docs
or do an explicit match.
For demonstration, let's do this in two steps (with second step $graphLookup).
"""
# --- 1) Outbound edges (direct from doc) ---
doc = await self.collection.find_one({"_id": node_id}, {"edges": 1})
if not doc:
return 0
outbound_count = len(doc.get("edges", []))
# --- 2) Inbound edges:
# A simple way is: find all docs where "edges.target" == node_id.
# But let's do a $graphLookup from `node_id` in REVERSE.
# There's a trick to do "reverse" graphLookups: you'd store
# reversed edges or do a more advanced pipeline. Typically you'd do
# a direct match. We'll just do a direct match for inbound.
inbound_count_pipeline = [
{"$match": {"edges.target": node_id}},
{
"$project": {
"matchingEdgesCount": {
"$size": {
"$filter": {
"input": "$edges",
"as": "edge",
"cond": {"$eq": ["$$edge.target", node_id]},
}
}
}
}
},
{"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}},
]
inbound_cursor = self.collection.aggregate(inbound_count_pipeline)
inbound_result = await inbound_cursor.to_list(None)
inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0
return outbound_count + inbound_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""
If your graph can hold multiple edges from the same src to the same tgt
(e.g. different 'relation' values), you can sum them. If it's always
one edge, this is typically 1 or 0.
We'll do a single-hop $graphLookup from src_id,
then count how many edges reference tgt_id at depth=0.
"""
pipeline = [
{"$match": {"_id": src_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "neighbors",
"depthField": "depth",
"maxDepth": 0,
}
},
{"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}},
]
cursor = self.collection.aggregate(pipeline)
results = await cursor.to_list(None)
if not results:
return 0
# We can simply count how many edges in `results[0].edges` have target == tgt_id.
edges = results[0].get("edges", [])
count = sum(1 for e in edges if e.get("target") == tgt_id)
return count
#
# -------------------------------------------------------------------------
# GETTERS
# -------------------------------------------------------------------------
#
async def get_node(self, node_id: str) -> Union[dict, None]:
"""
Return the full node document (including "edges"), or None if missing.
"""
return await self.collection.find_one({"_id": node_id})
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""
Return the first edge dict from source_node_id to target_node_id if it exists.
Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
"""
pipeline = [
{"$match": {"_id": source_node_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "neighbors",
"depthField": "depth",
"maxDepth": 0,
}
},
{"$project": {"edges": 1}},
]
cursor = self.collection.aggregate(pipeline)
docs = await cursor.to_list(None)
if not docs:
return None
for e in docs[0].get("edges", []):
if e.get("target") == target_node_id:
return e
return None
async def get_node_edges(
self, source_node_id: str
) -> Union[List[Tuple[str, str]], None]:
"""
Return a list of (target_id, relation) for direct edges from source_node_id.
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
"""
pipeline = [
{"$match": {"_id": source_node_id}},
{
"$graphLookup": {
"from": self.collection.name,
"startWith": "$edges.target",
"connectFromField": "edges.target",
"connectToField": "_id",
"as": "neighbors",
"depthField": "depth",
"maxDepth": 0,
}
},
{"$project": {"_id": 0, "edges": 1}},
]
cursor = self.collection.aggregate(pipeline)
result = await cursor.to_list(None)
if not result:
return None
edges = result[0].get("edges", [])
return [(e["target"], e["relation"]) for e in edges]
#
# -------------------------------------------------------------------------
# UPSERTS
# -------------------------------------------------------------------------
#
async def upsert_node(self, node_id: str, node_data: dict):
"""
Insert or update a node document. If new, create an empty edges array.
"""
# By default, preserve existing 'edges'.
# We'll only set 'edges' to [] on insert (no overwrite).
update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}}
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict
):
"""
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
If an edge with the same target exists, we remove it and re-insert with updated data.
"""
# Ensure source node exists
await self.upsert_node(source_node_id, {})
# Remove existing edge (if any)
await self.collection.update_one(
{"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}}
)
# Insert new edge
new_edge = {"target": target_node_id}
new_edge.update(edge_data)
await self.collection.update_one(
{"_id": source_node_id}, {"$push": {"edges": new_edge}}
)
#
# -------------------------------------------------------------------------
# DELETION
# -------------------------------------------------------------------------
#
async def delete_node(self, node_id: str):
"""
1) Remove nodes doc entirely.
2) Remove inbound edges from any doc that references node_id.
"""
# Remove inbound edges from all other docs
await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}})
# Remove the node doc
await self.collection.delete_one({"_id": node_id})
#
# -------------------------------------------------------------------------
# EMBEDDINGS (NOT IMPLEMENTED)
# -------------------------------------------------------------------------
#
async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]:
"""
Placeholder for demonstration, raises NotImplementedError.
"""
raise NotImplementedError("Node embedding is not used in lightrag.")

View File

@@ -0,0 +1,213 @@
"""
NanoVectorDB Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import asyncio
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import numpy as np
import pipmaster as pm
if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB
import time
from lightrag.utils import (
logger,
compute_mdhash_id,
)
from lightrag.base import (
BaseVectorStorage,
)
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self):
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
current_time = time.time()
list_data = [
{
"__id__": k,
"__created_at__": current_time,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
async def wrapped_task(batch):
result = await self.embedding_func(batch)
pbar.update(1)
return result
embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
)
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
else:
# sometimes the embedding is not returned correctly. just log it.
logger.error(
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
)
async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
embedding = embedding[0]
logger.info(
f"Query: {query}, top_k: {top_k}, cosine_better_than_threshold: {self.cosine_better_than_threshold}"
)
results = self._client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
results = [
{
**dp,
"id": dp["__id__"],
"distance": dp["__metrics__"],
"created_at": dp.get("__created_at__"),
}
for dp in results
]
return results
@property
def client_storage(self):
return getattr(self._client, "_NanoVectorDB__storage")
async def delete(self, ids: list[str]):
"""Delete vectors with specified IDs
Args:
ids: List of vector IDs to be deleted
"""
try:
self._client.delete(ids)
logger.info(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str):
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Check if the entity exists
if self._client.get([entity_id]):
await self.delete([entity_id])
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str):
try:
relations = [
dp
for dp in self.client_storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete:
await self.delete(ids_to_delete)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
)
else:
logger.debug(f"No relations found for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self):
self._client.save()

View File

@@ -1,18 +1,20 @@
import asyncio
import inspect
import os
from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict
import inspect
from minirag.utils import logger
from ..base import BaseGraphStorage
import pipmaster as pm
if not pm.is_installed("neo4j"):
pm.install("neo4j")
from neo4j import (
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
GraphDatabase,
)
from tenacity import (
retry,
stop_after_attempt,
@@ -20,6 +22,9 @@ from tenacity import (
retry_if_exception_type,
)
from lightrag.utils import logger
from ..base import BaseGraphStorage
@dataclass
class Neo4JStorage(BaseGraphStorage):
@@ -27,17 +32,63 @@ class Neo4JStorage(BaseGraphStorage):
def load_nx_graph(file_name):
print("no preloading of graph with neo4j in production")
def __init__(self, namespace, global_config):
super().__init__(namespace=namespace, global_config=global_config)
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self._driver = None
self._driver_lock = asyncio.Lock()
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"]
PASSWORD = os.environ["NEO4J_PASSWORD"]
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
DATABASE = os.environ.get(
"NEO4J_DATABASE"
) # If this param is None, the home database will be used. If it is not None, the specified database will be used.
self._DATABASE = DATABASE
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD)
)
return None
_database_name = "home database" if DATABASE is None else f"database {DATABASE}"
with GraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
) as _sync_driver:
try:
with _sync_driver.session(database=DATABASE) as session:
try:
session.run("MATCH (n) RETURN n LIMIT 0")
logger.info(f"Connected to {DATABASE} at {URI}")
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"{DATABASE} at {URI} is not available".capitalize()
)
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {DATABASE} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()
)
try:
with _sync_driver.session() as session:
session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS")
logger.info(f"{DATABASE} at {URI} created".capitalize())
except neo4jExceptions.ClientError as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
):
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead."
)
logger.error(f"Failed to create {DATABASE} at {URI}")
raise e
def __post_init__(self):
self._node_embed_algorithms = {
@@ -59,7 +110,7 @@ class Neo4JStorage(BaseGraphStorage):
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
query = (
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
)
@@ -74,7 +125,7 @@ class Neo4JStorage(BaseGraphStorage):
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
query = (
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
"RETURN COUNT(r) > 0 AS edgeExists"
@@ -86,11 +137,8 @@ class Neo4JStorage(BaseGraphStorage):
)
return single_result["edgeExists"]
def close(self):
self._driver.close()
async def get_node(self, node_id: str) -> Union[dict, None]:
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
entity_name_label = node_id.strip('"')
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
result = await session.run(query)
@@ -107,7 +155,7 @@ class Neo4JStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('"')
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
query = f"""
MATCH (n:`{entity_name_label}`)
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
@@ -154,7 +202,7 @@ class Neo4JStorage(BaseGraphStorage):
Returns:
list: List of all relationships/edges found
"""
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
query = f"""
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
RETURN properties(r) as edge_properties
@@ -185,7 +233,7 @@ class Neo4JStorage(BaseGraphStorage):
query = f"""MATCH (n:`{node_label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
results = await session.run(query)
edges = []
async for record in results:
@@ -214,6 +262,7 @@ class Neo4JStorage(BaseGraphStorage):
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
@@ -239,7 +288,7 @@ class Neo4JStorage(BaseGraphStorage):
)
try:
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_upsert)
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
@@ -286,7 +335,7 @@ class Neo4JStorage(BaseGraphStorage):
)
try:
async with self._driver.session() as session:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_upsert_edge)
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
@@ -294,3 +343,177 @@ class Neo4JStorage(BaseGraphStorage):
async def _node2vec_embed(self):
print("Implemented but never called.")
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> Dict[str, List[Dict]]:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Key fixes:
1. Include the starting node itself
2. Handle multi-label nodes
3. Clarify relationship directions
4. Add depth control
"""
label = node_label.strip('"')
result = {"nodes": [], "edges": []}
seen_nodes = set()
seen_edges = set()
async with self._driver.session(database=self._DATABASE) as session:
try:
# Critical debug step: first verify if starting node exists
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
validate_result = await session.run(validate_query)
if not await validate_result.single():
logger.warning(f"Starting node {label} does not exist!")
return result
# Optimized query (including direction handling and self-loops)
main_query = f"""
MATCH (start:`{label}`)
WITH start
CALL apoc.path.subgraphAll(start, {{
relationshipFilter: '>',
minLevel: 0,
maxLevel: {max_depth},
bfs: true
}})
YIELD nodes, relationships
RETURN nodes, relationships
"""
result_set = await session.run(main_query)
record = await result_set.single()
if record:
# Handle nodes (compatible with multi-label cases)
for node in record["nodes"]:
# Use node ID + label combination as unique identifier
node_id = f"{node.id}_{'_'.join(node.labels)}"
if node_id not in seen_nodes:
node_data = dict(node)
node_data["labels"] = list(node.labels) # Keep all labels
result["nodes"].append(node_data)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = f"{rel.id}_{rel.type}"
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
edge_data = dict(rel)
edge_data.update(
{
"source": f"{start.id}_{'_'.join(start.labels)}",
"target": f"{end.id}_{'_'.join(end.labels)}",
"type": rel.type,
"direction": rel.element_id.split(
"->" if rel.end_node == end else "<-"
)[1],
}
)
result["edges"].append(edge_data)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result['nodes'])} | Edge count: {len(result['edges'])}"
)
except neo4jExceptions.ClientError as e:
logger.error(f"APOC query failed: {str(e)}")
return await self._robust_fallback(label, max_depth)
return result
async def _robust_fallback(
self, label: str, max_depth: int
) -> Dict[str, List[Dict]]:
"""Enhanced fallback query solution"""
result = {"nodes": [], "edges": []}
visited_nodes = set()
visited_edges = set()
async def traverse(current_label: str, current_depth: int):
if current_depth > max_depth:
return
# Get current node details
node = await self.get_node(current_label)
if not node:
return
node_id = f"{current_label}"
if node_id in visited_nodes:
return
visited_nodes.add(node_id)
# Add node data (with complete labels)
node_data = {k: v for k, v in node.items()}
node_data["labels"] = [
current_label
] # Assume get_node method returns label information
result["nodes"].append(node_data)
# Get all outgoing and incoming edges
query = f"""
MATCH (a)-[r]-(b)
WHERE a:`{current_label}` OR b:`{current_label}`
RETURN a, r, b,
CASE WHEN startNode(r) = a THEN 'OUTGOING' ELSE 'INCOMING' END AS direction
"""
async with self._driver.session(database=self._DATABASE) as session:
results = await session.run(query)
async for record in results:
# Handle edges
rel = record["r"]
edge_id = f"{rel.id}_{rel.type}"
if edge_id not in visited_edges:
edge_data = dict(rel)
edge_data.update(
{
"source": list(record["a"].labels)[0],
"target": list(record["b"].labels)[0],
"type": rel.type,
"direction": record["direction"],
}
)
result["edges"].append(edge_data)
visited_edges.add(edge_id)
# Recursively traverse adjacent nodes
next_label = (
list(record["b"].labels)[0]
if record["direction"] == "OUTGOING"
else list(record["a"].labels)[0]
)
await traverse(next_label, current_depth + 1)
await traverse(label, 0)
return result
async def get_all_labels(self) -> List[str]:
"""
Get all existing node labels in the database
Returns:
["Person", "Company", ...] # Alphabetically sorted label list
"""
async with self._driver.session(database=self._DATABASE) as session:
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label"
# Method 2: Query compatible with older versions
query = """
MATCH (n)
WITH DISTINCT labels(n) AS node_labels
UNWIND node_labels AS label
RETURN DISTINCT label
ORDER BY label
"""
result = await session.run(query)
labels = []
async for record in result:
labels.append(record["label"])
return labels

228
minirag/kg/networkx_impl.py Normal file
View File

@@ -0,0 +1,228 @@
"""
NetworkX Storage Module
=======================
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
Author: lightrag team
Created: 2024-01-25
License: MIT
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Version: 1.0.0
Dependencies:
- NetworkX
- NumPy
- LightRAG
- graspologic
Features:
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
- Query graph nodes and edges
- Calculate node and edge degrees
- Embed nodes using various algorithms (e.g., Node2Vec)
- Remove nodes and edges from the graph
Usage:
from lightrag.storage.networkx_storage import NetworkXStorage
"""
import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast
import networkx as nx
import numpy as np
from lightrag.utils import (
logger,
)
from lightrag.base import (
BaseGraphStorage,
)
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name):
logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
nx.write_graphml(graph, file_name)
@staticmethod
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
"""
from graspologic.utils import largest_connected_component
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
node_mapping = {
node: html.unescape(node.upper().strip()) for node in graph.nodes()
} # type: ignore
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Ensure an undirected graph with the same relationships will always be read the same way.
"""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def __post_init__(self):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
return self._graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
return self._graph.nodes.get(node_id)
async def node_degree(self, node_id: str) -> int:
return self._graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
self._graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str):
"""
Delete a node from the graph based on the specified node_id.
:param node_id: The node_id to delete
"""
if self._graph.has_node(node_id):
self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED
async def _node2vec_embed(self):
from graspologic import embed
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node IDs to be deleted
"""
for node in nodes:
if self._graph.has_node(node):
self._graph.remove_node(node)
def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
if self._graph.has_edge(source, target):
self._graph.remove_edge(source, target)

View File

@@ -1,3 +1,4 @@
import os
import asyncio
# import html
@@ -6,6 +7,11 @@ from dataclasses import dataclass
from typing import Union
import numpy as np
import array
import pipmaster as pm
if not pm.is_installed("oracledb"):
pm.install("oracledb")
from ..utils import logger
from ..base import (
@@ -114,16 +120,19 @@ class OracleDB:
logger.info("Finished check all tables in Oracle database")
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
async def query(
self, sql: str, params: dict = None, multirows: bool = False
) -> Union[dict, None]:
async with self.pool.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
with connection.cursor() as cursor:
try:
await cursor.execute(sql)
await cursor.execute(sql, params)
except Exception as e:
logger.error(f"Oracle database error: {e}")
print(sql)
print(params)
raise
columns = [column[0].lower() for column in cursor.description]
if multirows:
@@ -140,7 +149,7 @@ class OracleDB:
data = None
return data
async def execute(self, sql: str, data: list = None):
async def execute(self, sql: str, data: Union[list, dict] = None):
# logger.info("go into OracleDB execute method")
try:
async with self.pool.acquire() as connection:
@@ -150,8 +159,6 @@ class OracleDB:
if data is None:
await cursor.execute(sql)
else:
# print(data)
# print(sql)
await cursor.execute(sql, data)
await connection.commit()
except Exception as e:
@@ -164,34 +171,64 @@ class OracleDB:
@dataclass
class OracleKVStorage(BaseKVStorage):
# should pass db object to self.db
db: OracleDB = None
meta_fields = None
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict, None]:
"""根据 id 获取 doc_full 数据."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
workspace=self.db.workspace, id=id
)
"""get doc_full data based on id."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.db.workspace, "id": id}
# print("get_by_id:"+SQL)
res = await self.db.query(SQL)
if "llm_response_cache" == self.namespace:
array_res = await self.db.query(SQL, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
else:
res = await self.db.query(SQL, params)
if res:
data = res # {"data":res}
# print (data)
return data
return res
else:
return None
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
"""Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
if "llm_response_cache" == self.namespace:
array_res = await self.db.query(SQL, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
return res
else:
return None
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""根据 id 获取 doc_chunks 数据"""
"""get doc_chunks data based on id"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
# print("get_by_ids:"+SQL)
res = await self.db.query(SQL, multirows=True)
res = await self.db.query(SQL, params, multirows=True)
if "llm_response_cache" == self.namespace:
modes = set()
dict_res: dict[str, dict] = {}
for row in res:
modes.add(row["mode"])
for mode in modes:
if mode not in dict_res:
dict_res[mode] = {}
for row in res:
dict_res[row["mode"]][row["id"]] = row
res = [{k: v} for k, v in dict_res.items()]
if res:
data = res # [{"data":i} for i in res]
# print(data)
@@ -199,33 +236,43 @@ class OracleKVStorage(BaseKVStorage):
else:
return None
async def get_by_status_and_ids(
self, status: str, ids: list[str]
) -> Union[list[dict], None]:
"""Specifically for llm_response_cache."""
if ids is not None:
SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
else:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
res = await self.db.query(SQL, params, multirows=True)
if res:
return res
else:
return None
async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容"""
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace],
workspace=self.db.workspace,
ids=",".join([f"'{k}'" for k in keys]),
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
)
res = await self.db.query(SQL, multirows=True)
data = None
params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, multirows=True)
if res:
exist_keys = [key["id"] for key in res]
data = set([s for s in keys if s not in exist_keys])
return data
else:
exist_keys = []
data = set([s for s in keys if s not in exist_keys])
return data
return set(keys)
################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
# print(self._data)
# values = []
if self.namespace == "text_chunks":
list_data = [
{
"__id__": k,
"id": k,
**{k1: v1 for k1, v1 in v.items()},
}
for k, v in data.items()
@@ -241,32 +288,50 @@ class OracleKVStorage(BaseKVStorage):
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
# print(list_data)
merge_sql = SQL_TEMPLATES["merge_chunk"]
for item in list_data:
merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
values = [
item["__id__"],
item["content"],
self.db.workspace,
item["tokens"],
item["chunk_order_index"],
item["full_doc_id"],
item["__vector__"],
]
# print(merge_sql)
await self.db.execute(merge_sql, values)
_data = {
"id": item["id"],
"content": item["content"],
"workspace": self.db.workspace,
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector": item["__vector__"],
"status": item["status"],
}
await self.db.execute(merge_sql, _data)
if self.namespace == "full_docs":
for k, v in self._data.items():
for k, v in data.items():
# values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
check_id=k,
)
values = [k, self._data[k]["content"], self.db.workspace]
# print(merge_sql)
await self.db.execute(merge_sql, values)
return left_data
merge_sql = SQL_TEMPLATES["merge_doc_full"]
_data = {
"id": k,
"content": v["content"],
"workspace": self.db.workspace,
}
await self.db.execute(merge_sql, _data)
if self.namespace == "llm_response_cache":
for mode, items in data.items():
for k, v in items.items():
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
_data = {
"workspace": self.db.workspace,
"id": k,
"original_prompt": v["original_prompt"],
"return_value": v["return"],
"cache_mode": mode,
}
await self.db.execute(upsert_sql, _data)
return None
async def change_status(self, id: str, status: str):
SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace])
params = {"workspace": self.db.workspace, "id": id, "status": status}
await self.db.execute(SQL, params)
async def index_done_callback(self):
if self.namespace in ["full_docs", "text_chunks"]:
@@ -275,10 +340,16 @@ class OracleKVStorage(BaseKVStorage):
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2
# should pass db object to self.db
db: OracleDB = None
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self):
pass
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]):
"""向向量数据库中插入数据"""
@@ -295,18 +366,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
# 转换精度
dtype = str(embedding.dtype).upper()
dimension = embedding.shape[0]
embedding_string = ", ".join(map(str, embedding.tolist()))
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
SQL = SQL_TEMPLATES[self.namespace].format(
embedding_string=embedding_string,
dimension=dimension,
dtype=dtype,
workspace=self.db.workspace,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
params = {
"embedding_string": embedding_string,
"workspace": self.db.workspace,
"top_k": top_k,
"better_than_threshold": self.cosine_better_than_threshold,
}
# print(SQL)
results = await self.db.query(SQL, multirows=True)
results = await self.db.query(SQL, params=params, multirows=True)
# print("vector search result:",results)
return results
@@ -317,7 +387,7 @@ class OracleGraphStorage(BaseGraphStorage):
def __post_init__(self):
"""从graphml文件加载图"""
self._max_batch_size = self.global_config["embedding_batch_num"]
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
#################### insert method ################
@@ -328,6 +398,8 @@ class OracleGraphStorage(BaseGraphStorage):
entity_type = node_data["entity_type"]
description = node_data["description"]
source_id = node_data["source_id"]
logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
content = entity_name + description
contents = [content]
batches = [
@@ -339,22 +411,17 @@ class OracleGraphStorage(BaseGraphStorage):
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["merge_node"].format(
workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
)
# print(merge_sql)
await self.db.execute(
merge_sql,
[
self.db.workspace,
entity_name,
entity_type,
description,
source_id,
content,
content_vector,
],
)
merge_sql = SQL_TEMPLATES["merge_node"]
data = {
"workspace": self.db.workspace,
"name": entity_name,
"entity_type": entity_type,
"description": description,
"source_chunk_id": source_id,
"content": content,
"content_vector": content_vector,
}
await self.db.execute(merge_sql, data)
# self._graph.add_node(node_id, **node_data)
async def upsert_edge(
@@ -368,6 +435,10 @@ class OracleGraphStorage(BaseGraphStorage):
keywords = edge_data["keywords"]
description = edge_data["description"]
source_chunk_id = edge_data["source_id"]
logger.debug(
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
)
content = keywords + source_name + target_name + description
contents = [content]
batches = [
@@ -379,27 +450,20 @@ class OracleGraphStorage(BaseGraphStorage):
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["merge_edge"].format(
workspace=self.db.workspace,
source_name=source_name,
target_name=target_name,
source_chunk_id=source_chunk_id,
)
merge_sql = SQL_TEMPLATES["merge_edge"]
data = {
"workspace": self.db.workspace,
"source_name": source_name,
"target_name": target_name,
"weight": weight,
"keywords": keywords,
"description": description,
"source_chunk_id": source_chunk_id,
"content": content,
"content_vector": content_vector,
}
# print(merge_sql)
await self.db.execute(
merge_sql,
[
self.db.workspace,
source_name,
target_name,
weight,
keywords,
description,
source_chunk_id,
content,
content_vector,
],
)
await self.db.execute(merge_sql, data)
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
@@ -429,12 +493,11 @@ class OracleGraphStorage(BaseGraphStorage):
#################### query method #################
async def has_node(self, node_id: str) -> bool:
"""根据节点id检查节点是否存在"""
SQL = SQL_TEMPLATES["has_node"].format(
workspace=self.db.workspace, node_id=node_id
)
SQL = SQL_TEMPLATES["has_node"]
params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL)
# print(self.db.workspace, node_id)
res = await self.db.query(SQL)
res = await self.db.query(SQL, params)
if res:
# print("Node exist!",res)
return True
@@ -444,13 +507,14 @@ class OracleGraphStorage(BaseGraphStorage):
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""根据源和目标节点id检查边是否存在"""
SQL = SQL_TEMPLATES["has_edge"].format(
workspace=self.db.workspace,
source_node_id=source_node_id,
target_node_id=target_node_id,
)
SQL = SQL_TEMPLATES["has_edge"]
params = {
"workspace": self.db.workspace,
"source_node_id": source_node_id,
"target_node_id": target_node_id,
}
# print(SQL)
res = await self.db.query(SQL)
res = await self.db.query(SQL, params)
if res:
# print("Edge exist!",res)
return True
@@ -460,11 +524,10 @@ class OracleGraphStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int:
"""根据节点id获取节点的度"""
SQL = SQL_TEMPLATES["node_degree"].format(
workspace=self.db.workspace, node_id=node_id
)
SQL = SQL_TEMPLATES["node_degree"]
params = {"workspace": self.db.workspace, "node_id": node_id}
# print(SQL)
res = await self.db.query(SQL)
res = await self.db.query(SQL, params)
if res:
# print("Node degree",res["degree"])
return res["degree"]
@@ -480,12 +543,11 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> Union[dict, None]:
"""根据节点id获取节点数据"""
SQL = SQL_TEMPLATES["get_node"].format(
workspace=self.db.workspace, node_id=node_id
)
SQL = SQL_TEMPLATES["get_node"]
params = {"workspace": self.db.workspace, "node_id": node_id}
# print(self.db.workspace, node_id)
# print(SQL)
res = await self.db.query(SQL)
res = await self.db.query(SQL, params)
if res:
# print("Get node!",self.db.workspace, node_id,res)
return res
@@ -497,12 +559,13 @@ class OracleGraphStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
"""根据源和目标节点id获取边"""
SQL = SQL_TEMPLATES["get_edge"].format(
workspace=self.db.workspace,
source_node_id=source_node_id,
target_node_id=target_node_id,
)
res = await self.db.query(SQL)
SQL = SQL_TEMPLATES["get_edge"]
params = {
"workspace": self.db.workspace,
"source_node_id": source_node_id,
"target_node_id": target_node_id,
}
res = await self.db.query(SQL, params)
if res:
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
return res
@@ -513,10 +576,9 @@ class OracleGraphStorage(BaseGraphStorage):
async def get_node_edges(self, source_node_id: str):
"""根据节点id获取节点的所有边"""
if await self.has_node(source_node_id):
SQL = SQL_TEMPLATES["get_node_edges"].format(
workspace=self.db.workspace, source_node_id=source_node_id
)
res = await self.db.query(sql=SQL, multirows=True)
SQL = SQL_TEMPLATES["get_node_edges"]
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
res = await self.db.query(sql=SQL, params=params, multirows=True)
if res:
data = [(i["source_name"], i["target_name"]) for i in res]
# print("Get node edge!",self.db.workspace, source_node_id,data)
@@ -525,6 +587,29 @@ class OracleGraphStorage(BaseGraphStorage):
# print("Node Edge not exist!",self.db.workspace, source_node_id)
return []
async def get_all_nodes(self, limit: int):
"""查询所有节点"""
SQL = SQL_TEMPLATES["get_all_nodes"]
params = {"workspace": self.db.workspace, "limit": str(limit)}
res = await self.db.query(sql=SQL, params=params, multirows=True)
if res:
return res
async def get_all_edges(self, limit: int):
"""查询所有边"""
SQL = SQL_TEMPLATES["get_all_edges"]
params = {"workspace": self.db.workspace, "limit": str(limit)}
res = await self.db.query(sql=SQL, params=params, multirows=True)
if res:
return res
async def get_statistics(self):
SQL = SQL_TEMPLATES["get_statistics"]
params = {"workspace": self.db.workspace}
res = await self.db.query(sql=SQL, params=params, multirows=True)
if res:
return res
N_T = {
"full_docs": "LIGHTRAG_DOC_FULL",
@@ -537,20 +622,26 @@ N_T = {
TABLES = {
"LIGHTRAG_DOC_FULL": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
id varchar(256)PRIMARY KEY,
id varchar(256),
workspace varchar(1024),
doc_name varchar(1024),
content CLOB,
meta JSON,
content_summary varchar(1024),
content_length NUMBER,
status varchar(256),
chunks_count NUMBER,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL
updatetime TIMESTAMP DEFAULT NULL,
error varchar(4096)
)"""
},
"LIGHTRAG_DOC_CHUNKS": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
id varchar(256) PRIMARY KEY,
id varchar(256),
workspace varchar(1024),
full_doc_id varchar(256),
status varchar(256),
chunk_order_index NUMBER,
tokens NUMBER,
content CLOB,
@@ -592,9 +683,15 @@ TABLES = {
"LIGHTRAG_LLM_CACHE": {
"ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
id varchar(256) PRIMARY KEY,
send clob,
return clob,
model varchar(1024),
workspace varchar(1024),
cache_mode varchar(256),
model_name varchar(256),
original_prompt clob,
return_value clob,
embedding CLOB,
embedding_shape NUMBER,
embedding_min NUMBER,
embedding_max NUMBER,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL
)"""
@@ -619,82 +716,141 @@ TABLES = {
SQL_TEMPLATES = {
# SQL for KVStorage
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
"filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = '{check_id}')
WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:1,:2,:3)
""",
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
USING DUAL
ON (a.id = '{check_id}')
WHEN NOT MATCHED THEN
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
values (:1,:2,:3,:4,:5,:6,:7) """,
"get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
"get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
"get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
"change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:id,:content,:workspace)""",
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
USING DUAL
ON (id = :id and workspace = :workspace)
WHEN NOT MATCHED THEN INSERT
(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
"upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
USING DUAL
ON (a.id = :id)
WHEN NOT MATCHED THEN
INSERT (workspace,id,original_prompt,return_value,cache_mode)
VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode)
WHEN MATCHED THEN UPDATE
SET original_prompt = :original_prompt,
return_value = :return_value,
cache_mode = :cache_mode,
updatetime = SYSDATE""",
# SQL for VectorStorage
"entities": """SELECT name as entity_name FROM
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace)
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace)
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
"chunks": """SELECT id FROM
(SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
(SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace)
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
# SQL for GraphStorage
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
MATCH (a)
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
WHERE a.workspace=:workspace AND a.name=:node_id
COLUMNS (a.name))""",
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
MATCH (a) -[e]-> (b)
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
AND a.name='{source_node_id}' AND b.name='{target_node_id}'
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
AND a.name=:source_node_id AND b.name=:target_node_id
COLUMNS (e.source_name,e.target_name) )""",
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b)
WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
AND a.name='{node_id}' or b.name = '{node_id}'
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
AND a.name=:node_id or b.name = :node_id
COLUMNS (a.name))""",
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
FROM GRAPH_TABLE (lightrag_graph
MATCH (a)
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
WHERE a.workspace=:workspace AND a.name=:node_id
COLUMNS (a.name)
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
WHERE t2.workspace='{workspace}'""",
WHERE t2.workspace=:workspace""",
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b)
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
AND a.name='{source_node_id}' and b.name = '{target_node_id}'
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
AND a.name=:source_node_id and b.name = :target_node_id
COLUMNS (e.id,a.name as source_id)
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
"get_node_edges": """SELECT source_name,target_name
FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b)
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
AND a.name='{source_node_id}'
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
AND a.name=:source_node_id
COLUMNS (a.name as source_name,b.name as target_name))""",
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
USING DUAL
ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
ON (a.workspace=:workspace and a.name=:name)
WHEN NOT MATCHED THEN
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
values (:1,:2,:3,:4,:5,:6,:7) """,
values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector)
WHEN MATCHED THEN
UPDATE SET
entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
USING DUAL
ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name)
WHEN NOT MATCHED THEN
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector)
WHEN MATCHED THEN
UPDATE SET
weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
"get_all_nodes": """WITH t0 AS (
SELECT name AS id, entity_type AS label, entity_type, description,
'["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
FROM lightrag_graph_nodes
WHERE workspace = :workspace
ORDER BY createtime DESC fetch first :limit rows only
), t1 AS (
SELECT t0.id, source_chunk_id
FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
), t2 AS (
SELECT t1.id, LISTAGG(t2.content, '\n') content
FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
GROUP BY t1.id
)
SELECT t0.id, label, entity_type, description, t2.content
FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
"get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
t1.weight,t1.DESCRIPTION,t2.content
FROM LIGHTRAG_GRAPH_EDGES t1
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
WHERE t1.workspace=:workspace
order by t1.CREATETIME DESC
fetch first :limit rows only""",
"get_statistics": """select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
count(distinct CASE WHEN type='edge' THEN id END) as edges_count
FROM (
select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
UNION
select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
)""",
}

1182
minirag/kg/postgres_impl.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,136 @@
import asyncio
import sys
import os
import pipmaster as pm
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import asyncpg
import psycopg
from psycopg_pool import AsyncConnectionPool
from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
DB = "rag"
USER = "rag"
PASSWORD = "rag"
HOST = "localhost"
PORT = "15432"
os.environ["AGE_GRAPH_NAME"] = "dickens"
if sys.platform.startswith("win"):
import asyncio.windows_events
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
async def get_pool():
return await asyncpg.create_pool(
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
min_size=10,
max_size=10,
max_queries=5000,
max_inactive_connection_lifetime=300.0,
)
async def main1():
connection_string = (
f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
)
pool = AsyncConnectionPool(connection_string, open=False)
await pool.open()
try:
conn = await pool.getconn(timeout=10)
async with conn.cursor() as curs:
try:
await curs.execute('SET search_path = ag_catalog, "$user", public')
await curs.execute("SELECT create_graph('dickens-2')")
await conn.commit()
print("create_graph success")
except (
psycopg.errors.InvalidSchemaName,
psycopg.errors.UniqueViolation,
):
print("create_graph already exists")
await conn.rollback()
finally:
pass
db = PostgreSQLDB(
config={
"host": "localhost",
"port": 15432,
"user": "rag",
"password": "rag",
"database": "r1",
}
)
async def query_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace="chunk_entity_relation",
global_config={},
embedding_func=None,
)
graph.db = db
res = await graph.get_node('"A CHRISTMAS CAROL"')
print("Node is: ", res)
res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG")
print("Edge is: ", res)
res = await graph.get_node_edges('"SCROOGE"')
print("Node Edges are: ", res)
async def create_edge_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace="chunk_entity_relation",
global_config={},
embedding_func=None,
)
graph.db = db
await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"})
await graph.upsert_node('"THE GIRLS"', {"world": "hello"})
await graph.upsert_edge(
'"THE CRATCHITS"',
'"THE GIRLS"',
edge_data={
"weight": 7.0,
"description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.',
"keywords": '"family, collective effort"',
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
},
)
res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"')
print("Edge is: ", res)
async def main():
pool = await get_pool()
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
async with pool.acquire() as conn:
try:
await conn.execute(
"""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')"""
)
except asyncpg.exceptions.InvalidSchemaNameError:
print("create_graph already exists")
# stmt = await conn.prepare(sql)
row = await conn.fetch(sql)
print("row is: ", row)
row = await conn.fetchrow("select '100'::int + 200 as result")
print(row) # <Record result=300>
if __name__ == "__main__":
asyncio.run(query_with_age())

70
minirag/kg/redis_impl.py Normal file
View File

@@ -0,0 +1,70 @@
import os
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import pipmaster as pm
if not pm.is_installed("redis"):
pm.install("redis")
# aioredis is a depricated library, replaced with redis
from redis.asyncio import Redis
from lightrag.utils import logger
from lightrag.base import BaseKVStorage
import json
@dataclass
class RedisKVStorage(BaseKVStorage):
def __post_init__(self):
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
self._redis = Redis.from_url(redis_url, decode_responses=True)
logger.info(f"Use Redis as KV {self.namespace}")
async def all_keys(self) -> list[str]:
keys = await self._redis.keys(f"{self.namespace}:*")
return [key.split(":", 1)[-1] for key in keys]
async def get_by_id(self, id):
data = await self._redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
async def get_by_ids(self, ids, fields=None):
pipe = self._redis.pipeline()
for id in ids:
pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute()
if fields:
# Filter fields if specified
return [
{field: value.get(field) for field in fields if field in value}
if (value := json.loads(result))
else None
for result in results
]
return [json.loads(result) if result else None for result in results]
async def filter_keys(self, data: list[str]) -> set[str]:
pipe = self._redis.pipeline()
for key in data:
pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
existing_ids = {data[i] for i, exists in enumerate(results) if exists}
return set(data) - existing_ids
async def upsert(self, data: dict[str, dict]):
pipe = self._redis.pipeline()
for k, v in tqdm_async(data.items(), desc="Upserting"):
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute()
for k in data:
data[k]["_id"] = k
return data
async def drop(self):
keys = await self._redis.keys(f"{self.namespace}:*")
if keys:
await self._redis.delete(*keys)

665
minirag/kg/tidb_impl.py Normal file
View File

@@ -0,0 +1,665 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Union
import numpy as np
import pipmaster as pm
if not pm.is_installed("pymysql"):
pm.install("pymysql")
if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy")
from sqlalchemy import create_engine, text
from tqdm import tqdm
from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
from lightrag.utils import logger
class TiDB(object):
def __init__(self, config, **kwargs):
self.host = config.get("host", None)
self.port = config.get("port", None)
self.user = config.get("user", None)
self.password = config.get("password", None)
self.database = config.get("database", None)
self.workspace = config.get("workspace", None)
connection_string = (
f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
f"?ssl_verify_cert=true&ssl_verify_identity=true"
)
try:
self.engine = create_engine(connection_string)
logger.info(f"Connected to TiDB database at {self.database}")
except Exception as e:
logger.error(f"Failed to connect to TiDB database at {self.database}")
logger.error(f"TiDB database error: {e}")
raise
async def check_tables(self):
for k, v in TABLES.items():
try:
await self.query(f"SELECT 1 FROM {k}".format(k=k))
except Exception as e:
logger.error(f"Failed to check table {k} in TiDB database")
logger.error(f"TiDB database error: {e}")
try:
# print(v["ddl"])
await self.execute(v["ddl"])
logger.info(f"Created table {k} in TiDB database")
except Exception as e:
logger.error(f"Failed to create table {k} in TiDB database")
logger.error(f"TiDB database error: {e}")
async def query(
self, sql: str, params: dict = None, multirows: bool = False
) -> Union[dict, None]:
if params is None:
params = {"workspace": self.workspace}
else:
params.update({"workspace": self.workspace})
with self.engine.connect() as conn, conn.begin():
try:
result = conn.execute(text(sql), params)
except Exception as e:
logger.error(f"Tidb database error: {e}")
print(sql)
print(params)
raise
if multirows:
rows = result.all()
if rows:
data = [dict(zip(result.keys(), row)) for row in rows]
else:
data = []
else:
row = result.first()
if row:
data = dict(zip(result.keys(), row))
else:
data = None
return data
async def execute(self, sql: str, data: list | dict = None):
# logger.info("go into TiDBDB execute method")
try:
with self.engine.connect() as conn, conn.begin():
if data is None:
conn.execute(text(sql))
else:
conn.execute(text(sql), parameters=data)
except Exception as e:
logger.error(f"TiDB database error: {e}")
print(sql)
print(data)
raise
@dataclass
class TiDBKVStorage(BaseKVStorage):
# should pass db object to self.db
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> Union[dict, None]:
"""根据 id 获取 doc_full 数据."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"id": id}
# print("get_by_id:"+SQL)
res = await self.db.query(SQL, params)
if res:
data = res # {"data":res}
# print (data)
return data
else:
return None
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""根据 id 获取 doc_chunks 数据"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
# print("get_by_ids:"+SQL)
res = await self.db.query(SQL, multirows=True)
if res:
data = res # [{"data":i} for i in res]
# print(data)
return data
else:
return None
async def filter_keys(self, keys: list[str]) -> set[str]:
"""过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace],
id_field=N_ID[self.namespace],
ids=",".join([f"'{id}'" for id in keys]),
)
try:
await self.db.query(SQL)
except Exception as e:
logger.error(f"Tidb database error: {e}")
print(SQL)
res = await self.db.query(SQL, multirows=True)
if res:
exist_keys = [key["id"] for key in res]
data = set([s for s in keys if s not in exist_keys])
else:
exist_keys = []
data = set([s for s in keys if s not in exist_keys])
return data
################ INSERT full_doc AND chunks ################
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
if self.namespace == "text_chunks":
list_data = [
{
"__id__": k,
**{k1: v1 for k1, v1 in v.items()},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
merge_sql = SQL_TEMPLATES["upsert_chunk"]
data = []
for item in list_data:
data.append(
{
"id": item["__id__"],
"content": item["content"],
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector": f"{item["__vector__"].tolist()}",
"workspace": self.db.workspace,
}
)
await self.db.execute(merge_sql, data)
if self.namespace == "full_docs":
merge_sql = SQL_TEMPLATES["upsert_doc_full"]
data = []
for k, v in self._data.items():
data.append(
{
"id": k,
"content": v["content"],
"workspace": self.db.workspace,
}
)
await self.db.execute(merge_sql, data)
return left_data
async def index_done_callback(self):
if self.namespace in ["full_docs", "text_chunks"]:
logger.info("full doc and chunk data had been saved into TiDB db!")
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
# Use global config value if specified, otherwise use default
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
self.cosine_better_than_threshold = config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
async def query(self, query: str, top_k: int) -> list[dict]:
"""search from tidb vector"""
embeddings = await self.embedding_func([query])
embedding = embeddings[0]
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
params = {
"embedding_string": embedding_string,
"top_k": top_k,
"better_than_threshold": self.cosine_better_than_threshold,
}
results = await self.db.query(
SQL_TEMPLATES[self.namespace], params=params, multirows=True
)
print("vector search result:", results)
if not results:
return []
return results
###### INSERT entities And relationships ######
async def upsert(self, data: dict[str, dict]):
# ignore, upsert in TiDBKVStorage already
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
if self.namespace == "chunks":
return []
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
list_data = [
{
"id": k,
**{k1: v1 for k1, v1 in v.items()},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = []
for f in tqdm(
asyncio.as_completed(embedding_tasks),
total=len(embedding_tasks),
desc="Generating embeddings",
unit="batch",
):
embeddings = await f
embeddings_list.append(embeddings)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["content_vector"] = embeddings[i]
if self.namespace == "entities":
data = []
for item in list_data:
param = {
"id": item["id"],
"name": item["entity_name"],
"content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}",
"workspace": self.db.workspace,
}
# update entity_id if node inserted by graph_storage_instance before
has = await self.db.query(SQL_TEMPLATES["has_entity"], param)
if has["cnt"] != 0:
await self.db.execute(SQL_TEMPLATES["update_entity"], param)
continue
data.append(param)
if data:
merge_sql = SQL_TEMPLATES["insert_entity"]
await self.db.execute(merge_sql, data)
elif self.namespace == "relationships":
data = []
for item in list_data:
param = {
"id": item["id"],
"source_name": item["src_id"],
"target_name": item["tgt_id"],
"content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}",
"workspace": self.db.workspace,
}
# update relation_id if node inserted by graph_storage_instance before
has = await self.db.query(SQL_TEMPLATES["has_relationship"], param)
if has["cnt"] != 0:
await self.db.execute(SQL_TEMPLATES["update_relationship"], param)
continue
data.append(param)
if data:
merge_sql = SQL_TEMPLATES["insert_relationship"]
await self.db.execute(merge_sql, data)
@dataclass
class TiDBGraphStorage(BaseGraphStorage):
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
#################### upsert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
entity_name = node_id
entity_type = node_data["entity_type"]
description = node_data["description"]
source_id = node_data["source_id"]
logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
content = entity_name + description
contents = [content]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
sql = SQL_TEMPLATES["upsert_node"]
data = {
"workspace": self.db.workspace,
"name": entity_name,
"entity_type": entity_type,
"description": description,
"source_chunk_id": source_id,
"content": content,
"content_vector": f"{content_vector.tolist()}",
}
await self.db.execute(sql, data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
source_name = source_node_id
target_name = target_node_id
weight = edge_data["weight"]
keywords = edge_data["keywords"]
description = edge_data["description"]
source_chunk_id = edge_data["source_id"]
logger.debug(
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
)
content = keywords + source_name + target_name + description
contents = [content]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["upsert_edge"]
data = {
"workspace": self.db.workspace,
"source_name": source_name,
"target_name": target_name,
"weight": weight,
"keywords": keywords,
"description": description,
"source_chunk_id": source_chunk_id,
"content": content,
"content_vector": f"{content_vector.tolist()}",
}
await self.db.execute(merge_sql, data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
# Query
async def has_node(self, node_id: str) -> bool:
sql = SQL_TEMPLATES["has_entity"]
param = {"name": node_id, "workspace": self.db.workspace}
has = await self.db.query(sql, param)
return has["cnt"] != 0
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
sql = SQL_TEMPLATES["has_relationship"]
param = {
"source_name": source_node_id,
"target_name": target_node_id,
"workspace": self.db.workspace,
}
has = await self.db.query(sql, param)
return has["cnt"] != 0
async def node_degree(self, node_id: str) -> int:
sql = SQL_TEMPLATES["node_degree"]
param = {"name": node_id, "workspace": self.db.workspace}
result = await self.db.query(sql, param)
return result["cnt"]
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
return degree
async def get_node(self, node_id: str) -> Union[dict, None]:
sql = SQL_TEMPLATES["get_node"]
param = {"name": node_id, "workspace": self.db.workspace}
return await self.db.query(sql, param)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
sql = SQL_TEMPLATES["get_edge"]
param = {
"source_name": source_node_id,
"target_name": target_node_id,
"workspace": self.db.workspace,
}
return await self.db.query(sql, param)
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
sql = SQL_TEMPLATES["get_node_edges"]
param = {"source_name": source_node_id, "workspace": self.db.workspace}
res = await self.db.query(sql, param, multirows=True)
if res:
data = [(i["source_name"], i["target_name"]) for i in res]
return data
else:
return []
N_T = {
"full_docs": "LIGHTRAG_DOC_FULL",
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
"chunks": "LIGHTRAG_DOC_CHUNKS",
"entities": "LIGHTRAG_GRAPH_NODES",
"relationships": "LIGHTRAG_GRAPH_EDGES",
}
N_ID = {
"full_docs": "doc_id",
"text_chunks": "chunk_id",
"chunks": "chunk_id",
"entities": "entity_id",
"relationships": "relation_id",
}
TABLES = {
"LIGHTRAG_DOC_FULL": {
"ddl": """
CREATE TABLE LIGHTRAG_DOC_FULL (
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
`doc_id` VARCHAR(256) NOT NULL,
`workspace` varchar(1024),
`content` LONGTEXT,
`meta` JSON,
`createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
`updatetime` TIMESTAMP DEFAULT NULL,
UNIQUE KEY (`doc_id`)
);
"""
},
"LIGHTRAG_DOC_CHUNKS": {
"ddl": """
CREATE TABLE LIGHTRAG_DOC_CHUNKS (
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
`chunk_id` VARCHAR(256) NOT NULL,
`full_doc_id` VARCHAR(256) NOT NULL,
`workspace` varchar(1024),
`chunk_order_index` INT,
`tokens` INT,
`content` LONGTEXT,
`content_vector` VECTOR,
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
`updatetime` DATETIME DEFAULT NULL,
UNIQUE KEY (`chunk_id`)
);
"""
},
"LIGHTRAG_GRAPH_NODES": {
"ddl": """
CREATE TABLE LIGHTRAG_GRAPH_NODES (
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
`entity_id` VARCHAR(256),
`workspace` varchar(1024),
`name` VARCHAR(2048),
`entity_type` VARCHAR(1024),
`description` LONGTEXT,
`source_chunk_id` VARCHAR(256),
`content` LONGTEXT,
`content_vector` VECTOR,
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
`updatetime` DATETIME DEFAULT NULL,
KEY (`entity_id`)
);
"""
},
"LIGHTRAG_GRAPH_EDGES": {
"ddl": """
CREATE TABLE LIGHTRAG_GRAPH_EDGES (
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
`relation_id` VARCHAR(256),
`workspace` varchar(1024),
`source_name` VARCHAR(2048),
`target_name` VARCHAR(2048),
`weight` DECIMAL,
`keywords` TEXT,
`description` LONGTEXT,
`source_chunk_id` varchar(256),
`content` LONGTEXT,
`content_vector` VECTOR,
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
`updatetime` DATETIME DEFAULT NULL,
KEY (`relation_id`)
);
"""
},
"LIGHTRAG_LLM_CACHE": {
"ddl": """
CREATE TABLE LIGHTRAG_LLM_CACHE (
id BIGINT PRIMARY KEY AUTO_INCREMENT,
send TEXT,
return TEXT,
model VARCHAR(1024),
createtime DATETIME DEFAULT CURRENT_TIMESTAMP,
updatetime DATETIME DEFAULT NULL
);
"""
},
}
SQL_TEMPLATES = {
# SQL for KVStorage
"get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace",
"get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace",
"get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
"get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
"filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
# SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
"upsert_doc_full": """
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
VALUES (:id, :content, :workspace)
ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
""",
"upsert_chunk": """
INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
ON DUPLICATE KEY UPDATE
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
""",
# SQL for VectorStorage
"entities": """SELECT n.name as entity_name FROM
(SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n
WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k
""",
"relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id FROM
(SELECT source_name, target_name, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e
WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k
""",
"chunks": """SELECT c.id FROM
(SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k
""",
"has_entity": """
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace
""",
"has_relationship": """
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace
""",
"update_entity": """
UPDATE LIGHTRAG_GRAPH_NODES SET
entity_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP
WHERE workspace = :workspace AND name = :name
""",
"update_relationship": """
UPDATE LIGHTRAG_GRAPH_EDGES SET
relation_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP
WHERE workspace = :workspace AND source_name = :source_name AND target_name = :target_name
""",
"insert_entity": """
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
VALUES(:id, :name, :content, :content_vector, :workspace)
""",
"insert_relationship": """
INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
""",
# SQL for GraphStorage
"get_node": """
SELECT entity_id AS id, workspace, name, entity_type, description, source_chunk_id AS source_id, content, content_vector
FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace
""",
"get_edge": """
SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id AS source_id, content, content_vector
FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace
""",
"get_node_edges": """
SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id, content, content_vector
FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND workspace = :workspace
""",
"node_degree": """
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace AND :name IN (source_name, target_name)
""",
"upsert_node": """
INSERT INTO LIGHTRAG_GRAPH_NODES(name, content, content_vector, workspace, source_chunk_id, entity_type, description)
VALUES(:name, :content, :content_vector, :workspace, :source_chunk_id, :entity_type, :description)
ON DUPLICATE KEY UPDATE
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP,
source_chunk_id = VALUES(source_chunk_id), entity_type = VALUES(entity_type), description = VALUES(description)
""",
"upsert_edge": """
INSERT INTO LIGHTRAG_GRAPH_EDGES(source_name, target_name, content, content_vector,
workspace, weight, keywords, description, source_chunk_id)
VALUES(:source_name, :target_name, :content, :content_vector,
:workspace, :weight, :keywords, :description, :source_chunk_id)
ON DUPLICATE KEY UPDATE
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP,
weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description),
source_chunk_id = VALUES(source_chunk_id)
""",
}

View File

@@ -1,729 +1,5 @@
import os
import copy
from functools import lru_cache
import json
import aioboto3
import aiohttp
import numpy as np
import ollama
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
Timeout,
AsyncAzureOpenAI,
)
import base64
import struct
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from pydantic import BaseModel, Field
from typing import List, Dict, Callable, Any
from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs,
) -> str:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
)
return response.choices[0].message.content
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs,
):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hashing_kv is not None:
await hashing_kv.upsert(
{args_hash: {"return": response.choices[0].message.content, "model": model}}
)
return response.choices[0].message.content
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, max=60),
retry=retry_if_exception_type((BedrockError)),
)
async def bedrock_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> str:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
# Fix message history format
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments
args = {"modelId": model, "messages": messages}
# Define system prompt
if system_prompt:
args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
inference_params_map = {
"max_tokens": "maxTokens",
"top_p": "topP",
"stop_sequences": "stopSequences",
}
if inference_params := list(
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
):
args["inferenceConfig"] = {}
for param in inference_params:
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Call model via Converse API
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
try:
response = await bedrock_async_client.converse(**args, **kwargs)
except Exception as e:
raise BedrockError(e)
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": response["output"]["message"]["content"][0]["text"],
"model": model,
}
}
)
return response["output"]["message"]["content"][0]["text"]
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, device_map="auto", trust_remote_code=True#False
)
hf_model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
if hf_tokenizer.pad_token is None:
hf_tokenizer.pad_token = hf_tokenizer.eos_token
return hf_model, hf_tokenizer
async def hf_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
input_prompt = ""
try:
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
try:
ori_message = copy.deepcopy(messages)
if messages[0]["role"] == "system":
messages[1]["content"] = (
"<system>"
+ messages[0]["content"]
+ "</system>\n"
+ messages[1]["content"]
)
messages = messages[1:]
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
len_message = len(ori_message)
for msgid in range(len_message):
input_prompt = (
input_prompt
+ "<"
+ ori_message[msgid]["role"]
+ ">"
+ ori_message[msgid]["content"]
+ "</"
+ ori_message[msgid]["role"]
+ ">\n"
)
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
torch.cuda.empty_cache()
# inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
**input_ids, max_new_tokens=500, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(
output[0][len(input_ids[0]) :], skip_special_tokens=True
)
FINDSTRING = "<|COMPLETE|>"
last_assistant_index = response_text.find(FINDSTRING)
if last_assistant_index != -1:
response_text = response_text[:last_assistant_index + len(FINDSTRING)]
else:
response_text = response_text
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
result = response["message"]["content"]
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
return result
@lru_cache(maxsize=1)
def initialize_lmdeploy_pipeline(
model,
tp=1,
chat_template=None,
log_level="WARNING",
model_format="hf",
quant_policy=0,
):
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
lmdeploy_pipe = pipeline(
model_path=model,
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=ChatTemplateConfig(model_name=chat_template)
if chat_template
else None,
log_level="WARNING",
)
return lmdeploy_pipe
async def lmdeploy_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
chat_template=None,
model_format="hf",
quant_policy=0,
**kwargs,
) -> str:
"""
Args:
model (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
except Exception:
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
do_preprocess = kwargs.pop("do_preprocess", True)
do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs
version = version_info
if do_sample is not None and version < (0, 6, 0):
raise RuntimeError(
"`do_sample` parameter is not supported by lmdeploy until "
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
)
else:
do_sample = True
gen_params.update(do_sample=do_sample)
lmdeploy_pipe = initialize_lmdeploy_pipeline(
model=model,
tp=tp,
chat_template=chat_template,
model_format=model_format,
quant_policy=quant_policy,
log_level="WARNING",
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
**gen_params,
)
response = ""
async for res in lmdeploy_pipe.generate(
messages,
gen_config=gen_config,
do_preprocess=do_preprocess,
stream_response=False,
session_id=1,
):
response += res.response
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
return response
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"conversation-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def azure_openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def siliconcloud_embedding(
texts: list[str],
model: str = "netease-youdao/bce-embedding-base_v1",
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
max_token_size: int = 512,
api_key: str = None,
) -> np.ndarray:
if api_key and not api_key.startswith("Bearer "):
api_key = "Bearer " + api_key
headers = {"Authorization": api_key, "Content-Type": "application/json"}
truncate_texts = [text[0:max_token_size] for text in texts]
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
base64_strings = []
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json()
if "code" in content:
raise ValueError(content)
base64_strings = [item["embedding"] for item in content["data"]]
embeddings = []
for string in base64_strings:
decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4
float_array = struct.unpack("<" + "f" * n, decode_bytes)
embeddings.append(float_array)
return np.array(embeddings)
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential(multiplier=1, min=4, max=10),
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
async def bedrock_embedding(
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
) -> np.ndarray:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = await response.get("body").json()
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
return np.array(embed_texts)
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
embed_model.to('cuda:0')
input_ids = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).input_ids.cuda()
with torch.no_grad():
outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().cpu().numpy()
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
embed_text = []
ollama_client = ollama.Client(**kwargs)
for text in texts:
data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])
return embed_text
from pydantic import BaseModel, Field
class Model(BaseModel):
@@ -793,6 +69,8 @@ class MultiModel:
self, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name
kwargs.pop("keyword_extraction", None)
kwargs.pop("mode", None)
next_model = self._next_model()
args = dict(
prompt=prompt,
@@ -809,6 +87,8 @@ if __name__ == "__main__":
import asyncio
async def main():
from lightrag.llm.openai import gpt_4o_mini_complete
result = await gpt_4o_mini_complete("How are you?")
print(result)

0
minirag/llm/__init__.py Normal file
View File

189
minirag/llm/azure_openai.py Normal file
View File

@@ -0,0 +1,189 @@
"""
Azure OpenAI LLM Interface Module
==========================
This module provides interfaces for interacting with aure openai's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- openai
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.azure_openai import azure_openai_model_complete, azure_openai_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import os
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("openai"):
pm.install("openai")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
from openai import (
AsyncAzureOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
safe_unicode_decode,
)
import numpy as np
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APIConnectionError)
),
)
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
api_version=None,
**kwargs,
):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hasattr(response, "__aiter__"):
async def inner():
async for chunk in response:
if len(chunk.choices) == 0:
continue
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
return inner()
else:
content = response.choices[0].message.content
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
return content
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await azure_openai_complete_if_cache(
os.getenv("LLM_MODEL", "gpt-4o-mini"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def azure_openai_embed(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
api_version: str = None,
) -> np.ndarray:
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])

225
minirag/llm/bedrock.py Normal file
View File

@@ -0,0 +1,225 @@
"""
Bedrock LLM Interface Module
==========================
This module provides interfaces for interacting with Bedrock's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- aioboto3, tenacity
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.bebrock import bebrock_model_complete, bebrock_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import copy
import os
import json
import pipmaster as pm # Pipmaster for dynamic library install
if not pm.is_installed("aioboto3"):
pm.install("aioboto3")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
import aioboto3
import numpy as np
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
locate_json_string_body_from_string,
)
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, max=60),
retry=retry_if_exception_type((BedrockError)),
)
async def bedrock_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> str:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
kwargs.pop("hashing_kv", None)
# Fix message history format
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments
args = {"modelId": model, "messages": messages}
# Define system prompt
if system_prompt:
args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
inference_params_map = {
"max_tokens": "maxTokens",
"top_p": "topP",
"stop_sequences": "stopSequences",
}
if inference_params := list(
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
):
args["inferenceConfig"] = {}
for param in inference_params:
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
# Call model via Converse API
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
try:
response = await bedrock_async_client.converse(**args, **kwargs)
except Exception as e:
raise BedrockError(e)
return response["output"]["message"]["content"][0]["text"]
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential(multiplier=1, min=4, max=10),
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
async def bedrock_embed(
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
) -> np.ndarray:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = await response.get("body").json()
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
return np.array(embed_texts)

188
minirag/llm/hf.py Normal file
View File

@@ -0,0 +1,188 @@
"""
Hugging face LLM Interface Module
==========================
This module provides interfaces for interacting with Hugging face's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- transformers
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.hf import hf_model_complete, hf_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import copy
import os
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("transformers"):
pm.install("transformers")
if not pm.is_installed("torch"):
pm.install("torch")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
from transformers import AutoTokenizer, AutoModelForCausalLM
from functools import lru_cache
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from lightrag.utils import (
locate_json_string_body_from_string,
)
import torch
import numpy as np
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
hf_model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
if hf_tokenizer.pad_token is None:
hf_tokenizer.pad_token = hf_tokenizer.eos_token
return hf_model, hf_tokenizer
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def hf_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
**kwargs,
) -> str:
model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
kwargs.pop("hashing_kv", None)
input_prompt = ""
try:
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
try:
ori_message = copy.deepcopy(messages)
if messages[0]["role"] == "system":
messages[1]["content"] = (
"<system>"
+ messages[0]["content"]
+ "</system>\n"
+ messages[1]["content"]
)
messages = messages[1:]
input_prompt = hf_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
len_message = len(ori_message)
for msgid in range(len_message):
input_prompt = (
input_prompt
+ "<"
+ ori_message[msgid]["role"]
+ ">"
+ ori_message[msgid]["content"]
+ "</"
+ ori_message[msgid]["role"]
+ ">\n"
)
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
return response_text
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
result = await hf_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
device = next(embed_model.parameters()).device
input_ids = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).input_ids.to(device)
with torch.no_grad():
outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1)
if embeddings.dtype == torch.bfloat16:
return embeddings.detach().to(torch.float32).cpu().numpy()
else:
return embeddings.detach().cpu().numpy()

86
minirag/llm/jina.py Normal file
View File

@@ -0,0 +1,86 @@
"""
Jina Embedding Interface Module
==========================
This module provides interfaces for interacting with jina system,
including embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added embedding generation
Dependencies:
- tenacity
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.jina import jina_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import os
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("lmdeploy"):
pm.install("lmdeploy")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
import numpy as np
import aiohttp
async def fetch_data(url, headers, data):
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
response_json = await response.json()
data_list = response_json.get("data", [])
return data_list
async def jina_embed(
texts: list[str],
dimensions: int = 1024,
late_chunking: bool = False,
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["JINA_API_KEY"] = api_key
url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
}
data = {
"model": "jina-embeddings-v3",
"normalized": True,
"embedding_type": "float",
"dimensions": f"{dimensions}",
"late_chunking": late_chunking,
"input": texts,
}
data_list = await fetch_data(url, headers, data)
return np.array([dp["embedding"] for dp in data_list])

191
minirag/llm/lmdeploy.py Normal file
View File

@@ -0,0 +1,191 @@
"""
LMDeploy LLM Interface Module
==========================
This module provides interfaces for interacting with LMDeploy's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- tenacity
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.lmdeploy import lmdeploy_model_complete, lmdeploy_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("lmdeploy"):
pm.install("lmdeploy[all]")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from functools import lru_cache
@lru_cache(maxsize=1)
def initialize_lmdeploy_pipeline(
model,
tp=1,
chat_template=None,
log_level="WARNING",
model_format="hf",
quant_policy=0,
):
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
lmdeploy_pipe = pipeline(
model_path=model,
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=(
ChatTemplateConfig(model_name=chat_template) if chat_template else None
),
log_level="WARNING",
)
return lmdeploy_pipe
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def lmdeploy_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
chat_template=None,
model_format="hf",
quant_policy=0,
**kwargs,
) -> str:
"""
Args:
model (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
except Exception:
raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
kwargs.pop("hashing_kv", None)
kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
do_preprocess = kwargs.pop("do_preprocess", True)
do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs
version = version_info
if do_sample is not None and version < (0, 6, 0):
raise RuntimeError(
"`do_sample` parameter is not supported by lmdeploy until "
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
)
else:
do_sample = True
gen_params.update(do_sample=do_sample)
lmdeploy_pipe = initialize_lmdeploy_pipeline(
model=model,
tp=tp,
chat_template=chat_template,
model_format=model_format,
quant_policy=quant_policy,
log_level="WARNING",
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
**gen_params,
)
response = ""
async for res in lmdeploy_pipe.generate(
messages,
gen_config=gen_config,
do_preprocess=do_preprocess,
stream_response=False,
session_id=1,
):
response += res.response
return response

224
minirag/llm/lollms.py Normal file
View File

@@ -0,0 +1,224 @@
"""
LoLLMs (Lord of Large Language Models) Interface Module
=====================================================
This module provides the official interface for interacting with LoLLMs (Lord of Large Language and multimodal Systems),
a unified framework for AI model interaction and deployment.
LoLLMs is designed as a "one tool to rule them all" solution, providing seamless integration
with various AI models while maintaining high performance and user-friendly interfaces.
Author: ParisNeo
Created: 2024-01-24
License: Apache 2.0
Copyright (c) 2024 ParisNeo
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Version: 2.0.0
Change Log:
- 2.0.0 (2024-01-24):
* Added async support for model inference
* Implemented streaming capabilities
* Added embedding generation functionality
* Enhanced parameter handling
* Improved error handling and timeout management
Dependencies:
- aiohttp
- numpy
- Python >= 3.10
Features:
- Async text generation with streaming support
- Embedding generation
- Configurable model parameters
- System prompt and chat history support
- Timeout handling
- API key authentication
Usage:
from llm_interfaces.lollms import lollms_model_complete, lollms_embed
Project Repository: https://github.com/ParisNeo/lollms
Documentation: https://github.com/ParisNeo/lollms/docs
"""
__version__ = "1.0.0"
__author__ = "ParisNeo"
__status__ = "Production"
__project_url__ = "https://github.com/ParisNeo/lollms"
__doc_url__ = "https://github.com/ParisNeo/lollms/docs"
import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
if not pm.is_installed("aiohttp"):
pm.install("aiohttp")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
import aiohttp
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from typing import Union, List
import numpy as np
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def lollms_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url="http://localhost:9600",
**kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Client implementation for lollms generation."""
stream = True if kwargs.get("stream") else False
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
if api_key
else {"Content-Type": "application/json"}
)
# Extract lollms specific parameters
request_data = {
"prompt": prompt,
"model_name": model,
"personality": kwargs.get("personality", -1),
"n_predict": kwargs.get("n_predict", None),
"stream": stream,
"temperature": kwargs.get("temperature", 0.1),
"top_k": kwargs.get("top_k", 50),
"top_p": kwargs.get("top_p", 0.95),
"repeat_penalty": kwargs.get("repeat_penalty", 0.8),
"repeat_last_n": kwargs.get("repeat_last_n", 40),
"seed": kwargs.get("seed", None),
"n_threads": kwargs.get("n_threads", 8),
}
# Prepare the full prompt including history
full_prompt = ""
if system_prompt:
full_prompt += f"{system_prompt}\n"
for msg in history_messages:
full_prompt += f"{msg['role']}: {msg['content']}\n"
full_prompt += prompt
request_data["prompt"] = full_prompt
timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
if stream:
async def inner():
async with session.post(
f"{base_url}/lollms_generate", json=request_data
) as response:
async for line in response.content:
yield line.decode().strip()
return inner()
else:
async with session.post(
f"{base_url}/lollms_generate", json=request_data
) as response:
return await response.text()
async def lollms_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
"""Complete function for lollms model generation."""
# Extract and remove keyword_extraction from kwargs if present
keyword_extraction = kwargs.pop("keyword_extraction", None)
# Get model name from config
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
# If keyword extraction is needed, we might need to modify the prompt
# or add specific parameters for JSON output (if lollms supports it)
if keyword_extraction:
# Note: You might need to adjust this based on how lollms handles structured output
pass
return await lollms_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def lollms_embed(
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
) -> np.ndarray:
"""
Generate embeddings for a list of texts using lollms server.
Args:
texts: List of strings to embed
embed_model: Model name (not used directly as lollms uses configured vectorizer)
base_url: URL of the lollms server
**kwargs: Additional arguments passed to the request
Returns:
np.ndarray: Array of embeddings
"""
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": api_key}
if api_key
else {"Content-Type": "application/json"}
)
async with aiohttp.ClientSession(headers=headers) as session:
embeddings = []
for text in texts:
request_data = {"text": text}
async with session.post(
f"{base_url}/lollms_embed",
json=request_data,
) as response:
result = await response.json()
embeddings.append(result["vector"])
return np.array(embeddings)

View File

@@ -0,0 +1,108 @@
"""
OpenAI LLM Interface Module
==========================
This module provides interfaces for interacting with openai's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- openai
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.nvidia_openai import nvidia_openai_model_complete, nvidia_openai_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
import os
if sys.version_info < (3, 9):
pass
else:
pass
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("openai"):
pm.install("openai")
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
)
import numpy as np
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def nvidia_openai_embed(
texts: list[str],
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
base_url: str = "https://integrate.api.nvidia.com/v1",
api_key: str = None,
input_type: str = "passage", # query for retrieval, passage for embedding
trunc: str = "NONE", # NONE or START or END
encode: str = "float", # float or base64
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model,
input=texts,
encoding_format=encode,
extra_body={"input_type": input_type, "truncate": trunc},
)
return np.array([dp.embedding for dp in response.data])

158
minirag/llm/ollama.py Normal file
View File

@@ -0,0 +1,158 @@
"""
Ollama LLM Interface Module
==========================
This module provides interfaces for interacting with Ollama's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- ollama
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.ollama_interface import ollama_model_complete, ollama_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("ollama"):
pm.install("ollama")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
import ollama
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
import numpy as np
from typing import Union
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def ollama_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
**kwargs,
) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
kwargs.pop("hashing_kv", None)
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
if api_key
else {"Content-Type": "application/json"}
)
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
"""cannot cache stream response"""
async def inner():
async for chunk in response:
yield chunk["message"]["content"]
return inner()
else:
return response["message"]["content"]
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
"""
Deprecated in favor of `embed`.
"""
embed_text = []
ollama_client = ollama.Client(**kwargs)
for text in texts:
data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])
return embed_text
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": api_key}
if api_key
else {"Content-Type": "application/json"}
)
kwargs["headers"] = headers
ollama_client = ollama.Client(**kwargs)
data = ollama_client.embed(model=embed_model, input=texts)
return data["embeddings"]

230
minirag/llm/openai.py Normal file
View File

@@ -0,0 +1,230 @@
"""
OpenAI LLM Interface Module
==========================
This module provides interfaces for interacting with openai's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- openai
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.openai import openai_model_complete, openai_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
import os
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("openai"):
pm.install("openai")
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
import numpy as np
from typing import Union
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs,
) -> str:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# 添加日志输出
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}")
logger.debug("Full context:")
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hasattr(response, "__aiter__"):
async def inner():
async for chunk in response:
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
return inner()
else:
content = response.choices[0].message.content
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
return content
async def openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await openai_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_mini_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def nvidia_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await openai_complete_if_cache(
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url="https://integrate.api.nvidia.com/v1",
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def openai_embed(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])

109
minirag/llm/siliconcloud.py Normal file
View File

@@ -0,0 +1,109 @@
"""
SiliconCloud Embedding Interface Module
==========================
This module provides interfaces for interacting with SiliconCloud system,
including embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added embedding generation
Dependencies:
- tenacity
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.siliconcloud import siliconcloud_model_complete, siliconcloud_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
if sys.version_info < (3, 9):
pass
else:
pass
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("lmdeploy"):
pm.install("lmdeploy")
from openai import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
import numpy as np
import aiohttp
import base64
import struct
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def siliconcloud_embedding(
texts: list[str],
model: str = "netease-youdao/bce-embedding-base_v1",
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
max_token_size: int = 512,
api_key: str = None,
) -> np.ndarray:
if api_key and not api_key.startswith("Bearer "):
api_key = "Bearer " + api_key
headers = {"Authorization": api_key, "Content-Type": "application/json"}
truncate_texts = [text[0:max_token_size] for text in texts]
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
base64_strings = []
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json()
if "code" in content:
raise ValueError(content)
base64_strings = [item["embedding"] for item in content["data"]]
embeddings = []
for string in base64_strings:
decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4
float_array = struct.unpack("<" + "f" * n, decode_bytes)
embeddings.append(float_array)
return np.array(embeddings)

246
minirag/llm/zhipu.py Normal file
View File

@@ -0,0 +1,246 @@
"""
Zhipu LLM Interface Module
==========================
This module provides interfaces for interacting with LMDeploy's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- tenacity
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.zhipu import zhipu_model_complete, zhipu_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
import re
import json
if sys.version_info < (3, 9):
pass
else:
pass
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("zhipuai"):
pm.install("zhipuai")
from openai import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
import numpy as np
from typing import Union, List, Optional, Dict
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def zhipu_complete_if_cache(
prompt: Union[str, List[Dict[str, str]]],
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
history_messages: List[Dict[str, str]] = [],
**kwargs,
) -> str:
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()
messages = []
if not system_prompt:
system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
# Add system prompt if provided
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Add debug logging
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}")
# Remove unsupported kwargs
kwargs = {
k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
}
response = client.chat.completions.create(model=model, messages=messages, **kwargs)
return response.choices[0].message.content
async def zhipu_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
):
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
# Add a system prompt to guide the model to return JSON format
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
Please analyze the content and extract two types of keywords:
1. High-level keywords: Important concepts and main themes
2. Low-level keywords: Specific details and supporting elements
Return your response in this exact JSON format:
{
"high_level_keywords": ["keyword1", "keyword2"],
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
}
Only return the JSON, no other text."""
# Combine with existing system prompt if any
if system_prompt:
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
else:
system_prompt = extraction_prompt
try:
response = await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
# Try to parse as JSON
try:
data = json.loads(response)
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", []),
)
except json.JSONDecodeError:
# If direct JSON parsing fails, try to extract JSON from text
match = re.search(r"\{[\s\S]*\}", response)
if match:
try:
data = json.loads(match.group())
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", []),
)
except json.JSONDecodeError:
pass
# If all parsing fails, log warning and return empty format
logger.warning(
f"Failed to parse keyword extraction response: {response}"
)
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
except Exception as e:
logger.error(f"Error during keyword extraction: {str(e)}")
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
else:
# For non-keyword-extraction, just return the raw response string
return await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def zhipu_embedding(
texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
) -> np.ndarray:
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()
# Convert single text to list if needed
if isinstance(texts, str):
texts = [texts]
embeddings = []
for text in texts:
try:
response = client.embeddings.create(model=model, input=[text], **kwargs)
embeddings.append(response.data[0].embedding)
except Exception as e:
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
return np.array(embeddings)

View File

@@ -33,15 +33,31 @@ from .base import (
QueryParam,
)
from .storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
)
from .kg.neo4j_impl import Neo4JStorage
from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
STORAGES = {
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.jsondocstatus_impl",
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorge": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
}
# future KG integrations
@@ -49,17 +65,50 @@ from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBS
# GraphStorage as ArangoDBStorage
# )
def lazy_external_import(module_name: str, class_name: str):
"""Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package
import inspect
caller_frame = inspect.currentframe().f_back
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args, **kwargs):
import importlib
module = importlib.import_module(module_name, package=package)
cls = getattr(module, class_name)
return cls(*args, **kwargs)
return import_class
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
This function tries to get the current event loop. If the current event loop is closed or does not exist,
it creates a new event loop and sets it as the current event loop.
Returns:
asyncio.AbstractEventLoop: The current or newly created event loop.
"""
try:
return asyncio.get_event_loop()
# Try to get the current event loop
current_loop = asyncio.get_event_loop()
if current_loop.is_closed():
raise RuntimeError("Event loop is closed.")
return current_loop
except RuntimeError:
# If no event loop exists or it is closed, create a new one
logger.info("Creating a new event loop in main thread.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
return new_loop
return loop
@dataclass
@@ -100,12 +149,12 @@ class MiniRAG:
}
)
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
embedding_func: EmbeddingFunc = None
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
# LLM
llm_model_func: callable = hf_model_complete#gpt_4o_mini_complete #
llm_model_func: callable = None
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
@@ -120,27 +169,55 @@ class MiniRAG:
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
# Add new field for document status storage type
doc_status_storage: str = field(default="JsonDocStatusStorage")
# Custom Chunking Function
chunking_func: callable = chunking_by_token_size
chunking_func_kwargs: dict = field(default_factory=dict)
def __post_init__(self):
log_file = os.path.join(self.working_dir, "minirag.log")
set_logger(log_file)
logger.setLevel(self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
# show config
global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"MiniRAG init with param:\n {_print_config}\n")
# @TODO: should move all storage setup here to leverage initial start params attached to self.
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self._get_storage_class()[self.kv_storage]
self._get_storage_class(self.kv_storage)
)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
self.vector_storage
]
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
self.graph_storage
]
)
self.key_string_value_json_storage_cls = partial(
self.key_string_value_json_storage_cls, global_config=global_config
)
self.vector_db_storage_cls = partial(
self.vector_db_storage_cls, global_config=global_config
)
self.graph_storage_cls = partial(
self.graph_storage_cls, global_config=global_config
)
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
namespace="json_doc_status_storage",
embedding_func=None,
)
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
@@ -218,21 +295,39 @@ class MiniRAG:
**self.llm_model_kwargs,
)
)
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.doc_status = self.doc_status_storage_cls(
namespace="doc_status",
global_config=global_config,
embedding_func=None,
)
def _get_storage_class(self, storage_name: str) -> dict:
import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
def set_storage_client(self, db_client):
# Now only tested on Oracle Database
for storage in [
self.vector_db_storage_cls,
self.graph_storage_cls,
self.doc_status,
self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.key_string_value_json_storage_cls,
self.chunks_vdb,
self.relationships_vdb,
self.entities_vdb,
self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache,
]:
# set client
storage.db = db_client
def _get_storage_class(self) -> Type[BaseGraphStorage]:
return {
# kv storage
"JsonKVStorage": JsonKVStorage,
"OracleKVStorage": OracleKVStorage,
# vector storage
"NanoVectorDBStorage": NanoVectorDBStorage,
"OracleVectorDBStorage": OracleVectorDBStorage,
# graph storage
"NetworkXStorage": NetworkXStorage,
"Neo4JStorage": Neo4JStorage,
"OracleGraphStorage": OracleGraphStorage,
# "ArangoDBStorage": ArangoDBStorage
}
def insert(self, string_or_strings):
loop = always_get_an_event_loop()

View File

@@ -1,354 +0,0 @@
import asyncio
import html
import os
from dataclasses import dataclass
from typing import Any, Union, cast
import networkx as nx
import numpy as np
from nano_vectordb import NanoVectorDB
import copy
from .utils import (
logger,
load_json,
write_json,
compute_mdhash_id,
merge_tuples,
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
)
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def all_keys(self) -> list[str]:
return list(self._data.keys())
async def index_done_callback(self):
write_json(self._data, self._file_name)
async def get_by_id(self, id):
return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None):
if fields is None:
return [self._data.get(id, None) for id in ids]
return [
(
{k: v for k, v in self._data[id].items() if k in fields}
if self._data.get(id, None)
else None
)
for id in ids
]
async def filter_keys(self, data: list[str]) -> set[str]:
return set([s for s in data if s not in self._data])
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
return left_data
async def drop(self):
self._data = {}
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.#2
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
self.cosine_better_than_threshold = self.global_config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
list_data = [
{
"__id__": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
embedding = embedding[0]
results = self._client.query(
query=embedding,
top_k=top_k,
better_than_threshold=self.cosine_better_than_threshold,
)
results = [
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
]
return results
@property
def client_storage(self):
return getattr(self._client, "_NanoVectorDB__storage")
async def delete_entity(self, entity_name: str):
try:
entity_id = [compute_mdhash_id(entity_name, prefix="ent-")]
if self._client.get(entity_id):
self._client.delete(entity_id)
logger.info(f"Entity {entity_name} have been deleted.")
else:
logger.info(f"No entity found with name {entity_name}.")
except Exception as e:
logger.error(f"Error while deleting entity {entity_name}: {e}")
async def delete_relation(self, entity_name: str):
try:
relations = [
dp
for dp in self.client_storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete:
self._client.delete(ids_to_delete)
logger.info(
f"All relations related to entity {entity_name} have been deleted."
)
else:
logger.info(f"No relations found for entity {entity_name}.")
except Exception as e:
logger.error(
f"Error while deleting relations for entity {entity_name}: {e}"
)
async def index_done_callback(self):
self._client.save()
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod
def load_nx_graph(file_name) -> nx.Graph:
if os.path.exists(file_name):
return nx.read_graphml(file_name)
return None
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name):
logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
nx.write_graphml(graph, file_name)
@staticmethod
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
"""
from graspologic.utils import largest_connected_component
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
node_mapping = {
node: html.unescape(node.upper().strip()) for node in graph.nodes()
} # type: ignore
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Ensure an undirected graph with the same relationships will always be read the same way.
"""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def __post_init__(self):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
)
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
async def index_done_callback(self):
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool:
return self._graph.has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
return self._graph.has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> Union[dict, None]:
return self._graph.nodes.get(node_id)
async def get_types(self) -> list:
all_entity_type = []
all_type_w_name = {}
for n in self._graph.nodes(data=True):
key = n[1]['entity_type'].strip('\"')
all_entity_type.append(key)
if key not in all_type_w_name:
all_type_w_name[key] = []
all_type_w_name[key].append(n[0].strip('\"'))
else:
if len(all_type_w_name[key])<=1:
all_type_w_name[key].append(n[0].strip('\"'))
return list(set(all_entity_type)),all_type_w_name
async def get_node_from_types(self,type_list) -> Union[dict, None]:
node_list = []
for name, arrt in self._graph.nodes(data = True):
node_type = arrt.get('entity_type').strip('\"')
if node_type in type_list:
node_list.append(name)
node_datas = await asyncio.gather(
*[self.get_node(name) for name in node_list]
)
node_datas = [
{**n, "entity_name": k}
for k, n in zip(node_list, node_datas)
if n is not None
]
return node_datas#,node_dict
async def get_neighbors_within_k_hops(self,source_node_id: str, k):
count = 0
if await self.has_node(source_node_id):
source_edge = list(self._graph.edges(source_node_id))
else:
print("NO THIS ID:",source_node_id)
return []
count = count+1
while count<k:
count = count+1
sc_edge = copy.deepcopy(source_edge)
source_edge =[]
for pair in sc_edge:
append_edge = list(self._graph.edges(pair[-1]))
for tuples in merge_tuples([pair],append_edge):
source_edge.append(tuples)
return source_edge
async def node_degree(self, node_id: str) -> int:
return self._graph.degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
return self._graph.edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str):
if self._graph.has_node(source_node_id):
return list(self._graph.edges(source_node_id))
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
self._graph.add_node(node_id, **node_data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str):
"""
Delete a node from the graph based on the specified node_id.
:param node_id: The node_id to delete
"""
if self._graph.has_node(node_id):
self._graph.remove_node(node_id)
logger.info(f"Node {node_id} deleted from the graph.")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
# @TODO: NOT USED
async def _node2vec_embed(self):
from graspologic import embed
embeddings, nodes = embed.node2vec_embed(
self._graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids

View File

@@ -1,38 +1,33 @@
accelerate
aioboto3
aiofiles
aiohttp
asyncpg
configparser
graspologic
# database packages
graspologic
gremlinpython
hnswlib
nano-vectordb
neo4j
networkx
# Basic modules
numpy
ollama
openai
oracledb
psycopg-pool
psycopg[binary,pool]
pipmaster
pydantic
pymilvus
pymongo
pymysql
# File manipulation libraries
PyPDF2
python-docx
python-dotenv
pyvis
python-pptx
setuptools
# lmdeploy[all]
sqlalchemy
tenacity
json_repair
rouge
nltk
# LLM packages
tiktoken
torch
tqdm
transformers
xxhash
xxhash
# Extra libraries are installed when needed using pipmaster

View File

@@ -22,6 +22,15 @@ def read_requirements():
)
return deps
def read_api_requirements():
api_deps = []
try:
with open("./minirag/api/requirements.txt") as f:
api_deps = [line.strip() for line in f if line.strip()]
except FileNotFoundError:
print("Warning: API requirements.txt not found.")
return api_deps
long_description = read_long_description()
requirements = read_requirements()
@@ -29,7 +38,7 @@ requirements = read_requirements()
setuptools.setup(
name="minirag-hku",
url="https://github.com/HKUDS/MiniRAG",
version="0.0.1",
version="0.0.2",
author="HKUDS",
description="MiniRAG: Towards Extremely Simple Retrieval-Augmented Generation",
long_description=long_description,
@@ -48,4 +57,12 @@ setuptools.setup(
python_requires=">=3.9",#rec: 3.9.19
install_requires=requirements,
include_package_data=True, # Includes non-code files from MANIFEST.in
extras_require={
"api": read_api_requirements(), # API requirements as optional
},
entry_points={
"console_scripts": [
"minirag-server=minirag.api.minirag_server:main [api]",
],
},
)