fix paralle insertion
This commit is contained in:
		@@ -4,6 +4,7 @@ from dataclasses import asdict, dataclass, field
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from functools import partial
 | 
			
		||||
from typing import Type, cast, Any
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from .operate import (
 | 
			
		||||
@@ -23,7 +24,7 @@ from .utils import (
 | 
			
		||||
    clean_text,
 | 
			
		||||
    get_content_summary,
 | 
			
		||||
    set_logger,
 | 
			
		||||
    logger
 | 
			
		||||
    logger,
 | 
			
		||||
)
 | 
			
		||||
from .base import (
 | 
			
		||||
    BaseGraphStorage,
 | 
			
		||||
@@ -66,6 +67,8 @@ STORAGES = {
 | 
			
		||||
#     GraphStorage as ArangoDBStorage
 | 
			
		||||
# )
 | 
			
		||||
 | 
			
		||||
load_dotenv(dotenv_path=".env", override=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def lazy_external_import(module_name: str, class_name: str):
 | 
			
		||||
    """Lazily import a class from an external module based on the package of the caller."""
 | 
			
		||||
@@ -155,7 +158,9 @@ class MiniRAG:
 | 
			
		||||
 | 
			
		||||
    # LLM
 | 
			
		||||
    llm_model_func: callable = None
 | 
			
		||||
    llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"  #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
 | 
			
		||||
    llm_model_name: str = (
 | 
			
		||||
        "meta-llama/Llama-3.2-1B-Instruct"  #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
 | 
			
		||||
    )
 | 
			
		||||
    llm_model_max_token_size: int = 32768
 | 
			
		||||
    llm_model_max_async: int = 16
 | 
			
		||||
    llm_model_kwargs: dict = field(default_factory=dict)
 | 
			
		||||
@@ -176,6 +181,8 @@ class MiniRAG:
 | 
			
		||||
    chunking_func: callable = chunking_by_token_size
 | 
			
		||||
    chunking_func_kwargs: dict = field(default_factory=dict)
 | 
			
		||||
 | 
			
		||||
    max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 2)))
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        log_file = os.path.join(self.working_dir, "minirag.log")
 | 
			
		||||
        set_logger(log_file)
 | 
			
		||||
@@ -330,14 +337,22 @@ class MiniRAG:
 | 
			
		||||
        loop = always_get_an_event_loop()
 | 
			
		||||
        return loop.run_until_complete(self.ainsert(string_or_strings))
 | 
			
		||||
 | 
			
		||||
    async def ainsert(self, input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, ids: str | list[str] | None = None) -> None:
 | 
			
		||||
    async def ainsert(
 | 
			
		||||
        self,
 | 
			
		||||
        input: str | list[str],
 | 
			
		||||
        split_by_character: str | None = None,
 | 
			
		||||
        split_by_character_only: bool = False,
 | 
			
		||||
        ids: str | list[str] | None = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        if isinstance(input, str):
 | 
			
		||||
            input = [input]
 | 
			
		||||
        if isinstance(ids, str):
 | 
			
		||||
            ids = [ids]
 | 
			
		||||
 | 
			
		||||
        await self.apipeline_enqueue_documents(input, ids)
 | 
			
		||||
        await self.apipeline_process_enqueue_documents(split_by_character, split_by_character_only)
 | 
			
		||||
        await self.apipeline_process_enqueue_documents(
 | 
			
		||||
            split_by_character, split_by_character_only
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Perform additional entity extraction as per original ainsert logic
 | 
			
		||||
        inserting_chunks = {
 | 
			
		||||
@@ -345,7 +360,9 @@ class MiniRAG:
 | 
			
		||||
                **dp,
 | 
			
		||||
                "full_doc_id": doc_id,
 | 
			
		||||
            }
 | 
			
		||||
            for doc_id, status_doc in (await self.doc_status.get_docs_by_status(DocStatus.PROCESSED)).items()
 | 
			
		||||
            for doc_id, status_doc in (
 | 
			
		||||
                await self.doc_status.get_docs_by_status(DocStatus.PROCESSED)
 | 
			
		||||
            ).items()
 | 
			
		||||
            for dp in self.chunking_func(
 | 
			
		||||
                status_doc.content,
 | 
			
		||||
                split_by_character,
 | 
			
		||||
@@ -367,7 +384,9 @@ class MiniRAG:
 | 
			
		||||
                global_config=asdict(self),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    async def apipeline_enqueue_documents(self, input: str | list[str], ids: list[str] | None = None) -> None:
 | 
			
		||||
    async def apipeline_enqueue_documents(
 | 
			
		||||
        self, input: str | list[str], ids: list[str] | None = None
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Pipeline for Processing Documents
 | 
			
		||||
 | 
			
		||||
@@ -392,7 +411,12 @@ class MiniRAG:
 | 
			
		||||
            input = list(set(clean_text(doc) for doc in input))
 | 
			
		||||
            contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input}
 | 
			
		||||
 | 
			
		||||
        unique_contents = {id_: content for content, id_ in {content: id_ for id_, content in contents.items()}.items()}
 | 
			
		||||
        unique_contents = {
 | 
			
		||||
            id_: content
 | 
			
		||||
            for content, id_ in {
 | 
			
		||||
                content: id_ for id_, content in contents.items()
 | 
			
		||||
            }.items()
 | 
			
		||||
        }
 | 
			
		||||
        new_docs: dict[str, Any] = {
 | 
			
		||||
            id_: {
 | 
			
		||||
                "content": content,
 | 
			
		||||
@@ -408,7 +432,11 @@ class MiniRAG:
 | 
			
		||||
        all_new_doc_ids = set(new_docs.keys())
 | 
			
		||||
        unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids)
 | 
			
		||||
 | 
			
		||||
        new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids if doc_id in new_docs}
 | 
			
		||||
        new_docs = {
 | 
			
		||||
            doc_id: new_docs[doc_id]
 | 
			
		||||
            for doc_id in unique_new_doc_ids
 | 
			
		||||
            if doc_id in new_docs
 | 
			
		||||
        }
 | 
			
		||||
        if not new_docs:
 | 
			
		||||
            logger.info("No new unique documents were found.")
 | 
			
		||||
            return
 | 
			
		||||
@@ -416,7 +444,11 @@ class MiniRAG:
 | 
			
		||||
        await self.doc_status.upsert(new_docs)
 | 
			
		||||
        logger.info(f"Stored {len(new_docs)} new unique documents")
 | 
			
		||||
 | 
			
		||||
    async def apipeline_process_enqueue_documents(self, split_by_character: str | None = None, split_by_character_only: bool = False) -> None:
 | 
			
		||||
    async def apipeline_process_enqueue_documents(
 | 
			
		||||
        self,
 | 
			
		||||
        split_by_character: str | None = None,
 | 
			
		||||
        split_by_character_only: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Process pending documents by splitting them into chunks, processing
 | 
			
		||||
        each chunk for entity and relation extraction, and updating the
 | 
			
		||||
@@ -428,7 +460,11 @@ class MiniRAG:
 | 
			
		||||
            self.doc_status.get_docs_by_status(DocStatus.PENDING),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        to_process_docs: dict[str, Any] = {**processing_docs, **failed_docs, **pending_docs}
 | 
			
		||||
        to_process_docs: dict[str, Any] = {
 | 
			
		||||
            **processing_docs,
 | 
			
		||||
            **failed_docs,
 | 
			
		||||
            **pending_docs,
 | 
			
		||||
        }
 | 
			
		||||
        if not to_process_docs:
 | 
			
		||||
            logger.info("No documents to process")
 | 
			
		||||
            return
 | 
			
		||||
@@ -460,7 +496,8 @@ class MiniRAG:
 | 
			
		||||
                    self.full_docs.upsert({doc_id: {"content": status_doc.content}}),
 | 
			
		||||
                    self.text_chunks.upsert(chunks),
 | 
			
		||||
                )
 | 
			
		||||
                await self.doc_status.upsert({
 | 
			
		||||
                await self.doc_status.upsert(
 | 
			
		||||
                    {
 | 
			
		||||
                        doc_id: {
 | 
			
		||||
                            "status": DocStatus.PROCESSED,
 | 
			
		||||
                            "chunks_count": len(chunks),
 | 
			
		||||
@@ -470,10 +507,10 @@ class MiniRAG:
 | 
			
		||||
                            "created_at": status_doc.created_at,
 | 
			
		||||
                            "updated_at": datetime.now().isoformat(),
 | 
			
		||||
                        }
 | 
			
		||||
                })
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
        logger.info("Document processing pipeline completed")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    async def _insert_done(self):
 | 
			
		||||
        tasks = []
 | 
			
		||||
        for storage_inst in [
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user