fix paralle insertion

This commit is contained in:
yashshah035
2025-03-29 20:13:53 +05:30
parent 7ef179222d
commit b869f2d95c

View File

@@ -4,6 +4,7 @@ from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast, Any
from dotenv import load_dotenv
from .operate import (
@@ -23,7 +24,7 @@ from .utils import (
clean_text,
get_content_summary,
set_logger,
logger
logger,
)
from .base import (
BaseGraphStorage,
@@ -66,6 +67,8 @@ STORAGES = {
# GraphStorage as ArangoDBStorage
# )
load_dotenv(dotenv_path=".env", override=False)
def lazy_external_import(module_name: str, class_name: str):
"""Lazily import a class from an external module based on the package of the caller."""
@@ -155,7 +158,9 @@ class MiniRAG:
# LLM
llm_model_func: callable = None
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_name: str = (
"meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
)
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
llm_model_kwargs: dict = field(default_factory=dict)
@@ -176,6 +181,8 @@ class MiniRAG:
chunking_func: callable = chunking_by_token_size
chunking_func_kwargs: dict = field(default_factory=dict)
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 2)))
def __post_init__(self):
log_file = os.path.join(self.working_dir, "minirag.log")
set_logger(log_file)
@@ -330,14 +337,22 @@ class MiniRAG:
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))
async def ainsert(self, input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, ids: str | list[str] | None = None) -> None:
async def ainsert(
self,
input: str | list[str],
split_by_character: str | None = None,
split_by_character_only: bool = False,
ids: str | list[str] | None = None,
) -> None:
if isinstance(input, str):
input = [input]
if isinstance(ids, str):
ids = [ids]
await self.apipeline_enqueue_documents(input, ids)
await self.apipeline_process_enqueue_documents(split_by_character, split_by_character_only)
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
)
# Perform additional entity extraction as per original ainsert logic
inserting_chunks = {
@@ -345,7 +360,9 @@ class MiniRAG:
**dp,
"full_doc_id": doc_id,
}
for doc_id, status_doc in (await self.doc_status.get_docs_by_status(DocStatus.PROCESSED)).items()
for doc_id, status_doc in (
await self.doc_status.get_docs_by_status(DocStatus.PROCESSED)
).items()
for dp in self.chunking_func(
status_doc.content,
split_by_character,
@@ -367,7 +384,9 @@ class MiniRAG:
global_config=asdict(self),
)
async def apipeline_enqueue_documents(self, input: str | list[str], ids: list[str] | None = None) -> None:
async def apipeline_enqueue_documents(
self, input: str | list[str], ids: list[str] | None = None
) -> None:
"""
Pipeline for Processing Documents
@@ -392,7 +411,12 @@ class MiniRAG:
input = list(set(clean_text(doc) for doc in input))
contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input}
unique_contents = {id_: content for content, id_ in {content: id_ for id_, content in contents.items()}.items()}
unique_contents = {
id_: content
for content, id_ in {
content: id_ for id_, content in contents.items()
}.items()
}
new_docs: dict[str, Any] = {
id_: {
"content": content,
@@ -408,7 +432,11 @@ class MiniRAG:
all_new_doc_ids = set(new_docs.keys())
unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids)
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids if doc_id in new_docs}
new_docs = {
doc_id: new_docs[doc_id]
for doc_id in unique_new_doc_ids
if doc_id in new_docs
}
if not new_docs:
logger.info("No new unique documents were found.")
return
@@ -416,7 +444,11 @@ class MiniRAG:
await self.doc_status.upsert(new_docs)
logger.info(f"Stored {len(new_docs)} new unique documents")
async def apipeline_process_enqueue_documents(self, split_by_character: str | None = None, split_by_character_only: bool = False) -> None:
async def apipeline_process_enqueue_documents(
self,
split_by_character: str | None = None,
split_by_character_only: bool = False,
) -> None:
"""
Process pending documents by splitting them into chunks, processing
each chunk for entity and relation extraction, and updating the
@@ -428,7 +460,11 @@ class MiniRAG:
self.doc_status.get_docs_by_status(DocStatus.PENDING),
)
to_process_docs: dict[str, Any] = {**processing_docs, **failed_docs, **pending_docs}
to_process_docs: dict[str, Any] = {
**processing_docs,
**failed_docs,
**pending_docs,
}
if not to_process_docs:
logger.info("No documents to process")
return
@@ -460,20 +496,21 @@ class MiniRAG:
self.full_docs.upsert({doc_id: {"content": status_doc.content}}),
self.text_chunks.upsert(chunks),
)
await self.doc_status.upsert({
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
await self.doc_status.upsert(
{
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
}
})
)
logger.info("Document processing pipeline completed")
async def _insert_done(self):
tasks = []
for storage_inst in [