add statistics while graph creation
This commit is contained in:
113
create.py
113
create.py
@@ -1,6 +1,9 @@
|
||||
import glob
|
||||
import os
|
||||
import asyncio
|
||||
import statistics
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, List, Any
|
||||
|
||||
import aiofiles
|
||||
from lightrag import LightRAG, QueryParam
|
||||
@@ -21,37 +24,51 @@ def read_text_file(file_path):
|
||||
return text
|
||||
|
||||
|
||||
def get_text_statistics(text):
|
||||
"""Calculate statistics for the given text."""
|
||||
char_count = len(text)
|
||||
word_count = len(text.split())
|
||||
line_count = len(text.splitlines())
|
||||
return {
|
||||
'char_count': char_count,
|
||||
'word_count': word_count,
|
||||
'line_count': line_count
|
||||
}
|
||||
|
||||
|
||||
def with_statistics(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(file_path: str, *args, **kwargs) -> Dict[str, Any]:
|
||||
# Read the file
|
||||
text = read_text_file(file_path)
|
||||
|
||||
# Get text statistics
|
||||
stats = get_text_statistics(text)
|
||||
file_size = os.path.getsize(file_path)
|
||||
stats['file_size'] = file_size
|
||||
stats['file_name'] = os.path.basename(file_path)
|
||||
|
||||
# Log individual file statistics
|
||||
logger.debug(f"File: {stats['file_name']}")
|
||||
logger.debug(f" - Size: {file_size} bytes")
|
||||
logger.debug(f" - Characters: {stats['char_count']}")
|
||||
logger.debug(f" - Words: {stats['word_count']}")
|
||||
logger.debug(f" - Lines: {stats['line_count']}")
|
||||
|
||||
# Call the original function
|
||||
result = func(text, *args, **kwargs)
|
||||
|
||||
return {
|
||||
'result': result,
|
||||
'stats': stats
|
||||
}
|
||||
|
||||
return wrapper
|
||||
|
||||
async def initialize_rag_azure_openai():
|
||||
rag = LightRAG(
|
||||
working_dir=os.environ["KNOWLEDGE_GRAPH_PATH_GPT4o"],
|
||||
graph_storage="NetworkXStorage",
|
||||
kv_storage="JsonKVStorage",
|
||||
vector_storage="FaissVectorDBStorage",
|
||||
vector_db_storage_cls_kwargs={
|
||||
"cosine_better_than_threshold": 0.2
|
||||
},
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=3072,
|
||||
max_token_size=8192,
|
||||
func=lambda texts: azure_openai_embed(texts)
|
||||
),
|
||||
llm_model_func=azure_openai_complete,
|
||||
enable_llm_cache=False,
|
||||
enable_llm_cache_for_entity_extract=False,
|
||||
embedding_cache_config={
|
||||
"enabled": False,
|
||||
"similarity_threshold": 0.95,
|
||||
"use_llm_check": False
|
||||
},
|
||||
)
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
return rag
|
||||
|
||||
async def initilize_rag_ollama():
|
||||
rag = LightRAG(
|
||||
working_dir=os.environ["KNOWLEDGE_GRAPH_PATH_GEMMA327b"],
|
||||
working_dir=os.environ["KNOWLEDGE_GRAPH_PATH"],
|
||||
graph_storage="NetworkXStorage", # "Neo4JStorage",
|
||||
kv_storage="JsonKVStorage",
|
||||
vector_storage="FaissVectorDBStorage",
|
||||
@@ -86,31 +103,41 @@ async def initilize_rag_ollama():
|
||||
return rag
|
||||
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
logger.info("Initializing lightRAG instance")
|
||||
rag = asyncio.run(initilize_rag_ollama())
|
||||
|
||||
input_dir_path = "/Users/tcudikel/Dev/ancient-history/data/input/transcripts"
|
||||
txt_files = glob.glob(f"{input_dir_path}/*.txt")
|
||||
logger.debug(f"found {len(txt_files)} files in {input_dir_path}")
|
||||
logger.info(f"Found {len(txt_files)} text files in {input_dir_path}")
|
||||
|
||||
# Collect statistics
|
||||
all_stats = []
|
||||
|
||||
@with_statistics
|
||||
def process_file(text, rag):
|
||||
return rag.insert(text)
|
||||
|
||||
for file_path in tqdm(txt_files, desc="Processing files", unit="file", miniters=1, ncols=100, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'):
|
||||
text = read_text_file(file_path)
|
||||
rag.insert(text)
|
||||
result_with_stats = process_file(file_path, rag)
|
||||
all_stats.append(result_with_stats['stats'])
|
||||
|
||||
# Calculate and log summary statistics
|
||||
if all_stats:
|
||||
char_counts = [stat['char_count'] for stat in all_stats]
|
||||
word_counts = [stat['word_count'] for stat in all_stats]
|
||||
line_counts = [stat['line_count'] for stat in all_stats]
|
||||
|
||||
logger.info("Text statistics summary:")
|
||||
logger.info(f" - Total characters: {sum(char_counts)}")
|
||||
logger.info(f" - Total words: {sum(word_counts)}")
|
||||
logger.info(f" - Total lines: {sum(line_counts)}")
|
||||
logger.info(f" - Average characters per file: {statistics.mean(char_counts):.2f}")
|
||||
logger.info(f" - Average words per file: {statistics.mean(word_counts):.2f}")
|
||||
logger.info(f" - Average lines per file: {statistics.mean(line_counts):.2f}")
|
||||
|
||||
logger.success(f"{len(txt_files)} files inserted into the knowledge graph.")
|
||||
"""mode="mix"
|
||||
rag.query(
|
||||
"What are the top themes in this story?",
|
||||
param=QueryParam(
|
||||
mode=mode,
|
||||
response_type="Single Paragraph",
|
||||
# conversation_history=,
|
||||
# history_turns=5,
|
||||
)
|
||||
)"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user