improved existing insertion pipeline

This commit is contained in:
yashshah035
2025-03-12 10:48:00 +05:30
parent 415b7992c1
commit 7724e86e15

View File

@@ -3,7 +3,7 @@ import os
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast
from typing import Type, cast, Any
from .operate import (
@@ -20,7 +20,10 @@ from .utils import (
limit_async_func_call,
convert_response_to_json,
logger,
clean_text,
get_content_summary,
set_logger,
logger
)
from .base import (
BaseGraphStorage,
@@ -28,6 +31,7 @@ from .base import (
BaseVectorStorage,
StorageNameSpace,
QueryParam,
DocStatus,
)
@@ -326,54 +330,35 @@ class MiniRAG:
loop = always_get_an_event_loop()
return loop.run_until_complete(self.ainsert(string_or_strings))
async def ainsert(self, string_or_strings):
update_storage = False
try:
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
async def ainsert(self, input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, ids: str | list[str] | None = None) -> None:
if isinstance(input, str):
input = [input]
if isinstance(ids, str):
ids = [ids]
new_docs = {
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
for c in string_or_strings
await self.apipeline_enqueue_documents(input, ids)
await self.apipeline_process_enqueue_documents(split_by_character, split_by_character_only)
# Perform additional entity extraction as per original ainsert logic
inserting_chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning("All docs are already in the storage")
return
update_storage = True
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
inserting_chunks = {}
for doc_key, doc in new_docs.items():
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
inserting_chunks.update(chunks)
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
for doc_id, status_doc in (await self.doc_status.get_docs_by_status(DocStatus.PROCESSED)).items()
for dp in self.chunking_func(
status_doc.content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
self.tiktoken_model_name,
)
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning("All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
}
await self.chunks_vdb.upsert(inserting_chunks)
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
if inserting_chunks:
logger.info("Performing entity extraction on newly processed chunks")
await extract_entities(
inserting_chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
@@ -381,16 +366,113 @@ class MiniRAG:
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.warning("No new entities and relationships found")
return
self.chunk_entity_relation_graph = maybe_new_kg
await self.full_docs.upsert(new_docs)
await self.text_chunks.upsert(inserting_chunks)
finally:
if update_storage:
await self._insert_done()
async def apipeline_enqueue_documents(self, input: str | list[str], ids: list[str] | None = None) -> None:
"""
Pipeline for Processing Documents
1. Validate ids if provided or generate MD5 hash IDs
2. Remove duplicate contents
3. Generate document initial status
4. Filter out already processed documents
5. Enqueue document in status
"""
if isinstance(input, str):
input = [input]
if isinstance(ids, str):
ids = [ids]
if ids is not None:
if len(ids) != len(input):
raise ValueError("Number of IDs must match the number of documents")
if len(ids) != len(set(ids)):
raise ValueError("IDs must be unique")
contents = {id_: doc for id_, doc in zip(ids, input)}
else:
input = list(set(clean_text(doc) for doc in input))
contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input}
unique_contents = {id_: content for content, id_ in {content: id_ for id_, content in contents.items()}.items()}
new_docs: dict[str, Any] = {
id_: {
"content": content,
"content_summary": get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
}
for id_, content in unique_contents.items()
}
all_new_doc_ids = set(new_docs.keys())
unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids)
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids if doc_id in new_docs}
if not new_docs:
logger.info("No new unique documents were found.")
return
await self.doc_status.upsert(new_docs)
logger.info(f"Stored {len(new_docs)} new unique documents")
async def apipeline_process_enqueue_documents(self, split_by_character: str | None = None, split_by_character_only: bool = False) -> None:
"""
Process pending documents by splitting them into chunks, processing
each chunk for entity and relation extraction, and updating the
document status.
"""
processing_docs, failed_docs, pending_docs = await asyncio.gather(
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
self.doc_status.get_docs_by_status(DocStatus.FAILED),
self.doc_status.get_docs_by_status(DocStatus.PENDING),
)
to_process_docs: dict[str, Any] = {**processing_docs, **failed_docs, **pending_docs}
if not to_process_docs:
logger.info("No documents to process")
return
docs_batches = [
list(to_process_docs.items())[i : i + self.max_parallel_insert]
for i in range(0, len(to_process_docs), self.max_parallel_insert)
]
logger.info(f"Number of batches to process: {len(docs_batches)}")
for batch_idx, docs_batch in enumerate(docs_batches):
for doc_id, status_doc in docs_batch:
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in self.chunking_func(
status_doc.content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
self.tiktoken_model_name,
)
}
await asyncio.gather(
self.chunks_vdb.upsert(chunks),
self.full_docs.upsert({doc_id: {"content": status_doc.content}}),
self.text_chunks.upsert(chunks),
)
await self.doc_status.upsert({
doc_id: {
"status": DocStatus.PROCESSED,
"chunks_count": len(chunks),
"content": status_doc.content,
"content_summary": status_doc.content_summary,
"content_length": status_doc.content_length,
"created_at": status_doc.created_at,
"updated_at": datetime.now().isoformat(),
}
})
logger.info("Document processing pipeline completed")
async def _insert_done(self):
tasks = []