mirror of
https://github.com/hhy-huang/HiRAG.git
synced 2025-09-16 23:52:00 +03:00
fix: remove hnswlib package dependence
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from .gdb_networkx import NetworkXStorage
|
||||
from .gdb_neo4j import Neo4jStorage
|
||||
from .vdb_hnswlib import HNSWVectorStorage
|
||||
from .vdb_nanovectordb import NanoVectorDBStorage
|
||||
from .kv_json import JsonKVStorage
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
import pickle
|
||||
import hnswlib
|
||||
import numpy as np
|
||||
import xxhash
|
||||
|
||||
from .._utils import logger
|
||||
from ..base import BaseVectorStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class HNSWVectorStorage(BaseVectorStorage):
|
||||
ef_construction: int = 100
|
||||
M: int = 16
|
||||
max_elements: int = 1000000
|
||||
ef_search: int = 50
|
||||
num_threads: int = -1
|
||||
_index: Any = field(init=False)
|
||||
_metadata: dict[str, dict] = field(default_factory=dict)
|
||||
_current_elements: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
self._index_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"{self.namespace}_hnsw.index"
|
||||
)
|
||||
self._metadata_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl"
|
||||
)
|
||||
self._embedding_batch_num = self.global_config.get("embedding_batch_num", 100)
|
||||
|
||||
hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction)
|
||||
self.M = hnsw_params.get("M", self.M)
|
||||
self.max_elements = hnsw_params.get("max_elements", self.max_elements)
|
||||
self.ef_search = hnsw_params.get("ef_search", self.ef_search)
|
||||
self.num_threads = hnsw_params.get("num_threads", self.num_threads)
|
||||
self._index = hnswlib.Index(
|
||||
space="cosine", dim=self.embedding_func.embedding_dim
|
||||
)
|
||||
|
||||
if os.path.exists(self._index_file_name) and os.path.exists(
|
||||
self._metadata_file_name
|
||||
):
|
||||
self._index.load_index(
|
||||
self._index_file_name, max_elements=self.max_elements
|
||||
)
|
||||
with open(self._metadata_file_name, "rb") as f:
|
||||
self._metadata, self._current_elements = pickle.load(f)
|
||||
logger.info(
|
||||
f"Loaded existing index for {self.namespace} with {self._current_elements} elements"
|
||||
)
|
||||
else:
|
||||
self._index.init_index(
|
||||
max_elements=self.max_elements,
|
||||
ef_construction=self.ef_construction,
|
||||
M=self.M,
|
||||
)
|
||||
self._index.set_ef(self.ef_search)
|
||||
self._metadata = {}
|
||||
self._current_elements = 0
|
||||
logger.info(f"Created new index for {self.namespace}")
|
||||
|
||||
async def upsert(self, data: dict[str, dict]) -> np.ndarray:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not data:
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
|
||||
if self._current_elements + len(data) > self.max_elements:
|
||||
raise ValueError(
|
||||
f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}"
|
||||
)
|
||||
|
||||
list_data = [
|
||||
{
|
||||
"id": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batch_size = min(self._embedding_batch_num, len(contents))
|
||||
embeddings = np.concatenate(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.embedding_func(contents[i : i + batch_size])
|
||||
for i in range(0, len(contents), batch_size)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
ids = np.fromiter(
|
||||
(xxhash.xxh32_intdigest(d["id"].encode()) for d in list_data),
|
||||
dtype=np.uint32,
|
||||
count=len(list_data),
|
||||
)
|
||||
self._metadata.update(
|
||||
{
|
||||
id_int: {
|
||||
k: v for k, v in d.items() if k in self.meta_fields or k == "id"
|
||||
}
|
||||
for id_int, d in zip(ids, list_data)
|
||||
}
|
||||
)
|
||||
self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads)
|
||||
self._current_elements = self._index.get_current_count()
|
||||
return ids
|
||||
|
||||
async def query(self, query: str, top_k: int = 5) -> list[dict]:
|
||||
if self._current_elements == 0:
|
||||
return []
|
||||
|
||||
top_k = min(top_k, self._current_elements)
|
||||
|
||||
if top_k > self.ef_search:
|
||||
logger.warning(
|
||||
f"Setting ef_search to {top_k} because top_k is larger than ef_search"
|
||||
)
|
||||
self._index.set_ef(top_k)
|
||||
|
||||
embedding = await self.embedding_func([query])
|
||||
labels, distances = self._index.knn_query(
|
||||
data=embedding[0], k=top_k, num_threads=self.num_threads
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
**self._metadata.get(label, {}),
|
||||
"distance": distance,
|
||||
"similarity": 1 - distance,
|
||||
}
|
||||
for label, distance in zip(labels[0], distances[0])
|
||||
]
|
||||
|
||||
async def index_done_callback(self):
|
||||
self._index.save_index(self._index_file_name)
|
||||
with open(self._metadata_file_name, "wb") as f:
|
||||
pickle.dump((self._metadata, self._current_elements), f)
|
||||
Reference in New Issue
Block a user