improved existing insertion pipeline
This commit is contained in:
@@ -3,7 +3,7 @@ import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Type, cast
|
||||
from typing import Type, cast, Any
|
||||
|
||||
|
||||
from .operate import (
|
||||
@@ -20,7 +20,10 @@ from .utils import (
|
||||
limit_async_func_call,
|
||||
convert_response_to_json,
|
||||
logger,
|
||||
clean_text,
|
||||
get_content_summary,
|
||||
set_logger,
|
||||
logger
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -28,6 +31,7 @@ from .base import (
|
||||
BaseVectorStorage,
|
||||
StorageNameSpace,
|
||||
QueryParam,
|
||||
DocStatus,
|
||||
)
|
||||
|
||||
|
||||
@@ -326,54 +330,35 @@ class MiniRAG:
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.ainsert(string_or_strings))
|
||||
|
||||
async def ainsert(self, string_or_strings):
|
||||
update_storage = False
|
||||
try:
|
||||
if isinstance(string_or_strings, str):
|
||||
string_or_strings = [string_or_strings]
|
||||
async def ainsert(self, input: str | list[str], split_by_character: str | None = None, split_by_character_only: bool = False, ids: str | list[str] | None = None) -> None:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if isinstance(ids, str):
|
||||
ids = [ids]
|
||||
|
||||
new_docs = {
|
||||
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
|
||||
for c in string_or_strings
|
||||
await self.apipeline_enqueue_documents(input, ids)
|
||||
await self.apipeline_process_enqueue_documents(split_by_character, split_by_character_only)
|
||||
|
||||
# Perform additional entity extraction as per original ainsert logic
|
||||
inserting_chunks = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
"full_doc_id": doc_id,
|
||||
}
|
||||
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||
if not len(new_docs):
|
||||
logger.warning("All docs are already in the storage")
|
||||
return
|
||||
update_storage = True
|
||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||
|
||||
inserting_chunks = {}
|
||||
for doc_key, doc in new_docs.items():
|
||||
chunks = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
"full_doc_id": doc_key,
|
||||
}
|
||||
for dp in chunking_by_token_size(
|
||||
doc["content"],
|
||||
overlap_token_size=self.chunk_overlap_token_size,
|
||||
max_token_size=self.chunk_token_size,
|
||||
tiktoken_model=self.tiktoken_model_name,
|
||||
)
|
||||
}
|
||||
inserting_chunks.update(chunks)
|
||||
_add_chunk_keys = await self.text_chunks.filter_keys(
|
||||
list(inserting_chunks.keys())
|
||||
for doc_id, status_doc in (await self.doc_status.get_docs_by_status(DocStatus.PROCESSED)).items()
|
||||
for dp in self.chunking_func(
|
||||
status_doc.content,
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
self.chunk_overlap_token_size,
|
||||
self.chunk_token_size,
|
||||
self.tiktoken_model_name,
|
||||
)
|
||||
inserting_chunks = {
|
||||
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
||||
}
|
||||
if not len(inserting_chunks):
|
||||
logger.warning("All chunks are already in the storage")
|
||||
return
|
||||
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
||||
}
|
||||
|
||||
await self.chunks_vdb.upsert(inserting_chunks)
|
||||
|
||||
logger.info("[Entity Extraction]...")
|
||||
maybe_new_kg = await extract_entities(
|
||||
if inserting_chunks:
|
||||
logger.info("Performing entity extraction on newly processed chunks")
|
||||
await extract_entities(
|
||||
inserting_chunks,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
@@ -381,16 +366,113 @@ class MiniRAG:
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
global_config=asdict(self),
|
||||
)
|
||||
if maybe_new_kg is None:
|
||||
logger.warning("No new entities and relationships found")
|
||||
return
|
||||
self.chunk_entity_relation_graph = maybe_new_kg
|
||||
|
||||
await self.full_docs.upsert(new_docs)
|
||||
await self.text_chunks.upsert(inserting_chunks)
|
||||
finally:
|
||||
if update_storage:
|
||||
await self._insert_done()
|
||||
async def apipeline_enqueue_documents(self, input: str | list[str], ids: list[str] | None = None) -> None:
|
||||
"""
|
||||
Pipeline for Processing Documents
|
||||
|
||||
1. Validate ids if provided or generate MD5 hash IDs
|
||||
2. Remove duplicate contents
|
||||
3. Generate document initial status
|
||||
4. Filter out already processed documents
|
||||
5. Enqueue document in status
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if isinstance(ids, str):
|
||||
ids = [ids]
|
||||
|
||||
if ids is not None:
|
||||
if len(ids) != len(input):
|
||||
raise ValueError("Number of IDs must match the number of documents")
|
||||
if len(ids) != len(set(ids)):
|
||||
raise ValueError("IDs must be unique")
|
||||
contents = {id_: doc for id_, doc in zip(ids, input)}
|
||||
else:
|
||||
input = list(set(clean_text(doc) for doc in input))
|
||||
contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input}
|
||||
|
||||
unique_contents = {id_: content for content, id_ in {content: id_ for id_, content in contents.items()}.items()}
|
||||
new_docs: dict[str, Any] = {
|
||||
id_: {
|
||||
"content": content,
|
||||
"content_summary": get_content_summary(content),
|
||||
"content_length": len(content),
|
||||
"status": DocStatus.PENDING,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
for id_, content in unique_contents.items()
|
||||
}
|
||||
|
||||
all_new_doc_ids = set(new_docs.keys())
|
||||
unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids)
|
||||
|
||||
new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids if doc_id in new_docs}
|
||||
if not new_docs:
|
||||
logger.info("No new unique documents were found.")
|
||||
return
|
||||
|
||||
await self.doc_status.upsert(new_docs)
|
||||
logger.info(f"Stored {len(new_docs)} new unique documents")
|
||||
|
||||
async def apipeline_process_enqueue_documents(self, split_by_character: str | None = None, split_by_character_only: bool = False) -> None:
|
||||
"""
|
||||
Process pending documents by splitting them into chunks, processing
|
||||
each chunk for entity and relation extraction, and updating the
|
||||
document status.
|
||||
"""
|
||||
processing_docs, failed_docs, pending_docs = await asyncio.gather(
|
||||
self.doc_status.get_docs_by_status(DocStatus.PROCESSING),
|
||||
self.doc_status.get_docs_by_status(DocStatus.FAILED),
|
||||
self.doc_status.get_docs_by_status(DocStatus.PENDING),
|
||||
)
|
||||
|
||||
to_process_docs: dict[str, Any] = {**processing_docs, **failed_docs, **pending_docs}
|
||||
if not to_process_docs:
|
||||
logger.info("No documents to process")
|
||||
return
|
||||
|
||||
docs_batches = [
|
||||
list(to_process_docs.items())[i : i + self.max_parallel_insert]
|
||||
for i in range(0, len(to_process_docs), self.max_parallel_insert)
|
||||
]
|
||||
logger.info(f"Number of batches to process: {len(docs_batches)}")
|
||||
|
||||
for batch_idx, docs_batch in enumerate(docs_batches):
|
||||
for doc_id, status_doc in docs_batch:
|
||||
chunks = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
"full_doc_id": doc_id,
|
||||
}
|
||||
for dp in self.chunking_func(
|
||||
status_doc.content,
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
self.chunk_overlap_token_size,
|
||||
self.chunk_token_size,
|
||||
self.tiktoken_model_name,
|
||||
)
|
||||
}
|
||||
await asyncio.gather(
|
||||
self.chunks_vdb.upsert(chunks),
|
||||
self.full_docs.upsert({doc_id: {"content": status_doc.content}}),
|
||||
self.text_chunks.upsert(chunks),
|
||||
)
|
||||
await self.doc_status.upsert({
|
||||
doc_id: {
|
||||
"status": DocStatus.PROCESSED,
|
||||
"chunks_count": len(chunks),
|
||||
"content": status_doc.content,
|
||||
"content_summary": status_doc.content_summary,
|
||||
"content_length": status_doc.content_length,
|
||||
"created_at": status_doc.created_at,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
}
|
||||
})
|
||||
logger.info("Document processing pipeline completed")
|
||||
|
||||
|
||||
async def _insert_done(self):
|
||||
tasks = []
|
||||
|
||||
Reference in New Issue
Block a user