fix paralle insertion
This commit is contained in:
@@ -4,6 +4,7 @@ from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Type, cast, Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
from .operate import (
|
||||
@@ -23,7 +24,7 @@ from .utils import (
|
||||
clean_text,
|
||||
get_content_summary,
|
||||
set_logger,
|
||||
logger
|
||||
logger,
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -66,6 +67,8 @@ STORAGES = {
|
||||
# GraphStorage as ArangoDBStorage
|
||||
# )
|
||||
|
||||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
def lazy_external_import(module_name: str, class_name: str):
|
||||
"""Lazily import a class from an external module based on the package of the caller."""
|
||||
@@ -155,7 +158,9 @@ class MiniRAG:
|
||||
|
||||
# LLM
|
||||
llm_model_func: callable = None
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_name: str = (
|
||||
"meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
)
|
||||
llm_model_max_token_size: int = 32768
|
||||
llm_model_max_async: int = 16
|
||||
llm_model_kwargs: dict = field(default_factory=dict)
|
||||
@@ -176,6 +181,8 @@ class MiniRAG:
|
||||
chunking_func: callable = chunking_by_token_size
|
||||
chunking_func_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 2)))
|
||||
|
||||
def __post_init__(self):
|
||||
log_file = os.path.join(self.working_dir, "minirag.log")
|
||||
set_logger(log_file)
|
||||
@@ -330,14 +337,22 @@ class MiniRAG:
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
||||
|
||||
async def ainsert(self, input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, ids: str | list[str] | None = None) -> None:
|
||||
async def ainsert(
|
||||
self,
|
||||
input: str | list[str],
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
ids: str | list[str] | None = None,
|
||||
) -> None:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if isinstance(ids, str):
|
||||
ids = [ids]
|
||||
|
||||
await self.apipeline_enqueue_documents(input, ids)
|
||||
await self.apipeline_process_enqueue_documents(split_by_character, split_by_character_only)
|
||||
await self.apipeline_process_enqueue_documents(
|
||||
split_by_character, split_by_character_only
|
||||
)
|
||||
|
||||
# Perform additional entity extraction as per original ainsert logic
|
||||
inserting_chunks = {
|
||||
@@ -345,7 +360,9 @@ class MiniRAG:
|
||||
**dp,
|
||||
"full_doc_id": doc_id,
|
||||
}
|
||||
for doc_id, status_doc in (await self.doc_status.get_docs_by_status(DocStatus.PROCESSED)).items()
|
||||
for doc_id, status_doc in (
|
||||
await self.doc_status.get_docs_by_status(DocStatus.PROCESSED)
|
||||
).items()
|
||||
for dp in self.chunking_func(
|
||||
status_doc.content,
|
||||
split_by_character,
|
||||
@@ -367,7 +384,9 @@ class MiniRAG:
|
||||
global_config=asdict(self),
|
||||
)
|
||||
|
||||
async def apipeline_enqueue_documents(self, input: str | list[str], ids: list[str] | None = None) -> None:
|
||||
async def apipeline_enqueue_documents(
|
||||
self, input: str | list[str], ids: list[str] | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Pipeline for Processing Documents
|
||||
|
||||
@@ -392,7 +411,12 @@ class MiniRAG:
|
||||
input = list(set(clean_text(doc) for doc in input))
|
||||
contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input}
|
||||
|
||||
unique_contents = {id_: content for content, id_ in {content: id_ for id_, content in contents.items()}.items()}
|
||||
unique_contents = {
|
||||
id_: content
|
||||
for content, id_ in {
|
||||
content: id_ for id_, content in contents.items()
|
||||
}.items()
|
||||
}
|
||||
new_docs: dict[str, Any] = {
|
||||
id_: {
|
||||
"content": content,
|
||||
@@ -408,7 +432,11 @@ class MiniRAG:
|
||||
all_new_doc_ids = set(new_docs.keys())
|
||||
unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids)
|
||||
|
||||
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids if doc_id in new_docs}
|
||||
new_docs = {
|
||||
doc_id: new_docs[doc_id]
|
||||
for doc_id in unique_new_doc_ids
|
||||
if doc_id in new_docs
|
||||
}
|
||||
if not new_docs:
|
||||
logger.info("No new unique documents were found.")
|
||||
return
|
||||
@@ -416,7 +444,11 @@ class MiniRAG:
|
||||
await self.doc_status.upsert(new_docs)
|
||||
logger.info(f"Stored {len(new_docs)} new unique documents")
|
||||
|
||||
async def apipeline_process_enqueue_documents(self, split_by_character: str | None = None, split_by_character_only: bool = False) -> None:
|
||||
async def apipeline_process_enqueue_documents(
|
||||
self,
|
||||
split_by_character: str | None = None,
|
||||
split_by_character_only: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Process pending documents by splitting them into chunks, processing
|
||||
each chunk for entity and relation extraction, and updating the
|
||||
@@ -428,7 +460,11 @@ class MiniRAG:
|
||||
self.doc_status.get_docs_by_status(DocStatus.PENDING),
|
||||
)
|
||||
|
||||
to_process_docs: dict[str, Any] = {**processing_docs, **failed_docs, **pending_docs}
|
||||
to_process_docs: dict[str, Any] = {
|
||||
**processing_docs,
|
||||
**failed_docs,
|
||||
**pending_docs,
|
||||
}
|
||||
if not to_process_docs:
|
||||
logger.info("No documents to process")
|
||||
return
|
||||
@@ -460,20 +496,21 @@ class MiniRAG:
|
||||
self.full_docs.upsert({doc_id: {"content": status_doc.content}}),
|
||||
self.text_chunks.upsert(chunks),
|
||||
)
|
||||
await self.doc_status.upsert({
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
await self.doc_status.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
logger.info("Document processing pipeline completed")
|
||||
|
||||
|
||||
async def _insert_done(self):
|
||||
tasks = []
|
||||
for storage_inst in [
|
||||
|
||||
Reference in New Issue
Block a user