Brought all lightrag structure to the minirag app
This commit is contained in:
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
__pycache__
|
||||
*.egg-info
|
||||
dickens/
|
||||
book.txt
|
||||
lightrag-dev/
|
||||
.idea/
|
||||
dist/
|
||||
env/
|
||||
local_neo4jWorkDir/
|
||||
neo4jWorkDir/
|
||||
ignore_this.txt
|
||||
.venv/
|
||||
*.ignore.*
|
||||
.ruff_cache/
|
||||
gui/
|
||||
*.log
|
||||
.vscode
|
||||
inputs
|
||||
rag_storage
|
||||
.env
|
||||
venv/
|
||||
examples/input/
|
||||
examples/output/
|
||||
.DS_Store
|
||||
#Remove config.ini from repo
|
||||
*.ini
|
||||
22
.pre-commit-config.yaml
Normal file
22
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: requirements-txt-fixer
|
||||
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.4
|
||||
hooks:
|
||||
- id: ruff-format
|
||||
- id: ruff
|
||||
args: [--fix, --ignore=E402]
|
||||
|
||||
|
||||
- repo: https://github.com/mgedmin/check-manifest
|
||||
rev: "0.49"
|
||||
hooks:
|
||||
- id: check-manifest
|
||||
stages: [manual]
|
||||
BIN
assets/logo.png
Normal file
BIN
assets/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.8 MiB |
7
minirag/api/.env.aoi.example
Normal file
7
minirag/api/.env.aoi.example
Normal file
@@ -0,0 +1,7 @@
|
||||
AZURE_OPENAI_API_VERSION=2024-08-01-preview
|
||||
AZURE_OPENAI_DEPLOYMENT=gpt-4o
|
||||
AZURE_OPENAI_API_KEY=myapikey
|
||||
AZURE_OPENAI_ENDPOINT=https://myendpoint.openai.azure.com
|
||||
|
||||
AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
|
||||
AZURE_EMBEDDING_API_VERSION=2023-05-15
|
||||
2
minirag/api/.gitignore
vendored
Normal file
2
minirag/api/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
inputs
|
||||
rag_storage
|
||||
475
minirag/api/README.md
Normal file
475
minirag/api/README.md
Normal file
@@ -0,0 +1,475 @@
|
||||
## Install with API Support
|
||||
|
||||
LightRAG provides optional API support through FastAPI servers that add RAG capabilities to existing LLM services. You can install LightRAG with API support in two ways:
|
||||
|
||||
### 1. Installation from PyPI
|
||||
|
||||
```bash
|
||||
pip install "lightrag-hku[api]"
|
||||
```
|
||||
|
||||
### 2. Installation from Source (Development)
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/HKUDS/lightrag.git
|
||||
|
||||
# Change to the repository directory
|
||||
cd lightrag
|
||||
|
||||
# create a Python virtual enviroment if neccesary
|
||||
# Install in editable mode with API support
|
||||
pip install -e ".[api]"
|
||||
```
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before running any of the servers, ensure you have the corresponding backend service running for both llm and embedding.
|
||||
The new api allows you to mix different bindings for llm/embeddings.
|
||||
For example, you have the possibility to use ollama for the embedding and openai for the llm.
|
||||
|
||||
#### For LoLLMs Server
|
||||
- LoLLMs must be running and accessible
|
||||
- Default connection: http://localhost:9600
|
||||
- Configure using --llm-binding-host and/or --embedding-binding-host if running on a different host/port
|
||||
|
||||
#### For Ollama Server
|
||||
- Ollama must be running and accessible
|
||||
- Requires environment variables setup or command line argument provided
|
||||
- Environment variables: LLM_BINDING=ollama, LLM_BINDING_HOST, LLM_MODEL
|
||||
- Command line arguments: --llm-binding=ollama, --llm-binding-host, --llm-model
|
||||
- Default connection is http://localhost:11434 if not priveded
|
||||
|
||||
> The default MAX_TOKENS(num_ctx) for Ollama is 32768. If your Ollama server is lacking or GPU memory, set it to a lower value.
|
||||
|
||||
#### For OpenAI Alike Server
|
||||
- Requires environment variables setup or command line argument provided
|
||||
- Environment variables: LLM_BINDING=ollama, LLM_BINDING_HOST, LLM_MODEL, LLM_BINDING_API_KEY
|
||||
- Command line arguments: --llm-binding=ollama, --llm-binding-host, --llm-model, --llm-binding-api-key
|
||||
- Default connection is https://api.openai.com/v1 if not priveded
|
||||
|
||||
#### For Azure OpenAI Server
|
||||
Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)):
|
||||
```bash
|
||||
# Change the resource group name, location and OpenAI resource name as needed
|
||||
RESOURCE_GROUP_NAME=LightRAG
|
||||
LOCATION=swedencentral
|
||||
RESOURCE_NAME=LightRAG-OpenAI
|
||||
|
||||
az login
|
||||
az group create --name $RESOURCE_GROUP_NAME --location $LOCATION
|
||||
az cognitiveservices account create --name $RESOURCE_NAME --resource-group $RESOURCE_GROUP_NAME --kind OpenAI --sku S0 --location swedencentral
|
||||
az cognitiveservices account deployment create --resource-group $RESOURCE_GROUP_NAME --model-format OpenAI --name $RESOURCE_NAME --deployment-name gpt-4o --model-name gpt-4o --model-version "2024-08-06" --sku-capacity 100 --sku-name "Standard"
|
||||
az cognitiveservices account deployment create --resource-group $RESOURCE_GROUP_NAME --model-format OpenAI --name $RESOURCE_NAME --deployment-name text-embedding-3-large --model-name text-embedding-3-large --model-version "1" --sku-capacity 80 --sku-name "Standard"
|
||||
az cognitiveservices account show --name $RESOURCE_NAME --resource-group $RESOURCE_GROUP_NAME --query "properties.endpoint"
|
||||
az cognitiveservices account keys list --name $RESOURCE_NAME -g $RESOURCE_GROUP_NAME
|
||||
|
||||
```
|
||||
The output of the last command will give you the endpoint and the key for the OpenAI API. You can use these values to set the environment variables in the `.env` file.
|
||||
|
||||
```
|
||||
LLM_BINDING=azure_openai
|
||||
LLM_BINDING_HOST=endpoint_of_azure_ai
|
||||
LLM_MODEL=model_name_of_azure_ai
|
||||
LLM_BINDING_API_KEY=api_key_of_azure_ai
|
||||
```
|
||||
|
||||
### About Ollama API
|
||||
|
||||
We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily.
|
||||
|
||||
#### Choose Query mode in chat
|
||||
|
||||
A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include:
|
||||
|
||||
/local
|
||||
/global
|
||||
/hybrid
|
||||
/naive
|
||||
/mix
|
||||
|
||||
For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。
|
||||
|
||||
#### Connect Open WebUI to LightRAG
|
||||
|
||||
After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
|
||||
|
||||
## Configuration
|
||||
|
||||
LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables.
|
||||
|
||||
For better performance, the API server's default values for TOP_K and COSINE_THRESHOLD are set to 50 and 0.4 respectively. If COSINE_THRESHOLD remains at its default value of 0.2 in LightRAG, many irrelevant entities and relations would be retrieved and sent to the LLM.
|
||||
|
||||
### Environment Variables
|
||||
|
||||
You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. Here's a complete example of available environment variables:
|
||||
|
||||
```env
|
||||
# Server Configuration
|
||||
HOST=0.0.0.0
|
||||
PORT=9721
|
||||
|
||||
# Directory Configuration
|
||||
WORKING_DIR=/app/data/rag_storage
|
||||
INPUT_DIR=/app/data/inputs
|
||||
|
||||
# RAG Configuration
|
||||
MAX_ASYNC=4
|
||||
MAX_TOKENS=32768
|
||||
EMBEDDING_DIM=1024
|
||||
MAX_EMBED_TOKENS=8192
|
||||
#HISTORY_TURNS=3
|
||||
#CHUNK_SIZE=1200
|
||||
#CHUNK_OVERLAP_SIZE=100
|
||||
#COSINE_THRESHOLD=0.4
|
||||
#TOP_K=50
|
||||
|
||||
# LLM Configuration
|
||||
LLM_BINDING=ollama
|
||||
LLM_BINDING_HOST=http://localhost:11434
|
||||
LLM_MODEL=mistral-nemo:latest
|
||||
|
||||
# must be set if using OpenAI LLM (LLM_MODEL must be set or set by command line parms)
|
||||
OPENAI_API_KEY=you_api_key
|
||||
|
||||
# Embedding Configuration
|
||||
EMBEDDING_BINDING=ollama
|
||||
EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||
EMBEDDING_MODEL=bge-m3:latest
|
||||
|
||||
# Security
|
||||
#LIGHTRAG_API_KEY=you-api-key-for-accessing-LightRAG
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# Optional SSL Configuration
|
||||
#SSL=true
|
||||
#SSL_CERTFILE=/path/to/cert.pem
|
||||
#SSL_KEYFILE=/path/to/key.pem
|
||||
|
||||
# Optional Timeout
|
||||
#TIMEOUT=30
|
||||
```
|
||||
|
||||
### Configuration Priority
|
||||
|
||||
The configuration values are loaded in the following order (highest priority first):
|
||||
1. Command-line arguments
|
||||
2. Environment variables
|
||||
3. Default values
|
||||
|
||||
For example:
|
||||
```bash
|
||||
# This command-line argument will override both the environment variable and default value
|
||||
python lightrag.py --port 8080
|
||||
|
||||
# The environment variable will override the default value but not the command-line argument
|
||||
PORT=7000 python lightrag.py
|
||||
```
|
||||
|
||||
#### LightRag Server Options
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| --host | 0.0.0.0 | Server host |
|
||||
| --port | 9721 | Server port |
|
||||
| --llm-binding | ollama | LLM binding to be used. Supported: lollms, ollama, openai |
|
||||
| --llm-binding-host | (dynamic) | LLM server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) |
|
||||
| --llm-model | mistral-nemo:latest | LLM model name |
|
||||
| --llm-binding-api-key | None | API Key for OpenAI Alike LLM |
|
||||
| --embedding-binding | ollama | Embedding binding to be used. Supported: lollms, ollama, openai |
|
||||
| --embedding-binding-host | (dynamic) | Embedding server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) |
|
||||
| --embedding-model | bge-m3:latest | Embedding model name |
|
||||
| --working-dir | ./rag_storage | Working directory for RAG storage |
|
||||
| --input-dir | ./inputs | Directory containing input documents |
|
||||
| --max-async | 4 | Maximum async operations |
|
||||
| --max-tokens | 32768 | Maximum token size |
|
||||
| --embedding-dim | 1024 | Embedding dimensions |
|
||||
| --max-embed-tokens | 8192 | Maximum embedding token size |
|
||||
| --timeout | None | Timeout in seconds (useful when using slow AI). Use None for infinite timeout |
|
||||
| --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
|
||||
| --key | None | API key for authentication. Protects lightrag server against unauthorized access |
|
||||
| --ssl | False | Enable HTTPS |
|
||||
| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) |
|
||||
| --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) |
|
||||
| --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. |
|
||||
| --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. |
|
||||
|
||||
### Example Usage
|
||||
|
||||
#### Running a Lightrag server with ollama default local server as llm and embedding backends
|
||||
|
||||
Ollama is the default backend for both llm and embedding, so by default you can run lightrag-server with no parameters and the default ones will be used. Make sure ollama is installed and is running and default models are already installed on ollama.
|
||||
|
||||
```bash
|
||||
# Run lightrag with ollama, mistral-nemo:latest for llm, and bge-m3:latest for embedding
|
||||
lightrag-server
|
||||
|
||||
# Using specific models (ensure they are installed in your ollama instance)
|
||||
lightrag-server --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-model nomic-embed-text --embedding-dim 1024
|
||||
|
||||
# Using an authentication key
|
||||
lightrag-server --key my-key
|
||||
|
||||
# Using lollms for llm and ollama for embedding
|
||||
lightrag-server --llm-binding lollms
|
||||
```
|
||||
|
||||
#### Running a Lightrag server with lollms default local server as llm and embedding backends
|
||||
|
||||
```bash
|
||||
# Run lightrag with lollms, mistral-nemo:latest for llm, and bge-m3:latest for embedding, use lollms for both llm and embedding
|
||||
lightrag-server --llm-binding lollms --embedding-binding lollms
|
||||
|
||||
# Using specific models (ensure they are installed in your ollama instance)
|
||||
lightrag-server --llm-binding lollms --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-binding lollms --embedding-model nomic-embed-text --embedding-dim 1024
|
||||
|
||||
# Using an authentication key
|
||||
lightrag-server --key my-key
|
||||
|
||||
# Using lollms for llm and openai for embedding
|
||||
lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small
|
||||
```
|
||||
|
||||
|
||||
#### Running a Lightrag server with openai server as llm and embedding backends
|
||||
|
||||
```bash
|
||||
# Run lightrag with lollms, GPT-4o-mini for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding
|
||||
lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small
|
||||
|
||||
# Using an authentication key
|
||||
lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small --key my-key
|
||||
|
||||
# Using lollms for llm and openai for embedding
|
||||
lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small
|
||||
```
|
||||
|
||||
#### Running a Lightrag server with azure openai server as llm and embedding backends
|
||||
|
||||
```bash
|
||||
# Run lightrag with lollms, GPT-4o-mini for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding
|
||||
lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small
|
||||
|
||||
# Using an authentication key
|
||||
lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding azure_openai --embedding-model text-embedding-3-small --key my-key
|
||||
|
||||
# Using lollms for llm and azure_openai for embedding
|
||||
lightrag-server --llm-binding lollms --embedding-binding azure_openai --embedding-model text-embedding-3-small
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- For LoLLMs: Make sure the specified models are installed in your LoLLMs instance
|
||||
- For Ollama: Make sure the specified models are installed in your Ollama instance
|
||||
- For OpenAI: Ensure you have set up your OPENAI_API_KEY environment variable
|
||||
- For Azure OpenAI: Build and configure your server as stated in the Prequisites section
|
||||
|
||||
For help on any server, use the --help flag:
|
||||
```bash
|
||||
lightrag-server --help
|
||||
```
|
||||
|
||||
Note: If you don't need the API functionality, you can install the base package without API support using:
|
||||
```bash
|
||||
pip install lightrag-hku
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality.
|
||||
|
||||
### Query Endpoints
|
||||
|
||||
#### POST /query
|
||||
Query the RAG system with options for different search modes.
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9721/query" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"query": "Your question here", "mode": "hybrid", ""}'
|
||||
```
|
||||
|
||||
#### POST /query/stream
|
||||
Stream responses from the RAG system.
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9721/query/stream" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"query": "Your question here", "mode": "hybrid"}'
|
||||
```
|
||||
|
||||
### Document Management Endpoints
|
||||
|
||||
#### POST /documents/text
|
||||
Insert text directly into the RAG system.
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9721/documents/text" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"text": "Your text content here", "description": "Optional description"}'
|
||||
```
|
||||
|
||||
#### POST /documents/file
|
||||
Upload a single file to the RAG system.
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9721/documents/file" \
|
||||
-F "file=@/path/to/your/document.txt" \
|
||||
-F "description=Optional description"
|
||||
```
|
||||
|
||||
#### POST /documents/batch
|
||||
Upload multiple files at once.
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9721/documents/batch" \
|
||||
-F "files=@/path/to/doc1.txt" \
|
||||
-F "files=@/path/to/doc2.txt"
|
||||
```
|
||||
|
||||
#### POST /documents/scan
|
||||
|
||||
Trigger document scan for new files in the Input directory.
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:9721/documents/scan" --max-time 1800
|
||||
```
|
||||
|
||||
> Ajust max-time according to the estimated index time for all new files.
|
||||
|
||||
### Ollama Emulation Endpoints
|
||||
|
||||
#### GET /api/version
|
||||
|
||||
Get Ollama version information
|
||||
|
||||
```bash
|
||||
curl http://localhost:9721/api/version
|
||||
```
|
||||
|
||||
#### GET /api/tags
|
||||
|
||||
Get Ollama available models
|
||||
|
||||
```bash
|
||||
curl http://localhost:9721/api/tags
|
||||
```
|
||||
|
||||
#### POST /api/chat
|
||||
|
||||
Handle chat completion requests
|
||||
|
||||
```shell
|
||||
curl -N -X POST http://localhost:9721/api/chat -H "Content-Type: application/json" -d \
|
||||
'{"model":"lightrag:latest","messages":[{"role":"user","content":"猪八戒是谁"}],"stream":true}'
|
||||
```
|
||||
|
||||
> For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md)
|
||||
|
||||
#### DELETE /documents
|
||||
|
||||
Clear all documents from the RAG system.
|
||||
|
||||
```bash
|
||||
curl -X DELETE "http://localhost:9721/documents"
|
||||
```
|
||||
|
||||
### Utility Endpoints
|
||||
|
||||
#### GET /health
|
||||
Check server health and configuration.
|
||||
|
||||
```bash
|
||||
curl "http://localhost:9721/health"
|
||||
```
|
||||
|
||||
## Development
|
||||
Contribute to the project: [Guide](contributor-readme.MD)
|
||||
|
||||
### Running in Development Mode
|
||||
|
||||
For LoLLMs:
|
||||
```bash
|
||||
uvicorn lollms_lightrag_server:app --reload --port 9721
|
||||
```
|
||||
|
||||
For Ollama:
|
||||
```bash
|
||||
uvicorn ollama_lightrag_server:app --reload --port 9721
|
||||
```
|
||||
|
||||
For OpenAI:
|
||||
```bash
|
||||
uvicorn openai_lightrag_server:app --reload --port 9721
|
||||
```
|
||||
For Azure OpenAI:
|
||||
```bash
|
||||
uvicorn azure_openai_lightrag_server:app --reload --port 9721
|
||||
```
|
||||
### API Documentation
|
||||
|
||||
When any server is running, visit:
|
||||
- Swagger UI: http://localhost:9721/docs
|
||||
- ReDoc: http://localhost:9721/redoc
|
||||
|
||||
### Testing API Endpoints
|
||||
|
||||
You can test the API endpoints using the provided curl commands or through the Swagger UI interface. Make sure to:
|
||||
1. Start the appropriate backend service (LoLLMs, Ollama, or OpenAI)
|
||||
2. Start the RAG server
|
||||
3. Upload some documents using the document management endpoints
|
||||
4. Query the system using the query endpoints
|
||||
5. Trigger document scan if new files is put into inputs directory
|
||||
|
||||
### Important Features
|
||||
|
||||
#### Automatic Document Vectorization
|
||||
When starting any of the servers with the `--input-dir` parameter, the system will automatically:
|
||||
1. Check for existing vectorized content in the database
|
||||
2. Only vectorize new documents that aren't already in the database
|
||||
3. Make all content immediately available for RAG queries
|
||||
|
||||
This intelligent caching mechanism:
|
||||
- Prevents unnecessary re-vectorization of existing documents
|
||||
- Reduces startup time for subsequent runs
|
||||
- Preserves system resources
|
||||
- Maintains consistency across restarts
|
||||
|
||||
**Important Notes:**
|
||||
- The `--input-dir` parameter enables automatic document processing at startup
|
||||
- Documents already in the database are not re-vectorized
|
||||
- Only new documents in the input directory will be processed
|
||||
- This optimization significantly reduces startup time for subsequent runs
|
||||
- The working directory (`--working-dir`) stores the vectorized documents database
|
||||
|
||||
## Install Lightrag as a Linux Service
|
||||
|
||||
Create your service file: `lightrag-server.sevice`. Modified the following lines from `lightrag-server.sevice.example`
|
||||
|
||||
```text
|
||||
Description=LightRAG Ollama Service
|
||||
WorkingDirectory=<lightrag installed directory>
|
||||
ExecStart=<lightrag installed directory>/lightrag/api/start_lightrag.sh
|
||||
```
|
||||
|
||||
Create your service startup script: `start_lightrag.sh`. Change you python virtual environment activation method as need:
|
||||
|
||||
```shell
|
||||
#!/bin/bash
|
||||
|
||||
# python virtual environment activation
|
||||
source /home/netman/lightrag-xyj/venv/bin/activate
|
||||
# start lightrag api server
|
||||
lightrag-server
|
||||
```
|
||||
|
||||
Install lightrag.service in Linux. Sample commands in Ubuntu server look like:
|
||||
|
||||
```shell
|
||||
sudo cp lightrag-server.service /etc/systemd/system/
|
||||
sudo systemctl daemon-reload
|
||||
sudo systemctl start lightrag-server.service
|
||||
sudo systemctl status lightrag-server.service
|
||||
sudo systemctl enable lightrag-server.service
|
||||
```
|
||||
1
minirag/api/__init__.py
Normal file
1
minirag/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__api_version__ = "1.0.3"
|
||||
1939
minirag/api/minirag_server.py
Normal file
1939
minirag/api/minirag_server.py
Normal file
File diff suppressed because it is too large
Load Diff
12
minirag/api/requirements.txt
Normal file
12
minirag/api/requirements.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
ascii_colors
|
||||
fastapi
|
||||
nest_asyncio
|
||||
numpy
|
||||
pipmaster
|
||||
python-dotenv
|
||||
python-multipart
|
||||
tenacity
|
||||
tiktoken
|
||||
torch
|
||||
tqdm
|
||||
uvicorn
|
||||
2
minirag/api/static/README.md
Normal file
2
minirag/api/static/README.md
Normal file
@@ -0,0 +1,2 @@
|
||||
# LightRag Webui
|
||||
A simple webui to interact with the lightrag datalake
|
||||
BIN
minirag/api/static/favicon.ico
Normal file
BIN
minirag/api/static/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 MiB |
104
minirag/api/static/index.html
Normal file
104
minirag/api/static/index.html
Normal file
@@ -0,0 +1,104 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>MiniRag Interface</title>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
||||
<style>
|
||||
.fade-in {
|
||||
animation: fadeIn 0.3s ease-in;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; }
|
||||
to { opacity: 1; }
|
||||
}
|
||||
|
||||
.spin {
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
from { transform: rotate(0deg); }
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.slide-in {
|
||||
animation: slideIn 0.3s ease-out;
|
||||
}
|
||||
|
||||
@keyframes slideIn {
|
||||
from { transform: translateX(-100%); }
|
||||
to { transform: translateX(0); }
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body class="bg-gray-50">
|
||||
<div class="flex h-screen">
|
||||
<!-- Sidebar -->
|
||||
<div class="w-64 bg-white shadow-lg">
|
||||
<div class="p-4">
|
||||
<h1 class="text-xl font-bold text-gray-800 mb-6">MiniRag</h1>
|
||||
<nav class="space-y-2">
|
||||
<a href="#" class="nav-item" data-page="file-manager">
|
||||
<div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors">
|
||||
<svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M3 7v10a2 2 0 002 2h14a2 2 0 002-2V9a2 2 0 00-2-2h-6l-2-2H5a2 2 0 00-2 2z"/>
|
||||
</svg>
|
||||
File Manager
|
||||
</div>
|
||||
</a>
|
||||
<a href="#" class="nav-item" data-page="query">
|
||||
<div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors">
|
||||
<svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z"/>
|
||||
</svg>
|
||||
Query Database
|
||||
</div>
|
||||
</a>
|
||||
<a href="#" class="nav-item" data-page="knowledge-graph">
|
||||
<div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors">
|
||||
<svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z"/>
|
||||
</svg>
|
||||
Knowledge Graph
|
||||
</div>
|
||||
</a>
|
||||
<a href="#" class="nav-item" data-page="status">
|
||||
<div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors">
|
||||
<svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"/>
|
||||
</svg>
|
||||
Status
|
||||
</div>
|
||||
</a>
|
||||
<a href="#" class="nav-item" data-page="settings">
|
||||
<div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors">
|
||||
<svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"/>
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"/>
|
||||
</svg>
|
||||
Settings
|
||||
</div>
|
||||
</a>
|
||||
</nav>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Main Content -->
|
||||
<div class="flex-1 overflow-auto p-6">
|
||||
<div id="content" class="fade-in"></div>
|
||||
</div>
|
||||
|
||||
<!-- Toast Notification -->
|
||||
<div id="toast" class="fixed bottom-4 right-4 hidden">
|
||||
<div class="bg-gray-800 text-white px-6 py-3 rounded-lg shadow-lg"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="/js/api.js"></script>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
404
minirag/api/static/js/api.js
Normal file
404
minirag/api/static/js/api.js
Normal file
@@ -0,0 +1,404 @@
|
||||
// State management
|
||||
const state = {
|
||||
apiKey: localStorage.getItem('apiKey') || '',
|
||||
files: [],
|
||||
indexedFiles: [],
|
||||
currentPage: 'file-manager'
|
||||
};
|
||||
|
||||
// Utility functions
|
||||
const showToast = (message, duration = 3000) => {
|
||||
const toast = document.getElementById('toast');
|
||||
toast.querySelector('div').textContent = message;
|
||||
toast.classList.remove('hidden');
|
||||
setTimeout(() => toast.classList.add('hidden'), duration);
|
||||
};
|
||||
|
||||
const fetchWithAuth = async (url, options = {}) => {
|
||||
const headers = {
|
||||
...(options.headers || {}),
|
||||
...(state.apiKey ? { 'Authorization': `Bearer ${state.apiKey}` } : {})
|
||||
};
|
||||
return fetch(url, { ...options, headers });
|
||||
};
|
||||
|
||||
// Page renderers
|
||||
const pages = {
|
||||
'file-manager': () => `
|
||||
<div class="space-y-6">
|
||||
<h2 class="text-2xl font-bold text-gray-800">File Manager</h2>
|
||||
|
||||
<div class="border-2 border-dashed border-gray-300 rounded-lg p-8 text-center hover:border-gray-400 transition-colors">
|
||||
<input type="file" id="fileInput" multiple accept=".txt,.md,.doc,.docx,.pdf,.pptx" class="hidden">
|
||||
<label for="fileInput" class="cursor-pointer">
|
||||
<svg class="mx-auto h-12 w-12 text-gray-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 16a4 4 0 01-.88-7.903A5 5 0 1115.9 6L16 6a5 5 0 011 9.9M15 13l-3-3m0 0l-3 3m3-3v12"/>
|
||||
</svg>
|
||||
<p class="mt-2 text-gray-600">Drag files here or click to select</p>
|
||||
<p class="text-sm text-gray-500">Supported formats: TXT, MD, DOC, PDF, PPTX</p>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div id="fileList" class="space-y-2">
|
||||
<h3 class="text-lg font-semibold text-gray-700">Selected Files</h3>
|
||||
<div class="space-y-2"></div>
|
||||
</div>
|
||||
<div id="uploadProgress" class="hidden mt-4">
|
||||
<div class="w-full bg-gray-200 rounded-full h-2.5">
|
||||
<div class="bg-blue-600 h-2.5 rounded-full" style="width: 0%"></div>
|
||||
</div>
|
||||
<p class="text-sm text-gray-600 mt-2"><span id="uploadStatus">0</span> files processed</p>
|
||||
</div>
|
||||
<button id="rescanBtn" class="flex items-center bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="20" height="20" fill="currentColor" class="mr-2">
|
||||
<path d="M12 4a8 8 0 1 1-8 8H2.5a9.5 9.5 0 1 0 2.8-6.7L2 3v6h6L5.7 6.7A7.96 7.96 0 0 1 12 4z"/>
|
||||
</svg>
|
||||
Rescan Files
|
||||
</button>
|
||||
<button id="uploadBtn" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
|
||||
Upload & Index Files
|
||||
</button>
|
||||
|
||||
<div id="indexedFiles" class="space-y-2">
|
||||
<h3 class="text-lg font-semibold text-gray-700">Indexed Files</h3>
|
||||
<div class="space-y-2"></div>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
`,
|
||||
|
||||
'query': () => `
|
||||
<div class="space-y-6">
|
||||
<h2 class="text-2xl font-bold text-gray-800">Query Database</h2>
|
||||
|
||||
<div class="space-y-4">
|
||||
<div>
|
||||
<label class="block text-sm font-medium text-gray-700">Query Mode</label>
|
||||
<select id="queryMode" class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500">
|
||||
<option value="light">Light</option>
|
||||
<option value="naive">Naive</option>
|
||||
<option value="mini">Mini</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label class="block text-sm font-medium text-gray-700">Query</label>
|
||||
<textarea id="queryInput" rows="4" class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500"></textarea>
|
||||
</div>
|
||||
|
||||
<button id="queryBtn" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
|
||||
Send Query
|
||||
</button>
|
||||
|
||||
<div id="queryResult" class="mt-4 p-4 bg-white rounded-lg shadow"></div>
|
||||
</div>
|
||||
</div>
|
||||
`,
|
||||
|
||||
'knowledge-graph': () => `
|
||||
<div class="flex items-center justify-center h-full">
|
||||
<div class="text-center">
|
||||
<svg class="mx-auto h-12 w-12 text-gray-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10"/>
|
||||
</svg>
|
||||
<h3 class="mt-2 text-sm font-medium text-gray-900">Under Construction</h3>
|
||||
<p class="mt-1 text-sm text-gray-500">Knowledge graph visualization will be available in a future update.</p>
|
||||
</div>
|
||||
</div>
|
||||
`,
|
||||
|
||||
'status': () => `
|
||||
<div class="space-y-6">
|
||||
<h2 class="text-2xl font-bold text-gray-800">System Status</h2>
|
||||
<div id="statusContent" class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||
<div class="p-6 bg-white rounded-lg shadow-sm">
|
||||
<h3 class="text-lg font-semibold mb-4">System Health</h3>
|
||||
<div id="healthStatus"></div>
|
||||
</div>
|
||||
<div class="p-6 bg-white rounded-lg shadow-sm">
|
||||
<h3 class="text-lg font-semibold mb-4">Configuration</h3>
|
||||
<div id="configStatus"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`,
|
||||
|
||||
'settings': () => `
|
||||
<div class="space-y-6">
|
||||
<h2 class="text-2xl font-bold text-gray-800">Settings</h2>
|
||||
|
||||
<div class="max-w-xl">
|
||||
<div class="space-y-4">
|
||||
<div>
|
||||
<label class="block text-sm font-medium text-gray-700">API Key</label>
|
||||
<input type="password" id="apiKeyInput" value="${state.apiKey}"
|
||||
class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500">
|
||||
</div>
|
||||
|
||||
<button id="saveSettings" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors">
|
||||
Save Settings
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`
|
||||
};
|
||||
|
||||
// Page handlers
|
||||
const handlers = {
|
||||
'file-manager': () => {
|
||||
const fileInput = document.getElementById('fileInput');
|
||||
const dropZone = fileInput.parentElement.parentElement;
|
||||
const fileList = document.querySelector('#fileList div');
|
||||
const indexedFiles = document.querySelector('#indexedFiles div');
|
||||
const uploadBtn = document.getElementById('uploadBtn');
|
||||
|
||||
const updateFileList = () => {
|
||||
fileList.innerHTML = state.files.map(file => `
|
||||
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
|
||||
<span>${file.name}</span>
|
||||
<button class="text-red-600 hover:text-red-700" onclick="removeFile('${file.name}')">
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
`).join('');
|
||||
};
|
||||
|
||||
const updateIndexedFiles = async () => {
|
||||
const response = await fetchWithAuth('/health');
|
||||
const data = await response.json();
|
||||
indexedFiles.innerHTML = data.indexed_files.map(file => `
|
||||
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
|
||||
<span>${file}</span>
|
||||
</div>
|
||||
`).join('');
|
||||
};
|
||||
|
||||
dropZone.addEventListener('dragover', (e) => {
|
||||
e.preventDefault();
|
||||
dropZone.classList.add('border-blue-500');
|
||||
});
|
||||
|
||||
dropZone.addEventListener('dragleave', () => {
|
||||
dropZone.classList.remove('border-blue-500');
|
||||
});
|
||||
|
||||
dropZone.addEventListener('drop', (e) => {
|
||||
e.preventDefault();
|
||||
dropZone.classList.remove('border-blue-500');
|
||||
const files = Array.from(e.dataTransfer.files);
|
||||
state.files.push(...files);
|
||||
updateFileList();
|
||||
});
|
||||
|
||||
fileInput.addEventListener('change', () => {
|
||||
state.files.push(...Array.from(fileInput.files));
|
||||
updateFileList();
|
||||
});
|
||||
|
||||
uploadBtn.addEventListener('click', async () => {
|
||||
if (state.files.length === 0) {
|
||||
showToast('Please select files to upload');
|
||||
return;
|
||||
}
|
||||
let apiKey = localStorage.getItem('apiKey') || '';
|
||||
const progress = document.getElementById('uploadProgress');
|
||||
const progressBar = progress.querySelector('div');
|
||||
const statusText = document.getElementById('uploadStatus');
|
||||
progress.classList.remove('hidden');
|
||||
|
||||
for (let i = 0; i < state.files.length; i++) {
|
||||
const formData = new FormData();
|
||||
formData.append('file', state.files[i]);
|
||||
|
||||
try {
|
||||
await fetch('/documents/upload', {
|
||||
method: 'POST',
|
||||
headers: apiKey ? { 'Authorization': `Bearer ${apiKey}` } : {},
|
||||
body: formData
|
||||
});
|
||||
|
||||
const percentage = ((i + 1) / state.files.length) * 100;
|
||||
progressBar.style.width = `${percentage}%`;
|
||||
statusText.textContent = `${i + 1}/${state.files.length}`;
|
||||
} catch (error) {
|
||||
console.error('Upload error:', error);
|
||||
}
|
||||
}
|
||||
progress.classList.add('hidden');
|
||||
});
|
||||
|
||||
rescanBtn.addEventListener('click', async () => {
|
||||
const progress = document.getElementById('uploadProgress');
|
||||
const progressBar = progress.querySelector('div');
|
||||
const statusText = document.getElementById('uploadStatus');
|
||||
progress.classList.remove('hidden');
|
||||
|
||||
try {
|
||||
// Start the scanning process
|
||||
const scanResponse = await fetch('/documents/scan', {
|
||||
method: 'POST',
|
||||
});
|
||||
|
||||
if (!scanResponse.ok) {
|
||||
throw new Error('Scan failed to start');
|
||||
}
|
||||
|
||||
// Start polling for progress
|
||||
const pollInterval = setInterval(async () => {
|
||||
const progressResponse = await fetch('/documents/scan-progress');
|
||||
const progressData = await progressResponse.json();
|
||||
|
||||
// Update progress bar
|
||||
progressBar.style.width = `${progressData.progress}%`;
|
||||
|
||||
// Update status text
|
||||
if (progressData.total_files > 0) {
|
||||
statusText.textContent = `Processing ${progressData.current_file} (${progressData.indexed_count}/${progressData.total_files})`;
|
||||
}
|
||||
|
||||
// Check if scanning is complete
|
||||
if (!progressData.is_scanning) {
|
||||
clearInterval(pollInterval);
|
||||
progress.classList.add('hidden');
|
||||
statusText.textContent = 'Scan complete!';
|
||||
}
|
||||
}, 1000); // Poll every second
|
||||
|
||||
} catch (error) {
|
||||
console.error('Upload error:', error);
|
||||
progress.classList.add('hidden');
|
||||
statusText.textContent = 'Error during scanning process';
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
updateIndexedFiles();
|
||||
},
|
||||
|
||||
'query': () => {
|
||||
const queryBtn = document.getElementById('queryBtn');
|
||||
const queryInput = document.getElementById('queryInput');
|
||||
const queryMode = document.getElementById('queryMode');
|
||||
const queryResult = document.getElementById('queryResult');
|
||||
|
||||
let apiKey = localStorage.getItem('apiKey') || '';
|
||||
|
||||
queryBtn.addEventListener('click', async () => {
|
||||
const query = queryInput.value.trim();
|
||||
if (!query) {
|
||||
showToast('Please enter a query');
|
||||
return;
|
||||
}
|
||||
|
||||
queryBtn.disabled = true;
|
||||
queryBtn.innerHTML = `
|
||||
<svg class="animate-spin h-5 w-5 mr-3" viewBox="0 0 24 24">
|
||||
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4" fill="none"/>
|
||||
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"/>
|
||||
</svg>
|
||||
Processing...
|
||||
`;
|
||||
|
||||
try {
|
||||
const response = await fetchWithAuth('/query', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
query,
|
||||
mode: queryMode.value,
|
||||
stream: false,
|
||||
only_need_context: false
|
||||
})
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
queryResult.innerHTML = marked.parse(data.response);
|
||||
} catch (error) {
|
||||
showToast('Error processing query');
|
||||
} finally {
|
||||
queryBtn.disabled = false;
|
||||
queryBtn.textContent = 'Send Query';
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
'status': async () => {
|
||||
const healthStatus = document.getElementById('healthStatus');
|
||||
const configStatus = document.getElementById('configStatus');
|
||||
|
||||
try {
|
||||
const response = await fetchWithAuth('/health');
|
||||
const data = await response.json();
|
||||
|
||||
healthStatus.innerHTML = `
|
||||
<div class="space-y-2">
|
||||
<div class="flex items-center">
|
||||
<div class="w-3 h-3 rounded-full ${data.status === 'healthy' ? 'bg-green-500' : 'bg-red-500'} mr-2"></div>
|
||||
<span class="font-medium">${data.status}</span>
|
||||
</div>
|
||||
<div>
|
||||
<p class="text-sm text-gray-600">Working Directory: ${data.working_directory}</p>
|
||||
<p class="text-sm text-gray-600">Input Directory: ${data.input_directory}</p>
|
||||
<p class="text-sm text-gray-600">Indexed Files: ${data.indexed_files_count}</p>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
configStatus.innerHTML = Object.entries(data.configuration)
|
||||
.map(([key, value]) => `
|
||||
<div class="mb-2">
|
||||
<span class="text-sm font-medium text-gray-700">${key}:</span>
|
||||
<span class="text-sm text-gray-600 ml-2">${value}</span>
|
||||
</div>
|
||||
`).join('');
|
||||
} catch (error) {
|
||||
showToast('Error fetching status');
|
||||
}
|
||||
},
|
||||
|
||||
'settings': () => {
|
||||
const saveBtn = document.getElementById('saveSettings');
|
||||
const apiKeyInput = document.getElementById('apiKeyInput');
|
||||
|
||||
saveBtn.addEventListener('click', () => {
|
||||
state.apiKey = apiKeyInput.value;
|
||||
localStorage.setItem('apiKey', state.apiKey);
|
||||
showToast('Settings saved successfully');
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Navigation handling
|
||||
document.querySelectorAll('.nav-item').forEach(item => {
|
||||
item.addEventListener('click', (e) => {
|
||||
e.preventDefault();
|
||||
const page = item.dataset.page;
|
||||
document.getElementById('content').innerHTML = pages[page]();
|
||||
if (handlers[page]) handlers[page]();
|
||||
state.currentPage = page;
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize with file manager
|
||||
document.getElementById('content').innerHTML = pages['file-manager']();
|
||||
handlers['file-manager']();
|
||||
|
||||
// Global functions
|
||||
window.removeFile = (fileName) => {
|
||||
state.files = state.files.filter(file => file.name !== fileName);
|
||||
document.querySelector('#fileList div').innerHTML = state.files.map(file => `
|
||||
<div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm">
|
||||
<span>${file.name}</span>
|
||||
<button class="text-red-600 hover:text-red-700" onclick="removeFile('${file.name}')">
|
||||
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
`).join('');
|
||||
};
|
||||
211
minirag/api/static/js/graph.js
Normal file
211
minirag/api/static/js/graph.js
Normal file
@@ -0,0 +1,211 @@
|
||||
// js/graph.js
|
||||
function openGraphModal(label) {
|
||||
const modal = document.getElementById("graph-modal");
|
||||
const graphTitle = document.getElementById("graph-title");
|
||||
|
||||
if (!modal || !graphTitle) {
|
||||
console.error("Key element not found");
|
||||
return;
|
||||
}
|
||||
|
||||
graphTitle.textContent = `Knowledge Graph - ${label}`;
|
||||
modal.style.display = "flex";
|
||||
|
||||
renderGraph(label);
|
||||
}
|
||||
|
||||
function closeGraphModal() {
|
||||
const modal = document.getElementById("graph-modal");
|
||||
modal.style.display = "none";
|
||||
clearGraph();
|
||||
}
|
||||
|
||||
function clearGraph() {
|
||||
const svg = document.getElementById("graph-svg");
|
||||
svg.innerHTML = "";
|
||||
}
|
||||
|
||||
|
||||
async function getGraph(label) {
|
||||
try {
|
||||
const response = await fetch(`/graphs?label=${label}`);
|
||||
const rawData = await response.json();
|
||||
console.log({data: JSON.parse(JSON.stringify(rawData))});
|
||||
|
||||
const nodes = rawData.nodes
|
||||
|
||||
nodes.forEach(node => {
|
||||
node.id = Date.now().toString(36) + Math.random().toString(36).substring(2); // 使用 crypto.randomUUID() 生成唯一 UUID
|
||||
});
|
||||
|
||||
// Strictly verify edge data
|
||||
const edges = (rawData.edges || []).map(edge => {
|
||||
const sourceNode = nodes.find(n => n.labels.includes(edge.source));
|
||||
const targetNode = nodes.find(n => n.labels.includes(edge.target)
|
||||
)
|
||||
;
|
||||
if (!sourceNode || !targetNode) {
|
||||
console.warn("NOT VALID EDGE:", edge);
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
source: sourceNode,
|
||||
target: targetNode,
|
||||
type: edge.type || ""
|
||||
};
|
||||
}).filter(edge => edge !== null);
|
||||
|
||||
return {nodes, edges};
|
||||
} catch (error) {
|
||||
console.error("Loading graph failed:", error);
|
||||
return {nodes: [], edges: []};
|
||||
}
|
||||
}
|
||||
|
||||
async function renderGraph(label) {
|
||||
const data = await getGraph(label);
|
||||
|
||||
|
||||
if (!data.nodes || data.nodes.length === 0) {
|
||||
d3.select("#graph-svg")
|
||||
.html(`<text x="50%" y="50%" text-anchor="middle">No valid nodes</text>`);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const svg = d3.select("#graph-svg");
|
||||
const width = svg.node().clientWidth;
|
||||
const height = svg.node().clientHeight;
|
||||
|
||||
svg.selectAll("*").remove();
|
||||
|
||||
// Create a force oriented diagram layout
|
||||
const simulation = d3.forceSimulation(data.nodes)
|
||||
.force("charge", d3.forceManyBody().strength(-300))
|
||||
.force("center", d3.forceCenter(width / 2, height / 2));
|
||||
|
||||
// Add a connection (if there are valid edges)
|
||||
if (data.edges.length > 0) {
|
||||
simulation.force("link",
|
||||
d3.forceLink(data.edges)
|
||||
.id(d => d.id)
|
||||
.distance(100)
|
||||
);
|
||||
}
|
||||
|
||||
// Draw nodes
|
||||
const nodes = svg.selectAll(".node")
|
||||
.data(data.nodes)
|
||||
.enter()
|
||||
.append("circle")
|
||||
.attr("class", "node")
|
||||
.attr("r", 10)
|
||||
.call(d3.drag()
|
||||
.on("start", dragStarted)
|
||||
.on("drag", dragged)
|
||||
.on("end", dragEnded)
|
||||
);
|
||||
|
||||
|
||||
svg.append("defs")
|
||||
.append("marker")
|
||||
.attr("id", "arrow-out")
|
||||
.attr("viewBox", "0 0 10 10")
|
||||
.attr("refX", 8)
|
||||
.attr("refY", 5)
|
||||
.attr("markerWidth", 6)
|
||||
.attr("markerHeight", 6)
|
||||
.attr("orient", "auto")
|
||||
.append("path")
|
||||
.attr("d", "M0,0 L10,5 L0,10 Z")
|
||||
.attr("fill", "#999");
|
||||
|
||||
// Draw edges (with arrows)
|
||||
const links = svg.selectAll(".link")
|
||||
.data(data.edges)
|
||||
.enter()
|
||||
.append("line")
|
||||
.attr("class", "link")
|
||||
.attr("marker-end", "url(#arrow-out)"); // Always draw arrows on the target side
|
||||
|
||||
// Edge style configuration
|
||||
links
|
||||
.attr("stroke", "#999")
|
||||
.attr("stroke-width", 2)
|
||||
.attr("stroke-opacity", 0.8);
|
||||
|
||||
// Draw label (with background box)
|
||||
const labels = svg.selectAll(".label")
|
||||
.data(data.nodes)
|
||||
.enter()
|
||||
.append("text")
|
||||
.attr("class", "label")
|
||||
.text(d => d.labels[0] || "")
|
||||
.attr("text-anchor", "start")
|
||||
.attr("dy", "0.3em")
|
||||
.attr("fill", "#333");
|
||||
|
||||
// Update Location
|
||||
simulation.on("tick", () => {
|
||||
links
|
||||
.attr("x1", d => {
|
||||
// Calculate the direction vector from the source node to the target node
|
||||
const dx = d.target.x - d.source.x;
|
||||
const dy = d.target.y - d.source.y;
|
||||
const distance = Math.sqrt(dx * dx + dy * dy);
|
||||
if (distance === 0) return d.source.x; // 避免除以零 Avoid dividing by zero
|
||||
// Adjust the starting point coordinates (source node edge) based on radius 10
|
||||
return d.source.x + (dx / distance) * 10;
|
||||
})
|
||||
.attr("y1", d => {
|
||||
const dx = d.target.x - d.source.x;
|
||||
const dy = d.target.y - d.source.y;
|
||||
const distance = Math.sqrt(dx * dx + dy * dy);
|
||||
if (distance === 0) return d.source.y;
|
||||
return d.source.y + (dy / distance) * 10;
|
||||
})
|
||||
.attr("x2", d => {
|
||||
// Adjust the endpoint coordinates (target node edge) based on a radius of 10
|
||||
const dx = d.target.x - d.source.x;
|
||||
const dy = d.target.y - d.source.y;
|
||||
const distance = Math.sqrt(dx * dx + dy * dy);
|
||||
if (distance === 0) return d.target.x;
|
||||
return d.target.x - (dx / distance) * 10;
|
||||
})
|
||||
.attr("y2", d => {
|
||||
const dx = d.target.x - d.source.x;
|
||||
const dy = d.target.y - d.source.y;
|
||||
const distance = Math.sqrt(dx * dx + dy * dy);
|
||||
if (distance === 0) return d.target.y;
|
||||
return d.target.y - (dy / distance) * 10;
|
||||
});
|
||||
|
||||
// Update the position of nodes and labels (keep unchanged)
|
||||
nodes
|
||||
.attr("cx", d => d.x)
|
||||
.attr("cy", d => d.y);
|
||||
|
||||
labels
|
||||
.attr("x", d => d.x + 12)
|
||||
.attr("y", d => d.y + 4);
|
||||
});
|
||||
|
||||
// Drag and drop logic
|
||||
function dragStarted(event, d) {
|
||||
if (!event.active) simulation.alphaTarget(0.3).restart();
|
||||
d.fx = d.x;
|
||||
d.fy = d.y;
|
||||
}
|
||||
|
||||
function dragged(event, d) {
|
||||
d.fx = event.x;
|
||||
d.fy = event.y;
|
||||
simulation.alpha(0.3).restart();
|
||||
}
|
||||
|
||||
function dragEnded(event, d) {
|
||||
if (!event.active) simulation.alphaTarget(0);
|
||||
d.fx = null;
|
||||
d.fy = null;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict, Union, Literal, Generic, TypeVar
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from .utils import EmbeddingFunc
|
||||
@@ -17,16 +17,28 @@ T = TypeVar("T")
|
||||
class QueryParam:
|
||||
mode: Literal["light", "naive","mini"] = "mini"
|
||||
only_need_context: bool = False
|
||||
only_need_prompt: bool = False
|
||||
response_type: str = "Multiple Paragraphs"
|
||||
stream: bool = False
|
||||
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
||||
top_k: int = 5
|
||||
top_k: int = int(os.getenv("TOP_K", "60"))
|
||||
# Number of document chunks to retrieve.
|
||||
# top_n: int = 10
|
||||
# Number of tokens for the original chunks.
|
||||
max_token_for_text_unit: int = 2000
|
||||
max_token_for_text_unit: int = 4000
|
||||
# Number of tokens for the relationship descriptions
|
||||
max_token_for_global_context: int = 2000
|
||||
max_token_for_global_context: int = 4000
|
||||
# Number of tokens for the entity descriptions
|
||||
max_token_for_local_context: int = 2000#For Light/Graph
|
||||
max_token_for_node_context: int = 500#For Mini, if too long, SLM may be fail to generate any response
|
||||
max_token_for_local_context: int = 4000
|
||||
hl_keywords: list[str] = field(default_factory=list)
|
||||
ll_keywords: list[str] = field(default_factory=list)
|
||||
# Conversation history support
|
||||
conversation_history: list[dict] = field(
|
||||
default_factory=list
|
||||
) # Format: [{"role": "user/assistant", "content": "message"}]
|
||||
history_turns: int = (
|
||||
3 # Number of complete conversation turns (user-assistant pairs) to consider
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class StorageNameSpace:
|
||||
|
||||
621
minirag/kg/age_impl.py
Normal file
621
minirag/kg/age_impl.py
Normal file
@@ -0,0 +1,621 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("psycopg-pool"):
|
||||
pm.install("psycopg-pool")
|
||||
pm.install("psycopg[binary,pool]")
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from lightrag.utils import logger
|
||||
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
import asyncio.windows_events
|
||||
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
class AGEQueryException(Exception):
|
||||
"""Exception for the AGE queries."""
|
||||
|
||||
def __init__(self, exception: Union[str, Dict]) -> None:
|
||||
if isinstance(exception, dict):
|
||||
self.message = exception["message"] if "message" in exception else "unknown"
|
||||
self.details = exception["details"] if "details" in exception else "unknown"
|
||||
else:
|
||||
self.message = exception
|
||||
self.details = "unknown"
|
||||
|
||||
def get_message(self) -> str:
|
||||
return self.message
|
||||
|
||||
def get_details(self) -> Any:
|
||||
return self.details
|
||||
|
||||
|
||||
@dataclass
|
||||
class AGEStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name):
|
||||
print("no preloading of graph with AGE in production")
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self._driver = None
|
||||
self._driver_lock = asyncio.Lock()
|
||||
DB = os.environ["AGE_POSTGRES_DB"].replace("\\", "\\\\").replace("'", "\\'")
|
||||
USER = os.environ["AGE_POSTGRES_USER"].replace("\\", "\\\\").replace("'", "\\'")
|
||||
PASSWORD = (
|
||||
os.environ["AGE_POSTGRES_PASSWORD"]
|
||||
.replace("\\", "\\\\")
|
||||
.replace("'", "\\'")
|
||||
)
|
||||
HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'")
|
||||
PORT = int(os.environ["AGE_POSTGRES_PORT"])
|
||||
self.graph_name = os.environ["AGE_GRAPH_NAME"]
|
||||
|
||||
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
|
||||
|
||||
self._driver = AsyncConnectionPool(connection_string, open=False)
|
||||
|
||||
return None
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
self._driver = None
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
if self._driver:
|
||||
await self._driver.close()
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
|
||||
@staticmethod
|
||||
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a record returned from an age query to a dictionary
|
||||
|
||||
Args:
|
||||
record (): a record from an age query result
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: a dictionary representation of the record where
|
||||
the dictionary key is the field name and the value is the
|
||||
value converted to a python type
|
||||
"""
|
||||
# result holder
|
||||
d = {}
|
||||
|
||||
# prebuild a mapping of vertex_id to vertex mappings to be used
|
||||
# later to build edges
|
||||
vertices = {}
|
||||
for k in record._fields:
|
||||
v = getattr(record, k)
|
||||
# agtype comes back '{key: value}::type' which must be parsed
|
||||
if isinstance(v, str) and "::" in v:
|
||||
dtype = v.split("::")[-1]
|
||||
v = v.split("::")[0]
|
||||
if dtype == "vertex":
|
||||
vertex = json.loads(v)
|
||||
vertices[vertex["id"]] = vertex.get("properties")
|
||||
|
||||
# iterate returned fields and parse appropriately
|
||||
for k in record._fields:
|
||||
v = getattr(record, k)
|
||||
if isinstance(v, str) and "::" in v:
|
||||
dtype = v.split("::")[-1]
|
||||
v = v.split("::")[0]
|
||||
else:
|
||||
dtype = ""
|
||||
|
||||
if dtype == "vertex":
|
||||
vertex = json.loads(v)
|
||||
field = json.loads(v).get("properties")
|
||||
if not field:
|
||||
field = {}
|
||||
field["label"] = AGEStorage._decode_graph_label(vertex["label"])
|
||||
d[k] = field
|
||||
# convert edge from id-label->id by replacing id with node information
|
||||
# we only do this if the vertex was also returned in the query
|
||||
# this is an attempt to be consistent with neo4j implementation
|
||||
elif dtype == "edge":
|
||||
edge = json.loads(v)
|
||||
d[k] = (
|
||||
vertices.get(edge["start_id"], {}),
|
||||
edge[
|
||||
"label"
|
||||
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
|
||||
vertices.get(edge["end_id"], {}),
|
||||
)
|
||||
else:
|
||||
d[k] = json.loads(v) if isinstance(v, str) else v
|
||||
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def _format_properties(
|
||||
properties: Dict[str, Any], _id: Union[str, None] = None
|
||||
) -> str:
|
||||
"""
|
||||
Convert a dictionary of properties to a string representation that
|
||||
can be used in a cypher query insert/merge statement.
|
||||
|
||||
Args:
|
||||
properties (Dict[str,str]): a dictionary containing node/edge properties
|
||||
id (Union[str, None]): the id of the node or None if none exists
|
||||
|
||||
Returns:
|
||||
str: the properties dictionary as a properly formatted string
|
||||
"""
|
||||
props = []
|
||||
# wrap property key in backticks to escape
|
||||
for k, v in properties.items():
|
||||
prop = f"`{k}`: {json.dumps(v)}"
|
||||
props.append(prop)
|
||||
if _id is not None and "id" not in properties:
|
||||
props.append(
|
||||
f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}"
|
||||
)
|
||||
return "{" + ", ".join(props) + "}"
|
||||
|
||||
@staticmethod
|
||||
def _encode_graph_label(label: str) -> str:
|
||||
"""
|
||||
Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
|
||||
|
||||
Args:
|
||||
label (str): the original label
|
||||
|
||||
Returns:
|
||||
str: the encoded label
|
||||
"""
|
||||
return "x" + label.encode().hex()
|
||||
|
||||
@staticmethod
|
||||
def _decode_graph_label(encoded_label: str) -> str:
|
||||
"""
|
||||
Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
|
||||
|
||||
Args:
|
||||
encoded_label (str): the encoded label
|
||||
|
||||
Returns:
|
||||
str: the decoded label
|
||||
"""
|
||||
return bytes.fromhex(encoded_label.removeprefix("x")).decode()
|
||||
|
||||
@staticmethod
|
||||
def _get_col_name(field: str, idx: int) -> str:
|
||||
"""
|
||||
Convert a cypher return field to a pgsql select field
|
||||
If possible keep the cypher column name, but create a generic name if necessary
|
||||
|
||||
Args:
|
||||
field (str): a return field from a cypher query to be formatted for pgsql
|
||||
idx (int): the position of the field in the return statement
|
||||
|
||||
Returns:
|
||||
str: the field to be used in the pgsql select statement
|
||||
"""
|
||||
# remove white space
|
||||
field = field.strip()
|
||||
# if an alias is provided for the field, use it
|
||||
if " as " in field:
|
||||
return field.split(" as ")[-1].strip()
|
||||
# if the return value is an unnamed primitive, give it a generic name
|
||||
if field.isnumeric() or field in ("true", "false", "null"):
|
||||
return f"column_{idx}"
|
||||
# otherwise return the value stripping out some common special chars
|
||||
return field.replace("(", "_").replace(")", "")
|
||||
|
||||
@staticmethod
|
||||
def _wrap_query(query: str, graph_name: str, **params: str) -> str:
|
||||
"""
|
||||
Convert a cypher query to an Apache Age compatible
|
||||
sql query by wrapping the cypher query in ag_catalog.cypher,
|
||||
casting results to agtype and building a select statement
|
||||
|
||||
Args:
|
||||
query (str): a valid cypher query
|
||||
graph_name (str): the name of the graph to query
|
||||
params (dict): parameters for the query
|
||||
|
||||
Returns:
|
||||
str: an equivalent pgsql query
|
||||
"""
|
||||
|
||||
# pgsql template
|
||||
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
|
||||
{query}
|
||||
$$) AS ({fields});"""
|
||||
|
||||
# if there are any returned fields they must be added to the pgsql query
|
||||
if "return" in query.lower():
|
||||
# parse return statement to identify returned fields
|
||||
fields = (
|
||||
query.lower()
|
||||
.split("return")[-1]
|
||||
.split("distinct")[-1]
|
||||
.split("order by")[0]
|
||||
.split("skip")[0]
|
||||
.split("limit")[0]
|
||||
.split(",")
|
||||
)
|
||||
|
||||
# raise exception if RETURN * is found as we can't resolve the fields
|
||||
if "*" in [x.strip() for x in fields]:
|
||||
raise ValueError(
|
||||
"AGE graph does not support 'RETURN *'"
|
||||
+ " statements in Cypher queries"
|
||||
)
|
||||
|
||||
# get pgsql formatted field names
|
||||
fields = [
|
||||
AGEStorage._get_col_name(field, idx) for idx, field in enumerate(fields)
|
||||
]
|
||||
|
||||
# build resulting pgsql relation
|
||||
fields_str = ", ".join(
|
||||
[field.split(".")[-1] + " agtype" for field in fields]
|
||||
)
|
||||
|
||||
# if no return statement we still need to return a single field of type agtype
|
||||
else:
|
||||
fields_str = "a agtype"
|
||||
|
||||
select_str = "*"
|
||||
|
||||
return template.format(
|
||||
graph_name=graph_name,
|
||||
query=query.format(**params),
|
||||
fields=fields_str,
|
||||
projection=select_str,
|
||||
)
|
||||
|
||||
async def _query(self, query: str, **params: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query the graph by taking a cypher query, converting it to an
|
||||
age compatible query, executing it and converting the result
|
||||
|
||||
Args:
|
||||
query (str): a cypher query to be executed
|
||||
params (dict): parameters for the query
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: a list of dictionaries containing the result set
|
||||
"""
|
||||
# convert cypher query to pgsql/age query
|
||||
wrapped_query = self._wrap_query(query, self.graph_name, **params)
|
||||
|
||||
await self._driver.open()
|
||||
|
||||
# create graph if it doesn't exist
|
||||
async with self._get_pool_connection() as conn:
|
||||
async with conn.cursor() as curs:
|
||||
try:
|
||||
await curs.execute('SET search_path = ag_catalog, "$user", public')
|
||||
await curs.execute(f"SELECT create_graph('{self.graph_name}')")
|
||||
await conn.commit()
|
||||
except (
|
||||
psycopg.errors.InvalidSchemaName,
|
||||
psycopg.errors.UniqueViolation,
|
||||
):
|
||||
await conn.rollback()
|
||||
|
||||
# execute the query, rolling back on an error
|
||||
async with self._get_pool_connection() as conn:
|
||||
async with conn.cursor(row_factory=namedtuple_row) as curs:
|
||||
try:
|
||||
await curs.execute('SET search_path = ag_catalog, "$user", public')
|
||||
await curs.execute(wrapped_query)
|
||||
await conn.commit()
|
||||
except psycopg.Error as e:
|
||||
await conn.rollback()
|
||||
raise AGEQueryException(
|
||||
{
|
||||
"message": f"Error executing graph query: {query.format(**params)}",
|
||||
"detail": str(e),
|
||||
}
|
||||
) from e
|
||||
|
||||
data = await curs.fetchall()
|
||||
if data is None:
|
||||
result = []
|
||||
# decode records
|
||||
else:
|
||||
result = [AGEStorage._record_to_dict(d) for d in data]
|
||||
|
||||
return result
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
query = """
|
||||
MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists
|
||||
"""
|
||||
params = {"label": AGEStorage._encode_graph_label(entity_name_label)}
|
||||
single_result = (await self._query(query, **params))[0]
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query.format(**params),
|
||||
single_result["node_exists"],
|
||||
)
|
||||
|
||||
return single_result["node_exists"]
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
query = """
|
||||
MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
|
||||
RETURN COUNT(r) > 0 AS edge_exists
|
||||
"""
|
||||
params = {
|
||||
"src_label": AGEStorage._encode_graph_label(entity_name_label_source),
|
||||
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target),
|
||||
}
|
||||
single_result = (await self._query(query, **params))[0]
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query.format(**params),
|
||||
single_result["edge_exists"],
|
||||
)
|
||||
return single_result["edge_exists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
entity_name_label = node_id.strip('"')
|
||||
query = """
|
||||
MATCH (n:`{label}`) RETURN n
|
||||
"""
|
||||
params = {"label": AGEStorage._encode_graph_label(entity_name_label)}
|
||||
record = await self._query(query, **params)
|
||||
if record:
|
||||
node = record[0]
|
||||
node_dict = node["n"]
|
||||
logger.debug(
|
||||
"{%s}: query: {%s}, result: {%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query.format(**params),
|
||||
node_dict,
|
||||
)
|
||||
return node_dict
|
||||
return None
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
query = """
|
||||
MATCH (n:`{label}`)-[]->(x)
|
||||
RETURN count(x) AS total_edge_count
|
||||
"""
|
||||
params = {"label": AGEStorage._encode_graph_label(entity_name_label)}
|
||||
record = (await self._query(query, **params))[0]
|
||||
if record:
|
||||
edge_count = int(record["total_edge_count"])
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query.format(**params),
|
||||
edge_count,
|
||||
)
|
||||
return edge_count
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
entity_name_label_source = src_id.strip('"')
|
||||
entity_name_label_target = tgt_id.strip('"')
|
||||
src_degree = await self.node_degree(entity_name_label_source)
|
||||
trg_degree = await self.node_degree(entity_name_label_target)
|
||||
|
||||
# Convert None to 0 for addition
|
||||
src_degree = 0 if src_degree is None else src_degree
|
||||
trg_degree = 0 if trg_degree is None else trg_degree
|
||||
|
||||
degrees = int(src_degree) + int(trg_degree)
|
||||
logger.debug(
|
||||
"{%s}:query:src_Degree+trg_degree:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
degrees,
|
||||
)
|
||||
return degrees
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Find all edges between nodes of two given labels
|
||||
|
||||
Args:
|
||||
source_node_label (str): Label of the source nodes
|
||||
target_node_label (str): Label of the target nodes
|
||||
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
"""
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
query = """
|
||||
MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
LIMIT 1
|
||||
"""
|
||||
params = {
|
||||
"src_label": AGEStorage._encode_graph_label(entity_name_label_source),
|
||||
"tgt_label": AGEStorage._encode_graph_label(entity_name_label_target),
|
||||
}
|
||||
record = await self._query(query, **params)
|
||||
if record and record[0] and record[0]["edge_properties"]:
|
||||
result = record[0]["edge_properties"]
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query.format(**params),
|
||||
result,
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
:return: List of dictionaries containing edge information
|
||||
"""
|
||||
node_label = source_node_id.strip('"')
|
||||
|
||||
query = """
|
||||
MATCH (n:`{label}`)
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected
|
||||
"""
|
||||
params = {"label": AGEStorage._encode_graph_label(node_label)}
|
||||
results = await self._query(query, **params)
|
||||
edges = []
|
||||
for record in results:
|
||||
source_node = record["n"] if record["n"] else None
|
||||
connected_node = record["connected"] if record["connected"] else None
|
||||
|
||||
source_label = (
|
||||
source_node["label"] if source_node and source_node["label"] else None
|
||||
)
|
||||
target_label = (
|
||||
connected_node["label"]
|
||||
if connected_node and connected_node["label"]
|
||||
else None
|
||||
)
|
||||
|
||||
if source_label and target_label:
|
||||
edges.append((source_label, target_label))
|
||||
|
||||
return edges
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((AGEQueryException,)),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||
"""
|
||||
Upsert a node in the AGE database.
|
||||
|
||||
Args:
|
||||
node_id: The unique identifier for the node (used as label)
|
||||
node_data: Dictionary of node properties
|
||||
"""
|
||||
label = node_id.strip('"')
|
||||
properties = node_data
|
||||
|
||||
query = """
|
||||
MERGE (n:`{label}`)
|
||||
SET n += {properties}
|
||||
"""
|
||||
params = {
|
||||
"label": AGEStorage._encode_graph_label(label),
|
||||
"properties": AGEStorage._format_properties(properties),
|
||||
}
|
||||
try:
|
||||
await self._query(query, **params)
|
||||
logger.debug(
|
||||
"Upserted node with label '{%s}' and properties: {%s}",
|
||||
label,
|
||||
properties,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error during upsert: {%s}", e)
|
||||
raise
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((AGEQueryException,)),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their labels.
|
||||
|
||||
Args:
|
||||
source_node_id (str): Label of the source node (used as identifier)
|
||||
target_node_id (str): Label of the target node (used as identifier)
|
||||
edge_data (dict): Dictionary of properties to set on the edge
|
||||
"""
|
||||
source_node_label = source_node_id.strip('"')
|
||||
target_node_label = target_node_id.strip('"')
|
||||
edge_properties = edge_data
|
||||
|
||||
query = """
|
||||
MATCH (source:`{src_label}`)
|
||||
WITH source
|
||||
MATCH (target:`{tgt_label}`)
|
||||
MERGE (source)-[r:DIRECTED]->(target)
|
||||
SET r += {properties}
|
||||
RETURN r
|
||||
"""
|
||||
params = {
|
||||
"src_label": AGEStorage._encode_graph_label(source_node_label),
|
||||
"tgt_label": AGEStorage._encode_graph_label(target_node_label),
|
||||
"properties": AGEStorage._format_properties(edge_properties),
|
||||
}
|
||||
try:
|
||||
await self._query(query, **params)
|
||||
logger.debug(
|
||||
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
|
||||
source_node_label,
|
||||
target_node_label,
|
||||
edge_properties,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error during edge upsert: {%s}", e)
|
||||
raise
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
print("Implemented but never called.")
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_pool_connection(self, timeout: Optional[float] = None):
|
||||
"""Workaround for a psycopg_pool bug"""
|
||||
|
||||
try:
|
||||
connection = await self._driver.getconn(timeout=timeout)
|
||||
except PoolTimeout:
|
||||
await self._driver._add_connection(None) # workaround...
|
||||
connection = await self._driver.getconn(timeout=timeout)
|
||||
|
||||
try:
|
||||
async with connection:
|
||||
yield connection
|
||||
finally:
|
||||
await self._driver.putconn(connection)
|
||||
173
minirag/kg/chroma_impl.py
Normal file
173
minirag/kg/chroma_impl.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import os
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
from chromadb import HttpClient
|
||||
from chromadb.config import Settings
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from lightrag.utils import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
"""ChromaDB vector storage implementation."""
|
||||
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
|
||||
def __post_init__(self):
|
||||
try:
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
user_collection_settings = config.get("collection_settings", {})
|
||||
# Default HNSW index settings for ChromaDB
|
||||
default_collection_settings = {
|
||||
# Distance metric used for similarity search (cosine similarity)
|
||||
"hnsw:space": "cosine",
|
||||
# Number of nearest neighbors to explore during index construction
|
||||
# Higher values = better recall but slower indexing
|
||||
"hnsw:construction_ef": 128,
|
||||
# Number of nearest neighbors to explore during search
|
||||
# Higher values = better recall but slower search
|
||||
"hnsw:search_ef": 128,
|
||||
# Number of connections per node in the HNSW graph
|
||||
# Higher values = better recall but more memory usage
|
||||
"hnsw:M": 16,
|
||||
# Number of vectors to process in one batch during indexing
|
||||
"hnsw:batch_size": 100,
|
||||
# Number of updates before forcing index synchronization
|
||||
# Lower values = more frequent syncs but slower indexing
|
||||
"hnsw:sync_threshold": 1000,
|
||||
}
|
||||
collection_settings = {
|
||||
**default_collection_settings,
|
||||
**user_collection_settings,
|
||||
}
|
||||
|
||||
auth_provider = config.get(
|
||||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
|
||||
)
|
||||
auth_credentials = config.get("auth_token", "secret-token")
|
||||
headers = {}
|
||||
|
||||
if "token_authn" in auth_provider:
|
||||
headers = {
|
||||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
|
||||
}
|
||||
elif "basic_authn" in auth_provider:
|
||||
auth_credentials = config.get("auth_credentials", "admin:admin")
|
||||
|
||||
self._client = HttpClient(
|
||||
host=config.get("host", "localhost"),
|
||||
port=config.get("port", 8000),
|
||||
headers=headers,
|
||||
settings=Settings(
|
||||
chroma_api_impl="rest",
|
||||
chroma_client_auth_provider=auth_provider,
|
||||
chroma_client_auth_credentials=auth_credentials,
|
||||
allow_reset=True,
|
||||
anonymized_telemetry=False,
|
||||
),
|
||||
)
|
||||
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=self.namespace,
|
||||
metadata={
|
||||
**collection_settings,
|
||||
"dimension": self.embedding_func.embedding_dim,
|
||||
},
|
||||
)
|
||||
# Use batch size from collection settings if specified
|
||||
self._max_batch_size = self.global_config.get(
|
||||
"embedding_batch_num", collection_settings.get("hnsw:batch_size", 32)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"ChromaDB initialization failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
if not data:
|
||||
logger.warning("Empty data provided to vector DB")
|
||||
return []
|
||||
|
||||
try:
|
||||
ids = list(data.keys())
|
||||
documents = [v["content"] for v in data.values()]
|
||||
metadatas = [
|
||||
{k: v for k, v in item.items() if k in self.meta_fields}
|
||||
or {"_default": "true"}
|
||||
for item in data.values()
|
||||
]
|
||||
|
||||
# Process in batches
|
||||
batches = [
|
||||
documents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(documents), self._max_batch_size)
|
||||
]
|
||||
|
||||
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
||||
embeddings_list = []
|
||||
|
||||
# Pre-allocate embeddings_list with known size
|
||||
embeddings_list = [None] * len(embedding_tasks)
|
||||
|
||||
# Use asyncio.gather instead of as_completed if order doesn't matter
|
||||
embeddings_results = await asyncio.gather(*embedding_tasks)
|
||||
embeddings_list = list(embeddings_results)
|
||||
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
|
||||
# Upsert in batches
|
||||
for i in range(0, len(ids), self._max_batch_size):
|
||||
batch_slice = slice(i, i + self._max_batch_size)
|
||||
|
||||
self._collection.upsert(
|
||||
ids=ids[batch_slice],
|
||||
embeddings=embeddings[batch_slice].tolist(),
|
||||
documents=documents[batch_slice],
|
||||
metadatas=metadatas[batch_slice],
|
||||
)
|
||||
|
||||
return ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ChromaDB upsert: {str(e)}")
|
||||
raise
|
||||
|
||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
||||
try:
|
||||
embedding = await self.embedding_func([query])
|
||||
|
||||
results = self._collection.query(
|
||||
query_embeddings=embedding.tolist(),
|
||||
n_results=top_k * 2, # Request more results to allow for filtering
|
||||
include=["metadatas", "distances", "documents"],
|
||||
)
|
||||
|
||||
# Filter results by cosine similarity threshold and take top k
|
||||
# We request 2x results initially to have enough after filtering
|
||||
# ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal)
|
||||
# We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity)
|
||||
# Only keep results with distance below threshold, then take top k
|
||||
return [
|
||||
{
|
||||
"id": results["ids"][0][i],
|
||||
"distance": 1 - results["distances"][0][i],
|
||||
"content": results["documents"][0][i],
|
||||
**results["metadatas"][0][i],
|
||||
}
|
||||
for i in range(len(results["ids"][0]))
|
||||
if (1 - results["distances"][0][i]) >= self.cosine_better_than_threshold
|
||||
][:top_k]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ChromaDB query: {str(e)}")
|
||||
raise
|
||||
|
||||
async def index_done_callback(self):
|
||||
# ChromaDB handles persistence automatically
|
||||
pass
|
||||
397
minirag/kg/gremlin_impl.py
Normal file
397
minirag/kg/gremlin_impl.py
Normal file
@@ -0,0 +1,397 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from gremlin_python.driver import client, serializer
|
||||
from gremlin_python.driver.aiohttp.transport import AiohttpTransport
|
||||
from gremlin_python.driver.protocol import GremlinServerError
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from lightrag.utils import logger
|
||||
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class GremlinStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name):
|
||||
print("no preloading of graph with Gremlin in production")
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
|
||||
self._driver = None
|
||||
self._driver_lock = asyncio.Lock()
|
||||
|
||||
USER = os.environ.get("GREMLIN_USER", "")
|
||||
PASSWORD = os.environ.get("GREMLIN_PASSWORD", "")
|
||||
HOST = os.environ["GREMLIN_HOST"]
|
||||
PORT = int(os.environ["GREMLIN_PORT"])
|
||||
|
||||
# TraversalSource, a custom one has to be created manually,
|
||||
# default it "g"
|
||||
SOURCE = os.environ.get("GREMLIN_TRAVERSE_SOURCE", "g")
|
||||
|
||||
# All vertices will have graph={GRAPH} property, so that we can
|
||||
# have several logical graphs for one source
|
||||
GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"])
|
||||
|
||||
self.graph_name = GRAPH
|
||||
|
||||
self._driver = client.Client(
|
||||
f"ws://{HOST}:{PORT}/gremlin",
|
||||
SOURCE,
|
||||
username=USER,
|
||||
password=PASSWORD,
|
||||
message_serializer=serializer.GraphSONSerializersV3d0(),
|
||||
transport_factory=lambda: AiohttpTransport(call_from_event_loop=True),
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
if self._driver:
|
||||
self._driver.close()
|
||||
self._driver = None
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
if self._driver:
|
||||
self._driver.close()
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
|
||||
@staticmethod
|
||||
def _to_value_map(value: Any) -> str:
|
||||
"""Dump supported Python object as Gremlin valueMap"""
|
||||
json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
|
||||
parsed_str = json_str.replace("'", r"\'")
|
||||
|
||||
# walk over the string and replace curly brackets with square brackets
|
||||
# outside of strings, as well as replace double quotes with single quotes
|
||||
# and "deescape" double quotes inside of strings
|
||||
outside_str = True
|
||||
escaped = False
|
||||
remove_indices = []
|
||||
for i, c in enumerate(parsed_str):
|
||||
if escaped:
|
||||
# previous character was an "odd" backslash
|
||||
escaped = False
|
||||
if c == '"':
|
||||
# we want to "deescape" double quotes: store indices to delete
|
||||
remove_indices.insert(0, i - 1)
|
||||
elif c == "\\":
|
||||
escaped = True
|
||||
elif c == '"':
|
||||
outside_str = not outside_str
|
||||
parsed_str = parsed_str[:i] + "'" + parsed_str[i + 1 :]
|
||||
elif c == "{" and outside_str:
|
||||
parsed_str = parsed_str[:i] + "[" + parsed_str[i + 1 :]
|
||||
elif c == "}" and outside_str:
|
||||
parsed_str = parsed_str[:i] + "]" + parsed_str[i + 1 :]
|
||||
for idx in remove_indices:
|
||||
parsed_str = parsed_str[:idx] + parsed_str[idx + 1 :]
|
||||
return parsed_str
|
||||
|
||||
@staticmethod
|
||||
def _convert_properties(properties: Dict[str, Any]) -> str:
|
||||
"""Create chained .property() commands from properties dict"""
|
||||
props = []
|
||||
for k, v in properties.items():
|
||||
prop_name = GremlinStorage._to_value_map(k)
|
||||
props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})")
|
||||
return "".join(props)
|
||||
|
||||
@staticmethod
|
||||
def _fix_name(name: str) -> str:
|
||||
"""Strip double quotes and format as a proper field name"""
|
||||
name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'"))
|
||||
|
||||
return name
|
||||
|
||||
async def _query(self, query: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query the Gremlin graph
|
||||
|
||||
Args:
|
||||
query (str): a query to be executed
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: a list of dictionaries containing the result set
|
||||
"""
|
||||
|
||||
result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
|
||||
if result:
|
||||
result = result[0]
|
||||
|
||||
return result
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name = GremlinStorage._fix_name(node_id)
|
||||
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
.has('entity_name', {entity_name})
|
||||
.limit(1)
|
||||
.count()
|
||||
.project('has_node')
|
||||
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
|
||||
"""
|
||||
result = await self._query(query)
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query,
|
||||
result[0]["has_node"],
|
||||
)
|
||||
|
||||
return result[0]["has_node"]
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
||||
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
||||
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
.has('entity_name', {entity_name_source})
|
||||
.outE()
|
||||
.inV().has('graph', {self.graph_name})
|
||||
.has('entity_name', {entity_name_target})
|
||||
.limit(1)
|
||||
.count()
|
||||
.project('has_edge')
|
||||
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
|
||||
"""
|
||||
result = await self._query(query)
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query,
|
||||
result[0]["has_edge"],
|
||||
)
|
||||
|
||||
return result[0]["has_edge"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
entity_name = GremlinStorage._fix_name(node_id)
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
.has('entity_name', {entity_name})
|
||||
.limit(1)
|
||||
.project('properties')
|
||||
.by(elementMap())
|
||||
"""
|
||||
result = await self._query(query)
|
||||
if result:
|
||||
node = result[0]
|
||||
node_dict = node["properties"]
|
||||
logger.debug(
|
||||
"{%s}: query: {%s}, result: {%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query.format,
|
||||
node_dict,
|
||||
)
|
||||
return node_dict
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
entity_name = GremlinStorage._fix_name(node_id)
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
.has('entity_name', {entity_name})
|
||||
.outE()
|
||||
.inV().has('graph', {self.graph_name})
|
||||
.count()
|
||||
.project('total_edge_count')
|
||||
.by()
|
||||
"""
|
||||
result = await self._query(query)
|
||||
edge_count = result[0]["total_edge_count"]
|
||||
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query,
|
||||
edge_count,
|
||||
)
|
||||
|
||||
return edge_count
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
src_degree = await self.node_degree(src_id)
|
||||
trg_degree = await self.node_degree(tgt_id)
|
||||
|
||||
# Convert None to 0 for addition
|
||||
src_degree = 0 if src_degree is None else src_degree
|
||||
trg_degree = 0 if trg_degree is None else trg_degree
|
||||
|
||||
degrees = int(src_degree) + int(trg_degree)
|
||||
logger.debug(
|
||||
"{%s}:query:src_Degree+trg_degree:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
degrees,
|
||||
)
|
||||
return degrees
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Find all edges between nodes of two given names
|
||||
|
||||
Args:
|
||||
source_node_id (str): Name of the source nodes
|
||||
target_node_id (str): Name of the target nodes
|
||||
|
||||
Returns:
|
||||
dict|None: Dict of found edge properties, or None if not found
|
||||
"""
|
||||
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
||||
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
.has('entity_name', {entity_name_source})
|
||||
.outE()
|
||||
.inV().has('graph', {self.graph_name})
|
||||
.has('entity_name', {entity_name_target})
|
||||
.limit(1)
|
||||
.project('edge_properties')
|
||||
.by(__.bothE().elementMap())
|
||||
"""
|
||||
result = await self._query(query)
|
||||
if result:
|
||||
edge_properties = result[0]["edge_properties"]
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query,
|
||||
edge_properties,
|
||||
)
|
||||
return edge_properties
|
||||
|
||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Retrieves all edges (relationships) for a particular node identified by its name.
|
||||
:return: List of tuples containing edge sources and targets
|
||||
"""
|
||||
node_name = GremlinStorage._fix_name(source_node_id)
|
||||
query = f"""g
|
||||
.E()
|
||||
.filter(
|
||||
__.or(
|
||||
__.outV().has('graph', {self.graph_name})
|
||||
.has('entity_name', {node_name}),
|
||||
__.inV().has('graph', {self.graph_name})
|
||||
.has('entity_name', {node_name})
|
||||
)
|
||||
)
|
||||
.project('source_name', 'target_name')
|
||||
.by(__.outV().values('entity_name'))
|
||||
.by(__.inV().values('entity_name'))
|
||||
"""
|
||||
result = await self._query(query)
|
||||
edges = [(res["source_name"], res["target_name"]) for res in result]
|
||||
|
||||
return edges
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(10),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((GremlinServerError,)),
|
||||
)
|
||||
async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
|
||||
"""
|
||||
Upsert a node in the Gremlin graph.
|
||||
|
||||
Args:
|
||||
node_id: The unique identifier for the node (used as name)
|
||||
node_data: Dictionary of node properties
|
||||
"""
|
||||
name = GremlinStorage._fix_name(node_id)
|
||||
properties = GremlinStorage._convert_properties(node_data)
|
||||
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
.has('entity_name', {name})
|
||||
.fold()
|
||||
.coalesce(
|
||||
__.unfold(),
|
||||
__.addV('ENTITY')
|
||||
.property('graph', {self.graph_name})
|
||||
.property('entity_name', {name})
|
||||
)
|
||||
{properties}
|
||||
"""
|
||||
|
||||
try:
|
||||
await self._query(query)
|
||||
logger.debug(
|
||||
"Upserted node with name {%s} and properties: {%s}",
|
||||
name,
|
||||
properties,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error during upsert: {%s}", e)
|
||||
raise
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(10),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((GremlinServerError,)),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their names.
|
||||
|
||||
Args:
|
||||
source_node_id (str): Name of the source node (used as identifier)
|
||||
target_node_id (str): Name of the target node (used as identifier)
|
||||
edge_data (dict): Dictionary of properties to set on the edge
|
||||
"""
|
||||
source_node_name = GremlinStorage._fix_name(source_node_id)
|
||||
target_node_name = GremlinStorage._fix_name(target_node_id)
|
||||
edge_properties = GremlinStorage._convert_properties(edge_data)
|
||||
|
||||
query = f"""g
|
||||
.V().has('graph', {self.graph_name})
|
||||
.has('entity_name', {source_node_name}).as('source')
|
||||
.V().has('graph', {self.graph_name})
|
||||
.has('entity_name', {target_node_name}).as('target')
|
||||
.coalesce(
|
||||
__.select('source').outE('DIRECTED').where(__.inV().as('target')),
|
||||
__.select('source').addE('DIRECTED').to(__.select('target'))
|
||||
)
|
||||
.property('graph', {self.graph_name})
|
||||
{edge_properties}
|
||||
"""
|
||||
try:
|
||||
await self._query(query)
|
||||
logger.debug(
|
||||
"Upserted edge from {%s} to {%s} with properties: {%s}",
|
||||
source_node_name,
|
||||
target_node_name,
|
||||
edge_properties,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error during edge upsert: {%s}", e)
|
||||
raise
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
print("Implemented but never called.")
|
||||
134
minirag/kg/json_kv_impl.py
Normal file
134
minirag/kg/json_kv_impl.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
JsonDocStatus Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
load_json,
|
||||
write_json,
|
||||
)
|
||||
|
||||
from lightrag.base import (
|
||||
BaseKVStorage,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = load_json(self._file_name) or {}
|
||||
self._lock = asyncio.Lock()
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
return list(self._data.keys())
|
||||
|
||||
async def index_done_callback(self):
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def get_by_id(self, id):
|
||||
return self._data.get(id, None)
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
if fields is None:
|
||||
return [self._data.get(id, None) for id in ids]
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items() if k in fields}
|
||||
if self._data.get(id, None)
|
||||
else None
|
||||
)
|
||||
for id in ids
|
||||
]
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
return set([s for s in data if s not in self._data])
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
return left_data
|
||||
|
||||
async def drop(self):
|
||||
self._data = {}
|
||||
|
||||
async def filter(self, filter_func):
|
||||
"""Filter key-value pairs based on a filter function
|
||||
|
||||
Args:
|
||||
filter_func: The filter function, which takes a value as an argument and returns a boolean value
|
||||
|
||||
Returns:
|
||||
Dict: Key-value pairs that meet the condition
|
||||
"""
|
||||
result = {}
|
||||
async with self._lock:
|
||||
for key, value in self._data.items():
|
||||
if filter_func(value):
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete data with specified IDs
|
||||
|
||||
Args:
|
||||
ids: List of IDs to delete
|
||||
"""
|
||||
async with self._lock:
|
||||
for id in ids:
|
||||
if id in self._data:
|
||||
del self._data[id]
|
||||
await self.index_done_callback()
|
||||
logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}")
|
||||
128
minirag/kg/jsondocstatus_impl.py
Normal file
128
minirag/kg/jsondocstatus_impl.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
JsonDocStatus Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Union, Dict
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
load_json,
|
||||
write_json,
|
||||
)
|
||||
|
||||
from lightrag.base import (
|
||||
DocStatus,
|
||||
DocProcessingStatus,
|
||||
DocStatusStorage,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonDocStatusStorage(DocStatusStorage):
|
||||
"""JSON implementation of document status storage"""
|
||||
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = load_json(self._file_name) or {}
|
||||
logger.info(f"Loaded document status storage with {len(self._data)} records")
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
"""Return keys that should be processed (not in storage or not successfully processed)"""
|
||||
return set(
|
||||
[
|
||||
k
|
||||
for k in data
|
||||
if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED
|
||||
]
|
||||
)
|
||||
|
||||
async def get_status_counts(self) -> Dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
counts = {status: 0 for status in DocStatus}
|
||||
for doc in self._data.values():
|
||||
counts[doc["status"]] += 1
|
||||
return counts
|
||||
|
||||
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
|
||||
"""Get all failed documents"""
|
||||
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED}
|
||||
|
||||
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
|
||||
"""Get all pending documents"""
|
||||
return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING}
|
||||
|
||||
async def index_done_callback(self):
|
||||
"""Save data to file after indexing"""
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
"""Update or insert document status
|
||||
|
||||
Args:
|
||||
data: Dictionary of document IDs and their status data
|
||||
"""
|
||||
self._data.update(data)
|
||||
await self.index_done_callback()
|
||||
return data
|
||||
|
||||
async def get_by_id(self, id: str):
|
||||
return self._data.get(id)
|
||||
|
||||
async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]:
|
||||
"""Get document status by ID"""
|
||||
return self._data.get(doc_id)
|
||||
|
||||
async def delete(self, doc_ids: list[str]):
|
||||
"""Delete document status by IDs"""
|
||||
for doc_id in doc_ids:
|
||||
self._data.pop(doc_id, None)
|
||||
await self.index_done_callback()
|
||||
94
minirag/kg/milvus_impl.py
Normal file
94
minirag/kg/milvus_impl.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import asyncio
|
||||
import os
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from lightrag.utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("pymilvus"):
|
||||
pm.install("pymilvus")
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
|
||||
@dataclass
|
||||
class MilvusVectorDBStorge(BaseVectorStorage):
|
||||
@staticmethod
|
||||
def create_collection_if_not_exist(
|
||||
client: MilvusClient, collection_name: str, **kwargs
|
||||
):
|
||||
if client.has_collection(collection_name):
|
||||
return
|
||||
client.create_collection(
|
||||
collection_name, max_length=64, id_type="string", **kwargs
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self._client = MilvusClient(
|
||||
uri=os.environ.get(
|
||||
"MILVUS_URI",
|
||||
os.path.join(self.global_config["working_dir"], "milvus_lite.db"),
|
||||
),
|
||||
user=os.environ.get("MILVUS_USER", ""),
|
||||
password=os.environ.get("MILVUS_PASSWORD", ""),
|
||||
token=os.environ.get("MILVUS_TOKEN", ""),
|
||||
db_name=os.environ.get("MILVUS_DB_NAME", ""),
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
MilvusVectorDBStorge.create_collection_if_not_exist(
|
||||
self._client,
|
||||
self.namespace,
|
||||
dimension=self.embedding_func.embedding_dim,
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
list_data = [
|
||||
{
|
||||
"id": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
|
||||
async def wrapped_task(batch):
|
||||
result = await self.embedding_func(batch)
|
||||
pbar.update(1)
|
||||
return result
|
||||
|
||||
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
||||
pbar = tqdm_async(
|
||||
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
||||
)
|
||||
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["vector"] = embeddings[i]
|
||||
results = self._client.upsert(collection_name=self.namespace, data=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query, top_k=5):
|
||||
embedding = await self.embedding_func([query])
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
data=embedding,
|
||||
limit=top_k,
|
||||
output_fields=list(self.meta_fields),
|
||||
search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
|
||||
)
|
||||
print(results)
|
||||
return [
|
||||
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
|
||||
for dp in results[0]
|
||||
]
|
||||
440
minirag/kg/mongo_impl.py
Normal file
440
minirag/kg/mongo_impl.py
Normal file
@@ -0,0 +1,440 @@
|
||||
import os
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
import np
|
||||
|
||||
if not pm.is_installed("pymongo"):
|
||||
pm.install("pymongo")
|
||||
|
||||
from pymongo import MongoClient
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from typing import Union, List, Tuple
|
||||
from lightrag.utils import logger
|
||||
|
||||
from lightrag.base import BaseKVStorage
|
||||
from lightrag.base import BaseGraphStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class MongoKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
client = MongoClient(
|
||||
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
|
||||
)
|
||||
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
|
||||
self._data = database.get_collection(self.namespace)
|
||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
return [x["_id"] for x in self._data.find({}, {"_id": 1})]
|
||||
|
||||
async def get_by_id(self, id):
|
||||
return self._data.find_one({"_id": id})
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
if fields is None:
|
||||
return list(self._data.find({"_id": {"$in": ids}}))
|
||||
return list(
|
||||
self._data.find(
|
||||
{"_id": {"$in": ids}},
|
||||
{field: 1 for field in fields},
|
||||
)
|
||||
)
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
existing_ids = [
|
||||
str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1})
|
||||
]
|
||||
return set([s for s in data if s not in existing_ids])
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
if self.namespace == "llm_response_cache":
|
||||
for mode, items in data.items():
|
||||
for k, v in tqdm_async(items.items(), desc="Upserting"):
|
||||
key = f"{mode}_{k}"
|
||||
result = self._data.update_one(
|
||||
{"_id": key}, {"$setOnInsert": v}, upsert=True
|
||||
)
|
||||
if result.upserted_id:
|
||||
logger.debug(f"\nInserted new document with key: {key}")
|
||||
data[mode][k]["_id"] = key
|
||||
else:
|
||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
||||
self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
|
||||
data[k]["_id"] = k
|
||||
return data
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
if "llm_response_cache" == self.namespace:
|
||||
res = {}
|
||||
v = self._data.find_one({"_id": mode + "_" + id})
|
||||
if v:
|
||||
res[id] = v
|
||||
logger.debug(f"llm_response_cache find one by:{id}")
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
async def drop(self):
|
||||
""" """
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MongoGraphStorage(BaseGraphStorage):
|
||||
"""
|
||||
A concrete implementation using MongoDB’s $graphLookup to demonstrate multi-hop queries.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self.client = AsyncIOMotorClient(
|
||||
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
|
||||
)
|
||||
self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")]
|
||||
self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")]
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
# HELPER: $graphLookup pipeline
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def _graph_lookup(
|
||||
self, start_node_id: str, max_depth: int = None
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Performs a $graphLookup starting from 'start_node_id' and returns
|
||||
all reachable documents (including the start node itself).
|
||||
|
||||
Pipeline Explanation:
|
||||
- 1) $match: We match the start node document by _id = start_node_id.
|
||||
- 2) $graphLookup:
|
||||
"from": same collection,
|
||||
"startWith": "$edges.target" (the immediate neighbors in 'edges'),
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "reachableNodes",
|
||||
"maxDepth": max_depth (if provided),
|
||||
"depthField": "depth" (used for debugging or filtering).
|
||||
- 3) We add an $project or $unwind as needed to extract data.
|
||||
"""
|
||||
pipeline = [
|
||||
{"$match": {"_id": start_node_id}},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self.collection.name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "reachableNodes",
|
||||
"depthField": "depth",
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
# If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth
|
||||
if max_depth is not None:
|
||||
pipeline[1]["$graphLookup"]["maxDepth"] = max_depth
|
||||
|
||||
# Return the matching doc plus a field "reachableNodes"
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
results = await cursor.to_list(None)
|
||||
|
||||
# If there's no matching node, results = [].
|
||||
# Otherwise, results[0] is the start node doc,
|
||||
# plus results[0]["reachableNodes"] is the array of connected docs.
|
||||
return results
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
# BASIC QUERIES
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if node_id is present in the collection by looking up its doc.
|
||||
No real need for $graphLookup here, but let's keep it direct.
|
||||
"""
|
||||
doc = await self.collection.find_one({"_id": node_id}, {"_id": 1})
|
||||
return doc is not None
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
"""
|
||||
Check if there's a direct single-hop edge from source_node_id to target_node_id.
|
||||
|
||||
We'll do a $graphLookup with maxDepth=0 from the source node—meaning
|
||||
“Look up zero expansions.” Actually, for a direct edge check, we can do maxDepth=1
|
||||
and then see if the target node is in the "reachableNodes" at depth=0.
|
||||
|
||||
But typically for a direct edge, we might just do a find_one.
|
||||
Below is a demonstration approach.
|
||||
"""
|
||||
|
||||
# We can do a single-hop graphLookup (maxDepth=0 or 1).
|
||||
# Then check if the target_node appears among the edges array.
|
||||
pipeline = [
|
||||
{"$match": {"_id": source_node_id}},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self.collection.name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "reachableNodes",
|
||||
"depthField": "depth",
|
||||
"maxDepth": 0, # means: do not follow beyond immediate edges
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"_id": 0,
|
||||
"reachableNodes._id": 1, # only keep the _id from the subdocs
|
||||
}
|
||||
},
|
||||
]
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
results = await cursor.to_list(None)
|
||||
if not results:
|
||||
return False
|
||||
|
||||
# results[0]["reachableNodes"] are the immediate neighbors
|
||||
reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])]
|
||||
return target_node_id in reachable_ids
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
# DEGREES
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""
|
||||
Returns the total number of edges connected to node_id (both inbound and outbound).
|
||||
The easiest approach is typically two queries:
|
||||
- count of edges array in node_id's doc
|
||||
- count of how many other docs have node_id in their edges.target.
|
||||
|
||||
But we'll do a $graphLookup demonstration for inbound edges:
|
||||
1) Outbound edges: direct from node's edges array
|
||||
2) Inbound edges: we can do a special $graphLookup from all docs
|
||||
or do an explicit match.
|
||||
|
||||
For demonstration, let's do this in two steps (with second step $graphLookup).
|
||||
"""
|
||||
# --- 1) Outbound edges (direct from doc) ---
|
||||
doc = await self.collection.find_one({"_id": node_id}, {"edges": 1})
|
||||
if not doc:
|
||||
return 0
|
||||
outbound_count = len(doc.get("edges", []))
|
||||
|
||||
# --- 2) Inbound edges:
|
||||
# A simple way is: find all docs where "edges.target" == node_id.
|
||||
# But let's do a $graphLookup from `node_id` in REVERSE.
|
||||
# There's a trick to do "reverse" graphLookups: you'd store
|
||||
# reversed edges or do a more advanced pipeline. Typically you'd do
|
||||
# a direct match. We'll just do a direct match for inbound.
|
||||
inbound_count_pipeline = [
|
||||
{"$match": {"edges.target": node_id}},
|
||||
{
|
||||
"$project": {
|
||||
"matchingEdgesCount": {
|
||||
"$size": {
|
||||
"$filter": {
|
||||
"input": "$edges",
|
||||
"as": "edge",
|
||||
"cond": {"$eq": ["$$edge.target", node_id]},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}},
|
||||
]
|
||||
inbound_cursor = self.collection.aggregate(inbound_count_pipeline)
|
||||
inbound_result = await inbound_cursor.to_list(None)
|
||||
inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0
|
||||
|
||||
return outbound_count + inbound_count
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
"""
|
||||
If your graph can hold multiple edges from the same src to the same tgt
|
||||
(e.g. different 'relation' values), you can sum them. If it's always
|
||||
one edge, this is typically 1 or 0.
|
||||
|
||||
We'll do a single-hop $graphLookup from src_id,
|
||||
then count how many edges reference tgt_id at depth=0.
|
||||
"""
|
||||
pipeline = [
|
||||
{"$match": {"_id": src_id}},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self.collection.name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "neighbors",
|
||||
"depthField": "depth",
|
||||
"maxDepth": 0,
|
||||
}
|
||||
},
|
||||
{"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}},
|
||||
]
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
results = await cursor.to_list(None)
|
||||
if not results:
|
||||
return 0
|
||||
|
||||
# We can simply count how many edges in `results[0].edges` have target == tgt_id.
|
||||
edges = results[0].get("edges", [])
|
||||
count = sum(1 for e in edges if e.get("target") == tgt_id)
|
||||
return count
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
# GETTERS
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
"""
|
||||
Return the full node document (including "edges"), or None if missing.
|
||||
"""
|
||||
return await self.collection.find_one({"_id": node_id})
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Return the first edge dict from source_node_id to target_node_id if it exists.
|
||||
Uses a single-hop $graphLookup as demonstration, though a direct find is simpler.
|
||||
"""
|
||||
pipeline = [
|
||||
{"$match": {"_id": source_node_id}},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self.collection.name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "neighbors",
|
||||
"depthField": "depth",
|
||||
"maxDepth": 0,
|
||||
}
|
||||
},
|
||||
{"$project": {"edges": 1}},
|
||||
]
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
docs = await cursor.to_list(None)
|
||||
if not docs:
|
||||
return None
|
||||
|
||||
for e in docs[0].get("edges", []):
|
||||
if e.get("target") == target_node_id:
|
||||
return e
|
||||
return None
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> Union[List[Tuple[str, str]], None]:
|
||||
"""
|
||||
Return a list of (target_id, relation) for direct edges from source_node_id.
|
||||
Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
|
||||
"""
|
||||
pipeline = [
|
||||
{"$match": {"_id": source_node_id}},
|
||||
{
|
||||
"$graphLookup": {
|
||||
"from": self.collection.name,
|
||||
"startWith": "$edges.target",
|
||||
"connectFromField": "edges.target",
|
||||
"connectToField": "_id",
|
||||
"as": "neighbors",
|
||||
"depthField": "depth",
|
||||
"maxDepth": 0,
|
||||
}
|
||||
},
|
||||
{"$project": {"_id": 0, "edges": 1}},
|
||||
]
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
result = await cursor.to_list(None)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
edges = result[0].get("edges", [])
|
||||
return [(e["target"], e["relation"]) for e in edges]
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
# UPSERTS
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict):
|
||||
"""
|
||||
Insert or update a node document. If new, create an empty edges array.
|
||||
"""
|
||||
# By default, preserve existing 'edges'.
|
||||
# We'll only set 'edges' to [] on insert (no overwrite).
|
||||
update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}}
|
||||
await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict
|
||||
):
|
||||
"""
|
||||
Upsert an edge from source_node_id -> target_node_id with optional 'relation'.
|
||||
If an edge with the same target exists, we remove it and re-insert with updated data.
|
||||
"""
|
||||
# Ensure source node exists
|
||||
await self.upsert_node(source_node_id, {})
|
||||
|
||||
# Remove existing edge (if any)
|
||||
await self.collection.update_one(
|
||||
{"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}}
|
||||
)
|
||||
|
||||
# Insert new edge
|
||||
new_edge = {"target": target_node_id}
|
||||
new_edge.update(edge_data)
|
||||
await self.collection.update_one(
|
||||
{"_id": source_node_id}, {"$push": {"edges": new_edge}}
|
||||
)
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
# DELETION
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
"""
|
||||
1) Remove node’s doc entirely.
|
||||
2) Remove inbound edges from any doc that references node_id.
|
||||
"""
|
||||
# Remove inbound edges from all other docs
|
||||
await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}})
|
||||
|
||||
# Remove the node doc
|
||||
await self.collection.delete_one({"_id": node_id})
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
# EMBEDDINGS (NOT IMPLEMENTED)
|
||||
# -------------------------------------------------------------------------
|
||||
#
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]:
|
||||
"""
|
||||
Placeholder for demonstration, raises NotImplementedError.
|
||||
"""
|
||||
raise NotImplementedError("Node embedding is not used in lightrag.")
|
||||
213
minirag/kg/nano_vector_db_impl.py
Normal file
213
minirag/kg/nano_vector_db_impl.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
NanoVectorDB Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("nano-vectordb"):
|
||||
pm.install("nano-vectordb")
|
||||
|
||||
from nano_vectordb import NanoVectorDB
|
||||
import time
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
compute_mdhash_id,
|
||||
)
|
||||
|
||||
from lightrag.base import (
|
||||
BaseVectorStorage,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
|
||||
def __post_init__(self):
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
|
||||
current_time = time.time()
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
"__created_at__": current_time,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
|
||||
async def wrapped_task(batch):
|
||||
result = await self.embedding_func(batch)
|
||||
pbar.update(1)
|
||||
return result
|
||||
|
||||
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
||||
pbar = tqdm_async(
|
||||
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
||||
)
|
||||
embeddings_list = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
if len(embeddings) == len(list_data):
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
else:
|
||||
# sometimes the embedding is not returned correctly. just log it.
|
||||
logger.error(
|
||||
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
logger.info(
|
||||
f"Query: {query}, top_k: {top_k}, cosine_better_than_threshold: {self.cosine_better_than_threshold}"
|
||||
)
|
||||
results = self._client.query(
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
results = [
|
||||
{
|
||||
**dp,
|
||||
"id": dp["__id__"],
|
||||
"distance": dp["__metrics__"],
|
||||
"created_at": dp.get("__created_at__"),
|
||||
}
|
||||
for dp in results
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def client_storage(self):
|
||||
return getattr(self._client, "_NanoVectorDB__storage")
|
||||
|
||||
async def delete(self, ids: list[str]):
|
||||
"""Delete vectors with specified IDs
|
||||
|
||||
Args:
|
||||
ids: List of vector IDs to be deleted
|
||||
"""
|
||||
try:
|
||||
self._client.delete(ids)
|
||||
logger.info(
|
||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
|
||||
async def delete_entity(self, entity_name: str):
|
||||
try:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(
|
||||
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
||||
)
|
||||
# Check if the entity exists
|
||||
if self._client.get([entity_id]):
|
||||
await self.delete([entity_id])
|
||||
logger.debug(f"Successfully deleted entity {entity_name}")
|
||||
else:
|
||||
logger.debug(f"Entity {entity_name} not found in storage")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str):
|
||||
try:
|
||||
relations = [
|
||||
dp
|
||||
for dp in self.client_storage["data"]
|
||||
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
||||
]
|
||||
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
||||
ids_to_delete = [relation["__id__"] for relation in relations]
|
||||
|
||||
if ids_to_delete:
|
||||
await self.delete(ids_to_delete)
|
||||
logger.debug(
|
||||
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No relations found for entity {entity_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
|
||||
async def index_done_callback(self):
|
||||
self._client.save()
|
||||
@@ -1,18 +1,20 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, Tuple, List, Dict
|
||||
import inspect
|
||||
from minirag.utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("neo4j"):
|
||||
pm.install("neo4j")
|
||||
|
||||
from neo4j import (
|
||||
AsyncGraphDatabase,
|
||||
exceptions as neo4jExceptions,
|
||||
AsyncDriver,
|
||||
AsyncManagedTransaction,
|
||||
GraphDatabase,
|
||||
)
|
||||
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
@@ -20,6 +22,9 @@ from tenacity import (
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class Neo4JStorage(BaseGraphStorage):
|
||||
@@ -27,17 +32,63 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
def load_nx_graph(file_name):
|
||||
print("no preloading of graph with neo4j in production")
|
||||
|
||||
def __init__(self, namespace, global_config):
|
||||
super().__init__(namespace=namespace, global_config=global_config)
|
||||
def __init__(self, namespace, global_config, embedding_func):
|
||||
super().__init__(
|
||||
namespace=namespace,
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
self._driver = None
|
||||
self._driver_lock = asyncio.Lock()
|
||||
URI = os.environ["NEO4J_URI"]
|
||||
USERNAME = os.environ["NEO4J_USERNAME"]
|
||||
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
||||
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
|
||||
DATABASE = os.environ.get(
|
||||
"NEO4J_DATABASE"
|
||||
) # If this param is None, the home database will be used. If it is not None, the specified database will be used.
|
||||
self._DATABASE = DATABASE
|
||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||
URI, auth=(USERNAME, PASSWORD)
|
||||
)
|
||||
return None
|
||||
_database_name = "home database" if DATABASE is None else f"database {DATABASE}"
|
||||
with GraphDatabase.driver(
|
||||
URI,
|
||||
auth=(USERNAME, PASSWORD),
|
||||
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
||||
) as _sync_driver:
|
||||
try:
|
||||
with _sync_driver.session(database=DATABASE) as session:
|
||||
try:
|
||||
session.run("MATCH (n) RETURN n LIMIT 0")
|
||||
logger.info(f"Connected to {DATABASE} at {URI}")
|
||||
except neo4jExceptions.ServiceUnavailable as e:
|
||||
logger.error(
|
||||
f"{DATABASE} at {URI} is not available".capitalize()
|
||||
)
|
||||
raise e
|
||||
except neo4jExceptions.AuthError as e:
|
||||
logger.error(f"Authentication failed for {DATABASE} at {URI}")
|
||||
raise e
|
||||
except neo4jExceptions.ClientError as e:
|
||||
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
||||
logger.info(
|
||||
f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()
|
||||
)
|
||||
try:
|
||||
with _sync_driver.session() as session:
|
||||
session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS")
|
||||
logger.info(f"{DATABASE} at {URI} created".capitalize())
|
||||
except neo4jExceptions.ClientError as e:
|
||||
if (
|
||||
e.code
|
||||
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
||||
):
|
||||
logger.warning(
|
||||
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead."
|
||||
)
|
||||
logger.error(f"Failed to create {DATABASE} at {URI}")
|
||||
raise e
|
||||
|
||||
def __post_init__(self):
|
||||
self._node_embed_algorithms = {
|
||||
@@ -59,7 +110,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = (
|
||||
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
||||
)
|
||||
@@ -74,7 +125,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = (
|
||||
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
||||
"RETURN COUNT(r) > 0 AS edgeExists"
|
||||
@@ -86,11 +137,8 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
return single_result["edgeExists"]
|
||||
|
||||
def close(self):
|
||||
self._driver.close()
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
entity_name_label = node_id.strip('"')
|
||||
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
||||
result = await session.run(query)
|
||||
@@ -107,7 +155,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = f"""
|
||||
MATCH (n:`{entity_name_label}`)
|
||||
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
|
||||
@@ -154,7 +202,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
"""
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
query = f"""
|
||||
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
@@ -185,7 +233,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
query = f"""MATCH (n:`{node_label}`)
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected"""
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
results = await session.run(query)
|
||||
edges = []
|
||||
async for record in results:
|
||||
@@ -214,6 +262,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
neo4jExceptions.ServiceUnavailable,
|
||||
neo4jExceptions.TransientError,
|
||||
neo4jExceptions.WriteServiceUnavailable,
|
||||
neo4jExceptions.ClientError,
|
||||
)
|
||||
),
|
||||
)
|
||||
@@ -239,7 +288,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
|
||||
try:
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
await session.execute_write(_do_upsert)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during upsert: {str(e)}")
|
||||
@@ -286,7 +335,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
|
||||
try:
|
||||
async with self._driver.session() as session:
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
await session.execute_write(_do_upsert_edge)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during edge upsert: {str(e)}")
|
||||
@@ -294,3 +343,177 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
async def _node2vec_embed(self):
|
||||
print("Implemented but never called.")
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
|
||||
Key fixes:
|
||||
1. Include the starting node itself
|
||||
2. Handle multi-label nodes
|
||||
3. Clarify relationship directions
|
||||
4. Add depth control
|
||||
"""
|
||||
label = node_label.strip('"')
|
||||
result = {"nodes": [], "edges": []}
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
try:
|
||||
# Critical debug step: first verify if starting node exists
|
||||
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
|
||||
validate_result = await session.run(validate_query)
|
||||
if not await validate_result.single():
|
||||
logger.warning(f"Starting node {label} does not exist!")
|
||||
return result
|
||||
|
||||
# Optimized query (including direction handling and self-loops)
|
||||
main_query = f"""
|
||||
MATCH (start:`{label}`)
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {{
|
||||
relationshipFilter: '>',
|
||||
minLevel: 0,
|
||||
maxLevel: {max_depth},
|
||||
bfs: true
|
||||
}})
|
||||
YIELD nodes, relationships
|
||||
RETURN nodes, relationships
|
||||
"""
|
||||
result_set = await session.run(main_query)
|
||||
record = await result_set.single()
|
||||
|
||||
if record:
|
||||
# Handle nodes (compatible with multi-label cases)
|
||||
for node in record["nodes"]:
|
||||
# Use node ID + label combination as unique identifier
|
||||
node_id = f"{node.id}_{'_'.join(node.labels)}"
|
||||
if node_id not in seen_nodes:
|
||||
node_data = dict(node)
|
||||
node_data["labels"] = list(node.labels) # Keep all labels
|
||||
result["nodes"].append(node_data)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Handle relationships (including direction information)
|
||||
for rel in record["relationships"]:
|
||||
edge_id = f"{rel.id}_{rel.type}"
|
||||
if edge_id not in seen_edges:
|
||||
start = rel.start_node
|
||||
end = rel.end_node
|
||||
edge_data = dict(rel)
|
||||
edge_data.update(
|
||||
{
|
||||
"source": f"{start.id}_{'_'.join(start.labels)}",
|
||||
"target": f"{end.id}_{'_'.join(end.labels)}",
|
||||
"type": rel.type,
|
||||
"direction": rel.element_id.split(
|
||||
"->" if rel.end_node == end else "<-"
|
||||
)[1],
|
||||
}
|
||||
)
|
||||
result["edges"].append(edge_data)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result['nodes'])} | Edge count: {len(result['edges'])}"
|
||||
)
|
||||
|
||||
except neo4jExceptions.ClientError as e:
|
||||
logger.error(f"APOC query failed: {str(e)}")
|
||||
return await self._robust_fallback(label, max_depth)
|
||||
|
||||
return result
|
||||
|
||||
async def _robust_fallback(
|
||||
self, label: str, max_depth: int
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""Enhanced fallback query solution"""
|
||||
result = {"nodes": [], "edges": []}
|
||||
visited_nodes = set()
|
||||
visited_edges = set()
|
||||
|
||||
async def traverse(current_label: str, current_depth: int):
|
||||
if current_depth > max_depth:
|
||||
return
|
||||
|
||||
# Get current node details
|
||||
node = await self.get_node(current_label)
|
||||
if not node:
|
||||
return
|
||||
|
||||
node_id = f"{current_label}"
|
||||
if node_id in visited_nodes:
|
||||
return
|
||||
visited_nodes.add(node_id)
|
||||
|
||||
# Add node data (with complete labels)
|
||||
node_data = {k: v for k, v in node.items()}
|
||||
node_data["labels"] = [
|
||||
current_label
|
||||
] # Assume get_node method returns label information
|
||||
result["nodes"].append(node_data)
|
||||
|
||||
# Get all outgoing and incoming edges
|
||||
query = f"""
|
||||
MATCH (a)-[r]-(b)
|
||||
WHERE a:`{current_label}` OR b:`{current_label}`
|
||||
RETURN a, r, b,
|
||||
CASE WHEN startNode(r) = a THEN 'OUTGOING' ELSE 'INCOMING' END AS direction
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
results = await session.run(query)
|
||||
async for record in results:
|
||||
# Handle edges
|
||||
rel = record["r"]
|
||||
edge_id = f"{rel.id}_{rel.type}"
|
||||
if edge_id not in visited_edges:
|
||||
edge_data = dict(rel)
|
||||
edge_data.update(
|
||||
{
|
||||
"source": list(record["a"].labels)[0],
|
||||
"target": list(record["b"].labels)[0],
|
||||
"type": rel.type,
|
||||
"direction": record["direction"],
|
||||
}
|
||||
)
|
||||
result["edges"].append(edge_data)
|
||||
visited_edges.add(edge_id)
|
||||
|
||||
# Recursively traverse adjacent nodes
|
||||
next_label = (
|
||||
list(record["b"].labels)[0]
|
||||
if record["direction"] == "OUTGOING"
|
||||
else list(record["a"].labels)[0]
|
||||
)
|
||||
await traverse(next_label, current_depth + 1)
|
||||
|
||||
await traverse(label, 0)
|
||||
return result
|
||||
|
||||
async def get_all_labels(self) -> List[str]:
|
||||
"""
|
||||
Get all existing node labels in the database
|
||||
Returns:
|
||||
["Person", "Company", ...] # Alphabetically sorted label list
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
|
||||
# query = "CALL db.labels() YIELD label RETURN label"
|
||||
|
||||
# Method 2: Query compatible with older versions
|
||||
query = """
|
||||
MATCH (n)
|
||||
WITH DISTINCT labels(n) AS node_labels
|
||||
UNWIND node_labels AS label
|
||||
RETURN DISTINCT label
|
||||
ORDER BY label
|
||||
"""
|
||||
|
||||
result = await session.run(query)
|
||||
labels = []
|
||||
async for record in result:
|
||||
labels.append(record["label"])
|
||||
return labels
|
||||
|
||||
228
minirag/kg/networkx_impl.py
Normal file
228
minirag/kg/networkx_impl.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
NetworkX Storage Module
|
||||
=======================
|
||||
|
||||
This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks.
|
||||
|
||||
The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX.
|
||||
|
||||
Author: lightrag team
|
||||
Created: 2024-01-25
|
||||
License: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Dependencies:
|
||||
- NetworkX
|
||||
- NumPy
|
||||
- LightRAG
|
||||
- graspologic
|
||||
|
||||
Features:
|
||||
- Load and save graphs in various formats (e.g., GEXF, GraphML, JSON)
|
||||
- Query graph nodes and edges
|
||||
- Calculate node and edge degrees
|
||||
- Embed nodes using various algorithms (e.g., Node2Vec)
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
import html
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
)
|
||||
|
||||
from lightrag.base import (
|
||||
BaseGraphStorage,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetworkXStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name) -> nx.Graph:
|
||||
if os.path.exists(file_name):
|
||||
return nx.read_graphml(file_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def write_nx_graph(graph: nx.Graph, file_name):
|
||||
logger.info(
|
||||
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
|
||||
)
|
||||
nx.write_graphml(graph, file_name)
|
||||
|
||||
@staticmethod
|
||||
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
|
||||
"""
|
||||
from graspologic.utils import largest_connected_component
|
||||
|
||||
graph = graph.copy()
|
||||
graph = cast(nx.Graph, largest_connected_component(graph))
|
||||
node_mapping = {
|
||||
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
||||
} # type: ignore
|
||||
graph = nx.relabel_nodes(graph, node_mapping)
|
||||
return NetworkXStorage._stabilize_graph(graph)
|
||||
|
||||
@staticmethod
|
||||
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
Ensure an undirected graph with the same relationships will always be read the same way.
|
||||
"""
|
||||
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
||||
|
||||
sorted_nodes = graph.nodes(data=True)
|
||||
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
|
||||
|
||||
fixed_graph.add_nodes_from(sorted_nodes)
|
||||
edges = list(graph.edges(data=True))
|
||||
|
||||
if not graph.is_directed():
|
||||
|
||||
def _sort_source_target(edge):
|
||||
source, target, edge_data = edge
|
||||
if source > target:
|
||||
temp = source
|
||||
source = target
|
||||
target = temp
|
||||
return source, target, edge_data
|
||||
|
||||
edges = [_sort_source_target(edge) for edge in edges]
|
||||
|
||||
def _get_edge_key(source: Any, target: Any) -> str:
|
||||
return f"{source} -> {target}"
|
||||
|
||||
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
|
||||
|
||||
fixed_graph.add_edges_from(edges)
|
||||
return fixed_graph
|
||||
|
||||
def __post_init__(self):
|
||||
self._graphml_xml_file = os.path.join(
|
||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
if preloaded_graph is not None:
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
self._graph = preloaded_graph or nx.Graph()
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
return self._graph.has_node(node_id)
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
return self._graph.has_edge(source_node_id, target_node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
return self._graph.nodes.get(node_id)
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
return self._graph.degree(node_id)
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
return self._graph.edges.get((source_node_id, target_node_id))
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
if self._graph.has_node(source_node_id):
|
||||
return list(self._graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
"""
|
||||
Delete a node from the graph based on the specified node_id.
|
||||
|
||||
:param node_id: The node_id to delete
|
||||
"""
|
||||
if self._graph.has_node(node_id):
|
||||
self._graph.remove_node(node_id)
|
||||
logger.info(f"Node {node_id} deleted from the graph.")
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
# @TODO: NOT USED
|
||||
async def _node2vec_embed(self):
|
||||
from graspologic import embed
|
||||
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
self._graph,
|
||||
**self.global_config["node2vec_params"],
|
||||
)
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
def remove_nodes(self, nodes: list[str]):
|
||||
"""Delete multiple nodes
|
||||
|
||||
Args:
|
||||
nodes: List of node IDs to be deleted
|
||||
"""
|
||||
for node in nodes:
|
||||
if self._graph.has_node(node):
|
||||
self._graph.remove_node(node)
|
||||
|
||||
def remove_edges(self, edges: list[tuple[str, str]]):
|
||||
"""Delete multiple edges
|
||||
|
||||
Args:
|
||||
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
||||
"""
|
||||
for source, target in edges:
|
||||
if self._graph.has_edge(source, target):
|
||||
self._graph.remove_edge(source, target)
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
# import html
|
||||
@@ -6,6 +7,11 @@ from dataclasses import dataclass
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import array
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("oracledb"):
|
||||
pm.install("oracledb")
|
||||
|
||||
|
||||
from ..utils import logger
|
||||
from ..base import (
|
||||
@@ -114,16 +120,19 @@ class OracleDB:
|
||||
|
||||
logger.info("Finished check all tables in Oracle database")
|
||||
|
||||
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
|
||||
async def query(
|
||||
self, sql: str, params: dict = None, multirows: bool = False
|
||||
) -> Union[dict, None]:
|
||||
async with self.pool.acquire() as connection:
|
||||
connection.inputtypehandler = self.input_type_handler
|
||||
connection.outputtypehandler = self.output_type_handler
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
await cursor.execute(sql)
|
||||
await cursor.execute(sql, params)
|
||||
except Exception as e:
|
||||
logger.error(f"Oracle database error: {e}")
|
||||
print(sql)
|
||||
print(params)
|
||||
raise
|
||||
columns = [column[0].lower() for column in cursor.description]
|
||||
if multirows:
|
||||
@@ -140,7 +149,7 @@ class OracleDB:
|
||||
data = None
|
||||
return data
|
||||
|
||||
async def execute(self, sql: str, data: list = None):
|
||||
async def execute(self, sql: str, data: Union[list, dict] = None):
|
||||
# logger.info("go into OracleDB execute method")
|
||||
try:
|
||||
async with self.pool.acquire() as connection:
|
||||
@@ -150,8 +159,6 @@ class OracleDB:
|
||||
if data is None:
|
||||
await cursor.execute(sql)
|
||||
else:
|
||||
# print(data)
|
||||
# print(sql)
|
||||
await cursor.execute(sql, data)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
@@ -164,34 +171,64 @@ class OracleDB:
|
||||
@dataclass
|
||||
class OracleKVStorage(BaseKVStorage):
|
||||
# should pass db object to self.db
|
||||
db: OracleDB = None
|
||||
meta_fields = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._data = {}
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||
"""根据 id 获取 doc_full 数据."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
|
||||
workspace=self.db.workspace, id=id
|
||||
)
|
||||
"""get doc_full data based on id."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "id": id}
|
||||
# print("get_by_id:"+SQL)
|
||||
res = await self.db.query(SQL)
|
||||
if "llm_response_cache" == self.namespace:
|
||||
array_res = await self.db.query(SQL, params, multirows=True)
|
||||
res = {}
|
||||
for row in array_res:
|
||||
res[row["id"]] = row
|
||||
else:
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
data = res # {"data":res}
|
||||
# print (data)
|
||||
return data
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
|
||||
if "llm_response_cache" == self.namespace:
|
||||
array_res = await self.db.query(SQL, params, multirows=True)
|
||||
res = {}
|
||||
for row in array_res:
|
||||
res[row["id"]] = row
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||
"""根据 id 获取 doc_chunks 数据"""
|
||||
"""get doc_chunks data based on id"""
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
params = {"workspace": self.db.workspace}
|
||||
# print("get_by_ids:"+SQL)
|
||||
res = await self.db.query(SQL, multirows=True)
|
||||
res = await self.db.query(SQL, params, multirows=True)
|
||||
if "llm_response_cache" == self.namespace:
|
||||
modes = set()
|
||||
dict_res: dict[str, dict] = {}
|
||||
for row in res:
|
||||
modes.add(row["mode"])
|
||||
for mode in modes:
|
||||
if mode not in dict_res:
|
||||
dict_res[mode] = {}
|
||||
for row in res:
|
||||
dict_res[row["mode"]][row["id"]] = row
|
||||
res = [{k: v} for k, v in dict_res.items()]
|
||||
if res:
|
||||
data = res # [{"data":i} for i in res]
|
||||
# print(data)
|
||||
@@ -199,33 +236,43 @@ class OracleKVStorage(BaseKVStorage):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_by_status_and_ids(
|
||||
self, status: str, ids: list[str]
|
||||
) -> Union[list[dict], None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
if ids is not None:
|
||||
SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
else:
|
||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
res = await self.db.query(SQL, params, multirows=True)
|
||||
if res:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""过滤掉重复内容"""
|
||||
"""Return keys that don't exist in storage"""
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||
table_name=N_T[self.namespace],
|
||||
workspace=self.db.workspace,
|
||||
ids=",".join([f"'{k}'" for k in keys]),
|
||||
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
|
||||
)
|
||||
res = await self.db.query(SQL, multirows=True)
|
||||
data = None
|
||||
params = {"workspace": self.db.workspace}
|
||||
res = await self.db.query(SQL, params, multirows=True)
|
||||
if res:
|
||||
exist_keys = [key["id"] for key in res]
|
||||
data = set([s for s in keys if s not in exist_keys])
|
||||
return data
|
||||
else:
|
||||
exist_keys = []
|
||||
data = set([s for s in keys if s not in exist_keys])
|
||||
return data
|
||||
return set(keys)
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
# print(self._data)
|
||||
# values = []
|
||||
if self.namespace == "text_chunks":
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
"id": k,
|
||||
**{k1: v1 for k1, v1 in v.items()},
|
||||
}
|
||||
for k, v in data.items()
|
||||
@@ -241,32 +288,50 @@ class OracleKVStorage(BaseKVStorage):
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
# print(list_data)
|
||||
|
||||
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
||||
for item in list_data:
|
||||
merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
|
||||
|
||||
values = [
|
||||
item["__id__"],
|
||||
item["content"],
|
||||
self.db.workspace,
|
||||
item["tokens"],
|
||||
item["chunk_order_index"],
|
||||
item["full_doc_id"],
|
||||
item["__vector__"],
|
||||
]
|
||||
# print(merge_sql)
|
||||
await self.db.execute(merge_sql, values)
|
||||
|
||||
_data = {
|
||||
"id": item["id"],
|
||||
"content": item["content"],
|
||||
"workspace": self.db.workspace,
|
||||
"tokens": item["tokens"],
|
||||
"chunk_order_index": item["chunk_order_index"],
|
||||
"full_doc_id": item["full_doc_id"],
|
||||
"content_vector": item["__vector__"],
|
||||
"status": item["status"],
|
||||
}
|
||||
await self.db.execute(merge_sql, _data)
|
||||
if self.namespace == "full_docs":
|
||||
for k, v in self._data.items():
|
||||
for k, v in data.items():
|
||||
# values.clear()
|
||||
merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
|
||||
check_id=k,
|
||||
)
|
||||
values = [k, self._data[k]["content"], self.db.workspace]
|
||||
# print(merge_sql)
|
||||
await self.db.execute(merge_sql, values)
|
||||
return left_data
|
||||
merge_sql = SQL_TEMPLATES["merge_doc_full"]
|
||||
_data = {
|
||||
"id": k,
|
||||
"content": v["content"],
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
await self.db.execute(merge_sql, _data)
|
||||
|
||||
if self.namespace == "llm_response_cache":
|
||||
for mode, items in data.items():
|
||||
for k, v in items.items():
|
||||
upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
|
||||
_data = {
|
||||
"workspace": self.db.workspace,
|
||||
"id": k,
|
||||
"original_prompt": v["original_prompt"],
|
||||
"return_value": v["return"],
|
||||
"cache_mode": mode,
|
||||
}
|
||||
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
return None
|
||||
|
||||
async def change_status(self, id: str, status: str):
|
||||
SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace])
|
||||
params = {"workspace": self.db.workspace, "id": id, "status": status}
|
||||
await self.db.execute(SQL, params)
|
||||
|
||||
async def index_done_callback(self):
|
||||
if self.namespace in ["full_docs", "text_chunks"]:
|
||||
@@ -275,10 +340,16 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
@dataclass
|
||||
class OracleVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = 0.2
|
||||
# should pass db object to self.db
|
||||
db: OracleDB = None
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
"""向向量数据库中插入数据"""
|
||||
@@ -295,18 +366,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
# 转换精度
|
||||
dtype = str(embedding.dtype).upper()
|
||||
dimension = embedding.shape[0]
|
||||
embedding_string = ", ".join(map(str, embedding.tolist()))
|
||||
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
|
||||
|
||||
SQL = SQL_TEMPLATES[self.namespace].format(
|
||||
embedding_string=embedding_string,
|
||||
dimension=dimension,
|
||||
dtype=dtype,
|
||||
workspace=self.db.workspace,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
|
||||
params = {
|
||||
"embedding_string": embedding_string,
|
||||
"workspace": self.db.workspace,
|
||||
"top_k": top_k,
|
||||
"better_than_threshold": self.cosine_better_than_threshold,
|
||||
}
|
||||
# print(SQL)
|
||||
results = await self.db.query(SQL, multirows=True)
|
||||
results = await self.db.query(SQL, params=params, multirows=True)
|
||||
# print("vector search result:",results)
|
||||
return results
|
||||
|
||||
@@ -317,7 +387,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
def __post_init__(self):
|
||||
"""从graphml文件加载图"""
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
|
||||
|
||||
#################### insert method ################
|
||||
|
||||
@@ -328,6 +398,8 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
entity_type = node_data["entity_type"]
|
||||
description = node_data["description"]
|
||||
source_id = node_data["source_id"]
|
||||
logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
|
||||
|
||||
content = entity_name + description
|
||||
contents = [content]
|
||||
batches = [
|
||||
@@ -339,22 +411,17 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
||||
workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
|
||||
)
|
||||
# print(merge_sql)
|
||||
await self.db.execute(
|
||||
merge_sql,
|
||||
[
|
||||
self.db.workspace,
|
||||
entity_name,
|
||||
entity_type,
|
||||
description,
|
||||
source_id,
|
||||
content,
|
||||
content_vector,
|
||||
],
|
||||
)
|
||||
merge_sql = SQL_TEMPLATES["merge_node"]
|
||||
data = {
|
||||
"workspace": self.db.workspace,
|
||||
"name": entity_name,
|
||||
"entity_type": entity_type,
|
||||
"description": description,
|
||||
"source_chunk_id": source_id,
|
||||
"content": content,
|
||||
"content_vector": content_vector,
|
||||
}
|
||||
await self.db.execute(merge_sql, data)
|
||||
# self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
@@ -368,6 +435,10 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
keywords = edge_data["keywords"]
|
||||
description = edge_data["description"]
|
||||
source_chunk_id = edge_data["source_id"]
|
||||
logger.debug(
|
||||
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
|
||||
)
|
||||
|
||||
content = keywords + source_name + target_name + description
|
||||
contents = [content]
|
||||
batches = [
|
||||
@@ -379,27 +450,20 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
||||
workspace=self.db.workspace,
|
||||
source_name=source_name,
|
||||
target_name=target_name,
|
||||
source_chunk_id=source_chunk_id,
|
||||
)
|
||||
merge_sql = SQL_TEMPLATES["merge_edge"]
|
||||
data = {
|
||||
"workspace": self.db.workspace,
|
||||
"source_name": source_name,
|
||||
"target_name": target_name,
|
||||
"weight": weight,
|
||||
"keywords": keywords,
|
||||
"description": description,
|
||||
"source_chunk_id": source_chunk_id,
|
||||
"content": content,
|
||||
"content_vector": content_vector,
|
||||
}
|
||||
# print(merge_sql)
|
||||
await self.db.execute(
|
||||
merge_sql,
|
||||
[
|
||||
self.db.workspace,
|
||||
source_name,
|
||||
target_name,
|
||||
weight,
|
||||
keywords,
|
||||
description,
|
||||
source_chunk_id,
|
||||
content,
|
||||
content_vector,
|
||||
],
|
||||
)
|
||||
await self.db.execute(merge_sql, data)
|
||||
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
@@ -429,12 +493,11 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
#################### query method #################
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
"""根据节点id检查节点是否存在"""
|
||||
SQL = SQL_TEMPLATES["has_node"].format(
|
||||
workspace=self.db.workspace, node_id=node_id
|
||||
)
|
||||
SQL = SQL_TEMPLATES["has_node"]
|
||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||
# print(SQL)
|
||||
# print(self.db.workspace, node_id)
|
||||
res = await self.db.query(SQL)
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
# print("Node exist!",res)
|
||||
return True
|
||||
@@ -444,13 +507,14 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
"""根据源和目标节点id检查边是否存在"""
|
||||
SQL = SQL_TEMPLATES["has_edge"].format(
|
||||
workspace=self.db.workspace,
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
)
|
||||
SQL = SQL_TEMPLATES["has_edge"]
|
||||
params = {
|
||||
"workspace": self.db.workspace,
|
||||
"source_node_id": source_node_id,
|
||||
"target_node_id": target_node_id,
|
||||
}
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL)
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
# print("Edge exist!",res)
|
||||
return True
|
||||
@@ -460,11 +524,10 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
"""根据节点id获取节点的度"""
|
||||
SQL = SQL_TEMPLATES["node_degree"].format(
|
||||
workspace=self.db.workspace, node_id=node_id
|
||||
)
|
||||
SQL = SQL_TEMPLATES["node_degree"]
|
||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL)
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
# print("Node degree",res["degree"])
|
||||
return res["degree"]
|
||||
@@ -480,12 +543,11 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
"""根据节点id获取节点数据"""
|
||||
SQL = SQL_TEMPLATES["get_node"].format(
|
||||
workspace=self.db.workspace, node_id=node_id
|
||||
)
|
||||
SQL = SQL_TEMPLATES["get_node"]
|
||||
params = {"workspace": self.db.workspace, "node_id": node_id}
|
||||
# print(self.db.workspace, node_id)
|
||||
# print(SQL)
|
||||
res = await self.db.query(SQL)
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
# print("Get node!",self.db.workspace, node_id,res)
|
||||
return res
|
||||
@@ -497,12 +559,13 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""根据源和目标节点id获取边"""
|
||||
SQL = SQL_TEMPLATES["get_edge"].format(
|
||||
workspace=self.db.workspace,
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
)
|
||||
res = await self.db.query(SQL)
|
||||
SQL = SQL_TEMPLATES["get_edge"]
|
||||
params = {
|
||||
"workspace": self.db.workspace,
|
||||
"source_node_id": source_node_id,
|
||||
"target_node_id": target_node_id,
|
||||
}
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
||||
return res
|
||||
@@ -513,10 +576,9 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
"""根据节点id获取节点的所有边"""
|
||||
if await self.has_node(source_node_id):
|
||||
SQL = SQL_TEMPLATES["get_node_edges"].format(
|
||||
workspace=self.db.workspace, source_node_id=source_node_id
|
||||
)
|
||||
res = await self.db.query(sql=SQL, multirows=True)
|
||||
SQL = SQL_TEMPLATES["get_node_edges"]
|
||||
params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
|
||||
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||
if res:
|
||||
data = [(i["source_name"], i["target_name"]) for i in res]
|
||||
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
||||
@@ -525,6 +587,29 @@ class OracleGraphStorage(BaseGraphStorage):
|
||||
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
||||
return []
|
||||
|
||||
async def get_all_nodes(self, limit: int):
|
||||
"""查询所有节点"""
|
||||
SQL = SQL_TEMPLATES["get_all_nodes"]
|
||||
params = {"workspace": self.db.workspace, "limit": str(limit)}
|
||||
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||
if res:
|
||||
return res
|
||||
|
||||
async def get_all_edges(self, limit: int):
|
||||
"""查询所有边"""
|
||||
SQL = SQL_TEMPLATES["get_all_edges"]
|
||||
params = {"workspace": self.db.workspace, "limit": str(limit)}
|
||||
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||
if res:
|
||||
return res
|
||||
|
||||
async def get_statistics(self):
|
||||
SQL = SQL_TEMPLATES["get_statistics"]
|
||||
params = {"workspace": self.db.workspace}
|
||||
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
||||
if res:
|
||||
return res
|
||||
|
||||
|
||||
N_T = {
|
||||
"full_docs": "LIGHTRAG_DOC_FULL",
|
||||
@@ -537,20 +622,26 @@ N_T = {
|
||||
TABLES = {
|
||||
"LIGHTRAG_DOC_FULL": {
|
||||
"ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
|
||||
id varchar(256)PRIMARY KEY,
|
||||
id varchar(256),
|
||||
workspace varchar(1024),
|
||||
doc_name varchar(1024),
|
||||
content CLOB,
|
||||
meta JSON,
|
||||
content_summary varchar(1024),
|
||||
content_length NUMBER,
|
||||
status varchar(256),
|
||||
chunks_count NUMBER,
|
||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updatetime TIMESTAMP DEFAULT NULL
|
||||
updatetime TIMESTAMP DEFAULT NULL,
|
||||
error varchar(4096)
|
||||
)"""
|
||||
},
|
||||
"LIGHTRAG_DOC_CHUNKS": {
|
||||
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
||||
id varchar(256) PRIMARY KEY,
|
||||
id varchar(256),
|
||||
workspace varchar(1024),
|
||||
full_doc_id varchar(256),
|
||||
status varchar(256),
|
||||
chunk_order_index NUMBER,
|
||||
tokens NUMBER,
|
||||
content CLOB,
|
||||
@@ -592,9 +683,15 @@ TABLES = {
|
||||
"LIGHTRAG_LLM_CACHE": {
|
||||
"ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
|
||||
id varchar(256) PRIMARY KEY,
|
||||
send clob,
|
||||
return clob,
|
||||
model varchar(1024),
|
||||
workspace varchar(1024),
|
||||
cache_mode varchar(256),
|
||||
model_name varchar(256),
|
||||
original_prompt clob,
|
||||
return_value clob,
|
||||
embedding CLOB,
|
||||
embedding_shape NUMBER,
|
||||
embedding_min NUMBER,
|
||||
embedding_max NUMBER,
|
||||
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updatetime TIMESTAMP DEFAULT NULL
|
||||
)"""
|
||||
@@ -619,82 +716,141 @@ TABLES = {
|
||||
|
||||
SQL_TEMPLATES = {
|
||||
# SQL for KVStorage
|
||||
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
|
||||
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
|
||||
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
|
||||
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
|
||||
"filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
|
||||
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
|
||||
USING DUAL
|
||||
ON (a.id = '{check_id}')
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT(id,content,workspace) values(:1,:2,:3)
|
||||
""",
|
||||
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
|
||||
USING DUAL
|
||||
ON (a.id = '{check_id}')
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
|
||||
values (:1,:2,:3,:4,:5,:6,:7) """,
|
||||
"get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
|
||||
"get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
|
||||
"get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
|
||||
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
|
||||
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
|
||||
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
|
||||
"get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
|
||||
"get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
|
||||
"get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
|
||||
"get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
|
||||
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
|
||||
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
|
||||
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
|
||||
"change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
|
||||
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
|
||||
USING DUAL
|
||||
ON (a.id = :id and a.workspace = :workspace)
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT(id,content,workspace) values(:id,:content,:workspace)""",
|
||||
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
|
||||
USING DUAL
|
||||
ON (id = :id and workspace = :workspace)
|
||||
WHEN NOT MATCHED THEN INSERT
|
||||
(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
|
||||
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
|
||||
"upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
|
||||
USING DUAL
|
||||
ON (a.id = :id)
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT (workspace,id,original_prompt,return_value,cache_mode)
|
||||
VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode)
|
||||
WHEN MATCHED THEN UPDATE
|
||||
SET original_prompt = :original_prompt,
|
||||
return_value = :return_value,
|
||||
cache_mode = :cache_mode,
|
||||
updatetime = SYSDATE""",
|
||||
# SQL for VectorStorage
|
||||
"entities": """SELECT name as entity_name FROM
|
||||
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
||||
FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
|
||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
||||
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
||||
FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace)
|
||||
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
||||
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
|
||||
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
||||
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
|
||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
||||
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
||||
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace)
|
||||
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
||||
"chunks": """SELECT id FROM
|
||||
(SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
|
||||
WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
|
||||
(SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace)
|
||||
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
||||
# SQL for GraphStorage
|
||||
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)
|
||||
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
||||
WHERE a.workspace=:workspace AND a.name=:node_id
|
||||
COLUMNS (a.name))""",
|
||||
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a) -[e]-> (b)
|
||||
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
||||
AND a.name='{source_node_id}' AND b.name='{target_node_id}'
|
||||
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
||||
AND a.name=:source_node_id AND b.name=:target_node_id
|
||||
COLUMNS (e.source_name,e.target_name) )""",
|
||||
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)-[e]->(b)
|
||||
WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
||||
AND a.name='{node_id}' or b.name = '{node_id}'
|
||||
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
||||
AND a.name=:node_id or b.name = :node_id
|
||||
COLUMNS (a.name))""",
|
||||
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
||||
FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)
|
||||
WHERE a.workspace='{workspace}' AND a.name='{node_id}'
|
||||
WHERE a.workspace=:workspace AND a.name=:node_id
|
||||
COLUMNS (a.name)
|
||||
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
||||
WHERE t2.workspace='{workspace}'""",
|
||||
WHERE t2.workspace=:workspace""",
|
||||
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
||||
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
||||
FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)-[e]->(b)
|
||||
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
||||
AND a.name='{source_node_id}' and b.name = '{target_node_id}'
|
||||
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
||||
AND a.name=:source_node_id and b.name = :target_node_id
|
||||
COLUMNS (e.id,a.name as source_id)
|
||||
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
||||
"get_node_edges": """SELECT source_name,target_name
|
||||
FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)-[e]->(b)
|
||||
WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
|
||||
AND a.name='{source_node_id}'
|
||||
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
||||
AND a.name=:source_node_id
|
||||
COLUMNS (a.name as source_name,b.name as target_name))""",
|
||||
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
||||
USING DUAL
|
||||
ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
|
||||
ON (a.workspace=:workspace and a.name=:name)
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
|
||||
values (:1,:2,:3,:4,:5,:6,:7) """,
|
||||
values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector)
|
||||
WHEN MATCHED THEN
|
||||
UPDATE SET
|
||||
entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
|
||||
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
|
||||
USING DUAL
|
||||
ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
|
||||
ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name)
|
||||
WHEN NOT MATCHED THEN
|
||||
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
||||
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
|
||||
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector)
|
||||
WHEN MATCHED THEN
|
||||
UPDATE SET
|
||||
weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
|
||||
"get_all_nodes": """WITH t0 AS (
|
||||
SELECT name AS id, entity_type AS label, entity_type, description,
|
||||
'["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
|
||||
FROM lightrag_graph_nodes
|
||||
WHERE workspace = :workspace
|
||||
ORDER BY createtime DESC fetch first :limit rows only
|
||||
), t1 AS (
|
||||
SELECT t0.id, source_chunk_id
|
||||
FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
|
||||
), t2 AS (
|
||||
SELECT t1.id, LISTAGG(t2.content, '\n') content
|
||||
FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
|
||||
GROUP BY t1.id
|
||||
)
|
||||
SELECT t0.id, label, entity_type, description, t2.content
|
||||
FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
|
||||
"get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
|
||||
t1.weight,t1.DESCRIPTION,t2.content
|
||||
FROM LIGHTRAG_GRAPH_EDGES t1
|
||||
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
||||
WHERE t1.workspace=:workspace
|
||||
order by t1.CREATETIME DESC
|
||||
fetch first :limit rows only""",
|
||||
"get_statistics": """select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
|
||||
count(distinct CASE WHEN type='edge' THEN id END) as edges_count
|
||||
FROM (
|
||||
select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
|
||||
UNION
|
||||
select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
|
||||
MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
|
||||
)""",
|
||||
}
|
||||
|
||||
1182
minirag/kg/postgres_impl.py
Normal file
1182
minirag/kg/postgres_impl.py
Normal file
File diff suppressed because it is too large
Load Diff
136
minirag/kg/postgres_impl_test.py
Normal file
136
minirag/kg/postgres_impl_test.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("psycopg-pool"):
|
||||
pm.install("psycopg-pool")
|
||||
pm.install("psycopg[binary,pool]")
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
import asyncpg
|
||||
import psycopg
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
|
||||
|
||||
DB = "rag"
|
||||
USER = "rag"
|
||||
PASSWORD = "rag"
|
||||
HOST = "localhost"
|
||||
PORT = "15432"
|
||||
os.environ["AGE_GRAPH_NAME"] = "dickens"
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
import asyncio.windows_events
|
||||
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
async def get_pool():
|
||||
return await asyncpg.create_pool(
|
||||
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
|
||||
min_size=10,
|
||||
max_size=10,
|
||||
max_queries=5000,
|
||||
max_inactive_connection_lifetime=300.0,
|
||||
)
|
||||
|
||||
|
||||
async def main1():
|
||||
connection_string = (
|
||||
f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
|
||||
)
|
||||
pool = AsyncConnectionPool(connection_string, open=False)
|
||||
await pool.open()
|
||||
|
||||
try:
|
||||
conn = await pool.getconn(timeout=10)
|
||||
async with conn.cursor() as curs:
|
||||
try:
|
||||
await curs.execute('SET search_path = ag_catalog, "$user", public')
|
||||
await curs.execute("SELECT create_graph('dickens-2')")
|
||||
await conn.commit()
|
||||
print("create_graph success")
|
||||
except (
|
||||
psycopg.errors.InvalidSchemaName,
|
||||
psycopg.errors.UniqueViolation,
|
||||
):
|
||||
print("create_graph already exists")
|
||||
await conn.rollback()
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
db = PostgreSQLDB(
|
||||
config={
|
||||
"host": "localhost",
|
||||
"port": 15432,
|
||||
"user": "rag",
|
||||
"password": "rag",
|
||||
"database": "r1",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def query_with_age():
|
||||
await db.initdb()
|
||||
graph = PGGraphStorage(
|
||||
namespace="chunk_entity_relation",
|
||||
global_config={},
|
||||
embedding_func=None,
|
||||
)
|
||||
graph.db = db
|
||||
res = await graph.get_node('"A CHRISTMAS CAROL"')
|
||||
print("Node is: ", res)
|
||||
res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG")
|
||||
print("Edge is: ", res)
|
||||
res = await graph.get_node_edges('"SCROOGE"')
|
||||
print("Node Edges are: ", res)
|
||||
|
||||
|
||||
async def create_edge_with_age():
|
||||
await db.initdb()
|
||||
graph = PGGraphStorage(
|
||||
namespace="chunk_entity_relation",
|
||||
global_config={},
|
||||
embedding_func=None,
|
||||
)
|
||||
graph.db = db
|
||||
await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"})
|
||||
await graph.upsert_node('"THE GIRLS"', {"world": "hello"})
|
||||
await graph.upsert_edge(
|
||||
'"THE CRATCHITS"',
|
||||
'"THE GIRLS"',
|
||||
edge_data={
|
||||
"weight": 7.0,
|
||||
"description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.',
|
||||
"keywords": '"family, collective effort"',
|
||||
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
|
||||
},
|
||||
)
|
||||
res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"')
|
||||
print("Edge is: ", res)
|
||||
|
||||
|
||||
async def main():
|
||||
pool = await get_pool()
|
||||
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
|
||||
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
|
||||
async with pool.acquire() as conn:
|
||||
try:
|
||||
await conn.execute(
|
||||
"""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')"""
|
||||
)
|
||||
except asyncpg.exceptions.InvalidSchemaNameError:
|
||||
print("create_graph already exists")
|
||||
# stmt = await conn.prepare(sql)
|
||||
row = await conn.fetch(sql)
|
||||
print("row is: ", row)
|
||||
|
||||
row = await conn.fetchrow("select '100'::int + 200 as result")
|
||||
print(row) # <Record result=300>
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(query_with_age())
|
||||
70
minirag/kg/redis_impl.py
Normal file
70
minirag/kg/redis_impl.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("redis"):
|
||||
pm.install("redis")
|
||||
|
||||
# aioredis is a depricated library, replaced with redis
|
||||
from redis.asyncio import Redis
|
||||
from lightrag.utils import logger
|
||||
from lightrag.base import BaseKVStorage
|
||||
import json
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedisKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
|
||||
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
||||
logger.info(f"Use Redis as KV {self.namespace}")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||
return [key.split(":", 1)[-1] for key in keys]
|
||||
|
||||
async def get_by_id(self, id):
|
||||
data = await self._redis.get(f"{self.namespace}:{id}")
|
||||
return json.loads(data) if data else None
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
pipe = self._redis.pipeline()
|
||||
for id in ids:
|
||||
pipe.get(f"{self.namespace}:{id}")
|
||||
results = await pipe.execute()
|
||||
|
||||
if fields:
|
||||
# Filter fields if specified
|
||||
return [
|
||||
{field: value.get(field) for field in fields if field in value}
|
||||
if (value := json.loads(result))
|
||||
else None
|
||||
for result in results
|
||||
]
|
||||
|
||||
return [json.loads(result) if result else None for result in results]
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
pipe = self._redis.pipeline()
|
||||
for key in data:
|
||||
pipe.exists(f"{self.namespace}:{key}")
|
||||
results = await pipe.execute()
|
||||
|
||||
existing_ids = {data[i] for i, exists in enumerate(results) if exists}
|
||||
return set(data) - existing_ids
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
pipe = self._redis.pipeline()
|
||||
for k, v in tqdm_async(data.items(), desc="Upserting"):
|
||||
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
|
||||
await pipe.execute()
|
||||
|
||||
for k in data:
|
||||
data[k]["_id"] = k
|
||||
return data
|
||||
|
||||
async def drop(self):
|
||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
665
minirag/kg/tidb_impl.py
Normal file
665
minirag/kg/tidb_impl.py
Normal file
@@ -0,0 +1,665 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("pymysql"):
|
||||
pm.install("pymysql")
|
||||
if not pm.is_installed("sqlalchemy"):
|
||||
pm.install("sqlalchemy")
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from tqdm import tqdm
|
||||
|
||||
from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
|
||||
from lightrag.utils import logger
|
||||
|
||||
|
||||
class TiDB(object):
|
||||
def __init__(self, config, **kwargs):
|
||||
self.host = config.get("host", None)
|
||||
self.port = config.get("port", None)
|
||||
self.user = config.get("user", None)
|
||||
self.password = config.get("password", None)
|
||||
self.database = config.get("database", None)
|
||||
self.workspace = config.get("workspace", None)
|
||||
connection_string = (
|
||||
f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
||||
f"?ssl_verify_cert=true&ssl_verify_identity=true"
|
||||
)
|
||||
|
||||
try:
|
||||
self.engine = create_engine(connection_string)
|
||||
logger.info(f"Connected to TiDB database at {self.database}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to TiDB database at {self.database}")
|
||||
logger.error(f"TiDB database error: {e}")
|
||||
raise
|
||||
|
||||
async def check_tables(self):
|
||||
for k, v in TABLES.items():
|
||||
try:
|
||||
await self.query(f"SELECT 1 FROM {k}".format(k=k))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check table {k} in TiDB database")
|
||||
logger.error(f"TiDB database error: {e}")
|
||||
try:
|
||||
# print(v["ddl"])
|
||||
await self.execute(v["ddl"])
|
||||
logger.info(f"Created table {k} in TiDB database")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create table {k} in TiDB database")
|
||||
logger.error(f"TiDB database error: {e}")
|
||||
|
||||
async def query(
|
||||
self, sql: str, params: dict = None, multirows: bool = False
|
||||
) -> Union[dict, None]:
|
||||
if params is None:
|
||||
params = {"workspace": self.workspace}
|
||||
else:
|
||||
params.update({"workspace": self.workspace})
|
||||
with self.engine.connect() as conn, conn.begin():
|
||||
try:
|
||||
result = conn.execute(text(sql), params)
|
||||
except Exception as e:
|
||||
logger.error(f"Tidb database error: {e}")
|
||||
print(sql)
|
||||
print(params)
|
||||
raise
|
||||
if multirows:
|
||||
rows = result.all()
|
||||
if rows:
|
||||
data = [dict(zip(result.keys(), row)) for row in rows]
|
||||
else:
|
||||
data = []
|
||||
else:
|
||||
row = result.first()
|
||||
if row:
|
||||
data = dict(zip(result.keys(), row))
|
||||
else:
|
||||
data = None
|
||||
return data
|
||||
|
||||
async def execute(self, sql: str, data: list | dict = None):
|
||||
# logger.info("go into TiDBDB execute method")
|
||||
try:
|
||||
with self.engine.connect() as conn, conn.begin():
|
||||
if data is None:
|
||||
conn.execute(text(sql))
|
||||
else:
|
||||
conn.execute(text(sql), parameters=data)
|
||||
except Exception as e:
|
||||
logger.error(f"TiDB database error: {e}")
|
||||
print(sql)
|
||||
print(data)
|
||||
raise
|
||||
|
||||
|
||||
@dataclass
|
||||
class TiDBKVStorage(BaseKVStorage):
|
||||
# should pass db object to self.db
|
||||
def __post_init__(self):
|
||||
self._data = {}
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||
"""根据 id 获取 doc_full 数据."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"id": id}
|
||||
# print("get_by_id:"+SQL)
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
data = res # {"data":res}
|
||||
# print (data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||
"""根据 id 获取 doc_chunks 数据"""
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
# print("get_by_ids:"+SQL)
|
||||
res = await self.db.query(SQL, multirows=True)
|
||||
if res:
|
||||
data = res # [{"data":i} for i in res]
|
||||
# print(data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""过滤掉重复内容"""
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||
table_name=N_T[self.namespace],
|
||||
id_field=N_ID[self.namespace],
|
||||
ids=",".join([f"'{id}'" for id in keys]),
|
||||
)
|
||||
try:
|
||||
await self.db.query(SQL)
|
||||
except Exception as e:
|
||||
logger.error(f"Tidb database error: {e}")
|
||||
print(SQL)
|
||||
res = await self.db.query(SQL, multirows=True)
|
||||
if res:
|
||||
exist_keys = [key["id"] for key in res]
|
||||
data = set([s for s in keys if s not in exist_keys])
|
||||
else:
|
||||
exist_keys = []
|
||||
data = set([s for s in keys if s not in exist_keys])
|
||||
return data
|
||||
|
||||
################ INSERT full_doc AND chunks ################
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
if self.namespace == "text_chunks":
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
**{k1: v1 for k1, v1 in v.items()},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
|
||||
merge_sql = SQL_TEMPLATES["upsert_chunk"]
|
||||
data = []
|
||||
for item in list_data:
|
||||
data.append(
|
||||
{
|
||||
"id": item["__id__"],
|
||||
"content": item["content"],
|
||||
"tokens": item["tokens"],
|
||||
"chunk_order_index": item["chunk_order_index"],
|
||||
"full_doc_id": item["full_doc_id"],
|
||||
"content_vector": f"{item["__vector__"].tolist()}",
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
)
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
if self.namespace == "full_docs":
|
||||
merge_sql = SQL_TEMPLATES["upsert_doc_full"]
|
||||
data = []
|
||||
for k, v in self._data.items():
|
||||
data.append(
|
||||
{
|
||||
"id": k,
|
||||
"content": v["content"],
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
)
|
||||
await self.db.execute(merge_sql, data)
|
||||
return left_data
|
||||
|
||||
async def index_done_callback(self):
|
||||
if self.namespace in ["full_docs", "text_chunks"]:
|
||||
logger.info("full doc and chunk data had been saved into TiDB db!")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
|
||||
def __post_init__(self):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||
"""search from tidb vector"""
|
||||
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
|
||||
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
|
||||
|
||||
params = {
|
||||
"embedding_string": embedding_string,
|
||||
"top_k": top_k,
|
||||
"better_than_threshold": self.cosine_better_than_threshold,
|
||||
}
|
||||
|
||||
results = await self.db.query(
|
||||
SQL_TEMPLATES[self.namespace], params=params, multirows=True
|
||||
)
|
||||
print("vector search result:", results)
|
||||
if not results:
|
||||
return []
|
||||
return results
|
||||
|
||||
###### INSERT entities And relationships ######
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
# ignore, upsert in TiDBKVStorage already
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
if self.namespace == "chunks":
|
||||
return []
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
|
||||
list_data = [
|
||||
{
|
||||
"id": k,
|
||||
**{k1: v1 for k1, v1 in v.items()},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
||||
embeddings_list = []
|
||||
for f in tqdm(
|
||||
asyncio.as_completed(embedding_tasks),
|
||||
total=len(embedding_tasks),
|
||||
desc="Generating embeddings",
|
||||
unit="batch",
|
||||
):
|
||||
embeddings = await f
|
||||
embeddings_list.append(embeddings)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["content_vector"] = embeddings[i]
|
||||
|
||||
if self.namespace == "entities":
|
||||
data = []
|
||||
for item in list_data:
|
||||
param = {
|
||||
"id": item["id"],
|
||||
"name": item["entity_name"],
|
||||
"content": item["content"],
|
||||
"content_vector": f"{item["content_vector"].tolist()}",
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
# update entity_id if node inserted by graph_storage_instance before
|
||||
has = await self.db.query(SQL_TEMPLATES["has_entity"], param)
|
||||
if has["cnt"] != 0:
|
||||
await self.db.execute(SQL_TEMPLATES["update_entity"], param)
|
||||
continue
|
||||
|
||||
data.append(param)
|
||||
if data:
|
||||
merge_sql = SQL_TEMPLATES["insert_entity"]
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
elif self.namespace == "relationships":
|
||||
data = []
|
||||
for item in list_data:
|
||||
param = {
|
||||
"id": item["id"],
|
||||
"source_name": item["src_id"],
|
||||
"target_name": item["tgt_id"],
|
||||
"content": item["content"],
|
||||
"content_vector": f"{item["content_vector"].tolist()}",
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
# update relation_id if node inserted by graph_storage_instance before
|
||||
has = await self.db.query(SQL_TEMPLATES["has_relationship"], param)
|
||||
if has["cnt"] != 0:
|
||||
await self.db.execute(SQL_TEMPLATES["update_relationship"], param)
|
||||
continue
|
||||
|
||||
data.append(param)
|
||||
if data:
|
||||
merge_sql = SQL_TEMPLATES["insert_relationship"]
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
#################### upsert method ################
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
entity_name = node_id
|
||||
entity_type = node_data["entity_type"]
|
||||
description = node_data["description"]
|
||||
source_id = node_data["source_id"]
|
||||
logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
|
||||
content = entity_name + description
|
||||
contents = [content]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
sql = SQL_TEMPLATES["upsert_node"]
|
||||
data = {
|
||||
"workspace": self.db.workspace,
|
||||
"name": entity_name,
|
||||
"entity_type": entity_type,
|
||||
"description": description,
|
||||
"source_chunk_id": source_id,
|
||||
"content": content,
|
||||
"content_vector": f"{content_vector.tolist()}",
|
||||
}
|
||||
await self.db.execute(sql, data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
source_name = source_node_id
|
||||
target_name = target_node_id
|
||||
weight = edge_data["weight"]
|
||||
keywords = edge_data["keywords"]
|
||||
description = edge_data["description"]
|
||||
source_chunk_id = edge_data["source_id"]
|
||||
logger.debug(
|
||||
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
|
||||
)
|
||||
|
||||
content = keywords + source_name + target_name + description
|
||||
contents = [content]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["upsert_edge"]
|
||||
data = {
|
||||
"workspace": self.db.workspace,
|
||||
"source_name": source_name,
|
||||
"target_name": target_name,
|
||||
"weight": weight,
|
||||
"keywords": keywords,
|
||||
"description": description,
|
||||
"source_chunk_id": source_chunk_id,
|
||||
"content": content,
|
||||
"content_vector": f"{content_vector.tolist()}",
|
||||
}
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
# Query
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
sql = SQL_TEMPLATES["has_entity"]
|
||||
param = {"name": node_id, "workspace": self.db.workspace}
|
||||
has = await self.db.query(sql, param)
|
||||
return has["cnt"] != 0
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
sql = SQL_TEMPLATES["has_relationship"]
|
||||
param = {
|
||||
"source_name": source_node_id,
|
||||
"target_name": target_node_id,
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
has = await self.db.query(sql, param)
|
||||
return has["cnt"] != 0
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
sql = SQL_TEMPLATES["node_degree"]
|
||||
param = {"name": node_id, "workspace": self.db.workspace}
|
||||
result = await self.db.query(sql, param)
|
||||
return result["cnt"]
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
||||
return degree
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
sql = SQL_TEMPLATES["get_node"]
|
||||
param = {"name": node_id, "workspace": self.db.workspace}
|
||||
return await self.db.query(sql, param)
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
sql = SQL_TEMPLATES["get_edge"]
|
||||
param = {
|
||||
"source_name": source_node_id,
|
||||
"target_name": target_node_id,
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
return await self.db.query(sql, param)
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> Union[list[tuple[str, str]], None]:
|
||||
sql = SQL_TEMPLATES["get_node_edges"]
|
||||
param = {"source_name": source_node_id, "workspace": self.db.workspace}
|
||||
res = await self.db.query(sql, param, multirows=True)
|
||||
if res:
|
||||
data = [(i["source_name"], i["target_name"]) for i in res]
|
||||
return data
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
N_T = {
|
||||
"full_docs": "LIGHTRAG_DOC_FULL",
|
||||
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||
"chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||
"entities": "LIGHTRAG_GRAPH_NODES",
|
||||
"relationships": "LIGHTRAG_GRAPH_EDGES",
|
||||
}
|
||||
N_ID = {
|
||||
"full_docs": "doc_id",
|
||||
"text_chunks": "chunk_id",
|
||||
"chunks": "chunk_id",
|
||||
"entities": "entity_id",
|
||||
"relationships": "relation_id",
|
||||
}
|
||||
|
||||
TABLES = {
|
||||
"LIGHTRAG_DOC_FULL": {
|
||||
"ddl": """
|
||||
CREATE TABLE LIGHTRAG_DOC_FULL (
|
||||
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||
`doc_id` VARCHAR(256) NOT NULL,
|
||||
`workspace` varchar(1024),
|
||||
`content` LONGTEXT,
|
||||
`meta` JSON,
|
||||
`createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
`updatetime` TIMESTAMP DEFAULT NULL,
|
||||
UNIQUE KEY (`doc_id`)
|
||||
);
|
||||
"""
|
||||
},
|
||||
"LIGHTRAG_DOC_CHUNKS": {
|
||||
"ddl": """
|
||||
CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
||||
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||
`chunk_id` VARCHAR(256) NOT NULL,
|
||||
`full_doc_id` VARCHAR(256) NOT NULL,
|
||||
`workspace` varchar(1024),
|
||||
`chunk_order_index` INT,
|
||||
`tokens` INT,
|
||||
`content` LONGTEXT,
|
||||
`content_vector` VECTOR,
|
||||
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
`updatetime` DATETIME DEFAULT NULL,
|
||||
UNIQUE KEY (`chunk_id`)
|
||||
);
|
||||
"""
|
||||
},
|
||||
"LIGHTRAG_GRAPH_NODES": {
|
||||
"ddl": """
|
||||
CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
||||
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||
`entity_id` VARCHAR(256),
|
||||
`workspace` varchar(1024),
|
||||
`name` VARCHAR(2048),
|
||||
`entity_type` VARCHAR(1024),
|
||||
`description` LONGTEXT,
|
||||
`source_chunk_id` VARCHAR(256),
|
||||
`content` LONGTEXT,
|
||||
`content_vector` VECTOR,
|
||||
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
`updatetime` DATETIME DEFAULT NULL,
|
||||
KEY (`entity_id`)
|
||||
);
|
||||
"""
|
||||
},
|
||||
"LIGHTRAG_GRAPH_EDGES": {
|
||||
"ddl": """
|
||||
CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
||||
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||
`relation_id` VARCHAR(256),
|
||||
`workspace` varchar(1024),
|
||||
`source_name` VARCHAR(2048),
|
||||
`target_name` VARCHAR(2048),
|
||||
`weight` DECIMAL,
|
||||
`keywords` TEXT,
|
||||
`description` LONGTEXT,
|
||||
`source_chunk_id` varchar(256),
|
||||
`content` LONGTEXT,
|
||||
`content_vector` VECTOR,
|
||||
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
`updatetime` DATETIME DEFAULT NULL,
|
||||
KEY (`relation_id`)
|
||||
);
|
||||
"""
|
||||
},
|
||||
"LIGHTRAG_LLM_CACHE": {
|
||||
"ddl": """
|
||||
CREATE TABLE LIGHTRAG_LLM_CACHE (
|
||||
id BIGINT PRIMARY KEY AUTO_INCREMENT,
|
||||
send TEXT,
|
||||
return TEXT,
|
||||
model VARCHAR(1024),
|
||||
createtime DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updatetime DATETIME DEFAULT NULL
|
||||
);
|
||||
"""
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
SQL_TEMPLATES = {
|
||||
# SQL for KVStorage
|
||||
"get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace",
|
||||
"get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace",
|
||||
"get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
|
||||
"get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
|
||||
"filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
|
||||
# SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
|
||||
"upsert_doc_full": """
|
||||
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
|
||||
VALUES (:id, :content, :workspace)
|
||||
ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||
""",
|
||||
"upsert_chunk": """
|
||||
INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
|
||||
VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
|
||||
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||
""",
|
||||
# SQL for VectorStorage
|
||||
"entities": """SELECT n.name as entity_name FROM
|
||||
(SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
|
||||
FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n
|
||||
WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k
|
||||
""",
|
||||
"relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id FROM
|
||||
(SELECT source_name, target_name, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
||||
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e
|
||||
WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k
|
||||
""",
|
||||
"chunks": """SELECT c.id FROM
|
||||
(SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
|
||||
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k
|
||||
""",
|
||||
"has_entity": """
|
||||
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace
|
||||
""",
|
||||
"has_relationship": """
|
||||
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace
|
||||
""",
|
||||
"update_entity": """
|
||||
UPDATE LIGHTRAG_GRAPH_NODES SET
|
||||
entity_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP
|
||||
WHERE workspace = :workspace AND name = :name
|
||||
""",
|
||||
"update_relationship": """
|
||||
UPDATE LIGHTRAG_GRAPH_EDGES SET
|
||||
relation_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP
|
||||
WHERE workspace = :workspace AND source_name = :source_name AND target_name = :target_name
|
||||
""",
|
||||
"insert_entity": """
|
||||
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
|
||||
VALUES(:id, :name, :content, :content_vector, :workspace)
|
||||
""",
|
||||
"insert_relationship": """
|
||||
INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
|
||||
VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
|
||||
""",
|
||||
# SQL for GraphStorage
|
||||
"get_node": """
|
||||
SELECT entity_id AS id, workspace, name, entity_type, description, source_chunk_id AS source_id, content, content_vector
|
||||
FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace
|
||||
""",
|
||||
"get_edge": """
|
||||
SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id AS source_id, content, content_vector
|
||||
FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace
|
||||
""",
|
||||
"get_node_edges": """
|
||||
SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id, content, content_vector
|
||||
FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND workspace = :workspace
|
||||
""",
|
||||
"node_degree": """
|
||||
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace AND :name IN (source_name, target_name)
|
||||
""",
|
||||
"upsert_node": """
|
||||
INSERT INTO LIGHTRAG_GRAPH_NODES(name, content, content_vector, workspace, source_chunk_id, entity_type, description)
|
||||
VALUES(:name, :content, :content_vector, :workspace, :source_chunk_id, :entity_type, :description)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
|
||||
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP,
|
||||
source_chunk_id = VALUES(source_chunk_id), entity_type = VALUES(entity_type), description = VALUES(description)
|
||||
""",
|
||||
"upsert_edge": """
|
||||
INSERT INTO LIGHTRAG_GRAPH_EDGES(source_name, target_name, content, content_vector,
|
||||
workspace, weight, keywords, description, source_chunk_id)
|
||||
VALUES(:source_name, :target_name, :content, :content_vector,
|
||||
:workspace, :weight, :keywords, :description, :source_chunk_id)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
|
||||
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP,
|
||||
weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description),
|
||||
source_chunk_id = VALUES(source_chunk_id)
|
||||
""",
|
||||
}
|
||||
730
minirag/llm.py
730
minirag/llm.py
@@ -1,729 +1,5 @@
|
||||
import os
|
||||
import copy
|
||||
from functools import lru_cache
|
||||
import json
|
||||
import aioboto3
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import ollama
|
||||
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
AsyncAzureOpenAI,
|
||||
)
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Callable, Any
|
||||
from .base import BaseKVStorage
|
||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def openai_complete_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
|
||||
openai_async_client = (
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def azure_openai_complete_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
**kwargs,
|
||||
):
|
||||
if api_key:
|
||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||
if base_url:
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
)
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
if prompt is not None:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Generic error for issues related to Amazon Bedrock"""
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, max=60),
|
||||
retry=retry_if_exception_type((BedrockError)),
|
||||
)
|
||||
async def bedrock_complete_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||
)
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||
)
|
||||
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||
"AWS_SESSION_TOKEN", aws_session_token
|
||||
)
|
||||
|
||||
# Fix message history format
|
||||
messages = []
|
||||
for history_message in history_messages:
|
||||
message = copy.copy(history_message)
|
||||
message["content"] = [{"text": message["content"]}]
|
||||
messages.append(message)
|
||||
|
||||
# Add user prompt
|
||||
messages.append({"role": "user", "content": [{"text": prompt}]})
|
||||
|
||||
# Initialize Converse API arguments
|
||||
args = {"modelId": model, "messages": messages}
|
||||
|
||||
# Define system prompt
|
||||
if system_prompt:
|
||||
args["system"] = [{"text": system_prompt}]
|
||||
|
||||
# Map and set up inference parameters
|
||||
inference_params_map = {
|
||||
"max_tokens": "maxTokens",
|
||||
"top_p": "topP",
|
||||
"stop_sequences": "stopSequences",
|
||||
}
|
||||
if inference_params := list(
|
||||
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
|
||||
):
|
||||
args["inferenceConfig"] = {}
|
||||
for param in inference_params:
|
||||
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
||||
kwargs.pop(param)
|
||||
)
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
# Call model via Converse API
|
||||
session = aioboto3.Session()
|
||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||
try:
|
||||
response = await bedrock_async_client.converse(**args, **kwargs)
|
||||
except Exception as e:
|
||||
raise BedrockError(e)
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{
|
||||
args_hash: {
|
||||
"return": response["output"]["message"]["content"][0]["text"],
|
||||
"model": model,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return response["output"]["message"]["content"][0]["text"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def initialize_hf_model(model_name):
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, device_map="auto", trust_remote_code=True#False
|
||||
)
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
if hf_tokenizer.pad_token is None:
|
||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||
|
||||
return hf_model, hf_tokenizer
|
||||
|
||||
|
||||
async def hf_model_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
model_name = model
|
||||
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
input_prompt = ""
|
||||
try:
|
||||
input_prompt = hf_tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
ori_message = copy.deepcopy(messages)
|
||||
if messages[0]["role"] == "system":
|
||||
messages[1]["content"] = (
|
||||
"<system>"
|
||||
+ messages[0]["content"]
|
||||
+ "</system>\n"
|
||||
+ messages[1]["content"]
|
||||
)
|
||||
messages = messages[1:]
|
||||
input_prompt = hf_tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
len_message = len(ori_message)
|
||||
for msgid in range(len_message):
|
||||
input_prompt = (
|
||||
input_prompt
|
||||
+ "<"
|
||||
+ ori_message[msgid]["role"]
|
||||
+ ">"
|
||||
+ ori_message[msgid]["content"]
|
||||
+ "</"
|
||||
+ ori_message[msgid]["role"]
|
||||
+ ">\n"
|
||||
)
|
||||
|
||||
input_ids = hf_tokenizer(
|
||||
input_prompt, return_tensors="pt", padding=True, truncation=True
|
||||
).to("cuda")
|
||||
torch.cuda.empty_cache()
|
||||
# inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
|
||||
output = hf_model.generate(
|
||||
**input_ids, max_new_tokens=500, num_return_sequences=1, early_stopping=True
|
||||
)
|
||||
response_text = hf_tokenizer.decode(
|
||||
output[0][len(input_ids[0]) :], skip_special_tokens=True
|
||||
)
|
||||
|
||||
|
||||
FINDSTRING = "<|COMPLETE|>"
|
||||
last_assistant_index = response_text.find(FINDSTRING)
|
||||
|
||||
if last_assistant_index != -1:
|
||||
response_text = response_text[:last_assistant_index + len(FINDSTRING)]
|
||||
else:
|
||||
response_text = response_text
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
async def ollama_model_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
kwargs.pop("max_tokens", None)
|
||||
kwargs.pop("response_format", None)
|
||||
host = kwargs.pop("host", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
|
||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||
|
||||
result = response["message"]["content"]
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def initialize_lmdeploy_pipeline(
|
||||
model,
|
||||
tp=1,
|
||||
chat_template=None,
|
||||
log_level="WARNING",
|
||||
model_format="hf",
|
||||
quant_policy=0,
|
||||
):
|
||||
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
|
||||
|
||||
lmdeploy_pipe = pipeline(
|
||||
model_path=model,
|
||||
backend_config=TurbomindEngineConfig(
|
||||
tp=tp, model_format=model_format, quant_policy=quant_policy
|
||||
),
|
||||
chat_template_config=ChatTemplateConfig(model_name=chat_template)
|
||||
if chat_template
|
||||
else None,
|
||||
log_level="WARNING",
|
||||
)
|
||||
return lmdeploy_pipe
|
||||
|
||||
|
||||
async def lmdeploy_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
chat_template=None,
|
||||
model_format="hf",
|
||||
quant_policy=0,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
model (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a turbomind model which is
|
||||
converted by `lmdeploy convert` command or download
|
||||
from ii) and iii).
|
||||
- ii) The model_id of a lmdeploy-quantized model hosted
|
||||
inside a model repo on huggingface.co, such as
|
||||
"InternLM/internlm-chat-20b-4bit",
|
||||
"lmdeploy/llama2-chat-70b-4bit", etc.
|
||||
- iii) The model_id of a model hosted inside a model repo
|
||||
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||
and so on.
|
||||
chat_template (str): needed when model is a pytorch model on
|
||||
huggingface.co, such as "internlm-chat-7b",
|
||||
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
|
||||
and when the model name of local path did not match the original model name in HF.
|
||||
tp (int): tensor parallel
|
||||
prompt (Union[str, List[str]]): input texts to be completed.
|
||||
do_preprocess (bool): whether pre-process the messages. Default to
|
||||
True, which means chat_template will be applied.
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be True.
|
||||
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
|
||||
Default to be False, which means greedy decoding will be applied.
|
||||
"""
|
||||
try:
|
||||
import lmdeploy
|
||||
from lmdeploy import version_info, GenerationConfig
|
||||
except Exception:
|
||||
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
|
||||
|
||||
kwargs.pop("response_format", None)
|
||||
max_new_tokens = kwargs.pop("max_tokens", 512)
|
||||
tp = kwargs.pop("tp", 1)
|
||||
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
|
||||
do_preprocess = kwargs.pop("do_preprocess", True)
|
||||
do_sample = kwargs.pop("do_sample", False)
|
||||
gen_params = kwargs
|
||||
|
||||
version = version_info
|
||||
if do_sample is not None and version < (0, 6, 0):
|
||||
raise RuntimeError(
|
||||
"`do_sample` parameter is not supported by lmdeploy until "
|
||||
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
|
||||
)
|
||||
else:
|
||||
do_sample = True
|
||||
gen_params.update(do_sample=do_sample)
|
||||
|
||||
lmdeploy_pipe = initialize_lmdeploy_pipeline(
|
||||
model=model,
|
||||
tp=tp,
|
||||
chat_template=chat_template,
|
||||
model_format=model_format,
|
||||
quant_policy=quant_policy,
|
||||
log_level="WARNING",
|
||||
)
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
gen_config = GenerationConfig(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**gen_params,
|
||||
)
|
||||
|
||||
response = ""
|
||||
async for res in lmdeploy_pipe.generate(
|
||||
messages,
|
||||
gen_config=gen_config,
|
||||
do_preprocess=do_preprocess,
|
||||
stream_response=False,
|
||||
session_id=1,
|
||||
):
|
||||
response += res.response
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
|
||||
return response
|
||||
|
||||
|
||||
async def gpt_4o_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def gpt_4o_mini_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o-mini",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def azure_openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await azure_openai_complete_if_cache(
|
||||
"conversation-4o-mini",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def bedrock_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await bedrock_complete_if_cache(
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def hf_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await hf_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def ollama_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await ollama_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def openai_embedding(
|
||||
texts: list[str],
|
||||
model: str = "text-embedding-3-small",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
|
||||
openai_async_client = (
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="float"
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def azure_openai_embedding(
|
||||
texts: list[str],
|
||||
model: str = "text-embedding-3-small",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||
if base_url:
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
)
|
||||
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="float"
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def siliconcloud_embedding(
|
||||
texts: list[str],
|
||||
model: str = "netease-youdao/bce-embedding-base_v1",
|
||||
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
|
||||
max_token_size: int = 512,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key and not api_key.startswith("Bearer "):
|
||||
api_key = "Bearer " + api_key
|
||||
|
||||
headers = {"Authorization": api_key, "Content-Type": "application/json"}
|
||||
|
||||
truncate_texts = [text[0:max_token_size] for text in texts]
|
||||
|
||||
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
|
||||
|
||||
base64_strings = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(base_url, headers=headers, json=payload) as response:
|
||||
content = await response.json()
|
||||
if "code" in content:
|
||||
raise ValueError(content)
|
||||
base64_strings = [item["embedding"] for item in content["data"]]
|
||||
|
||||
embeddings = []
|
||||
for string in base64_strings:
|
||||
decode_bytes = base64.b64decode(string)
|
||||
n = len(decode_bytes) // 4
|
||||
float_array = struct.unpack("<" + "f" * n, decode_bytes)
|
||||
embeddings.append(float_array)
|
||||
return np.array(embeddings)
|
||||
|
||||
|
||||
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
# @retry(
|
||||
# stop=stop_after_attempt(3),
|
||||
# wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
||||
# )
|
||||
async def bedrock_embedding(
|
||||
texts: list[str],
|
||||
model: str = "amazon.titan-embed-text-v2:0",
|
||||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
) -> np.ndarray:
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||
)
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||
)
|
||||
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||
"AWS_SESSION_TOKEN", aws_session_token
|
||||
)
|
||||
|
||||
session = aioboto3.Session()
|
||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||
if (model_provider := model.split(".")[0]) == "amazon":
|
||||
embed_texts = []
|
||||
for text in texts:
|
||||
if "v2" in model:
|
||||
body = json.dumps(
|
||||
{
|
||||
"inputText": text,
|
||||
# 'dimensions': embedding_dim,
|
||||
"embeddingTypes": ["float"],
|
||||
}
|
||||
)
|
||||
elif "v1" in model:
|
||||
body = json.dumps({"inputText": text})
|
||||
else:
|
||||
raise ValueError(f"Model {model} is not supported!")
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
modelId=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = await response.get("body").json()
|
||||
|
||||
embed_texts.append(response_body["embedding"])
|
||||
elif model_provider == "cohere":
|
||||
body = json.dumps(
|
||||
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
||||
)
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
model=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
embed_texts = response_body["embeddings"]
|
||||
else:
|
||||
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
||||
|
||||
return np.array(embed_texts)
|
||||
|
||||
|
||||
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||
embed_model.to('cuda:0')
|
||||
input_ids = tokenizer(
|
||||
texts, return_tensors="pt", padding=True, truncation=True
|
||||
).input_ids.cuda()
|
||||
with torch.no_grad():
|
||||
outputs = embed_model(input_ids)
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
return embeddings.detach().cpu().numpy()
|
||||
|
||||
|
||||
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||
embed_text = []
|
||||
ollama_client = ollama.Client(**kwargs)
|
||||
for text in texts:
|
||||
data = ollama_client.embeddings(model=embed_model, prompt=text)
|
||||
embed_text.append(data["embedding"])
|
||||
|
||||
return embed_text
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
@@ -793,6 +69,8 @@ class MultiModel:
|
||||
self, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
kwargs.pop("model", None) # stop from overwriting the custom model name
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
kwargs.pop("mode", None)
|
||||
next_model = self._next_model()
|
||||
args = dict(
|
||||
prompt=prompt,
|
||||
@@ -809,6 +87,8 @@ if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
from lightrag.llm.openai import gpt_4o_mini_complete
|
||||
|
||||
result = await gpt_4o_mini_complete("How are you?")
|
||||
print(result)
|
||||
|
||||
|
||||
0
minirag/llm/__init__.py
Normal file
0
minirag/llm/__init__.py
Normal file
189
minirag/llm/azure_openai.py
Normal file
189
minirag/llm/azure_openai.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Azure OpenAI LLM Interface Module
|
||||
==========================
|
||||
|
||||
This module provides interfaces for interacting with aure openai's language models,
|
||||
including text generation and embedding capabilities.
|
||||
|
||||
Author: Lightrag team
|
||||
Created: 2024-01-24
|
||||
License: MIT License
|
||||
|
||||
Copyright (c) 2024 Lightrag
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Change Log:
|
||||
- 1.0.0 (2024-01-24): Initial release
|
||||
* Added async chat completion support
|
||||
* Added embedding generation
|
||||
* Added stream response capability
|
||||
|
||||
Dependencies:
|
||||
- openai
|
||||
- numpy
|
||||
- pipmaster
|
||||
- Python >= 3.10
|
||||
|
||||
Usage:
|
||||
from llm_interfaces.azure_openai import azure_openai_model_complete, azure_openai_embed
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "lightrag Team"
|
||||
__status__ = "Production"
|
||||
|
||||
|
||||
import os
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("openai"):
|
||||
pm.install("openai")
|
||||
if not pm.is_installed("tenacity"):
|
||||
pm.install("tenacity")
|
||||
|
||||
from openai import (
|
||||
AsyncAzureOpenAI,
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
locate_json_string_body_from_string,
|
||||
safe_unicode_decode,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APIConnectionError)
|
||||
),
|
||||
)
|
||||
async def azure_openai_complete_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
api_version=None,
|
||||
**kwargs,
|
||||
):
|
||||
if api_key:
|
||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||
if base_url:
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||
if api_version:
|
||||
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
|
||||
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
)
|
||||
kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
if prompt is not None:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if "response_format" in kwargs:
|
||||
response = await openai_async_client.beta.chat.completions.parse(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
else:
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
if hasattr(response, "__aiter__"):
|
||||
|
||||
async def inner():
|
||||
async for chunk in response:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
content = chunk.choices[0].delta.content
|
||||
if content is None:
|
||||
continue
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
yield content
|
||||
|
||||
return inner()
|
||||
else:
|
||||
content = response.choices[0].message.content
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
return content
|
||||
|
||||
|
||||
async def azure_openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
result = await azure_openai_complete_if_cache(
|
||||
os.getenv("LLM_MODEL", "gpt-4o-mini"),
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction: # TODO: use JSON API
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def azure_openai_embed(
|
||||
texts: list[str],
|
||||
model: str = "text-embedding-3-small",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
api_version: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||
if base_url:
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||
if api_version:
|
||||
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
|
||||
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
)
|
||||
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="float"
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
225
minirag/llm/bedrock.py
Normal file
225
minirag/llm/bedrock.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Bedrock LLM Interface Module
|
||||
==========================
|
||||
|
||||
This module provides interfaces for interacting with Bedrock's language models,
|
||||
including text generation and embedding capabilities.
|
||||
|
||||
Author: Lightrag team
|
||||
Created: 2024-01-24
|
||||
License: MIT License
|
||||
|
||||
Copyright (c) 2024 Lightrag
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Change Log:
|
||||
- 1.0.0 (2024-01-24): Initial release
|
||||
* Added async chat completion support
|
||||
* Added embedding generation
|
||||
* Added stream response capability
|
||||
|
||||
Dependencies:
|
||||
- aioboto3, tenacity
|
||||
- numpy
|
||||
- pipmaster
|
||||
- Python >= 3.10
|
||||
|
||||
Usage:
|
||||
from llm_interfaces.bebrock import bebrock_model_complete, bebrock_embed
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "lightrag Team"
|
||||
__status__ = "Production"
|
||||
|
||||
|
||||
import copy
|
||||
import os
|
||||
import json
|
||||
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
if not pm.is_installed("aioboto3"):
|
||||
pm.install("aioboto3")
|
||||
if not pm.is_installed("tenacity"):
|
||||
pm.install("tenacity")
|
||||
import aioboto3
|
||||
import numpy as np
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import (
|
||||
locate_json_string_body_from_string,
|
||||
)
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Generic error for issues related to Amazon Bedrock"""
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, max=60),
|
||||
retry=retry_if_exception_type((BedrockError)),
|
||||
)
|
||||
async def bedrock_complete_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||
)
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||
)
|
||||
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||
"AWS_SESSION_TOKEN", aws_session_token
|
||||
)
|
||||
kwargs.pop("hashing_kv", None)
|
||||
# Fix message history format
|
||||
messages = []
|
||||
for history_message in history_messages:
|
||||
message = copy.copy(history_message)
|
||||
message["content"] = [{"text": message["content"]}]
|
||||
messages.append(message)
|
||||
|
||||
# Add user prompt
|
||||
messages.append({"role": "user", "content": [{"text": prompt}]})
|
||||
|
||||
# Initialize Converse API arguments
|
||||
args = {"modelId": model, "messages": messages}
|
||||
|
||||
# Define system prompt
|
||||
if system_prompt:
|
||||
args["system"] = [{"text": system_prompt}]
|
||||
|
||||
# Map and set up inference parameters
|
||||
inference_params_map = {
|
||||
"max_tokens": "maxTokens",
|
||||
"top_p": "topP",
|
||||
"stop_sequences": "stopSequences",
|
||||
}
|
||||
if inference_params := list(
|
||||
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
|
||||
):
|
||||
args["inferenceConfig"] = {}
|
||||
for param in inference_params:
|
||||
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
||||
kwargs.pop(param)
|
||||
)
|
||||
|
||||
# Call model via Converse API
|
||||
session = aioboto3.Session()
|
||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||
try:
|
||||
response = await bedrock_async_client.converse(**args, **kwargs)
|
||||
except Exception as e:
|
||||
raise BedrockError(e)
|
||||
|
||||
return response["output"]["message"]["content"][0]["text"]
|
||||
|
||||
|
||||
async def bedrock_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
result = await bedrock_complete_if_cache(
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction: # TODO: use JSON API
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
# @retry(
|
||||
# stop=stop_after_attempt(3),
|
||||
# wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
||||
# )
|
||||
async def bedrock_embed(
|
||||
texts: list[str],
|
||||
model: str = "amazon.titan-embed-text-v2:0",
|
||||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
) -> np.ndarray:
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||
)
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||
)
|
||||
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||
"AWS_SESSION_TOKEN", aws_session_token
|
||||
)
|
||||
|
||||
session = aioboto3.Session()
|
||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||
if (model_provider := model.split(".")[0]) == "amazon":
|
||||
embed_texts = []
|
||||
for text in texts:
|
||||
if "v2" in model:
|
||||
body = json.dumps(
|
||||
{
|
||||
"inputText": text,
|
||||
# 'dimensions': embedding_dim,
|
||||
"embeddingTypes": ["float"],
|
||||
}
|
||||
)
|
||||
elif "v1" in model:
|
||||
body = json.dumps({"inputText": text})
|
||||
else:
|
||||
raise ValueError(f"Model {model} is not supported!")
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
modelId=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = await response.get("body").json()
|
||||
|
||||
embed_texts.append(response_body["embedding"])
|
||||
elif model_provider == "cohere":
|
||||
body = json.dumps(
|
||||
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
||||
)
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
model=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
embed_texts = response_body["embeddings"]
|
||||
else:
|
||||
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
||||
|
||||
return np.array(embed_texts)
|
||||
188
minirag/llm/hf.py
Normal file
188
minirag/llm/hf.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Hugging face LLM Interface Module
|
||||
==========================
|
||||
|
||||
This module provides interfaces for interacting with Hugging face's language models,
|
||||
including text generation and embedding capabilities.
|
||||
|
||||
Author: Lightrag team
|
||||
Created: 2024-01-24
|
||||
License: MIT License
|
||||
|
||||
Copyright (c) 2024 Lightrag
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Change Log:
|
||||
- 1.0.0 (2024-01-24): Initial release
|
||||
* Added async chat completion support
|
||||
* Added embedding generation
|
||||
* Added stream response capability
|
||||
|
||||
Dependencies:
|
||||
- transformers
|
||||
- numpy
|
||||
- pipmaster
|
||||
- Python >= 3.10
|
||||
|
||||
Usage:
|
||||
from llm_interfaces.hf import hf_model_complete, hf_embed
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "lightrag Team"
|
||||
__status__ = "Production"
|
||||
|
||||
import copy
|
||||
import os
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("transformers"):
|
||||
pm.install("transformers")
|
||||
if not pm.is_installed("torch"):
|
||||
pm.install("torch")
|
||||
if not pm.is_installed("tenacity"):
|
||||
pm.install("tenacity")
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from functools import lru_cache
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from lightrag.utils import (
|
||||
locate_json_string_body_from_string,
|
||||
)
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def initialize_hf_model(model_name):
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
if hf_tokenizer.pad_token is None:
|
||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||
|
||||
return hf_model, hf_tokenizer
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def hf_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
**kwargs,
|
||||
) -> str:
|
||||
model_name = model
|
||||
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
kwargs.pop("hashing_kv", None)
|
||||
input_prompt = ""
|
||||
try:
|
||||
input_prompt = hf_tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
ori_message = copy.deepcopy(messages)
|
||||
if messages[0]["role"] == "system":
|
||||
messages[1]["content"] = (
|
||||
"<system>"
|
||||
+ messages[0]["content"]
|
||||
+ "</system>\n"
|
||||
+ messages[1]["content"]
|
||||
)
|
||||
messages = messages[1:]
|
||||
input_prompt = hf_tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
len_message = len(ori_message)
|
||||
for msgid in range(len_message):
|
||||
input_prompt = (
|
||||
input_prompt
|
||||
+ "<"
|
||||
+ ori_message[msgid]["role"]
|
||||
+ ">"
|
||||
+ ori_message[msgid]["content"]
|
||||
+ "</"
|
||||
+ ori_message[msgid]["role"]
|
||||
+ ">\n"
|
||||
)
|
||||
|
||||
input_ids = hf_tokenizer(
|
||||
input_prompt, return_tensors="pt", padding=True, truncation=True
|
||||
).to("cuda")
|
||||
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
|
||||
output = hf_model.generate(
|
||||
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
|
||||
)
|
||||
response_text = hf_tokenizer.decode(
|
||||
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||
)
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
async def hf_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
result = await hf_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction: # TODO: use JSON API
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||
device = next(embed_model.parameters()).device
|
||||
input_ids = tokenizer(
|
||||
texts, return_tensors="pt", padding=True, truncation=True
|
||||
).input_ids.to(device)
|
||||
with torch.no_grad():
|
||||
outputs = embed_model(input_ids)
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
if embeddings.dtype == torch.bfloat16:
|
||||
return embeddings.detach().to(torch.float32).cpu().numpy()
|
||||
else:
|
||||
return embeddings.detach().cpu().numpy()
|
||||
86
minirag/llm/jina.py
Normal file
86
minirag/llm/jina.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Jina Embedding Interface Module
|
||||
==========================
|
||||
|
||||
This module provides interfaces for interacting with jina system,
|
||||
including embedding capabilities.
|
||||
|
||||
Author: Lightrag team
|
||||
Created: 2024-01-24
|
||||
License: MIT License
|
||||
|
||||
Copyright (c) 2024 Lightrag
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Change Log:
|
||||
- 1.0.0 (2024-01-24): Initial release
|
||||
* Added embedding generation
|
||||
|
||||
Dependencies:
|
||||
- tenacity
|
||||
- numpy
|
||||
- pipmaster
|
||||
- Python >= 3.10
|
||||
|
||||
Usage:
|
||||
from llm_interfaces.jina import jina_embed
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "lightrag Team"
|
||||
__status__ = "Production"
|
||||
|
||||
import os
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("lmdeploy"):
|
||||
pm.install("lmdeploy")
|
||||
if not pm.is_installed("tenacity"):
|
||||
pm.install("tenacity")
|
||||
|
||||
|
||||
import numpy as np
|
||||
import aiohttp
|
||||
|
||||
|
||||
async def fetch_data(url, headers, data):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
response_json = await response.json()
|
||||
data_list = response_json.get("data", [])
|
||||
return data_list
|
||||
|
||||
|
||||
async def jina_embed(
|
||||
texts: list[str],
|
||||
dimensions: int = 1024,
|
||||
late_chunking: bool = False,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["JINA_API_KEY"] = api_key
|
||||
url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
||||
}
|
||||
data = {
|
||||
"model": "jina-embeddings-v3",
|
||||
"normalized": True,
|
||||
"embedding_type": "float",
|
||||
"dimensions": f"{dimensions}",
|
||||
"late_chunking": late_chunking,
|
||||
"input": texts,
|
||||
}
|
||||
data_list = await fetch_data(url, headers, data)
|
||||
return np.array([dp["embedding"] for dp in data_list])
|
||||
191
minirag/llm/lmdeploy.py
Normal file
191
minirag/llm/lmdeploy.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
LMDeploy LLM Interface Module
|
||||
==========================
|
||||
|
||||
This module provides interfaces for interacting with LMDeploy's language models,
|
||||
including text generation and embedding capabilities.
|
||||
|
||||
Author: Lightrag team
|
||||
Created: 2024-01-24
|
||||
License: MIT License
|
||||
|
||||
Copyright (c) 2024 Lightrag
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Change Log:
|
||||
- 1.0.0 (2024-01-24): Initial release
|
||||
* Added async chat completion support
|
||||
* Added embedding generation
|
||||
* Added stream response capability
|
||||
|
||||
Dependencies:
|
||||
- tenacity
|
||||
- numpy
|
||||
- pipmaster
|
||||
- Python >= 3.10
|
||||
|
||||
Usage:
|
||||
from llm_interfaces.lmdeploy import lmdeploy_model_complete, lmdeploy_embed
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "lightrag Team"
|
||||
__status__ = "Production"
|
||||
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("lmdeploy"):
|
||||
pm.install("lmdeploy[all]")
|
||||
if not pm.is_installed("tenacity"):
|
||||
pm.install("tenacity")
|
||||
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def initialize_lmdeploy_pipeline(
|
||||
model,
|
||||
tp=1,
|
||||
chat_template=None,
|
||||
log_level="WARNING",
|
||||
model_format="hf",
|
||||
quant_policy=0,
|
||||
):
|
||||
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
|
||||
|
||||
lmdeploy_pipe = pipeline(
|
||||
model_path=model,
|
||||
backend_config=TurbomindEngineConfig(
|
||||
tp=tp, model_format=model_format, quant_policy=quant_policy
|
||||
),
|
||||
chat_template_config=(
|
||||
ChatTemplateConfig(model_name=chat_template) if chat_template else None
|
||||
),
|
||||
log_level="WARNING",
|
||||
)
|
||||
return lmdeploy_pipe
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def lmdeploy_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
chat_template=None,
|
||||
model_format="hf",
|
||||
quant_policy=0,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
model (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a turbomind model which is
|
||||
converted by `lmdeploy convert` command or download
|
||||
from ii) and iii).
|
||||
- ii) The model_id of a lmdeploy-quantized model hosted
|
||||
inside a model repo on huggingface.co, such as
|
||||
"InternLM/internlm-chat-20b-4bit",
|
||||
"lmdeploy/llama2-chat-70b-4bit", etc.
|
||||
- iii) The model_id of a model hosted inside a model repo
|
||||
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||
and so on.
|
||||
chat_template (str): needed when model is a pytorch model on
|
||||
huggingface.co, such as "internlm-chat-7b",
|
||||
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
|
||||
and when the model name of local path did not match the original model name in HF.
|
||||
tp (int): tensor parallel
|
||||
prompt (Union[str, List[str]]): input texts to be completed.
|
||||
do_preprocess (bool): whether pre-process the messages. Default to
|
||||
True, which means chat_template will be applied.
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be True.
|
||||
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
|
||||
Default to be False, which means greedy decoding will be applied.
|
||||
"""
|
||||
try:
|
||||
import lmdeploy
|
||||
from lmdeploy import version_info, GenerationConfig
|
||||
except Exception:
|
||||
raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("response_format", None)
|
||||
max_new_tokens = kwargs.pop("max_tokens", 512)
|
||||
tp = kwargs.pop("tp", 1)
|
||||
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
|
||||
do_preprocess = kwargs.pop("do_preprocess", True)
|
||||
do_sample = kwargs.pop("do_sample", False)
|
||||
gen_params = kwargs
|
||||
|
||||
version = version_info
|
||||
if do_sample is not None and version < (0, 6, 0):
|
||||
raise RuntimeError(
|
||||
"`do_sample` parameter is not supported by lmdeploy until "
|
||||
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
|
||||
)
|
||||
else:
|
||||
do_sample = True
|
||||
gen_params.update(do_sample=do_sample)
|
||||
|
||||
lmdeploy_pipe = initialize_lmdeploy_pipeline(
|
||||
model=model,
|
||||
tp=tp,
|
||||
chat_template=chat_template,
|
||||
model_format=model_format,
|
||||
quant_policy=quant_policy,
|
||||
log_level="WARNING",
|
||||
)
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
gen_config = GenerationConfig(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**gen_params,
|
||||
)
|
||||
|
||||
response = ""
|
||||
async for res in lmdeploy_pipe.generate(
|
||||
messages,
|
||||
gen_config=gen_config,
|
||||
do_preprocess=do_preprocess,
|
||||
stream_response=False,
|
||||
session_id=1,
|
||||
):
|
||||
response += res.response
|
||||
return response
|
||||
224
minirag/llm/lollms.py
Normal file
224
minirag/llm/lollms.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
LoLLMs (Lord of Large Language Models) Interface Module
|
||||
=====================================================
|
||||
|
||||
This module provides the official interface for interacting with LoLLMs (Lord of Large Language and multimodal Systems),
|
||||
a unified framework for AI model interaction and deployment.
|
||||
|
||||
LoLLMs is designed as a "one tool to rule them all" solution, providing seamless integration
|
||||
with various AI models while maintaining high performance and user-friendly interfaces.
|
||||
|
||||
Author: ParisNeo
|
||||
Created: 2024-01-24
|
||||
License: Apache 2.0
|
||||
|
||||
Copyright (c) 2024 ParisNeo
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
Version: 2.0.0
|
||||
|
||||
Change Log:
|
||||
- 2.0.0 (2024-01-24):
|
||||
* Added async support for model inference
|
||||
* Implemented streaming capabilities
|
||||
* Added embedding generation functionality
|
||||
* Enhanced parameter handling
|
||||
* Improved error handling and timeout management
|
||||
|
||||
Dependencies:
|
||||
- aiohttp
|
||||
- numpy
|
||||
- Python >= 3.10
|
||||
|
||||
Features:
|
||||
- Async text generation with streaming support
|
||||
- Embedding generation
|
||||
- Configurable model parameters
|
||||
- System prompt and chat history support
|
||||
- Timeout handling
|
||||
- API key authentication
|
||||
|
||||
Usage:
|
||||
from llm_interfaces.lollms import lollms_model_complete, lollms_embed
|
||||
|
||||
Project Repository: https://github.com/ParisNeo/lollms
|
||||
Documentation: https://github.com/ParisNeo/lollms/docs
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "ParisNeo"
|
||||
__status__ = "Production"
|
||||
__project_url__ = "https://github.com/ParisNeo/lollms"
|
||||
__doc_url__ = "https://github.com/ParisNeo/lollms/docs"
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import AsyncIterator
|
||||
else:
|
||||
from collections.abc import AsyncIterator
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
if not pm.is_installed("aiohttp"):
|
||||
pm.install("aiohttp")
|
||||
if not pm.is_installed("tenacity"):
|
||||
pm.install("tenacity")
|
||||
|
||||
import aiohttp
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
|
||||
from typing import Union, List
|
||||
import numpy as np
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def lollms_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
base_url="http://localhost:9600",
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
"""Client implementation for lollms generation."""
|
||||
|
||||
stream = True if kwargs.get("stream") else False
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
headers = (
|
||||
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
if api_key
|
||||
else {"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
# Extract lollms specific parameters
|
||||
request_data = {
|
||||
"prompt": prompt,
|
||||
"model_name": model,
|
||||
"personality": kwargs.get("personality", -1),
|
||||
"n_predict": kwargs.get("n_predict", None),
|
||||
"stream": stream,
|
||||
"temperature": kwargs.get("temperature", 0.1),
|
||||
"top_k": kwargs.get("top_k", 50),
|
||||
"top_p": kwargs.get("top_p", 0.95),
|
||||
"repeat_penalty": kwargs.get("repeat_penalty", 0.8),
|
||||
"repeat_last_n": kwargs.get("repeat_last_n", 40),
|
||||
"seed": kwargs.get("seed", None),
|
||||
"n_threads": kwargs.get("n_threads", 8),
|
||||
}
|
||||
|
||||
# Prepare the full prompt including history
|
||||
full_prompt = ""
|
||||
if system_prompt:
|
||||
full_prompt += f"{system_prompt}\n"
|
||||
for msg in history_messages:
|
||||
full_prompt += f"{msg['role']}: {msg['content']}\n"
|
||||
full_prompt += prompt
|
||||
|
||||
request_data["prompt"] = full_prompt
|
||||
timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
|
||||
if stream:
|
||||
|
||||
async def inner():
|
||||
async with session.post(
|
||||
f"{base_url}/lollms_generate", json=request_data
|
||||
) as response:
|
||||
async for line in response.content:
|
||||
yield line.decode().strip()
|
||||
|
||||
return inner()
|
||||
else:
|
||||
async with session.post(
|
||||
f"{base_url}/lollms_generate", json=request_data
|
||||
) as response:
|
||||
return await response.text()
|
||||
|
||||
|
||||
async def lollms_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
"""Complete function for lollms model generation."""
|
||||
|
||||
# Extract and remove keyword_extraction from kwargs if present
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
|
||||
# Get model name from config
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
|
||||
# If keyword extraction is needed, we might need to modify the prompt
|
||||
# or add specific parameters for JSON output (if lollms supports it)
|
||||
if keyword_extraction:
|
||||
# Note: You might need to adjust this based on how lollms handles structured output
|
||||
pass
|
||||
|
||||
return await lollms_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def lollms_embed(
|
||||
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of texts using lollms server.
|
||||
|
||||
Args:
|
||||
texts: List of strings to embed
|
||||
embed_model: Model name (not used directly as lollms uses configured vectorizer)
|
||||
base_url: URL of the lollms server
|
||||
**kwargs: Additional arguments passed to the request
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings
|
||||
"""
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
headers = (
|
||||
{"Content-Type": "application/json", "Authorization": api_key}
|
||||
if api_key
|
||||
else {"Content-Type": "application/json"}
|
||||
)
|
||||
async with aiohttp.ClientSession(headers=headers) as session:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
request_data = {"text": text}
|
||||
|
||||
async with session.post(
|
||||
f"{base_url}/lollms_embed",
|
||||
json=request_data,
|
||||
) as response:
|
||||
result = await response.json()
|
||||
embeddings.append(result["vector"])
|
||||
|
||||
return np.array(embeddings)
|
||||
108
minirag/llm/nvidia_openai.py
Normal file
108
minirag/llm/nvidia_openai.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
OpenAI LLM Interface Module
|
||||
==========================
|
||||
|
||||
This module provides interfaces for interacting with openai's language models,
|
||||
including text generation and embedding capabilities.
|
||||
|
||||
Author: Lightrag team
|
||||
Created: 2024-01-24
|
||||
License: MIT License
|
||||
|
||||
Copyright (c) 2024 Lightrag
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Change Log:
|
||||
- 1.0.0 (2024-01-24): Initial release
|
||||
* Added async chat completion support
|
||||
* Added embedding generation
|
||||
* Added stream response capability
|
||||
|
||||
Dependencies:
|
||||
- openai
|
||||
- numpy
|
||||
- pipmaster
|
||||
- Python >= 3.10
|
||||
|
||||
Usage:
|
||||
from llm_interfaces.nvidia_openai import nvidia_openai_model_complete, nvidia_openai_embed
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "lightrag Team"
|
||||
__status__ = "Production"
|
||||
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("openai"):
|
||||
pm.install("openai")
|
||||
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
)
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def nvidia_openai_embed(
|
||||
texts: list[str],
|
||||
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
|
||||
# refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
api_key: str = None,
|
||||
input_type: str = "passage", # query for retrieval, passage for embedding
|
||||
trunc: str = "NONE", # NONE or START or END
|
||||
encode: str = "float", # float or base64
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
|
||||
openai_async_client = (
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model,
|
||||
input=texts,
|
||||
encoding_format=encode,
|
||||
extra_body={"input_type": input_type, "truncate": trunc},
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
158
minirag/llm/ollama.py
Normal file
158
minirag/llm/ollama.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Ollama LLM Interface Module
|
||||
==========================
|
||||
|
||||
This module provides interfaces for interacting with Ollama's language models,
|
||||
including text generation and embedding capabilities.
|
||||
|
||||
Author: Lightrag team
|
||||
Created: 2024-01-24
|
||||
License: MIT License
|
||||
|
||||
Copyright (c) 2024 Lightrag
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Change Log:
|
||||
- 1.0.0 (2024-01-24): Initial release
|
||||
* Added async chat completion support
|
||||
* Added embedding generation
|
||||
* Added stream response capability
|
||||
|
||||
Dependencies:
|
||||
- ollama
|
||||
- numpy
|
||||
- pipmaster
|
||||
- Python >= 3.10
|
||||
|
||||
Usage:
|
||||
from llm_interfaces.ollama_interface import ollama_model_complete, ollama_embed
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "lightrag Team"
|
||||
__status__ = "Production"
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import AsyncIterator
|
||||
else:
|
||||
from collections.abc import AsyncIterator
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("ollama"):
|
||||
pm.install("ollama")
|
||||
if not pm.is_installed("tenacity"):
|
||||
pm.install("tenacity")
|
||||
|
||||
import ollama
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def ollama_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
stream = True if kwargs.get("stream") else False
|
||||
kwargs.pop("max_tokens", None)
|
||||
# kwargs.pop("response_format", None) # allow json
|
||||
host = kwargs.pop("host", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
kwargs.pop("hashing_kv", None)
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
headers = (
|
||||
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
if api_key
|
||||
else {"Content-Type": "application/json"}
|
||||
)
|
||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||
if stream:
|
||||
"""cannot cache stream response"""
|
||||
|
||||
async def inner():
|
||||
async for chunk in response:
|
||||
yield chunk["message"]["content"]
|
||||
|
||||
return inner()
|
||||
else:
|
||||
return response["message"]["content"]
|
||||
|
||||
|
||||
async def ollama_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
if keyword_extraction:
|
||||
kwargs["format"] = "json"
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await ollama_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||
"""
|
||||
Deprecated in favor of `embed`.
|
||||
"""
|
||||
embed_text = []
|
||||
ollama_client = ollama.Client(**kwargs)
|
||||
for text in texts:
|
||||
data = ollama_client.embeddings(model=embed_model, prompt=text)
|
||||
embed_text.append(data["embedding"])
|
||||
|
||||
return embed_text
|
||||
|
||||
|
||||
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
headers = (
|
||||
{"Content-Type": "application/json", "Authorization": api_key}
|
||||
if api_key
|
||||
else {"Content-Type": "application/json"}
|
||||
)
|
||||
kwargs["headers"] = headers
|
||||
ollama_client = ollama.Client(**kwargs)
|
||||
data = ollama_client.embed(model=embed_model, input=texts)
|
||||
return data["embeddings"]
|
||||
230
minirag/llm/openai.py
Normal file
230
minirag/llm/openai.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
OpenAI LLM Interface Module
|
||||
==========================
|
||||
|
||||
This module provides interfaces for interacting with openai's language models,
|
||||
including text generation and embedding capabilities.
|
||||
|
||||
Author: Lightrag team
|
||||
Created: 2024-01-24
|
||||
License: MIT License
|
||||
|
||||
Copyright (c) 2024 Lightrag
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Change Log:
|
||||
- 1.0.0 (2024-01-24): Initial release
|
||||
* Added async chat completion support
|
||||
* Added embedding generation
|
||||
* Added stream response capability
|
||||
|
||||
Dependencies:
|
||||
- openai
|
||||
- numpy
|
||||
- pipmaster
|
||||
- Python >= 3.10
|
||||
|
||||
Usage:
|
||||
from llm_interfaces.openai import openai_model_complete, openai_embed
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "lightrag Team"
|
||||
__status__ = "Production"
|
||||
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import AsyncIterator
|
||||
else:
|
||||
from collections.abc import AsyncIterator
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("openai"):
|
||||
pm.install("openai")
|
||||
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
locate_json_string_body_from_string,
|
||||
safe_unicode_decode,
|
||||
logger,
|
||||
)
|
||||
from lightrag.types import GPTKeywordExtractionFormat
|
||||
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def openai_complete_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
|
||||
openai_async_client = (
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# 添加日志输出
|
||||
logger.debug("===== Query Input to LLM =====")
|
||||
logger.debug(f"Query: {prompt}")
|
||||
logger.debug(f"System prompt: {system_prompt}")
|
||||
logger.debug("Full context:")
|
||||
if "response_format" in kwargs:
|
||||
response = await openai_async_client.beta.chat.completions.parse(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
else:
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
if hasattr(response, "__aiter__"):
|
||||
|
||||
async def inner():
|
||||
async for chunk in response:
|
||||
content = chunk.choices[0].delta.content
|
||||
if content is None:
|
||||
continue
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
yield content
|
||||
|
||||
return inner()
|
||||
else:
|
||||
content = response.choices[0].message.content
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
return content
|
||||
|
||||
|
||||
async def openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
if keyword_extraction:
|
||||
kwargs["response_format"] = "json"
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await openai_complete_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def gpt_4o_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
if keyword_extraction:
|
||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def gpt_4o_mini_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
if keyword_extraction:
|
||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o-mini",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def nvidia_openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
result = await openai_complete_if_cache(
|
||||
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
base_url="https://integrate.api.nvidia.com/v1",
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction: # TODO: use JSON API
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def openai_embed(
|
||||
texts: list[str],
|
||||
model: str = "text-embedding-3-small",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
|
||||
openai_async_client = (
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="float"
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
109
minirag/llm/siliconcloud.py
Normal file
109
minirag/llm/siliconcloud.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
SiliconCloud Embedding Interface Module
|
||||
==========================
|
||||
|
||||
This module provides interfaces for interacting with SiliconCloud system,
|
||||
including embedding capabilities.
|
||||
|
||||
Author: Lightrag team
|
||||
Created: 2024-01-24
|
||||
License: MIT License
|
||||
|
||||
Copyright (c) 2024 Lightrag
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Change Log:
|
||||
- 1.0.0 (2024-01-24): Initial release
|
||||
* Added embedding generation
|
||||
|
||||
Dependencies:
|
||||
- tenacity
|
||||
- numpy
|
||||
- pipmaster
|
||||
- Python >= 3.10
|
||||
|
||||
Usage:
|
||||
from llm_interfaces.siliconcloud import siliconcloud_model_complete, siliconcloud_embed
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "lightrag Team"
|
||||
__status__ = "Production"
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("lmdeploy"):
|
||||
pm.install("lmdeploy")
|
||||
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
|
||||
import numpy as np
|
||||
import aiohttp
|
||||
import base64
|
||||
import struct
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def siliconcloud_embedding(
|
||||
texts: list[str],
|
||||
model: str = "netease-youdao/bce-embedding-base_v1",
|
||||
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
|
||||
max_token_size: int = 512,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key and not api_key.startswith("Bearer "):
|
||||
api_key = "Bearer " + api_key
|
||||
|
||||
headers = {"Authorization": api_key, "Content-Type": "application/json"}
|
||||
|
||||
truncate_texts = [text[0:max_token_size] for text in texts]
|
||||
|
||||
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
|
||||
|
||||
base64_strings = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(base_url, headers=headers, json=payload) as response:
|
||||
content = await response.json()
|
||||
if "code" in content:
|
||||
raise ValueError(content)
|
||||
base64_strings = [item["embedding"] for item in content["data"]]
|
||||
|
||||
embeddings = []
|
||||
for string in base64_strings:
|
||||
decode_bytes = base64.b64decode(string)
|
||||
n = len(decode_bytes) // 4
|
||||
float_array = struct.unpack("<" + "f" * n, decode_bytes)
|
||||
embeddings.append(float_array)
|
||||
return np.array(embeddings)
|
||||
246
minirag/llm/zhipu.py
Normal file
246
minirag/llm/zhipu.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Zhipu LLM Interface Module
|
||||
==========================
|
||||
|
||||
This module provides interfaces for interacting with LMDeploy's language models,
|
||||
including text generation and embedding capabilities.
|
||||
|
||||
Author: Lightrag team
|
||||
Created: 2024-01-24
|
||||
License: MIT License
|
||||
|
||||
Copyright (c) 2024 Lightrag
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Change Log:
|
||||
- 1.0.0 (2024-01-24): Initial release
|
||||
* Added async chat completion support
|
||||
* Added embedding generation
|
||||
* Added stream response capability
|
||||
|
||||
Dependencies:
|
||||
- tenacity
|
||||
- numpy
|
||||
- pipmaster
|
||||
- Python >= 3.10
|
||||
|
||||
Usage:
|
||||
from llm_interfaces.zhipu import zhipu_model_complete, zhipu_embed
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "lightrag Team"
|
||||
__status__ = "Production"
|
||||
|
||||
import sys
|
||||
import re
|
||||
import json
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
if not pm.is_installed("zhipuai"):
|
||||
pm.install("zhipuai")
|
||||
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
logger,
|
||||
)
|
||||
|
||||
from lightrag.types import GPTKeywordExtractionFormat
|
||||
|
||||
import numpy as np
|
||||
from typing import Union, List, Optional, Dict
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def zhipu_complete_if_cache(
|
||||
prompt: Union[str, List[Dict[str, str]]],
|
||||
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
|
||||
api_key: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
history_messages: List[Dict[str, str]] = [],
|
||||
**kwargs,
|
||||
) -> str:
|
||||
# dynamically load ZhipuAI
|
||||
try:
|
||||
from zhipuai import ZhipuAI
|
||||
except ImportError:
|
||||
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
||||
|
||||
if api_key:
|
||||
client = ZhipuAI(api_key=api_key)
|
||||
else:
|
||||
# please set ZHIPUAI_API_KEY in your environment
|
||||
# os.environ["ZHIPUAI_API_KEY"]
|
||||
client = ZhipuAI()
|
||||
|
||||
messages = []
|
||||
|
||||
if not system_prompt:
|
||||
system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
|
||||
|
||||
# Add system prompt if provided
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Add debug logging
|
||||
logger.debug("===== Query Input to LLM =====")
|
||||
logger.debug(f"Query: {prompt}")
|
||||
logger.debug(f"System prompt: {system_prompt}")
|
||||
|
||||
# Remove unsupported kwargs
|
||||
kwargs = {
|
||||
k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
|
||||
}
|
||||
|
||||
response = client.chat.completions.create(model=model, messages=messages, **kwargs)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
async def zhipu_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
):
|
||||
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
|
||||
if keyword_extraction:
|
||||
# Add a system prompt to guide the model to return JSON format
|
||||
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
|
||||
Please analyze the content and extract two types of keywords:
|
||||
1. High-level keywords: Important concepts and main themes
|
||||
2. Low-level keywords: Specific details and supporting elements
|
||||
|
||||
Return your response in this exact JSON format:
|
||||
{
|
||||
"high_level_keywords": ["keyword1", "keyword2"],
|
||||
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
|
||||
}
|
||||
|
||||
Only return the JSON, no other text."""
|
||||
|
||||
# Combine with existing system prompt if any
|
||||
if system_prompt:
|
||||
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
|
||||
else:
|
||||
system_prompt = extraction_prompt
|
||||
|
||||
try:
|
||||
response = await zhipu_complete_if_cache(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
data = json.loads(response)
|
||||
return GPTKeywordExtractionFormat(
|
||||
high_level_keywords=data.get("high_level_keywords", []),
|
||||
low_level_keywords=data.get("low_level_keywords", []),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# If direct JSON parsing fails, try to extract JSON from text
|
||||
match = re.search(r"\{[\s\S]*\}", response)
|
||||
if match:
|
||||
try:
|
||||
data = json.loads(match.group())
|
||||
return GPTKeywordExtractionFormat(
|
||||
high_level_keywords=data.get("high_level_keywords", []),
|
||||
low_level_keywords=data.get("low_level_keywords", []),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# If all parsing fails, log warning and return empty format
|
||||
logger.warning(
|
||||
f"Failed to parse keyword extraction response: {response}"
|
||||
)
|
||||
return GPTKeywordExtractionFormat(
|
||||
high_level_keywords=[], low_level_keywords=[]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during keyword extraction: {str(e)}")
|
||||
return GPTKeywordExtractionFormat(
|
||||
high_level_keywords=[], low_level_keywords=[]
|
||||
)
|
||||
else:
|
||||
# For non-keyword-extraction, just return the raw response string
|
||||
return await zhipu_complete_if_cache(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def zhipu_embedding(
|
||||
texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
|
||||
) -> np.ndarray:
|
||||
# dynamically load ZhipuAI
|
||||
try:
|
||||
from zhipuai import ZhipuAI
|
||||
except ImportError:
|
||||
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
||||
if api_key:
|
||||
client = ZhipuAI(api_key=api_key)
|
||||
else:
|
||||
# please set ZHIPUAI_API_KEY in your environment
|
||||
# os.environ["ZHIPUAI_API_KEY"]
|
||||
client = ZhipuAI()
|
||||
|
||||
# Convert single text to list if needed
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
try:
|
||||
response = client.embeddings.create(model=model, input=[text], **kwargs)
|
||||
embeddings.append(response.data[0].embedding)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
|
||||
|
||||
return np.array(embeddings)
|
||||
@@ -33,15 +33,31 @@ from .base import (
|
||||
QueryParam,
|
||||
)
|
||||
|
||||
from .storage import (
|
||||
JsonKVStorage,
|
||||
NanoVectorDBStorage,
|
||||
NetworkXStorage,
|
||||
)
|
||||
|
||||
from .kg.neo4j_impl import Neo4JStorage
|
||||
|
||||
from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
|
||||
STORAGES = {
|
||||
"NetworkXStorage": ".kg.networkx_impl",
|
||||
"JsonKVStorage": ".kg.json_kv_impl",
|
||||
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
|
||||
"JsonDocStatusStorage": ".kg.jsondocstatus_impl",
|
||||
"Neo4JStorage": ".kg.neo4j_impl",
|
||||
"OracleKVStorage": ".kg.oracle_impl",
|
||||
"OracleGraphStorage": ".kg.oracle_impl",
|
||||
"OracleVectorDBStorage": ".kg.oracle_impl",
|
||||
"MilvusVectorDBStorge": ".kg.milvus_impl",
|
||||
"MongoKVStorage": ".kg.mongo_impl",
|
||||
"MongoGraphStorage": ".kg.mongo_impl",
|
||||
"RedisKVStorage": ".kg.redis_impl",
|
||||
"ChromaVectorDBStorage": ".kg.chroma_impl",
|
||||
"TiDBKVStorage": ".kg.tidb_impl",
|
||||
"TiDBVectorDBStorage": ".kg.tidb_impl",
|
||||
"TiDBGraphStorage": ".kg.tidb_impl",
|
||||
"PGKVStorage": ".kg.postgres_impl",
|
||||
"PGVectorStorage": ".kg.postgres_impl",
|
||||
"AGEStorage": ".kg.age_impl",
|
||||
"PGGraphStorage": ".kg.postgres_impl",
|
||||
"GremlinStorage": ".kg.gremlin_impl",
|
||||
"PGDocStatusStorage": ".kg.postgres_impl",
|
||||
}
|
||||
|
||||
# future KG integrations
|
||||
|
||||
@@ -49,17 +65,50 @@ from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBS
|
||||
# GraphStorage as ArangoDBStorage
|
||||
# )
|
||||
|
||||
def lazy_external_import(module_name: str, class_name: str):
|
||||
"""Lazily import a class from an external module based on the package of the caller."""
|
||||
|
||||
# Get the caller's module and package
|
||||
import inspect
|
||||
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
module = inspect.getmodule(caller_frame)
|
||||
package = module.__package__ if module else None
|
||||
|
||||
def import_class(*args, **kwargs):
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_name, package=package)
|
||||
cls = getattr(module, class_name)
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
return import_class
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
"""
|
||||
Ensure that there is always an event loop available.
|
||||
|
||||
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
||||
it creates a new event loop and sets it as the current event loop.
|
||||
|
||||
Returns:
|
||||
asyncio.AbstractEventLoop: The current or newly created event loop.
|
||||
"""
|
||||
try:
|
||||
return asyncio.get_event_loop()
|
||||
# Try to get the current event loop
|
||||
current_loop = asyncio.get_event_loop()
|
||||
if current_loop.is_closed():
|
||||
raise RuntimeError("Event loop is closed.")
|
||||
return current_loop
|
||||
|
||||
except RuntimeError:
|
||||
# If no event loop exists or it is closed, create a new one
|
||||
logger.info("Creating a new event loop in main thread.")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop
|
||||
|
||||
return loop
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -100,12 +149,12 @@ class MiniRAG:
|
||||
}
|
||||
)
|
||||
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
||||
embedding_func: EmbeddingFunc = None
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 16
|
||||
|
||||
# LLM
|
||||
llm_model_func: callable = hf_model_complete#gpt_4o_mini_complete #
|
||||
llm_model_func: callable = None
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_max_token_size: int = 32768
|
||||
llm_model_max_async: int = 16
|
||||
@@ -120,27 +169,55 @@ class MiniRAG:
|
||||
addon_params: dict = field(default_factory=dict)
|
||||
convert_response_to_json_func: callable = convert_response_to_json
|
||||
|
||||
# Add new field for document status storage type
|
||||
doc_status_storage: str = field(default="JsonDocStatusStorage")
|
||||
|
||||
# Custom Chunking Function
|
||||
chunking_func: callable = chunking_by_token_size
|
||||
chunking_func_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
log_file = os.path.join(self.working_dir, "minirag.log")
|
||||
set_logger(log_file)
|
||||
logger.setLevel(self.log_level)
|
||||
|
||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||
if not os.path.exists(self.working_dir):
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
||||
# show config
|
||||
global_config = asdict(self)
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
|
||||
logger.debug(f"MiniRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# @TODO: should move all storage setup here to leverage initial start params attached to self.
|
||||
|
||||
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
||||
self._get_storage_class()[self.kv_storage]
|
||||
self._get_storage_class(self.kv_storage)
|
||||
)
|
||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
|
||||
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
|
||||
self.vector_storage
|
||||
]
|
||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
|
||||
)
|
||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
|
||||
self.graph_storage
|
||||
]
|
||||
)
|
||||
|
||||
self.key_string_value_json_storage_cls = partial(
|
||||
self.key_string_value_json_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
self.vector_db_storage_cls = partial(
|
||||
self.vector_db_storage_cls, global_config=global_config
|
||||
)
|
||||
|
||||
self.graph_storage_cls = partial(
|
||||
self.graph_storage_cls, global_config=global_config
|
||||
)
|
||||
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
|
||||
namespace="json_doc_status_storage",
|
||||
embedding_func=None,
|
||||
)
|
||||
|
||||
if not os.path.exists(self.working_dir):
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
@@ -218,21 +295,39 @@ class MiniRAG:
|
||||
**self.llm_model_kwargs,
|
||||
)
|
||||
)
|
||||
# Initialize document status storage
|
||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||
self.doc_status = self.doc_status_storage_cls(
|
||||
namespace="doc_status",
|
||||
global_config=global_config,
|
||||
embedding_func=None,
|
||||
)
|
||||
|
||||
def _get_storage_class(self, storage_name: str) -> dict:
|
||||
import_path = STORAGES[storage_name]
|
||||
storage_class = lazy_external_import(import_path, storage_name)
|
||||
return storage_class
|
||||
|
||||
def set_storage_client(self, db_client):
|
||||
# Now only tested on Oracle Database
|
||||
for storage in [
|
||||
self.vector_db_storage_cls,
|
||||
self.graph_storage_cls,
|
||||
self.doc_status,
|
||||
self.full_docs,
|
||||
self.text_chunks,
|
||||
self.llm_response_cache,
|
||||
self.key_string_value_json_storage_cls,
|
||||
self.chunks_vdb,
|
||||
self.relationships_vdb,
|
||||
self.entities_vdb,
|
||||
self.graph_storage_cls,
|
||||
self.chunk_entity_relation_graph,
|
||||
self.llm_response_cache,
|
||||
]:
|
||||
# set client
|
||||
storage.db = db_client
|
||||
|
||||
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
||||
return {
|
||||
# kv storage
|
||||
"JsonKVStorage": JsonKVStorage,
|
||||
"OracleKVStorage": OracleKVStorage,
|
||||
# vector storage
|
||||
"NanoVectorDBStorage": NanoVectorDBStorage,
|
||||
"OracleVectorDBStorage": OracleVectorDBStorage,
|
||||
# graph storage
|
||||
"NetworkXStorage": NetworkXStorage,
|
||||
"Neo4JStorage": Neo4JStorage,
|
||||
"OracleGraphStorage": OracleGraphStorage,
|
||||
# "ArangoDBStorage": ArangoDBStorage
|
||||
}
|
||||
|
||||
def insert(self, string_or_strings):
|
||||
loop = always_get_an_event_loop()
|
||||
|
||||
@@ -1,354 +0,0 @@
|
||||
import asyncio
|
||||
import html
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from nano_vectordb import NanoVectorDB
|
||||
import copy
|
||||
from .utils import (
|
||||
logger,
|
||||
load_json,
|
||||
write_json,
|
||||
compute_mdhash_id,
|
||||
merge_tuples,
|
||||
)
|
||||
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
working_dir = self.global_config["working_dir"]
|
||||
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
||||
self._data = load_json(self._file_name) or {}
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
return list(self._data.keys())
|
||||
|
||||
async def index_done_callback(self):
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
async def get_by_id(self, id):
|
||||
return self._data.get(id, None)
|
||||
|
||||
async def get_by_ids(self, ids, fields=None):
|
||||
if fields is None:
|
||||
return [self._data.get(id, None) for id in ids]
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items() if k in fields}
|
||||
if self._data.get(id, None)
|
||||
else None
|
||||
)
|
||||
for id in ids
|
||||
]
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
return set([s for s in data if s not in self._data])
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
return left_data
|
||||
|
||||
async def drop(self):
|
||||
self._data = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class NanoVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = 0.#2
|
||||
|
||||
def __post_init__(self):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
self.cosine_better_than_threshold = self.global_config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
results = self._client.query(
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
|
||||
results = [
|
||||
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def client_storage(self):
|
||||
return getattr(self._client, "_NanoVectorDB__storage")
|
||||
|
||||
async def delete_entity(self, entity_name: str):
|
||||
try:
|
||||
entity_id = [compute_mdhash_id(entity_name, prefix="ent-")]
|
||||
|
||||
if self._client.get(entity_id):
|
||||
self._client.delete(entity_id)
|
||||
logger.info(f"Entity {entity_name} have been deleted.")
|
||||
else:
|
||||
logger.info(f"No entity found with name {entity_name}.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting entity {entity_name}: {e}")
|
||||
|
||||
async def delete_relation(self, entity_name: str):
|
||||
try:
|
||||
relations = [
|
||||
dp
|
||||
for dp in self.client_storage["data"]
|
||||
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
||||
]
|
||||
ids_to_delete = [relation["__id__"] for relation in relations]
|
||||
|
||||
if ids_to_delete:
|
||||
self._client.delete(ids_to_delete)
|
||||
logger.info(
|
||||
f"All relations related to entity {entity_name} have been deleted."
|
||||
)
|
||||
else:
|
||||
logger.info(f"No relations found for entity {entity_name}.")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while deleting relations for entity {entity_name}: {e}"
|
||||
)
|
||||
|
||||
async def index_done_callback(self):
|
||||
self._client.save()
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetworkXStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name) -> nx.Graph:
|
||||
if os.path.exists(file_name):
|
||||
return nx.read_graphml(file_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def write_nx_graph(graph: nx.Graph, file_name):
|
||||
logger.info(
|
||||
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
|
||||
)
|
||||
nx.write_graphml(graph, file_name)
|
||||
|
||||
@staticmethod
|
||||
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
|
||||
"""
|
||||
from graspologic.utils import largest_connected_component
|
||||
|
||||
graph = graph.copy()
|
||||
graph = cast(nx.Graph, largest_connected_component(graph))
|
||||
node_mapping = {
|
||||
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
||||
} # type: ignore
|
||||
graph = nx.relabel_nodes(graph, node_mapping)
|
||||
return NetworkXStorage._stabilize_graph(graph)
|
||||
|
||||
@staticmethod
|
||||
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
||||
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
||||
Ensure an undirected graph with the same relationships will always be read the same way.
|
||||
"""
|
||||
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
||||
|
||||
sorted_nodes = graph.nodes(data=True)
|
||||
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
|
||||
|
||||
fixed_graph.add_nodes_from(sorted_nodes)
|
||||
edges = list(graph.edges(data=True))
|
||||
|
||||
if not graph.is_directed():
|
||||
|
||||
def _sort_source_target(edge):
|
||||
source, target, edge_data = edge
|
||||
if source > target:
|
||||
temp = source
|
||||
source = target
|
||||
target = temp
|
||||
return source, target, edge_data
|
||||
|
||||
edges = [_sort_source_target(edge) for edge in edges]
|
||||
|
||||
def _get_edge_key(source: Any, target: Any) -> str:
|
||||
return f"{source} -> {target}"
|
||||
|
||||
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
|
||||
|
||||
fixed_graph.add_edges_from(edges)
|
||||
return fixed_graph
|
||||
|
||||
def __post_init__(self):
|
||||
self._graphml_xml_file = os.path.join(
|
||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||
)
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
if preloaded_graph is not None:
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
self._graph = preloaded_graph or nx.Graph()
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
async def index_done_callback(self):
|
||||
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
return self._graph.has_node(node_id)
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
return self._graph.has_edge(source_node_id, target_node_id)
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
return self._graph.nodes.get(node_id)
|
||||
|
||||
async def get_types(self) -> list:
|
||||
all_entity_type = []
|
||||
all_type_w_name = {}
|
||||
for n in self._graph.nodes(data=True):
|
||||
key = n[1]['entity_type'].strip('\"')
|
||||
all_entity_type.append(key)
|
||||
if key not in all_type_w_name:
|
||||
all_type_w_name[key] = []
|
||||
all_type_w_name[key].append(n[0].strip('\"'))
|
||||
else:
|
||||
if len(all_type_w_name[key])<=1:
|
||||
all_type_w_name[key].append(n[0].strip('\"'))
|
||||
|
||||
return list(set(all_entity_type)),all_type_w_name
|
||||
|
||||
|
||||
|
||||
async def get_node_from_types(self,type_list) -> Union[dict, None]:
|
||||
node_list = []
|
||||
for name, arrt in self._graph.nodes(data = True):
|
||||
node_type = arrt.get('entity_type').strip('\"')
|
||||
if node_type in type_list:
|
||||
node_list.append(name)
|
||||
node_datas = await asyncio.gather(
|
||||
*[self.get_node(name) for name in node_list]
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k}
|
||||
for k, n in zip(node_list, node_datas)
|
||||
if n is not None
|
||||
]
|
||||
return node_datas#,node_dict
|
||||
|
||||
|
||||
async def get_neighbors_within_k_hops(self,source_node_id: str, k):
|
||||
count = 0
|
||||
if await self.has_node(source_node_id):
|
||||
source_edge = list(self._graph.edges(source_node_id))
|
||||
else:
|
||||
print("NO THIS ID:",source_node_id)
|
||||
return []
|
||||
count = count+1
|
||||
while count<k:
|
||||
count = count+1
|
||||
sc_edge = copy.deepcopy(source_edge)
|
||||
source_edge =[]
|
||||
for pair in sc_edge:
|
||||
append_edge = list(self._graph.edges(pair[-1]))
|
||||
for tuples in merge_tuples([pair],append_edge):
|
||||
source_edge.append(tuples)
|
||||
return source_edge
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
return self._graph.degree(node_id)
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
return self._graph.edges.get((source_node_id, target_node_id))
|
||||
|
||||
async def get_node_edges(self, source_node_id: str):
|
||||
if self._graph.has_node(source_node_id):
|
||||
return list(self._graph.edges(source_node_id))
|
||||
return None
|
||||
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
self._graph.add_node(node_id, **node_data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
||||
|
||||
async def delete_node(self, node_id: str):
|
||||
"""
|
||||
Delete a node from the graph based on the specified node_id.
|
||||
|
||||
:param node_id: The node_id to delete
|
||||
"""
|
||||
if self._graph.has_node(node_id):
|
||||
self._graph.remove_node(node_id)
|
||||
logger.info(f"Node {node_id} deleted from the graph.")
|
||||
else:
|
||||
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
# @TODO: NOT USED
|
||||
async def _node2vec_embed(self):
|
||||
from graspologic import embed
|
||||
|
||||
embeddings, nodes = embed.node2vec_embed(
|
||||
self._graph,
|
||||
**self.global_config["node2vec_params"],
|
||||
)
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
@@ -1,38 +1,33 @@
|
||||
accelerate
|
||||
aioboto3
|
||||
aiofiles
|
||||
aiohttp
|
||||
asyncpg
|
||||
configparser
|
||||
graspologic
|
||||
|
||||
# database packages
|
||||
graspologic
|
||||
gremlinpython
|
||||
hnswlib
|
||||
nano-vectordb
|
||||
neo4j
|
||||
networkx
|
||||
|
||||
# Basic modules
|
||||
numpy
|
||||
ollama
|
||||
openai
|
||||
oracledb
|
||||
psycopg-pool
|
||||
psycopg[binary,pool]
|
||||
pipmaster
|
||||
pydantic
|
||||
pymilvus
|
||||
pymongo
|
||||
pymysql
|
||||
|
||||
# File manipulation libraries
|
||||
PyPDF2
|
||||
python-docx
|
||||
python-dotenv
|
||||
pyvis
|
||||
python-pptx
|
||||
|
||||
setuptools
|
||||
# lmdeploy[all]
|
||||
sqlalchemy
|
||||
tenacity
|
||||
json_repair
|
||||
rouge
|
||||
nltk
|
||||
|
||||
|
||||
# LLM packages
|
||||
tiktoken
|
||||
torch
|
||||
tqdm
|
||||
transformers
|
||||
xxhash
|
||||
xxhash
|
||||
|
||||
# Extra libraries are installed when needed using pipmaster
|
||||
|
||||
19
setup.py
19
setup.py
@@ -22,6 +22,15 @@ def read_requirements():
|
||||
)
|
||||
return deps
|
||||
|
||||
def read_api_requirements():
|
||||
api_deps = []
|
||||
try:
|
||||
with open("./minirag/api/requirements.txt") as f:
|
||||
api_deps = [line.strip() for line in f if line.strip()]
|
||||
except FileNotFoundError:
|
||||
print("Warning: API requirements.txt not found.")
|
||||
return api_deps
|
||||
|
||||
|
||||
long_description = read_long_description()
|
||||
requirements = read_requirements()
|
||||
@@ -29,7 +38,7 @@ requirements = read_requirements()
|
||||
setuptools.setup(
|
||||
name="minirag-hku",
|
||||
url="https://github.com/HKUDS/MiniRAG",
|
||||
version="0.0.1",
|
||||
version="0.0.2",
|
||||
author="HKUDS",
|
||||
description="MiniRAG: Towards Extremely Simple Retrieval-Augmented Generation",
|
||||
long_description=long_description,
|
||||
@@ -48,4 +57,12 @@ setuptools.setup(
|
||||
python_requires=">=3.9",#rec: 3.9.19
|
||||
install_requires=requirements,
|
||||
include_package_data=True, # Includes non-code files from MANIFEST.in
|
||||
extras_require={
|
||||
"api": read_api_requirements(), # API requirements as optional
|
||||
},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"minirag-server=minirag.api.minirag_server:main [api]",
|
||||
],
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user