Brought all lightrag structure to the minirag app
This commit is contained in:
		
							
								
								
									
										26
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| __pycache__ | ||||
| *.egg-info | ||||
| dickens/ | ||||
| book.txt | ||||
| lightrag-dev/ | ||||
| .idea/ | ||||
| dist/ | ||||
| env/ | ||||
| local_neo4jWorkDir/ | ||||
| neo4jWorkDir/ | ||||
| ignore_this.txt | ||||
| .venv/ | ||||
| *.ignore.* | ||||
| .ruff_cache/ | ||||
| gui/ | ||||
| *.log | ||||
| .vscode | ||||
| inputs | ||||
| rag_storage | ||||
| .env | ||||
| venv/ | ||||
| examples/input/ | ||||
| examples/output/ | ||||
| .DS_Store | ||||
| #Remove config.ini from repo | ||||
| *.ini | ||||
							
								
								
									
										22
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| repos: | ||||
|   - repo: https://github.com/pre-commit/pre-commit-hooks | ||||
|     rev: v5.0.0 | ||||
|     hooks: | ||||
|       - id: trailing-whitespace | ||||
|       - id: end-of-file-fixer | ||||
|       - id: requirements-txt-fixer | ||||
|  | ||||
|  | ||||
|   - repo: https://github.com/astral-sh/ruff-pre-commit | ||||
|     rev: v0.6.4 | ||||
|     hooks: | ||||
|       - id: ruff-format | ||||
|       - id: ruff | ||||
|         args: [--fix, --ignore=E402] | ||||
|  | ||||
|  | ||||
|   - repo: https://github.com/mgedmin/check-manifest | ||||
|     rev: "0.49" | ||||
|     hooks: | ||||
|       - id: check-manifest | ||||
|         stages: [manual] | ||||
							
								
								
									
										
											BIN
										
									
								
								assets/logo.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								assets/logo.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 1.8 MiB | 
							
								
								
									
										7
									
								
								minirag/api/.env.aoi.example
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								minirag/api/.env.aoi.example
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| AZURE_OPENAI_API_VERSION=2024-08-01-preview | ||||
| AZURE_OPENAI_DEPLOYMENT=gpt-4o | ||||
| AZURE_OPENAI_API_KEY=myapikey | ||||
| AZURE_OPENAI_ENDPOINT=https://myendpoint.openai.azure.com | ||||
|  | ||||
| AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large | ||||
| AZURE_EMBEDDING_API_VERSION=2023-05-15 | ||||
							
								
								
									
										2
									
								
								minirag/api/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								minirag/api/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| inputs | ||||
| rag_storage | ||||
							
								
								
									
										475
									
								
								minirag/api/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										475
									
								
								minirag/api/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,475 @@ | ||||
| ## Install with API Support | ||||
|  | ||||
| LightRAG provides optional API support through FastAPI servers that add RAG capabilities to existing LLM services. You can install LightRAG with API support in two ways: | ||||
|  | ||||
| ### 1. Installation from PyPI | ||||
|  | ||||
| ```bash | ||||
| pip install "lightrag-hku[api]" | ||||
| ``` | ||||
|  | ||||
| ### 2. Installation from Source (Development) | ||||
|  | ||||
| ```bash | ||||
| # Clone the repository | ||||
| git clone https://github.com/HKUDS/lightrag.git | ||||
|  | ||||
| # Change to the repository directory | ||||
| cd lightrag | ||||
|  | ||||
| # create a Python virtual enviroment if neccesary | ||||
| # Install in editable mode with API support | ||||
| pip install -e ".[api]" | ||||
| ``` | ||||
|  | ||||
| ### Prerequisites | ||||
|  | ||||
| Before running any of the servers, ensure you have the corresponding backend service running for both llm and embedding. | ||||
| The new api allows you to mix different bindings for llm/embeddings. | ||||
| For example, you have the possibility to use ollama for the embedding and openai for the llm. | ||||
|  | ||||
| #### For LoLLMs Server | ||||
| - LoLLMs must be running and accessible | ||||
| - Default connection: http://localhost:9600 | ||||
| - Configure using --llm-binding-host and/or --embedding-binding-host if running on a different host/port | ||||
|  | ||||
| #### For Ollama Server | ||||
| - Ollama must be running and accessible | ||||
| - Requires environment variables setup or command line argument provided | ||||
| - Environment variables: LLM_BINDING=ollama, LLM_BINDING_HOST, LLM_MODEL | ||||
| - Command line arguments: --llm-binding=ollama, --llm-binding-host, --llm-model | ||||
| - Default connection is  http://localhost:11434 if not priveded | ||||
|  | ||||
| > The default MAX_TOKENS(num_ctx) for Ollama is 32768. If your Ollama server is lacking or GPU memory, set it to a lower value. | ||||
|  | ||||
| #### For OpenAI Alike Server | ||||
| - Requires environment variables setup or command line argument provided | ||||
| - Environment variables: LLM_BINDING=ollama, LLM_BINDING_HOST, LLM_MODEL, LLM_BINDING_API_KEY | ||||
| - Command line arguments: --llm-binding=ollama, --llm-binding-host, --llm-model, --llm-binding-api-key | ||||
| - Default connection is https://api.openai.com/v1 if not priveded | ||||
|  | ||||
| #### For Azure OpenAI Server | ||||
| Azure OpenAI API can be created using the following commands in Azure CLI (you need to install Azure CLI first from [https://docs.microsoft.com/en-us/cli/azure/install-azure-cli](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)): | ||||
| ```bash | ||||
| # Change the resource group name, location and OpenAI resource name as needed | ||||
| RESOURCE_GROUP_NAME=LightRAG | ||||
| LOCATION=swedencentral | ||||
| RESOURCE_NAME=LightRAG-OpenAI | ||||
|  | ||||
| az login | ||||
| az group create --name $RESOURCE_GROUP_NAME --location $LOCATION | ||||
| az cognitiveservices account create --name $RESOURCE_NAME --resource-group $RESOURCE_GROUP_NAME  --kind OpenAI --sku S0 --location swedencentral | ||||
| az cognitiveservices account deployment create --resource-group $RESOURCE_GROUP_NAME  --model-format OpenAI --name $RESOURCE_NAME --deployment-name gpt-4o --model-name gpt-4o --model-version "2024-08-06"  --sku-capacity 100 --sku-name "Standard" | ||||
| az cognitiveservices account deployment create --resource-group $RESOURCE_GROUP_NAME  --model-format OpenAI --name $RESOURCE_NAME --deployment-name text-embedding-3-large --model-name text-embedding-3-large --model-version "1"  --sku-capacity 80 --sku-name "Standard" | ||||
| az cognitiveservices account show --name $RESOURCE_NAME --resource-group $RESOURCE_GROUP_NAME --query "properties.endpoint" | ||||
| az cognitiveservices account keys list --name $RESOURCE_NAME -g $RESOURCE_GROUP_NAME | ||||
|  | ||||
| ``` | ||||
| The output of the last command will give you the endpoint and the key for the OpenAI API. You can use these values to set the environment variables in the `.env` file. | ||||
|  | ||||
| ``` | ||||
| LLM_BINDING=azure_openai | ||||
| LLM_BINDING_HOST=endpoint_of_azure_ai | ||||
| LLM_MODEL=model_name_of_azure_ai | ||||
| LLM_BINDING_API_KEY=api_key_of_azure_ai | ||||
| ``` | ||||
|  | ||||
| ### About Ollama API | ||||
|  | ||||
| We provide an Ollama-compatible interfaces for LightRAG, aiming to emulate LightRAG as an Ollama chat model. This allows AI chat frontends supporting Ollama, such as Open WebUI, to access LightRAG easily. | ||||
|  | ||||
| #### Choose Query mode in chat | ||||
|  | ||||
| A query prefix in the query string can determines which LightRAG query mode is used to generate the respond for the query. The supported prefixes include: | ||||
|  | ||||
| /local | ||||
| /global | ||||
| /hybrid | ||||
| /naive | ||||
| /mix | ||||
|  | ||||
| For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode query for LighRAG. A chat message without query prefix will trigger a hybrid mode query by default。 | ||||
|  | ||||
| #### Connect Open WebUI to LightRAG | ||||
|  | ||||
| After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface. | ||||
|  | ||||
| ## Configuration | ||||
|  | ||||
| LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables. | ||||
|  | ||||
| For better performance, the API server's default values for TOP_K and COSINE_THRESHOLD are set to 50 and 0.4 respectively. If COSINE_THRESHOLD remains at its default value of 0.2 in LightRAG, many irrelevant entities and relations would be retrieved and sent to the LLM. | ||||
|  | ||||
| ### Environment Variables | ||||
|  | ||||
| You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. Here's a complete example of available environment variables: | ||||
|  | ||||
| ```env | ||||
| # Server Configuration | ||||
| HOST=0.0.0.0 | ||||
| PORT=9721 | ||||
|  | ||||
| # Directory Configuration | ||||
| WORKING_DIR=/app/data/rag_storage | ||||
| INPUT_DIR=/app/data/inputs | ||||
|  | ||||
| # RAG Configuration | ||||
| MAX_ASYNC=4 | ||||
| MAX_TOKENS=32768 | ||||
| EMBEDDING_DIM=1024 | ||||
| MAX_EMBED_TOKENS=8192 | ||||
| #HISTORY_TURNS=3 | ||||
| #CHUNK_SIZE=1200 | ||||
| #CHUNK_OVERLAP_SIZE=100 | ||||
| #COSINE_THRESHOLD=0.4 | ||||
| #TOP_K=50 | ||||
|  | ||||
| # LLM Configuration | ||||
| LLM_BINDING=ollama | ||||
| LLM_BINDING_HOST=http://localhost:11434 | ||||
| LLM_MODEL=mistral-nemo:latest | ||||
|  | ||||
| # must be set if using OpenAI LLM (LLM_MODEL must be set or set by command line parms) | ||||
| OPENAI_API_KEY=you_api_key | ||||
|  | ||||
| # Embedding Configuration | ||||
| EMBEDDING_BINDING=ollama | ||||
| EMBEDDING_BINDING_HOST=http://localhost:11434 | ||||
| EMBEDDING_MODEL=bge-m3:latest | ||||
|  | ||||
| # Security | ||||
| #LIGHTRAG_API_KEY=you-api-key-for-accessing-LightRAG | ||||
|  | ||||
| # Logging | ||||
| LOG_LEVEL=INFO | ||||
|  | ||||
| # Optional SSL Configuration | ||||
| #SSL=true | ||||
| #SSL_CERTFILE=/path/to/cert.pem | ||||
| #SSL_KEYFILE=/path/to/key.pem | ||||
|  | ||||
| # Optional Timeout | ||||
| #TIMEOUT=30 | ||||
| ``` | ||||
|  | ||||
| ### Configuration Priority | ||||
|  | ||||
| The configuration values are loaded in the following order (highest priority first): | ||||
| 1. Command-line arguments | ||||
| 2. Environment variables | ||||
| 3. Default values | ||||
|  | ||||
| For example: | ||||
| ```bash | ||||
| # This command-line argument will override both the environment variable and default value | ||||
| python lightrag.py --port 8080 | ||||
|  | ||||
| # The environment variable will override the default value but not the command-line argument | ||||
| PORT=7000 python lightrag.py | ||||
| ``` | ||||
|  | ||||
| #### LightRag Server Options | ||||
|  | ||||
| | Parameter | Default | Description | | ||||
| |-----------|---------|-------------| | ||||
| | --host | 0.0.0.0 | Server host | | ||||
| | --port | 9721 | Server port | | ||||
| | --llm-binding | ollama | LLM binding to be used. Supported: lollms, ollama, openai | | ||||
| | --llm-binding-host | (dynamic) | LLM server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) | | ||||
| | --llm-model | mistral-nemo:latest | LLM model name | | ||||
| | --llm-binding-api-key | None | API Key for OpenAI Alike LLM | | ||||
| | --embedding-binding | ollama | Embedding binding to be used. Supported: lollms, ollama, openai | | ||||
| | --embedding-binding-host | (dynamic) | Embedding server host URL. Defaults based on binding: http://localhost:11434 (ollama), http://localhost:9600 (lollms), https://api.openai.com/v1 (openai) | | ||||
| | --embedding-model | bge-m3:latest | Embedding model name | | ||||
| | --working-dir | ./rag_storage | Working directory for RAG storage | | ||||
| | --input-dir | ./inputs | Directory containing input documents | | ||||
| | --max-async | 4 | Maximum async operations | | ||||
| | --max-tokens | 32768 | Maximum token size | | ||||
| | --embedding-dim | 1024 | Embedding dimensions | | ||||
| | --max-embed-tokens | 8192 | Maximum embedding token size | | ||||
| | --timeout | None | Timeout in seconds (useful when using slow AI). Use None for infinite timeout | | ||||
| | --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) | | ||||
| | --key | None | API key for authentication. Protects lightrag server against unauthorized access | | ||||
| | --ssl | False | Enable HTTPS | | ||||
| | --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) | | ||||
| | --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) | | ||||
| | --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. | | ||||
| | --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. | | ||||
|  | ||||
| ### Example Usage | ||||
|  | ||||
| #### Running a Lightrag server with ollama default local server as llm and embedding backends | ||||
|  | ||||
| Ollama is the default backend for both llm and embedding, so by default you can run lightrag-server with no parameters and the default ones will be used. Make sure ollama is installed and is running and default models are already installed on ollama. | ||||
|  | ||||
| ```bash | ||||
| # Run lightrag with ollama, mistral-nemo:latest for llm, and bge-m3:latest for embedding | ||||
| lightrag-server | ||||
|  | ||||
| # Using specific models (ensure they are installed in your ollama instance) | ||||
| lightrag-server --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-model nomic-embed-text --embedding-dim 1024 | ||||
|  | ||||
| # Using an authentication key | ||||
| lightrag-server --key my-key | ||||
|  | ||||
| # Using lollms for llm and ollama for embedding | ||||
| lightrag-server --llm-binding lollms | ||||
| ``` | ||||
|  | ||||
| #### Running a Lightrag server with lollms default local server as llm and embedding backends | ||||
|  | ||||
| ```bash | ||||
| # Run lightrag with lollms, mistral-nemo:latest for llm, and bge-m3:latest for embedding, use lollms for both llm and embedding | ||||
| lightrag-server --llm-binding lollms --embedding-binding lollms | ||||
|  | ||||
| # Using specific models (ensure they are installed in your ollama instance) | ||||
| lightrag-server --llm-binding lollms --llm-model adrienbrault/nous-hermes2theta-llama3-8b:f16 --embedding-binding lollms --embedding-model nomic-embed-text --embedding-dim 1024 | ||||
|  | ||||
| # Using an authentication key | ||||
| lightrag-server --key my-key | ||||
|  | ||||
| # Using lollms for llm and openai for embedding | ||||
| lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small | ||||
| ``` | ||||
|  | ||||
|  | ||||
| #### Running a Lightrag server with openai server as llm and embedding backends | ||||
|  | ||||
| ```bash | ||||
| # Run lightrag with lollms, GPT-4o-mini  for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding | ||||
| lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small | ||||
|  | ||||
| # Using an authentication key | ||||
| lightrag-server --llm-binding openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small --key my-key | ||||
|  | ||||
| # Using lollms for llm and openai for embedding | ||||
| lightrag-server --llm-binding lollms --embedding-binding openai --embedding-model text-embedding-3-small | ||||
| ``` | ||||
|  | ||||
| #### Running a Lightrag server with azure openai server as llm and embedding backends | ||||
|  | ||||
| ```bash | ||||
| # Run lightrag with lollms, GPT-4o-mini  for llm, and text-embedding-3-small for embedding, use openai for both llm and embedding | ||||
| lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding openai --embedding-model text-embedding-3-small | ||||
|  | ||||
| # Using an authentication key | ||||
| lightrag-server --llm-binding azure_openai --llm-model GPT-4o-mini --embedding-binding azure_openai --embedding-model text-embedding-3-small --key my-key | ||||
|  | ||||
| # Using lollms for llm and azure_openai for embedding | ||||
| lightrag-server --llm-binding lollms --embedding-binding azure_openai --embedding-model text-embedding-3-small | ||||
| ``` | ||||
|  | ||||
| **Important Notes:** | ||||
| - For LoLLMs: Make sure the specified models are installed in your LoLLMs instance | ||||
| - For Ollama: Make sure the specified models are installed in your Ollama instance | ||||
| - For OpenAI: Ensure you have set up your OPENAI_API_KEY environment variable | ||||
| - For Azure OpenAI: Build and configure your server as stated in the Prequisites section | ||||
|  | ||||
| For help on any server, use the --help flag: | ||||
| ```bash | ||||
| lightrag-server --help | ||||
| ``` | ||||
|  | ||||
| Note: If you don't need the API functionality, you can install the base package without API support using: | ||||
| ```bash | ||||
| pip install lightrag-hku | ||||
| ``` | ||||
|  | ||||
| ## API Endpoints | ||||
|  | ||||
| All servers (LoLLMs, Ollama, OpenAI and Azure OpenAI) provide the same REST API endpoints for RAG functionality. | ||||
|  | ||||
| ### Query Endpoints | ||||
|  | ||||
| #### POST /query | ||||
| Query the RAG system with options for different search modes. | ||||
|  | ||||
| ```bash | ||||
| curl -X POST "http://localhost:9721/query" \ | ||||
|     -H "Content-Type: application/json" \ | ||||
|     -d '{"query": "Your question here", "mode": "hybrid", ""}' | ||||
| ``` | ||||
|  | ||||
| #### POST /query/stream | ||||
| Stream responses from the RAG system. | ||||
|  | ||||
| ```bash | ||||
| curl -X POST "http://localhost:9721/query/stream" \ | ||||
|     -H "Content-Type: application/json" \ | ||||
|     -d '{"query": "Your question here", "mode": "hybrid"}' | ||||
| ``` | ||||
|  | ||||
| ### Document Management Endpoints | ||||
|  | ||||
| #### POST /documents/text | ||||
| Insert text directly into the RAG system. | ||||
|  | ||||
| ```bash | ||||
| curl -X POST "http://localhost:9721/documents/text" \ | ||||
|     -H "Content-Type: application/json" \ | ||||
|     -d '{"text": "Your text content here", "description": "Optional description"}' | ||||
| ``` | ||||
|  | ||||
| #### POST /documents/file | ||||
| Upload a single file to the RAG system. | ||||
|  | ||||
| ```bash | ||||
| curl -X POST "http://localhost:9721/documents/file" \ | ||||
|     -F "file=@/path/to/your/document.txt" \ | ||||
|     -F "description=Optional description" | ||||
| ``` | ||||
|  | ||||
| #### POST /documents/batch | ||||
| Upload multiple files at once. | ||||
|  | ||||
| ```bash | ||||
| curl -X POST "http://localhost:9721/documents/batch" \ | ||||
|     -F "files=@/path/to/doc1.txt" \ | ||||
|     -F "files=@/path/to/doc2.txt" | ||||
| ``` | ||||
|  | ||||
| #### POST /documents/scan | ||||
|  | ||||
| Trigger document scan for new files in the Input directory. | ||||
|  | ||||
| ```bash | ||||
| curl -X POST "http://localhost:9721/documents/scan" --max-time 1800 | ||||
| ``` | ||||
|  | ||||
| > Ajust max-time according to the estimated index time  for all new files. | ||||
|  | ||||
| ### Ollama Emulation Endpoints | ||||
|  | ||||
| #### GET /api/version | ||||
|  | ||||
| Get Ollama version information | ||||
|  | ||||
| ```bash | ||||
| curl http://localhost:9721/api/version | ||||
| ``` | ||||
|  | ||||
| #### GET /api/tags | ||||
|  | ||||
| Get Ollama available models | ||||
|  | ||||
| ```bash | ||||
| curl http://localhost:9721/api/tags | ||||
| ``` | ||||
|  | ||||
| #### POST /api/chat | ||||
|  | ||||
| Handle chat completion requests | ||||
|  | ||||
| ```shell | ||||
| curl -N -X POST http://localhost:9721/api/chat -H "Content-Type: application/json" -d \ | ||||
|   '{"model":"lightrag:latest","messages":[{"role":"user","content":"猪八戒是谁"}],"stream":true}' | ||||
| ``` | ||||
|  | ||||
| > For more information about Ollama API pls. visit :  [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md) | ||||
|  | ||||
| #### DELETE /documents | ||||
|  | ||||
| Clear all documents from the RAG system. | ||||
|  | ||||
| ```bash | ||||
| curl -X DELETE "http://localhost:9721/documents" | ||||
| ``` | ||||
|  | ||||
| ### Utility Endpoints | ||||
|  | ||||
| #### GET /health | ||||
| Check server health and configuration. | ||||
|  | ||||
| ```bash | ||||
| curl "http://localhost:9721/health" | ||||
| ``` | ||||
|  | ||||
| ## Development | ||||
| Contribute to the project: [Guide](contributor-readme.MD) | ||||
|  | ||||
| ### Running in Development Mode | ||||
|  | ||||
| For LoLLMs: | ||||
| ```bash | ||||
| uvicorn lollms_lightrag_server:app --reload --port 9721 | ||||
| ``` | ||||
|  | ||||
| For Ollama: | ||||
| ```bash | ||||
| uvicorn ollama_lightrag_server:app --reload --port 9721 | ||||
| ``` | ||||
|  | ||||
| For OpenAI: | ||||
| ```bash | ||||
| uvicorn openai_lightrag_server:app --reload --port 9721 | ||||
| ``` | ||||
| For Azure OpenAI: | ||||
| ```bash | ||||
| uvicorn azure_openai_lightrag_server:app --reload --port 9721 | ||||
| ``` | ||||
| ### API Documentation | ||||
|  | ||||
| When any server is running, visit: | ||||
| - Swagger UI: http://localhost:9721/docs | ||||
| - ReDoc: http://localhost:9721/redoc | ||||
|  | ||||
| ### Testing API Endpoints | ||||
|  | ||||
| You can test the API endpoints using the provided curl commands or through the Swagger UI interface. Make sure to: | ||||
| 1. Start the appropriate backend service (LoLLMs, Ollama, or OpenAI) | ||||
| 2. Start the RAG server | ||||
| 3. Upload some documents using the document management endpoints | ||||
| 4. Query the system using the query endpoints | ||||
| 5. Trigger document scan if new files is put into inputs directory | ||||
|  | ||||
| ### Important Features | ||||
|  | ||||
| #### Automatic Document Vectorization | ||||
| When starting any of the servers with the `--input-dir` parameter, the system will automatically: | ||||
| 1. Check for existing vectorized content in the database | ||||
| 2. Only vectorize new documents that aren't already in the database | ||||
| 3. Make all content immediately available for RAG queries | ||||
|  | ||||
| This intelligent caching mechanism: | ||||
| - Prevents unnecessary re-vectorization of existing documents | ||||
| - Reduces startup time for subsequent runs | ||||
| - Preserves system resources | ||||
| - Maintains consistency across restarts | ||||
|  | ||||
| **Important Notes:** | ||||
| - The `--input-dir` parameter enables automatic document processing at startup | ||||
| - Documents already in the database are not re-vectorized | ||||
| - Only new documents in the input directory will be processed | ||||
| - This optimization significantly reduces startup time for subsequent runs | ||||
| - The working directory (`--working-dir`) stores the vectorized documents database | ||||
|  | ||||
| ## Install Lightrag as a Linux Service | ||||
|  | ||||
| Create your service file: `lightrag-server.sevice`. Modified the following lines from `lightrag-server.sevice.example` | ||||
|  | ||||
| ```text | ||||
| Description=LightRAG Ollama Service | ||||
| WorkingDirectory=<lightrag installed directory> | ||||
| ExecStart=<lightrag installed directory>/lightrag/api/start_lightrag.sh | ||||
| ``` | ||||
|  | ||||
| Create your service startup script: `start_lightrag.sh`. Change you python virtual environment activation method as need: | ||||
|  | ||||
| ```shell | ||||
| #!/bin/bash | ||||
|  | ||||
| # python virtual environment activation | ||||
| source /home/netman/lightrag-xyj/venv/bin/activate | ||||
| # start lightrag api server | ||||
| lightrag-server | ||||
| ``` | ||||
|  | ||||
| Install lightrag.service in Linux.  Sample commands in Ubuntu server look like: | ||||
|  | ||||
| ```shell | ||||
| sudo cp lightrag-server.service /etc/systemd/system/ | ||||
| sudo systemctl daemon-reload | ||||
| sudo systemctl start lightrag-server.service | ||||
| sudo systemctl status lightrag-server.service | ||||
| sudo systemctl enable lightrag-server.service | ||||
| ``` | ||||
							
								
								
									
										1
									
								
								minirag/api/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								minirag/api/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| __api_version__ = "1.0.3" | ||||
							
								
								
									
										1939
									
								
								minirag/api/minirag_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1939
									
								
								minirag/api/minirag_server.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										12
									
								
								minirag/api/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								minirag/api/requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| ascii_colors | ||||
| fastapi | ||||
| nest_asyncio | ||||
| numpy | ||||
| pipmaster | ||||
| python-dotenv | ||||
| python-multipart | ||||
| tenacity | ||||
| tiktoken | ||||
| torch | ||||
| tqdm | ||||
| uvicorn | ||||
							
								
								
									
										2
									
								
								minirag/api/static/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								minirag/api/static/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| # LightRag Webui | ||||
| A simple webui to interact with the lightrag datalake | ||||
							
								
								
									
										
											BIN
										
									
								
								minirag/api/static/favicon.ico
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								minirag/api/static/favicon.ico
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 1.9 MiB | 
							
								
								
									
										104
									
								
								minirag/api/static/index.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								minirag/api/static/index.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,104 @@ | ||||
| <!DOCTYPE html> | ||||
| <html lang="en"> | ||||
| <head> | ||||
|     <meta charset="UTF-8"> | ||||
|     <meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||||
|     <title>MiniRag Interface</title> | ||||
|     <script src="https://cdn.tailwindcss.com"></script> | ||||
|     <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script> | ||||
|     <style> | ||||
|         .fade-in { | ||||
|             animation: fadeIn 0.3s ease-in; | ||||
|         } | ||||
|  | ||||
|         @keyframes fadeIn { | ||||
|             from { opacity: 0; } | ||||
|             to { opacity: 1; } | ||||
|         } | ||||
|  | ||||
|         .spin { | ||||
|             animation: spin 1s linear infinite; | ||||
|         } | ||||
|  | ||||
|         @keyframes spin { | ||||
|             from { transform: rotate(0deg); } | ||||
|             to { transform: rotate(360deg); } | ||||
|         } | ||||
|  | ||||
|         .slide-in { | ||||
|             animation: slideIn 0.3s ease-out; | ||||
|         } | ||||
|  | ||||
|         @keyframes slideIn { | ||||
|             from { transform: translateX(-100%); } | ||||
|             to { transform: translateX(0); } | ||||
|         } | ||||
|     </style> | ||||
| </head> | ||||
| <body class="bg-gray-50"> | ||||
|     <div class="flex h-screen"> | ||||
|         <!-- Sidebar --> | ||||
|         <div class="w-64 bg-white shadow-lg"> | ||||
|             <div class="p-4"> | ||||
|                 <h1 class="text-xl font-bold text-gray-800 mb-6">MiniRag</h1> | ||||
|                 <nav class="space-y-2"> | ||||
|                     <a href="#" class="nav-item" data-page="file-manager"> | ||||
|                         <div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors"> | ||||
|                             <svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24"> | ||||
|                                 <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M3 7v10a2 2 0 002 2h14a2 2 0 002-2V9a2 2 0 00-2-2h-6l-2-2H5a2 2 0 00-2 2z"/> | ||||
|                             </svg> | ||||
|                             File Manager | ||||
|                         </div> | ||||
|                     </a> | ||||
|                     <a href="#" class="nav-item" data-page="query"> | ||||
|                         <div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors"> | ||||
|                             <svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24"> | ||||
|                                 <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z"/> | ||||
|                             </svg> | ||||
|                             Query Database | ||||
|                         </div> | ||||
|                     </a> | ||||
|                     <a href="#" class="nav-item" data-page="knowledge-graph"> | ||||
|                         <div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors"> | ||||
|                             <svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24"> | ||||
|                                 <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z"/> | ||||
|                             </svg> | ||||
|                             Knowledge Graph | ||||
|                         </div> | ||||
|                     </a> | ||||
|                     <a href="#" class="nav-item" data-page="status"> | ||||
|                         <div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors"> | ||||
|                             <svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24"> | ||||
|                                 <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"/> | ||||
|                             </svg> | ||||
|                             Status | ||||
|                         </div> | ||||
|                     </a> | ||||
|                     <a href="#" class="nav-item" data-page="settings"> | ||||
|                         <div class="flex items-center p-2 rounded-lg hover:bg-gray-100 transition-colors"> | ||||
|                             <svg class="w-5 h-5 mr-3" fill="none" stroke="currentColor" viewBox="0 0 24 24"> | ||||
|                                 <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"/> | ||||
|                                 <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"/> | ||||
|                             </svg> | ||||
|                             Settings | ||||
|                         </div> | ||||
|                     </a> | ||||
|                 </nav> | ||||
|             </div> | ||||
|         </div> | ||||
|  | ||||
|         <!-- Main Content --> | ||||
|         <div class="flex-1 overflow-auto p-6"> | ||||
|             <div id="content" class="fade-in"></div> | ||||
|         </div> | ||||
|  | ||||
|         <!-- Toast Notification --> | ||||
|         <div id="toast" class="fixed bottom-4 right-4 hidden"> | ||||
|             <div class="bg-gray-800 text-white px-6 py-3 rounded-lg shadow-lg"></div> | ||||
|         </div> | ||||
|     </div> | ||||
|  | ||||
|     <script src="/js/api.js"></script> | ||||
|  | ||||
| </body> | ||||
| </html> | ||||
							
								
								
									
										404
									
								
								minirag/api/static/js/api.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										404
									
								
								minirag/api/static/js/api.js
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,404 @@ | ||||
| // State management | ||||
| const state = { | ||||
|     apiKey: localStorage.getItem('apiKey') || '', | ||||
|     files: [], | ||||
|     indexedFiles: [], | ||||
|     currentPage: 'file-manager' | ||||
| }; | ||||
|  | ||||
| // Utility functions | ||||
| const showToast = (message, duration = 3000) => { | ||||
|     const toast = document.getElementById('toast'); | ||||
|     toast.querySelector('div').textContent = message; | ||||
|     toast.classList.remove('hidden'); | ||||
|     setTimeout(() => toast.classList.add('hidden'), duration); | ||||
| }; | ||||
|  | ||||
| const fetchWithAuth = async (url, options = {}) => { | ||||
|     const headers = { | ||||
|         ...(options.headers || {}), | ||||
|         ...(state.apiKey ? { 'Authorization': `Bearer ${state.apiKey}` } : {}) | ||||
|     }; | ||||
|     return fetch(url, { ...options, headers }); | ||||
| }; | ||||
|  | ||||
| // Page renderers | ||||
| const pages = { | ||||
|     'file-manager': () => ` | ||||
|         <div class="space-y-6"> | ||||
|             <h2 class="text-2xl font-bold text-gray-800">File Manager</h2> | ||||
|  | ||||
|             <div class="border-2 border-dashed border-gray-300 rounded-lg p-8 text-center hover:border-gray-400 transition-colors"> | ||||
|                 <input type="file" id="fileInput" multiple accept=".txt,.md,.doc,.docx,.pdf,.pptx" class="hidden"> | ||||
|                 <label for="fileInput" class="cursor-pointer"> | ||||
|                     <svg class="mx-auto h-12 w-12 text-gray-400" fill="none" stroke="currentColor" viewBox="0 0 24 24"> | ||||
|                         <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 16a4 4 0 01-.88-7.903A5 5 0 1115.9 6L16 6a5 5 0 011 9.9M15 13l-3-3m0 0l-3 3m3-3v12"/> | ||||
|                     </svg> | ||||
|                     <p class="mt-2 text-gray-600">Drag files here or click to select</p> | ||||
|                     <p class="text-sm text-gray-500">Supported formats: TXT, MD, DOC, PDF, PPTX</p> | ||||
|                 </label> | ||||
|             </div> | ||||
|  | ||||
|             <div id="fileList" class="space-y-2"> | ||||
|                 <h3 class="text-lg font-semibold text-gray-700">Selected Files</h3> | ||||
|                 <div class="space-y-2"></div> | ||||
|             </div> | ||||
|             <div id="uploadProgress" class="hidden mt-4"> | ||||
|                 <div class="w-full bg-gray-200 rounded-full h-2.5"> | ||||
|                     <div class="bg-blue-600 h-2.5 rounded-full" style="width: 0%"></div> | ||||
|                 </div> | ||||
|                 <p class="text-sm text-gray-600 mt-2"><span id="uploadStatus">0</span> files processed</p> | ||||
|             </div> | ||||
|             <button id="rescanBtn" class="flex items-center bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors"> | ||||
|                 <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="20" height="20" fill="currentColor" class="mr-2"> | ||||
|                     <path d="M12 4a8 8 0 1 1-8 8H2.5a9.5 9.5 0 1 0 2.8-6.7L2 3v6h6L5.7 6.7A7.96 7.96 0 0 1 12 4z"/> | ||||
|                 </svg> | ||||
|                 Rescan Files | ||||
|             </button> | ||||
|             <button id="uploadBtn" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors"> | ||||
|                 Upload & Index Files | ||||
|             </button> | ||||
|  | ||||
|             <div id="indexedFiles" class="space-y-2"> | ||||
|                 <h3 class="text-lg font-semibold text-gray-700">Indexed Files</h3> | ||||
|                 <div class="space-y-2"></div> | ||||
|             </div> | ||||
|  | ||||
|  | ||||
|  | ||||
|         </div> | ||||
|     `, | ||||
|  | ||||
|     'query': () => ` | ||||
|         <div class="space-y-6"> | ||||
|             <h2 class="text-2xl font-bold text-gray-800">Query Database</h2> | ||||
|  | ||||
|             <div class="space-y-4"> | ||||
|                 <div> | ||||
|                     <label class="block text-sm font-medium text-gray-700">Query Mode</label> | ||||
|                     <select id="queryMode" class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500"> | ||||
|                         <option value="light">Light</option> | ||||
|                         <option value="naive">Naive</option> | ||||
|                         <option value="mini">Mini</option> | ||||
|                     </select> | ||||
|                 </div> | ||||
|  | ||||
|                 <div> | ||||
|                     <label class="block text-sm font-medium text-gray-700">Query</label> | ||||
|                     <textarea id="queryInput" rows="4" class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500"></textarea> | ||||
|                 </div> | ||||
|  | ||||
|                 <button id="queryBtn" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors"> | ||||
|                     Send Query | ||||
|                 </button> | ||||
|  | ||||
|                 <div id="queryResult" class="mt-4 p-4 bg-white rounded-lg shadow"></div> | ||||
|             </div> | ||||
|         </div> | ||||
|     `, | ||||
|  | ||||
|     'knowledge-graph': () => ` | ||||
|         <div class="flex items-center justify-center h-full"> | ||||
|             <div class="text-center"> | ||||
|                 <svg class="mx-auto h-12 w-12 text-gray-400" fill="none" stroke="currentColor" viewBox="0 0 24 24"> | ||||
|                     <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 11H5m14 0a2 2 0 012 2v6a2 2 0 01-2 2H5a2 2 0 01-2-2v-6a2 2 0 012-2m14 0V9a2 2 0 00-2-2M5 11V9a2 2 0 012-2m0 0V5a2 2 0 012-2h6a2 2 0 012 2v2M7 7h10"/> | ||||
|                 </svg> | ||||
|                 <h3 class="mt-2 text-sm font-medium text-gray-900">Under Construction</h3> | ||||
|                 <p class="mt-1 text-sm text-gray-500">Knowledge graph visualization will be available in a future update.</p> | ||||
|             </div> | ||||
|         </div> | ||||
|     `, | ||||
|  | ||||
|     'status': () => ` | ||||
|         <div class="space-y-6"> | ||||
|             <h2 class="text-2xl font-bold text-gray-800">System Status</h2> | ||||
|             <div id="statusContent" class="grid grid-cols-1 md:grid-cols-2 gap-6"> | ||||
|                 <div class="p-6 bg-white rounded-lg shadow-sm"> | ||||
|                     <h3 class="text-lg font-semibold mb-4">System Health</h3> | ||||
|                     <div id="healthStatus"></div> | ||||
|                 </div> | ||||
|                 <div class="p-6 bg-white rounded-lg shadow-sm"> | ||||
|                     <h3 class="text-lg font-semibold mb-4">Configuration</h3> | ||||
|                     <div id="configStatus"></div> | ||||
|                 </div> | ||||
|             </div> | ||||
|         </div> | ||||
|     `, | ||||
|  | ||||
|     'settings': () => ` | ||||
|         <div class="space-y-6"> | ||||
|             <h2 class="text-2xl font-bold text-gray-800">Settings</h2> | ||||
|  | ||||
|             <div class="max-w-xl"> | ||||
|                 <div class="space-y-4"> | ||||
|                     <div> | ||||
|                         <label class="block text-sm font-medium text-gray-700">API Key</label> | ||||
|                         <input type="password" id="apiKeyInput" value="${state.apiKey}" | ||||
|                             class="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500"> | ||||
|                     </div> | ||||
|  | ||||
|                     <button id="saveSettings" class="bg-blue-600 text-white px-4 py-2 rounded-lg hover:bg-blue-700 transition-colors"> | ||||
|                         Save Settings | ||||
|                     </button> | ||||
|                 </div> | ||||
|             </div> | ||||
|         </div> | ||||
|     ` | ||||
| }; | ||||
|  | ||||
| // Page handlers | ||||
| const handlers = { | ||||
|     'file-manager': () => { | ||||
|         const fileInput = document.getElementById('fileInput'); | ||||
|         const dropZone = fileInput.parentElement.parentElement; | ||||
|         const fileList = document.querySelector('#fileList div'); | ||||
|         const indexedFiles = document.querySelector('#indexedFiles div'); | ||||
|         const uploadBtn = document.getElementById('uploadBtn'); | ||||
|  | ||||
|         const updateFileList = () => { | ||||
|             fileList.innerHTML = state.files.map(file => ` | ||||
|                 <div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm"> | ||||
|                     <span>${file.name}</span> | ||||
|                     <button class="text-red-600 hover:text-red-700" onclick="removeFile('${file.name}')"> | ||||
|                         <svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24"> | ||||
|                             <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"/> | ||||
|                         </svg> | ||||
|                     </button> | ||||
|                 </div> | ||||
|             `).join(''); | ||||
|         }; | ||||
|  | ||||
|         const updateIndexedFiles = async () => { | ||||
|             const response = await fetchWithAuth('/health'); | ||||
|             const data = await response.json(); | ||||
|             indexedFiles.innerHTML = data.indexed_files.map(file => ` | ||||
|                 <div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm"> | ||||
|                     <span>${file}</span> | ||||
|                 </div> | ||||
|             `).join(''); | ||||
|         }; | ||||
|  | ||||
|         dropZone.addEventListener('dragover', (e) => { | ||||
|             e.preventDefault(); | ||||
|             dropZone.classList.add('border-blue-500'); | ||||
|         }); | ||||
|  | ||||
|         dropZone.addEventListener('dragleave', () => { | ||||
|             dropZone.classList.remove('border-blue-500'); | ||||
|         }); | ||||
|  | ||||
|         dropZone.addEventListener('drop', (e) => { | ||||
|             e.preventDefault(); | ||||
|             dropZone.classList.remove('border-blue-500'); | ||||
|             const files = Array.from(e.dataTransfer.files); | ||||
|             state.files.push(...files); | ||||
|             updateFileList(); | ||||
|         }); | ||||
|  | ||||
|         fileInput.addEventListener('change', () => { | ||||
|             state.files.push(...Array.from(fileInput.files)); | ||||
|             updateFileList(); | ||||
|         }); | ||||
|  | ||||
|         uploadBtn.addEventListener('click', async () => { | ||||
|             if (state.files.length === 0) { | ||||
|                 showToast('Please select files to upload'); | ||||
|                 return; | ||||
|             } | ||||
|             let apiKey = localStorage.getItem('apiKey') || ''; | ||||
|             const progress = document.getElementById('uploadProgress'); | ||||
|             const progressBar = progress.querySelector('div'); | ||||
|             const statusText = document.getElementById('uploadStatus'); | ||||
|             progress.classList.remove('hidden'); | ||||
|  | ||||
|             for (let i = 0; i < state.files.length; i++) { | ||||
|                 const formData = new FormData(); | ||||
|                 formData.append('file', state.files[i]); | ||||
|  | ||||
|                 try { | ||||
|                     await fetch('/documents/upload', { | ||||
|                         method: 'POST', | ||||
|                         headers: apiKey ? { 'Authorization': `Bearer ${apiKey}` } : {}, | ||||
|                         body: formData | ||||
|                     }); | ||||
|  | ||||
|                     const percentage = ((i + 1) / state.files.length) * 100; | ||||
|                     progressBar.style.width = `${percentage}%`; | ||||
|                     statusText.textContent = `${i + 1}/${state.files.length}`; | ||||
|                 } catch (error) { | ||||
|                     console.error('Upload error:', error); | ||||
|                 } | ||||
|             } | ||||
|             progress.classList.add('hidden'); | ||||
|         }); | ||||
|  | ||||
|         rescanBtn.addEventListener('click', async () => { | ||||
|             const progress = document.getElementById('uploadProgress'); | ||||
|             const progressBar = progress.querySelector('div'); | ||||
|             const statusText = document.getElementById('uploadStatus'); | ||||
|             progress.classList.remove('hidden'); | ||||
|  | ||||
|             try { | ||||
|                 // Start the scanning process | ||||
|                 const scanResponse = await fetch('/documents/scan', { | ||||
|                     method: 'POST', | ||||
|                 }); | ||||
|  | ||||
|                 if (!scanResponse.ok) { | ||||
|                     throw new Error('Scan failed to start'); | ||||
|                 } | ||||
|  | ||||
|                 // Start polling for progress | ||||
|                 const pollInterval = setInterval(async () => { | ||||
|                     const progressResponse = await fetch('/documents/scan-progress'); | ||||
|                     const progressData = await progressResponse.json(); | ||||
|  | ||||
|                     // Update progress bar | ||||
|                     progressBar.style.width = `${progressData.progress}%`; | ||||
|                      | ||||
|                     // Update status text | ||||
|                     if (progressData.total_files > 0) { | ||||
|                         statusText.textContent = `Processing ${progressData.current_file} (${progressData.indexed_count}/${progressData.total_files})`; | ||||
|                     } | ||||
|  | ||||
|                     // Check if scanning is complete | ||||
|                     if (!progressData.is_scanning) { | ||||
|                         clearInterval(pollInterval); | ||||
|                         progress.classList.add('hidden'); | ||||
|                         statusText.textContent = 'Scan complete!'; | ||||
|                     } | ||||
|                 }, 1000); // Poll every second | ||||
|  | ||||
|             } catch (error) { | ||||
|                 console.error('Upload error:', error); | ||||
|                 progress.classList.add('hidden'); | ||||
|                 statusText.textContent = 'Error during scanning process'; | ||||
|             } | ||||
|         }); | ||||
|  | ||||
|  | ||||
|         updateIndexedFiles(); | ||||
|     }, | ||||
|  | ||||
|     'query': () => { | ||||
|         const queryBtn = document.getElementById('queryBtn'); | ||||
|         const queryInput = document.getElementById('queryInput'); | ||||
|         const queryMode = document.getElementById('queryMode'); | ||||
|         const queryResult = document.getElementById('queryResult'); | ||||
|  | ||||
|         let apiKey = localStorage.getItem('apiKey') || ''; | ||||
|  | ||||
|         queryBtn.addEventListener('click', async () => { | ||||
|             const query = queryInput.value.trim(); | ||||
|             if (!query) { | ||||
|                 showToast('Please enter a query'); | ||||
|                 return; | ||||
|             } | ||||
|  | ||||
|             queryBtn.disabled = true; | ||||
|             queryBtn.innerHTML = ` | ||||
|                 <svg class="animate-spin h-5 w-5 mr-3" viewBox="0 0 24 24"> | ||||
|                     <circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4" fill="none"/> | ||||
|                     <path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"/> | ||||
|                 </svg> | ||||
|                 Processing... | ||||
|             `; | ||||
|  | ||||
|             try { | ||||
|                 const response = await fetchWithAuth('/query', { | ||||
|                     method: 'POST', | ||||
|                     headers: { 'Content-Type': 'application/json' }, | ||||
|                     body: JSON.stringify({ | ||||
|                         query, | ||||
|                         mode: queryMode.value, | ||||
|                         stream: false, | ||||
|                         only_need_context: false | ||||
|                     }) | ||||
|                 }); | ||||
|  | ||||
|                 const data = await response.json(); | ||||
|                 queryResult.innerHTML = marked.parse(data.response); | ||||
|             } catch (error) { | ||||
|                 showToast('Error processing query'); | ||||
|             } finally { | ||||
|                 queryBtn.disabled = false; | ||||
|                 queryBtn.textContent = 'Send Query'; | ||||
|             } | ||||
|         }); | ||||
|     }, | ||||
|  | ||||
|     'status': async () => { | ||||
|         const healthStatus = document.getElementById('healthStatus'); | ||||
|         const configStatus = document.getElementById('configStatus'); | ||||
|  | ||||
|         try { | ||||
|             const response = await fetchWithAuth('/health'); | ||||
|             const data = await response.json(); | ||||
|  | ||||
|             healthStatus.innerHTML = ` | ||||
|                 <div class="space-y-2"> | ||||
|                     <div class="flex items-center"> | ||||
|                         <div class="w-3 h-3 rounded-full ${data.status === 'healthy' ? 'bg-green-500' : 'bg-red-500'} mr-2"></div> | ||||
|                         <span class="font-medium">${data.status}</span> | ||||
|                     </div> | ||||
|                     <div> | ||||
|                         <p class="text-sm text-gray-600">Working Directory: ${data.working_directory}</p> | ||||
|                         <p class="text-sm text-gray-600">Input Directory: ${data.input_directory}</p> | ||||
|                         <p class="text-sm text-gray-600">Indexed Files: ${data.indexed_files_count}</p> | ||||
|                     </div> | ||||
|                 </div> | ||||
|             `; | ||||
|  | ||||
|             configStatus.innerHTML = Object.entries(data.configuration) | ||||
|                 .map(([key, value]) => ` | ||||
|                     <div class="mb-2"> | ||||
|                         <span class="text-sm font-medium text-gray-700">${key}:</span> | ||||
|                         <span class="text-sm text-gray-600 ml-2">${value}</span> | ||||
|                     </div> | ||||
|                 `).join(''); | ||||
|         } catch (error) { | ||||
|             showToast('Error fetching status'); | ||||
|         } | ||||
|     }, | ||||
|  | ||||
|     'settings': () => { | ||||
|         const saveBtn = document.getElementById('saveSettings'); | ||||
|         const apiKeyInput = document.getElementById('apiKeyInput'); | ||||
|  | ||||
|         saveBtn.addEventListener('click', () => { | ||||
|             state.apiKey = apiKeyInput.value; | ||||
|             localStorage.setItem('apiKey', state.apiKey); | ||||
|             showToast('Settings saved successfully'); | ||||
|         }); | ||||
|     } | ||||
| }; | ||||
|  | ||||
| // Navigation handling | ||||
| document.querySelectorAll('.nav-item').forEach(item => { | ||||
|     item.addEventListener('click', (e) => { | ||||
|         e.preventDefault(); | ||||
|         const page = item.dataset.page; | ||||
|         document.getElementById('content').innerHTML = pages[page](); | ||||
|         if (handlers[page]) handlers[page](); | ||||
|         state.currentPage = page; | ||||
|     }); | ||||
| }); | ||||
|  | ||||
| // Initialize with file manager | ||||
| document.getElementById('content').innerHTML = pages['file-manager'](); | ||||
| handlers['file-manager'](); | ||||
|  | ||||
| // Global functions | ||||
| window.removeFile = (fileName) => { | ||||
|     state.files = state.files.filter(file => file.name !== fileName); | ||||
|     document.querySelector('#fileList div').innerHTML = state.files.map(file => ` | ||||
|         <div class="flex items-center justify-between bg-white p-3 rounded-lg shadow-sm"> | ||||
|             <span>${file.name}</span> | ||||
|             <button class="text-red-600 hover:text-red-700" onclick="removeFile('${file.name}')"> | ||||
|                 <svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24"> | ||||
|                     <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"/> | ||||
|                 </svg> | ||||
|             </button> | ||||
|         </div> | ||||
|     `).join(''); | ||||
| }; | ||||
							
								
								
									
										211
									
								
								minirag/api/static/js/graph.js
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										211
									
								
								minirag/api/static/js/graph.js
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,211 @@ | ||||
| // js/graph.js | ||||
| function openGraphModal(label) { | ||||
|     const modal = document.getElementById("graph-modal"); | ||||
|     const graphTitle = document.getElementById("graph-title"); | ||||
|  | ||||
|     if (!modal || !graphTitle) { | ||||
|         console.error("Key element not found"); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     graphTitle.textContent = `Knowledge Graph - ${label}`; | ||||
|     modal.style.display = "flex"; | ||||
|  | ||||
|     renderGraph(label); | ||||
| } | ||||
|  | ||||
| function closeGraphModal() { | ||||
|     const modal = document.getElementById("graph-modal"); | ||||
|     modal.style.display = "none"; | ||||
|     clearGraph(); | ||||
| } | ||||
|  | ||||
| function clearGraph() { | ||||
|     const svg = document.getElementById("graph-svg"); | ||||
|     svg.innerHTML = ""; | ||||
| } | ||||
|  | ||||
|  | ||||
| async function getGraph(label) { | ||||
|     try { | ||||
|         const response = await fetch(`/graphs?label=${label}`); | ||||
|         const rawData = await response.json(); | ||||
|         console.log({data: JSON.parse(JSON.stringify(rawData))}); | ||||
|  | ||||
|         const nodes = rawData.nodes | ||||
|  | ||||
|         nodes.forEach(node => { | ||||
|             node.id = Date.now().toString(36) + Math.random().toString(36).substring(2); // 使用 crypto.randomUUID() 生成唯一 UUID | ||||
|         }); | ||||
|  | ||||
|         //  Strictly verify edge data | ||||
|         const edges = (rawData.edges || []).map(edge => { | ||||
|             const sourceNode = nodes.find(n => n.labels.includes(edge.source)); | ||||
|             const targetNode = nodes.find(n => n.labels.includes(edge.target) | ||||
|                 ) | ||||
|             ; | ||||
|             if (!sourceNode || !targetNode) { | ||||
|                 console.warn("NOT VALID EDGE:", edge); | ||||
|                 return null; | ||||
|             } | ||||
|             return { | ||||
|                 source: sourceNode, | ||||
|                 target: targetNode, | ||||
|                 type: edge.type || "" | ||||
|             }; | ||||
|         }).filter(edge => edge !== null); | ||||
|  | ||||
|         return {nodes, edges}; | ||||
|     } catch (error) { | ||||
|         console.error("Loading graph failed:", error); | ||||
|         return {nodes: [], edges: []}; | ||||
|     } | ||||
| } | ||||
|  | ||||
| async function renderGraph(label) { | ||||
|     const data = await getGraph(label); | ||||
|  | ||||
|  | ||||
|     if (!data.nodes || data.nodes.length === 0) { | ||||
|         d3.select("#graph-svg") | ||||
|             .html(`<text x="50%" y="50%" text-anchor="middle">No valid nodes</text>`); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|  | ||||
|     const svg = d3.select("#graph-svg"); | ||||
|     const width = svg.node().clientWidth; | ||||
|     const height = svg.node().clientHeight; | ||||
|  | ||||
|     svg.selectAll("*").remove(); | ||||
|  | ||||
|     //  Create a force oriented diagram layout | ||||
|     const simulation = d3.forceSimulation(data.nodes) | ||||
|         .force("charge", d3.forceManyBody().strength(-300)) | ||||
|         .force("center", d3.forceCenter(width / 2, height / 2)); | ||||
|  | ||||
|     //  Add a connection (if there are valid edges) | ||||
|     if (data.edges.length > 0) { | ||||
|         simulation.force("link", | ||||
|             d3.forceLink(data.edges) | ||||
|                 .id(d => d.id) | ||||
|                 .distance(100) | ||||
|         ); | ||||
|     } | ||||
|  | ||||
|     //  Draw nodes | ||||
|     const nodes = svg.selectAll(".node") | ||||
|         .data(data.nodes) | ||||
|         .enter() | ||||
|         .append("circle") | ||||
|         .attr("class", "node") | ||||
|         .attr("r", 10) | ||||
|         .call(d3.drag() | ||||
|             .on("start", dragStarted) | ||||
|             .on("drag", dragged) | ||||
|             .on("end", dragEnded) | ||||
|         ); | ||||
|  | ||||
|  | ||||
|     svg.append("defs") | ||||
|         .append("marker") | ||||
|         .attr("id", "arrow-out") | ||||
|         .attr("viewBox", "0 0 10 10") | ||||
|         .attr("refX", 8) | ||||
|         .attr("refY", 5) | ||||
|         .attr("markerWidth", 6) | ||||
|         .attr("markerHeight", 6) | ||||
|         .attr("orient", "auto") | ||||
|         .append("path") | ||||
|         .attr("d", "M0,0 L10,5 L0,10 Z") | ||||
|         .attr("fill", "#999"); | ||||
|  | ||||
|     //  Draw edges (with arrows) | ||||
|     const links = svg.selectAll(".link") | ||||
|         .data(data.edges) | ||||
|         .enter() | ||||
|         .append("line") | ||||
|         .attr("class", "link") | ||||
|         .attr("marker-end", "url(#arrow-out)"); //  Always draw arrows on the target side | ||||
|  | ||||
|     //  Edge style configuration | ||||
|     links | ||||
|         .attr("stroke", "#999") | ||||
|         .attr("stroke-width", 2) | ||||
|         .attr("stroke-opacity", 0.8); | ||||
|  | ||||
|     //  Draw label (with background box) | ||||
|     const labels = svg.selectAll(".label") | ||||
|         .data(data.nodes) | ||||
|         .enter() | ||||
|         .append("text") | ||||
|         .attr("class", "label") | ||||
|         .text(d => d.labels[0] || "") | ||||
|         .attr("text-anchor", "start") | ||||
|         .attr("dy", "0.3em") | ||||
|         .attr("fill", "#333"); | ||||
|  | ||||
|     //  Update Location | ||||
|     simulation.on("tick", () => { | ||||
|         links | ||||
|             .attr("x1", d => { | ||||
|                 //  Calculate the direction vector from the source node to the target node | ||||
|                 const dx = d.target.x - d.source.x; | ||||
|                 const dy = d.target.y - d.source.y; | ||||
|                 const distance = Math.sqrt(dx * dx + dy * dy); | ||||
|                 if (distance === 0) return d.source.x; // 避免除以零 Avoid dividing by zero | ||||
|                 // Adjust the starting point coordinates (source node edge) based on radius 10 | ||||
|                 return d.source.x + (dx / distance) * 10; | ||||
|             }) | ||||
|             .attr("y1", d => { | ||||
|                 const dx = d.target.x - d.source.x; | ||||
|                 const dy = d.target.y - d.source.y; | ||||
|                 const distance = Math.sqrt(dx * dx + dy * dy); | ||||
|                 if (distance === 0) return d.source.y; | ||||
|                 return d.source.y + (dy / distance) * 10; | ||||
|             }) | ||||
|             .attr("x2", d => { | ||||
|                 // Adjust the endpoint coordinates (target node edge) based on a radius of 10 | ||||
|                 const dx = d.target.x - d.source.x; | ||||
|                 const dy = d.target.y - d.source.y; | ||||
|                 const distance = Math.sqrt(dx * dx + dy * dy); | ||||
|                 if (distance === 0) return d.target.x; | ||||
|                 return d.target.x - (dx / distance) * 10; | ||||
|             }) | ||||
|             .attr("y2", d => { | ||||
|                 const dx = d.target.x - d.source.x; | ||||
|                 const dy = d.target.y - d.source.y; | ||||
|                 const distance = Math.sqrt(dx * dx + dy * dy); | ||||
|                 if (distance === 0) return d.target.y; | ||||
|                 return d.target.y - (dy / distance) * 10; | ||||
|             }); | ||||
|  | ||||
|         // Update the position of nodes and labels (keep unchanged) | ||||
|         nodes | ||||
|             .attr("cx", d => d.x) | ||||
|             .attr("cy", d => d.y); | ||||
|  | ||||
|         labels | ||||
|             .attr("x", d => d.x + 12) | ||||
|             .attr("y", d => d.y + 4); | ||||
|     }); | ||||
|  | ||||
|     // Drag and drop logic | ||||
|     function dragStarted(event, d) { | ||||
|         if (!event.active) simulation.alphaTarget(0.3).restart(); | ||||
|         d.fx = d.x; | ||||
|         d.fy = d.y; | ||||
|     } | ||||
|  | ||||
|     function dragged(event, d) { | ||||
|         d.fx = event.x; | ||||
|         d.fy = event.y; | ||||
|         simulation.alpha(0.3).restart(); | ||||
|     } | ||||
|  | ||||
|     function dragEnded(event, d) { | ||||
|         if (!event.active) simulation.alphaTarget(0); | ||||
|         d.fx = null; | ||||
|         d.fy = null; | ||||
|     } | ||||
| } | ||||
| @@ -1,6 +1,6 @@ | ||||
| from dataclasses import dataclass, field | ||||
| from typing import TypedDict, Union, Literal, Generic, TypeVar | ||||
|  | ||||
| import os | ||||
| import numpy as np | ||||
|  | ||||
| from .utils import EmbeddingFunc | ||||
| @@ -17,16 +17,28 @@ T = TypeVar("T") | ||||
| class QueryParam: | ||||
|     mode: Literal["light", "naive","mini"] = "mini" | ||||
|     only_need_context: bool = False | ||||
|     only_need_prompt: bool = False | ||||
|     response_type: str = "Multiple Paragraphs" | ||||
|     stream: bool = False | ||||
|     # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. | ||||
|     top_k: int = 5 | ||||
|     top_k: int = int(os.getenv("TOP_K", "60")) | ||||
|     # Number of document chunks to retrieve. | ||||
|     # top_n: int = 10 | ||||
|     # Number of tokens for the original chunks. | ||||
|     max_token_for_text_unit: int = 2000 | ||||
|     max_token_for_text_unit: int = 4000 | ||||
|     # Number of tokens for the relationship descriptions | ||||
|     max_token_for_global_context: int = 2000 | ||||
|     max_token_for_global_context: int = 4000 | ||||
|     # Number of tokens for the entity descriptions | ||||
|     max_token_for_local_context: int = 2000#For Light/Graph | ||||
|     max_token_for_node_context: int = 500#For Mini, if too long, SLM may be fail to generate any response | ||||
|     max_token_for_local_context: int = 4000 | ||||
|     hl_keywords: list[str] = field(default_factory=list) | ||||
|     ll_keywords: list[str] = field(default_factory=list) | ||||
|     # Conversation history support | ||||
|     conversation_history: list[dict] = field( | ||||
|         default_factory=list | ||||
|     )  # Format: [{"role": "user/assistant", "content": "message"}] | ||||
|     history_turns: int = ( | ||||
|         3  # Number of complete conversation turns (user-assistant pairs) to consider | ||||
|     ) | ||||
|  | ||||
| @dataclass | ||||
| class StorageNameSpace: | ||||
|   | ||||
							
								
								
									
										621
									
								
								minirag/kg/age_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										621
									
								
								minirag/kg/age_impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,621 @@ | ||||
| import asyncio | ||||
| import inspect | ||||
| import json | ||||
| import os | ||||
| import sys | ||||
| from contextlib import asynccontextmanager | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union | ||||
| import pipmaster as pm | ||||
|  | ||||
| if not pm.is_installed("psycopg-pool"): | ||||
|     pm.install("psycopg-pool") | ||||
|     pm.install("psycopg[binary,pool]") | ||||
| if not pm.is_installed("asyncpg"): | ||||
|     pm.install("asyncpg") | ||||
|  | ||||
|  | ||||
| import psycopg | ||||
| from psycopg.rows import namedtuple_row | ||||
| from psycopg_pool import AsyncConnectionPool, PoolTimeout | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     retry_if_exception_type, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
| ) | ||||
|  | ||||
| from lightrag.utils import logger | ||||
|  | ||||
| from ..base import BaseGraphStorage | ||||
|  | ||||
| if sys.platform.startswith("win"): | ||||
|     import asyncio.windows_events | ||||
|  | ||||
|     asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) | ||||
|  | ||||
|  | ||||
| class AGEQueryException(Exception): | ||||
|     """Exception for the AGE queries.""" | ||||
|  | ||||
|     def __init__(self, exception: Union[str, Dict]) -> None: | ||||
|         if isinstance(exception, dict): | ||||
|             self.message = exception["message"] if "message" in exception else "unknown" | ||||
|             self.details = exception["details"] if "details" in exception else "unknown" | ||||
|         else: | ||||
|             self.message = exception | ||||
|             self.details = "unknown" | ||||
|  | ||||
|     def get_message(self) -> str: | ||||
|         return self.message | ||||
|  | ||||
|     def get_details(self) -> Any: | ||||
|         return self.details | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class AGEStorage(BaseGraphStorage): | ||||
|     @staticmethod | ||||
|     def load_nx_graph(file_name): | ||||
|         print("no preloading of graph with AGE in production") | ||||
|  | ||||
|     def __init__(self, namespace, global_config, embedding_func): | ||||
|         super().__init__( | ||||
|             namespace=namespace, | ||||
|             global_config=global_config, | ||||
|             embedding_func=embedding_func, | ||||
|         ) | ||||
|         self._driver = None | ||||
|         self._driver_lock = asyncio.Lock() | ||||
|         DB = os.environ["AGE_POSTGRES_DB"].replace("\\", "\\\\").replace("'", "\\'") | ||||
|         USER = os.environ["AGE_POSTGRES_USER"].replace("\\", "\\\\").replace("'", "\\'") | ||||
|         PASSWORD = ( | ||||
|             os.environ["AGE_POSTGRES_PASSWORD"] | ||||
|             .replace("\\", "\\\\") | ||||
|             .replace("'", "\\'") | ||||
|         ) | ||||
|         HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'") | ||||
|         PORT = int(os.environ["AGE_POSTGRES_PORT"]) | ||||
|         self.graph_name = os.environ["AGE_GRAPH_NAME"] | ||||
|  | ||||
|         connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" | ||||
|  | ||||
|         self._driver = AsyncConnectionPool(connection_string, open=False) | ||||
|  | ||||
|         return None | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         self._node_embed_algorithms = { | ||||
|             "node2vec": self._node2vec_embed, | ||||
|         } | ||||
|  | ||||
|     async def close(self): | ||||
|         if self._driver: | ||||
|             await self._driver.close() | ||||
|             self._driver = None | ||||
|  | ||||
|     async def __aexit__(self, exc_type, exc, tb): | ||||
|         if self._driver: | ||||
|             await self._driver.close() | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         print("KG successfully indexed.") | ||||
|  | ||||
|     @staticmethod | ||||
|     def _record_to_dict(record: NamedTuple) -> Dict[str, Any]: | ||||
|         """ | ||||
|         Convert a record returned from an age query to a dictionary | ||||
|  | ||||
|         Args: | ||||
|             record (): a record from an age query result | ||||
|  | ||||
|         Returns: | ||||
|             Dict[str, Any]: a dictionary representation of the record where | ||||
|                 the dictionary key is the field name and the value is the | ||||
|                 value converted to a python type | ||||
|         """ | ||||
|         # result holder | ||||
|         d = {} | ||||
|  | ||||
|         # prebuild a mapping of vertex_id to vertex mappings to be used | ||||
|         # later to build edges | ||||
|         vertices = {} | ||||
|         for k in record._fields: | ||||
|             v = getattr(record, k) | ||||
|             # agtype comes back '{key: value}::type' which must be parsed | ||||
|             if isinstance(v, str) and "::" in v: | ||||
|                 dtype = v.split("::")[-1] | ||||
|                 v = v.split("::")[0] | ||||
|                 if dtype == "vertex": | ||||
|                     vertex = json.loads(v) | ||||
|                     vertices[vertex["id"]] = vertex.get("properties") | ||||
|  | ||||
|         # iterate returned fields and parse appropriately | ||||
|         for k in record._fields: | ||||
|             v = getattr(record, k) | ||||
|             if isinstance(v, str) and "::" in v: | ||||
|                 dtype = v.split("::")[-1] | ||||
|                 v = v.split("::")[0] | ||||
|             else: | ||||
|                 dtype = "" | ||||
|  | ||||
|             if dtype == "vertex": | ||||
|                 vertex = json.loads(v) | ||||
|                 field = json.loads(v).get("properties") | ||||
|                 if not field: | ||||
|                     field = {} | ||||
|                 field["label"] = AGEStorage._decode_graph_label(vertex["label"]) | ||||
|                 d[k] = field | ||||
|             # convert edge from id-label->id by replacing id with node information | ||||
|             # we only do this if the vertex was also returned in the query | ||||
|             # this is an attempt to be consistent with neo4j implementation | ||||
|             elif dtype == "edge": | ||||
|                 edge = json.loads(v) | ||||
|                 d[k] = ( | ||||
|                     vertices.get(edge["start_id"], {}), | ||||
|                     edge[ | ||||
|                         "label" | ||||
|                     ],  # we don't use decode_graph_label(), since edge label is always "DIRECTED" | ||||
|                     vertices.get(edge["end_id"], {}), | ||||
|                 ) | ||||
|             else: | ||||
|                 d[k] = json.loads(v) if isinstance(v, str) else v | ||||
|  | ||||
|         return d | ||||
|  | ||||
|     @staticmethod | ||||
|     def _format_properties( | ||||
|         properties: Dict[str, Any], _id: Union[str, None] = None | ||||
|     ) -> str: | ||||
|         """ | ||||
|         Convert a dictionary of properties to a string representation that | ||||
|         can be used in a cypher query insert/merge statement. | ||||
|  | ||||
|         Args: | ||||
|             properties (Dict[str,str]): a dictionary containing node/edge properties | ||||
|             id (Union[str, None]): the id of the node or None if none exists | ||||
|  | ||||
|         Returns: | ||||
|             str: the properties dictionary as a properly formatted string | ||||
|         """ | ||||
|         props = [] | ||||
|         # wrap property key in backticks to escape | ||||
|         for k, v in properties.items(): | ||||
|             prop = f"`{k}`: {json.dumps(v)}" | ||||
|             props.append(prop) | ||||
|         if _id is not None and "id" not in properties: | ||||
|             props.append( | ||||
|                 f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}" | ||||
|             ) | ||||
|         return "{" + ", ".join(props) + "}" | ||||
|  | ||||
|     @staticmethod | ||||
|     def _encode_graph_label(label: str) -> str: | ||||
|         """ | ||||
|         Since AGE suports only alphanumerical labels, we will encode generic label as HEX string | ||||
|  | ||||
|         Args: | ||||
|             label (str): the original label | ||||
|  | ||||
|         Returns: | ||||
|             str: the encoded label | ||||
|         """ | ||||
|         return "x" + label.encode().hex() | ||||
|  | ||||
|     @staticmethod | ||||
|     def _decode_graph_label(encoded_label: str) -> str: | ||||
|         """ | ||||
|         Since AGE suports only alphanumerical labels, we will encode generic label as HEX string | ||||
|  | ||||
|         Args: | ||||
|             encoded_label (str): the encoded label | ||||
|  | ||||
|         Returns: | ||||
|             str: the decoded label | ||||
|         """ | ||||
|         return bytes.fromhex(encoded_label.removeprefix("x")).decode() | ||||
|  | ||||
|     @staticmethod | ||||
|     def _get_col_name(field: str, idx: int) -> str: | ||||
|         """ | ||||
|         Convert a cypher return field to a pgsql select field | ||||
|         If possible keep the cypher column name, but create a generic name if necessary | ||||
|  | ||||
|         Args: | ||||
|             field (str): a return field from a cypher query to be formatted for pgsql | ||||
|             idx (int): the position of the field in the return statement | ||||
|  | ||||
|         Returns: | ||||
|             str: the field to be used in the pgsql select statement | ||||
|         """ | ||||
|         # remove white space | ||||
|         field = field.strip() | ||||
|         # if an alias is provided for the field, use it | ||||
|         if " as " in field: | ||||
|             return field.split(" as ")[-1].strip() | ||||
|         # if the return value is an unnamed primitive, give it a generic name | ||||
|         if field.isnumeric() or field in ("true", "false", "null"): | ||||
|             return f"column_{idx}" | ||||
|         # otherwise return the value stripping out some common special chars | ||||
|         return field.replace("(", "_").replace(")", "") | ||||
|  | ||||
|     @staticmethod | ||||
|     def _wrap_query(query: str, graph_name: str, **params: str) -> str: | ||||
|         """ | ||||
|         Convert a cypher query to an Apache Age compatible | ||||
|         sql query by wrapping the cypher query in ag_catalog.cypher, | ||||
|         casting results to agtype and building a select statement | ||||
|  | ||||
|         Args: | ||||
|             query (str): a valid cypher query | ||||
|             graph_name (str): the name of the graph to query | ||||
|             params (dict): parameters for the query | ||||
|  | ||||
|         Returns: | ||||
|             str: an equivalent pgsql query | ||||
|         """ | ||||
|  | ||||
|         # pgsql template | ||||
|         template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$ | ||||
|             {query} | ||||
|         $$) AS ({fields});""" | ||||
|  | ||||
|         # if there are any returned fields they must be added to the pgsql query | ||||
|         if "return" in query.lower(): | ||||
|             # parse return statement to identify returned fields | ||||
|             fields = ( | ||||
|                 query.lower() | ||||
|                 .split("return")[-1] | ||||
|                 .split("distinct")[-1] | ||||
|                 .split("order by")[0] | ||||
|                 .split("skip")[0] | ||||
|                 .split("limit")[0] | ||||
|                 .split(",") | ||||
|             ) | ||||
|  | ||||
|             # raise exception if RETURN * is found as we can't resolve the fields | ||||
|             if "*" in [x.strip() for x in fields]: | ||||
|                 raise ValueError( | ||||
|                     "AGE graph does not support 'RETURN *'" | ||||
|                     + " statements in Cypher queries" | ||||
|                 ) | ||||
|  | ||||
|             # get pgsql formatted field names | ||||
|             fields = [ | ||||
|                 AGEStorage._get_col_name(field, idx) for idx, field in enumerate(fields) | ||||
|             ] | ||||
|  | ||||
|             # build resulting pgsql relation | ||||
|             fields_str = ", ".join( | ||||
|                 [field.split(".")[-1] + " agtype" for field in fields] | ||||
|             ) | ||||
|  | ||||
|         # if no return statement we still need to return a single field of type agtype | ||||
|         else: | ||||
|             fields_str = "a agtype" | ||||
|  | ||||
|         select_str = "*" | ||||
|  | ||||
|         return template.format( | ||||
|             graph_name=graph_name, | ||||
|             query=query.format(**params), | ||||
|             fields=fields_str, | ||||
|             projection=select_str, | ||||
|         ) | ||||
|  | ||||
|     async def _query(self, query: str, **params: str) -> List[Dict[str, Any]]: | ||||
|         """ | ||||
|         Query the graph by taking a cypher query, converting it to an | ||||
|         age compatible query, executing it and converting the result | ||||
|  | ||||
|         Args: | ||||
|             query (str): a cypher query to be executed | ||||
|             params (dict): parameters for the query | ||||
|  | ||||
|         Returns: | ||||
|             List[Dict[str, Any]]: a list of dictionaries containing the result set | ||||
|         """ | ||||
|         # convert cypher query to pgsql/age query | ||||
|         wrapped_query = self._wrap_query(query, self.graph_name, **params) | ||||
|  | ||||
|         await self._driver.open() | ||||
|  | ||||
|         # create graph if it doesn't exist | ||||
|         async with self._get_pool_connection() as conn: | ||||
|             async with conn.cursor() as curs: | ||||
|                 try: | ||||
|                     await curs.execute('SET search_path = ag_catalog, "$user", public') | ||||
|                     await curs.execute(f"SELECT create_graph('{self.graph_name}')") | ||||
|                     await conn.commit() | ||||
|                 except ( | ||||
|                     psycopg.errors.InvalidSchemaName, | ||||
|                     psycopg.errors.UniqueViolation, | ||||
|                 ): | ||||
|                     await conn.rollback() | ||||
|  | ||||
|         # execute the query, rolling back on an error | ||||
|         async with self._get_pool_connection() as conn: | ||||
|             async with conn.cursor(row_factory=namedtuple_row) as curs: | ||||
|                 try: | ||||
|                     await curs.execute('SET search_path = ag_catalog, "$user", public') | ||||
|                     await curs.execute(wrapped_query) | ||||
|                     await conn.commit() | ||||
|                 except psycopg.Error as e: | ||||
|                     await conn.rollback() | ||||
|                     raise AGEQueryException( | ||||
|                         { | ||||
|                             "message": f"Error executing graph query: {query.format(**params)}", | ||||
|                             "detail": str(e), | ||||
|                         } | ||||
|                     ) from e | ||||
|  | ||||
|                 data = await curs.fetchall() | ||||
|                 if data is None: | ||||
|                     result = [] | ||||
|                 # decode records | ||||
|                 else: | ||||
|                     result = [AGEStorage._record_to_dict(d) for d in data] | ||||
|  | ||||
|                 return result | ||||
|  | ||||
|     async def has_node(self, node_id: str) -> bool: | ||||
|         entity_name_label = node_id.strip('"') | ||||
|  | ||||
|         query = """ | ||||
|                 MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists | ||||
|                 """ | ||||
|         params = {"label": AGEStorage._encode_graph_label(entity_name_label)} | ||||
|         single_result = (await self._query(query, **params))[0] | ||||
|         logger.debug( | ||||
|             "{%s}:query:{%s}:result:{%s}", | ||||
|             inspect.currentframe().f_code.co_name, | ||||
|             query.format(**params), | ||||
|             single_result["node_exists"], | ||||
|         ) | ||||
|  | ||||
|         return single_result["node_exists"] | ||||
|  | ||||
|     async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: | ||||
|         entity_name_label_source = source_node_id.strip('"') | ||||
|         entity_name_label_target = target_node_id.strip('"') | ||||
|  | ||||
|         query = """ | ||||
|                 MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) | ||||
|                 RETURN COUNT(r) > 0 AS edge_exists | ||||
|                 """ | ||||
|         params = { | ||||
|             "src_label": AGEStorage._encode_graph_label(entity_name_label_source), | ||||
|             "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), | ||||
|         } | ||||
|         single_result = (await self._query(query, **params))[0] | ||||
|         logger.debug( | ||||
|             "{%s}:query:{%s}:result:{%s}", | ||||
|             inspect.currentframe().f_code.co_name, | ||||
|             query.format(**params), | ||||
|             single_result["edge_exists"], | ||||
|         ) | ||||
|         return single_result["edge_exists"] | ||||
|  | ||||
|     async def get_node(self, node_id: str) -> Union[dict, None]: | ||||
|         entity_name_label = node_id.strip('"') | ||||
|         query = """ | ||||
|                 MATCH (n:`{label}`) RETURN n | ||||
|                 """ | ||||
|         params = {"label": AGEStorage._encode_graph_label(entity_name_label)} | ||||
|         record = await self._query(query, **params) | ||||
|         if record: | ||||
|             node = record[0] | ||||
|             node_dict = node["n"] | ||||
|             logger.debug( | ||||
|                 "{%s}: query: {%s}, result: {%s}", | ||||
|                 inspect.currentframe().f_code.co_name, | ||||
|                 query.format(**params), | ||||
|                 node_dict, | ||||
|             ) | ||||
|             return node_dict | ||||
|         return None | ||||
|  | ||||
|     async def node_degree(self, node_id: str) -> int: | ||||
|         entity_name_label = node_id.strip('"') | ||||
|  | ||||
|         query = """ | ||||
|                 MATCH (n:`{label}`)-[]->(x) | ||||
|                 RETURN count(x) AS total_edge_count | ||||
|                 """ | ||||
|         params = {"label": AGEStorage._encode_graph_label(entity_name_label)} | ||||
|         record = (await self._query(query, **params))[0] | ||||
|         if record: | ||||
|             edge_count = int(record["total_edge_count"]) | ||||
|             logger.debug( | ||||
|                 "{%s}:query:{%s}:result:{%s}", | ||||
|                 inspect.currentframe().f_code.co_name, | ||||
|                 query.format(**params), | ||||
|                 edge_count, | ||||
|             ) | ||||
|             return edge_count | ||||
|  | ||||
|     async def edge_degree(self, src_id: str, tgt_id: str) -> int: | ||||
|         entity_name_label_source = src_id.strip('"') | ||||
|         entity_name_label_target = tgt_id.strip('"') | ||||
|         src_degree = await self.node_degree(entity_name_label_source) | ||||
|         trg_degree = await self.node_degree(entity_name_label_target) | ||||
|  | ||||
|         # Convert None to 0 for addition | ||||
|         src_degree = 0 if src_degree is None else src_degree | ||||
|         trg_degree = 0 if trg_degree is None else trg_degree | ||||
|  | ||||
|         degrees = int(src_degree) + int(trg_degree) | ||||
|         logger.debug( | ||||
|             "{%s}:query:src_Degree+trg_degree:result:{%s}", | ||||
|             inspect.currentframe().f_code.co_name, | ||||
|             degrees, | ||||
|         ) | ||||
|         return degrees | ||||
|  | ||||
|     async def get_edge( | ||||
|         self, source_node_id: str, target_node_id: str | ||||
|     ) -> Union[dict, None]: | ||||
|         """ | ||||
|         Find all edges between nodes of two given labels | ||||
|  | ||||
|         Args: | ||||
|             source_node_label (str): Label of the source nodes | ||||
|             target_node_label (str): Label of the target nodes | ||||
|  | ||||
|         Returns: | ||||
|             list: List of all relationships/edges found | ||||
|         """ | ||||
|         entity_name_label_source = source_node_id.strip('"') | ||||
|         entity_name_label_target = target_node_id.strip('"') | ||||
|  | ||||
|         query = """ | ||||
|                 MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`) | ||||
|                 RETURN properties(r) as edge_properties | ||||
|                 LIMIT 1 | ||||
|                 """ | ||||
|         params = { | ||||
|             "src_label": AGEStorage._encode_graph_label(entity_name_label_source), | ||||
|             "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), | ||||
|         } | ||||
|         record = await self._query(query, **params) | ||||
|         if record and record[0] and record[0]["edge_properties"]: | ||||
|             result = record[0]["edge_properties"] | ||||
|             logger.debug( | ||||
|                 "{%s}:query:{%s}:result:{%s}", | ||||
|                 inspect.currentframe().f_code.co_name, | ||||
|                 query.format(**params), | ||||
|                 result, | ||||
|             ) | ||||
|             return result | ||||
|  | ||||
|     async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: | ||||
|         """ | ||||
|         Retrieves all edges (relationships) for a particular node identified by its label. | ||||
|         :return: List of dictionaries containing edge information | ||||
|         """ | ||||
|         node_label = source_node_id.strip('"') | ||||
|  | ||||
|         query = """ | ||||
|                 MATCH (n:`{label}`) | ||||
|                 OPTIONAL MATCH (n)-[r]-(connected) | ||||
|                 RETURN n, r, connected | ||||
|                 """ | ||||
|         params = {"label": AGEStorage._encode_graph_label(node_label)} | ||||
|         results = await self._query(query, **params) | ||||
|         edges = [] | ||||
|         for record in results: | ||||
|             source_node = record["n"] if record["n"] else None | ||||
|             connected_node = record["connected"] if record["connected"] else None | ||||
|  | ||||
|             source_label = ( | ||||
|                 source_node["label"] if source_node and source_node["label"] else None | ||||
|             ) | ||||
|             target_label = ( | ||||
|                 connected_node["label"] | ||||
|                 if connected_node and connected_node["label"] | ||||
|                 else None | ||||
|             ) | ||||
|  | ||||
|             if source_label and target_label: | ||||
|                 edges.append((source_label, target_label)) | ||||
|  | ||||
|         return edges | ||||
|  | ||||
|     @retry( | ||||
|         stop=stop_after_attempt(3), | ||||
|         wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|         retry=retry_if_exception_type((AGEQueryException,)), | ||||
|     ) | ||||
|     async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): | ||||
|         """ | ||||
|         Upsert a node in the AGE database. | ||||
|  | ||||
|         Args: | ||||
|             node_id: The unique identifier for the node (used as label) | ||||
|             node_data: Dictionary of node properties | ||||
|         """ | ||||
|         label = node_id.strip('"') | ||||
|         properties = node_data | ||||
|  | ||||
|         query = """ | ||||
|                 MERGE (n:`{label}`) | ||||
|                 SET n += {properties} | ||||
|                 """ | ||||
|         params = { | ||||
|             "label": AGEStorage._encode_graph_label(label), | ||||
|             "properties": AGEStorage._format_properties(properties), | ||||
|         } | ||||
|         try: | ||||
|             await self._query(query, **params) | ||||
|             logger.debug( | ||||
|                 "Upserted node with label '{%s}' and properties: {%s}", | ||||
|                 label, | ||||
|                 properties, | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             logger.error("Error during upsert: {%s}", e) | ||||
|             raise | ||||
|  | ||||
|     @retry( | ||||
|         stop=stop_after_attempt(3), | ||||
|         wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|         retry=retry_if_exception_type((AGEQueryException,)), | ||||
|     ) | ||||
|     async def upsert_edge( | ||||
|         self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] | ||||
|     ): | ||||
|         """ | ||||
|         Upsert an edge and its properties between two nodes identified by their labels. | ||||
|  | ||||
|         Args: | ||||
|             source_node_id (str): Label of the source node (used as identifier) | ||||
|             target_node_id (str): Label of the target node (used as identifier) | ||||
|             edge_data (dict): Dictionary of properties to set on the edge | ||||
|         """ | ||||
|         source_node_label = source_node_id.strip('"') | ||||
|         target_node_label = target_node_id.strip('"') | ||||
|         edge_properties = edge_data | ||||
|  | ||||
|         query = """ | ||||
|                 MATCH (source:`{src_label}`) | ||||
|                 WITH source | ||||
|                 MATCH (target:`{tgt_label}`) | ||||
|                 MERGE (source)-[r:DIRECTED]->(target) | ||||
|                 SET r += {properties} | ||||
|                 RETURN r | ||||
|                 """ | ||||
|         params = { | ||||
|             "src_label": AGEStorage._encode_graph_label(source_node_label), | ||||
|             "tgt_label": AGEStorage._encode_graph_label(target_node_label), | ||||
|             "properties": AGEStorage._format_properties(edge_properties), | ||||
|         } | ||||
|         try: | ||||
|             await self._query(query, **params) | ||||
|             logger.debug( | ||||
|                 "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", | ||||
|                 source_node_label, | ||||
|                 target_node_label, | ||||
|                 edge_properties, | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             logger.error("Error during edge upsert: {%s}", e) | ||||
|             raise | ||||
|  | ||||
|     async def _node2vec_embed(self): | ||||
|         print("Implemented but never called.") | ||||
|  | ||||
|     @asynccontextmanager | ||||
|     async def _get_pool_connection(self, timeout: Optional[float] = None): | ||||
|         """Workaround for a psycopg_pool bug""" | ||||
|  | ||||
|         try: | ||||
|             connection = await self._driver.getconn(timeout=timeout) | ||||
|         except PoolTimeout: | ||||
|             await self._driver._add_connection(None)  # workaround... | ||||
|             connection = await self._driver.getconn(timeout=timeout) | ||||
|  | ||||
|         try: | ||||
|             async with connection: | ||||
|                 yield connection | ||||
|         finally: | ||||
|             await self._driver.putconn(connection) | ||||
							
								
								
									
										173
									
								
								minirag/kg/chroma_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								minirag/kg/chroma_impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,173 @@ | ||||
| import os | ||||
| import asyncio | ||||
| from dataclasses import dataclass | ||||
| from typing import Union | ||||
| import numpy as np | ||||
| from chromadb import HttpClient | ||||
| from chromadb.config import Settings | ||||
| from lightrag.base import BaseVectorStorage | ||||
| from lightrag.utils import logger | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class ChromaVectorDBStorage(BaseVectorStorage): | ||||
|     """ChromaDB vector storage implementation.""" | ||||
|  | ||||
|     cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         try: | ||||
|             # Use global config value if specified, otherwise use default | ||||
|             config = self.global_config.get("vector_db_storage_cls_kwargs", {}) | ||||
|             self.cosine_better_than_threshold = config.get( | ||||
|                 "cosine_better_than_threshold", self.cosine_better_than_threshold | ||||
|             ) | ||||
|  | ||||
|             user_collection_settings = config.get("collection_settings", {}) | ||||
|             # Default HNSW index settings for ChromaDB | ||||
|             default_collection_settings = { | ||||
|                 # Distance metric used for similarity search (cosine similarity) | ||||
|                 "hnsw:space": "cosine", | ||||
|                 # Number of nearest neighbors to explore during index construction | ||||
|                 # Higher values = better recall but slower indexing | ||||
|                 "hnsw:construction_ef": 128, | ||||
|                 # Number of nearest neighbors to explore during search | ||||
|                 # Higher values = better recall but slower search | ||||
|                 "hnsw:search_ef": 128, | ||||
|                 # Number of connections per node in the HNSW graph | ||||
|                 # Higher values = better recall but more memory usage | ||||
|                 "hnsw:M": 16, | ||||
|                 # Number of vectors to process in one batch during indexing | ||||
|                 "hnsw:batch_size": 100, | ||||
|                 # Number of updates before forcing index synchronization | ||||
|                 # Lower values = more frequent syncs but slower indexing | ||||
|                 "hnsw:sync_threshold": 1000, | ||||
|             } | ||||
|             collection_settings = { | ||||
|                 **default_collection_settings, | ||||
|                 **user_collection_settings, | ||||
|             } | ||||
|  | ||||
|             auth_provider = config.get( | ||||
|                 "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider" | ||||
|             ) | ||||
|             auth_credentials = config.get("auth_token", "secret-token") | ||||
|             headers = {} | ||||
|  | ||||
|             if "token_authn" in auth_provider: | ||||
|                 headers = { | ||||
|                     config.get("auth_header_name", "X-Chroma-Token"): auth_credentials | ||||
|                 } | ||||
|             elif "basic_authn" in auth_provider: | ||||
|                 auth_credentials = config.get("auth_credentials", "admin:admin") | ||||
|  | ||||
|             self._client = HttpClient( | ||||
|                 host=config.get("host", "localhost"), | ||||
|                 port=config.get("port", 8000), | ||||
|                 headers=headers, | ||||
|                 settings=Settings( | ||||
|                     chroma_api_impl="rest", | ||||
|                     chroma_client_auth_provider=auth_provider, | ||||
|                     chroma_client_auth_credentials=auth_credentials, | ||||
|                     allow_reset=True, | ||||
|                     anonymized_telemetry=False, | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|             self._collection = self._client.get_or_create_collection( | ||||
|                 name=self.namespace, | ||||
|                 metadata={ | ||||
|                     **collection_settings, | ||||
|                     "dimension": self.embedding_func.embedding_dim, | ||||
|                 }, | ||||
|             ) | ||||
|             # Use batch size from collection settings if specified | ||||
|             self._max_batch_size = self.global_config.get( | ||||
|                 "embedding_batch_num", collection_settings.get("hnsw:batch_size", 32) | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             logger.error(f"ChromaDB initialization failed: {str(e)}") | ||||
|             raise | ||||
|  | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         if not data: | ||||
|             logger.warning("Empty data provided to vector DB") | ||||
|             return [] | ||||
|  | ||||
|         try: | ||||
|             ids = list(data.keys()) | ||||
|             documents = [v["content"] for v in data.values()] | ||||
|             metadatas = [ | ||||
|                 {k: v for k, v in item.items() if k in self.meta_fields} | ||||
|                 or {"_default": "true"} | ||||
|                 for item in data.values() | ||||
|             ] | ||||
|  | ||||
|             # Process in batches | ||||
|             batches = [ | ||||
|                 documents[i : i + self._max_batch_size] | ||||
|                 for i in range(0, len(documents), self._max_batch_size) | ||||
|             ] | ||||
|  | ||||
|             embedding_tasks = [self.embedding_func(batch) for batch in batches] | ||||
|             embeddings_list = [] | ||||
|  | ||||
|             # Pre-allocate embeddings_list with known size | ||||
|             embeddings_list = [None] * len(embedding_tasks) | ||||
|  | ||||
|             # Use asyncio.gather instead of as_completed if order doesn't matter | ||||
|             embeddings_results = await asyncio.gather(*embedding_tasks) | ||||
|             embeddings_list = list(embeddings_results) | ||||
|  | ||||
|             embeddings = np.concatenate(embeddings_list) | ||||
|  | ||||
|             # Upsert in batches | ||||
|             for i in range(0, len(ids), self._max_batch_size): | ||||
|                 batch_slice = slice(i, i + self._max_batch_size) | ||||
|  | ||||
|                 self._collection.upsert( | ||||
|                     ids=ids[batch_slice], | ||||
|                     embeddings=embeddings[batch_slice].tolist(), | ||||
|                     documents=documents[batch_slice], | ||||
|                     metadatas=metadatas[batch_slice], | ||||
|                 ) | ||||
|  | ||||
|             return ids | ||||
|  | ||||
|         except Exception as e: | ||||
|             logger.error(f"Error during ChromaDB upsert: {str(e)}") | ||||
|             raise | ||||
|  | ||||
|     async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: | ||||
|         try: | ||||
|             embedding = await self.embedding_func([query]) | ||||
|  | ||||
|             results = self._collection.query( | ||||
|                 query_embeddings=embedding.tolist(), | ||||
|                 n_results=top_k * 2,  # Request more results to allow for filtering | ||||
|                 include=["metadatas", "distances", "documents"], | ||||
|             ) | ||||
|  | ||||
|             # Filter results by cosine similarity threshold and take top k | ||||
|             # We request 2x results initially to have enough after filtering | ||||
|             # ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal) | ||||
|             # We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity) | ||||
|             # Only keep results with distance below threshold, then take top k | ||||
|             return [ | ||||
|                 { | ||||
|                     "id": results["ids"][0][i], | ||||
|                     "distance": 1 - results["distances"][0][i], | ||||
|                     "content": results["documents"][0][i], | ||||
|                     **results["metadatas"][0][i], | ||||
|                 } | ||||
|                 for i in range(len(results["ids"][0])) | ||||
|                 if (1 - results["distances"][0][i]) >= self.cosine_better_than_threshold | ||||
|             ][:top_k] | ||||
|  | ||||
|         except Exception as e: | ||||
|             logger.error(f"Error during ChromaDB query: {str(e)}") | ||||
|             raise | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         # ChromaDB handles persistence automatically | ||||
|         pass | ||||
							
								
								
									
										397
									
								
								minirag/kg/gremlin_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										397
									
								
								minirag/kg/gremlin_impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,397 @@ | ||||
| import asyncio | ||||
| import inspect | ||||
| import json | ||||
| import os | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Dict, List, Tuple, Union | ||||
|  | ||||
| from gremlin_python.driver import client, serializer | ||||
| from gremlin_python.driver.aiohttp.transport import AiohttpTransport | ||||
| from gremlin_python.driver.protocol import GremlinServerError | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     retry_if_exception_type, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
| ) | ||||
|  | ||||
| from lightrag.utils import logger | ||||
|  | ||||
| from ..base import BaseGraphStorage | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class GremlinStorage(BaseGraphStorage): | ||||
|     @staticmethod | ||||
|     def load_nx_graph(file_name): | ||||
|         print("no preloading of graph with Gremlin in production") | ||||
|  | ||||
|     def __init__(self, namespace, global_config, embedding_func): | ||||
|         super().__init__( | ||||
|             namespace=namespace, | ||||
|             global_config=global_config, | ||||
|             embedding_func=embedding_func, | ||||
|         ) | ||||
|  | ||||
|         self._driver = None | ||||
|         self._driver_lock = asyncio.Lock() | ||||
|  | ||||
|         USER = os.environ.get("GREMLIN_USER", "") | ||||
|         PASSWORD = os.environ.get("GREMLIN_PASSWORD", "") | ||||
|         HOST = os.environ["GREMLIN_HOST"] | ||||
|         PORT = int(os.environ["GREMLIN_PORT"]) | ||||
|  | ||||
|         # TraversalSource, a custom one has to be created manually, | ||||
|         # default it "g" | ||||
|         SOURCE = os.environ.get("GREMLIN_TRAVERSE_SOURCE", "g") | ||||
|  | ||||
|         # All vertices will have graph={GRAPH} property, so that we can | ||||
|         # have several logical graphs for one source | ||||
|         GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"]) | ||||
|  | ||||
|         self.graph_name = GRAPH | ||||
|  | ||||
|         self._driver = client.Client( | ||||
|             f"ws://{HOST}:{PORT}/gremlin", | ||||
|             SOURCE, | ||||
|             username=USER, | ||||
|             password=PASSWORD, | ||||
|             message_serializer=serializer.GraphSONSerializersV3d0(), | ||||
|             transport_factory=lambda: AiohttpTransport(call_from_event_loop=True), | ||||
|         ) | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         self._node_embed_algorithms = { | ||||
|             "node2vec": self._node2vec_embed, | ||||
|         } | ||||
|  | ||||
|     async def close(self): | ||||
|         if self._driver: | ||||
|             self._driver.close() | ||||
|             self._driver = None | ||||
|  | ||||
|     async def __aexit__(self, exc_type, exc, tb): | ||||
|         if self._driver: | ||||
|             self._driver.close() | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         print("KG successfully indexed.") | ||||
|  | ||||
|     @staticmethod | ||||
|     def _to_value_map(value: Any) -> str: | ||||
|         """Dump supported Python object as Gremlin valueMap""" | ||||
|         json_str = json.dumps(value, ensure_ascii=False, sort_keys=False) | ||||
|         parsed_str = json_str.replace("'", r"\'") | ||||
|  | ||||
|         # walk over the string and replace curly brackets with square brackets | ||||
|         # outside of strings, as well as replace double quotes with single quotes | ||||
|         # and "deescape" double quotes inside of strings | ||||
|         outside_str = True | ||||
|         escaped = False | ||||
|         remove_indices = [] | ||||
|         for i, c in enumerate(parsed_str): | ||||
|             if escaped: | ||||
|                 # previous character was an "odd" backslash | ||||
|                 escaped = False | ||||
|                 if c == '"': | ||||
|                     # we want to "deescape" double quotes: store indices to delete | ||||
|                     remove_indices.insert(0, i - 1) | ||||
|             elif c == "\\": | ||||
|                 escaped = True | ||||
|             elif c == '"': | ||||
|                 outside_str = not outside_str | ||||
|                 parsed_str = parsed_str[:i] + "'" + parsed_str[i + 1 :] | ||||
|             elif c == "{" and outside_str: | ||||
|                 parsed_str = parsed_str[:i] + "[" + parsed_str[i + 1 :] | ||||
|             elif c == "}" and outside_str: | ||||
|                 parsed_str = parsed_str[:i] + "]" + parsed_str[i + 1 :] | ||||
|         for idx in remove_indices: | ||||
|             parsed_str = parsed_str[:idx] + parsed_str[idx + 1 :] | ||||
|         return parsed_str | ||||
|  | ||||
|     @staticmethod | ||||
|     def _convert_properties(properties: Dict[str, Any]) -> str: | ||||
|         """Create chained .property() commands from properties dict""" | ||||
|         props = [] | ||||
|         for k, v in properties.items(): | ||||
|             prop_name = GremlinStorage._to_value_map(k) | ||||
|             props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})") | ||||
|         return "".join(props) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _fix_name(name: str) -> str: | ||||
|         """Strip double quotes and format as a proper field name""" | ||||
|         name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'")) | ||||
|  | ||||
|         return name | ||||
|  | ||||
|     async def _query(self, query: str) -> List[Dict[str, Any]]: | ||||
|         """ | ||||
|         Query the Gremlin graph | ||||
|  | ||||
|         Args: | ||||
|             query (str): a query to be executed | ||||
|  | ||||
|         Returns: | ||||
|             List[Dict[str, Any]]: a list of dictionaries containing the result set | ||||
|         """ | ||||
|  | ||||
|         result = list(await asyncio.wrap_future(self._driver.submit_async(query))) | ||||
|         if result: | ||||
|             result = result[0] | ||||
|  | ||||
|         return result | ||||
|  | ||||
|     async def has_node(self, node_id: str) -> bool: | ||||
|         entity_name = GremlinStorage._fix_name(node_id) | ||||
|  | ||||
|         query = f"""g | ||||
|                  .V().has('graph', {self.graph_name}) | ||||
|                  .has('entity_name', {entity_name}) | ||||
|                  .limit(1) | ||||
|                  .count() | ||||
|                  .project('has_node') | ||||
|                     .by(__.choose(__.is(gt(0)), constant(true), constant(false))) | ||||
|                  """ | ||||
|         result = await self._query(query) | ||||
|         logger.debug( | ||||
|             "{%s}:query:{%s}:result:{%s}", | ||||
|             inspect.currentframe().f_code.co_name, | ||||
|             query, | ||||
|             result[0]["has_node"], | ||||
|         ) | ||||
|  | ||||
|         return result[0]["has_node"] | ||||
|  | ||||
|     async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: | ||||
|         entity_name_source = GremlinStorage._fix_name(source_node_id) | ||||
|         entity_name_target = GremlinStorage._fix_name(target_node_id) | ||||
|  | ||||
|         query = f"""g | ||||
|                  .V().has('graph', {self.graph_name}) | ||||
|                  .has('entity_name', {entity_name_source}) | ||||
|                  .outE() | ||||
|                  .inV().has('graph', {self.graph_name}) | ||||
|                  .has('entity_name', {entity_name_target}) | ||||
|                  .limit(1) | ||||
|                  .count() | ||||
|                  .project('has_edge') | ||||
|                     .by(__.choose(__.is(gt(0)), constant(true), constant(false))) | ||||
|                  """ | ||||
|         result = await self._query(query) | ||||
|         logger.debug( | ||||
|             "{%s}:query:{%s}:result:{%s}", | ||||
|             inspect.currentframe().f_code.co_name, | ||||
|             query, | ||||
|             result[0]["has_edge"], | ||||
|         ) | ||||
|  | ||||
|         return result[0]["has_edge"] | ||||
|  | ||||
|     async def get_node(self, node_id: str) -> Union[dict, None]: | ||||
|         entity_name = GremlinStorage._fix_name(node_id) | ||||
|         query = f"""g | ||||
|                  .V().has('graph', {self.graph_name}) | ||||
|                  .has('entity_name', {entity_name}) | ||||
|                  .limit(1) | ||||
|                  .project('properties') | ||||
|                     .by(elementMap()) | ||||
|                  """ | ||||
|         result = await self._query(query) | ||||
|         if result: | ||||
|             node = result[0] | ||||
|             node_dict = node["properties"] | ||||
|             logger.debug( | ||||
|                 "{%s}: query: {%s}, result: {%s}", | ||||
|                 inspect.currentframe().f_code.co_name, | ||||
|                 query.format, | ||||
|                 node_dict, | ||||
|             ) | ||||
|             return node_dict | ||||
|  | ||||
|     async def node_degree(self, node_id: str) -> int: | ||||
|         entity_name = GremlinStorage._fix_name(node_id) | ||||
|         query = f"""g | ||||
|                  .V().has('graph', {self.graph_name}) | ||||
|                  .has('entity_name', {entity_name}) | ||||
|                  .outE() | ||||
|                  .inV().has('graph', {self.graph_name}) | ||||
|                  .count() | ||||
|                  .project('total_edge_count') | ||||
|                     .by() | ||||
|                  """ | ||||
|         result = await self._query(query) | ||||
|         edge_count = result[0]["total_edge_count"] | ||||
|  | ||||
|         logger.debug( | ||||
|             "{%s}:query:{%s}:result:{%s}", | ||||
|             inspect.currentframe().f_code.co_name, | ||||
|             query, | ||||
|             edge_count, | ||||
|         ) | ||||
|  | ||||
|         return edge_count | ||||
|  | ||||
|     async def edge_degree(self, src_id: str, tgt_id: str) -> int: | ||||
|         src_degree = await self.node_degree(src_id) | ||||
|         trg_degree = await self.node_degree(tgt_id) | ||||
|  | ||||
|         # Convert None to 0 for addition | ||||
|         src_degree = 0 if src_degree is None else src_degree | ||||
|         trg_degree = 0 if trg_degree is None else trg_degree | ||||
|  | ||||
|         degrees = int(src_degree) + int(trg_degree) | ||||
|         logger.debug( | ||||
|             "{%s}:query:src_Degree+trg_degree:result:{%s}", | ||||
|             inspect.currentframe().f_code.co_name, | ||||
|             degrees, | ||||
|         ) | ||||
|         return degrees | ||||
|  | ||||
|     async def get_edge( | ||||
|         self, source_node_id: str, target_node_id: str | ||||
|     ) -> Union[dict, None]: | ||||
|         """ | ||||
|         Find all edges between nodes of two given names | ||||
|  | ||||
|         Args: | ||||
|             source_node_id (str): Name of the source nodes | ||||
|             target_node_id (str): Name of the target nodes | ||||
|  | ||||
|         Returns: | ||||
|             dict|None: Dict of found edge properties, or None if not found | ||||
|         """ | ||||
|         entity_name_source = GremlinStorage._fix_name(source_node_id) | ||||
|         entity_name_target = GremlinStorage._fix_name(target_node_id) | ||||
|         query = f"""g | ||||
|                  .V().has('graph', {self.graph_name}) | ||||
|                  .has('entity_name', {entity_name_source}) | ||||
|                  .outE() | ||||
|                  .inV().has('graph', {self.graph_name}) | ||||
|                  .has('entity_name', {entity_name_target}) | ||||
|                  .limit(1) | ||||
|                  .project('edge_properties') | ||||
|                  .by(__.bothE().elementMap()) | ||||
|                  """ | ||||
|         result = await self._query(query) | ||||
|         if result: | ||||
|             edge_properties = result[0]["edge_properties"] | ||||
|             logger.debug( | ||||
|                 "{%s}:query:{%s}:result:{%s}", | ||||
|                 inspect.currentframe().f_code.co_name, | ||||
|                 query, | ||||
|                 edge_properties, | ||||
|             ) | ||||
|             return edge_properties | ||||
|  | ||||
|     async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: | ||||
|         """ | ||||
|         Retrieves all edges (relationships) for a particular node identified by its name. | ||||
|         :return: List of tuples containing edge sources and targets | ||||
|         """ | ||||
|         node_name = GremlinStorage._fix_name(source_node_id) | ||||
|         query = f"""g | ||||
|                  .E() | ||||
|                  .filter( | ||||
|                      __.or( | ||||
|                          __.outV().has('graph', {self.graph_name}) | ||||
|                            .has('entity_name', {node_name}), | ||||
|                          __.inV().has('graph', {self.graph_name}) | ||||
|                            .has('entity_name', {node_name}) | ||||
|                      ) | ||||
|                  ) | ||||
|                  .project('source_name', 'target_name') | ||||
|                  .by(__.outV().values('entity_name')) | ||||
|                  .by(__.inV().values('entity_name')) | ||||
|                  """ | ||||
|         result = await self._query(query) | ||||
|         edges = [(res["source_name"], res["target_name"]) for res in result] | ||||
|  | ||||
|         return edges | ||||
|  | ||||
|     @retry( | ||||
|         stop=stop_after_attempt(10), | ||||
|         wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|         retry=retry_if_exception_type((GremlinServerError,)), | ||||
|     ) | ||||
|     async def upsert_node(self, node_id: str, node_data: Dict[str, Any]): | ||||
|         """ | ||||
|         Upsert a node in the Gremlin graph. | ||||
|  | ||||
|         Args: | ||||
|             node_id: The unique identifier for the node (used as name) | ||||
|             node_data: Dictionary of node properties | ||||
|         """ | ||||
|         name = GremlinStorage._fix_name(node_id) | ||||
|         properties = GremlinStorage._convert_properties(node_data) | ||||
|  | ||||
|         query = f"""g | ||||
|                  .V().has('graph', {self.graph_name}) | ||||
|                  .has('entity_name', {name}) | ||||
|                  .fold() | ||||
|                  .coalesce( | ||||
|                      __.unfold(), | ||||
|                      __.addV('ENTITY') | ||||
|                          .property('graph', {self.graph_name}) | ||||
|                          .property('entity_name', {name}) | ||||
|                  ) | ||||
|                  {properties} | ||||
|                  """ | ||||
|  | ||||
|         try: | ||||
|             await self._query(query) | ||||
|             logger.debug( | ||||
|                 "Upserted node with name {%s} and properties: {%s}", | ||||
|                 name, | ||||
|                 properties, | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             logger.error("Error during upsert: {%s}", e) | ||||
|             raise | ||||
|  | ||||
|     @retry( | ||||
|         stop=stop_after_attempt(10), | ||||
|         wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|         retry=retry_if_exception_type((GremlinServerError,)), | ||||
|     ) | ||||
|     async def upsert_edge( | ||||
|         self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] | ||||
|     ): | ||||
|         """ | ||||
|         Upsert an edge and its properties between two nodes identified by their names. | ||||
|  | ||||
|         Args: | ||||
|             source_node_id (str): Name of the source node (used as identifier) | ||||
|             target_node_id (str): Name of the target node (used as identifier) | ||||
|             edge_data (dict): Dictionary of properties to set on the edge | ||||
|         """ | ||||
|         source_node_name = GremlinStorage._fix_name(source_node_id) | ||||
|         target_node_name = GremlinStorage._fix_name(target_node_id) | ||||
|         edge_properties = GremlinStorage._convert_properties(edge_data) | ||||
|  | ||||
|         query = f"""g | ||||
|                  .V().has('graph', {self.graph_name}) | ||||
|                  .has('entity_name', {source_node_name}).as('source') | ||||
|                  .V().has('graph', {self.graph_name}) | ||||
|                  .has('entity_name', {target_node_name}).as('target') | ||||
|                  .coalesce( | ||||
|                       __.select('source').outE('DIRECTED').where(__.inV().as('target')), | ||||
|                       __.select('source').addE('DIRECTED').to(__.select('target')) | ||||
|                   ) | ||||
|                   .property('graph', {self.graph_name}) | ||||
|                  {edge_properties} | ||||
|                  """ | ||||
|         try: | ||||
|             await self._query(query) | ||||
|             logger.debug( | ||||
|                 "Upserted edge from {%s} to {%s} with properties: {%s}", | ||||
|                 source_node_name, | ||||
|                 target_node_name, | ||||
|                 edge_properties, | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             logger.error("Error during edge upsert: {%s}", e) | ||||
|             raise | ||||
|  | ||||
|     async def _node2vec_embed(self): | ||||
|         print("Implemented but never called.") | ||||
							
								
								
									
										134
									
								
								minirag/kg/json_kv_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								minirag/kg/json_kv_impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,134 @@ | ||||
| """ | ||||
| JsonDocStatus Storage Module | ||||
| ======================= | ||||
|  | ||||
| This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. | ||||
|  | ||||
| The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. | ||||
|  | ||||
| Author: lightrag team | ||||
| Created: 2024-01-25 | ||||
| License: MIT | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| The above copyright notice and this permission notice shall be included in all | ||||
| copies or substantial portions of the Software. | ||||
|  | ||||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||
| SOFTWARE. | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Dependencies: | ||||
|     - NetworkX | ||||
|     - NumPy | ||||
|     - LightRAG | ||||
|     - graspologic | ||||
|  | ||||
| Features: | ||||
|     - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) | ||||
|     - Query graph nodes and edges | ||||
|     - Calculate node and edge degrees | ||||
|     - Embed nodes using various algorithms (e.g., Node2Vec) | ||||
|     - Remove nodes and edges from the graph | ||||
|  | ||||
| Usage: | ||||
|     from lightrag.storage.networkx_storage import NetworkXStorage | ||||
|  | ||||
| """ | ||||
|  | ||||
| import asyncio | ||||
| import os | ||||
| from dataclasses import dataclass | ||||
|  | ||||
| from lightrag.utils import ( | ||||
|     logger, | ||||
|     load_json, | ||||
|     write_json, | ||||
| ) | ||||
|  | ||||
| from lightrag.base import ( | ||||
|     BaseKVStorage, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class JsonKVStorage(BaseKVStorage): | ||||
|     def __post_init__(self): | ||||
|         working_dir = self.global_config["working_dir"] | ||||
|         self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") | ||||
|         self._data = load_json(self._file_name) or {} | ||||
|         self._lock = asyncio.Lock() | ||||
|         logger.info(f"Load KV {self.namespace} with {len(self._data)} data") | ||||
|  | ||||
|     async def all_keys(self) -> list[str]: | ||||
|         return list(self._data.keys()) | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         write_json(self._data, self._file_name) | ||||
|  | ||||
|     async def get_by_id(self, id): | ||||
|         return self._data.get(id, None) | ||||
|  | ||||
|     async def get_by_ids(self, ids, fields=None): | ||||
|         if fields is None: | ||||
|             return [self._data.get(id, None) for id in ids] | ||||
|         return [ | ||||
|             ( | ||||
|                 {k: v for k, v in self._data[id].items() if k in fields} | ||||
|                 if self._data.get(id, None) | ||||
|                 else None | ||||
|             ) | ||||
|             for id in ids | ||||
|         ] | ||||
|  | ||||
|     async def filter_keys(self, data: list[str]) -> set[str]: | ||||
|         return set([s for s in data if s not in self._data]) | ||||
|  | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         left_data = {k: v for k, v in data.items() if k not in self._data} | ||||
|         self._data.update(left_data) | ||||
|         return left_data | ||||
|  | ||||
|     async def drop(self): | ||||
|         self._data = {} | ||||
|  | ||||
|     async def filter(self, filter_func): | ||||
|         """Filter key-value pairs based on a filter function | ||||
|  | ||||
|         Args: | ||||
|             filter_func: The filter function, which takes a value as an argument and returns a boolean value | ||||
|  | ||||
|         Returns: | ||||
|             Dict: Key-value pairs that meet the condition | ||||
|         """ | ||||
|         result = {} | ||||
|         async with self._lock: | ||||
|             for key, value in self._data.items(): | ||||
|                 if filter_func(value): | ||||
|                     result[key] = value | ||||
|         return result | ||||
|  | ||||
|     async def delete(self, ids: list[str]): | ||||
|         """Delete data with specified IDs | ||||
|  | ||||
|         Args: | ||||
|             ids: List of IDs to delete | ||||
|         """ | ||||
|         async with self._lock: | ||||
|             for id in ids: | ||||
|                 if id in self._data: | ||||
|                     del self._data[id] | ||||
|             await self.index_done_callback() | ||||
|             logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}") | ||||
							
								
								
									
										128
									
								
								minirag/kg/jsondocstatus_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								minirag/kg/jsondocstatus_impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,128 @@ | ||||
| """ | ||||
| JsonDocStatus Storage Module | ||||
| ======================= | ||||
|  | ||||
| This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. | ||||
|  | ||||
| The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. | ||||
|  | ||||
| Author: lightrag team | ||||
| Created: 2024-01-25 | ||||
| License: MIT | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| The above copyright notice and this permission notice shall be included in all | ||||
| copies or substantial portions of the Software. | ||||
|  | ||||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||
| SOFTWARE. | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Dependencies: | ||||
|     - NetworkX | ||||
|     - NumPy | ||||
|     - LightRAG | ||||
|     - graspologic | ||||
|  | ||||
| Features: | ||||
|     - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) | ||||
|     - Query graph nodes and edges | ||||
|     - Calculate node and edge degrees | ||||
|     - Embed nodes using various algorithms (e.g., Node2Vec) | ||||
|     - Remove nodes and edges from the graph | ||||
|  | ||||
| Usage: | ||||
|     from lightrag.storage.networkx_storage import NetworkXStorage | ||||
|  | ||||
| """ | ||||
|  | ||||
| import os | ||||
| from dataclasses import dataclass | ||||
| from typing import Union, Dict | ||||
|  | ||||
| from lightrag.utils import ( | ||||
|     logger, | ||||
|     load_json, | ||||
|     write_json, | ||||
| ) | ||||
|  | ||||
| from lightrag.base import ( | ||||
|     DocStatus, | ||||
|     DocProcessingStatus, | ||||
|     DocStatusStorage, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class JsonDocStatusStorage(DocStatusStorage): | ||||
|     """JSON implementation of document status storage""" | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         working_dir = self.global_config["working_dir"] | ||||
|         self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") | ||||
|         self._data = load_json(self._file_name) or {} | ||||
|         logger.info(f"Loaded document status storage with {len(self._data)} records") | ||||
|  | ||||
|     async def filter_keys(self, data: list[str]) -> set[str]: | ||||
|         """Return keys that should be processed (not in storage or not successfully processed)""" | ||||
|         return set( | ||||
|             [ | ||||
|                 k | ||||
|                 for k in data | ||||
|                 if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|     async def get_status_counts(self) -> Dict[str, int]: | ||||
|         """Get counts of documents in each status""" | ||||
|         counts = {status: 0 for status in DocStatus} | ||||
|         for doc in self._data.values(): | ||||
|             counts[doc["status"]] += 1 | ||||
|         return counts | ||||
|  | ||||
|     async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: | ||||
|         """Get all failed documents""" | ||||
|         return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED} | ||||
|  | ||||
|     async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: | ||||
|         """Get all pending documents""" | ||||
|         return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING} | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         """Save data to file after indexing""" | ||||
|         write_json(self._data, self._file_name) | ||||
|  | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         """Update or insert document status | ||||
|  | ||||
|         Args: | ||||
|             data: Dictionary of document IDs and their status data | ||||
|         """ | ||||
|         self._data.update(data) | ||||
|         await self.index_done_callback() | ||||
|         return data | ||||
|  | ||||
|     async def get_by_id(self, id: str): | ||||
|         return self._data.get(id) | ||||
|  | ||||
|     async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]: | ||||
|         """Get document status by ID""" | ||||
|         return self._data.get(doc_id) | ||||
|  | ||||
|     async def delete(self, doc_ids: list[str]): | ||||
|         """Delete document status by IDs""" | ||||
|         for doc_id in doc_ids: | ||||
|             self._data.pop(doc_id, None) | ||||
|         await self.index_done_callback() | ||||
							
								
								
									
										94
									
								
								minirag/kg/milvus_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								minirag/kg/milvus_impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| import asyncio | ||||
| import os | ||||
| from tqdm.asyncio import tqdm as tqdm_async | ||||
| from dataclasses import dataclass | ||||
| import numpy as np | ||||
| from lightrag.utils import logger | ||||
| from ..base import BaseVectorStorage | ||||
|  | ||||
| import pipmaster as pm | ||||
|  | ||||
| if not pm.is_installed("pymilvus"): | ||||
|     pm.install("pymilvus") | ||||
| from pymilvus import MilvusClient | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class MilvusVectorDBStorge(BaseVectorStorage): | ||||
|     @staticmethod | ||||
|     def create_collection_if_not_exist( | ||||
|         client: MilvusClient, collection_name: str, **kwargs | ||||
|     ): | ||||
|         if client.has_collection(collection_name): | ||||
|             return | ||||
|         client.create_collection( | ||||
|             collection_name, max_length=64, id_type="string", **kwargs | ||||
|         ) | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         self._client = MilvusClient( | ||||
|             uri=os.environ.get( | ||||
|                 "MILVUS_URI", | ||||
|                 os.path.join(self.global_config["working_dir"], "milvus_lite.db"), | ||||
|             ), | ||||
|             user=os.environ.get("MILVUS_USER", ""), | ||||
|             password=os.environ.get("MILVUS_PASSWORD", ""), | ||||
|             token=os.environ.get("MILVUS_TOKEN", ""), | ||||
|             db_name=os.environ.get("MILVUS_DB_NAME", ""), | ||||
|         ) | ||||
|         self._max_batch_size = self.global_config["embedding_batch_num"] | ||||
|         MilvusVectorDBStorge.create_collection_if_not_exist( | ||||
|             self._client, | ||||
|             self.namespace, | ||||
|             dimension=self.embedding_func.embedding_dim, | ||||
|         ) | ||||
|  | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         logger.info(f"Inserting {len(data)} vectors to {self.namespace}") | ||||
|         if not len(data): | ||||
|             logger.warning("You insert an empty data to vector DB") | ||||
|             return [] | ||||
|         list_data = [ | ||||
|             { | ||||
|                 "id": k, | ||||
|                 **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, | ||||
|             } | ||||
|             for k, v in data.items() | ||||
|         ] | ||||
|         contents = [v["content"] for v in data.values()] | ||||
|         batches = [ | ||||
|             contents[i : i + self._max_batch_size] | ||||
|             for i in range(0, len(contents), self._max_batch_size) | ||||
|         ] | ||||
|  | ||||
|         async def wrapped_task(batch): | ||||
|             result = await self.embedding_func(batch) | ||||
|             pbar.update(1) | ||||
|             return result | ||||
|  | ||||
|         embedding_tasks = [wrapped_task(batch) for batch in batches] | ||||
|         pbar = tqdm_async( | ||||
|             total=len(embedding_tasks), desc="Generating embeddings", unit="batch" | ||||
|         ) | ||||
|         embeddings_list = await asyncio.gather(*embedding_tasks) | ||||
|  | ||||
|         embeddings = np.concatenate(embeddings_list) | ||||
|         for i, d in enumerate(list_data): | ||||
|             d["vector"] = embeddings[i] | ||||
|         results = self._client.upsert(collection_name=self.namespace, data=list_data) | ||||
|         return results | ||||
|  | ||||
|     async def query(self, query, top_k=5): | ||||
|         embedding = await self.embedding_func([query]) | ||||
|         results = self._client.search( | ||||
|             collection_name=self.namespace, | ||||
|             data=embedding, | ||||
|             limit=top_k, | ||||
|             output_fields=list(self.meta_fields), | ||||
|             search_params={"metric_type": "COSINE", "params": {"radius": 0.2}}, | ||||
|         ) | ||||
|         print(results) | ||||
|         return [ | ||||
|             {**dp["entity"], "id": dp["id"], "distance": dp["distance"]} | ||||
|             for dp in results[0] | ||||
|         ] | ||||
							
								
								
									
										440
									
								
								minirag/kg/mongo_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										440
									
								
								minirag/kg/mongo_impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,440 @@ | ||||
| import os | ||||
| from tqdm.asyncio import tqdm as tqdm_async | ||||
| from dataclasses import dataclass | ||||
| import pipmaster as pm | ||||
| import np | ||||
|  | ||||
| if not pm.is_installed("pymongo"): | ||||
|     pm.install("pymongo") | ||||
|  | ||||
| from pymongo import MongoClient | ||||
| from motor.motor_asyncio import AsyncIOMotorClient | ||||
| from typing import Union, List, Tuple | ||||
| from lightrag.utils import logger | ||||
|  | ||||
| from lightrag.base import BaseKVStorage | ||||
| from lightrag.base import BaseGraphStorage | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class MongoKVStorage(BaseKVStorage): | ||||
|     def __post_init__(self): | ||||
|         client = MongoClient( | ||||
|             os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") | ||||
|         ) | ||||
|         database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG")) | ||||
|         self._data = database.get_collection(self.namespace) | ||||
|         logger.info(f"Use MongoDB as KV {self.namespace}") | ||||
|  | ||||
|     async def all_keys(self) -> list[str]: | ||||
|         return [x["_id"] for x in self._data.find({}, {"_id": 1})] | ||||
|  | ||||
|     async def get_by_id(self, id): | ||||
|         return self._data.find_one({"_id": id}) | ||||
|  | ||||
|     async def get_by_ids(self, ids, fields=None): | ||||
|         if fields is None: | ||||
|             return list(self._data.find({"_id": {"$in": ids}})) | ||||
|         return list( | ||||
|             self._data.find( | ||||
|                 {"_id": {"$in": ids}}, | ||||
|                 {field: 1 for field in fields}, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     async def filter_keys(self, data: list[str]) -> set[str]: | ||||
|         existing_ids = [ | ||||
|             str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1}) | ||||
|         ] | ||||
|         return set([s for s in data if s not in existing_ids]) | ||||
|  | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         if self.namespace == "llm_response_cache": | ||||
|             for mode, items in data.items(): | ||||
|                 for k, v in tqdm_async(items.items(), desc="Upserting"): | ||||
|                     key = f"{mode}_{k}" | ||||
|                     result = self._data.update_one( | ||||
|                         {"_id": key}, {"$setOnInsert": v}, upsert=True | ||||
|                     ) | ||||
|                     if result.upserted_id: | ||||
|                         logger.debug(f"\nInserted new document with key: {key}") | ||||
|                     data[mode][k]["_id"] = key | ||||
|         else: | ||||
|             for k, v in tqdm_async(data.items(), desc="Upserting"): | ||||
|                 self._data.update_one({"_id": k}, {"$set": v}, upsert=True) | ||||
|                 data[k]["_id"] = k | ||||
|         return data | ||||
|  | ||||
|     async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: | ||||
|         if "llm_response_cache" == self.namespace: | ||||
|             res = {} | ||||
|             v = self._data.find_one({"_id": mode + "_" + id}) | ||||
|             if v: | ||||
|                 res[id] = v | ||||
|                 logger.debug(f"llm_response_cache find one by:{id}") | ||||
|                 return res | ||||
|             else: | ||||
|                 return None | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     async def drop(self): | ||||
|         """ """ | ||||
|         pass | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class MongoGraphStorage(BaseGraphStorage): | ||||
|     """ | ||||
|     A concrete implementation using MongoDB’s $graphLookup to demonstrate multi-hop queries. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, namespace, global_config, embedding_func): | ||||
|         super().__init__( | ||||
|             namespace=namespace, | ||||
|             global_config=global_config, | ||||
|             embedding_func=embedding_func, | ||||
|         ) | ||||
|         self.client = AsyncIOMotorClient( | ||||
|             os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") | ||||
|         ) | ||||
|         self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")] | ||||
|         self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")] | ||||
|  | ||||
|     # | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # HELPER: $graphLookup pipeline | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # | ||||
|  | ||||
|     async def _graph_lookup( | ||||
|         self, start_node_id: str, max_depth: int = None | ||||
|     ) -> List[dict]: | ||||
|         """ | ||||
|         Performs a $graphLookup starting from 'start_node_id' and returns | ||||
|         all reachable documents (including the start node itself). | ||||
|  | ||||
|         Pipeline Explanation: | ||||
|         - 1) $match: We match the start node document by _id = start_node_id. | ||||
|         - 2) $graphLookup: | ||||
|             "from": same collection, | ||||
|             "startWith": "$edges.target" (the immediate neighbors in 'edges'), | ||||
|             "connectFromField": "edges.target", | ||||
|             "connectToField": "_id", | ||||
|             "as": "reachableNodes", | ||||
|             "maxDepth": max_depth (if provided), | ||||
|             "depthField": "depth" (used for debugging or filtering). | ||||
|         - 3) We add an $project or $unwind as needed to extract data. | ||||
|         """ | ||||
|         pipeline = [ | ||||
|             {"$match": {"_id": start_node_id}}, | ||||
|             { | ||||
|                 "$graphLookup": { | ||||
|                     "from": self.collection.name, | ||||
|                     "startWith": "$edges.target", | ||||
|                     "connectFromField": "edges.target", | ||||
|                     "connectToField": "_id", | ||||
|                     "as": "reachableNodes", | ||||
|                     "depthField": "depth", | ||||
|                 } | ||||
|             }, | ||||
|         ] | ||||
|  | ||||
|         # If you want a limited depth (e.g., only 1 or 2 hops), set maxDepth | ||||
|         if max_depth is not None: | ||||
|             pipeline[1]["$graphLookup"]["maxDepth"] = max_depth | ||||
|  | ||||
|         # Return the matching doc plus a field "reachableNodes" | ||||
|         cursor = self.collection.aggregate(pipeline) | ||||
|         results = await cursor.to_list(None) | ||||
|  | ||||
|         # If there's no matching node, results = []. | ||||
|         # Otherwise, results[0] is the start node doc, | ||||
|         # plus results[0]["reachableNodes"] is the array of connected docs. | ||||
|         return results | ||||
|  | ||||
|     # | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # BASIC QUERIES | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # | ||||
|  | ||||
|     async def has_node(self, node_id: str) -> bool: | ||||
|         """ | ||||
|         Check if node_id is present in the collection by looking up its doc. | ||||
|         No real need for $graphLookup here, but let's keep it direct. | ||||
|         """ | ||||
|         doc = await self.collection.find_one({"_id": node_id}, {"_id": 1}) | ||||
|         return doc is not None | ||||
|  | ||||
|     async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: | ||||
|         """ | ||||
|         Check if there's a direct single-hop edge from source_node_id to target_node_id. | ||||
|  | ||||
|         We'll do a $graphLookup with maxDepth=0 from the source node—meaning | ||||
|         “Look up zero expansions.” Actually, for a direct edge check, we can do maxDepth=1 | ||||
|         and then see if the target node is in the "reachableNodes" at depth=0. | ||||
|  | ||||
|         But typically for a direct edge, we might just do a find_one. | ||||
|         Below is a demonstration approach. | ||||
|         """ | ||||
|  | ||||
|         # We can do a single-hop graphLookup (maxDepth=0 or 1). | ||||
|         # Then check if the target_node appears among the edges array. | ||||
|         pipeline = [ | ||||
|             {"$match": {"_id": source_node_id}}, | ||||
|             { | ||||
|                 "$graphLookup": { | ||||
|                     "from": self.collection.name, | ||||
|                     "startWith": "$edges.target", | ||||
|                     "connectFromField": "edges.target", | ||||
|                     "connectToField": "_id", | ||||
|                     "as": "reachableNodes", | ||||
|                     "depthField": "depth", | ||||
|                     "maxDepth": 0,  # means: do not follow beyond immediate edges | ||||
|                 } | ||||
|             }, | ||||
|             { | ||||
|                 "$project": { | ||||
|                     "_id": 0, | ||||
|                     "reachableNodes._id": 1,  # only keep the _id from the subdocs | ||||
|                 } | ||||
|             }, | ||||
|         ] | ||||
|         cursor = self.collection.aggregate(pipeline) | ||||
|         results = await cursor.to_list(None) | ||||
|         if not results: | ||||
|             return False | ||||
|  | ||||
|         # results[0]["reachableNodes"] are the immediate neighbors | ||||
|         reachable_ids = [d["_id"] for d in results[0].get("reachableNodes", [])] | ||||
|         return target_node_id in reachable_ids | ||||
|  | ||||
|     # | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # DEGREES | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # | ||||
|  | ||||
|     async def node_degree(self, node_id: str) -> int: | ||||
|         """ | ||||
|         Returns the total number of edges connected to node_id (both inbound and outbound). | ||||
|         The easiest approach is typically two queries: | ||||
|          - count of edges array in node_id's doc | ||||
|          - count of how many other docs have node_id in their edges.target. | ||||
|  | ||||
|         But we'll do a $graphLookup demonstration for inbound edges: | ||||
|         1) Outbound edges: direct from node's edges array | ||||
|         2) Inbound edges: we can do a special $graphLookup from all docs | ||||
|            or do an explicit match. | ||||
|  | ||||
|         For demonstration, let's do this in two steps (with second step $graphLookup). | ||||
|         """ | ||||
|         # --- 1) Outbound edges (direct from doc) --- | ||||
|         doc = await self.collection.find_one({"_id": node_id}, {"edges": 1}) | ||||
|         if not doc: | ||||
|             return 0 | ||||
|         outbound_count = len(doc.get("edges", [])) | ||||
|  | ||||
|         # --- 2) Inbound edges: | ||||
|         # A simple way is: find all docs where "edges.target" == node_id. | ||||
|         # But let's do a $graphLookup from `node_id` in REVERSE. | ||||
|         # There's a trick to do "reverse" graphLookups: you'd store | ||||
|         # reversed edges or do a more advanced pipeline. Typically you'd do | ||||
|         # a direct match. We'll just do a direct match for inbound. | ||||
|         inbound_count_pipeline = [ | ||||
|             {"$match": {"edges.target": node_id}}, | ||||
|             { | ||||
|                 "$project": { | ||||
|                     "matchingEdgesCount": { | ||||
|                         "$size": { | ||||
|                             "$filter": { | ||||
|                                 "input": "$edges", | ||||
|                                 "as": "edge", | ||||
|                                 "cond": {"$eq": ["$$edge.target", node_id]}, | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             }, | ||||
|             {"$group": {"_id": None, "totalInbound": {"$sum": "$matchingEdgesCount"}}}, | ||||
|         ] | ||||
|         inbound_cursor = self.collection.aggregate(inbound_count_pipeline) | ||||
|         inbound_result = await inbound_cursor.to_list(None) | ||||
|         inbound_count = inbound_result[0]["totalInbound"] if inbound_result else 0 | ||||
|  | ||||
|         return outbound_count + inbound_count | ||||
|  | ||||
|     async def edge_degree(self, src_id: str, tgt_id: str) -> int: | ||||
|         """ | ||||
|         If your graph can hold multiple edges from the same src to the same tgt | ||||
|         (e.g. different 'relation' values), you can sum them. If it's always | ||||
|         one edge, this is typically 1 or 0. | ||||
|  | ||||
|         We'll do a single-hop $graphLookup from src_id, | ||||
|         then count how many edges reference tgt_id at depth=0. | ||||
|         """ | ||||
|         pipeline = [ | ||||
|             {"$match": {"_id": src_id}}, | ||||
|             { | ||||
|                 "$graphLookup": { | ||||
|                     "from": self.collection.name, | ||||
|                     "startWith": "$edges.target", | ||||
|                     "connectFromField": "edges.target", | ||||
|                     "connectToField": "_id", | ||||
|                     "as": "neighbors", | ||||
|                     "depthField": "depth", | ||||
|                     "maxDepth": 0, | ||||
|                 } | ||||
|             }, | ||||
|             {"$project": {"edges": 1, "neighbors._id": 1, "neighbors.type": 1}}, | ||||
|         ] | ||||
|         cursor = self.collection.aggregate(pipeline) | ||||
|         results = await cursor.to_list(None) | ||||
|         if not results: | ||||
|             return 0 | ||||
|  | ||||
|         # We can simply count how many edges in `results[0].edges` have target == tgt_id. | ||||
|         edges = results[0].get("edges", []) | ||||
|         count = sum(1 for e in edges if e.get("target") == tgt_id) | ||||
|         return count | ||||
|  | ||||
|     # | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # GETTERS | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # | ||||
|  | ||||
|     async def get_node(self, node_id: str) -> Union[dict, None]: | ||||
|         """ | ||||
|         Return the full node document (including "edges"), or None if missing. | ||||
|         """ | ||||
|         return await self.collection.find_one({"_id": node_id}) | ||||
|  | ||||
|     async def get_edge( | ||||
|         self, source_node_id: str, target_node_id: str | ||||
|     ) -> Union[dict, None]: | ||||
|         """ | ||||
|         Return the first edge dict from source_node_id to target_node_id if it exists. | ||||
|         Uses a single-hop $graphLookup as demonstration, though a direct find is simpler. | ||||
|         """ | ||||
|         pipeline = [ | ||||
|             {"$match": {"_id": source_node_id}}, | ||||
|             { | ||||
|                 "$graphLookup": { | ||||
|                     "from": self.collection.name, | ||||
|                     "startWith": "$edges.target", | ||||
|                     "connectFromField": "edges.target", | ||||
|                     "connectToField": "_id", | ||||
|                     "as": "neighbors", | ||||
|                     "depthField": "depth", | ||||
|                     "maxDepth": 0, | ||||
|                 } | ||||
|             }, | ||||
|             {"$project": {"edges": 1}}, | ||||
|         ] | ||||
|         cursor = self.collection.aggregate(pipeline) | ||||
|         docs = await cursor.to_list(None) | ||||
|         if not docs: | ||||
|             return None | ||||
|  | ||||
|         for e in docs[0].get("edges", []): | ||||
|             if e.get("target") == target_node_id: | ||||
|                 return e | ||||
|         return None | ||||
|  | ||||
|     async def get_node_edges( | ||||
|         self, source_node_id: str | ||||
|     ) -> Union[List[Tuple[str, str]], None]: | ||||
|         """ | ||||
|         Return a list of (target_id, relation) for direct edges from source_node_id. | ||||
|         Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler. | ||||
|         """ | ||||
|         pipeline = [ | ||||
|             {"$match": {"_id": source_node_id}}, | ||||
|             { | ||||
|                 "$graphLookup": { | ||||
|                     "from": self.collection.name, | ||||
|                     "startWith": "$edges.target", | ||||
|                     "connectFromField": "edges.target", | ||||
|                     "connectToField": "_id", | ||||
|                     "as": "neighbors", | ||||
|                     "depthField": "depth", | ||||
|                     "maxDepth": 0, | ||||
|                 } | ||||
|             }, | ||||
|             {"$project": {"_id": 0, "edges": 1}}, | ||||
|         ] | ||||
|         cursor = self.collection.aggregate(pipeline) | ||||
|         result = await cursor.to_list(None) | ||||
|         if not result: | ||||
|             return None | ||||
|  | ||||
|         edges = result[0].get("edges", []) | ||||
|         return [(e["target"], e["relation"]) for e in edges] | ||||
|  | ||||
|     # | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # UPSERTS | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # | ||||
|  | ||||
|     async def upsert_node(self, node_id: str, node_data: dict): | ||||
|         """ | ||||
|         Insert or update a node document. If new, create an empty edges array. | ||||
|         """ | ||||
|         # By default, preserve existing 'edges'. | ||||
|         # We'll only set 'edges' to [] on insert (no overwrite). | ||||
|         update_doc = {"$set": {**node_data}, "$setOnInsert": {"edges": []}} | ||||
|         await self.collection.update_one({"_id": node_id}, update_doc, upsert=True) | ||||
|  | ||||
|     async def upsert_edge( | ||||
|         self, source_node_id: str, target_node_id: str, edge_data: dict | ||||
|     ): | ||||
|         """ | ||||
|         Upsert an edge from source_node_id -> target_node_id with optional 'relation'. | ||||
|         If an edge with the same target exists, we remove it and re-insert with updated data. | ||||
|         """ | ||||
|         # Ensure source node exists | ||||
|         await self.upsert_node(source_node_id, {}) | ||||
|  | ||||
|         # Remove existing edge (if any) | ||||
|         await self.collection.update_one( | ||||
|             {"_id": source_node_id}, {"$pull": {"edges": {"target": target_node_id}}} | ||||
|         ) | ||||
|  | ||||
|         # Insert new edge | ||||
|         new_edge = {"target": target_node_id} | ||||
|         new_edge.update(edge_data) | ||||
|         await self.collection.update_one( | ||||
|             {"_id": source_node_id}, {"$push": {"edges": new_edge}} | ||||
|         ) | ||||
|  | ||||
|     # | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # DELETION | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # | ||||
|  | ||||
|     async def delete_node(self, node_id: str): | ||||
|         """ | ||||
|         1) Remove node’s doc entirely. | ||||
|         2) Remove inbound edges from any doc that references node_id. | ||||
|         """ | ||||
|         # Remove inbound edges from all other docs | ||||
|         await self.collection.update_many({}, {"$pull": {"edges": {"target": node_id}}}) | ||||
|  | ||||
|         # Remove the node doc | ||||
|         await self.collection.delete_one({"_id": node_id}) | ||||
|  | ||||
|     # | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # EMBEDDINGS (NOT IMPLEMENTED) | ||||
|     # ------------------------------------------------------------------------- | ||||
|     # | ||||
|  | ||||
|     async def embed_nodes(self, algorithm: str) -> Tuple[np.ndarray, List[str]]: | ||||
|         """ | ||||
|         Placeholder for demonstration, raises NotImplementedError. | ||||
|         """ | ||||
|         raise NotImplementedError("Node embedding is not used in lightrag.") | ||||
							
								
								
									
										213
									
								
								minirag/kg/nano_vector_db_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										213
									
								
								minirag/kg/nano_vector_db_impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,213 @@ | ||||
| """ | ||||
| NanoVectorDB Storage Module | ||||
| ======================= | ||||
|  | ||||
| This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. | ||||
|  | ||||
| The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. | ||||
|  | ||||
| Author: lightrag team | ||||
| Created: 2024-01-25 | ||||
| License: MIT | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| The above copyright notice and this permission notice shall be included in all | ||||
| copies or substantial portions of the Software. | ||||
|  | ||||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||
| SOFTWARE. | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Dependencies: | ||||
|     - NetworkX | ||||
|     - NumPy | ||||
|     - LightRAG | ||||
|     - graspologic | ||||
|  | ||||
| Features: | ||||
|     - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) | ||||
|     - Query graph nodes and edges | ||||
|     - Calculate node and edge degrees | ||||
|     - Embed nodes using various algorithms (e.g., Node2Vec) | ||||
|     - Remove nodes and edges from the graph | ||||
|  | ||||
| Usage: | ||||
|     from lightrag.storage.networkx_storage import NetworkXStorage | ||||
|  | ||||
| """ | ||||
|  | ||||
| import asyncio | ||||
| import os | ||||
| from tqdm.asyncio import tqdm as tqdm_async | ||||
| from dataclasses import dataclass | ||||
| import numpy as np | ||||
| import pipmaster as pm | ||||
|  | ||||
| if not pm.is_installed("nano-vectordb"): | ||||
|     pm.install("nano-vectordb") | ||||
|  | ||||
| from nano_vectordb import NanoVectorDB | ||||
| import time | ||||
|  | ||||
| from lightrag.utils import ( | ||||
|     logger, | ||||
|     compute_mdhash_id, | ||||
| ) | ||||
|  | ||||
| from lightrag.base import ( | ||||
|     BaseVectorStorage, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class NanoVectorDBStorage(BaseVectorStorage): | ||||
|     cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         # Use global config value if specified, otherwise use default | ||||
|         config = self.global_config.get("vector_db_storage_cls_kwargs", {}) | ||||
|         self.cosine_better_than_threshold = config.get( | ||||
|             "cosine_better_than_threshold", self.cosine_better_than_threshold | ||||
|         ) | ||||
|  | ||||
|         self._client_file_name = os.path.join( | ||||
|             self.global_config["working_dir"], f"vdb_{self.namespace}.json" | ||||
|         ) | ||||
|         self._max_batch_size = self.global_config["embedding_batch_num"] | ||||
|         self._client = NanoVectorDB( | ||||
|             self.embedding_func.embedding_dim, storage_file=self._client_file_name | ||||
|         ) | ||||
|  | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         logger.info(f"Inserting {len(data)} vectors to {self.namespace}") | ||||
|         if not len(data): | ||||
|             logger.warning("You insert an empty data to vector DB") | ||||
|             return [] | ||||
|  | ||||
|         current_time = time.time() | ||||
|         list_data = [ | ||||
|             { | ||||
|                 "__id__": k, | ||||
|                 "__created_at__": current_time, | ||||
|                 **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, | ||||
|             } | ||||
|             for k, v in data.items() | ||||
|         ] | ||||
|         contents = [v["content"] for v in data.values()] | ||||
|         batches = [ | ||||
|             contents[i : i + self._max_batch_size] | ||||
|             for i in range(0, len(contents), self._max_batch_size) | ||||
|         ] | ||||
|  | ||||
|         async def wrapped_task(batch): | ||||
|             result = await self.embedding_func(batch) | ||||
|             pbar.update(1) | ||||
|             return result | ||||
|  | ||||
|         embedding_tasks = [wrapped_task(batch) for batch in batches] | ||||
|         pbar = tqdm_async( | ||||
|             total=len(embedding_tasks), desc="Generating embeddings", unit="batch" | ||||
|         ) | ||||
|         embeddings_list = await asyncio.gather(*embedding_tasks) | ||||
|  | ||||
|         embeddings = np.concatenate(embeddings_list) | ||||
|         if len(embeddings) == len(list_data): | ||||
|             for i, d in enumerate(list_data): | ||||
|                 d["__vector__"] = embeddings[i] | ||||
|             results = self._client.upsert(datas=list_data) | ||||
|             return results | ||||
|         else: | ||||
|             # sometimes the embedding is not returned correctly. just log it. | ||||
|             logger.error( | ||||
|                 f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" | ||||
|             ) | ||||
|  | ||||
|     async def query(self, query: str, top_k=5): | ||||
|         embedding = await self.embedding_func([query]) | ||||
|         embedding = embedding[0] | ||||
|         logger.info( | ||||
|             f"Query: {query}, top_k: {top_k}, cosine_better_than_threshold: {self.cosine_better_than_threshold}" | ||||
|         ) | ||||
|         results = self._client.query( | ||||
|             query=embedding, | ||||
|             top_k=top_k, | ||||
|             better_than_threshold=self.cosine_better_than_threshold, | ||||
|         ) | ||||
|         results = [ | ||||
|             { | ||||
|                 **dp, | ||||
|                 "id": dp["__id__"], | ||||
|                 "distance": dp["__metrics__"], | ||||
|                 "created_at": dp.get("__created_at__"), | ||||
|             } | ||||
|             for dp in results | ||||
|         ] | ||||
|         return results | ||||
|  | ||||
|     @property | ||||
|     def client_storage(self): | ||||
|         return getattr(self._client, "_NanoVectorDB__storage") | ||||
|  | ||||
|     async def delete(self, ids: list[str]): | ||||
|         """Delete vectors with specified IDs | ||||
|  | ||||
|         Args: | ||||
|             ids: List of vector IDs to be deleted | ||||
|         """ | ||||
|         try: | ||||
|             self._client.delete(ids) | ||||
|             logger.info( | ||||
|                 f"Successfully deleted {len(ids)} vectors from {self.namespace}" | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             logger.error(f"Error while deleting vectors from {self.namespace}: {e}") | ||||
|  | ||||
|     async def delete_entity(self, entity_name: str): | ||||
|         try: | ||||
|             entity_id = compute_mdhash_id(entity_name, prefix="ent-") | ||||
|             logger.debug( | ||||
|                 f"Attempting to delete entity {entity_name} with ID {entity_id}" | ||||
|             ) | ||||
|             # Check if the entity exists | ||||
|             if self._client.get([entity_id]): | ||||
|                 await self.delete([entity_id]) | ||||
|                 logger.debug(f"Successfully deleted entity {entity_name}") | ||||
|             else: | ||||
|                 logger.debug(f"Entity {entity_name} not found in storage") | ||||
|         except Exception as e: | ||||
|             logger.error(f"Error deleting entity {entity_name}: {e}") | ||||
|  | ||||
|     async def delete_entity_relation(self, entity_name: str): | ||||
|         try: | ||||
|             relations = [ | ||||
|                 dp | ||||
|                 for dp in self.client_storage["data"] | ||||
|                 if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name | ||||
|             ] | ||||
|             logger.debug(f"Found {len(relations)} relations for entity {entity_name}") | ||||
|             ids_to_delete = [relation["__id__"] for relation in relations] | ||||
|  | ||||
|             if ids_to_delete: | ||||
|                 await self.delete(ids_to_delete) | ||||
|                 logger.debug( | ||||
|                     f"Deleted {len(ids_to_delete)} relations for {entity_name}" | ||||
|                 ) | ||||
|             else: | ||||
|                 logger.debug(f"No relations found for entity {entity_name}") | ||||
|         except Exception as e: | ||||
|             logger.error(f"Error deleting relations for {entity_name}: {e}") | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         self._client.save() | ||||
| @@ -1,18 +1,20 @@ | ||||
| import asyncio | ||||
| import inspect | ||||
| import os | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Union, Tuple, List, Dict | ||||
| import inspect | ||||
| from minirag.utils import logger | ||||
| from ..base import BaseGraphStorage | ||||
| import pipmaster as pm | ||||
|  | ||||
| if not pm.is_installed("neo4j"): | ||||
|     pm.install("neo4j") | ||||
|  | ||||
| from neo4j import ( | ||||
|     AsyncGraphDatabase, | ||||
|     exceptions as neo4jExceptions, | ||||
|     AsyncDriver, | ||||
|     AsyncManagedTransaction, | ||||
|     GraphDatabase, | ||||
| ) | ||||
|  | ||||
|  | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
| @@ -20,6 +22,9 @@ from tenacity import ( | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
|  | ||||
| from lightrag.utils import logger | ||||
| from ..base import BaseGraphStorage | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class Neo4JStorage(BaseGraphStorage): | ||||
| @@ -27,17 +32,63 @@ class Neo4JStorage(BaseGraphStorage): | ||||
|     def load_nx_graph(file_name): | ||||
|         print("no preloading of graph with neo4j in production") | ||||
|  | ||||
|     def __init__(self, namespace, global_config): | ||||
|         super().__init__(namespace=namespace, global_config=global_config) | ||||
|     def __init__(self, namespace, global_config, embedding_func): | ||||
|         super().__init__( | ||||
|             namespace=namespace, | ||||
|             global_config=global_config, | ||||
|             embedding_func=embedding_func, | ||||
|         ) | ||||
|         self._driver = None | ||||
|         self._driver_lock = asyncio.Lock() | ||||
|         URI = os.environ["NEO4J_URI"] | ||||
|         USERNAME = os.environ["NEO4J_USERNAME"] | ||||
|         PASSWORD = os.environ["NEO4J_PASSWORD"] | ||||
|         MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800) | ||||
|         DATABASE = os.environ.get( | ||||
|             "NEO4J_DATABASE" | ||||
|         )  # If this param is None, the home database will be used. If it is not None, the specified database will be used. | ||||
|         self._DATABASE = DATABASE | ||||
|         self._driver: AsyncDriver = AsyncGraphDatabase.driver( | ||||
|             URI, auth=(USERNAME, PASSWORD) | ||||
|         ) | ||||
|         return None | ||||
|         _database_name = "home database" if DATABASE is None else f"database {DATABASE}" | ||||
|         with GraphDatabase.driver( | ||||
|             URI, | ||||
|             auth=(USERNAME, PASSWORD), | ||||
|             max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, | ||||
|         ) as _sync_driver: | ||||
|             try: | ||||
|                 with _sync_driver.session(database=DATABASE) as session: | ||||
|                     try: | ||||
|                         session.run("MATCH (n) RETURN n LIMIT 0") | ||||
|                         logger.info(f"Connected to {DATABASE} at {URI}") | ||||
|                     except neo4jExceptions.ServiceUnavailable as e: | ||||
|                         logger.error( | ||||
|                             f"{DATABASE} at {URI} is not available".capitalize() | ||||
|                         ) | ||||
|                         raise e | ||||
|             except neo4jExceptions.AuthError as e: | ||||
|                 logger.error(f"Authentication failed for {DATABASE} at {URI}") | ||||
|                 raise e | ||||
|             except neo4jExceptions.ClientError as e: | ||||
|                 if e.code == "Neo.ClientError.Database.DatabaseNotFound": | ||||
|                     logger.info( | ||||
|                         f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize() | ||||
|                     ) | ||||
|                 try: | ||||
|                     with _sync_driver.session() as session: | ||||
|                         session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS") | ||||
|                         logger.info(f"{DATABASE} at {URI} created".capitalize()) | ||||
|                 except neo4jExceptions.ClientError as e: | ||||
|                     if ( | ||||
|                         e.code | ||||
|                         == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" | ||||
|                     ): | ||||
|                         logger.warning( | ||||
|                             "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead." | ||||
|                         ) | ||||
|                     logger.error(f"Failed to create {DATABASE} at {URI}") | ||||
|                     raise e | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         self._node_embed_algorithms = { | ||||
| @@ -59,7 +110,7 @@ class Neo4JStorage(BaseGraphStorage): | ||||
|     async def has_node(self, node_id: str) -> bool: | ||||
|         entity_name_label = node_id.strip('"') | ||||
|  | ||||
|         async with self._driver.session() as session: | ||||
|         async with self._driver.session(database=self._DATABASE) as session: | ||||
|             query = ( | ||||
|                 f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" | ||||
|             ) | ||||
| @@ -74,7 +125,7 @@ class Neo4JStorage(BaseGraphStorage): | ||||
|         entity_name_label_source = source_node_id.strip('"') | ||||
|         entity_name_label_target = target_node_id.strip('"') | ||||
|  | ||||
|         async with self._driver.session() as session: | ||||
|         async with self._driver.session(database=self._DATABASE) as session: | ||||
|             query = ( | ||||
|                 f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " | ||||
|                 "RETURN COUNT(r) > 0 AS edgeExists" | ||||
| @@ -86,11 +137,8 @@ class Neo4JStorage(BaseGraphStorage): | ||||
|             ) | ||||
|             return single_result["edgeExists"] | ||||
|  | ||||
|         def close(self): | ||||
|             self._driver.close() | ||||
|  | ||||
|     async def get_node(self, node_id: str) -> Union[dict, None]: | ||||
|         async with self._driver.session() as session: | ||||
|         async with self._driver.session(database=self._DATABASE) as session: | ||||
|             entity_name_label = node_id.strip('"') | ||||
|             query = f"MATCH (n:`{entity_name_label}`) RETURN n" | ||||
|             result = await session.run(query) | ||||
| @@ -107,7 +155,7 @@ class Neo4JStorage(BaseGraphStorage): | ||||
|     async def node_degree(self, node_id: str) -> int: | ||||
|         entity_name_label = node_id.strip('"') | ||||
|  | ||||
|         async with self._driver.session() as session: | ||||
|         async with self._driver.session(database=self._DATABASE) as session: | ||||
|             query = f""" | ||||
|                 MATCH (n:`{entity_name_label}`) | ||||
|                 RETURN COUNT{{ (n)--() }} AS totalEdgeCount | ||||
| @@ -154,7 +202,7 @@ class Neo4JStorage(BaseGraphStorage): | ||||
|         Returns: | ||||
|             list: List of all relationships/edges found | ||||
|         """ | ||||
|         async with self._driver.session() as session: | ||||
|         async with self._driver.session(database=self._DATABASE) as session: | ||||
|             query = f""" | ||||
|             MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) | ||||
|             RETURN properties(r) as edge_properties | ||||
| @@ -185,7 +233,7 @@ class Neo4JStorage(BaseGraphStorage): | ||||
|         query = f"""MATCH (n:`{node_label}`) | ||||
|                 OPTIONAL MATCH (n)-[r]-(connected) | ||||
|                 RETURN n, r, connected""" | ||||
|         async with self._driver.session() as session: | ||||
|         async with self._driver.session(database=self._DATABASE) as session: | ||||
|             results = await session.run(query) | ||||
|             edges = [] | ||||
|             async for record in results: | ||||
| @@ -214,6 +262,7 @@ class Neo4JStorage(BaseGraphStorage): | ||||
|                 neo4jExceptions.ServiceUnavailable, | ||||
|                 neo4jExceptions.TransientError, | ||||
|                 neo4jExceptions.WriteServiceUnavailable, | ||||
|                 neo4jExceptions.ClientError, | ||||
|             ) | ||||
|         ), | ||||
|     ) | ||||
| @@ -239,7 +288,7 @@ class Neo4JStorage(BaseGraphStorage): | ||||
|             ) | ||||
|  | ||||
|         try: | ||||
|             async with self._driver.session() as session: | ||||
|             async with self._driver.session(database=self._DATABASE) as session: | ||||
|                 await session.execute_write(_do_upsert) | ||||
|         except Exception as e: | ||||
|             logger.error(f"Error during upsert: {str(e)}") | ||||
| @@ -286,7 +335,7 @@ class Neo4JStorage(BaseGraphStorage): | ||||
|             ) | ||||
|  | ||||
|         try: | ||||
|             async with self._driver.session() as session: | ||||
|             async with self._driver.session(database=self._DATABASE) as session: | ||||
|                 await session.execute_write(_do_upsert_edge) | ||||
|         except Exception as e: | ||||
|             logger.error(f"Error during edge upsert: {str(e)}") | ||||
| @@ -294,3 +343,177 @@ class Neo4JStorage(BaseGraphStorage): | ||||
|  | ||||
|     async def _node2vec_embed(self): | ||||
|         print("Implemented but never called.") | ||||
|  | ||||
|     async def get_knowledge_graph( | ||||
|         self, node_label: str, max_depth: int = 5 | ||||
|     ) -> Dict[str, List[Dict]]: | ||||
|         """ | ||||
|         Get complete connected subgraph for specified node (including the starting node itself) | ||||
|  | ||||
|         Key fixes: | ||||
|         1. Include the starting node itself | ||||
|         2. Handle multi-label nodes | ||||
|         3. Clarify relationship directions | ||||
|         4. Add depth control | ||||
|         """ | ||||
|         label = node_label.strip('"') | ||||
|         result = {"nodes": [], "edges": []} | ||||
|         seen_nodes = set() | ||||
|         seen_edges = set() | ||||
|  | ||||
|         async with self._driver.session(database=self._DATABASE) as session: | ||||
|             try: | ||||
|                 # Critical debug step: first verify if starting node exists | ||||
|                 validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" | ||||
|                 validate_result = await session.run(validate_query) | ||||
|                 if not await validate_result.single(): | ||||
|                     logger.warning(f"Starting node {label} does not exist!") | ||||
|                     return result | ||||
|  | ||||
|                 # Optimized query (including direction handling and self-loops) | ||||
|                 main_query = f""" | ||||
|                 MATCH (start:`{label}`) | ||||
|                 WITH start | ||||
|                 CALL apoc.path.subgraphAll(start, {{ | ||||
|                     relationshipFilter: '>', | ||||
|                     minLevel: 0, | ||||
|                     maxLevel: {max_depth}, | ||||
|                     bfs: true | ||||
|                 }}) | ||||
|                 YIELD nodes, relationships | ||||
|                 RETURN nodes, relationships | ||||
|                 """ | ||||
|                 result_set = await session.run(main_query) | ||||
|                 record = await result_set.single() | ||||
|  | ||||
|                 if record: | ||||
|                     # Handle nodes (compatible with multi-label cases) | ||||
|                     for node in record["nodes"]: | ||||
|                         # Use node ID + label combination as unique identifier | ||||
|                         node_id = f"{node.id}_{'_'.join(node.labels)}" | ||||
|                         if node_id not in seen_nodes: | ||||
|                             node_data = dict(node) | ||||
|                             node_data["labels"] = list(node.labels)  # Keep all labels | ||||
|                             result["nodes"].append(node_data) | ||||
|                             seen_nodes.add(node_id) | ||||
|  | ||||
|                     # Handle relationships (including direction information) | ||||
|                     for rel in record["relationships"]: | ||||
|                         edge_id = f"{rel.id}_{rel.type}" | ||||
|                         if edge_id not in seen_edges: | ||||
|                             start = rel.start_node | ||||
|                             end = rel.end_node | ||||
|                             edge_data = dict(rel) | ||||
|                             edge_data.update( | ||||
|                                 { | ||||
|                                     "source": f"{start.id}_{'_'.join(start.labels)}", | ||||
|                                     "target": f"{end.id}_{'_'.join(end.labels)}", | ||||
|                                     "type": rel.type, | ||||
|                                     "direction": rel.element_id.split( | ||||
|                                         "->" if rel.end_node == end else "<-" | ||||
|                                     )[1], | ||||
|                                 } | ||||
|                             ) | ||||
|                             result["edges"].append(edge_data) | ||||
|                             seen_edges.add(edge_id) | ||||
|  | ||||
|                     logger.info( | ||||
|                         f"Subgraph query successful | Node count: {len(result['nodes'])} | Edge count: {len(result['edges'])}" | ||||
|                     ) | ||||
|  | ||||
|             except neo4jExceptions.ClientError as e: | ||||
|                 logger.error(f"APOC query failed: {str(e)}") | ||||
|                 return await self._robust_fallback(label, max_depth) | ||||
|  | ||||
|         return result | ||||
|  | ||||
|     async def _robust_fallback( | ||||
|         self, label: str, max_depth: int | ||||
|     ) -> Dict[str, List[Dict]]: | ||||
|         """Enhanced fallback query solution""" | ||||
|         result = {"nodes": [], "edges": []} | ||||
|         visited_nodes = set() | ||||
|         visited_edges = set() | ||||
|  | ||||
|         async def traverse(current_label: str, current_depth: int): | ||||
|             if current_depth > max_depth: | ||||
|                 return | ||||
|  | ||||
|             # Get current node details | ||||
|             node = await self.get_node(current_label) | ||||
|             if not node: | ||||
|                 return | ||||
|  | ||||
|             node_id = f"{current_label}" | ||||
|             if node_id in visited_nodes: | ||||
|                 return | ||||
|             visited_nodes.add(node_id) | ||||
|  | ||||
|             # Add node data (with complete labels) | ||||
|             node_data = {k: v for k, v in node.items()} | ||||
|             node_data["labels"] = [ | ||||
|                 current_label | ||||
|             ]  # Assume get_node method returns label information | ||||
|             result["nodes"].append(node_data) | ||||
|  | ||||
|             # Get all outgoing and incoming edges | ||||
|             query = f""" | ||||
|             MATCH (a)-[r]-(b) | ||||
|             WHERE a:`{current_label}` OR b:`{current_label}` | ||||
|             RETURN a, r, b, | ||||
|                    CASE WHEN startNode(r) = a THEN 'OUTGOING' ELSE 'INCOMING' END AS direction | ||||
|             """ | ||||
|             async with self._driver.session(database=self._DATABASE) as session: | ||||
|                 results = await session.run(query) | ||||
|                 async for record in results: | ||||
|                     # Handle edges | ||||
|                     rel = record["r"] | ||||
|                     edge_id = f"{rel.id}_{rel.type}" | ||||
|                     if edge_id not in visited_edges: | ||||
|                         edge_data = dict(rel) | ||||
|                         edge_data.update( | ||||
|                             { | ||||
|                                 "source": list(record["a"].labels)[0], | ||||
|                                 "target": list(record["b"].labels)[0], | ||||
|                                 "type": rel.type, | ||||
|                                 "direction": record["direction"], | ||||
|                             } | ||||
|                         ) | ||||
|                         result["edges"].append(edge_data) | ||||
|                         visited_edges.add(edge_id) | ||||
|  | ||||
|                         # Recursively traverse adjacent nodes | ||||
|                         next_label = ( | ||||
|                             list(record["b"].labels)[0] | ||||
|                             if record["direction"] == "OUTGOING" | ||||
|                             else list(record["a"].labels)[0] | ||||
|                         ) | ||||
|                         await traverse(next_label, current_depth + 1) | ||||
|  | ||||
|         await traverse(label, 0) | ||||
|         return result | ||||
|  | ||||
|     async def get_all_labels(self) -> List[str]: | ||||
|         """ | ||||
|         Get all existing node labels in the database | ||||
|         Returns: | ||||
|             ["Person", "Company", ...]  # Alphabetically sorted label list | ||||
|         """ | ||||
|         async with self._driver.session(database=self._DATABASE) as session: | ||||
|             # Method 1: Direct metadata query (Available for Neo4j 4.3+) | ||||
|             # query = "CALL db.labels() YIELD label RETURN label" | ||||
|  | ||||
|             # Method 2: Query compatible with older versions | ||||
|             query = """ | ||||
|                 MATCH (n) | ||||
|                 WITH DISTINCT labels(n) AS node_labels | ||||
|                 UNWIND node_labels AS label | ||||
|                 RETURN DISTINCT label | ||||
|                 ORDER BY label | ||||
|             """ | ||||
|  | ||||
|             result = await session.run(query) | ||||
|             labels = [] | ||||
|             async for record in result: | ||||
|                 labels.append(record["label"]) | ||||
|             return labels | ||||
|   | ||||
							
								
								
									
										228
									
								
								minirag/kg/networkx_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								minirag/kg/networkx_impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,228 @@ | ||||
| """ | ||||
| NetworkX Storage Module | ||||
| ======================= | ||||
|  | ||||
| This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. | ||||
|  | ||||
| The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. | ||||
|  | ||||
| Author: lightrag team | ||||
| Created: 2024-01-25 | ||||
| License: MIT | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| The above copyright notice and this permission notice shall be included in all | ||||
| copies or substantial portions of the Software. | ||||
|  | ||||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||
| SOFTWARE. | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Dependencies: | ||||
|     - NetworkX | ||||
|     - NumPy | ||||
|     - LightRAG | ||||
|     - graspologic | ||||
|  | ||||
| Features: | ||||
|     - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) | ||||
|     - Query graph nodes and edges | ||||
|     - Calculate node and edge degrees | ||||
|     - Embed nodes using various algorithms (e.g., Node2Vec) | ||||
|     - Remove nodes and edges from the graph | ||||
|  | ||||
| Usage: | ||||
|     from lightrag.storage.networkx_storage import NetworkXStorage | ||||
|  | ||||
| """ | ||||
|  | ||||
| import html | ||||
| import os | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Union, cast | ||||
| import networkx as nx | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| from lightrag.utils import ( | ||||
|     logger, | ||||
| ) | ||||
|  | ||||
| from lightrag.base import ( | ||||
|     BaseGraphStorage, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class NetworkXStorage(BaseGraphStorage): | ||||
|     @staticmethod | ||||
|     def load_nx_graph(file_name) -> nx.Graph: | ||||
|         if os.path.exists(file_name): | ||||
|             return nx.read_graphml(file_name) | ||||
|         return None | ||||
|  | ||||
|     @staticmethod | ||||
|     def write_nx_graph(graph: nx.Graph, file_name): | ||||
|         logger.info( | ||||
|             f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" | ||||
|         ) | ||||
|         nx.write_graphml(graph, file_name) | ||||
|  | ||||
|     @staticmethod | ||||
|     def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: | ||||
|         """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py | ||||
|         Return the largest connected component of the graph, with nodes and edges sorted in a stable way. | ||||
|         """ | ||||
|         from graspologic.utils import largest_connected_component | ||||
|  | ||||
|         graph = graph.copy() | ||||
|         graph = cast(nx.Graph, largest_connected_component(graph)) | ||||
|         node_mapping = { | ||||
|             node: html.unescape(node.upper().strip()) for node in graph.nodes() | ||||
|         }  # type: ignore | ||||
|         graph = nx.relabel_nodes(graph, node_mapping) | ||||
|         return NetworkXStorage._stabilize_graph(graph) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _stabilize_graph(graph: nx.Graph) -> nx.Graph: | ||||
|         """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py | ||||
|         Ensure an undirected graph with the same relationships will always be read the same way. | ||||
|         """ | ||||
|         fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() | ||||
|  | ||||
|         sorted_nodes = graph.nodes(data=True) | ||||
|         sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) | ||||
|  | ||||
|         fixed_graph.add_nodes_from(sorted_nodes) | ||||
|         edges = list(graph.edges(data=True)) | ||||
|  | ||||
|         if not graph.is_directed(): | ||||
|  | ||||
|             def _sort_source_target(edge): | ||||
|                 source, target, edge_data = edge | ||||
|                 if source > target: | ||||
|                     temp = source | ||||
|                     source = target | ||||
|                     target = temp | ||||
|                 return source, target, edge_data | ||||
|  | ||||
|             edges = [_sort_source_target(edge) for edge in edges] | ||||
|  | ||||
|         def _get_edge_key(source: Any, target: Any) -> str: | ||||
|             return f"{source} -> {target}" | ||||
|  | ||||
|         edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) | ||||
|  | ||||
|         fixed_graph.add_edges_from(edges) | ||||
|         return fixed_graph | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         self._graphml_xml_file = os.path.join( | ||||
|             self.global_config["working_dir"], f"graph_{self.namespace}.graphml" | ||||
|         ) | ||||
|         preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) | ||||
|         if preloaded_graph is not None: | ||||
|             logger.info( | ||||
|                 f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" | ||||
|             ) | ||||
|         self._graph = preloaded_graph or nx.Graph() | ||||
|         self._node_embed_algorithms = { | ||||
|             "node2vec": self._node2vec_embed, | ||||
|         } | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) | ||||
|  | ||||
|     async def has_node(self, node_id: str) -> bool: | ||||
|         return self._graph.has_node(node_id) | ||||
|  | ||||
|     async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: | ||||
|         return self._graph.has_edge(source_node_id, target_node_id) | ||||
|  | ||||
|     async def get_node(self, node_id: str) -> Union[dict, None]: | ||||
|         return self._graph.nodes.get(node_id) | ||||
|  | ||||
|     async def node_degree(self, node_id: str) -> int: | ||||
|         return self._graph.degree(node_id) | ||||
|  | ||||
|     async def edge_degree(self, src_id: str, tgt_id: str) -> int: | ||||
|         return self._graph.degree(src_id) + self._graph.degree(tgt_id) | ||||
|  | ||||
|     async def get_edge( | ||||
|         self, source_node_id: str, target_node_id: str | ||||
|     ) -> Union[dict, None]: | ||||
|         return self._graph.edges.get((source_node_id, target_node_id)) | ||||
|  | ||||
|     async def get_node_edges(self, source_node_id: str): | ||||
|         if self._graph.has_node(source_node_id): | ||||
|             return list(self._graph.edges(source_node_id)) | ||||
|         return None | ||||
|  | ||||
|     async def upsert_node(self, node_id: str, node_data: dict[str, str]): | ||||
|         self._graph.add_node(node_id, **node_data) | ||||
|  | ||||
|     async def upsert_edge( | ||||
|         self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] | ||||
|     ): | ||||
|         self._graph.add_edge(source_node_id, target_node_id, **edge_data) | ||||
|  | ||||
|     async def delete_node(self, node_id: str): | ||||
|         """ | ||||
|         Delete a node from the graph based on the specified node_id. | ||||
|  | ||||
|         :param node_id: The node_id to delete | ||||
|         """ | ||||
|         if self._graph.has_node(node_id): | ||||
|             self._graph.remove_node(node_id) | ||||
|             logger.info(f"Node {node_id} deleted from the graph.") | ||||
|         else: | ||||
|             logger.warning(f"Node {node_id} not found in the graph for deletion.") | ||||
|  | ||||
|     async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: | ||||
|         if algorithm not in self._node_embed_algorithms: | ||||
|             raise ValueError(f"Node embedding algorithm {algorithm} not supported") | ||||
|         return await self._node_embed_algorithms[algorithm]() | ||||
|  | ||||
|     # @TODO: NOT USED | ||||
|     async def _node2vec_embed(self): | ||||
|         from graspologic import embed | ||||
|  | ||||
|         embeddings, nodes = embed.node2vec_embed( | ||||
|             self._graph, | ||||
|             **self.global_config["node2vec_params"], | ||||
|         ) | ||||
|  | ||||
|         nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] | ||||
|         return embeddings, nodes_ids | ||||
|  | ||||
|     def remove_nodes(self, nodes: list[str]): | ||||
|         """Delete multiple nodes | ||||
|  | ||||
|         Args: | ||||
|             nodes: List of node IDs to be deleted | ||||
|         """ | ||||
|         for node in nodes: | ||||
|             if self._graph.has_node(node): | ||||
|                 self._graph.remove_node(node) | ||||
|  | ||||
|     def remove_edges(self, edges: list[tuple[str, str]]): | ||||
|         """Delete multiple edges | ||||
|  | ||||
|         Args: | ||||
|             edges: List of edges to be deleted, each edge is a (source, target) tuple | ||||
|         """ | ||||
|         for source, target in edges: | ||||
|             if self._graph.has_edge(source, target): | ||||
|                 self._graph.remove_edge(source, target) | ||||
| @@ -1,3 +1,4 @@ | ||||
| import os | ||||
| import asyncio | ||||
|  | ||||
| # import html | ||||
| @@ -6,6 +7,11 @@ from dataclasses import dataclass | ||||
| from typing import Union | ||||
| import numpy as np | ||||
| import array | ||||
| import pipmaster as pm | ||||
|  | ||||
| if not pm.is_installed("oracledb"): | ||||
|     pm.install("oracledb") | ||||
|  | ||||
|  | ||||
| from ..utils import logger | ||||
| from ..base import ( | ||||
| @@ -114,16 +120,19 @@ class OracleDB: | ||||
|  | ||||
|         logger.info("Finished check all tables in Oracle database") | ||||
|  | ||||
|     async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]: | ||||
|     async def query( | ||||
|         self, sql: str, params: dict = None, multirows: bool = False | ||||
|     ) -> Union[dict, None]: | ||||
|         async with self.pool.acquire() as connection: | ||||
|             connection.inputtypehandler = self.input_type_handler | ||||
|             connection.outputtypehandler = self.output_type_handler | ||||
|             with connection.cursor() as cursor: | ||||
|                 try: | ||||
|                     await cursor.execute(sql) | ||||
|                     await cursor.execute(sql, params) | ||||
|                 except Exception as e: | ||||
|                     logger.error(f"Oracle database error: {e}") | ||||
|                     print(sql) | ||||
|                     print(params) | ||||
|                     raise | ||||
|                 columns = [column[0].lower() for column in cursor.description] | ||||
|                 if multirows: | ||||
| @@ -140,7 +149,7 @@ class OracleDB: | ||||
|                         data = None | ||||
|                 return data | ||||
|  | ||||
|     async def execute(self, sql: str, data: list = None): | ||||
|     async def execute(self, sql: str, data: Union[list, dict] = None): | ||||
|         # logger.info("go into OracleDB execute method") | ||||
|         try: | ||||
|             async with self.pool.acquire() as connection: | ||||
| @@ -150,8 +159,6 @@ class OracleDB: | ||||
|                     if data is None: | ||||
|                         await cursor.execute(sql) | ||||
|                     else: | ||||
|                         # print(data) | ||||
|                         # print(sql) | ||||
|                         await cursor.execute(sql, data) | ||||
|                     await connection.commit() | ||||
|         except Exception as e: | ||||
| @@ -164,34 +171,64 @@ class OracleDB: | ||||
| @dataclass | ||||
| class OracleKVStorage(BaseKVStorage): | ||||
|     # should pass db object to self.db | ||||
|     db: OracleDB = None | ||||
|     meta_fields = None | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         self._data = {} | ||||
|         self._max_batch_size = self.global_config["embedding_batch_num"] | ||||
|         self._max_batch_size = self.global_config.get("embedding_batch_num", 10) | ||||
|  | ||||
|     ################ QUERY METHODS ################ | ||||
|  | ||||
|     async def get_by_id(self, id: str) -> Union[dict, None]: | ||||
|         """根据 id 获取 doc_full 数据.""" | ||||
|         SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format( | ||||
|             workspace=self.db.workspace, id=id | ||||
|         ) | ||||
|         """get doc_full data based on id.""" | ||||
|         SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] | ||||
|         params = {"workspace": self.db.workspace, "id": id} | ||||
|         # print("get_by_id:"+SQL) | ||||
|         res = await self.db.query(SQL) | ||||
|         if "llm_response_cache" == self.namespace: | ||||
|             array_res = await self.db.query(SQL, params, multirows=True) | ||||
|             res = {} | ||||
|             for row in array_res: | ||||
|                 res[row["id"]] = row | ||||
|         else: | ||||
|             res = await self.db.query(SQL, params) | ||||
|         if res: | ||||
|             data = res  # {"data":res} | ||||
|             # print (data) | ||||
|             return data | ||||
|             return res | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: | ||||
|         """Specifically for llm_response_cache.""" | ||||
|         SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] | ||||
|         params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id} | ||||
|         if "llm_response_cache" == self.namespace: | ||||
|             array_res = await self.db.query(SQL, params, multirows=True) | ||||
|             res = {} | ||||
|             for row in array_res: | ||||
|                 res[row["id"]] = row | ||||
|             return res | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     # Query by id | ||||
|     async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: | ||||
|         """根据 id 获取 doc_chunks 数据""" | ||||
|         """get doc_chunks data based on id""" | ||||
|         SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( | ||||
|             workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids]) | ||||
|             ids=",".join([f"'{id}'" for id in ids]) | ||||
|         ) | ||||
|         params = {"workspace": self.db.workspace} | ||||
|         # print("get_by_ids:"+SQL) | ||||
|         res = await self.db.query(SQL, multirows=True) | ||||
|         res = await self.db.query(SQL, params, multirows=True) | ||||
|         if "llm_response_cache" == self.namespace: | ||||
|             modes = set() | ||||
|             dict_res: dict[str, dict] = {} | ||||
|             for row in res: | ||||
|                 modes.add(row["mode"]) | ||||
|             for mode in modes: | ||||
|                 if mode not in dict_res: | ||||
|                     dict_res[mode] = {} | ||||
|             for row in res: | ||||
|                 dict_res[row["mode"]][row["id"]] = row | ||||
|             res = [{k: v} for k, v in dict_res.items()] | ||||
|         if res: | ||||
|             data = res  # [{"data":i} for i in res] | ||||
|             # print(data) | ||||
| @@ -199,33 +236,43 @@ class OracleKVStorage(BaseKVStorage): | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     async def get_by_status_and_ids( | ||||
|         self, status: str, ids: list[str] | ||||
|     ) -> Union[list[dict], None]: | ||||
|         """Specifically for llm_response_cache.""" | ||||
|         if ids is not None: | ||||
|             SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format( | ||||
|                 ids=",".join([f"'{id}'" for id in ids]) | ||||
|             ) | ||||
|         else: | ||||
|             SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] | ||||
|         params = {"workspace": self.db.workspace, "status": status} | ||||
|         res = await self.db.query(SQL, params, multirows=True) | ||||
|         if res: | ||||
|             return res | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     async def filter_keys(self, keys: list[str]) -> set[str]: | ||||
|         """过滤掉重复内容""" | ||||
|         """Return keys that don't exist in storage""" | ||||
|         SQL = SQL_TEMPLATES["filter_keys"].format( | ||||
|             table_name=N_T[self.namespace], | ||||
|             workspace=self.db.workspace, | ||||
|             ids=",".join([f"'{k}'" for k in keys]), | ||||
|             table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) | ||||
|         ) | ||||
|         res = await self.db.query(SQL, multirows=True) | ||||
|         data = None | ||||
|         params = {"workspace": self.db.workspace} | ||||
|         res = await self.db.query(SQL, params, multirows=True) | ||||
|         if res: | ||||
|             exist_keys = [key["id"] for key in res] | ||||
|             data = set([s for s in keys if s not in exist_keys]) | ||||
|             return data | ||||
|         else: | ||||
|             exist_keys = [] | ||||
|             data = set([s for s in keys if s not in exist_keys]) | ||||
|         return data | ||||
|             return set(keys) | ||||
|  | ||||
|     ################ INSERT METHODS ################ | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         left_data = {k: v for k, v in data.items() if k not in self._data} | ||||
|         self._data.update(left_data) | ||||
|         # print(self._data) | ||||
|         # values = [] | ||||
|         if self.namespace == "text_chunks": | ||||
|             list_data = [ | ||||
|                 { | ||||
|                     "__id__": k, | ||||
|                     "id": k, | ||||
|                     **{k1: v1 for k1, v1 in v.items()}, | ||||
|                 } | ||||
|                 for k, v in data.items() | ||||
| @@ -241,32 +288,50 @@ class OracleKVStorage(BaseKVStorage): | ||||
|             embeddings = np.concatenate(embeddings_list) | ||||
|             for i, d in enumerate(list_data): | ||||
|                 d["__vector__"] = embeddings[i] | ||||
|             # print(list_data) | ||||
|  | ||||
|             merge_sql = SQL_TEMPLATES["merge_chunk"] | ||||
|             for item in list_data: | ||||
|                 merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"]) | ||||
|  | ||||
|                 values = [ | ||||
|                     item["__id__"], | ||||
|                     item["content"], | ||||
|                     self.db.workspace, | ||||
|                     item["tokens"], | ||||
|                     item["chunk_order_index"], | ||||
|                     item["full_doc_id"], | ||||
|                     item["__vector__"], | ||||
|                 ] | ||||
|                 # print(merge_sql) | ||||
|                 await self.db.execute(merge_sql, values) | ||||
|  | ||||
|                 _data = { | ||||
|                     "id": item["id"], | ||||
|                     "content": item["content"], | ||||
|                     "workspace": self.db.workspace, | ||||
|                     "tokens": item["tokens"], | ||||
|                     "chunk_order_index": item["chunk_order_index"], | ||||
|                     "full_doc_id": item["full_doc_id"], | ||||
|                     "content_vector": item["__vector__"], | ||||
|                     "status": item["status"], | ||||
|                 } | ||||
|                 await self.db.execute(merge_sql, _data) | ||||
|         if self.namespace == "full_docs": | ||||
|             for k, v in self._data.items(): | ||||
|             for k, v in data.items(): | ||||
|                 # values.clear() | ||||
|                 merge_sql = SQL_TEMPLATES["merge_doc_full"].format( | ||||
|                     check_id=k, | ||||
|                 ) | ||||
|                 values = [k, self._data[k]["content"], self.db.workspace] | ||||
|                 # print(merge_sql) | ||||
|                 await self.db.execute(merge_sql, values) | ||||
|         return left_data | ||||
|                 merge_sql = SQL_TEMPLATES["merge_doc_full"] | ||||
|                 _data = { | ||||
|                     "id": k, | ||||
|                     "content": v["content"], | ||||
|                     "workspace": self.db.workspace, | ||||
|                 } | ||||
|                 await self.db.execute(merge_sql, _data) | ||||
|  | ||||
|         if self.namespace == "llm_response_cache": | ||||
|             for mode, items in data.items(): | ||||
|                 for k, v in items.items(): | ||||
|                     upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] | ||||
|                     _data = { | ||||
|                         "workspace": self.db.workspace, | ||||
|                         "id": k, | ||||
|                         "original_prompt": v["original_prompt"], | ||||
|                         "return_value": v["return"], | ||||
|                         "cache_mode": mode, | ||||
|                     } | ||||
|  | ||||
|                     await self.db.execute(upsert_sql, _data) | ||||
|         return None | ||||
|  | ||||
|     async def change_status(self, id: str, status: str): | ||||
|         SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace]) | ||||
|         params = {"workspace": self.db.workspace, "id": id, "status": status} | ||||
|         await self.db.execute(SQL, params) | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         if self.namespace in ["full_docs", "text_chunks"]: | ||||
| @@ -275,10 +340,16 @@ class OracleKVStorage(BaseKVStorage): | ||||
|  | ||||
| @dataclass | ||||
| class OracleVectorDBStorage(BaseVectorStorage): | ||||
|     cosine_better_than_threshold: float = 0.2 | ||||
|     # should pass db object to self.db | ||||
|     db: OracleDB = None | ||||
|     cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         pass | ||||
|         # Use global config value if specified, otherwise use default | ||||
|         config = self.global_config.get("vector_db_storage_cls_kwargs", {}) | ||||
|         self.cosine_better_than_threshold = config.get( | ||||
|             "cosine_better_than_threshold", self.cosine_better_than_threshold | ||||
|         ) | ||||
|  | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         """向向量数据库中插入数据""" | ||||
| @@ -295,18 +366,17 @@ class OracleVectorDBStorage(BaseVectorStorage): | ||||
|         # 转换精度 | ||||
|         dtype = str(embedding.dtype).upper() | ||||
|         dimension = embedding.shape[0] | ||||
|         embedding_string = ", ".join(map(str, embedding.tolist())) | ||||
|         embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]" | ||||
|  | ||||
|         SQL = SQL_TEMPLATES[self.namespace].format( | ||||
|             embedding_string=embedding_string, | ||||
|             dimension=dimension, | ||||
|             dtype=dtype, | ||||
|             workspace=self.db.workspace, | ||||
|             top_k=top_k, | ||||
|             better_than_threshold=self.cosine_better_than_threshold, | ||||
|         ) | ||||
|         SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype) | ||||
|         params = { | ||||
|             "embedding_string": embedding_string, | ||||
|             "workspace": self.db.workspace, | ||||
|             "top_k": top_k, | ||||
|             "better_than_threshold": self.cosine_better_than_threshold, | ||||
|         } | ||||
|         # print(SQL) | ||||
|         results = await self.db.query(SQL, multirows=True) | ||||
|         results = await self.db.query(SQL, params=params, multirows=True) | ||||
|         # print("vector search result:",results) | ||||
|         return results | ||||
|  | ||||
| @@ -317,7 +387,7 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         """从graphml文件加载图""" | ||||
|         self._max_batch_size = self.global_config["embedding_batch_num"] | ||||
|         self._max_batch_size = self.global_config.get("embedding_batch_num", 10) | ||||
|  | ||||
|     #################### insert method ################ | ||||
|  | ||||
| @@ -328,6 +398,8 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|         entity_type = node_data["entity_type"] | ||||
|         description = node_data["description"] | ||||
|         source_id = node_data["source_id"] | ||||
|         logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}") | ||||
|  | ||||
|         content = entity_name + description | ||||
|         contents = [content] | ||||
|         batches = [ | ||||
| @@ -339,22 +411,17 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|         ) | ||||
|         embeddings = np.concatenate(embeddings_list) | ||||
|         content_vector = embeddings[0] | ||||
|         merge_sql = SQL_TEMPLATES["merge_node"].format( | ||||
|             workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id | ||||
|         ) | ||||
|         # print(merge_sql) | ||||
|         await self.db.execute( | ||||
|             merge_sql, | ||||
|             [ | ||||
|                 self.db.workspace, | ||||
|                 entity_name, | ||||
|                 entity_type, | ||||
|                 description, | ||||
|                 source_id, | ||||
|                 content, | ||||
|                 content_vector, | ||||
|             ], | ||||
|         ) | ||||
|         merge_sql = SQL_TEMPLATES["merge_node"] | ||||
|         data = { | ||||
|             "workspace": self.db.workspace, | ||||
|             "name": entity_name, | ||||
|             "entity_type": entity_type, | ||||
|             "description": description, | ||||
|             "source_chunk_id": source_id, | ||||
|             "content": content, | ||||
|             "content_vector": content_vector, | ||||
|         } | ||||
|         await self.db.execute(merge_sql, data) | ||||
|         # self._graph.add_node(node_id, **node_data) | ||||
|  | ||||
|     async def upsert_edge( | ||||
| @@ -368,6 +435,10 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|         keywords = edge_data["keywords"] | ||||
|         description = edge_data["description"] | ||||
|         source_chunk_id = edge_data["source_id"] | ||||
|         logger.debug( | ||||
|             f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}" | ||||
|         ) | ||||
|  | ||||
|         content = keywords + source_name + target_name + description | ||||
|         contents = [content] | ||||
|         batches = [ | ||||
| @@ -379,27 +450,20 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|         ) | ||||
|         embeddings = np.concatenate(embeddings_list) | ||||
|         content_vector = embeddings[0] | ||||
|         merge_sql = SQL_TEMPLATES["merge_edge"].format( | ||||
|             workspace=self.db.workspace, | ||||
|             source_name=source_name, | ||||
|             target_name=target_name, | ||||
|             source_chunk_id=source_chunk_id, | ||||
|         ) | ||||
|         merge_sql = SQL_TEMPLATES["merge_edge"] | ||||
|         data = { | ||||
|             "workspace": self.db.workspace, | ||||
|             "source_name": source_name, | ||||
|             "target_name": target_name, | ||||
|             "weight": weight, | ||||
|             "keywords": keywords, | ||||
|             "description": description, | ||||
|             "source_chunk_id": source_chunk_id, | ||||
|             "content": content, | ||||
|             "content_vector": content_vector, | ||||
|         } | ||||
|         # print(merge_sql) | ||||
|         await self.db.execute( | ||||
|             merge_sql, | ||||
|             [ | ||||
|                 self.db.workspace, | ||||
|                 source_name, | ||||
|                 target_name, | ||||
|                 weight, | ||||
|                 keywords, | ||||
|                 description, | ||||
|                 source_chunk_id, | ||||
|                 content, | ||||
|                 content_vector, | ||||
|             ], | ||||
|         ) | ||||
|         await self.db.execute(merge_sql, data) | ||||
|         # self._graph.add_edge(source_node_id, target_node_id, **edge_data) | ||||
|  | ||||
|     async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: | ||||
| @@ -429,12 +493,11 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|     #################### query method ################# | ||||
|     async def has_node(self, node_id: str) -> bool: | ||||
|         """根据节点id检查节点是否存在""" | ||||
|         SQL = SQL_TEMPLATES["has_node"].format( | ||||
|             workspace=self.db.workspace, node_id=node_id | ||||
|         ) | ||||
|         SQL = SQL_TEMPLATES["has_node"] | ||||
|         params = {"workspace": self.db.workspace, "node_id": node_id} | ||||
|         # print(SQL) | ||||
|         # print(self.db.workspace, node_id) | ||||
|         res = await self.db.query(SQL) | ||||
|         res = await self.db.query(SQL, params) | ||||
|         if res: | ||||
|             # print("Node exist!",res) | ||||
|             return True | ||||
| @@ -444,13 +507,14 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|  | ||||
|     async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: | ||||
|         """根据源和目标节点id检查边是否存在""" | ||||
|         SQL = SQL_TEMPLATES["has_edge"].format( | ||||
|             workspace=self.db.workspace, | ||||
|             source_node_id=source_node_id, | ||||
|             target_node_id=target_node_id, | ||||
|         ) | ||||
|         SQL = SQL_TEMPLATES["has_edge"] | ||||
|         params = { | ||||
|             "workspace": self.db.workspace, | ||||
|             "source_node_id": source_node_id, | ||||
|             "target_node_id": target_node_id, | ||||
|         } | ||||
|         # print(SQL) | ||||
|         res = await self.db.query(SQL) | ||||
|         res = await self.db.query(SQL, params) | ||||
|         if res: | ||||
|             # print("Edge exist!",res) | ||||
|             return True | ||||
| @@ -460,11 +524,10 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|  | ||||
|     async def node_degree(self, node_id: str) -> int: | ||||
|         """根据节点id获取节点的度""" | ||||
|         SQL = SQL_TEMPLATES["node_degree"].format( | ||||
|             workspace=self.db.workspace, node_id=node_id | ||||
|         ) | ||||
|         SQL = SQL_TEMPLATES["node_degree"] | ||||
|         params = {"workspace": self.db.workspace, "node_id": node_id} | ||||
|         # print(SQL) | ||||
|         res = await self.db.query(SQL) | ||||
|         res = await self.db.query(SQL, params) | ||||
|         if res: | ||||
|             # print("Node degree",res["degree"]) | ||||
|             return res["degree"] | ||||
| @@ -480,12 +543,11 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|  | ||||
|     async def get_node(self, node_id: str) -> Union[dict, None]: | ||||
|         """根据节点id获取节点数据""" | ||||
|         SQL = SQL_TEMPLATES["get_node"].format( | ||||
|             workspace=self.db.workspace, node_id=node_id | ||||
|         ) | ||||
|         SQL = SQL_TEMPLATES["get_node"] | ||||
|         params = {"workspace": self.db.workspace, "node_id": node_id} | ||||
|         # print(self.db.workspace, node_id) | ||||
|         # print(SQL) | ||||
|         res = await self.db.query(SQL) | ||||
|         res = await self.db.query(SQL, params) | ||||
|         if res: | ||||
|             # print("Get node!",self.db.workspace, node_id,res) | ||||
|             return res | ||||
| @@ -497,12 +559,13 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|         self, source_node_id: str, target_node_id: str | ||||
|     ) -> Union[dict, None]: | ||||
|         """根据源和目标节点id获取边""" | ||||
|         SQL = SQL_TEMPLATES["get_edge"].format( | ||||
|             workspace=self.db.workspace, | ||||
|             source_node_id=source_node_id, | ||||
|             target_node_id=target_node_id, | ||||
|         ) | ||||
|         res = await self.db.query(SQL) | ||||
|         SQL = SQL_TEMPLATES["get_edge"] | ||||
|         params = { | ||||
|             "workspace": self.db.workspace, | ||||
|             "source_node_id": source_node_id, | ||||
|             "target_node_id": target_node_id, | ||||
|         } | ||||
|         res = await self.db.query(SQL, params) | ||||
|         if res: | ||||
|             # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0]) | ||||
|             return res | ||||
| @@ -513,10 +576,9 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|     async def get_node_edges(self, source_node_id: str): | ||||
|         """根据节点id获取节点的所有边""" | ||||
|         if await self.has_node(source_node_id): | ||||
|             SQL = SQL_TEMPLATES["get_node_edges"].format( | ||||
|                 workspace=self.db.workspace, source_node_id=source_node_id | ||||
|             ) | ||||
|             res = await self.db.query(sql=SQL, multirows=True) | ||||
|             SQL = SQL_TEMPLATES["get_node_edges"] | ||||
|             params = {"workspace": self.db.workspace, "source_node_id": source_node_id} | ||||
|             res = await self.db.query(sql=SQL, params=params, multirows=True) | ||||
|             if res: | ||||
|                 data = [(i["source_name"], i["target_name"]) for i in res] | ||||
|                 # print("Get node edge!",self.db.workspace, source_node_id,data) | ||||
| @@ -525,6 +587,29 @@ class OracleGraphStorage(BaseGraphStorage): | ||||
|                 # print("Node Edge not exist!",self.db.workspace, source_node_id) | ||||
|                 return [] | ||||
|  | ||||
|     async def get_all_nodes(self, limit: int): | ||||
|         """查询所有节点""" | ||||
|         SQL = SQL_TEMPLATES["get_all_nodes"] | ||||
|         params = {"workspace": self.db.workspace, "limit": str(limit)} | ||||
|         res = await self.db.query(sql=SQL, params=params, multirows=True) | ||||
|         if res: | ||||
|             return res | ||||
|  | ||||
|     async def get_all_edges(self, limit: int): | ||||
|         """查询所有边""" | ||||
|         SQL = SQL_TEMPLATES["get_all_edges"] | ||||
|         params = {"workspace": self.db.workspace, "limit": str(limit)} | ||||
|         res = await self.db.query(sql=SQL, params=params, multirows=True) | ||||
|         if res: | ||||
|             return res | ||||
|  | ||||
|     async def get_statistics(self): | ||||
|         SQL = SQL_TEMPLATES["get_statistics"] | ||||
|         params = {"workspace": self.db.workspace} | ||||
|         res = await self.db.query(sql=SQL, params=params, multirows=True) | ||||
|         if res: | ||||
|             return res | ||||
|  | ||||
|  | ||||
| N_T = { | ||||
|     "full_docs": "LIGHTRAG_DOC_FULL", | ||||
| @@ -537,20 +622,26 @@ N_T = { | ||||
| TABLES = { | ||||
|     "LIGHTRAG_DOC_FULL": { | ||||
|         "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( | ||||
|                     id varchar(256)PRIMARY KEY, | ||||
|                     id varchar(256), | ||||
|                     workspace varchar(1024), | ||||
|                     doc_name varchar(1024), | ||||
|                     content CLOB, | ||||
|                     meta JSON, | ||||
|                     content_summary varchar(1024), | ||||
|                     content_length NUMBER, | ||||
|                     status varchar(256), | ||||
|                     chunks_count NUMBER, | ||||
|                     createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | ||||
|                     updatetime TIMESTAMP DEFAULT NULL | ||||
|                     updatetime TIMESTAMP DEFAULT NULL, | ||||
|                     error varchar(4096) | ||||
|                     )""" | ||||
|     }, | ||||
|     "LIGHTRAG_DOC_CHUNKS": { | ||||
|         "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( | ||||
|                     id varchar(256) PRIMARY KEY, | ||||
|                     id varchar(256), | ||||
|                     workspace varchar(1024), | ||||
|                     full_doc_id varchar(256), | ||||
|                     status varchar(256), | ||||
|                     chunk_order_index NUMBER, | ||||
|                     tokens NUMBER, | ||||
|                     content CLOB, | ||||
| @@ -592,9 +683,15 @@ TABLES = { | ||||
|     "LIGHTRAG_LLM_CACHE": { | ||||
|         "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( | ||||
|                     id varchar(256) PRIMARY KEY, | ||||
|                     send clob, | ||||
|                     return clob, | ||||
|                     model varchar(1024), | ||||
|                     workspace varchar(1024), | ||||
|                     cache_mode varchar(256), | ||||
|                     model_name varchar(256), | ||||
|                     original_prompt clob, | ||||
|                     return_value clob, | ||||
|                     embedding CLOB, | ||||
|                     embedding_shape NUMBER, | ||||
|                     embedding_min NUMBER, | ||||
|                     embedding_max NUMBER, | ||||
|                     createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | ||||
|                     updatetime TIMESTAMP DEFAULT NULL | ||||
|                     )""" | ||||
| @@ -619,82 +716,141 @@ TABLES = { | ||||
|  | ||||
| SQL_TEMPLATES = { | ||||
|     # SQL for KVStorage | ||||
|     "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'", | ||||
|     "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'", | ||||
|     "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})", | ||||
|     "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID  from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})", | ||||
|     "filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})", | ||||
|     "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a | ||||
|                     USING DUAL | ||||
|                     ON (a.id = '{check_id}') | ||||
|                     WHEN NOT MATCHED THEN | ||||
|                     INSERT(id,content,workspace) values(:1,:2,:3) | ||||
|                     """, | ||||
|     "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a | ||||
|                     USING DUAL | ||||
|                     ON (a.id = '{check_id}') | ||||
|                     WHEN NOT MATCHED THEN | ||||
|                     INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) | ||||
|                     values (:1,:2,:3,:4,:5,:6,:7) """, | ||||
|     "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", | ||||
|     "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", | ||||
|     "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" | ||||
|         FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", | ||||
|     "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" | ||||
|         FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""", | ||||
|     "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" | ||||
|         FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace  AND id IN ({ids})""", | ||||
|     "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})", | ||||
|     "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID  from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", | ||||
|     "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})", | ||||
|     "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})", | ||||
|     "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", | ||||
|     "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", | ||||
|     "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", | ||||
|     "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id", | ||||
|     "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a | ||||
|         USING DUAL | ||||
|         ON (a.id = :id and a.workspace = :workspace) | ||||
|         WHEN NOT MATCHED THEN | ||||
|         INSERT(id,content,workspace) values(:id,:content,:workspace)""", | ||||
|     "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS | ||||
|         USING DUAL | ||||
|         ON (id = :id and workspace = :workspace) | ||||
|         WHEN NOT MATCHED THEN INSERT | ||||
|             (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status) | ||||
|             values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """, | ||||
|     "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a | ||||
|         USING DUAL | ||||
|         ON (a.id = :id) | ||||
|         WHEN NOT MATCHED THEN | ||||
|         INSERT (workspace,id,original_prompt,return_value,cache_mode) | ||||
|             VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode) | ||||
|         WHEN MATCHED THEN UPDATE | ||||
|             SET original_prompt = :original_prompt, | ||||
|             return_value = :return_value, | ||||
|             cache_mode = :cache_mode, | ||||
|             updatetime = SYSDATE""", | ||||
|     # SQL for VectorStorage | ||||
|     "entities": """SELECT name as entity_name FROM | ||||
|         (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance | ||||
|         FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}') | ||||
|         WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", | ||||
|         (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance | ||||
|         FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace) | ||||
|         WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", | ||||
|     "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM | ||||
|         (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance | ||||
|         FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}') | ||||
|         WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", | ||||
|         (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance | ||||
|         FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace) | ||||
|         WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", | ||||
|     "chunks": """SELECT id FROM | ||||
|         (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance | ||||
|         FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}') | ||||
|         WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""", | ||||
|         (SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance | ||||
|         FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace) | ||||
|         WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""", | ||||
|     # SQL for GraphStorage | ||||
|     "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph | ||||
|         MATCH (a) | ||||
|         WHERE a.workspace='{workspace}' AND a.name='{node_id}' | ||||
|         WHERE a.workspace=:workspace AND a.name=:node_id | ||||
|         COLUMNS (a.name))""", | ||||
|     "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph | ||||
|         MATCH (a) -[e]-> (b) | ||||
|         WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' | ||||
|         AND a.name='{source_node_id}' AND b.name='{target_node_id}' | ||||
|         WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace | ||||
|         AND a.name=:source_node_id AND b.name=:target_node_id | ||||
|         COLUMNS (e.source_name,e.target_name)  )""", | ||||
|     "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph | ||||
|         MATCH (a)-[e]->(b) | ||||
|         WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' | ||||
|         AND a.name='{node_id}' or b.name = '{node_id}' | ||||
|         WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace | ||||
|         AND a.name=:node_id or b.name = :node_id | ||||
|         COLUMNS (a.name))""", | ||||
|     "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description | ||||
|         FROM GRAPH_TABLE (lightrag_graph | ||||
|         MATCH (a) | ||||
|         WHERE a.workspace='{workspace}' AND a.name='{node_id}' | ||||
|         WHERE a.workspace=:workspace AND a.name=:node_id | ||||
|         COLUMNS (a.name) | ||||
|         ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name | ||||
|         WHERE t2.workspace='{workspace}'""", | ||||
|         WHERE t2.workspace=:workspace""", | ||||
|     "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords, | ||||
|         NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords | ||||
|         FROM GRAPH_TABLE (lightrag_graph | ||||
|         MATCH (a)-[e]->(b) | ||||
|         WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' | ||||
|         AND a.name='{source_node_id}' and b.name = '{target_node_id}' | ||||
|         WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace | ||||
|         AND a.name=:source_node_id and b.name = :target_node_id | ||||
|         COLUMNS (e.id,a.name as source_id) | ||||
|         ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""", | ||||
|     "get_node_edges": """SELECT source_name,target_name | ||||
|             FROM GRAPH_TABLE (lightrag_graph | ||||
|             MATCH (a)-[e]->(b) | ||||
|             WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}' | ||||
|             AND a.name='{source_node_id}' | ||||
|             WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace | ||||
|             AND a.name=:source_node_id | ||||
|             COLUMNS (a.name as source_name,b.name as target_name))""", | ||||
|     "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a | ||||
|                     USING DUAL | ||||
|                     ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}') | ||||
|                     ON (a.workspace=:workspace and a.name=:name) | ||||
|                 WHEN NOT MATCHED THEN | ||||
|                     INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) | ||||
|                     values (:1,:2,:3,:4,:5,:6,:7) """, | ||||
|                     values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) | ||||
|                 WHEN MATCHED THEN | ||||
|                     UPDATE SET | ||||
|                     entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", | ||||
|     "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a | ||||
|                     USING DUAL | ||||
|                     ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}') | ||||
|                     ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name) | ||||
|                 WHEN NOT MATCHED THEN | ||||
|                     INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) | ||||
|                     values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """, | ||||
|                     values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) | ||||
|                 WHEN MATCHED THEN | ||||
|                     UPDATE SET | ||||
|                     weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", | ||||
|     "get_all_nodes": """WITH t0 AS ( | ||||
|                         SELECT name AS id, entity_type AS label, entity_type, description, | ||||
|                             '["' || replace(source_chunk_id, '<SEP>', '","') || '"]'     source_chunk_ids | ||||
|                         FROM lightrag_graph_nodes | ||||
|                         WHERE workspace = :workspace | ||||
|                         ORDER BY createtime DESC fetch first :limit rows only | ||||
|                     ), t1 AS ( | ||||
|                         SELECT t0.id, source_chunk_id | ||||
|                         FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) ) | ||||
|                     ), t2 AS ( | ||||
|                         SELECT t1.id, LISTAGG(t2.content, '\n') content | ||||
|                         FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id | ||||
|                         GROUP BY t1.id | ||||
|                     ) | ||||
|                     SELECT t0.id, label, entity_type, description, t2.content | ||||
|                     FROM t0 LEFT JOIN t2 ON t0.id = t2.id""", | ||||
|     "get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target, | ||||
|                 t1.weight,t1.DESCRIPTION,t2.content | ||||
|                 FROM LIGHTRAG_GRAPH_EDGES t1 | ||||
|                 LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id | ||||
|                 WHERE t1.workspace=:workspace | ||||
|                 order by t1.CREATETIME DESC | ||||
|                 fetch first :limit rows only""", | ||||
|     "get_statistics": """select  count(distinct CASE WHEN type='node' THEN id END) as nodes_count, | ||||
|                 count(distinct CASE WHEN type='edge' THEN id END) as edges_count | ||||
|                 FROM ( | ||||
|                 select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph | ||||
|                     MATCH (a) WHERE a.workspace=:workspace columns(a.name as id)) | ||||
|                 UNION | ||||
|                 select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph | ||||
|                     MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id)) | ||||
|                 )""", | ||||
| } | ||||
|   | ||||
							
								
								
									
										1182
									
								
								minirag/kg/postgres_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1182
									
								
								minirag/kg/postgres_impl.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										136
									
								
								minirag/kg/postgres_impl_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								minirag/kg/postgres_impl_test.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,136 @@ | ||||
| import asyncio | ||||
| import sys | ||||
| import os | ||||
| import pipmaster as pm | ||||
|  | ||||
| if not pm.is_installed("psycopg-pool"): | ||||
|     pm.install("psycopg-pool") | ||||
|     pm.install("psycopg[binary,pool]") | ||||
| if not pm.is_installed("asyncpg"): | ||||
|     pm.install("asyncpg") | ||||
|  | ||||
| import asyncpg | ||||
| import psycopg | ||||
| from psycopg_pool import AsyncConnectionPool | ||||
| from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage | ||||
|  | ||||
| DB = "rag" | ||||
| USER = "rag" | ||||
| PASSWORD = "rag" | ||||
| HOST = "localhost" | ||||
| PORT = "15432" | ||||
| os.environ["AGE_GRAPH_NAME"] = "dickens" | ||||
|  | ||||
| if sys.platform.startswith("win"): | ||||
|     import asyncio.windows_events | ||||
|  | ||||
|     asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) | ||||
|  | ||||
|  | ||||
| async def get_pool(): | ||||
|     return await asyncpg.create_pool( | ||||
|         f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}", | ||||
|         min_size=10, | ||||
|         max_size=10, | ||||
|         max_queries=5000, | ||||
|         max_inactive_connection_lifetime=300.0, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def main1(): | ||||
|     connection_string = ( | ||||
|         f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" | ||||
|     ) | ||||
|     pool = AsyncConnectionPool(connection_string, open=False) | ||||
|     await pool.open() | ||||
|  | ||||
|     try: | ||||
|         conn = await pool.getconn(timeout=10) | ||||
|         async with conn.cursor() as curs: | ||||
|             try: | ||||
|                 await curs.execute('SET search_path = ag_catalog, "$user", public') | ||||
|                 await curs.execute("SELECT create_graph('dickens-2')") | ||||
|                 await conn.commit() | ||||
|                 print("create_graph success") | ||||
|             except ( | ||||
|                 psycopg.errors.InvalidSchemaName, | ||||
|                 psycopg.errors.UniqueViolation, | ||||
|             ): | ||||
|                 print("create_graph already exists") | ||||
|                 await conn.rollback() | ||||
|     finally: | ||||
|         pass | ||||
|  | ||||
|  | ||||
| db = PostgreSQLDB( | ||||
|     config={ | ||||
|         "host": "localhost", | ||||
|         "port": 15432, | ||||
|         "user": "rag", | ||||
|         "password": "rag", | ||||
|         "database": "r1", | ||||
|     } | ||||
| ) | ||||
|  | ||||
|  | ||||
| async def query_with_age(): | ||||
|     await db.initdb() | ||||
|     graph = PGGraphStorage( | ||||
|         namespace="chunk_entity_relation", | ||||
|         global_config={}, | ||||
|         embedding_func=None, | ||||
|     ) | ||||
|     graph.db = db | ||||
|     res = await graph.get_node('"A CHRISTMAS CAROL"') | ||||
|     print("Node is: ", res) | ||||
|     res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG") | ||||
|     print("Edge is: ", res) | ||||
|     res = await graph.get_node_edges('"SCROOGE"') | ||||
|     print("Node Edges are: ", res) | ||||
|  | ||||
|  | ||||
| async def create_edge_with_age(): | ||||
|     await db.initdb() | ||||
|     graph = PGGraphStorage( | ||||
|         namespace="chunk_entity_relation", | ||||
|         global_config={}, | ||||
|         embedding_func=None, | ||||
|     ) | ||||
|     graph.db = db | ||||
|     await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"}) | ||||
|     await graph.upsert_node('"THE GIRLS"', {"world": "hello"}) | ||||
|     await graph.upsert_edge( | ||||
|         '"THE CRATCHITS"', | ||||
|         '"THE GIRLS"', | ||||
|         edge_data={ | ||||
|             "weight": 7.0, | ||||
|             "description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.', | ||||
|             "keywords": '"family, collective effort"', | ||||
|             "source_id": "chunk-1d4b58de5429cd1261370c231c8673e8", | ||||
|         }, | ||||
|     ) | ||||
|     res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"') | ||||
|     print("Edge is: ", res) | ||||
|  | ||||
|  | ||||
| async def main(): | ||||
|     pool = await get_pool() | ||||
|     sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)" | ||||
|     # cypher = "MATCH (n:how_are_you_doing) RETURN n" | ||||
|     async with pool.acquire() as conn: | ||||
|         try: | ||||
|             await conn.execute( | ||||
|                 """SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""" | ||||
|             ) | ||||
|         except asyncpg.exceptions.InvalidSchemaNameError: | ||||
|             print("create_graph already exists") | ||||
|         # stmt = await conn.prepare(sql) | ||||
|         row = await conn.fetch(sql) | ||||
|         print("row is: ", row) | ||||
|  | ||||
|         row = await conn.fetchrow("select '100'::int + 200 as result") | ||||
|         print(row)  # <Record result=300> | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     asyncio.run(query_with_age()) | ||||
							
								
								
									
										70
									
								
								minirag/kg/redis_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								minirag/kg/redis_impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,70 @@ | ||||
| import os | ||||
| from tqdm.asyncio import tqdm as tqdm_async | ||||
| from dataclasses import dataclass | ||||
| import pipmaster as pm | ||||
|  | ||||
| if not pm.is_installed("redis"): | ||||
|     pm.install("redis") | ||||
|  | ||||
| # aioredis is a depricated library, replaced with redis | ||||
| from redis.asyncio import Redis | ||||
| from lightrag.utils import logger | ||||
| from lightrag.base import BaseKVStorage | ||||
| import json | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class RedisKVStorage(BaseKVStorage): | ||||
|     def __post_init__(self): | ||||
|         redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379") | ||||
|         self._redis = Redis.from_url(redis_url, decode_responses=True) | ||||
|         logger.info(f"Use Redis as KV {self.namespace}") | ||||
|  | ||||
|     async def all_keys(self) -> list[str]: | ||||
|         keys = await self._redis.keys(f"{self.namespace}:*") | ||||
|         return [key.split(":", 1)[-1] for key in keys] | ||||
|  | ||||
|     async def get_by_id(self, id): | ||||
|         data = await self._redis.get(f"{self.namespace}:{id}") | ||||
|         return json.loads(data) if data else None | ||||
|  | ||||
|     async def get_by_ids(self, ids, fields=None): | ||||
|         pipe = self._redis.pipeline() | ||||
|         for id in ids: | ||||
|             pipe.get(f"{self.namespace}:{id}") | ||||
|         results = await pipe.execute() | ||||
|  | ||||
|         if fields: | ||||
|             # Filter fields if specified | ||||
|             return [ | ||||
|                 {field: value.get(field) for field in fields if field in value} | ||||
|                 if (value := json.loads(result)) | ||||
|                 else None | ||||
|                 for result in results | ||||
|             ] | ||||
|  | ||||
|         return [json.loads(result) if result else None for result in results] | ||||
|  | ||||
|     async def filter_keys(self, data: list[str]) -> set[str]: | ||||
|         pipe = self._redis.pipeline() | ||||
|         for key in data: | ||||
|             pipe.exists(f"{self.namespace}:{key}") | ||||
|         results = await pipe.execute() | ||||
|  | ||||
|         existing_ids = {data[i] for i, exists in enumerate(results) if exists} | ||||
|         return set(data) - existing_ids | ||||
|  | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         pipe = self._redis.pipeline() | ||||
|         for k, v in tqdm_async(data.items(), desc="Upserting"): | ||||
|             pipe.set(f"{self.namespace}:{k}", json.dumps(v)) | ||||
|         await pipe.execute() | ||||
|  | ||||
|         for k in data: | ||||
|             data[k]["_id"] = k | ||||
|         return data | ||||
|  | ||||
|     async def drop(self): | ||||
|         keys = await self._redis.keys(f"{self.namespace}:*") | ||||
|         if keys: | ||||
|             await self._redis.delete(*keys) | ||||
							
								
								
									
										665
									
								
								minirag/kg/tidb_impl.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										665
									
								
								minirag/kg/tidb_impl.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,665 @@ | ||||
| import asyncio | ||||
| import os | ||||
| from dataclasses import dataclass | ||||
| from typing import Union | ||||
|  | ||||
| import numpy as np | ||||
| import pipmaster as pm | ||||
|  | ||||
| if not pm.is_installed("pymysql"): | ||||
|     pm.install("pymysql") | ||||
| if not pm.is_installed("sqlalchemy"): | ||||
|     pm.install("sqlalchemy") | ||||
|  | ||||
| from sqlalchemy import create_engine, text | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage | ||||
| from lightrag.utils import logger | ||||
|  | ||||
|  | ||||
| class TiDB(object): | ||||
|     def __init__(self, config, **kwargs): | ||||
|         self.host = config.get("host", None) | ||||
|         self.port = config.get("port", None) | ||||
|         self.user = config.get("user", None) | ||||
|         self.password = config.get("password", None) | ||||
|         self.database = config.get("database", None) | ||||
|         self.workspace = config.get("workspace", None) | ||||
|         connection_string = ( | ||||
|             f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" | ||||
|             f"?ssl_verify_cert=true&ssl_verify_identity=true" | ||||
|         ) | ||||
|  | ||||
|         try: | ||||
|             self.engine = create_engine(connection_string) | ||||
|             logger.info(f"Connected to TiDB database at {self.database}") | ||||
|         except Exception as e: | ||||
|             logger.error(f"Failed to connect to TiDB database at {self.database}") | ||||
|             logger.error(f"TiDB database error: {e}") | ||||
|             raise | ||||
|  | ||||
|     async def check_tables(self): | ||||
|         for k, v in TABLES.items(): | ||||
|             try: | ||||
|                 await self.query(f"SELECT 1 FROM {k}".format(k=k)) | ||||
|             except Exception as e: | ||||
|                 logger.error(f"Failed to check table {k} in TiDB database") | ||||
|                 logger.error(f"TiDB database error: {e}") | ||||
|                 try: | ||||
|                     # print(v["ddl"]) | ||||
|                     await self.execute(v["ddl"]) | ||||
|                     logger.info(f"Created table {k} in TiDB database") | ||||
|                 except Exception as e: | ||||
|                     logger.error(f"Failed to create table {k} in TiDB database") | ||||
|                     logger.error(f"TiDB database error: {e}") | ||||
|  | ||||
|     async def query( | ||||
|         self, sql: str, params: dict = None, multirows: bool = False | ||||
|     ) -> Union[dict, None]: | ||||
|         if params is None: | ||||
|             params = {"workspace": self.workspace} | ||||
|         else: | ||||
|             params.update({"workspace": self.workspace}) | ||||
|         with self.engine.connect() as conn, conn.begin(): | ||||
|             try: | ||||
|                 result = conn.execute(text(sql), params) | ||||
|             except Exception as e: | ||||
|                 logger.error(f"Tidb database error: {e}") | ||||
|                 print(sql) | ||||
|                 print(params) | ||||
|                 raise | ||||
|             if multirows: | ||||
|                 rows = result.all() | ||||
|                 if rows: | ||||
|                     data = [dict(zip(result.keys(), row)) for row in rows] | ||||
|                 else: | ||||
|                     data = [] | ||||
|             else: | ||||
|                 row = result.first() | ||||
|                 if row: | ||||
|                     data = dict(zip(result.keys(), row)) | ||||
|                 else: | ||||
|                     data = None | ||||
|             return data | ||||
|  | ||||
|     async def execute(self, sql: str, data: list | dict = None): | ||||
|         # logger.info("go into TiDBDB execute method") | ||||
|         try: | ||||
|             with self.engine.connect() as conn, conn.begin(): | ||||
|                 if data is None: | ||||
|                     conn.execute(text(sql)) | ||||
|                 else: | ||||
|                     conn.execute(text(sql), parameters=data) | ||||
|         except Exception as e: | ||||
|             logger.error(f"TiDB database error: {e}") | ||||
|             print(sql) | ||||
|             print(data) | ||||
|             raise | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class TiDBKVStorage(BaseKVStorage): | ||||
|     # should pass db object to self.db | ||||
|     def __post_init__(self): | ||||
|         self._data = {} | ||||
|         self._max_batch_size = self.global_config["embedding_batch_num"] | ||||
|  | ||||
|     ################ QUERY METHODS ################ | ||||
|  | ||||
|     async def get_by_id(self, id: str) -> Union[dict, None]: | ||||
|         """根据 id 获取 doc_full 数据.""" | ||||
|         SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] | ||||
|         params = {"id": id} | ||||
|         # print("get_by_id:"+SQL) | ||||
|         res = await self.db.query(SQL, params) | ||||
|         if res: | ||||
|             data = res  # {"data":res} | ||||
|             # print (data) | ||||
|             return data | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     # Query by id | ||||
|     async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: | ||||
|         """根据 id 获取 doc_chunks 数据""" | ||||
|         SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( | ||||
|             ids=",".join([f"'{id}'" for id in ids]) | ||||
|         ) | ||||
|         # print("get_by_ids:"+SQL) | ||||
|         res = await self.db.query(SQL, multirows=True) | ||||
|         if res: | ||||
|             data = res  # [{"data":i} for i in res] | ||||
|             # print(data) | ||||
|             return data | ||||
|         else: | ||||
|             return None | ||||
|  | ||||
|     async def filter_keys(self, keys: list[str]) -> set[str]: | ||||
|         """过滤掉重复内容""" | ||||
|         SQL = SQL_TEMPLATES["filter_keys"].format( | ||||
|             table_name=N_T[self.namespace], | ||||
|             id_field=N_ID[self.namespace], | ||||
|             ids=",".join([f"'{id}'" for id in keys]), | ||||
|         ) | ||||
|         try: | ||||
|             await self.db.query(SQL) | ||||
|         except Exception as e: | ||||
|             logger.error(f"Tidb database error: {e}") | ||||
|             print(SQL) | ||||
|         res = await self.db.query(SQL, multirows=True) | ||||
|         if res: | ||||
|             exist_keys = [key["id"] for key in res] | ||||
|             data = set([s for s in keys if s not in exist_keys]) | ||||
|         else: | ||||
|             exist_keys = [] | ||||
|             data = set([s for s in keys if s not in exist_keys]) | ||||
|         return data | ||||
|  | ||||
|     ################ INSERT full_doc AND chunks ################ | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         left_data = {k: v for k, v in data.items() if k not in self._data} | ||||
|         self._data.update(left_data) | ||||
|         if self.namespace == "text_chunks": | ||||
|             list_data = [ | ||||
|                 { | ||||
|                     "__id__": k, | ||||
|                     **{k1: v1 for k1, v1 in v.items()}, | ||||
|                 } | ||||
|                 for k, v in data.items() | ||||
|             ] | ||||
|             contents = [v["content"] for v in data.values()] | ||||
|             batches = [ | ||||
|                 contents[i : i + self._max_batch_size] | ||||
|                 for i in range(0, len(contents), self._max_batch_size) | ||||
|             ] | ||||
|             embeddings_list = await asyncio.gather( | ||||
|                 *[self.embedding_func(batch) for batch in batches] | ||||
|             ) | ||||
|             embeddings = np.concatenate(embeddings_list) | ||||
|             for i, d in enumerate(list_data): | ||||
|                 d["__vector__"] = embeddings[i] | ||||
|  | ||||
|             merge_sql = SQL_TEMPLATES["upsert_chunk"] | ||||
|             data = [] | ||||
|             for item in list_data: | ||||
|                 data.append( | ||||
|                     { | ||||
|                         "id": item["__id__"], | ||||
|                         "content": item["content"], | ||||
|                         "tokens": item["tokens"], | ||||
|                         "chunk_order_index": item["chunk_order_index"], | ||||
|                         "full_doc_id": item["full_doc_id"], | ||||
|                         "content_vector": f"{item["__vector__"].tolist()}", | ||||
|                         "workspace": self.db.workspace, | ||||
|                     } | ||||
|                 ) | ||||
|             await self.db.execute(merge_sql, data) | ||||
|  | ||||
|         if self.namespace == "full_docs": | ||||
|             merge_sql = SQL_TEMPLATES["upsert_doc_full"] | ||||
|             data = [] | ||||
|             for k, v in self._data.items(): | ||||
|                 data.append( | ||||
|                     { | ||||
|                         "id": k, | ||||
|                         "content": v["content"], | ||||
|                         "workspace": self.db.workspace, | ||||
|                     } | ||||
|                 ) | ||||
|             await self.db.execute(merge_sql, data) | ||||
|         return left_data | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         if self.namespace in ["full_docs", "text_chunks"]: | ||||
|             logger.info("full doc and chunk data had been saved into TiDB db!") | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class TiDBVectorDBStorage(BaseVectorStorage): | ||||
|     cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         self._client_file_name = os.path.join( | ||||
|             self.global_config["working_dir"], f"vdb_{self.namespace}.json" | ||||
|         ) | ||||
|         self._max_batch_size = self.global_config["embedding_batch_num"] | ||||
|         # Use global config value if specified, otherwise use default | ||||
|         config = self.global_config.get("vector_db_storage_cls_kwargs", {}) | ||||
|         self.cosine_better_than_threshold = config.get( | ||||
|             "cosine_better_than_threshold", self.cosine_better_than_threshold | ||||
|         ) | ||||
|  | ||||
|     async def query(self, query: str, top_k: int) -> list[dict]: | ||||
|         """search from tidb vector""" | ||||
|  | ||||
|         embeddings = await self.embedding_func([query]) | ||||
|         embedding = embeddings[0] | ||||
|  | ||||
|         embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]" | ||||
|  | ||||
|         params = { | ||||
|             "embedding_string": embedding_string, | ||||
|             "top_k": top_k, | ||||
|             "better_than_threshold": self.cosine_better_than_threshold, | ||||
|         } | ||||
|  | ||||
|         results = await self.db.query( | ||||
|             SQL_TEMPLATES[self.namespace], params=params, multirows=True | ||||
|         ) | ||||
|         print("vector search result:", results) | ||||
|         if not results: | ||||
|             return [] | ||||
|         return results | ||||
|  | ||||
|     ###### INSERT entities And relationships ###### | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         # ignore, upsert in TiDBKVStorage already | ||||
|         if not len(data): | ||||
|             logger.warning("You insert an empty data to vector DB") | ||||
|             return [] | ||||
|         if self.namespace == "chunks": | ||||
|             return [] | ||||
|         logger.info(f"Inserting {len(data)} vectors to {self.namespace}") | ||||
|  | ||||
|         list_data = [ | ||||
|             { | ||||
|                 "id": k, | ||||
|                 **{k1: v1 for k1, v1 in v.items()}, | ||||
|             } | ||||
|             for k, v in data.items() | ||||
|         ] | ||||
|         contents = [v["content"] for v in data.values()] | ||||
|         batches = [ | ||||
|             contents[i : i + self._max_batch_size] | ||||
|             for i in range(0, len(contents), self._max_batch_size) | ||||
|         ] | ||||
|         embedding_tasks = [self.embedding_func(batch) for batch in batches] | ||||
|         embeddings_list = [] | ||||
|         for f in tqdm( | ||||
|             asyncio.as_completed(embedding_tasks), | ||||
|             total=len(embedding_tasks), | ||||
|             desc="Generating embeddings", | ||||
|             unit="batch", | ||||
|         ): | ||||
|             embeddings = await f | ||||
|             embeddings_list.append(embeddings) | ||||
|         embeddings = np.concatenate(embeddings_list) | ||||
|         for i, d in enumerate(list_data): | ||||
|             d["content_vector"] = embeddings[i] | ||||
|  | ||||
|         if self.namespace == "entities": | ||||
|             data = [] | ||||
|             for item in list_data: | ||||
|                 param = { | ||||
|                     "id": item["id"], | ||||
|                     "name": item["entity_name"], | ||||
|                     "content": item["content"], | ||||
|                     "content_vector": f"{item["content_vector"].tolist()}", | ||||
|                     "workspace": self.db.workspace, | ||||
|                 } | ||||
|                 # update entity_id if node inserted by graph_storage_instance before | ||||
|                 has = await self.db.query(SQL_TEMPLATES["has_entity"], param) | ||||
|                 if has["cnt"] != 0: | ||||
|                     await self.db.execute(SQL_TEMPLATES["update_entity"], param) | ||||
|                     continue | ||||
|  | ||||
|                 data.append(param) | ||||
|             if data: | ||||
|                 merge_sql = SQL_TEMPLATES["insert_entity"] | ||||
|                 await self.db.execute(merge_sql, data) | ||||
|  | ||||
|         elif self.namespace == "relationships": | ||||
|             data = [] | ||||
|             for item in list_data: | ||||
|                 param = { | ||||
|                     "id": item["id"], | ||||
|                     "source_name": item["src_id"], | ||||
|                     "target_name": item["tgt_id"], | ||||
|                     "content": item["content"], | ||||
|                     "content_vector": f"{item["content_vector"].tolist()}", | ||||
|                     "workspace": self.db.workspace, | ||||
|                 } | ||||
|                 # update relation_id if node inserted by graph_storage_instance before | ||||
|                 has = await self.db.query(SQL_TEMPLATES["has_relationship"], param) | ||||
|                 if has["cnt"] != 0: | ||||
|                     await self.db.execute(SQL_TEMPLATES["update_relationship"], param) | ||||
|                     continue | ||||
|  | ||||
|                 data.append(param) | ||||
|             if data: | ||||
|                 merge_sql = SQL_TEMPLATES["insert_relationship"] | ||||
|                 await self.db.execute(merge_sql, data) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class TiDBGraphStorage(BaseGraphStorage): | ||||
|     def __post_init__(self): | ||||
|         self._max_batch_size = self.global_config["embedding_batch_num"] | ||||
|  | ||||
|     #################### upsert method ################ | ||||
|     async def upsert_node(self, node_id: str, node_data: dict[str, str]): | ||||
|         entity_name = node_id | ||||
|         entity_type = node_data["entity_type"] | ||||
|         description = node_data["description"] | ||||
|         source_id = node_data["source_id"] | ||||
|         logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}") | ||||
|         content = entity_name + description | ||||
|         contents = [content] | ||||
|         batches = [ | ||||
|             contents[i : i + self._max_batch_size] | ||||
|             for i in range(0, len(contents), self._max_batch_size) | ||||
|         ] | ||||
|         embeddings_list = await asyncio.gather( | ||||
|             *[self.embedding_func(batch) for batch in batches] | ||||
|         ) | ||||
|         embeddings = np.concatenate(embeddings_list) | ||||
|         content_vector = embeddings[0] | ||||
|         sql = SQL_TEMPLATES["upsert_node"] | ||||
|         data = { | ||||
|             "workspace": self.db.workspace, | ||||
|             "name": entity_name, | ||||
|             "entity_type": entity_type, | ||||
|             "description": description, | ||||
|             "source_chunk_id": source_id, | ||||
|             "content": content, | ||||
|             "content_vector": f"{content_vector.tolist()}", | ||||
|         } | ||||
|         await self.db.execute(sql, data) | ||||
|  | ||||
|     async def upsert_edge( | ||||
|         self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] | ||||
|     ): | ||||
|         source_name = source_node_id | ||||
|         target_name = target_node_id | ||||
|         weight = edge_data["weight"] | ||||
|         keywords = edge_data["keywords"] | ||||
|         description = edge_data["description"] | ||||
|         source_chunk_id = edge_data["source_id"] | ||||
|         logger.debug( | ||||
|             f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}" | ||||
|         ) | ||||
|  | ||||
|         content = keywords + source_name + target_name + description | ||||
|         contents = [content] | ||||
|         batches = [ | ||||
|             contents[i : i + self._max_batch_size] | ||||
|             for i in range(0, len(contents), self._max_batch_size) | ||||
|         ] | ||||
|         embeddings_list = await asyncio.gather( | ||||
|             *[self.embedding_func(batch) for batch in batches] | ||||
|         ) | ||||
|         embeddings = np.concatenate(embeddings_list) | ||||
|         content_vector = embeddings[0] | ||||
|         merge_sql = SQL_TEMPLATES["upsert_edge"] | ||||
|         data = { | ||||
|             "workspace": self.db.workspace, | ||||
|             "source_name": source_name, | ||||
|             "target_name": target_name, | ||||
|             "weight": weight, | ||||
|             "keywords": keywords, | ||||
|             "description": description, | ||||
|             "source_chunk_id": source_chunk_id, | ||||
|             "content": content, | ||||
|             "content_vector": f"{content_vector.tolist()}", | ||||
|         } | ||||
|         await self.db.execute(merge_sql, data) | ||||
|  | ||||
|     async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: | ||||
|         if algorithm not in self._node_embed_algorithms: | ||||
|             raise ValueError(f"Node embedding algorithm {algorithm} not supported") | ||||
|         return await self._node_embed_algorithms[algorithm]() | ||||
|  | ||||
|     # Query | ||||
|  | ||||
|     async def has_node(self, node_id: str) -> bool: | ||||
|         sql = SQL_TEMPLATES["has_entity"] | ||||
|         param = {"name": node_id, "workspace": self.db.workspace} | ||||
|         has = await self.db.query(sql, param) | ||||
|         return has["cnt"] != 0 | ||||
|  | ||||
|     async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: | ||||
|         sql = SQL_TEMPLATES["has_relationship"] | ||||
|         param = { | ||||
|             "source_name": source_node_id, | ||||
|             "target_name": target_node_id, | ||||
|             "workspace": self.db.workspace, | ||||
|         } | ||||
|         has = await self.db.query(sql, param) | ||||
|         return has["cnt"] != 0 | ||||
|  | ||||
|     async def node_degree(self, node_id: str) -> int: | ||||
|         sql = SQL_TEMPLATES["node_degree"] | ||||
|         param = {"name": node_id, "workspace": self.db.workspace} | ||||
|         result = await self.db.query(sql, param) | ||||
|         return result["cnt"] | ||||
|  | ||||
|     async def edge_degree(self, src_id: str, tgt_id: str) -> int: | ||||
|         degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) | ||||
|         return degree | ||||
|  | ||||
|     async def get_node(self, node_id: str) -> Union[dict, None]: | ||||
|         sql = SQL_TEMPLATES["get_node"] | ||||
|         param = {"name": node_id, "workspace": self.db.workspace} | ||||
|         return await self.db.query(sql, param) | ||||
|  | ||||
|     async def get_edge( | ||||
|         self, source_node_id: str, target_node_id: str | ||||
|     ) -> Union[dict, None]: | ||||
|         sql = SQL_TEMPLATES["get_edge"] | ||||
|         param = { | ||||
|             "source_name": source_node_id, | ||||
|             "target_name": target_node_id, | ||||
|             "workspace": self.db.workspace, | ||||
|         } | ||||
|         return await self.db.query(sql, param) | ||||
|  | ||||
|     async def get_node_edges( | ||||
|         self, source_node_id: str | ||||
|     ) -> Union[list[tuple[str, str]], None]: | ||||
|         sql = SQL_TEMPLATES["get_node_edges"] | ||||
|         param = {"source_name": source_node_id, "workspace": self.db.workspace} | ||||
|         res = await self.db.query(sql, param, multirows=True) | ||||
|         if res: | ||||
|             data = [(i["source_name"], i["target_name"]) for i in res] | ||||
|             return data | ||||
|         else: | ||||
|             return [] | ||||
|  | ||||
|  | ||||
| N_T = { | ||||
|     "full_docs": "LIGHTRAG_DOC_FULL", | ||||
|     "text_chunks": "LIGHTRAG_DOC_CHUNKS", | ||||
|     "chunks": "LIGHTRAG_DOC_CHUNKS", | ||||
|     "entities": "LIGHTRAG_GRAPH_NODES", | ||||
|     "relationships": "LIGHTRAG_GRAPH_EDGES", | ||||
| } | ||||
| N_ID = { | ||||
|     "full_docs": "doc_id", | ||||
|     "text_chunks": "chunk_id", | ||||
|     "chunks": "chunk_id", | ||||
|     "entities": "entity_id", | ||||
|     "relationships": "relation_id", | ||||
| } | ||||
|  | ||||
| TABLES = { | ||||
|     "LIGHTRAG_DOC_FULL": { | ||||
|         "ddl": """ | ||||
|         CREATE TABLE LIGHTRAG_DOC_FULL ( | ||||
|             `id` BIGINT PRIMARY KEY AUTO_RANDOM, | ||||
|             `doc_id` VARCHAR(256) NOT NULL, | ||||
|             `workspace` varchar(1024), | ||||
|             `content` LONGTEXT, | ||||
|             `meta` JSON, | ||||
|             `createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | ||||
|             `updatetime` TIMESTAMP DEFAULT NULL, | ||||
|             UNIQUE KEY (`doc_id`) | ||||
|         ); | ||||
|         """ | ||||
|     }, | ||||
|     "LIGHTRAG_DOC_CHUNKS": { | ||||
|         "ddl": """ | ||||
|         CREATE TABLE LIGHTRAG_DOC_CHUNKS ( | ||||
|             `id` BIGINT PRIMARY KEY AUTO_RANDOM, | ||||
|             `chunk_id` VARCHAR(256) NOT NULL, | ||||
|             `full_doc_id` VARCHAR(256) NOT NULL, | ||||
|             `workspace` varchar(1024), | ||||
|             `chunk_order_index` INT, | ||||
|             `tokens` INT, | ||||
|             `content` LONGTEXT, | ||||
|             `content_vector` VECTOR, | ||||
|             `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP, | ||||
|             `updatetime` DATETIME DEFAULT NULL, | ||||
|             UNIQUE KEY (`chunk_id`) | ||||
|         ); | ||||
|         """ | ||||
|     }, | ||||
|     "LIGHTRAG_GRAPH_NODES": { | ||||
|         "ddl": """ | ||||
|         CREATE TABLE LIGHTRAG_GRAPH_NODES ( | ||||
|             `id` BIGINT PRIMARY KEY AUTO_RANDOM, | ||||
|             `entity_id`  VARCHAR(256), | ||||
|             `workspace` varchar(1024), | ||||
|             `name` VARCHAR(2048), | ||||
|             `entity_type` VARCHAR(1024), | ||||
|             `description` LONGTEXT, | ||||
|             `source_chunk_id` VARCHAR(256), | ||||
|             `content` LONGTEXT, | ||||
|             `content_vector` VECTOR, | ||||
|             `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP, | ||||
|             `updatetime` DATETIME DEFAULT NULL, | ||||
|             KEY (`entity_id`) | ||||
|         ); | ||||
|         """ | ||||
|     }, | ||||
|     "LIGHTRAG_GRAPH_EDGES": { | ||||
|         "ddl": """ | ||||
|         CREATE TABLE LIGHTRAG_GRAPH_EDGES ( | ||||
|             `id` BIGINT PRIMARY KEY AUTO_RANDOM, | ||||
|             `relation_id`  VARCHAR(256), | ||||
|             `workspace` varchar(1024), | ||||
|             `source_name` VARCHAR(2048), | ||||
|             `target_name` VARCHAR(2048), | ||||
|             `weight` DECIMAL, | ||||
|             `keywords` TEXT, | ||||
|             `description` LONGTEXT, | ||||
|             `source_chunk_id` varchar(256), | ||||
|             `content` LONGTEXT, | ||||
|             `content_vector` VECTOR, | ||||
|             `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP, | ||||
|             `updatetime` DATETIME DEFAULT NULL, | ||||
|             KEY (`relation_id`) | ||||
|         ); | ||||
|         """ | ||||
|     }, | ||||
|     "LIGHTRAG_LLM_CACHE": { | ||||
|         "ddl": """ | ||||
|         CREATE TABLE LIGHTRAG_LLM_CACHE ( | ||||
|             id BIGINT PRIMARY KEY AUTO_INCREMENT, | ||||
|             send TEXT, | ||||
|             return TEXT, | ||||
|             model VARCHAR(1024), | ||||
|             createtime DATETIME DEFAULT CURRENT_TIMESTAMP, | ||||
|             updatetime DATETIME DEFAULT NULL | ||||
|         ); | ||||
|         """ | ||||
|     }, | ||||
| } | ||||
|  | ||||
|  | ||||
| SQL_TEMPLATES = { | ||||
|     # SQL for KVStorage | ||||
|     "get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace", | ||||
|     "get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace", | ||||
|     "get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace", | ||||
|     "get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace", | ||||
|     "filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace", | ||||
|     # SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE) | ||||
|     "upsert_doc_full": """ | ||||
|         INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace) | ||||
|         VALUES (:id, :content, :workspace) | ||||
|         ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP | ||||
|     """, | ||||
|     "upsert_chunk": """ | ||||
|         INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace) | ||||
|         VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace) | ||||
|         ON DUPLICATE KEY UPDATE | ||||
|         content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index), | ||||
|         full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP | ||||
|     """, | ||||
|     # SQL for VectorStorage | ||||
|     "entities": """SELECT n.name as entity_name FROM | ||||
|         (SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance | ||||
|         FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n | ||||
|         WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k | ||||
|     """, | ||||
|     "relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id FROM | ||||
|         (SELECT source_name, target_name, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance | ||||
|         FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e | ||||
|         WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k | ||||
|     """, | ||||
|     "chunks": """SELECT c.id FROM | ||||
|         (SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance | ||||
|         FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c | ||||
|         WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k | ||||
|     """, | ||||
|     "has_entity": """ | ||||
|         SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace | ||||
|     """, | ||||
|     "has_relationship": """ | ||||
|         SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace | ||||
|     """, | ||||
|     "update_entity": """ | ||||
|         UPDATE LIGHTRAG_GRAPH_NODES SET | ||||
|             entity_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP | ||||
|         WHERE workspace = :workspace AND name = :name | ||||
|     """, | ||||
|     "update_relationship": """ | ||||
|         UPDATE LIGHTRAG_GRAPH_EDGES SET | ||||
|             relation_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP | ||||
|         WHERE workspace = :workspace AND source_name = :source_name AND target_name = :target_name | ||||
|     """, | ||||
|     "insert_entity": """ | ||||
|         INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace) | ||||
|         VALUES(:id, :name, :content, :content_vector, :workspace) | ||||
|     """, | ||||
|     "insert_relationship": """ | ||||
|         INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace) | ||||
|         VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace) | ||||
|     """, | ||||
|     # SQL for GraphStorage | ||||
|     "get_node": """ | ||||
|         SELECT entity_id AS id, workspace, name, entity_type, description, source_chunk_id AS source_id, content, content_vector | ||||
|         FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace | ||||
|     """, | ||||
|     "get_edge": """ | ||||
|         SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id AS source_id, content, content_vector | ||||
|         FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace | ||||
|     """, | ||||
|     "get_node_edges": """ | ||||
|         SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id, content, content_vector | ||||
|         FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND workspace = :workspace | ||||
|     """, | ||||
|     "node_degree": """ | ||||
|         SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace AND :name IN (source_name, target_name) | ||||
|     """, | ||||
|     "upsert_node": """ | ||||
|         INSERT INTO LIGHTRAG_GRAPH_NODES(name, content, content_vector, workspace, source_chunk_id, entity_type, description) | ||||
|         VALUES(:name, :content, :content_vector, :workspace, :source_chunk_id, :entity_type, :description) | ||||
|         ON DUPLICATE KEY UPDATE | ||||
|         name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector), | ||||
|         workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP, | ||||
|         source_chunk_id = VALUES(source_chunk_id), entity_type = VALUES(entity_type), description = VALUES(description) | ||||
|     """, | ||||
|     "upsert_edge": """ | ||||
|         INSERT INTO LIGHTRAG_GRAPH_EDGES(source_name, target_name, content, content_vector, | ||||
|             workspace, weight, keywords, description, source_chunk_id) | ||||
|         VALUES(:source_name, :target_name, :content, :content_vector, | ||||
|             :workspace, :weight, :keywords, :description, :source_chunk_id) | ||||
|         ON DUPLICATE KEY UPDATE | ||||
|         source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content), | ||||
|         content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP, | ||||
|         weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description), | ||||
|         source_chunk_id = VALUES(source_chunk_id) | ||||
|     """, | ||||
| } | ||||
							
								
								
									
										730
									
								
								minirag/llm.py
									
									
									
									
									
								
							
							
						
						
									
										730
									
								
								minirag/llm.py
									
									
									
									
									
								
							| @@ -1,729 +1,5 @@ | ||||
| import os | ||||
| import copy | ||||
| from functools import lru_cache | ||||
| import json | ||||
| import aioboto3 | ||||
| import aiohttp | ||||
| import numpy as np | ||||
| import ollama | ||||
|  | ||||
| from openai import ( | ||||
|     AsyncOpenAI, | ||||
|     APIConnectionError, | ||||
|     RateLimitError, | ||||
|     Timeout, | ||||
|     AsyncAzureOpenAI, | ||||
| ) | ||||
|  | ||||
| import base64 | ||||
| import struct | ||||
|  | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
| from transformers import AutoTokenizer, AutoModelForCausalLM | ||||
| import torch | ||||
| from pydantic import BaseModel, Field | ||||
| from typing import List, Dict, Callable, Any | ||||
| from .base import BaseKVStorage | ||||
| from .utils import compute_args_hash, wrap_embedding_func_with_attrs | ||||
|  | ||||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|     retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), | ||||
| ) | ||||
| async def openai_complete_if_cache( | ||||
|     model, | ||||
|     prompt, | ||||
|     system_prompt=None, | ||||
|     history_messages=[], | ||||
|     base_url=None, | ||||
|     api_key=None, | ||||
|     **kwargs, | ||||
| ) -> str: | ||||
|     if api_key: | ||||
|         os.environ["OPENAI_API_KEY"] = api_key | ||||
|  | ||||
|     openai_async_client = ( | ||||
|         AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) | ||||
|     ) | ||||
|     hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) | ||||
|     messages = [] | ||||
|     if system_prompt: | ||||
|         messages.append({"role": "system", "content": system_prompt}) | ||||
|     messages.extend(history_messages) | ||||
|     messages.append({"role": "user", "content": prompt}) | ||||
|     if hashing_kv is not None: | ||||
|         args_hash = compute_args_hash(model, messages) | ||||
|         if_cache_return = await hashing_kv.get_by_id(args_hash) | ||||
|         if if_cache_return is not None: | ||||
|             return if_cache_return["return"] | ||||
|  | ||||
|     response = await openai_async_client.chat.completions.create( | ||||
|         model=model, messages=messages, **kwargs | ||||
|     ) | ||||
|  | ||||
|     if hashing_kv is not None: | ||||
|         await hashing_kv.upsert( | ||||
|             {args_hash: {"return": response.choices[0].message.content, "model": model}} | ||||
|         ) | ||||
|      | ||||
|     return response.choices[0].message.content | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|     retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), | ||||
| ) | ||||
| async def azure_openai_complete_if_cache( | ||||
|     model, | ||||
|     prompt, | ||||
|     system_prompt=None, | ||||
|     history_messages=[], | ||||
|     base_url=None, | ||||
|     api_key=None, | ||||
|     **kwargs, | ||||
| ): | ||||
|     if api_key: | ||||
|         os.environ["AZURE_OPENAI_API_KEY"] = api_key | ||||
|     if base_url: | ||||
|         os.environ["AZURE_OPENAI_ENDPOINT"] = base_url | ||||
|  | ||||
|     openai_async_client = AsyncAzureOpenAI( | ||||
|         azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), | ||||
|         api_key=os.getenv("AZURE_OPENAI_API_KEY"), | ||||
|         api_version=os.getenv("AZURE_OPENAI_API_VERSION"), | ||||
|     ) | ||||
|  | ||||
|     hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) | ||||
|     messages = [] | ||||
|     if system_prompt: | ||||
|         messages.append({"role": "system", "content": system_prompt}) | ||||
|     messages.extend(history_messages) | ||||
|     if prompt is not None: | ||||
|         messages.append({"role": "user", "content": prompt}) | ||||
|     if hashing_kv is not None: | ||||
|         args_hash = compute_args_hash(model, messages) | ||||
|         if_cache_return = await hashing_kv.get_by_id(args_hash) | ||||
|         if if_cache_return is not None: | ||||
|             return if_cache_return["return"] | ||||
|  | ||||
|     response = await openai_async_client.chat.completions.create( | ||||
|         model=model, messages=messages, **kwargs | ||||
|     ) | ||||
|  | ||||
|     if hashing_kv is not None: | ||||
|         await hashing_kv.upsert( | ||||
|             {args_hash: {"return": response.choices[0].message.content, "model": model}} | ||||
|         ) | ||||
|     return response.choices[0].message.content | ||||
|  | ||||
|  | ||||
| class BedrockError(Exception): | ||||
|     """Generic error for issues related to Amazon Bedrock""" | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(5), | ||||
|     wait=wait_exponential(multiplier=1, max=60), | ||||
|     retry=retry_if_exception_type((BedrockError)), | ||||
| ) | ||||
| async def bedrock_complete_if_cache( | ||||
|     model, | ||||
|     prompt, | ||||
|     system_prompt=None, | ||||
|     history_messages=[], | ||||
|     aws_access_key_id=None, | ||||
|     aws_secret_access_key=None, | ||||
|     aws_session_token=None, | ||||
|     **kwargs, | ||||
| ) -> str: | ||||
|     os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( | ||||
|         "AWS_ACCESS_KEY_ID", aws_access_key_id | ||||
|     ) | ||||
|     os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( | ||||
|         "AWS_SECRET_ACCESS_KEY", aws_secret_access_key | ||||
|     ) | ||||
|     os.environ["AWS_SESSION_TOKEN"] = os.environ.get( | ||||
|         "AWS_SESSION_TOKEN", aws_session_token | ||||
|     ) | ||||
|  | ||||
|     # Fix message history format | ||||
|     messages = [] | ||||
|     for history_message in history_messages: | ||||
|         message = copy.copy(history_message) | ||||
|         message["content"] = [{"text": message["content"]}] | ||||
|         messages.append(message) | ||||
|  | ||||
|     # Add user prompt | ||||
|     messages.append({"role": "user", "content": [{"text": prompt}]}) | ||||
|  | ||||
|     # Initialize Converse API arguments | ||||
|     args = {"modelId": model, "messages": messages} | ||||
|  | ||||
|     # Define system prompt | ||||
|     if system_prompt: | ||||
|         args["system"] = [{"text": system_prompt}] | ||||
|  | ||||
|     # Map and set up inference parameters | ||||
|     inference_params_map = { | ||||
|         "max_tokens": "maxTokens", | ||||
|         "top_p": "topP", | ||||
|         "stop_sequences": "stopSequences", | ||||
|     } | ||||
|     if inference_params := list( | ||||
|         set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) | ||||
|     ): | ||||
|         args["inferenceConfig"] = {} | ||||
|         for param in inference_params: | ||||
|             args["inferenceConfig"][inference_params_map.get(param, param)] = ( | ||||
|                 kwargs.pop(param) | ||||
|             ) | ||||
|  | ||||
|     hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) | ||||
|     if hashing_kv is not None: | ||||
|         args_hash = compute_args_hash(model, messages) | ||||
|         if_cache_return = await hashing_kv.get_by_id(args_hash) | ||||
|         if if_cache_return is not None: | ||||
|             return if_cache_return["return"] | ||||
|  | ||||
|     # Call model via Converse API | ||||
|     session = aioboto3.Session() | ||||
|     async with session.client("bedrock-runtime") as bedrock_async_client: | ||||
|         try: | ||||
|             response = await bedrock_async_client.converse(**args, **kwargs) | ||||
|         except Exception as e: | ||||
|             raise BedrockError(e) | ||||
|  | ||||
|         if hashing_kv is not None: | ||||
|             await hashing_kv.upsert( | ||||
|                 { | ||||
|                     args_hash: { | ||||
|                         "return": response["output"]["message"]["content"][0]["text"], | ||||
|                         "model": model, | ||||
|                     } | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|         return response["output"]["message"]["content"][0]["text"] | ||||
|  | ||||
|  | ||||
| @lru_cache(maxsize=1) | ||||
| def initialize_hf_model(model_name): | ||||
|     hf_tokenizer = AutoTokenizer.from_pretrained( | ||||
|         model_name, device_map="auto", trust_remote_code=True#False | ||||
|     ) | ||||
|     hf_model = AutoModelForCausalLM.from_pretrained( | ||||
|         model_name, device_map="auto", trust_remote_code=True | ||||
|     ) | ||||
|     if hf_tokenizer.pad_token is None: | ||||
|         hf_tokenizer.pad_token = hf_tokenizer.eos_token | ||||
|  | ||||
|     return hf_model, hf_tokenizer | ||||
|  | ||||
|  | ||||
| async def hf_model_if_cache( | ||||
|     model, prompt, system_prompt=None, history_messages=[], **kwargs | ||||
| ) -> str: | ||||
|     model_name = model | ||||
|     hf_model, hf_tokenizer = initialize_hf_model(model_name) | ||||
|     hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) | ||||
|     messages = [] | ||||
|     if system_prompt: | ||||
|         messages.append({"role": "system", "content": system_prompt}) | ||||
|     messages.extend(history_messages) | ||||
|     messages.append({"role": "user", "content": prompt}) | ||||
|  | ||||
|     if hashing_kv is not None: | ||||
|         args_hash = compute_args_hash(model, messages) | ||||
|         if_cache_return = await hashing_kv.get_by_id(args_hash) | ||||
|         if if_cache_return is not None: | ||||
|             return if_cache_return["return"] | ||||
|     input_prompt = "" | ||||
|     try: | ||||
|         input_prompt = hf_tokenizer.apply_chat_template( | ||||
|             messages, tokenize=False, add_generation_prompt=True | ||||
|         ) | ||||
|     except Exception: | ||||
|         try: | ||||
|             ori_message = copy.deepcopy(messages) | ||||
|             if messages[0]["role"] == "system": | ||||
|                 messages[1]["content"] = ( | ||||
|                     "<system>" | ||||
|                     + messages[0]["content"] | ||||
|                     + "</system>\n" | ||||
|                     + messages[1]["content"] | ||||
|                 ) | ||||
|                 messages = messages[1:] | ||||
|                 input_prompt = hf_tokenizer.apply_chat_template( | ||||
|                     messages, tokenize=False, add_generation_prompt=True | ||||
|                 ) | ||||
|         except Exception: | ||||
|             len_message = len(ori_message) | ||||
|             for msgid in range(len_message): | ||||
|                 input_prompt = ( | ||||
|                     input_prompt | ||||
|                     + "<" | ||||
|                     + ori_message[msgid]["role"] | ||||
|                     + ">" | ||||
|                     + ori_message[msgid]["content"] | ||||
|                     + "</" | ||||
|                     + ori_message[msgid]["role"] | ||||
|                     + ">\n" | ||||
|                 ) | ||||
|  | ||||
|     input_ids = hf_tokenizer( | ||||
|         input_prompt, return_tensors="pt", padding=True, truncation=True | ||||
|     ).to("cuda") | ||||
|     torch.cuda.empty_cache() | ||||
|     # inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()} | ||||
|     output = hf_model.generate( | ||||
|         **input_ids, max_new_tokens=500, num_return_sequences=1, early_stopping=True | ||||
|     ) | ||||
|     response_text = hf_tokenizer.decode( | ||||
|         output[0][len(input_ids[0]) :], skip_special_tokens=True | ||||
|     ) | ||||
|      | ||||
|      | ||||
|     FINDSTRING = "<|COMPLETE|>" | ||||
|     last_assistant_index = response_text.find(FINDSTRING) | ||||
|  | ||||
|     if last_assistant_index != -1: | ||||
|         response_text = response_text[:last_assistant_index + len(FINDSTRING)] | ||||
|     else: | ||||
|         response_text = response_text | ||||
|      | ||||
|     if hashing_kv is not None: | ||||
|         await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}}) | ||||
|  | ||||
|     return response_text | ||||
|  | ||||
|  | ||||
| async def ollama_model_if_cache( | ||||
|     model, prompt, system_prompt=None, history_messages=[], **kwargs | ||||
| ) -> str: | ||||
|     kwargs.pop("max_tokens", None) | ||||
|     kwargs.pop("response_format", None) | ||||
|     host = kwargs.pop("host", None) | ||||
|     timeout = kwargs.pop("timeout", None) | ||||
|  | ||||
|     ollama_client = ollama.AsyncClient(host=host, timeout=timeout) | ||||
|     messages = [] | ||||
|     if system_prompt: | ||||
|         messages.append({"role": "system", "content": system_prompt}) | ||||
|  | ||||
|     hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) | ||||
|     messages.extend(history_messages) | ||||
|     messages.append({"role": "user", "content": prompt}) | ||||
|     if hashing_kv is not None: | ||||
|         args_hash = compute_args_hash(model, messages) | ||||
|         if_cache_return = await hashing_kv.get_by_id(args_hash) | ||||
|         if if_cache_return is not None: | ||||
|             return if_cache_return["return"] | ||||
|  | ||||
|     response = await ollama_client.chat(model=model, messages=messages, **kwargs) | ||||
|  | ||||
|     result = response["message"]["content"] | ||||
|  | ||||
|     if hashing_kv is not None: | ||||
|         await hashing_kv.upsert({args_hash: {"return": result, "model": model}}) | ||||
|  | ||||
|     return result | ||||
|  | ||||
|  | ||||
| @lru_cache(maxsize=1) | ||||
| def initialize_lmdeploy_pipeline( | ||||
|     model, | ||||
|     tp=1, | ||||
|     chat_template=None, | ||||
|     log_level="WARNING", | ||||
|     model_format="hf", | ||||
|     quant_policy=0, | ||||
| ): | ||||
|     from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig | ||||
|  | ||||
|     lmdeploy_pipe = pipeline( | ||||
|         model_path=model, | ||||
|         backend_config=TurbomindEngineConfig( | ||||
|             tp=tp, model_format=model_format, quant_policy=quant_policy | ||||
|         ), | ||||
|         chat_template_config=ChatTemplateConfig(model_name=chat_template) | ||||
|         if chat_template | ||||
|         else None, | ||||
|         log_level="WARNING", | ||||
|     ) | ||||
|     return lmdeploy_pipe | ||||
|  | ||||
|  | ||||
| async def lmdeploy_model_if_cache( | ||||
|     model, | ||||
|     prompt, | ||||
|     system_prompt=None, | ||||
|     history_messages=[], | ||||
|     chat_template=None, | ||||
|     model_format="hf", | ||||
|     quant_policy=0, | ||||
|     **kwargs, | ||||
| ) -> str: | ||||
|     """ | ||||
|     Args: | ||||
|         model (str): The path to the model. | ||||
|             It could be one of the following options: | ||||
|                     - i) A local directory path of a turbomind model which is | ||||
|                         converted by `lmdeploy convert` command or download | ||||
|                         from ii) and iii). | ||||
|                     - ii) The model_id of a lmdeploy-quantized model hosted | ||||
|                         inside a model repo on huggingface.co, such as | ||||
|                         "InternLM/internlm-chat-20b-4bit", | ||||
|                         "lmdeploy/llama2-chat-70b-4bit", etc. | ||||
|                     - iii) The model_id of a model hosted inside a model repo | ||||
|                         on huggingface.co, such as "internlm/internlm-chat-7b", | ||||
|                         "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | ||||
|                         and so on. | ||||
|         chat_template (str): needed when model is a pytorch model on | ||||
|             huggingface.co, such as "internlm-chat-7b", | ||||
|             "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, | ||||
|             and when the model name of local path did not match the original model name in HF. | ||||
|         tp (int): tensor parallel | ||||
|         prompt (Union[str, List[str]]): input texts to be completed. | ||||
|         do_preprocess (bool): whether pre-process the messages. Default to | ||||
|             True, which means chat_template will be applied. | ||||
|         skip_special_tokens (bool): Whether or not to remove special tokens | ||||
|             in the decoding. Default to be True. | ||||
|         do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. | ||||
|             Default to be False, which means greedy decoding will be applied. | ||||
|     """ | ||||
|     try: | ||||
|         import lmdeploy | ||||
|         from lmdeploy import version_info, GenerationConfig | ||||
|     except Exception: | ||||
|         raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") | ||||
|  | ||||
|     kwargs.pop("response_format", None) | ||||
|     max_new_tokens = kwargs.pop("max_tokens", 512) | ||||
|     tp = kwargs.pop("tp", 1) | ||||
|     skip_special_tokens = kwargs.pop("skip_special_tokens", True) | ||||
|     do_preprocess = kwargs.pop("do_preprocess", True) | ||||
|     do_sample = kwargs.pop("do_sample", False) | ||||
|     gen_params = kwargs | ||||
|  | ||||
|     version = version_info | ||||
|     if do_sample is not None and version < (0, 6, 0): | ||||
|         raise RuntimeError( | ||||
|             "`do_sample` parameter is not supported by lmdeploy until " | ||||
|             f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}" | ||||
|         ) | ||||
|     else: | ||||
|         do_sample = True | ||||
|         gen_params.update(do_sample=do_sample) | ||||
|  | ||||
|     lmdeploy_pipe = initialize_lmdeploy_pipeline( | ||||
|         model=model, | ||||
|         tp=tp, | ||||
|         chat_template=chat_template, | ||||
|         model_format=model_format, | ||||
|         quant_policy=quant_policy, | ||||
|         log_level="WARNING", | ||||
|     ) | ||||
|  | ||||
|     messages = [] | ||||
|     if system_prompt: | ||||
|         messages.append({"role": "system", "content": system_prompt}) | ||||
|  | ||||
|     hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) | ||||
|     messages.extend(history_messages) | ||||
|     messages.append({"role": "user", "content": prompt}) | ||||
|     if hashing_kv is not None: | ||||
|         args_hash = compute_args_hash(model, messages) | ||||
|         if_cache_return = await hashing_kv.get_by_id(args_hash) | ||||
|         if if_cache_return is not None: | ||||
|             return if_cache_return["return"] | ||||
|  | ||||
|     gen_config = GenerationConfig( | ||||
|         skip_special_tokens=skip_special_tokens, | ||||
|         max_new_tokens=max_new_tokens, | ||||
|         **gen_params, | ||||
|     ) | ||||
|  | ||||
|     response = "" | ||||
|     async for res in lmdeploy_pipe.generate( | ||||
|         messages, | ||||
|         gen_config=gen_config, | ||||
|         do_preprocess=do_preprocess, | ||||
|         stream_response=False, | ||||
|         session_id=1, | ||||
|     ): | ||||
|         response += res.response | ||||
|  | ||||
|     if hashing_kv is not None: | ||||
|         await hashing_kv.upsert({args_hash: {"return": response, "model": model}}) | ||||
|     return response | ||||
|  | ||||
|  | ||||
| async def gpt_4o_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], **kwargs | ||||
| ) -> str: | ||||
|     return await openai_complete_if_cache( | ||||
|         "gpt-4o", | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def gpt_4o_mini_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], **kwargs | ||||
| ) -> str: | ||||
|     return await openai_complete_if_cache( | ||||
|         "gpt-4o-mini", | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def azure_openai_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], **kwargs | ||||
| ) -> str: | ||||
|     return await azure_openai_complete_if_cache( | ||||
|         "conversation-4o-mini", | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def bedrock_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], **kwargs | ||||
| ) -> str: | ||||
|     return await bedrock_complete_if_cache( | ||||
|         "anthropic.claude-3-haiku-20240307-v1:0", | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def hf_model_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], **kwargs | ||||
| ) -> str: | ||||
|     model_name = kwargs["hashing_kv"].global_config["llm_model_name"] | ||||
|     return await hf_model_if_cache( | ||||
|         model_name, | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def ollama_model_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], **kwargs | ||||
| ) -> str: | ||||
|     model_name = kwargs["hashing_kv"].global_config["llm_model_name"] | ||||
|     return await ollama_model_if_cache( | ||||
|         model_name, | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=60), | ||||
|     retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), | ||||
| ) | ||||
| async def openai_embedding( | ||||
|     texts: list[str], | ||||
|     model: str = "text-embedding-3-small", | ||||
|     base_url: str = None, | ||||
|     api_key: str = None, | ||||
| ) -> np.ndarray: | ||||
|     if api_key: | ||||
|         os.environ["OPENAI_API_KEY"] = api_key | ||||
|  | ||||
|     openai_async_client = ( | ||||
|         AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) | ||||
|     ) | ||||
|     response = await openai_async_client.embeddings.create( | ||||
|         model=model, input=texts, encoding_format="float" | ||||
|     ) | ||||
|     return np.array([dp.embedding for dp in response.data]) | ||||
|  | ||||
|  | ||||
| @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|     retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), | ||||
| ) | ||||
| async def azure_openai_embedding( | ||||
|     texts: list[str], | ||||
|     model: str = "text-embedding-3-small", | ||||
|     base_url: str = None, | ||||
|     api_key: str = None, | ||||
| ) -> np.ndarray: | ||||
|     if api_key: | ||||
|         os.environ["AZURE_OPENAI_API_KEY"] = api_key | ||||
|     if base_url: | ||||
|         os.environ["AZURE_OPENAI_ENDPOINT"] = base_url | ||||
|  | ||||
|     openai_async_client = AsyncAzureOpenAI( | ||||
|         azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), | ||||
|         api_key=os.getenv("AZURE_OPENAI_API_KEY"), | ||||
|         api_version=os.getenv("AZURE_OPENAI_API_VERSION"), | ||||
|     ) | ||||
|  | ||||
|     response = await openai_async_client.embeddings.create( | ||||
|         model=model, input=texts, encoding_format="float" | ||||
|     ) | ||||
|     return np.array([dp.embedding for dp in response.data]) | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=60), | ||||
|     retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), | ||||
| ) | ||||
| async def siliconcloud_embedding( | ||||
|     texts: list[str], | ||||
|     model: str = "netease-youdao/bce-embedding-base_v1", | ||||
|     base_url: str = "https://api.siliconflow.cn/v1/embeddings", | ||||
|     max_token_size: int = 512, | ||||
|     api_key: str = None, | ||||
| ) -> np.ndarray: | ||||
|     if api_key and not api_key.startswith("Bearer "): | ||||
|         api_key = "Bearer " + api_key | ||||
|  | ||||
|     headers = {"Authorization": api_key, "Content-Type": "application/json"} | ||||
|  | ||||
|     truncate_texts = [text[0:max_token_size] for text in texts] | ||||
|  | ||||
|     payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"} | ||||
|  | ||||
|     base64_strings = [] | ||||
|     async with aiohttp.ClientSession() as session: | ||||
|         async with session.post(base_url, headers=headers, json=payload) as response: | ||||
|             content = await response.json() | ||||
|             if "code" in content: | ||||
|                 raise ValueError(content) | ||||
|             base64_strings = [item["embedding"] for item in content["data"]] | ||||
|  | ||||
|     embeddings = [] | ||||
|     for string in base64_strings: | ||||
|         decode_bytes = base64.b64decode(string) | ||||
|         n = len(decode_bytes) // 4 | ||||
|         float_array = struct.unpack("<" + "f" * n, decode_bytes) | ||||
|         embeddings.append(float_array) | ||||
|     return np.array(embeddings) | ||||
|  | ||||
|  | ||||
| # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) | ||||
| # @retry( | ||||
| #     stop=stop_after_attempt(3), | ||||
| #     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
| #     retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),  # TODO: fix exceptions | ||||
| # ) | ||||
| async def bedrock_embedding( | ||||
|     texts: list[str], | ||||
|     model: str = "amazon.titan-embed-text-v2:0", | ||||
|     aws_access_key_id=None, | ||||
|     aws_secret_access_key=None, | ||||
|     aws_session_token=None, | ||||
| ) -> np.ndarray: | ||||
|     os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( | ||||
|         "AWS_ACCESS_KEY_ID", aws_access_key_id | ||||
|     ) | ||||
|     os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( | ||||
|         "AWS_SECRET_ACCESS_KEY", aws_secret_access_key | ||||
|     ) | ||||
|     os.environ["AWS_SESSION_TOKEN"] = os.environ.get( | ||||
|         "AWS_SESSION_TOKEN", aws_session_token | ||||
|     ) | ||||
|  | ||||
|     session = aioboto3.Session() | ||||
|     async with session.client("bedrock-runtime") as bedrock_async_client: | ||||
|         if (model_provider := model.split(".")[0]) == "amazon": | ||||
|             embed_texts = [] | ||||
|             for text in texts: | ||||
|                 if "v2" in model: | ||||
|                     body = json.dumps( | ||||
|                         { | ||||
|                             "inputText": text, | ||||
|                             # 'dimensions': embedding_dim, | ||||
|                             "embeddingTypes": ["float"], | ||||
|                         } | ||||
|                     ) | ||||
|                 elif "v1" in model: | ||||
|                     body = json.dumps({"inputText": text}) | ||||
|                 else: | ||||
|                     raise ValueError(f"Model {model} is not supported!") | ||||
|  | ||||
|                 response = await bedrock_async_client.invoke_model( | ||||
|                     modelId=model, | ||||
|                     body=body, | ||||
|                     accept="application/json", | ||||
|                     contentType="application/json", | ||||
|                 ) | ||||
|  | ||||
|                 response_body = await response.get("body").json() | ||||
|  | ||||
|                 embed_texts.append(response_body["embedding"]) | ||||
|         elif model_provider == "cohere": | ||||
|             body = json.dumps( | ||||
|                 {"texts": texts, "input_type": "search_document", "truncate": "NONE"} | ||||
|             ) | ||||
|  | ||||
|             response = await bedrock_async_client.invoke_model( | ||||
|                 model=model, | ||||
|                 body=body, | ||||
|                 accept="application/json", | ||||
|                 contentType="application/json", | ||||
|             ) | ||||
|  | ||||
|             response_body = json.loads(response.get("body").read()) | ||||
|  | ||||
|             embed_texts = response_body["embeddings"] | ||||
|         else: | ||||
|             raise ValueError(f"Model provider '{model_provider}' is not supported!") | ||||
|  | ||||
|         return np.array(embed_texts) | ||||
|  | ||||
|  | ||||
| async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: | ||||
|     embed_model.to('cuda:0') | ||||
|     input_ids = tokenizer( | ||||
|         texts, return_tensors="pt", padding=True, truncation=True | ||||
|     ).input_ids.cuda() | ||||
|     with torch.no_grad(): | ||||
|         outputs = embed_model(input_ids) | ||||
|         embeddings = outputs.last_hidden_state.mean(dim=1) | ||||
|     return embeddings.detach().cpu().numpy() | ||||
|  | ||||
|  | ||||
| async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: | ||||
|     embed_text = [] | ||||
|     ollama_client = ollama.Client(**kwargs) | ||||
|     for text in texts: | ||||
|         data = ollama_client.embeddings(model=embed_model, prompt=text) | ||||
|         embed_text.append(data["embedding"]) | ||||
|  | ||||
|     return embed_text | ||||
| from pydantic import BaseModel, Field | ||||
|  | ||||
|  | ||||
| class Model(BaseModel): | ||||
| @@ -793,6 +69,8 @@ class MultiModel: | ||||
|         self, prompt, system_prompt=None, history_messages=[], **kwargs | ||||
|     ) -> str: | ||||
|         kwargs.pop("model", None)  # stop from overwriting the custom model name | ||||
|         kwargs.pop("keyword_extraction", None) | ||||
|         kwargs.pop("mode", None) | ||||
|         next_model = self._next_model() | ||||
|         args = dict( | ||||
|             prompt=prompt, | ||||
| @@ -809,6 +87,8 @@ if __name__ == "__main__": | ||||
|     import asyncio | ||||
|  | ||||
|     async def main(): | ||||
|         from lightrag.llm.openai import gpt_4o_mini_complete | ||||
|  | ||||
|         result = await gpt_4o_mini_complete("How are you?") | ||||
|         print(result) | ||||
|  | ||||
|   | ||||
							
								
								
									
										0
									
								
								minirag/llm/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								minirag/llm/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										189
									
								
								minirag/llm/azure_openai.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								minirag/llm/azure_openai.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,189 @@ | ||||
| """ | ||||
| Azure OpenAI LLM Interface Module | ||||
| ========================== | ||||
|  | ||||
| This module provides interfaces for interacting with aure openai's language models, | ||||
| including text generation and embedding capabilities. | ||||
|  | ||||
| Author: Lightrag team | ||||
| Created: 2024-01-24 | ||||
| License: MIT License | ||||
|  | ||||
| Copyright (c) 2024 Lightrag | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Change Log: | ||||
| - 1.0.0 (2024-01-24): Initial release | ||||
|     * Added async chat completion support | ||||
|     * Added embedding generation | ||||
|     * Added stream response capability | ||||
|  | ||||
| Dependencies: | ||||
|     - openai | ||||
|     - numpy | ||||
|     - pipmaster | ||||
|     - Python >= 3.10 | ||||
|  | ||||
| Usage: | ||||
|     from llm_interfaces.azure_openai import azure_openai_model_complete, azure_openai_embed | ||||
| """ | ||||
|  | ||||
| __version__ = "1.0.0" | ||||
| __author__ = "lightrag Team" | ||||
| __status__ = "Production" | ||||
|  | ||||
|  | ||||
| import os | ||||
| import pipmaster as pm  # Pipmaster for dynamic library install | ||||
|  | ||||
| # install specific modules | ||||
| if not pm.is_installed("openai"): | ||||
|     pm.install("openai") | ||||
| if not pm.is_installed("tenacity"): | ||||
|     pm.install("tenacity") | ||||
|  | ||||
| from openai import ( | ||||
|     AsyncAzureOpenAI, | ||||
|     APIConnectionError, | ||||
|     RateLimitError, | ||||
|     APITimeoutError, | ||||
| ) | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
|  | ||||
| from lightrag.utils import ( | ||||
|     wrap_embedding_func_with_attrs, | ||||
|     locate_json_string_body_from_string, | ||||
|     safe_unicode_decode, | ||||
| ) | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APIConnectionError) | ||||
|     ), | ||||
| ) | ||||
| async def azure_openai_complete_if_cache( | ||||
|     model, | ||||
|     prompt, | ||||
|     system_prompt=None, | ||||
|     history_messages=[], | ||||
|     base_url=None, | ||||
|     api_key=None, | ||||
|     api_version=None, | ||||
|     **kwargs, | ||||
| ): | ||||
|     if api_key: | ||||
|         os.environ["AZURE_OPENAI_API_KEY"] = api_key | ||||
|     if base_url: | ||||
|         os.environ["AZURE_OPENAI_ENDPOINT"] = base_url | ||||
|     if api_version: | ||||
|         os.environ["AZURE_OPENAI_API_VERSION"] = api_version | ||||
|  | ||||
|     openai_async_client = AsyncAzureOpenAI( | ||||
|         azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), | ||||
|         api_key=os.getenv("AZURE_OPENAI_API_KEY"), | ||||
|         api_version=os.getenv("AZURE_OPENAI_API_VERSION"), | ||||
|     ) | ||||
|     kwargs.pop("hashing_kv", None) | ||||
|     messages = [] | ||||
|     if system_prompt: | ||||
|         messages.append({"role": "system", "content": system_prompt}) | ||||
|     messages.extend(history_messages) | ||||
|     if prompt is not None: | ||||
|         messages.append({"role": "user", "content": prompt}) | ||||
|  | ||||
|     if "response_format" in kwargs: | ||||
|         response = await openai_async_client.beta.chat.completions.parse( | ||||
|             model=model, messages=messages, **kwargs | ||||
|         ) | ||||
|     else: | ||||
|         response = await openai_async_client.chat.completions.create( | ||||
|             model=model, messages=messages, **kwargs | ||||
|         ) | ||||
|  | ||||
|     if hasattr(response, "__aiter__"): | ||||
|  | ||||
|         async def inner(): | ||||
|             async for chunk in response: | ||||
|                 if len(chunk.choices) == 0: | ||||
|                     continue | ||||
|                 content = chunk.choices[0].delta.content | ||||
|                 if content is None: | ||||
|                     continue | ||||
|                 if r"\u" in content: | ||||
|                     content = safe_unicode_decode(content.encode("utf-8")) | ||||
|                 yield content | ||||
|  | ||||
|         return inner() | ||||
|     else: | ||||
|         content = response.choices[0].message.content | ||||
|         if r"\u" in content: | ||||
|             content = safe_unicode_decode(content.encode("utf-8")) | ||||
|         return content | ||||
|  | ||||
|  | ||||
| async def azure_openai_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | ||||
| ) -> str: | ||||
|     keyword_extraction = kwargs.pop("keyword_extraction", None) | ||||
|     result = await azure_openai_complete_if_cache( | ||||
|         os.getenv("LLM_MODEL", "gpt-4o-mini"), | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|     if keyword_extraction:  # TODO: use JSON API | ||||
|         return locate_json_string_body_from_string(result) | ||||
|     return result | ||||
|  | ||||
|  | ||||
| @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191) | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APITimeoutError) | ||||
|     ), | ||||
| ) | ||||
| async def azure_openai_embed( | ||||
|     texts: list[str], | ||||
|     model: str = "text-embedding-3-small", | ||||
|     base_url: str = None, | ||||
|     api_key: str = None, | ||||
|     api_version: str = None, | ||||
| ) -> np.ndarray: | ||||
|     if api_key: | ||||
|         os.environ["AZURE_OPENAI_API_KEY"] = api_key | ||||
|     if base_url: | ||||
|         os.environ["AZURE_OPENAI_ENDPOINT"] = base_url | ||||
|     if api_version: | ||||
|         os.environ["AZURE_OPENAI_API_VERSION"] = api_version | ||||
|  | ||||
|     openai_async_client = AsyncAzureOpenAI( | ||||
|         azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), | ||||
|         api_key=os.getenv("AZURE_OPENAI_API_KEY"), | ||||
|         api_version=os.getenv("AZURE_OPENAI_API_VERSION"), | ||||
|     ) | ||||
|  | ||||
|     response = await openai_async_client.embeddings.create( | ||||
|         model=model, input=texts, encoding_format="float" | ||||
|     ) | ||||
|     return np.array([dp.embedding for dp in response.data]) | ||||
							
								
								
									
										225
									
								
								minirag/llm/bedrock.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								minirag/llm/bedrock.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,225 @@ | ||||
| """ | ||||
| Bedrock LLM Interface Module | ||||
| ========================== | ||||
|  | ||||
| This module provides interfaces for interacting with Bedrock's language models, | ||||
| including text generation and embedding capabilities. | ||||
|  | ||||
| Author: Lightrag team | ||||
| Created: 2024-01-24 | ||||
| License: MIT License | ||||
|  | ||||
| Copyright (c) 2024 Lightrag | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Change Log: | ||||
| - 1.0.0 (2024-01-24): Initial release | ||||
|     * Added async chat completion support | ||||
|     * Added embedding generation | ||||
|     * Added stream response capability | ||||
|  | ||||
| Dependencies: | ||||
|     - aioboto3, tenacity | ||||
|     - numpy | ||||
|     - pipmaster | ||||
|     - Python >= 3.10 | ||||
|  | ||||
| Usage: | ||||
|     from llm_interfaces.bebrock import bebrock_model_complete, bebrock_embed | ||||
| """ | ||||
|  | ||||
| __version__ = "1.0.0" | ||||
| __author__ = "lightrag Team" | ||||
| __status__ = "Production" | ||||
|  | ||||
|  | ||||
| import copy | ||||
| import os | ||||
| import json | ||||
|  | ||||
| import pipmaster as pm  # Pipmaster for dynamic library install | ||||
|  | ||||
| if not pm.is_installed("aioboto3"): | ||||
|     pm.install("aioboto3") | ||||
| if not pm.is_installed("tenacity"): | ||||
|     pm.install("tenacity") | ||||
| import aioboto3 | ||||
| import numpy as np | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
|  | ||||
| from lightrag.utils import ( | ||||
|     locate_json_string_body_from_string, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class BedrockError(Exception): | ||||
|     """Generic error for issues related to Amazon Bedrock""" | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(5), | ||||
|     wait=wait_exponential(multiplier=1, max=60), | ||||
|     retry=retry_if_exception_type((BedrockError)), | ||||
| ) | ||||
| async def bedrock_complete_if_cache( | ||||
|     model, | ||||
|     prompt, | ||||
|     system_prompt=None, | ||||
|     history_messages=[], | ||||
|     aws_access_key_id=None, | ||||
|     aws_secret_access_key=None, | ||||
|     aws_session_token=None, | ||||
|     **kwargs, | ||||
| ) -> str: | ||||
|     os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( | ||||
|         "AWS_ACCESS_KEY_ID", aws_access_key_id | ||||
|     ) | ||||
|     os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( | ||||
|         "AWS_SECRET_ACCESS_KEY", aws_secret_access_key | ||||
|     ) | ||||
|     os.environ["AWS_SESSION_TOKEN"] = os.environ.get( | ||||
|         "AWS_SESSION_TOKEN", aws_session_token | ||||
|     ) | ||||
|     kwargs.pop("hashing_kv", None) | ||||
|     # Fix message history format | ||||
|     messages = [] | ||||
|     for history_message in history_messages: | ||||
|         message = copy.copy(history_message) | ||||
|         message["content"] = [{"text": message["content"]}] | ||||
|         messages.append(message) | ||||
|  | ||||
|     # Add user prompt | ||||
|     messages.append({"role": "user", "content": [{"text": prompt}]}) | ||||
|  | ||||
|     # Initialize Converse API arguments | ||||
|     args = {"modelId": model, "messages": messages} | ||||
|  | ||||
|     # Define system prompt | ||||
|     if system_prompt: | ||||
|         args["system"] = [{"text": system_prompt}] | ||||
|  | ||||
|     # Map and set up inference parameters | ||||
|     inference_params_map = { | ||||
|         "max_tokens": "maxTokens", | ||||
|         "top_p": "topP", | ||||
|         "stop_sequences": "stopSequences", | ||||
|     } | ||||
|     if inference_params := list( | ||||
|         set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) | ||||
|     ): | ||||
|         args["inferenceConfig"] = {} | ||||
|         for param in inference_params: | ||||
|             args["inferenceConfig"][inference_params_map.get(param, param)] = ( | ||||
|                 kwargs.pop(param) | ||||
|             ) | ||||
|  | ||||
|     # Call model via Converse API | ||||
|     session = aioboto3.Session() | ||||
|     async with session.client("bedrock-runtime") as bedrock_async_client: | ||||
|         try: | ||||
|             response = await bedrock_async_client.converse(**args, **kwargs) | ||||
|         except Exception as e: | ||||
|             raise BedrockError(e) | ||||
|  | ||||
|     return response["output"]["message"]["content"][0]["text"] | ||||
|  | ||||
|  | ||||
| async def bedrock_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | ||||
| ) -> str: | ||||
|     keyword_extraction = kwargs.pop("keyword_extraction", None) | ||||
|     result = await bedrock_complete_if_cache( | ||||
|         "anthropic.claude-3-haiku-20240307-v1:0", | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|     if keyword_extraction:  # TODO: use JSON API | ||||
|         return locate_json_string_body_from_string(result) | ||||
|     return result | ||||
|  | ||||
|  | ||||
| # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) | ||||
| # @retry( | ||||
| #     stop=stop_after_attempt(3), | ||||
| #     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
| #     retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),  # TODO: fix exceptions | ||||
| # ) | ||||
| async def bedrock_embed( | ||||
|     texts: list[str], | ||||
|     model: str = "amazon.titan-embed-text-v2:0", | ||||
|     aws_access_key_id=None, | ||||
|     aws_secret_access_key=None, | ||||
|     aws_session_token=None, | ||||
| ) -> np.ndarray: | ||||
|     os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( | ||||
|         "AWS_ACCESS_KEY_ID", aws_access_key_id | ||||
|     ) | ||||
|     os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( | ||||
|         "AWS_SECRET_ACCESS_KEY", aws_secret_access_key | ||||
|     ) | ||||
|     os.environ["AWS_SESSION_TOKEN"] = os.environ.get( | ||||
|         "AWS_SESSION_TOKEN", aws_session_token | ||||
|     ) | ||||
|  | ||||
|     session = aioboto3.Session() | ||||
|     async with session.client("bedrock-runtime") as bedrock_async_client: | ||||
|         if (model_provider := model.split(".")[0]) == "amazon": | ||||
|             embed_texts = [] | ||||
|             for text in texts: | ||||
|                 if "v2" in model: | ||||
|                     body = json.dumps( | ||||
|                         { | ||||
|                             "inputText": text, | ||||
|                             # 'dimensions': embedding_dim, | ||||
|                             "embeddingTypes": ["float"], | ||||
|                         } | ||||
|                     ) | ||||
|                 elif "v1" in model: | ||||
|                     body = json.dumps({"inputText": text}) | ||||
|                 else: | ||||
|                     raise ValueError(f"Model {model} is not supported!") | ||||
|  | ||||
|                 response = await bedrock_async_client.invoke_model( | ||||
|                     modelId=model, | ||||
|                     body=body, | ||||
|                     accept="application/json", | ||||
|                     contentType="application/json", | ||||
|                 ) | ||||
|  | ||||
|                 response_body = await response.get("body").json() | ||||
|  | ||||
|                 embed_texts.append(response_body["embedding"]) | ||||
|         elif model_provider == "cohere": | ||||
|             body = json.dumps( | ||||
|                 {"texts": texts, "input_type": "search_document", "truncate": "NONE"} | ||||
|             ) | ||||
|  | ||||
|             response = await bedrock_async_client.invoke_model( | ||||
|                 model=model, | ||||
|                 body=body, | ||||
|                 accept="application/json", | ||||
|                 contentType="application/json", | ||||
|             ) | ||||
|  | ||||
|             response_body = json.loads(response.get("body").read()) | ||||
|  | ||||
|             embed_texts = response_body["embeddings"] | ||||
|         else: | ||||
|             raise ValueError(f"Model provider '{model_provider}' is not supported!") | ||||
|  | ||||
|         return np.array(embed_texts) | ||||
							
								
								
									
										188
									
								
								minirag/llm/hf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										188
									
								
								minirag/llm/hf.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,188 @@ | ||||
| """ | ||||
| Hugging face LLM Interface Module | ||||
| ========================== | ||||
|  | ||||
| This module provides interfaces for interacting with Hugging face's language models, | ||||
| including text generation and embedding capabilities. | ||||
|  | ||||
| Author: Lightrag team | ||||
| Created: 2024-01-24 | ||||
| License: MIT License | ||||
|  | ||||
| Copyright (c) 2024 Lightrag | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Change Log: | ||||
| - 1.0.0 (2024-01-24): Initial release | ||||
|     * Added async chat completion support | ||||
|     * Added embedding generation | ||||
|     * Added stream response capability | ||||
|  | ||||
| Dependencies: | ||||
|     - transformers | ||||
|     - numpy | ||||
|     - pipmaster | ||||
|     - Python >= 3.10 | ||||
|  | ||||
| Usage: | ||||
|     from llm_interfaces.hf import hf_model_complete, hf_embed | ||||
| """ | ||||
|  | ||||
| __version__ = "1.0.0" | ||||
| __author__ = "lightrag Team" | ||||
| __status__ = "Production" | ||||
|  | ||||
| import copy | ||||
| import os | ||||
| import pipmaster as pm  # Pipmaster for dynamic library install | ||||
|  | ||||
| # install specific modules | ||||
| if not pm.is_installed("transformers"): | ||||
|     pm.install("transformers") | ||||
| if not pm.is_installed("torch"): | ||||
|     pm.install("torch") | ||||
| if not pm.is_installed("tenacity"): | ||||
|     pm.install("tenacity") | ||||
|  | ||||
| from transformers import AutoTokenizer, AutoModelForCausalLM | ||||
| from functools import lru_cache | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
| from lightrag.exceptions import ( | ||||
|     APIConnectionError, | ||||
|     RateLimitError, | ||||
|     APITimeoutError, | ||||
| ) | ||||
| from lightrag.utils import ( | ||||
|     locate_json_string_body_from_string, | ||||
| ) | ||||
| import torch | ||||
| import numpy as np | ||||
|  | ||||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||||
|  | ||||
|  | ||||
| @lru_cache(maxsize=1) | ||||
| def initialize_hf_model(model_name): | ||||
|     hf_tokenizer = AutoTokenizer.from_pretrained( | ||||
|         model_name, device_map="auto", trust_remote_code=True | ||||
|     ) | ||||
|     hf_model = AutoModelForCausalLM.from_pretrained( | ||||
|         model_name, device_map="auto", trust_remote_code=True | ||||
|     ) | ||||
|     if hf_tokenizer.pad_token is None: | ||||
|         hf_tokenizer.pad_token = hf_tokenizer.eos_token | ||||
|  | ||||
|     return hf_model, hf_tokenizer | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APITimeoutError) | ||||
|     ), | ||||
| ) | ||||
| async def hf_model_if_cache( | ||||
|     model, | ||||
|     prompt, | ||||
|     system_prompt=None, | ||||
|     history_messages=[], | ||||
|     **kwargs, | ||||
| ) -> str: | ||||
|     model_name = model | ||||
|     hf_model, hf_tokenizer = initialize_hf_model(model_name) | ||||
|     messages = [] | ||||
|     if system_prompt: | ||||
|         messages.append({"role": "system", "content": system_prompt}) | ||||
|     messages.extend(history_messages) | ||||
|     messages.append({"role": "user", "content": prompt}) | ||||
|     kwargs.pop("hashing_kv", None) | ||||
|     input_prompt = "" | ||||
|     try: | ||||
|         input_prompt = hf_tokenizer.apply_chat_template( | ||||
|             messages, tokenize=False, add_generation_prompt=True | ||||
|         ) | ||||
|     except Exception: | ||||
|         try: | ||||
|             ori_message = copy.deepcopy(messages) | ||||
|             if messages[0]["role"] == "system": | ||||
|                 messages[1]["content"] = ( | ||||
|                     "<system>" | ||||
|                     + messages[0]["content"] | ||||
|                     + "</system>\n" | ||||
|                     + messages[1]["content"] | ||||
|                 ) | ||||
|                 messages = messages[1:] | ||||
|                 input_prompt = hf_tokenizer.apply_chat_template( | ||||
|                     messages, tokenize=False, add_generation_prompt=True | ||||
|                 ) | ||||
|         except Exception: | ||||
|             len_message = len(ori_message) | ||||
|             for msgid in range(len_message): | ||||
|                 input_prompt = ( | ||||
|                     input_prompt | ||||
|                     + "<" | ||||
|                     + ori_message[msgid]["role"] | ||||
|                     + ">" | ||||
|                     + ori_message[msgid]["content"] | ||||
|                     + "</" | ||||
|                     + ori_message[msgid]["role"] | ||||
|                     + ">\n" | ||||
|                 ) | ||||
|  | ||||
|     input_ids = hf_tokenizer( | ||||
|         input_prompt, return_tensors="pt", padding=True, truncation=True | ||||
|     ).to("cuda") | ||||
|     inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()} | ||||
|     output = hf_model.generate( | ||||
|         **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True | ||||
|     ) | ||||
|     response_text = hf_tokenizer.decode( | ||||
|         output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True | ||||
|     ) | ||||
|  | ||||
|     return response_text | ||||
|  | ||||
|  | ||||
| async def hf_model_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | ||||
| ) -> str: | ||||
|     keyword_extraction = kwargs.pop("keyword_extraction", None) | ||||
|     model_name = kwargs["hashing_kv"].global_config["llm_model_name"] | ||||
|     result = await hf_model_if_cache( | ||||
|         model_name, | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|     if keyword_extraction:  # TODO: use JSON API | ||||
|         return locate_json_string_body_from_string(result) | ||||
|     return result | ||||
|  | ||||
|  | ||||
| async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: | ||||
|     device = next(embed_model.parameters()).device | ||||
|     input_ids = tokenizer( | ||||
|         texts, return_tensors="pt", padding=True, truncation=True | ||||
|     ).input_ids.to(device) | ||||
|     with torch.no_grad(): | ||||
|         outputs = embed_model(input_ids) | ||||
|         embeddings = outputs.last_hidden_state.mean(dim=1) | ||||
|     if embeddings.dtype == torch.bfloat16: | ||||
|         return embeddings.detach().to(torch.float32).cpu().numpy() | ||||
|     else: | ||||
|         return embeddings.detach().cpu().numpy() | ||||
							
								
								
									
										86
									
								
								minirag/llm/jina.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								minirag/llm/jina.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| """ | ||||
| Jina Embedding Interface Module | ||||
| ========================== | ||||
|  | ||||
| This module provides interfaces for interacting with jina system, | ||||
| including embedding capabilities. | ||||
|  | ||||
| Author: Lightrag team | ||||
| Created: 2024-01-24 | ||||
| License: MIT License | ||||
|  | ||||
| Copyright (c) 2024 Lightrag | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Change Log: | ||||
| - 1.0.0 (2024-01-24): Initial release | ||||
|     * Added embedding generation | ||||
|  | ||||
| Dependencies: | ||||
|     - tenacity | ||||
|     - numpy | ||||
|     - pipmaster | ||||
|     - Python >= 3.10 | ||||
|  | ||||
| Usage: | ||||
|     from llm_interfaces.jina import jina_embed | ||||
| """ | ||||
|  | ||||
| __version__ = "1.0.0" | ||||
| __author__ = "lightrag Team" | ||||
| __status__ = "Production" | ||||
|  | ||||
| import os | ||||
| import pipmaster as pm  # Pipmaster for dynamic library install | ||||
|  | ||||
| # install specific modules | ||||
| if not pm.is_installed("lmdeploy"): | ||||
|     pm.install("lmdeploy") | ||||
| if not pm.is_installed("tenacity"): | ||||
|     pm.install("tenacity") | ||||
|  | ||||
|  | ||||
| import numpy as np | ||||
| import aiohttp | ||||
|  | ||||
|  | ||||
| async def fetch_data(url, headers, data): | ||||
|     async with aiohttp.ClientSession() as session: | ||||
|         async with session.post(url, headers=headers, json=data) as response: | ||||
|             response_json = await response.json() | ||||
|             data_list = response_json.get("data", []) | ||||
|             return data_list | ||||
|  | ||||
|  | ||||
| async def jina_embed( | ||||
|     texts: list[str], | ||||
|     dimensions: int = 1024, | ||||
|     late_chunking: bool = False, | ||||
|     base_url: str = None, | ||||
|     api_key: str = None, | ||||
| ) -> np.ndarray: | ||||
|     if api_key: | ||||
|         os.environ["JINA_API_KEY"] = api_key | ||||
|     url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url | ||||
|     headers = { | ||||
|         "Content-Type": "application/json", | ||||
|         "Authorization": f"Bearer {os.environ['JINA_API_KEY']}", | ||||
|     } | ||||
|     data = { | ||||
|         "model": "jina-embeddings-v3", | ||||
|         "normalized": True, | ||||
|         "embedding_type": "float", | ||||
|         "dimensions": f"{dimensions}", | ||||
|         "late_chunking": late_chunking, | ||||
|         "input": texts, | ||||
|     } | ||||
|     data_list = await fetch_data(url, headers, data) | ||||
|     return np.array([dp["embedding"] for dp in data_list]) | ||||
							
								
								
									
										191
									
								
								minirag/llm/lmdeploy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								minirag/llm/lmdeploy.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,191 @@ | ||||
| """ | ||||
| LMDeploy LLM Interface Module | ||||
| ========================== | ||||
|  | ||||
| This module provides interfaces for interacting with LMDeploy's language models, | ||||
| including text generation and embedding capabilities. | ||||
|  | ||||
| Author: Lightrag team | ||||
| Created: 2024-01-24 | ||||
| License: MIT License | ||||
|  | ||||
| Copyright (c) 2024 Lightrag | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Change Log: | ||||
| - 1.0.0 (2024-01-24): Initial release | ||||
|     * Added async chat completion support | ||||
|     * Added embedding generation | ||||
|     * Added stream response capability | ||||
|  | ||||
| Dependencies: | ||||
|     - tenacity | ||||
|     - numpy | ||||
|     - pipmaster | ||||
|     - Python >= 3.10 | ||||
|  | ||||
| Usage: | ||||
|     from llm_interfaces.lmdeploy import lmdeploy_model_complete, lmdeploy_embed | ||||
| """ | ||||
|  | ||||
| __version__ = "1.0.0" | ||||
| __author__ = "lightrag Team" | ||||
| __status__ = "Production" | ||||
|  | ||||
| import pipmaster as pm  # Pipmaster for dynamic library install | ||||
|  | ||||
| # install specific modules | ||||
| if not pm.is_installed("lmdeploy"): | ||||
|     pm.install("lmdeploy[all]") | ||||
| if not pm.is_installed("tenacity"): | ||||
|     pm.install("tenacity") | ||||
|  | ||||
| from lightrag.exceptions import ( | ||||
|     APIConnectionError, | ||||
|     RateLimitError, | ||||
|     APITimeoutError, | ||||
| ) | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
|  | ||||
|  | ||||
| from functools import lru_cache | ||||
|  | ||||
|  | ||||
| @lru_cache(maxsize=1) | ||||
| def initialize_lmdeploy_pipeline( | ||||
|     model, | ||||
|     tp=1, | ||||
|     chat_template=None, | ||||
|     log_level="WARNING", | ||||
|     model_format="hf", | ||||
|     quant_policy=0, | ||||
| ): | ||||
|     from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig | ||||
|  | ||||
|     lmdeploy_pipe = pipeline( | ||||
|         model_path=model, | ||||
|         backend_config=TurbomindEngineConfig( | ||||
|             tp=tp, model_format=model_format, quant_policy=quant_policy | ||||
|         ), | ||||
|         chat_template_config=( | ||||
|             ChatTemplateConfig(model_name=chat_template) if chat_template else None | ||||
|         ), | ||||
|         log_level="WARNING", | ||||
|     ) | ||||
|     return lmdeploy_pipe | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APITimeoutError) | ||||
|     ), | ||||
| ) | ||||
| async def lmdeploy_model_if_cache( | ||||
|     model, | ||||
|     prompt, | ||||
|     system_prompt=None, | ||||
|     history_messages=[], | ||||
|     chat_template=None, | ||||
|     model_format="hf", | ||||
|     quant_policy=0, | ||||
|     **kwargs, | ||||
| ) -> str: | ||||
|     """ | ||||
|     Args: | ||||
|         model (str): The path to the model. | ||||
|             It could be one of the following options: | ||||
|                     - i) A local directory path of a turbomind model which is | ||||
|                         converted by `lmdeploy convert` command or download | ||||
|                         from ii) and iii). | ||||
|                     - ii) The model_id of a lmdeploy-quantized model hosted | ||||
|                         inside a model repo on huggingface.co, such as | ||||
|                         "InternLM/internlm-chat-20b-4bit", | ||||
|                         "lmdeploy/llama2-chat-70b-4bit", etc. | ||||
|                     - iii) The model_id of a model hosted inside a model repo | ||||
|                         on huggingface.co, such as "internlm/internlm-chat-7b", | ||||
|                         "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | ||||
|                         and so on. | ||||
|         chat_template (str): needed when model is a pytorch model on | ||||
|             huggingface.co, such as "internlm-chat-7b", | ||||
|             "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, | ||||
|             and when the model name of local path did not match the original model name in HF. | ||||
|         tp (int): tensor parallel | ||||
|         prompt (Union[str, List[str]]): input texts to be completed. | ||||
|         do_preprocess (bool): whether pre-process the messages. Default to | ||||
|             True, which means chat_template will be applied. | ||||
|         skip_special_tokens (bool): Whether or not to remove special tokens | ||||
|             in the decoding. Default to be True. | ||||
|         do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. | ||||
|             Default to be False, which means greedy decoding will be applied. | ||||
|     """ | ||||
|     try: | ||||
|         import lmdeploy | ||||
|         from lmdeploy import version_info, GenerationConfig | ||||
|     except Exception: | ||||
|         raise ImportError("Please install lmdeploy before initialize lmdeploy backend.") | ||||
|     kwargs.pop("hashing_kv", None) | ||||
|     kwargs.pop("response_format", None) | ||||
|     max_new_tokens = kwargs.pop("max_tokens", 512) | ||||
|     tp = kwargs.pop("tp", 1) | ||||
|     skip_special_tokens = kwargs.pop("skip_special_tokens", True) | ||||
|     do_preprocess = kwargs.pop("do_preprocess", True) | ||||
|     do_sample = kwargs.pop("do_sample", False) | ||||
|     gen_params = kwargs | ||||
|  | ||||
|     version = version_info | ||||
|     if do_sample is not None and version < (0, 6, 0): | ||||
|         raise RuntimeError( | ||||
|             "`do_sample` parameter is not supported by lmdeploy until " | ||||
|             f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}" | ||||
|         ) | ||||
|     else: | ||||
|         do_sample = True | ||||
|         gen_params.update(do_sample=do_sample) | ||||
|  | ||||
|     lmdeploy_pipe = initialize_lmdeploy_pipeline( | ||||
|         model=model, | ||||
|         tp=tp, | ||||
|         chat_template=chat_template, | ||||
|         model_format=model_format, | ||||
|         quant_policy=quant_policy, | ||||
|         log_level="WARNING", | ||||
|     ) | ||||
|  | ||||
|     messages = [] | ||||
|     if system_prompt: | ||||
|         messages.append({"role": "system", "content": system_prompt}) | ||||
|  | ||||
|     messages.extend(history_messages) | ||||
|     messages.append({"role": "user", "content": prompt}) | ||||
|  | ||||
|     gen_config = GenerationConfig( | ||||
|         skip_special_tokens=skip_special_tokens, | ||||
|         max_new_tokens=max_new_tokens, | ||||
|         **gen_params, | ||||
|     ) | ||||
|  | ||||
|     response = "" | ||||
|     async for res in lmdeploy_pipe.generate( | ||||
|         messages, | ||||
|         gen_config=gen_config, | ||||
|         do_preprocess=do_preprocess, | ||||
|         stream_response=False, | ||||
|         session_id=1, | ||||
|     ): | ||||
|         response += res.response | ||||
|     return response | ||||
							
								
								
									
										224
									
								
								minirag/llm/lollms.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										224
									
								
								minirag/llm/lollms.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,224 @@ | ||||
| """ | ||||
| LoLLMs (Lord of Large Language Models) Interface Module | ||||
| ===================================================== | ||||
|  | ||||
| This module provides the official interface for interacting with LoLLMs (Lord of Large Language and multimodal Systems), | ||||
| a unified framework for AI model interaction and deployment. | ||||
|  | ||||
| LoLLMs is designed as a "one tool to rule them all" solution, providing seamless integration | ||||
| with various AI models while maintaining high performance and user-friendly interfaces. | ||||
|  | ||||
| Author: ParisNeo | ||||
| Created: 2024-01-24 | ||||
| License: Apache 2.0 | ||||
|  | ||||
| Copyright (c) 2024 ParisNeo | ||||
|  | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|  | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
|  | ||||
| Version: 2.0.0 | ||||
|  | ||||
| Change Log: | ||||
| - 2.0.0 (2024-01-24): | ||||
|     * Added async support for model inference | ||||
|     * Implemented streaming capabilities | ||||
|     * Added embedding generation functionality | ||||
|     * Enhanced parameter handling | ||||
|     * Improved error handling and timeout management | ||||
|  | ||||
| Dependencies: | ||||
|     - aiohttp | ||||
|     - numpy | ||||
|     - Python >= 3.10 | ||||
|  | ||||
| Features: | ||||
|     - Async text generation with streaming support | ||||
|     - Embedding generation | ||||
|     - Configurable model parameters | ||||
|     - System prompt and chat history support | ||||
|     - Timeout handling | ||||
|     - API key authentication | ||||
|  | ||||
| Usage: | ||||
|     from llm_interfaces.lollms import lollms_model_complete, lollms_embed | ||||
|  | ||||
| Project Repository: https://github.com/ParisNeo/lollms | ||||
| Documentation: https://github.com/ParisNeo/lollms/docs | ||||
| """ | ||||
|  | ||||
| __version__ = "1.0.0" | ||||
| __author__ = "ParisNeo" | ||||
| __status__ = "Production" | ||||
| __project_url__ = "https://github.com/ParisNeo/lollms" | ||||
| __doc_url__ = "https://github.com/ParisNeo/lollms/docs" | ||||
| import sys | ||||
|  | ||||
| if sys.version_info < (3, 9): | ||||
|     from typing import AsyncIterator | ||||
| else: | ||||
|     from collections.abc import AsyncIterator | ||||
| import pipmaster as pm  # Pipmaster for dynamic library install | ||||
|  | ||||
| if not pm.is_installed("aiohttp"): | ||||
|     pm.install("aiohttp") | ||||
| if not pm.is_installed("tenacity"): | ||||
|     pm.install("tenacity") | ||||
|  | ||||
| import aiohttp | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
|  | ||||
| from lightrag.exceptions import ( | ||||
|     APIConnectionError, | ||||
|     RateLimitError, | ||||
|     APITimeoutError, | ||||
| ) | ||||
|  | ||||
| from typing import Union, List | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APITimeoutError) | ||||
|     ), | ||||
| ) | ||||
| async def lollms_model_if_cache( | ||||
|     model, | ||||
|     prompt, | ||||
|     system_prompt=None, | ||||
|     history_messages=[], | ||||
|     base_url="http://localhost:9600", | ||||
|     **kwargs, | ||||
| ) -> Union[str, AsyncIterator[str]]: | ||||
|     """Client implementation for lollms generation.""" | ||||
|  | ||||
|     stream = True if kwargs.get("stream") else False | ||||
|     api_key = kwargs.pop("api_key", None) | ||||
|     headers = ( | ||||
|         {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} | ||||
|         if api_key | ||||
|         else {"Content-Type": "application/json"} | ||||
|     ) | ||||
|  | ||||
|     # Extract lollms specific parameters | ||||
|     request_data = { | ||||
|         "prompt": prompt, | ||||
|         "model_name": model, | ||||
|         "personality": kwargs.get("personality", -1), | ||||
|         "n_predict": kwargs.get("n_predict", None), | ||||
|         "stream": stream, | ||||
|         "temperature": kwargs.get("temperature", 0.1), | ||||
|         "top_k": kwargs.get("top_k", 50), | ||||
|         "top_p": kwargs.get("top_p", 0.95), | ||||
|         "repeat_penalty": kwargs.get("repeat_penalty", 0.8), | ||||
|         "repeat_last_n": kwargs.get("repeat_last_n", 40), | ||||
|         "seed": kwargs.get("seed", None), | ||||
|         "n_threads": kwargs.get("n_threads", 8), | ||||
|     } | ||||
|  | ||||
|     # Prepare the full prompt including history | ||||
|     full_prompt = "" | ||||
|     if system_prompt: | ||||
|         full_prompt += f"{system_prompt}\n" | ||||
|     for msg in history_messages: | ||||
|         full_prompt += f"{msg['role']}: {msg['content']}\n" | ||||
|     full_prompt += prompt | ||||
|  | ||||
|     request_data["prompt"] = full_prompt | ||||
|     timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None)) | ||||
|  | ||||
|     async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session: | ||||
|         if stream: | ||||
|  | ||||
|             async def inner(): | ||||
|                 async with session.post( | ||||
|                     f"{base_url}/lollms_generate", json=request_data | ||||
|                 ) as response: | ||||
|                     async for line in response.content: | ||||
|                         yield line.decode().strip() | ||||
|  | ||||
|             return inner() | ||||
|         else: | ||||
|             async with session.post( | ||||
|                 f"{base_url}/lollms_generate", json=request_data | ||||
|             ) as response: | ||||
|                 return await response.text() | ||||
|  | ||||
|  | ||||
| async def lollms_model_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | ||||
| ) -> Union[str, AsyncIterator[str]]: | ||||
|     """Complete function for lollms model generation.""" | ||||
|  | ||||
|     # Extract and remove keyword_extraction from kwargs if present | ||||
|     keyword_extraction = kwargs.pop("keyword_extraction", None) | ||||
|  | ||||
|     # Get model name from config | ||||
|     model_name = kwargs["hashing_kv"].global_config["llm_model_name"] | ||||
|  | ||||
|     # If keyword extraction is needed, we might need to modify the prompt | ||||
|     # or add specific parameters for JSON output (if lollms supports it) | ||||
|     if keyword_extraction: | ||||
|         # Note: You might need to adjust this based on how lollms handles structured output | ||||
|         pass | ||||
|  | ||||
|     return await lollms_model_if_cache( | ||||
|         model_name, | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def lollms_embed( | ||||
|     texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs | ||||
| ) -> np.ndarray: | ||||
|     """ | ||||
|     Generate embeddings for a list of texts using lollms server. | ||||
|  | ||||
|     Args: | ||||
|         texts: List of strings to embed | ||||
|         embed_model: Model name (not used directly as lollms uses configured vectorizer) | ||||
|         base_url: URL of the lollms server | ||||
|         **kwargs: Additional arguments passed to the request | ||||
|  | ||||
|     Returns: | ||||
|         np.ndarray: Array of embeddings | ||||
|     """ | ||||
|     api_key = kwargs.pop("api_key", None) | ||||
|     headers = ( | ||||
|         {"Content-Type": "application/json", "Authorization": api_key} | ||||
|         if api_key | ||||
|         else {"Content-Type": "application/json"} | ||||
|     ) | ||||
|     async with aiohttp.ClientSession(headers=headers) as session: | ||||
|         embeddings = [] | ||||
|         for text in texts: | ||||
|             request_data = {"text": text} | ||||
|  | ||||
|             async with session.post( | ||||
|                 f"{base_url}/lollms_embed", | ||||
|                 json=request_data, | ||||
|             ) as response: | ||||
|                 result = await response.json() | ||||
|                 embeddings.append(result["vector"]) | ||||
|  | ||||
|         return np.array(embeddings) | ||||
							
								
								
									
										108
									
								
								minirag/llm/nvidia_openai.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								minirag/llm/nvidia_openai.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,108 @@ | ||||
| """ | ||||
| OpenAI LLM Interface Module | ||||
| ========================== | ||||
|  | ||||
| This module provides interfaces for interacting with openai's language models, | ||||
| including text generation and embedding capabilities. | ||||
|  | ||||
| Author: Lightrag team | ||||
| Created: 2024-01-24 | ||||
| License: MIT License | ||||
|  | ||||
| Copyright (c) 2024 Lightrag | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Change Log: | ||||
| - 1.0.0 (2024-01-24): Initial release | ||||
|     * Added async chat completion support | ||||
|     * Added embedding generation | ||||
|     * Added stream response capability | ||||
|  | ||||
| Dependencies: | ||||
|     - openai | ||||
|     - numpy | ||||
|     - pipmaster | ||||
|     - Python >= 3.10 | ||||
|  | ||||
| Usage: | ||||
|     from llm_interfaces.nvidia_openai import nvidia_openai_model_complete, nvidia_openai_embed | ||||
| """ | ||||
|  | ||||
| __version__ = "1.0.0" | ||||
| __author__ = "lightrag Team" | ||||
| __status__ = "Production" | ||||
|  | ||||
|  | ||||
| import sys | ||||
| import os | ||||
|  | ||||
| if sys.version_info < (3, 9): | ||||
|     pass | ||||
| else: | ||||
|     pass | ||||
| import pipmaster as pm  # Pipmaster for dynamic library install | ||||
|  | ||||
| # install specific modules | ||||
| if not pm.is_installed("openai"): | ||||
|     pm.install("openai") | ||||
|  | ||||
| from openai import ( | ||||
|     AsyncOpenAI, | ||||
|     APIConnectionError, | ||||
|     RateLimitError, | ||||
|     APITimeoutError, | ||||
| ) | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
|  | ||||
| from lightrag.utils import ( | ||||
|     wrap_embedding_func_with_attrs, | ||||
| ) | ||||
|  | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512) | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=60), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APITimeoutError) | ||||
|     ), | ||||
| ) | ||||
| async def nvidia_openai_embed( | ||||
|     texts: list[str], | ||||
|     model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1", | ||||
|     # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding | ||||
|     base_url: str = "https://integrate.api.nvidia.com/v1", | ||||
|     api_key: str = None, | ||||
|     input_type: str = "passage",  # query for retrieval, passage for embedding | ||||
|     trunc: str = "NONE",  # NONE or START or END | ||||
|     encode: str = "float",  # float or base64 | ||||
| ) -> np.ndarray: | ||||
|     if api_key: | ||||
|         os.environ["OPENAI_API_KEY"] = api_key | ||||
|  | ||||
|     openai_async_client = ( | ||||
|         AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) | ||||
|     ) | ||||
|     response = await openai_async_client.embeddings.create( | ||||
|         model=model, | ||||
|         input=texts, | ||||
|         encoding_format=encode, | ||||
|         extra_body={"input_type": input_type, "truncate": trunc}, | ||||
|     ) | ||||
|     return np.array([dp.embedding for dp in response.data]) | ||||
							
								
								
									
										158
									
								
								minirag/llm/ollama.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										158
									
								
								minirag/llm/ollama.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,158 @@ | ||||
| """ | ||||
| Ollama LLM Interface Module | ||||
| ========================== | ||||
|  | ||||
| This module provides interfaces for interacting with Ollama's language models, | ||||
| including text generation and embedding capabilities. | ||||
|  | ||||
| Author: Lightrag team | ||||
| Created: 2024-01-24 | ||||
| License: MIT License | ||||
|  | ||||
| Copyright (c) 2024 Lightrag | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Change Log: | ||||
| - 1.0.0 (2024-01-24): Initial release | ||||
|     * Added async chat completion support | ||||
|     * Added embedding generation | ||||
|     * Added stream response capability | ||||
|  | ||||
| Dependencies: | ||||
|     - ollama | ||||
|     - numpy | ||||
|     - pipmaster | ||||
|     - Python >= 3.10 | ||||
|  | ||||
| Usage: | ||||
|     from llm_interfaces.ollama_interface import ollama_model_complete, ollama_embed | ||||
| """ | ||||
|  | ||||
| __version__ = "1.0.0" | ||||
| __author__ = "lightrag Team" | ||||
| __status__ = "Production" | ||||
|  | ||||
| import sys | ||||
|  | ||||
| if sys.version_info < (3, 9): | ||||
|     from typing import AsyncIterator | ||||
| else: | ||||
|     from collections.abc import AsyncIterator | ||||
| import pipmaster as pm  # Pipmaster for dynamic library install | ||||
|  | ||||
| # install specific modules | ||||
| if not pm.is_installed("ollama"): | ||||
|     pm.install("ollama") | ||||
| if not pm.is_installed("tenacity"): | ||||
|     pm.install("tenacity") | ||||
|  | ||||
| import ollama | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
| from lightrag.exceptions import ( | ||||
|     APIConnectionError, | ||||
|     RateLimitError, | ||||
|     APITimeoutError, | ||||
| ) | ||||
| import numpy as np | ||||
| from typing import Union | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APITimeoutError) | ||||
|     ), | ||||
| ) | ||||
| async def ollama_model_if_cache( | ||||
|     model, | ||||
|     prompt, | ||||
|     system_prompt=None, | ||||
|     history_messages=[], | ||||
|     **kwargs, | ||||
| ) -> Union[str, AsyncIterator[str]]: | ||||
|     stream = True if kwargs.get("stream") else False | ||||
|     kwargs.pop("max_tokens", None) | ||||
|     # kwargs.pop("response_format", None) # allow json | ||||
|     host = kwargs.pop("host", None) | ||||
|     timeout = kwargs.pop("timeout", None) | ||||
|     kwargs.pop("hashing_kv", None) | ||||
|     api_key = kwargs.pop("api_key", None) | ||||
|     headers = ( | ||||
|         {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} | ||||
|         if api_key | ||||
|         else {"Content-Type": "application/json"} | ||||
|     ) | ||||
|     ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers) | ||||
|     messages = [] | ||||
|     if system_prompt: | ||||
|         messages.append({"role": "system", "content": system_prompt}) | ||||
|     messages.extend(history_messages) | ||||
|     messages.append({"role": "user", "content": prompt}) | ||||
|  | ||||
|     response = await ollama_client.chat(model=model, messages=messages, **kwargs) | ||||
|     if stream: | ||||
|         """cannot cache stream response""" | ||||
|  | ||||
|         async def inner(): | ||||
|             async for chunk in response: | ||||
|                 yield chunk["message"]["content"] | ||||
|  | ||||
|         return inner() | ||||
|     else: | ||||
|         return response["message"]["content"] | ||||
|  | ||||
|  | ||||
| async def ollama_model_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | ||||
| ) -> Union[str, AsyncIterator[str]]: | ||||
|     keyword_extraction = kwargs.pop("keyword_extraction", None) | ||||
|     if keyword_extraction: | ||||
|         kwargs["format"] = "json" | ||||
|     model_name = kwargs["hashing_kv"].global_config["llm_model_name"] | ||||
|     return await ollama_model_if_cache( | ||||
|         model_name, | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: | ||||
|     """ | ||||
|     Deprecated in favor of `embed`. | ||||
|     """ | ||||
|     embed_text = [] | ||||
|     ollama_client = ollama.Client(**kwargs) | ||||
|     for text in texts: | ||||
|         data = ollama_client.embeddings(model=embed_model, prompt=text) | ||||
|         embed_text.append(data["embedding"]) | ||||
|  | ||||
|     return embed_text | ||||
|  | ||||
|  | ||||
| async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: | ||||
|     api_key = kwargs.pop("api_key", None) | ||||
|     headers = ( | ||||
|         {"Content-Type": "application/json", "Authorization": api_key} | ||||
|         if api_key | ||||
|         else {"Content-Type": "application/json"} | ||||
|     ) | ||||
|     kwargs["headers"] = headers | ||||
|     ollama_client = ollama.Client(**kwargs) | ||||
|     data = ollama_client.embed(model=embed_model, input=texts) | ||||
|     return data["embeddings"] | ||||
							
								
								
									
										230
									
								
								minirag/llm/openai.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										230
									
								
								minirag/llm/openai.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,230 @@ | ||||
| """ | ||||
| OpenAI LLM Interface Module | ||||
| ========================== | ||||
|  | ||||
| This module provides interfaces for interacting with openai's language models, | ||||
| including text generation and embedding capabilities. | ||||
|  | ||||
| Author: Lightrag team | ||||
| Created: 2024-01-24 | ||||
| License: MIT License | ||||
|  | ||||
| Copyright (c) 2024 Lightrag | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Change Log: | ||||
| - 1.0.0 (2024-01-24): Initial release | ||||
|     * Added async chat completion support | ||||
|     * Added embedding generation | ||||
|     * Added stream response capability | ||||
|  | ||||
| Dependencies: | ||||
|     - openai | ||||
|     - numpy | ||||
|     - pipmaster | ||||
|     - Python >= 3.10 | ||||
|  | ||||
| Usage: | ||||
|     from llm_interfaces.openai import openai_model_complete, openai_embed | ||||
| """ | ||||
|  | ||||
| __version__ = "1.0.0" | ||||
| __author__ = "lightrag Team" | ||||
| __status__ = "Production" | ||||
|  | ||||
|  | ||||
| import sys | ||||
| import os | ||||
|  | ||||
| if sys.version_info < (3, 9): | ||||
|     from typing import AsyncIterator | ||||
| else: | ||||
|     from collections.abc import AsyncIterator | ||||
| import pipmaster as pm  # Pipmaster for dynamic library install | ||||
|  | ||||
| # install specific modules | ||||
| if not pm.is_installed("openai"): | ||||
|     pm.install("openai") | ||||
|  | ||||
| from openai import ( | ||||
|     AsyncOpenAI, | ||||
|     APIConnectionError, | ||||
|     RateLimitError, | ||||
|     APITimeoutError, | ||||
| ) | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
| from lightrag.utils import ( | ||||
|     wrap_embedding_func_with_attrs, | ||||
|     locate_json_string_body_from_string, | ||||
|     safe_unicode_decode, | ||||
|     logger, | ||||
| ) | ||||
| from lightrag.types import GPTKeywordExtractionFormat | ||||
|  | ||||
| import numpy as np | ||||
| from typing import Union | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APITimeoutError) | ||||
|     ), | ||||
| ) | ||||
| async def openai_complete_if_cache( | ||||
|     model, | ||||
|     prompt, | ||||
|     system_prompt=None, | ||||
|     history_messages=[], | ||||
|     base_url=None, | ||||
|     api_key=None, | ||||
|     **kwargs, | ||||
| ) -> str: | ||||
|     if api_key: | ||||
|         os.environ["OPENAI_API_KEY"] = api_key | ||||
|  | ||||
|     openai_async_client = ( | ||||
|         AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) | ||||
|     ) | ||||
|     kwargs.pop("hashing_kv", None) | ||||
|     kwargs.pop("keyword_extraction", None) | ||||
|     messages = [] | ||||
|     if system_prompt: | ||||
|         messages.append({"role": "system", "content": system_prompt}) | ||||
|     messages.extend(history_messages) | ||||
|     messages.append({"role": "user", "content": prompt}) | ||||
|  | ||||
|     # 添加日志输出 | ||||
|     logger.debug("===== Query Input to LLM =====") | ||||
|     logger.debug(f"Query: {prompt}") | ||||
|     logger.debug(f"System prompt: {system_prompt}") | ||||
|     logger.debug("Full context:") | ||||
|     if "response_format" in kwargs: | ||||
|         response = await openai_async_client.beta.chat.completions.parse( | ||||
|             model=model, messages=messages, **kwargs | ||||
|         ) | ||||
|     else: | ||||
|         response = await openai_async_client.chat.completions.create( | ||||
|             model=model, messages=messages, **kwargs | ||||
|         ) | ||||
|  | ||||
|     if hasattr(response, "__aiter__"): | ||||
|  | ||||
|         async def inner(): | ||||
|             async for chunk in response: | ||||
|                 content = chunk.choices[0].delta.content | ||||
|                 if content is None: | ||||
|                     continue | ||||
|                 if r"\u" in content: | ||||
|                     content = safe_unicode_decode(content.encode("utf-8")) | ||||
|                 yield content | ||||
|  | ||||
|         return inner() | ||||
|     else: | ||||
|         content = response.choices[0].message.content | ||||
|         if r"\u" in content: | ||||
|             content = safe_unicode_decode(content.encode("utf-8")) | ||||
|         return content | ||||
|  | ||||
|  | ||||
| async def openai_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | ||||
| ) -> Union[str, AsyncIterator[str]]: | ||||
|     keyword_extraction = kwargs.pop("keyword_extraction", None) | ||||
|     if keyword_extraction: | ||||
|         kwargs["response_format"] = "json" | ||||
|     model_name = kwargs["hashing_kv"].global_config["llm_model_name"] | ||||
|     return await openai_complete_if_cache( | ||||
|         model_name, | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def gpt_4o_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | ||||
| ) -> str: | ||||
|     keyword_extraction = kwargs.pop("keyword_extraction", None) | ||||
|     if keyword_extraction: | ||||
|         kwargs["response_format"] = GPTKeywordExtractionFormat | ||||
|     return await openai_complete_if_cache( | ||||
|         "gpt-4o", | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def gpt_4o_mini_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | ||||
| ) -> str: | ||||
|     keyword_extraction = kwargs.pop("keyword_extraction", None) | ||||
|     if keyword_extraction: | ||||
|         kwargs["response_format"] = GPTKeywordExtractionFormat | ||||
|     return await openai_complete_if_cache( | ||||
|         "gpt-4o-mini", | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         **kwargs, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| async def nvidia_openai_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | ||||
| ) -> str: | ||||
|     keyword_extraction = kwargs.pop("keyword_extraction", None) | ||||
|     result = await openai_complete_if_cache( | ||||
|         "nvidia/llama-3.1-nemotron-70b-instruct",  # context length 128k | ||||
|         prompt, | ||||
|         system_prompt=system_prompt, | ||||
|         history_messages=history_messages, | ||||
|         base_url="https://integrate.api.nvidia.com/v1", | ||||
|         **kwargs, | ||||
|     ) | ||||
|     if keyword_extraction:  # TODO: use JSON API | ||||
|         return locate_json_string_body_from_string(result) | ||||
|     return result | ||||
|  | ||||
|  | ||||
| @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=60), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APITimeoutError) | ||||
|     ), | ||||
| ) | ||||
| async def openai_embed( | ||||
|     texts: list[str], | ||||
|     model: str = "text-embedding-3-small", | ||||
|     base_url: str = None, | ||||
|     api_key: str = None, | ||||
| ) -> np.ndarray: | ||||
|     if api_key: | ||||
|         os.environ["OPENAI_API_KEY"] = api_key | ||||
|  | ||||
|     openai_async_client = ( | ||||
|         AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) | ||||
|     ) | ||||
|     response = await openai_async_client.embeddings.create( | ||||
|         model=model, input=texts, encoding_format="float" | ||||
|     ) | ||||
|     return np.array([dp.embedding for dp in response.data]) | ||||
							
								
								
									
										109
									
								
								minirag/llm/siliconcloud.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								minirag/llm/siliconcloud.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | ||||
| """ | ||||
| SiliconCloud Embedding Interface Module | ||||
| ========================== | ||||
|  | ||||
| This module provides interfaces for interacting with SiliconCloud system, | ||||
| including embedding capabilities. | ||||
|  | ||||
| Author: Lightrag team | ||||
| Created: 2024-01-24 | ||||
| License: MIT License | ||||
|  | ||||
| Copyright (c) 2024 Lightrag | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Change Log: | ||||
| - 1.0.0 (2024-01-24): Initial release | ||||
|     * Added embedding generation | ||||
|  | ||||
| Dependencies: | ||||
|     - tenacity | ||||
|     - numpy | ||||
|     - pipmaster | ||||
|     - Python >= 3.10 | ||||
|  | ||||
| Usage: | ||||
|     from llm_interfaces.siliconcloud import siliconcloud_model_complete, siliconcloud_embed | ||||
| """ | ||||
|  | ||||
| __version__ = "1.0.0" | ||||
| __author__ = "lightrag Team" | ||||
| __status__ = "Production" | ||||
|  | ||||
| import sys | ||||
|  | ||||
| if sys.version_info < (3, 9): | ||||
|     pass | ||||
| else: | ||||
|     pass | ||||
| import pipmaster as pm  # Pipmaster for dynamic library install | ||||
|  | ||||
| # install specific modules | ||||
| if not pm.is_installed("lmdeploy"): | ||||
|     pm.install("lmdeploy") | ||||
|  | ||||
| from openai import ( | ||||
|     APIConnectionError, | ||||
|     RateLimitError, | ||||
|     APITimeoutError, | ||||
| ) | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
|  | ||||
|  | ||||
| import numpy as np | ||||
| import aiohttp | ||||
| import base64 | ||||
| import struct | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=60), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APITimeoutError) | ||||
|     ), | ||||
| ) | ||||
| async def siliconcloud_embedding( | ||||
|     texts: list[str], | ||||
|     model: str = "netease-youdao/bce-embedding-base_v1", | ||||
|     base_url: str = "https://api.siliconflow.cn/v1/embeddings", | ||||
|     max_token_size: int = 512, | ||||
|     api_key: str = None, | ||||
| ) -> np.ndarray: | ||||
|     if api_key and not api_key.startswith("Bearer "): | ||||
|         api_key = "Bearer " + api_key | ||||
|  | ||||
|     headers = {"Authorization": api_key, "Content-Type": "application/json"} | ||||
|  | ||||
|     truncate_texts = [text[0:max_token_size] for text in texts] | ||||
|  | ||||
|     payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"} | ||||
|  | ||||
|     base64_strings = [] | ||||
|     async with aiohttp.ClientSession() as session: | ||||
|         async with session.post(base_url, headers=headers, json=payload) as response: | ||||
|             content = await response.json() | ||||
|             if "code" in content: | ||||
|                 raise ValueError(content) | ||||
|             base64_strings = [item["embedding"] for item in content["data"]] | ||||
|  | ||||
|     embeddings = [] | ||||
|     for string in base64_strings: | ||||
|         decode_bytes = base64.b64decode(string) | ||||
|         n = len(decode_bytes) // 4 | ||||
|         float_array = struct.unpack("<" + "f" * n, decode_bytes) | ||||
|         embeddings.append(float_array) | ||||
|     return np.array(embeddings) | ||||
							
								
								
									
										246
									
								
								minirag/llm/zhipu.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										246
									
								
								minirag/llm/zhipu.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,246 @@ | ||||
| """ | ||||
| Zhipu LLM Interface Module | ||||
| ========================== | ||||
|  | ||||
| This module provides interfaces for interacting with LMDeploy's language models, | ||||
| including text generation and embedding capabilities. | ||||
|  | ||||
| Author: Lightrag team | ||||
| Created: 2024-01-24 | ||||
| License: MIT License | ||||
|  | ||||
| Copyright (c) 2024 Lightrag | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| of this software and associated documentation files (the "Software"), to deal | ||||
| in the Software without restriction, including without limitation the rights | ||||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| copies of the Software, and to permit persons to whom the Software is | ||||
| furnished to do so, subject to the following conditions: | ||||
|  | ||||
| Version: 1.0.0 | ||||
|  | ||||
| Change Log: | ||||
| - 1.0.0 (2024-01-24): Initial release | ||||
|     * Added async chat completion support | ||||
|     * Added embedding generation | ||||
|     * Added stream response capability | ||||
|  | ||||
| Dependencies: | ||||
|     - tenacity | ||||
|     - numpy | ||||
|     - pipmaster | ||||
|     - Python >= 3.10 | ||||
|  | ||||
| Usage: | ||||
|     from llm_interfaces.zhipu import zhipu_model_complete, zhipu_embed | ||||
| """ | ||||
|  | ||||
| __version__ = "1.0.0" | ||||
| __author__ = "lightrag Team" | ||||
| __status__ = "Production" | ||||
|  | ||||
| import sys | ||||
| import re | ||||
| import json | ||||
|  | ||||
| if sys.version_info < (3, 9): | ||||
|     pass | ||||
| else: | ||||
|     pass | ||||
| import pipmaster as pm  # Pipmaster for dynamic library install | ||||
|  | ||||
| # install specific modules | ||||
| if not pm.is_installed("zhipuai"): | ||||
|     pm.install("zhipuai") | ||||
|  | ||||
| from openai import ( | ||||
|     APIConnectionError, | ||||
|     RateLimitError, | ||||
|     APITimeoutError, | ||||
| ) | ||||
| from tenacity import ( | ||||
|     retry, | ||||
|     stop_after_attempt, | ||||
|     wait_exponential, | ||||
|     retry_if_exception_type, | ||||
| ) | ||||
|  | ||||
| from lightrag.utils import ( | ||||
|     wrap_embedding_func_with_attrs, | ||||
|     logger, | ||||
| ) | ||||
|  | ||||
| from lightrag.types import GPTKeywordExtractionFormat | ||||
|  | ||||
| import numpy as np | ||||
| from typing import Union, List, Optional, Dict | ||||
|  | ||||
|  | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=10), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APITimeoutError) | ||||
|     ), | ||||
| ) | ||||
| async def zhipu_complete_if_cache( | ||||
|     prompt: Union[str, List[Dict[str, str]]], | ||||
|     model: str = "glm-4-flashx",  # The most cost/performance balance model in glm-4 series | ||||
|     api_key: Optional[str] = None, | ||||
|     system_prompt: Optional[str] = None, | ||||
|     history_messages: List[Dict[str, str]] = [], | ||||
|     **kwargs, | ||||
| ) -> str: | ||||
|     # dynamically load ZhipuAI | ||||
|     try: | ||||
|         from zhipuai import ZhipuAI | ||||
|     except ImportError: | ||||
|         raise ImportError("Please install zhipuai before initialize zhipuai backend.") | ||||
|  | ||||
|     if api_key: | ||||
|         client = ZhipuAI(api_key=api_key) | ||||
|     else: | ||||
|         # please set ZHIPUAI_API_KEY in your environment | ||||
|         # os.environ["ZHIPUAI_API_KEY"] | ||||
|         client = ZhipuAI() | ||||
|  | ||||
|     messages = [] | ||||
|  | ||||
|     if not system_prompt: | ||||
|         system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***" | ||||
|  | ||||
|     # Add system prompt if provided | ||||
|     if system_prompt: | ||||
|         messages.append({"role": "system", "content": system_prompt}) | ||||
|     messages.extend(history_messages) | ||||
|     messages.append({"role": "user", "content": prompt}) | ||||
|  | ||||
|     # Add debug logging | ||||
|     logger.debug("===== Query Input to LLM =====") | ||||
|     logger.debug(f"Query: {prompt}") | ||||
|     logger.debug(f"System prompt: {system_prompt}") | ||||
|  | ||||
|     # Remove unsupported kwargs | ||||
|     kwargs = { | ||||
|         k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"] | ||||
|     } | ||||
|  | ||||
|     response = client.chat.completions.create(model=model, messages=messages, **kwargs) | ||||
|  | ||||
|     return response.choices[0].message.content | ||||
|  | ||||
|  | ||||
| async def zhipu_complete( | ||||
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | ||||
| ): | ||||
|     # Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache | ||||
|     keyword_extraction = kwargs.pop("keyword_extraction", None) | ||||
|  | ||||
|     if keyword_extraction: | ||||
|         # Add a system prompt to guide the model to return JSON format | ||||
|         extraction_prompt = """You are a helpful assistant that extracts keywords from text. | ||||
|         Please analyze the content and extract two types of keywords: | ||||
|         1. High-level keywords: Important concepts and main themes | ||||
|         2. Low-level keywords: Specific details and supporting elements | ||||
|  | ||||
|         Return your response in this exact JSON format: | ||||
|         { | ||||
|             "high_level_keywords": ["keyword1", "keyword2"], | ||||
|             "low_level_keywords": ["keyword1", "keyword2", "keyword3"] | ||||
|         } | ||||
|  | ||||
|         Only return the JSON, no other text.""" | ||||
|  | ||||
|         # Combine with existing system prompt if any | ||||
|         if system_prompt: | ||||
|             system_prompt = f"{system_prompt}\n\n{extraction_prompt}" | ||||
|         else: | ||||
|             system_prompt = extraction_prompt | ||||
|  | ||||
|         try: | ||||
|             response = await zhipu_complete_if_cache( | ||||
|                 prompt=prompt, | ||||
|                 system_prompt=system_prompt, | ||||
|                 history_messages=history_messages, | ||||
|                 **kwargs, | ||||
|             ) | ||||
|  | ||||
|             # Try to parse as JSON | ||||
|             try: | ||||
|                 data = json.loads(response) | ||||
|                 return GPTKeywordExtractionFormat( | ||||
|                     high_level_keywords=data.get("high_level_keywords", []), | ||||
|                     low_level_keywords=data.get("low_level_keywords", []), | ||||
|                 ) | ||||
|             except json.JSONDecodeError: | ||||
|                 # If direct JSON parsing fails, try to extract JSON from text | ||||
|                 match = re.search(r"\{[\s\S]*\}", response) | ||||
|                 if match: | ||||
|                     try: | ||||
|                         data = json.loads(match.group()) | ||||
|                         return GPTKeywordExtractionFormat( | ||||
|                             high_level_keywords=data.get("high_level_keywords", []), | ||||
|                             low_level_keywords=data.get("low_level_keywords", []), | ||||
|                         ) | ||||
|                     except json.JSONDecodeError: | ||||
|                         pass | ||||
|  | ||||
|                 # If all parsing fails, log warning and return empty format | ||||
|                 logger.warning( | ||||
|                     f"Failed to parse keyword extraction response: {response}" | ||||
|                 ) | ||||
|                 return GPTKeywordExtractionFormat( | ||||
|                     high_level_keywords=[], low_level_keywords=[] | ||||
|                 ) | ||||
|         except Exception as e: | ||||
|             logger.error(f"Error during keyword extraction: {str(e)}") | ||||
|             return GPTKeywordExtractionFormat( | ||||
|                 high_level_keywords=[], low_level_keywords=[] | ||||
|             ) | ||||
|     else: | ||||
|         # For non-keyword-extraction, just return the raw response string | ||||
|         return await zhipu_complete_if_cache( | ||||
|             prompt=prompt, | ||||
|             system_prompt=system_prompt, | ||||
|             history_messages=history_messages, | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) | ||||
| @retry( | ||||
|     stop=stop_after_attempt(3), | ||||
|     wait=wait_exponential(multiplier=1, min=4, max=60), | ||||
|     retry=retry_if_exception_type( | ||||
|         (RateLimitError, APIConnectionError, APITimeoutError) | ||||
|     ), | ||||
| ) | ||||
| async def zhipu_embedding( | ||||
|     texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs | ||||
| ) -> np.ndarray: | ||||
|     # dynamically load ZhipuAI | ||||
|     try: | ||||
|         from zhipuai import ZhipuAI | ||||
|     except ImportError: | ||||
|         raise ImportError("Please install zhipuai before initialize zhipuai backend.") | ||||
|     if api_key: | ||||
|         client = ZhipuAI(api_key=api_key) | ||||
|     else: | ||||
|         # please set ZHIPUAI_API_KEY in your environment | ||||
|         # os.environ["ZHIPUAI_API_KEY"] | ||||
|         client = ZhipuAI() | ||||
|  | ||||
|     # Convert single text to list if needed | ||||
|     if isinstance(texts, str): | ||||
|         texts = [texts] | ||||
|  | ||||
|     embeddings = [] | ||||
|     for text in texts: | ||||
|         try: | ||||
|             response = client.embeddings.create(model=model, input=[text], **kwargs) | ||||
|             embeddings.append(response.data[0].embedding) | ||||
|         except Exception as e: | ||||
|             raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}") | ||||
|  | ||||
|     return np.array(embeddings) | ||||
| @@ -33,15 +33,31 @@ from .base import ( | ||||
|     QueryParam, | ||||
| ) | ||||
|  | ||||
| from .storage import ( | ||||
|     JsonKVStorage, | ||||
|     NanoVectorDBStorage, | ||||
|     NetworkXStorage, | ||||
| ) | ||||
|  | ||||
| from .kg.neo4j_impl import Neo4JStorage | ||||
|  | ||||
| from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage | ||||
| STORAGES = { | ||||
|     "NetworkXStorage": ".kg.networkx_impl", | ||||
|     "JsonKVStorage": ".kg.json_kv_impl", | ||||
|     "NanoVectorDBStorage": ".kg.nano_vector_db_impl", | ||||
|     "JsonDocStatusStorage": ".kg.jsondocstatus_impl", | ||||
|     "Neo4JStorage": ".kg.neo4j_impl", | ||||
|     "OracleKVStorage": ".kg.oracle_impl", | ||||
|     "OracleGraphStorage": ".kg.oracle_impl", | ||||
|     "OracleVectorDBStorage": ".kg.oracle_impl", | ||||
|     "MilvusVectorDBStorge": ".kg.milvus_impl", | ||||
|     "MongoKVStorage": ".kg.mongo_impl", | ||||
|     "MongoGraphStorage": ".kg.mongo_impl", | ||||
|     "RedisKVStorage": ".kg.redis_impl", | ||||
|     "ChromaVectorDBStorage": ".kg.chroma_impl", | ||||
|     "TiDBKVStorage": ".kg.tidb_impl", | ||||
|     "TiDBVectorDBStorage": ".kg.tidb_impl", | ||||
|     "TiDBGraphStorage": ".kg.tidb_impl", | ||||
|     "PGKVStorage": ".kg.postgres_impl", | ||||
|     "PGVectorStorage": ".kg.postgres_impl", | ||||
|     "AGEStorage": ".kg.age_impl", | ||||
|     "PGGraphStorage": ".kg.postgres_impl", | ||||
|     "GremlinStorage": ".kg.gremlin_impl", | ||||
|     "PGDocStatusStorage": ".kg.postgres_impl", | ||||
| } | ||||
|  | ||||
| # future KG integrations | ||||
|  | ||||
| @@ -49,17 +65,50 @@ from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBS | ||||
| #     GraphStorage as ArangoDBStorage | ||||
| # ) | ||||
|  | ||||
| def lazy_external_import(module_name: str, class_name: str): | ||||
|     """Lazily import a class from an external module based on the package of the caller.""" | ||||
|  | ||||
|     # Get the caller's module and package | ||||
|     import inspect | ||||
|  | ||||
|     caller_frame = inspect.currentframe().f_back | ||||
|     module = inspect.getmodule(caller_frame) | ||||
|     package = module.__package__ if module else None | ||||
|  | ||||
|     def import_class(*args, **kwargs): | ||||
|         import importlib | ||||
|  | ||||
|         module = importlib.import_module(module_name, package=package) | ||||
|         cls = getattr(module, class_name) | ||||
|         return cls(*args, **kwargs) | ||||
|  | ||||
|     return import_class | ||||
|  | ||||
|  | ||||
| def always_get_an_event_loop() -> asyncio.AbstractEventLoop: | ||||
|     """ | ||||
|     Ensure that there is always an event loop available. | ||||
|  | ||||
|     This function tries to get the current event loop. If the current event loop is closed or does not exist, | ||||
|     it creates a new event loop and sets it as the current event loop. | ||||
|  | ||||
|     Returns: | ||||
|         asyncio.AbstractEventLoop: The current or newly created event loop. | ||||
|     """ | ||||
|     try: | ||||
|         return asyncio.get_event_loop() | ||||
|         # Try to get the current event loop | ||||
|         current_loop = asyncio.get_event_loop() | ||||
|         if current_loop.is_closed(): | ||||
|             raise RuntimeError("Event loop is closed.") | ||||
|         return current_loop | ||||
|  | ||||
|     except RuntimeError: | ||||
|         # If no event loop exists or it is closed, create a new one | ||||
|         logger.info("Creating a new event loop in main thread.") | ||||
|         loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(loop) | ||||
|         new_loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(new_loop) | ||||
|         return new_loop | ||||
|  | ||||
|         return loop | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| @@ -100,12 +149,12 @@ class MiniRAG: | ||||
|         } | ||||
|     ) | ||||
|  | ||||
|     embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) | ||||
|     embedding_func: EmbeddingFunc = None | ||||
|     embedding_batch_num: int = 32 | ||||
|     embedding_func_max_async: int = 16 | ||||
|  | ||||
|     # LLM | ||||
|     llm_model_func: callable = hf_model_complete#gpt_4o_mini_complete  #  | ||||
|     llm_model_func: callable = None | ||||
|     llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"  #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' | ||||
|     llm_model_max_token_size: int = 32768 | ||||
|     llm_model_max_async: int = 16 | ||||
| @@ -120,27 +169,55 @@ class MiniRAG: | ||||
|     addon_params: dict = field(default_factory=dict) | ||||
|     convert_response_to_json_func: callable = convert_response_to_json | ||||
|  | ||||
|     # Add new field for document status storage type | ||||
|     doc_status_storage: str = field(default="JsonDocStatusStorage") | ||||
|  | ||||
|     # Custom Chunking Function | ||||
|     chunking_func: callable = chunking_by_token_size | ||||
|     chunking_func_kwargs: dict = field(default_factory=dict) | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         log_file = os.path.join(self.working_dir, "minirag.log") | ||||
|         set_logger(log_file) | ||||
|         logger.setLevel(self.log_level) | ||||
|  | ||||
|         logger.info(f"Logger initialized for working directory: {self.working_dir}") | ||||
|         if not os.path.exists(self.working_dir): | ||||
|             logger.info(f"Creating working directory {self.working_dir}") | ||||
|             os.makedirs(self.working_dir) | ||||
|  | ||||
|         _print_config = ",\n  ".join([f"{k} = {v}" for k, v in asdict(self).items()]) | ||||
|         # show config | ||||
|         global_config = asdict(self) | ||||
|         _print_config = ",\n  ".join([f"{k} = {v}" for k, v in global_config.items()]) | ||||
|         logger.debug(f"MiniRAG init with param:\n  {_print_config}\n") | ||||
|  | ||||
|         # @TODO: should move all storage setup here to leverage initial start params attached to self. | ||||
|  | ||||
|         self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( | ||||
|             self._get_storage_class()[self.kv_storage] | ||||
|             self._get_storage_class(self.kv_storage) | ||||
|         ) | ||||
|         self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[ | ||||
|         self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( | ||||
|             self.vector_storage | ||||
|         ] | ||||
|         self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ | ||||
|         ) | ||||
|         self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( | ||||
|             self.graph_storage | ||||
|         ] | ||||
|         ) | ||||
|  | ||||
|         self.key_string_value_json_storage_cls = partial( | ||||
|             self.key_string_value_json_storage_cls, global_config=global_config | ||||
|         ) | ||||
|  | ||||
|         self.vector_db_storage_cls = partial( | ||||
|             self.vector_db_storage_cls, global_config=global_config | ||||
|         ) | ||||
|  | ||||
|         self.graph_storage_cls = partial( | ||||
|             self.graph_storage_cls, global_config=global_config | ||||
|         ) | ||||
|         self.json_doc_status_storage = self.key_string_value_json_storage_cls( | ||||
|             namespace="json_doc_status_storage", | ||||
|             embedding_func=None, | ||||
|         ) | ||||
|  | ||||
|         if not os.path.exists(self.working_dir): | ||||
|             logger.info(f"Creating working directory {self.working_dir}") | ||||
| @@ -218,21 +295,39 @@ class MiniRAG: | ||||
|                 **self.llm_model_kwargs, | ||||
|             ) | ||||
|         ) | ||||
|         # Initialize document status storage | ||||
|         self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) | ||||
|         self.doc_status = self.doc_status_storage_cls( | ||||
|             namespace="doc_status", | ||||
|             global_config=global_config, | ||||
|             embedding_func=None, | ||||
|         ) | ||||
|  | ||||
|     def _get_storage_class(self, storage_name: str) -> dict: | ||||
|         import_path = STORAGES[storage_name] | ||||
|         storage_class = lazy_external_import(import_path, storage_name) | ||||
|         return storage_class | ||||
|  | ||||
|     def set_storage_client(self, db_client): | ||||
|         # Now only tested on Oracle Database | ||||
|         for storage in [ | ||||
|             self.vector_db_storage_cls, | ||||
|             self.graph_storage_cls, | ||||
|             self.doc_status, | ||||
|             self.full_docs, | ||||
|             self.text_chunks, | ||||
|             self.llm_response_cache, | ||||
|             self.key_string_value_json_storage_cls, | ||||
|             self.chunks_vdb, | ||||
|             self.relationships_vdb, | ||||
|             self.entities_vdb, | ||||
|             self.graph_storage_cls, | ||||
|             self.chunk_entity_relation_graph, | ||||
|             self.llm_response_cache, | ||||
|         ]: | ||||
|             # set client | ||||
|             storage.db = db_client | ||||
|  | ||||
|     def _get_storage_class(self) -> Type[BaseGraphStorage]: | ||||
|         return { | ||||
|             # kv storage | ||||
|             "JsonKVStorage": JsonKVStorage, | ||||
|             "OracleKVStorage": OracleKVStorage, | ||||
|             # vector storage | ||||
|             "NanoVectorDBStorage": NanoVectorDBStorage, | ||||
|             "OracleVectorDBStorage": OracleVectorDBStorage, | ||||
|             # graph storage | ||||
|             "NetworkXStorage": NetworkXStorage, | ||||
|             "Neo4JStorage": Neo4JStorage, | ||||
|             "OracleGraphStorage": OracleGraphStorage, | ||||
|             # "ArangoDBStorage": ArangoDBStorage | ||||
|         } | ||||
|  | ||||
|     def insert(self, string_or_strings): | ||||
|         loop = always_get_an_event_loop() | ||||
|   | ||||
| @@ -1,354 +0,0 @@ | ||||
| import asyncio | ||||
| import html | ||||
| import os | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Union, cast | ||||
| import networkx as nx | ||||
| import numpy as np | ||||
| from nano_vectordb import NanoVectorDB | ||||
| import copy | ||||
| from .utils import ( | ||||
|     logger, | ||||
|     load_json, | ||||
|     write_json, | ||||
|     compute_mdhash_id, | ||||
|     merge_tuples, | ||||
| ) | ||||
|  | ||||
| from .base import ( | ||||
|     BaseGraphStorage, | ||||
|     BaseKVStorage, | ||||
|     BaseVectorStorage, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class JsonKVStorage(BaseKVStorage): | ||||
|     def __post_init__(self): | ||||
|         working_dir = self.global_config["working_dir"] | ||||
|         self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") | ||||
|         self._data = load_json(self._file_name) or {} | ||||
|         logger.info(f"Load KV {self.namespace} with {len(self._data)} data") | ||||
|  | ||||
|     async def all_keys(self) -> list[str]: | ||||
|         return list(self._data.keys()) | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         write_json(self._data, self._file_name) | ||||
|  | ||||
|     async def get_by_id(self, id): | ||||
|         return self._data.get(id, None) | ||||
|  | ||||
|     async def get_by_ids(self, ids, fields=None): | ||||
|         if fields is None: | ||||
|             return [self._data.get(id, None) for id in ids] | ||||
|         return [ | ||||
|             ( | ||||
|                 {k: v for k, v in self._data[id].items() if k in fields} | ||||
|                 if self._data.get(id, None) | ||||
|                 else None | ||||
|             ) | ||||
|             for id in ids | ||||
|         ] | ||||
|  | ||||
|     async def filter_keys(self, data: list[str]) -> set[str]: | ||||
|         return set([s for s in data if s not in self._data]) | ||||
|  | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         left_data = {k: v for k, v in data.items() if k not in self._data} | ||||
|         self._data.update(left_data) | ||||
|         return left_data | ||||
|  | ||||
|     async def drop(self): | ||||
|         self._data = {} | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class NanoVectorDBStorage(BaseVectorStorage): | ||||
|     cosine_better_than_threshold: float = 0.#2 | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         self._client_file_name = os.path.join( | ||||
|             self.global_config["working_dir"], f"vdb_{self.namespace}.json" | ||||
|         ) | ||||
|         self._max_batch_size = self.global_config["embedding_batch_num"] | ||||
|         self._client = NanoVectorDB( | ||||
|             self.embedding_func.embedding_dim, storage_file=self._client_file_name | ||||
|         ) | ||||
|         self.cosine_better_than_threshold = self.global_config.get( | ||||
|             "cosine_better_than_threshold", self.cosine_better_than_threshold | ||||
|         ) | ||||
|  | ||||
|     async def upsert(self, data: dict[str, dict]): | ||||
|         logger.info(f"Inserting {len(data)} vectors to {self.namespace}") | ||||
|         if not len(data): | ||||
|             logger.warning("You insert an empty data to vector DB") | ||||
|             return [] | ||||
|         list_data = [ | ||||
|             { | ||||
|                 "__id__": k, | ||||
|                 **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, | ||||
|             } | ||||
|             for k, v in data.items() | ||||
|         ] | ||||
|         contents = [v["content"] for v in data.values()] | ||||
|         batches = [ | ||||
|             contents[i : i + self._max_batch_size] | ||||
|             for i in range(0, len(contents), self._max_batch_size) | ||||
|         ] | ||||
|         embeddings_list = await asyncio.gather( | ||||
|             *[self.embedding_func(batch) for batch in batches] | ||||
|         ) | ||||
|         embeddings = np.concatenate(embeddings_list) | ||||
|         for i, d in enumerate(list_data): | ||||
|             d["__vector__"] = embeddings[i] | ||||
|         results = self._client.upsert(datas=list_data) | ||||
|         return results | ||||
|  | ||||
|     async def query(self, query: str, top_k=5): | ||||
|         embedding = await self.embedding_func([query]) | ||||
|         embedding = embedding[0] | ||||
|         results = self._client.query( | ||||
|             query=embedding, | ||||
|             top_k=top_k, | ||||
|             better_than_threshold=self.cosine_better_than_threshold, | ||||
|         ) | ||||
|  | ||||
|         results = [ | ||||
|             {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results | ||||
|         ] | ||||
|         return results | ||||
|  | ||||
|     @property | ||||
|     def client_storage(self): | ||||
|         return getattr(self._client, "_NanoVectorDB__storage") | ||||
|  | ||||
|     async def delete_entity(self, entity_name: str): | ||||
|         try: | ||||
|             entity_id = [compute_mdhash_id(entity_name, prefix="ent-")] | ||||
|  | ||||
|             if self._client.get(entity_id): | ||||
|                 self._client.delete(entity_id) | ||||
|                 logger.info(f"Entity {entity_name} have been deleted.") | ||||
|             else: | ||||
|                 logger.info(f"No entity found with name {entity_name}.") | ||||
|         except Exception as e: | ||||
|             logger.error(f"Error while deleting entity {entity_name}: {e}") | ||||
|  | ||||
|     async def delete_relation(self, entity_name: str): | ||||
|         try: | ||||
|             relations = [ | ||||
|                 dp | ||||
|                 for dp in self.client_storage["data"] | ||||
|                 if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name | ||||
|             ] | ||||
|             ids_to_delete = [relation["__id__"] for relation in relations] | ||||
|  | ||||
|             if ids_to_delete: | ||||
|                 self._client.delete(ids_to_delete) | ||||
|                 logger.info( | ||||
|                     f"All relations related to entity {entity_name} have been deleted." | ||||
|                 ) | ||||
|             else: | ||||
|                 logger.info(f"No relations found for entity {entity_name}.") | ||||
|         except Exception as e: | ||||
|             logger.error( | ||||
|                 f"Error while deleting relations for entity {entity_name}: {e}" | ||||
|             ) | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         self._client.save() | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class NetworkXStorage(BaseGraphStorage): | ||||
|     @staticmethod | ||||
|     def load_nx_graph(file_name) -> nx.Graph: | ||||
|         if os.path.exists(file_name): | ||||
|             return nx.read_graphml(file_name) | ||||
|         return None | ||||
|  | ||||
|     @staticmethod | ||||
|     def write_nx_graph(graph: nx.Graph, file_name): | ||||
|         logger.info( | ||||
|             f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges" | ||||
|         ) | ||||
|         nx.write_graphml(graph, file_name) | ||||
|  | ||||
|     @staticmethod | ||||
|     def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: | ||||
|         """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py | ||||
|         Return the largest connected component of the graph, with nodes and edges sorted in a stable way. | ||||
|         """ | ||||
|         from graspologic.utils import largest_connected_component | ||||
|  | ||||
|         graph = graph.copy() | ||||
|         graph = cast(nx.Graph, largest_connected_component(graph)) | ||||
|         node_mapping = { | ||||
|             node: html.unescape(node.upper().strip()) for node in graph.nodes() | ||||
|         }  # type: ignore | ||||
|         graph = nx.relabel_nodes(graph, node_mapping) | ||||
|         return NetworkXStorage._stabilize_graph(graph) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _stabilize_graph(graph: nx.Graph) -> nx.Graph: | ||||
|         """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py | ||||
|         Ensure an undirected graph with the same relationships will always be read the same way. | ||||
|         """ | ||||
|         fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph() | ||||
|  | ||||
|         sorted_nodes = graph.nodes(data=True) | ||||
|         sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0]) | ||||
|  | ||||
|         fixed_graph.add_nodes_from(sorted_nodes) | ||||
|         edges = list(graph.edges(data=True)) | ||||
|  | ||||
|         if not graph.is_directed(): | ||||
|  | ||||
|             def _sort_source_target(edge): | ||||
|                 source, target, edge_data = edge | ||||
|                 if source > target: | ||||
|                     temp = source | ||||
|                     source = target | ||||
|                     target = temp | ||||
|                 return source, target, edge_data | ||||
|  | ||||
|             edges = [_sort_source_target(edge) for edge in edges] | ||||
|  | ||||
|         def _get_edge_key(source: Any, target: Any) -> str: | ||||
|             return f"{source} -> {target}" | ||||
|  | ||||
|         edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) | ||||
|  | ||||
|         fixed_graph.add_edges_from(edges) | ||||
|         return fixed_graph | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         self._graphml_xml_file = os.path.join( | ||||
|             self.global_config["working_dir"], f"graph_{self.namespace}.graphml" | ||||
|         ) | ||||
|         preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) | ||||
|         if preloaded_graph is not None: | ||||
|             logger.info( | ||||
|                 f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges" | ||||
|             ) | ||||
|         self._graph = preloaded_graph or nx.Graph() | ||||
|         self._node_embed_algorithms = { | ||||
|             "node2vec": self._node2vec_embed, | ||||
|         } | ||||
|  | ||||
|     async def index_done_callback(self): | ||||
|         NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file) | ||||
|  | ||||
|     async def has_node(self, node_id: str) -> bool: | ||||
|         return self._graph.has_node(node_id) | ||||
|  | ||||
|     async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: | ||||
|         return self._graph.has_edge(source_node_id, target_node_id) | ||||
|  | ||||
|     async def get_node(self, node_id: str) -> Union[dict, None]: | ||||
|         return self._graph.nodes.get(node_id) | ||||
|      | ||||
|     async def get_types(self) -> list: | ||||
|         all_entity_type = [] | ||||
|         all_type_w_name = {} | ||||
|         for n in self._graph.nodes(data=True): | ||||
|             key = n[1]['entity_type'].strip('\"') | ||||
|             all_entity_type.append(key) | ||||
|             if key not in all_type_w_name: | ||||
|                 all_type_w_name[key] = [] | ||||
|                 all_type_w_name[key].append(n[0].strip('\"')) | ||||
|             else: | ||||
|                 if len(all_type_w_name[key])<=1: | ||||
|                     all_type_w_name[key].append(n[0].strip('\"')) | ||||
|  | ||||
|         return list(set(all_entity_type)),all_type_w_name | ||||
|      | ||||
|  | ||||
|  | ||||
|     async def get_node_from_types(self,type_list)  -> Union[dict, None]: | ||||
|         node_list = [] | ||||
|         for name, arrt in self._graph.nodes(data = True): | ||||
|             node_type = arrt.get('entity_type').strip('\"') | ||||
|             if node_type in type_list: | ||||
|                 node_list.append(name) | ||||
|         node_datas = await asyncio.gather( | ||||
|             *[self.get_node(name) for name in node_list] | ||||
|         ) | ||||
|         node_datas = [ | ||||
|             {**n, "entity_name": k} | ||||
|             for k, n in zip(node_list, node_datas) | ||||
|             if n is not None | ||||
|         ] | ||||
|         return node_datas#,node_dict | ||||
|      | ||||
|  | ||||
|     async def get_neighbors_within_k_hops(self,source_node_id: str, k): | ||||
|         count = 0 | ||||
|         if await self.has_node(source_node_id): | ||||
|             source_edge = list(self._graph.edges(source_node_id)) | ||||
|         else: | ||||
|             print("NO THIS ID:",source_node_id) | ||||
|             return [] | ||||
|         count = count+1 | ||||
|         while count<k: | ||||
|             count = count+1 | ||||
|             sc_edge = copy.deepcopy(source_edge) | ||||
|             source_edge =[] | ||||
|             for pair in sc_edge: | ||||
|                 append_edge = list(self._graph.edges(pair[-1])) | ||||
|                 for tuples in merge_tuples([pair],append_edge): | ||||
|                     source_edge.append(tuples) | ||||
|         return source_edge | ||||
|     async def node_degree(self, node_id: str) -> int: | ||||
|         return self._graph.degree(node_id) | ||||
|  | ||||
|     async def edge_degree(self, src_id: str, tgt_id: str) -> int: | ||||
|         return self._graph.degree(src_id) + self._graph.degree(tgt_id) | ||||
|  | ||||
|     async def get_edge( | ||||
|         self, source_node_id: str, target_node_id: str | ||||
|     ) -> Union[dict, None]: | ||||
|         return self._graph.edges.get((source_node_id, target_node_id)) | ||||
|  | ||||
|     async def get_node_edges(self, source_node_id: str): | ||||
|         if self._graph.has_node(source_node_id): | ||||
|             return list(self._graph.edges(source_node_id)) | ||||
|         return None | ||||
|  | ||||
|     async def upsert_node(self, node_id: str, node_data: dict[str, str]): | ||||
|         self._graph.add_node(node_id, **node_data) | ||||
|  | ||||
|     async def upsert_edge( | ||||
|         self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] | ||||
|     ): | ||||
|         self._graph.add_edge(source_node_id, target_node_id, **edge_data) | ||||
|  | ||||
|     async def delete_node(self, node_id: str): | ||||
|         """ | ||||
|         Delete a node from the graph based on the specified node_id. | ||||
|  | ||||
|         :param node_id: The node_id to delete | ||||
|         """ | ||||
|         if self._graph.has_node(node_id): | ||||
|             self._graph.remove_node(node_id) | ||||
|             logger.info(f"Node {node_id} deleted from the graph.") | ||||
|         else: | ||||
|             logger.warning(f"Node {node_id} not found in the graph for deletion.") | ||||
|  | ||||
|     async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: | ||||
|         if algorithm not in self._node_embed_algorithms: | ||||
|             raise ValueError(f"Node embedding algorithm {algorithm} not supported") | ||||
|         return await self._node_embed_algorithms[algorithm]() | ||||
|  | ||||
|     # @TODO: NOT USED | ||||
|     async def _node2vec_embed(self): | ||||
|         from graspologic import embed | ||||
|  | ||||
|         embeddings, nodes = embed.node2vec_embed( | ||||
|             self._graph, | ||||
|             **self.global_config["node2vec_params"], | ||||
|         ) | ||||
|  | ||||
|         nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] | ||||
|         return embeddings, nodes_ids | ||||
| @@ -1,38 +1,33 @@ | ||||
| accelerate | ||||
| aioboto3 | ||||
| aiofiles | ||||
| aiohttp | ||||
| asyncpg | ||||
| configparser | ||||
| graspologic | ||||
|  | ||||
| # database packages | ||||
| graspologic | ||||
| gremlinpython | ||||
| hnswlib | ||||
| nano-vectordb | ||||
| neo4j | ||||
| networkx | ||||
|  | ||||
| # Basic modules | ||||
| numpy | ||||
| ollama | ||||
| openai | ||||
| oracledb | ||||
| psycopg-pool | ||||
| psycopg[binary,pool] | ||||
| pipmaster | ||||
| pydantic | ||||
| pymilvus | ||||
| pymongo | ||||
| pymysql | ||||
|  | ||||
| # File manipulation libraries | ||||
| PyPDF2 | ||||
| python-docx | ||||
| python-dotenv | ||||
| pyvis | ||||
| python-pptx | ||||
|  | ||||
| setuptools | ||||
| # lmdeploy[all] | ||||
| sqlalchemy | ||||
| tenacity | ||||
| json_repair | ||||
| rouge | ||||
| nltk | ||||
|  | ||||
|  | ||||
| # LLM packages | ||||
| tiktoken | ||||
| torch | ||||
| tqdm | ||||
| transformers | ||||
| xxhash | ||||
| xxhash | ||||
|  | ||||
| # Extra libraries are installed when needed using pipmaster | ||||
|   | ||||
							
								
								
									
										19
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								setup.py
									
									
									
									
									
								
							| @@ -22,6 +22,15 @@ def read_requirements(): | ||||
|         ) | ||||
|     return deps | ||||
|  | ||||
| def read_api_requirements(): | ||||
|     api_deps = [] | ||||
|     try: | ||||
|         with open("./minirag/api/requirements.txt") as f: | ||||
|             api_deps = [line.strip() for line in f if line.strip()] | ||||
|     except FileNotFoundError: | ||||
|         print("Warning: API requirements.txt not found.") | ||||
|     return api_deps | ||||
|  | ||||
|  | ||||
| long_description = read_long_description() | ||||
| requirements = read_requirements() | ||||
| @@ -29,7 +38,7 @@ requirements = read_requirements() | ||||
| setuptools.setup( | ||||
|     name="minirag-hku", | ||||
|     url="https://github.com/HKUDS/MiniRAG", | ||||
|     version="0.0.1", | ||||
|     version="0.0.2", | ||||
|     author="HKUDS", | ||||
|     description="MiniRAG: Towards Extremely Simple Retrieval-Augmented Generation", | ||||
|     long_description=long_description, | ||||
| @@ -48,4 +57,12 @@ setuptools.setup( | ||||
|     python_requires=">=3.9",#rec: 3.9.19 | ||||
|     install_requires=requirements, | ||||
|     include_package_data=True,  # Includes non-code files from MANIFEST.in | ||||
|     extras_require={ | ||||
|         "api": read_api_requirements(),  # API requirements as optional | ||||
|     }, | ||||
|     entry_points={ | ||||
|         "console_scripts": [ | ||||
|             "minirag-server=minirag.api.minirag_server:main [api]", | ||||
|         ], | ||||
|     }, | ||||
| ) | ||||
		Reference in New Issue
	
	Block a user
	 Saifeddine ALOUI
					Saifeddine ALOUI