From 9d51ec7e761ea98e2184c53703ac10bbdca91a8c Mon Sep 17 00:00:00 2001 From: haoyuhuang Date: Wed, 30 Apr 2025 16:35:11 +0800 Subject: [PATCH] fix: remove hnswlib package dependence --- hirag/_storage/__init__.py | 1 - hirag/_storage/vdb_hnswlib.py | 141 ---------------------------------- 2 files changed, 142 deletions(-) delete mode 100644 hirag/_storage/vdb_hnswlib.py diff --git a/hirag/_storage/__init__.py b/hirag/_storage/__init__.py index c8184ab..0c564cb 100644 --- a/hirag/_storage/__init__.py +++ b/hirag/_storage/__init__.py @@ -1,5 +1,4 @@ from .gdb_networkx import NetworkXStorage from .gdb_neo4j import Neo4jStorage -from .vdb_hnswlib import HNSWVectorStorage from .vdb_nanovectordb import NanoVectorDBStorage from .kv_json import JsonKVStorage diff --git a/hirag/_storage/vdb_hnswlib.py b/hirag/_storage/vdb_hnswlib.py deleted file mode 100644 index 3e98c95..0000000 --- a/hirag/_storage/vdb_hnswlib.py +++ /dev/null @@ -1,141 +0,0 @@ -import asyncio -import os -from dataclasses import dataclass, field -from typing import Any -import pickle -import hnswlib -import numpy as np -import xxhash - -from .._utils import logger -from ..base import BaseVectorStorage - - -@dataclass -class HNSWVectorStorage(BaseVectorStorage): - ef_construction: int = 100 - M: int = 16 - max_elements: int = 1000000 - ef_search: int = 50 - num_threads: int = -1 - _index: Any = field(init=False) - _metadata: dict[str, dict] = field(default_factory=dict) - _current_elements: int = 0 - - def __post_init__(self): - self._index_file_name = os.path.join( - self.global_config["working_dir"], f"{self.namespace}_hnsw.index" - ) - self._metadata_file_name = os.path.join( - self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl" - ) - self._embedding_batch_num = self.global_config.get("embedding_batch_num", 100) - - hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction) - self.M = hnsw_params.get("M", self.M) - self.max_elements = hnsw_params.get("max_elements", self.max_elements) - self.ef_search = hnsw_params.get("ef_search", self.ef_search) - self.num_threads = hnsw_params.get("num_threads", self.num_threads) - self._index = hnswlib.Index( - space="cosine", dim=self.embedding_func.embedding_dim - ) - - if os.path.exists(self._index_file_name) and os.path.exists( - self._metadata_file_name - ): - self._index.load_index( - self._index_file_name, max_elements=self.max_elements - ) - with open(self._metadata_file_name, "rb") as f: - self._metadata, self._current_elements = pickle.load(f) - logger.info( - f"Loaded existing index for {self.namespace} with {self._current_elements} elements" - ) - else: - self._index.init_index( - max_elements=self.max_elements, - ef_construction=self.ef_construction, - M=self.M, - ) - self._index.set_ef(self.ef_search) - self._metadata = {} - self._current_elements = 0 - logger.info(f"Created new index for {self.namespace}") - - async def upsert(self, data: dict[str, dict]) -> np.ndarray: - logger.info(f"Inserting {len(data)} vectors to {self.namespace}") - if not data: - logger.warning("You insert an empty data to vector DB") - return [] - - if self._current_elements + len(data) > self.max_elements: - raise ValueError( - f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}" - ) - - list_data = [ - { - "id": k, - **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batch_size = min(self._embedding_batch_num, len(contents)) - embeddings = np.concatenate( - await asyncio.gather( - *[ - self.embedding_func(contents[i : i + batch_size]) - for i in range(0, len(contents), batch_size) - ] - ) - ) - - ids = np.fromiter( - (xxhash.xxh32_intdigest(d["id"].encode()) for d in list_data), - dtype=np.uint32, - count=len(list_data), - ) - self._metadata.update( - { - id_int: { - k: v for k, v in d.items() if k in self.meta_fields or k == "id" - } - for id_int, d in zip(ids, list_data) - } - ) - self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads) - self._current_elements = self._index.get_current_count() - return ids - - async def query(self, query: str, top_k: int = 5) -> list[dict]: - if self._current_elements == 0: - return [] - - top_k = min(top_k, self._current_elements) - - if top_k > self.ef_search: - logger.warning( - f"Setting ef_search to {top_k} because top_k is larger than ef_search" - ) - self._index.set_ef(top_k) - - embedding = await self.embedding_func([query]) - labels, distances = self._index.knn_query( - data=embedding[0], k=top_k, num_threads=self.num_threads - ) - - return [ - { - **self._metadata.get(label, {}), - "distance": distance, - "similarity": 1 - distance, - } - for label, distance in zip(labels[0], distances[0]) - ] - - async def index_done_callback(self): - self._index.save_index(self._index_file_name) - with open(self._metadata_file_name, "wb") as f: - pickle.dump((self._metadata, self._current_elements), f)