fix bugs and update readme
This commit is contained in:
@@ -16,6 +16,13 @@ The Code Repository: **MiniRAG: Towards Extremely Simple Retrieval-Augmented Gen
|
||||
|
||||
[中文说明](./README_CN.md) | [日本語](./README_JA.md)
|
||||
|
||||
|
||||
|
||||
## 🎉 News
|
||||
- [x] [2025.02.14]🎯📢Now MiniRAG supports 10+ heterogeneous graph databases, including Neo4j, PostgreSQL, TiDB, etc. Happy valentine's day!🌹🌹🌹
|
||||
- [x] [2025.02.05]🎯📢Our team has released [VideoRAG](https://github.com/HKUDS/VideoRAG) understanding extremely long-context videos.
|
||||
- [x] [2025.02.01]🎯📢Now MiniRAG supports API&Docker deployment. see [This](./minirag/api/README.md) for more details.
|
||||
|
||||
## TLDR
|
||||
MiniRAG is an extremely simple retrieval-augmented generation framework that enables small models to achieve good RAG performance through heterogeneous graph indexing and lightweight topology-enhanced retrieval.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ if not pm.is_installed("psycopg-pool"):
|
||||
if not pm.is_installed("asyncpg"):
|
||||
pm.install("asyncpg")
|
||||
|
||||
|
||||
import copy
|
||||
import psycopg
|
||||
from psycopg.rows import namedtuple_row
|
||||
from psycopg_pool import AsyncConnectionPool, PoolTimeout
|
||||
@@ -25,8 +25,8 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from lightrag.utils import logger
|
||||
|
||||
from minirag.utils import logger
|
||||
from minirag.utils import merge_tuples
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
@@ -555,7 +555,52 @@ class AGEStorage(BaseGraphStorage):
|
||||
except Exception as e:
|
||||
logger.error("Error during upsert: {%s}", e)
|
||||
raise
|
||||
async def get_types(self):
|
||||
types = set()
|
||||
types_with_case = set()
|
||||
|
||||
for _, data in self._graph.nodes(data=True):
|
||||
if "type" in data:
|
||||
types.add(data["type"].lower())
|
||||
types_with_case.add(data["type"])
|
||||
return list(types), list(types_with_case)
|
||||
|
||||
|
||||
async def get_node_from_types(self,type_list) -> Union[dict, None]:
|
||||
node_list = []
|
||||
for name, arrt in self._graph.nodes(data = True):
|
||||
node_type = arrt.get('entity_type').strip('\"')
|
||||
if node_type in type_list:
|
||||
node_list.append(name)
|
||||
node_datas = await asyncio.gather(
|
||||
*[self.get_node(name) for name in node_list]
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k}
|
||||
for k, n in zip(node_list, node_datas)
|
||||
if n is not None
|
||||
]
|
||||
return node_datas#,node_dict
|
||||
|
||||
|
||||
async def get_neighbors_within_k_hops(self,source_node_id: str, k):
|
||||
count = 0
|
||||
if await self.has_node(source_node_id):
|
||||
source_edge = list(self._graph.edges(source_node_id))
|
||||
else:
|
||||
print("NO THIS ID:",source_node_id)
|
||||
return []
|
||||
count = count+1
|
||||
while count<k:
|
||||
count = count+1
|
||||
sc_edge = copy.deepcopy(source_edge)
|
||||
source_edge =[]
|
||||
for pair in sc_edge:
|
||||
append_edge = list(self._graph.edges(pair[-1]))
|
||||
for tuples in merge_tuples([pair],append_edge):
|
||||
source_edge.append(tuples)
|
||||
return source_edge
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
|
||||
@@ -5,8 +5,10 @@ from typing import Union
|
||||
import numpy as np
|
||||
from chromadb import HttpClient
|
||||
from chromadb.config import Settings
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from lightrag.utils import logger
|
||||
from minirag.base import BaseVectorStorage
|
||||
from minirag.utils import logger
|
||||
from minirag.utils import merge_tuples
|
||||
import copy
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -15,8 +15,9 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from lightrag.utils import logger
|
||||
|
||||
from minirag.utils import logger
|
||||
import copy
|
||||
from minirag.utils import merge_tuples
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
|
||||
@@ -309,6 +310,42 @@ class GremlinStorage(BaseGraphStorage):
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
async def get_node_from_types(self,type_list) -> Union[dict, None]:
|
||||
node_list = []
|
||||
for name, arrt in self._graph.nodes(data = True):
|
||||
node_type = arrt.get('entity_type').strip('\"')
|
||||
if node_type in type_list:
|
||||
node_list.append(name)
|
||||
node_datas = await asyncio.gather(
|
||||
*[self.get_node(name) for name in node_list]
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k}
|
||||
for k, n in zip(node_list, node_datas)
|
||||
if n is not None
|
||||
]
|
||||
return node_datas#,node_dict
|
||||
|
||||
|
||||
async def get_neighbors_within_k_hops(self,source_node_id: str, k):
|
||||
count = 0
|
||||
if await self.has_node(source_node_id):
|
||||
source_edge = list(self._graph.edges(source_node_id))
|
||||
else:
|
||||
print("NO THIS ID:",source_node_id)
|
||||
return []
|
||||
count = count+1
|
||||
while count<k:
|
||||
count = count+1
|
||||
sc_edge = copy.deepcopy(source_edge)
|
||||
source_edge =[]
|
||||
for pair in sc_edge:
|
||||
append_edge = list(self._graph.edges(pair[-1]))
|
||||
for tuples in merge_tuples([pair],append_edge):
|
||||
source_edge.append(tuples)
|
||||
return source_edge
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(10),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
|
||||
@@ -22,9 +22,10 @@ from tenacity import (
|
||||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import logger
|
||||
from minirag.utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
|
||||
import copy
|
||||
from minirag.utils import merge_tuples
|
||||
|
||||
@dataclass
|
||||
class Neo4JStorage(BaseGraphStorage):
|
||||
@@ -152,6 +153,42 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
return node_dict
|
||||
return None
|
||||
|
||||
|
||||
async def get_node_from_types(self,type_list) -> Union[dict, None]:
|
||||
node_list = []
|
||||
for name, arrt in self._graph.nodes(data = True):
|
||||
node_type = arrt.get('entity_type').strip('\"')
|
||||
if node_type in type_list:
|
||||
node_list.append(name)
|
||||
node_datas = await asyncio.gather(
|
||||
*[self.get_node(name) for name in node_list]
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k}
|
||||
for k, n in zip(node_list, node_datas)
|
||||
if n is not None
|
||||
]
|
||||
return node_datas#,node_dict
|
||||
|
||||
|
||||
async def get_neighbors_within_k_hops(self,source_node_id: str, k):
|
||||
count = 0
|
||||
if await self.has_node(source_node_id):
|
||||
source_edge = list(self._graph.edges(source_node_id))
|
||||
else:
|
||||
print("NO THIS ID:",source_node_id)
|
||||
return []
|
||||
count = count+1
|
||||
while count<k:
|
||||
count = count+1
|
||||
sc_edge = copy.deepcopy(source_edge)
|
||||
source_edge =[]
|
||||
for pair in sc_edge:
|
||||
append_edge = list(self._graph.edges(pair[-1]))
|
||||
for tuples in merge_tuples([pair],append_edge):
|
||||
source_edge.append(tuples)
|
||||
return source_edge
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
|
||||
@@ -44,17 +44,17 @@ Features:
|
||||
- Remove nodes and edges from the graph
|
||||
|
||||
Usage:
|
||||
from lightrag.storage.networkx_storage import NetworkXStorage
|
||||
from minirags.kg.networkx_impl import NetworkXStorage
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union, cast
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
import copy
|
||||
|
||||
from minirag.utils import (
|
||||
logger,
|
||||
@@ -64,6 +64,7 @@ from minirag.base import (
|
||||
BaseGraphStorage,
|
||||
)
|
||||
|
||||
from minirag.utils import merge_tuples
|
||||
|
||||
@dataclass
|
||||
class NetworkXStorage(BaseGraphStorage):
|
||||
@@ -155,6 +156,43 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
types_with_case.add(data["type"])
|
||||
return list(types), list(types_with_case)
|
||||
|
||||
|
||||
async def get_node_from_types(self,type_list) -> Union[dict, None]:
|
||||
node_list = []
|
||||
for name, arrt in self._graph.nodes(data = True):
|
||||
node_type = arrt.get('entity_type').strip('\"')
|
||||
if node_type in type_list:
|
||||
node_list.append(name)
|
||||
node_datas = await asyncio.gather(
|
||||
*[self.get_node(name) for name in node_list]
|
||||
)
|
||||
node_datas = [
|
||||
{**n, "entity_name": k}
|
||||
for k, n in zip(node_list, node_datas)
|
||||
if n is not None
|
||||
]
|
||||
return node_datas#,node_dict
|
||||
|
||||
|
||||
async def get_neighbors_within_k_hops(self,source_node_id: str, k):
|
||||
count = 0
|
||||
if await self.has_node(source_node_id):
|
||||
source_edge = list(self._graph.edges(source_node_id))
|
||||
else:
|
||||
print("NO THIS ID:",source_node_id)
|
||||
return []
|
||||
count = count+1
|
||||
while count<k:
|
||||
count = count+1
|
||||
sc_edge = copy.deepcopy(source_edge)
|
||||
source_edge =[]
|
||||
for pair in sc_edge:
|
||||
append_edge = list(self._graph.edges(pair[-1]))
|
||||
for tuples in merge_tuples([pair],append_edge):
|
||||
source_edge.append(tuples)
|
||||
return source_edge
|
||||
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
return self._graph.has_node(node_id)
|
||||
|
||||
|
||||
@@ -8,9 +8,11 @@ if not pm.is_installed("redis"):
|
||||
|
||||
# aioredis is a depricated library, replaced with redis
|
||||
from redis.asyncio import Redis
|
||||
from lightrag.utils import logger
|
||||
from lightrag.base import BaseKVStorage
|
||||
from minirag.utils import logger
|
||||
from minirag.base import BaseKVStorage
|
||||
import json
|
||||
import copy
|
||||
from minirag.utils import merge_tuples
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,665 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("pymysql"):
|
||||
pm.install("pymysql")
|
||||
if not pm.is_installed("sqlalchemy"):
|
||||
pm.install("sqlalchemy")
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from tqdm import tqdm
|
||||
|
||||
from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
|
||||
from lightrag.utils import logger
|
||||
|
||||
|
||||
class TiDB(object):
|
||||
def __init__(self, config, **kwargs):
|
||||
self.host = config.get("host", None)
|
||||
self.port = config.get("port", None)
|
||||
self.user = config.get("user", None)
|
||||
self.password = config.get("password", None)
|
||||
self.database = config.get("database", None)
|
||||
self.workspace = config.get("workspace", None)
|
||||
connection_string = (
|
||||
f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
||||
f"?ssl_verify_cert=true&ssl_verify_identity=true"
|
||||
)
|
||||
|
||||
try:
|
||||
self.engine = create_engine(connection_string)
|
||||
logger.info(f"Connected to TiDB database at {self.database}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to TiDB database at {self.database}")
|
||||
logger.error(f"TiDB database error: {e}")
|
||||
raise
|
||||
|
||||
async def check_tables(self):
|
||||
for k, v in TABLES.items():
|
||||
try:
|
||||
await self.query(f"SELECT 1 FROM {k}".format(k=k))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check table {k} in TiDB database")
|
||||
logger.error(f"TiDB database error: {e}")
|
||||
try:
|
||||
# print(v["ddl"])
|
||||
await self.execute(v["ddl"])
|
||||
logger.info(f"Created table {k} in TiDB database")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create table {k} in TiDB database")
|
||||
logger.error(f"TiDB database error: {e}")
|
||||
|
||||
async def query(
|
||||
self, sql: str, params: dict = None, multirows: bool = False
|
||||
) -> Union[dict, None]:
|
||||
if params is None:
|
||||
params = {"workspace": self.workspace}
|
||||
else:
|
||||
params.update({"workspace": self.workspace})
|
||||
with self.engine.connect() as conn, conn.begin():
|
||||
try:
|
||||
result = conn.execute(text(sql), params)
|
||||
except Exception as e:
|
||||
logger.error(f"Tidb database error: {e}")
|
||||
print(sql)
|
||||
print(params)
|
||||
raise
|
||||
if multirows:
|
||||
rows = result.all()
|
||||
if rows:
|
||||
data = [dict(zip(result.keys(), row)) for row in rows]
|
||||
else:
|
||||
data = []
|
||||
else:
|
||||
row = result.first()
|
||||
if row:
|
||||
data = dict(zip(result.keys(), row))
|
||||
else:
|
||||
data = None
|
||||
return data
|
||||
|
||||
async def execute(self, sql: str, data: list | dict = None):
|
||||
# logger.info("go into TiDBDB execute method")
|
||||
try:
|
||||
with self.engine.connect() as conn, conn.begin():
|
||||
if data is None:
|
||||
conn.execute(text(sql))
|
||||
else:
|
||||
conn.execute(text(sql), parameters=data)
|
||||
except Exception as e:
|
||||
logger.error(f"TiDB database error: {e}")
|
||||
print(sql)
|
||||
print(data)
|
||||
raise
|
||||
|
||||
|
||||
@dataclass
|
||||
class TiDBKVStorage(BaseKVStorage):
|
||||
# should pass db object to self.db
|
||||
def __post_init__(self):
|
||||
self._data = {}
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
################ QUERY METHODS ################
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict, None]:
|
||||
"""根据 id 获取 doc_full 数据."""
|
||||
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
||||
params = {"id": id}
|
||||
# print("get_by_id:"+SQL)
|
||||
res = await self.db.query(SQL, params)
|
||||
if res:
|
||||
data = res # {"data":res}
|
||||
# print (data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
||||
"""根据 id 获取 doc_chunks 数据"""
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
)
|
||||
# print("get_by_ids:"+SQL)
|
||||
res = await self.db.query(SQL, multirows=True)
|
||||
if res:
|
||||
data = res # [{"data":i} for i in res]
|
||||
# print(data)
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""过滤掉重复内容"""
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||
table_name=N_T[self.namespace],
|
||||
id_field=N_ID[self.namespace],
|
||||
ids=",".join([f"'{id}'" for id in keys]),
|
||||
)
|
||||
try:
|
||||
await self.db.query(SQL)
|
||||
except Exception as e:
|
||||
logger.error(f"Tidb database error: {e}")
|
||||
print(SQL)
|
||||
res = await self.db.query(SQL, multirows=True)
|
||||
if res:
|
||||
exist_keys = [key["id"] for key in res]
|
||||
data = set([s for s in keys if s not in exist_keys])
|
||||
else:
|
||||
exist_keys = []
|
||||
data = set([s for s in keys if s not in exist_keys])
|
||||
return data
|
||||
|
||||
################ INSERT full_doc AND chunks ################
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
self._data.update(left_data)
|
||||
if self.namespace == "text_chunks":
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
**{k1: v1 for k1, v1 in v.items()},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
|
||||
merge_sql = SQL_TEMPLATES["upsert_chunk"]
|
||||
data = []
|
||||
for item in list_data:
|
||||
data.append(
|
||||
{
|
||||
"id": item["__id__"],
|
||||
"content": item["content"],
|
||||
"tokens": item["tokens"],
|
||||
"chunk_order_index": item["chunk_order_index"],
|
||||
"full_doc_id": item["full_doc_id"],
|
||||
"content_vector": f"{item["__vector__"].tolist()}",
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
)
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
if self.namespace == "full_docs":
|
||||
merge_sql = SQL_TEMPLATES["upsert_doc_full"]
|
||||
data = []
|
||||
for k, v in self._data.items():
|
||||
data.append(
|
||||
{
|
||||
"id": k,
|
||||
"content": v["content"],
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
)
|
||||
await self.db.execute(merge_sql, data)
|
||||
return left_data
|
||||
|
||||
async def index_done_callback(self):
|
||||
if self.namespace in ["full_docs", "text_chunks"]:
|
||||
logger.info("full doc and chunk data had been saved into TiDB db!")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
||||
|
||||
def __post_init__(self):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
# Use global config value if specified, otherwise use default
|
||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.cosine_better_than_threshold = config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||
"""search from tidb vector"""
|
||||
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
|
||||
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
|
||||
|
||||
params = {
|
||||
"embedding_string": embedding_string,
|
||||
"top_k": top_k,
|
||||
"better_than_threshold": self.cosine_better_than_threshold,
|
||||
}
|
||||
|
||||
results = await self.db.query(
|
||||
SQL_TEMPLATES[self.namespace], params=params, multirows=True
|
||||
)
|
||||
print("vector search result:", results)
|
||||
if not results:
|
||||
return []
|
||||
return results
|
||||
|
||||
###### INSERT entities And relationships ######
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
# ignore, upsert in TiDBKVStorage already
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
if self.namespace == "chunks":
|
||||
return []
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
|
||||
list_data = [
|
||||
{
|
||||
"id": k,
|
||||
**{k1: v1 for k1, v1 in v.items()},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
||||
embeddings_list = []
|
||||
for f in tqdm(
|
||||
asyncio.as_completed(embedding_tasks),
|
||||
total=len(embedding_tasks),
|
||||
desc="Generating embeddings",
|
||||
unit="batch",
|
||||
):
|
||||
embeddings = await f
|
||||
embeddings_list.append(embeddings)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["content_vector"] = embeddings[i]
|
||||
|
||||
if self.namespace == "entities":
|
||||
data = []
|
||||
for item in list_data:
|
||||
param = {
|
||||
"id": item["id"],
|
||||
"name": item["entity_name"],
|
||||
"content": item["content"],
|
||||
"content_vector": f"{item["content_vector"].tolist()}",
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
# update entity_id if node inserted by graph_storage_instance before
|
||||
has = await self.db.query(SQL_TEMPLATES["has_entity"], param)
|
||||
if has["cnt"] != 0:
|
||||
await self.db.execute(SQL_TEMPLATES["update_entity"], param)
|
||||
continue
|
||||
|
||||
data.append(param)
|
||||
if data:
|
||||
merge_sql = SQL_TEMPLATES["insert_entity"]
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
elif self.namespace == "relationships":
|
||||
data = []
|
||||
for item in list_data:
|
||||
param = {
|
||||
"id": item["id"],
|
||||
"source_name": item["src_id"],
|
||||
"target_name": item["tgt_id"],
|
||||
"content": item["content"],
|
||||
"content_vector": f"{item["content_vector"].tolist()}",
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
# update relation_id if node inserted by graph_storage_instance before
|
||||
has = await self.db.query(SQL_TEMPLATES["has_relationship"], param)
|
||||
if has["cnt"] != 0:
|
||||
await self.db.execute(SQL_TEMPLATES["update_relationship"], param)
|
||||
continue
|
||||
|
||||
data.append(param)
|
||||
if data:
|
||||
merge_sql = SQL_TEMPLATES["insert_relationship"]
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
#################### upsert method ################
|
||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
||||
entity_name = node_id
|
||||
entity_type = node_data["entity_type"]
|
||||
description = node_data["description"]
|
||||
source_id = node_data["source_id"]
|
||||
logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
|
||||
content = entity_name + description
|
||||
contents = [content]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
sql = SQL_TEMPLATES["upsert_node"]
|
||||
data = {
|
||||
"workspace": self.db.workspace,
|
||||
"name": entity_name,
|
||||
"entity_type": entity_type,
|
||||
"description": description,
|
||||
"source_chunk_id": source_id,
|
||||
"content": content,
|
||||
"content_vector": f"{content_vector.tolist()}",
|
||||
}
|
||||
await self.db.execute(sql, data)
|
||||
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||
):
|
||||
source_name = source_node_id
|
||||
target_name = target_node_id
|
||||
weight = edge_data["weight"]
|
||||
keywords = edge_data["keywords"]
|
||||
description = edge_data["description"]
|
||||
source_chunk_id = edge_data["source_id"]
|
||||
logger.debug(
|
||||
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
|
||||
)
|
||||
|
||||
content = keywords + source_name + target_name + description
|
||||
contents = [content]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
content_vector = embeddings[0]
|
||||
merge_sql = SQL_TEMPLATES["upsert_edge"]
|
||||
data = {
|
||||
"workspace": self.db.workspace,
|
||||
"source_name": source_name,
|
||||
"target_name": target_name,
|
||||
"weight": weight,
|
||||
"keywords": keywords,
|
||||
"description": description,
|
||||
"source_chunk_id": source_chunk_id,
|
||||
"content": content,
|
||||
"content_vector": f"{content_vector.tolist()}",
|
||||
}
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
||||
if algorithm not in self._node_embed_algorithms:
|
||||
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
||||
return await self._node_embed_algorithms[algorithm]()
|
||||
|
||||
# Query
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
sql = SQL_TEMPLATES["has_entity"]
|
||||
param = {"name": node_id, "workspace": self.db.workspace}
|
||||
has = await self.db.query(sql, param)
|
||||
return has["cnt"] != 0
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
sql = SQL_TEMPLATES["has_relationship"]
|
||||
param = {
|
||||
"source_name": source_node_id,
|
||||
"target_name": target_node_id,
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
has = await self.db.query(sql, param)
|
||||
return has["cnt"] != 0
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
sql = SQL_TEMPLATES["node_degree"]
|
||||
param = {"name": node_id, "workspace": self.db.workspace}
|
||||
result = await self.db.query(sql, param)
|
||||
return result["cnt"]
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
|
||||
return degree
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
sql = SQL_TEMPLATES["get_node"]
|
||||
param = {"name": node_id, "workspace": self.db.workspace}
|
||||
return await self.db.query(sql, param)
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
sql = SQL_TEMPLATES["get_edge"]
|
||||
param = {
|
||||
"source_name": source_node_id,
|
||||
"target_name": target_node_id,
|
||||
"workspace": self.db.workspace,
|
||||
}
|
||||
return await self.db.query(sql, param)
|
||||
|
||||
async def get_node_edges(
|
||||
self, source_node_id: str
|
||||
) -> Union[list[tuple[str, str]], None]:
|
||||
sql = SQL_TEMPLATES["get_node_edges"]
|
||||
param = {"source_name": source_node_id, "workspace": self.db.workspace}
|
||||
res = await self.db.query(sql, param, multirows=True)
|
||||
if res:
|
||||
data = [(i["source_name"], i["target_name"]) for i in res]
|
||||
return data
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
N_T = {
|
||||
"full_docs": "LIGHTRAG_DOC_FULL",
|
||||
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||
"chunks": "LIGHTRAG_DOC_CHUNKS",
|
||||
"entities": "LIGHTRAG_GRAPH_NODES",
|
||||
"relationships": "LIGHTRAG_GRAPH_EDGES",
|
||||
}
|
||||
N_ID = {
|
||||
"full_docs": "doc_id",
|
||||
"text_chunks": "chunk_id",
|
||||
"chunks": "chunk_id",
|
||||
"entities": "entity_id",
|
||||
"relationships": "relation_id",
|
||||
}
|
||||
|
||||
TABLES = {
|
||||
"LIGHTRAG_DOC_FULL": {
|
||||
"ddl": """
|
||||
CREATE TABLE LIGHTRAG_DOC_FULL (
|
||||
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||
`doc_id` VARCHAR(256) NOT NULL,
|
||||
`workspace` varchar(1024),
|
||||
`content` LONGTEXT,
|
||||
`meta` JSON,
|
||||
`createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
`updatetime` TIMESTAMP DEFAULT NULL,
|
||||
UNIQUE KEY (`doc_id`)
|
||||
);
|
||||
"""
|
||||
},
|
||||
"LIGHTRAG_DOC_CHUNKS": {
|
||||
"ddl": """
|
||||
CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
||||
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||
`chunk_id` VARCHAR(256) NOT NULL,
|
||||
`full_doc_id` VARCHAR(256) NOT NULL,
|
||||
`workspace` varchar(1024),
|
||||
`chunk_order_index` INT,
|
||||
`tokens` INT,
|
||||
`content` LONGTEXT,
|
||||
`content_vector` VECTOR,
|
||||
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
`updatetime` DATETIME DEFAULT NULL,
|
||||
UNIQUE KEY (`chunk_id`)
|
||||
);
|
||||
"""
|
||||
},
|
||||
"LIGHTRAG_GRAPH_NODES": {
|
||||
"ddl": """
|
||||
CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
||||
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||
`entity_id` VARCHAR(256),
|
||||
`workspace` varchar(1024),
|
||||
`name` VARCHAR(2048),
|
||||
`entity_type` VARCHAR(1024),
|
||||
`description` LONGTEXT,
|
||||
`source_chunk_id` VARCHAR(256),
|
||||
`content` LONGTEXT,
|
||||
`content_vector` VECTOR,
|
||||
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
`updatetime` DATETIME DEFAULT NULL,
|
||||
KEY (`entity_id`)
|
||||
);
|
||||
"""
|
||||
},
|
||||
"LIGHTRAG_GRAPH_EDGES": {
|
||||
"ddl": """
|
||||
CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
||||
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
||||
`relation_id` VARCHAR(256),
|
||||
`workspace` varchar(1024),
|
||||
`source_name` VARCHAR(2048),
|
||||
`target_name` VARCHAR(2048),
|
||||
`weight` DECIMAL,
|
||||
`keywords` TEXT,
|
||||
`description` LONGTEXT,
|
||||
`source_chunk_id` varchar(256),
|
||||
`content` LONGTEXT,
|
||||
`content_vector` VECTOR,
|
||||
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
`updatetime` DATETIME DEFAULT NULL,
|
||||
KEY (`relation_id`)
|
||||
);
|
||||
"""
|
||||
},
|
||||
"LIGHTRAG_LLM_CACHE": {
|
||||
"ddl": """
|
||||
CREATE TABLE LIGHTRAG_LLM_CACHE (
|
||||
id BIGINT PRIMARY KEY AUTO_INCREMENT,
|
||||
send TEXT,
|
||||
return TEXT,
|
||||
model VARCHAR(1024),
|
||||
createtime DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updatetime DATETIME DEFAULT NULL
|
||||
);
|
||||
"""
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
SQL_TEMPLATES = {
|
||||
# SQL for KVStorage
|
||||
"get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace",
|
||||
"get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace",
|
||||
"get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
|
||||
"get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
|
||||
"filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
|
||||
# SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
|
||||
"upsert_doc_full": """
|
||||
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
|
||||
VALUES (:id, :content, :workspace)
|
||||
ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||
""",
|
||||
"upsert_chunk": """
|
||||
INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
|
||||
VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
|
||||
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
||||
""",
|
||||
# SQL for VectorStorage
|
||||
"entities": """SELECT n.name as entity_name FROM
|
||||
(SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
|
||||
FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n
|
||||
WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k
|
||||
""",
|
||||
"relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id FROM
|
||||
(SELECT source_name, target_name, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
||||
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e
|
||||
WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k
|
||||
""",
|
||||
"chunks": """SELECT c.id FROM
|
||||
(SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
||||
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
|
||||
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k
|
||||
""",
|
||||
"has_entity": """
|
||||
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace
|
||||
""",
|
||||
"has_relationship": """
|
||||
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace
|
||||
""",
|
||||
"update_entity": """
|
||||
UPDATE LIGHTRAG_GRAPH_NODES SET
|
||||
entity_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP
|
||||
WHERE workspace = :workspace AND name = :name
|
||||
""",
|
||||
"update_relationship": """
|
||||
UPDATE LIGHTRAG_GRAPH_EDGES SET
|
||||
relation_id = :id, content = :content, content_vector = :content_vector, updatetime = CURRENT_TIMESTAMP
|
||||
WHERE workspace = :workspace AND source_name = :source_name AND target_name = :target_name
|
||||
""",
|
||||
"insert_entity": """
|
||||
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
|
||||
VALUES(:id, :name, :content, :content_vector, :workspace)
|
||||
""",
|
||||
"insert_relationship": """
|
||||
INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
|
||||
VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
|
||||
""",
|
||||
# SQL for GraphStorage
|
||||
"get_node": """
|
||||
SELECT entity_id AS id, workspace, name, entity_type, description, source_chunk_id AS source_id, content, content_vector
|
||||
FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace
|
||||
""",
|
||||
"get_edge": """
|
||||
SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id AS source_id, content, content_vector
|
||||
FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace
|
||||
""",
|
||||
"get_node_edges": """
|
||||
SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id, content, content_vector
|
||||
FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND workspace = :workspace
|
||||
""",
|
||||
"node_degree": """
|
||||
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace AND :name IN (source_name, target_name)
|
||||
""",
|
||||
"upsert_node": """
|
||||
INSERT INTO LIGHTRAG_GRAPH_NODES(name, content, content_vector, workspace, source_chunk_id, entity_type, description)
|
||||
VALUES(:name, :content, :content_vector, :workspace, :source_chunk_id, :entity_type, :description)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
|
||||
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP,
|
||||
source_chunk_id = VALUES(source_chunk_id), entity_type = VALUES(entity_type), description = VALUES(description)
|
||||
""",
|
||||
"upsert_edge": """
|
||||
INSERT INTO LIGHTRAG_GRAPH_EDGES(source_name, target_name, content, content_vector,
|
||||
workspace, weight, keywords, description, source_chunk_id)
|
||||
VALUES(:source_name, :target_name, :content, :content_vector,
|
||||
:workspace, :weight, :keywords, :description, :source_chunk_id)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
|
||||
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP,
|
||||
weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description),
|
||||
source_chunk_id = VALUES(source_chunk_id)
|
||||
""",
|
||||
}
|
||||
Reference in New Issue
Block a user