mirror of
https://github.com/microsoft/graphrag.git
synced 2025-03-11 01:26:14 +03:00
Context property bag ("state") (#1774)
* Add pipeline state property bag to run context * Move state creation out of context util * Move callbacks into PipelineRunContext * Semver * Rename state.json to context.json to avoid confusion with stats.json * Expand smoke test row count * Add util to create storage and cache
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Adds general-purpose pipeline run state object."
|
||||
}
|
||||
@@ -14,10 +14,10 @@ from graphrag.callbacks.reporting import create_pipeline_reporter
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.run.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.run.run_pipeline import run_pipeline
|
||||
from graphrag.index.run.utils import create_callback_chain
|
||||
from graphrag.index.typing import WorkflowFunction
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.typing.workflow import WorkflowFunction
|
||||
from graphrag.index.workflows.factory import PipelineFactory
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""A no-op implementation of WorkflowCallbacks."""
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.run.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from graphrag.index.run.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""A module containing the WorkflowCallbacks registry."""
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.run.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
from graphrag.utils.api import create_storage_from_config
|
||||
from graphrag.utils.storage import load_table_from_storage, storage_has_table
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -497,10 +497,7 @@ def _resolve_output_files(
|
||||
dataframe_dict["num_indexes"] = len(config.outputs)
|
||||
dataframe_dict["index_names"] = config.outputs.keys()
|
||||
for output in config.outputs.values():
|
||||
output_config = output.model_dump()
|
||||
storage_obj = StorageFactory().create_storage(
|
||||
storage_type=output_config["type"], kwargs=output_config
|
||||
)
|
||||
storage_obj = create_storage_from_config(output)
|
||||
for name in output_list:
|
||||
if name not in dataframe_dict:
|
||||
dataframe_dict[name] = []
|
||||
@@ -527,10 +524,7 @@ def _resolve_output_files(
|
||||
return dataframe_dict
|
||||
# Loading output files for single-index search
|
||||
dataframe_dict["multi-index"] = False
|
||||
output_config = config.output.model_dump()
|
||||
storage_obj = StorageFactory().create_storage(
|
||||
storage_type=output_config["type"], kwargs=output_config
|
||||
)
|
||||
storage_obj = create_storage_from_config(config.output)
|
||||
for name in output_list:
|
||||
df_value = asyncio.run(load_table_from_storage(name=name, storage=storage_obj))
|
||||
dataframe_dict[name] = df_value
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
# isort: skip_file
|
||||
"""A module containing the 'PipelineRunStats' and 'PipelineRunContext' models."""
|
||||
|
||||
from dataclasses import dataclass as dc_dataclass
|
||||
from dataclasses import field
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
@dc_dataclass
|
||||
class PipelineRunStats:
|
||||
"""Pipeline running stats."""
|
||||
|
||||
total_runtime: float = field(default=0)
|
||||
"""Float representing the total runtime."""
|
||||
|
||||
num_documents: int = field(default=0)
|
||||
"""Number of documents."""
|
||||
|
||||
input_load_time: float = field(default=0)
|
||||
"""Float representing the input load time."""
|
||||
|
||||
workflows: dict[str, dict[str, float]] = field(default_factory=dict)
|
||||
"""A dictionary of workflows."""
|
||||
|
||||
|
||||
@dc_dataclass
|
||||
class PipelineRunContext:
|
||||
"""Provides the context for the current pipeline run."""
|
||||
|
||||
stats: PipelineRunStats
|
||||
storage: PipelineStorage
|
||||
"Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
|
||||
cache: PipelineCache
|
||||
"Cache instance for reading previous LLM responses."
|
||||
@@ -13,7 +13,7 @@ from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.operations.build_noun_graph.np_extractors.base import (
|
||||
BaseNounPhraseExtractor,
|
||||
)
|
||||
from graphrag.index.run.derive_from_rows import derive_from_rows
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Any
|
||||
import tiktoken
|
||||
|
||||
from graphrag.config.defaults import ENCODING_MODEL, graphrag_config_defaults
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.index.extract_claims import (
|
||||
CONTINUE_PROMPT,
|
||||
|
||||
@@ -20,7 +20,7 @@ from graphrag.index.operations.extract_covariates.typing import (
|
||||
Covariate,
|
||||
CovariateExtractionResult,
|
||||
)
|
||||
from graphrag.index.run.derive_from_rows import derive_from_rows
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,7 +16,7 @@ from graphrag.index.operations.extract_graph.typing import (
|
||||
EntityExtractStrategy,
|
||||
ExtractEntityStrategyType,
|
||||
)
|
||||
from graphrag.index.run.derive_from_rows import derive_from_rows
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import networkx as nx
|
||||
import tiktoken
|
||||
|
||||
from graphrag.config.defaults import ENCODING_MODEL, graphrag_config_defaults
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
from graphrag.index.utils.string import clean_str
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.index.extract_graph import (
|
||||
|
||||
@@ -14,7 +14,7 @@ from graphrag.index.operations.layout_graph.typing import (
|
||||
GraphLayout,
|
||||
NodePosition,
|
||||
)
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
|
||||
# TODO: This could be handled more elegantly, like what columns to use
|
||||
# for "size" or "cluster"
|
||||
|
||||
@@ -12,7 +12,7 @@ from graphrag.index.operations.layout_graph.typing import (
|
||||
GraphLayout,
|
||||
NodePosition,
|
||||
)
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
|
||||
# TODO: This could be handled more elegantly, like what columns to use
|
||||
# for "size" or "cluster"
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from graphrag.index.operations.summarize_communities.typing import (
|
||||
from graphrag.index.operations.summarize_communities.utils import (
|
||||
get_levels,
|
||||
)
|
||||
from graphrag.index.run.derive_from_rows import derive_from_rows
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
from graphrag.logger.progress import progress_ticker
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
from graphrag.index.utils.tokens import num_tokens_from_string
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
|
||||
|
||||
@@ -13,24 +13,23 @@ from dataclasses import asdict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.factory import CacheFactory
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunStats
|
||||
from graphrag.index.input.factory import create_input
|
||||
from graphrag.index.run.pipeline import Pipeline
|
||||
from graphrag.index.run.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.pipeline import Pipeline
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.update.incremental_index import (
|
||||
get_delta_docs,
|
||||
update_dataframe_outputs,
|
||||
)
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.progress import Progress
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.utils.api import create_cache_from_config, create_storage_from_config
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -46,17 +45,8 @@ async def run_pipeline(
|
||||
"""Run all workflows using a simplified pipeline."""
|
||||
root_dir = config.root_dir
|
||||
|
||||
storage_config = config.output.model_dump()
|
||||
storage = StorageFactory().create_storage(
|
||||
storage_type=storage_config["type"],
|
||||
kwargs=storage_config,
|
||||
)
|
||||
cache_config = config.cache.model_dump()
|
||||
cache = CacheFactory().create_cache(
|
||||
cache_type=cache_config["type"],
|
||||
root_dir=root_dir,
|
||||
kwargs=cache_config,
|
||||
)
|
||||
storage = create_storage_from_config(config.output)
|
||||
cache = create_cache_from_config(config.cache, root_dir)
|
||||
|
||||
dataset = await create_input(config.input, logger, root_dir)
|
||||
|
||||
@@ -70,11 +60,7 @@ async def run_pipeline(
|
||||
warning_msg = "Incremental indexing found no new documents, exiting."
|
||||
logger.warning(warning_msg)
|
||||
else:
|
||||
update_storage_config = config.update_index_output.model_dump()
|
||||
update_storage = StorageFactory().create_storage(
|
||||
storage_type=update_storage_config["type"],
|
||||
kwargs=update_storage_config,
|
||||
)
|
||||
update_storage = create_storage_from_config(config.update_index_output)
|
||||
# we use this to store the new subset index, and will merge its content with the previous index
|
||||
timestamped_storage = update_storage.child(time.strftime("%Y%m%d-%H%M%S"))
|
||||
delta_storage = timestamped_storage.child("delta")
|
||||
@@ -133,15 +119,20 @@ async def _run_pipeline(
|
||||
) -> AsyncIterable[PipelineRunResult]:
|
||||
start_time = time.time()
|
||||
|
||||
context = create_run_context(storage=storage, cache=cache, stats=None)
|
||||
# load existing state in case any workflows are stateful
|
||||
state_json = await storage.get("context.json")
|
||||
state = json.loads(state_json) if state_json else {}
|
||||
|
||||
context = create_run_context(
|
||||
storage=storage, cache=cache, callbacks=callbacks, state=state
|
||||
)
|
||||
|
||||
log.info("Final # of rows loaded: %s", len(dataset))
|
||||
context.stats.num_documents = len(dataset)
|
||||
last_workflow = "starting documents"
|
||||
|
||||
conf = config.model_copy()
|
||||
try:
|
||||
await _dump_stats(context.stats, context.storage)
|
||||
await _dump_json(context)
|
||||
await write_table_to_storage(dataset, "documents", context.storage)
|
||||
|
||||
for name, workflow_function in pipeline.run():
|
||||
@@ -149,32 +140,33 @@ async def _run_pipeline(
|
||||
progress = logger.child(name, transient=False)
|
||||
callbacks.workflow_start(name, None)
|
||||
work_time = time.time()
|
||||
result = await workflow_function(
|
||||
conf,
|
||||
context,
|
||||
callbacks,
|
||||
)
|
||||
result = await workflow_function(config, context)
|
||||
progress(Progress(percent=1))
|
||||
callbacks.workflow_end(name, result)
|
||||
if result.config:
|
||||
conf = result.config
|
||||
yield PipelineRunResult(name, result.result, conf, None)
|
||||
yield PipelineRunResult(
|
||||
workflow=name, result=result.result, state=context.state, errors=None
|
||||
)
|
||||
|
||||
context.stats.workflows[name] = {"overall": time.time() - work_time}
|
||||
|
||||
context.stats.total_runtime = time.time() - start_time
|
||||
await _dump_stats(context.stats, context.storage)
|
||||
await _dump_json(context)
|
||||
|
||||
except Exception as e:
|
||||
log.exception("error running workflow %s", last_workflow)
|
||||
callbacks.error("Error running pipeline!", e, traceback.format_exc())
|
||||
yield PipelineRunResult(last_workflow, None, conf, [e])
|
||||
yield PipelineRunResult(
|
||||
workflow=last_workflow, result=None, state=context.state, errors=[e]
|
||||
)
|
||||
|
||||
|
||||
async def _dump_stats(stats: PipelineRunStats, storage: PipelineStorage) -> None:
|
||||
"""Dump the stats to the storage."""
|
||||
await storage.set(
|
||||
"stats.json", json.dumps(asdict(stats), indent=4, ensure_ascii=False)
|
||||
async def _dump_json(context: PipelineRunContext) -> None:
|
||||
"""Dump the stats and context state to the storage."""
|
||||
await context.storage.set(
|
||||
"stats.json", json.dumps(asdict(context.stats), indent=4, ensure_ascii=False)
|
||||
)
|
||||
await context.storage.set(
|
||||
"context.json", json.dumps(context.state, indent=4, ensure_ascii=False)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,25 +5,32 @@
|
||||
|
||||
from graphrag.cache.memory_pipeline_cache import InMemoryCache
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.callbacks.progress_workflow_callbacks import ProgressWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks_manager import WorkflowCallbacksManager
|
||||
from graphrag.index.context import PipelineRunContext, PipelineRunStats
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.state import PipelineState
|
||||
from graphrag.index.typing.stats import PipelineRunStats
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
def create_run_context(
|
||||
storage: PipelineStorage | None,
|
||||
cache: PipelineCache | None,
|
||||
stats: PipelineRunStats | None,
|
||||
storage: PipelineStorage | None = None,
|
||||
cache: PipelineCache | None = None,
|
||||
callbacks: WorkflowCallbacks | None = None,
|
||||
stats: PipelineRunStats | None = None,
|
||||
state: PipelineState | None = None,
|
||||
) -> PipelineRunContext:
|
||||
"""Create the run context for the pipeline."""
|
||||
return PipelineRunContext(
|
||||
stats=stats or PipelineRunStats(),
|
||||
cache=cache or InMemoryCache(),
|
||||
storage=storage or MemoryPipelineStorage(),
|
||||
callbacks=callbacks or NoopWorkflowCallbacks(),
|
||||
state=state or {},
|
||||
)
|
||||
|
||||
|
||||
|
||||
4
graphrag/index/typing/__init__.py
Normal file
4
graphrag/index/typing/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Root typings for GraphRAG."""
|
||||
28
graphrag/index/typing/context.py
Normal file
28
graphrag/index/typing/context.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
# isort: skip_file
|
||||
"""A module containing the 'PipelineRunContext' models."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.typing.state import PipelineState
|
||||
from graphrag.index.typing.stats import PipelineRunStats
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunContext:
|
||||
"""Provides the context for the current pipeline run."""
|
||||
|
||||
stats: PipelineRunStats
|
||||
storage: PipelineStorage
|
||||
"Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
|
||||
cache: PipelineCache
|
||||
"Cache instance for reading previous LLM responses."
|
||||
callbacks: WorkflowCallbacks
|
||||
"Callbacks to be called during the pipeline run."
|
||||
state: PipelineState
|
||||
"Arbitrary property bag for runtime state, persistent pre-computes, or experimental features."
|
||||
8
graphrag/index/typing/error_handler.py
Normal file
8
graphrag/index/typing/error_handler.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Shared error handler types."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from graphrag.index.typing import Workflow
|
||||
from graphrag.index.typing.workflow import Workflow
|
||||
|
||||
|
||||
class Pipeline:
|
||||
@@ -6,7 +6,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.typing.state import PipelineState
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -17,6 +17,6 @@ class PipelineRunResult:
|
||||
"""The name of the workflow that was executed."""
|
||||
result: Any | None
|
||||
"""The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage."""
|
||||
config: GraphRagConfig | None
|
||||
"""Final config after running the workflow, which may have been mutated."""
|
||||
state: PipelineState
|
||||
"""Ongoing pipeline context state object."""
|
||||
errors: list[BaseException] | None
|
||||
8
graphrag/index/typing/state.py
Normal file
8
graphrag/index/typing/state.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Pipeline state types."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
PipelineState = dict[Any, Any]
|
||||
23
graphrag/index/typing/stats.py
Normal file
23
graphrag/index/typing/stats.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Pipeline stats types."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunStats:
|
||||
"""Pipeline running stats."""
|
||||
|
||||
total_runtime: float = field(default=0)
|
||||
"""Float representing the total runtime."""
|
||||
|
||||
num_documents: int = field(default=0)
|
||||
"""Number of documents."""
|
||||
|
||||
input_load_time: float = field(default=0)
|
||||
"""Float representing the input load time."""
|
||||
|
||||
workflows: dict[str, dict[str, float]] = field(default_factory=dict)
|
||||
"""A dictionary of workflows."""
|
||||
@@ -1,17 +1,14 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing the 'PipelineRunResult' model."""
|
||||
"""Pipeline workflow types."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
|
||||
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -20,12 +17,10 @@ class WorkflowFunctionOutput:
|
||||
|
||||
result: Any | None
|
||||
"""The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage."""
|
||||
config: GraphRagConfig | None
|
||||
"""If the config is mutated, return the mutated config here. This allows users to design workflows that tune config for downstream workflow use."""
|
||||
|
||||
|
||||
WorkflowFunction = Callable[
|
||||
[GraphRagConfig, PipelineRunContext, WorkflowCallbacks],
|
||||
[GraphRagConfig, PipelineRunContext],
|
||||
Awaitable[WorkflowFunctionOutput],
|
||||
]
|
||||
Workflow = tuple[str, WorkflowFunction]
|
||||
@@ -11,10 +11,10 @@ import pandas as pd
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.chunking_config import ChunkStrategyType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
|
||||
from graphrag.index.operations.chunk_text.strategies import get_encoding_fn
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.progress import Progress
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
@@ -23,7 +23,6 @@ from graphrag.utils.storage import load_table_from_storage, write_table_to_stora
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform base text_units."""
|
||||
documents = await load_table_from_storage("documents", context.storage)
|
||||
@@ -32,7 +31,7 @@ async def run_workflow(
|
||||
|
||||
output = create_base_text_units(
|
||||
documents,
|
||||
callbacks,
|
||||
context.callbacks,
|
||||
chunks.group_by_columns,
|
||||
chunks.size,
|
||||
chunks.overlap,
|
||||
@@ -44,7 +43,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "text_units", context.storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output, config=None)
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
def create_base_text_units(
|
||||
|
||||
@@ -10,20 +10,18 @@ from uuid import uuid4
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.cluster_graph import cluster_graph
|
||||
from graphrag.index.operations.create_graph import create_graph
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
_callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform final communities."""
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
@@ -43,7 +41,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "communities", context.storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output, config=None)
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
def create_communities(
|
||||
|
||||
@@ -11,7 +11,6 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.finalize_community_reports import (
|
||||
finalize_community_reports,
|
||||
)
|
||||
@@ -25,7 +24,8 @@ from graphrag.index.operations.summarize_communities.graph_context.context_build
|
||||
from graphrag.index.operations.summarize_communities.summarize_communities import (
|
||||
summarize_communities,
|
||||
)
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import (
|
||||
load_table_from_storage,
|
||||
storage_has_table,
|
||||
@@ -36,7 +36,6 @@ from graphrag.utils.storage import (
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
edges = await load_table_from_storage("relationships", context.storage)
|
||||
@@ -62,7 +61,7 @@ async def run_workflow(
|
||||
entities=entities,
|
||||
communities=communities,
|
||||
claims_input=claims,
|
||||
callbacks=callbacks,
|
||||
callbacks=context.callbacks,
|
||||
cache=context.cache,
|
||||
summarization_strategy=summarization_strategy,
|
||||
async_mode=async_mode,
|
||||
@@ -71,7 +70,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "community_reports", context.storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output, config=None)
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
async def create_community_reports(
|
||||
|
||||
@@ -12,7 +12,6 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.finalize_community_reports import (
|
||||
finalize_community_reports,
|
||||
)
|
||||
@@ -26,7 +25,8 @@ from graphrag.index.operations.summarize_communities.text_unit_context.context_b
|
||||
build_level_context,
|
||||
build_local_context,
|
||||
)
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -35,7 +35,6 @@ log = logging.getLogger(__name__)
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
@@ -56,7 +55,7 @@ async def run_workflow(
|
||||
entities,
|
||||
communities,
|
||||
text_units,
|
||||
callbacks,
|
||||
context.callbacks,
|
||||
context.cache,
|
||||
summarization_strategy,
|
||||
async_mode=async_mode,
|
||||
@@ -65,7 +64,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "community_reports", context.storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output, config=None)
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
async def create_community_reports_text(
|
||||
|
||||
@@ -5,18 +5,16 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
_config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
_callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform final documents."""
|
||||
documents = await load_table_from_storage("documents", context.storage)
|
||||
@@ -26,7 +24,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "documents", context.storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output, config=None)
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
def create_final_documents(
|
||||
|
||||
@@ -5,11 +5,10 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.schemas import TEXT_UNITS_FINAL_COLUMNS
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import (
|
||||
load_table_from_storage,
|
||||
storage_has_table,
|
||||
@@ -20,7 +19,6 @@ from graphrag.utils.storage import (
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
_callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform the text units."""
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
@@ -43,7 +41,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "text_units", context.storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output, config=None)
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
def create_final_text_units(
|
||||
|
||||
@@ -13,18 +13,17 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.data_model.schemas import COVARIATES_FINAL_COLUMNS
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.extract_covariates.extract_covariates import (
|
||||
extract_covariates as extractor,
|
||||
)
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to extract and format covariates."""
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
@@ -41,7 +40,7 @@ async def run_workflow(
|
||||
|
||||
output = await extract_covariates(
|
||||
text_units,
|
||||
callbacks,
|
||||
context.callbacks,
|
||||
context.cache,
|
||||
"claim",
|
||||
extraction_strategy,
|
||||
@@ -52,7 +51,7 @@ async def run_workflow(
|
||||
|
||||
await write_table_to_storage(output, "covariates", context.storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output, config=None)
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
async def extract_covariates(
|
||||
|
||||
@@ -11,21 +11,20 @@ from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.extract_graph.extract_graph import (
|
||||
extract_graph as extractor,
|
||||
)
|
||||
from graphrag.index.operations.summarize_descriptions import (
|
||||
summarize_descriptions,
|
||||
)
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
@@ -46,7 +45,7 @@ async def run_workflow(
|
||||
|
||||
entities, relationships = await extract_graph(
|
||||
text_units=text_units,
|
||||
callbacks=callbacks,
|
||||
callbacks=context.callbacks,
|
||||
cache=context.cache,
|
||||
extraction_strategy=extraction_strategy,
|
||||
extraction_num_threads=extract_graph_llm_settings.concurrent_requests,
|
||||
@@ -63,8 +62,7 @@ async def run_workflow(
|
||||
result={
|
||||
"entities": entities,
|
||||
"relationships": relationships,
|
||||
},
|
||||
config=None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,22 +6,20 @@
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.build_noun_graph.build_noun_graph import build_noun_graph
|
||||
from graphrag.index.operations.build_noun_graph.np_extractors.factory import (
|
||||
create_noun_phrase_extractor,
|
||||
)
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
_callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
@@ -39,8 +37,7 @@ async def run_workflow(
|
||||
result={
|
||||
"entities": entities,
|
||||
"relationships": relationships,
|
||||
},
|
||||
config=None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ from typing import ClassVar
|
||||
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.run.pipeline import Pipeline
|
||||
from graphrag.index.typing import WorkflowFunction
|
||||
from graphrag.index.typing.pipeline import Pipeline
|
||||
from graphrag.index.typing.workflow import WorkflowFunction
|
||||
|
||||
|
||||
class PipelineFactory:
|
||||
|
||||
@@ -8,19 +8,18 @@ import pandas as pd
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.create_graph import create_graph
|
||||
from graphrag.index.operations.finalize_entities import finalize_entities
|
||||
from graphrag.index.operations.finalize_relationships import finalize_relationships
|
||||
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
@@ -29,7 +28,7 @@ async def run_workflow(
|
||||
final_entities, final_relationships = finalize_graph(
|
||||
entities,
|
||||
relationships,
|
||||
callbacks,
|
||||
callbacks=context.callbacks,
|
||||
embed_config=config.embed_graph,
|
||||
layout_enabled=config.umap.enabled,
|
||||
)
|
||||
@@ -50,8 +49,7 @@ async def run_workflow(
|
||||
result={
|
||||
"entities": entities,
|
||||
"relationships": relationships,
|
||||
},
|
||||
config=None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,9 +22,9 @@ from graphrag.config.embeddings import (
|
||||
text_unit_text_embedding,
|
||||
)
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -33,7 +33,6 @@ log = logging.getLogger(__name__)
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
documents = await load_table_from_storage("documents", context.storage)
|
||||
@@ -47,27 +46,27 @@ async def run_workflow(
|
||||
embedded_fields = get_embedded_fields(config)
|
||||
text_embed = get_embedding_settings(config)
|
||||
|
||||
result = await generate_text_embeddings(
|
||||
output = await generate_text_embeddings(
|
||||
documents=documents,
|
||||
relationships=relationships,
|
||||
text_units=text_units,
|
||||
entities=entities,
|
||||
community_reports=community_reports,
|
||||
callbacks=callbacks,
|
||||
callbacks=context.callbacks,
|
||||
cache=context.cache,
|
||||
text_embed_config=text_embed,
|
||||
embedded_fields=embedded_fields,
|
||||
)
|
||||
|
||||
if config.snapshots.embeddings:
|
||||
for name, table in result.items():
|
||||
for name, table in output.items():
|
||||
await write_table_to_storage(
|
||||
table,
|
||||
f"embeddings.{name}",
|
||||
context.storage,
|
||||
)
|
||||
|
||||
return WorkflowFunctionOutput(result=result, config=None)
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
async def generate_text_embeddings(
|
||||
|
||||
@@ -5,21 +5,19 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.prune_graph_config import PruneGraphConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.operations.create_graph import create_graph
|
||||
from graphrag.index.operations.graph_to_dataframes import graph_to_dataframes
|
||||
from graphrag.index.operations.prune_graph import prune_graph as prune_graph_operation
|
||||
from graphrag.index.typing import WorkflowFunctionOutput
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
_callbacks: WorkflowCallbacks,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
@@ -38,8 +36,7 @@ async def run_workflow(
|
||||
result={
|
||||
"entities": pruned_entities,
|
||||
"relationships": pruned_relationships,
|
||||
},
|
||||
config=None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any
|
||||
|
||||
from fnllm.events import LLMEvents
|
||||
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
|
||||
|
||||
class FNLLMEvents(LLMEvents):
|
||||
|
||||
@@ -18,7 +18,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.language_model_config import (
|
||||
LanguageModelConfig,
|
||||
)
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
from graphrag.language_model.providers.fnllm.cache import FNLLMCacheProvider
|
||||
|
||||
|
||||
|
||||
@@ -6,8 +6,14 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from graphrag.cache.factory import CacheFactory
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
from graphrag.config.models.output_config import OutputConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.vector_stores.base import (
|
||||
BaseVectorStore,
|
||||
VectorStoreDocument,
|
||||
@@ -230,3 +236,22 @@ def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None:
|
||||
if prompt_file.exists():
|
||||
return prompt_file.read_bytes().decode(encoding="utf-8")
|
||||
return None
|
||||
|
||||
|
||||
def create_storage_from_config(output: OutputConfig) -> PipelineStorage:
|
||||
"""Create a storage object from the config."""
|
||||
storage_config = output.model_dump()
|
||||
return StorageFactory().create_storage(
|
||||
storage_type=storage_config["type"],
|
||||
kwargs=storage_config,
|
||||
)
|
||||
|
||||
|
||||
def create_cache_from_config(cache: CacheConfig, root_dir: str) -> PipelineCache:
|
||||
"""Create a cache object from the config."""
|
||||
cache_config = cache.model_dump()
|
||||
return CacheFactory().create_cache(
|
||||
cache_type=cache_config["type"],
|
||||
root_dir=root_dir,
|
||||
kwargs=cache_config,
|
||||
)
|
||||
|
||||
2
tests/fixtures/text/config.json
vendored
2
tests/fixtures/text/config.json
vendored
@@ -11,7 +11,7 @@
|
||||
"finalize_graph": {
|
||||
"row_range": [
|
||||
1,
|
||||
100
|
||||
500
|
||||
],
|
||||
"nan_allowed_columns": [
|
||||
"x",
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.workflows.create_base_text_units import run_workflow
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
@@ -22,11 +21,7 @@ async def test_create_base_text_units():
|
||||
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.storage)
|
||||
|
||||
@@ -46,11 +41,7 @@ async def test_create_base_text_units_metadata():
|
||||
|
||||
await update_document_metadata(config.input.metadata, context)
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.storage)
|
||||
compare_outputs(actual, expected)
|
||||
@@ -70,11 +61,7 @@ async def test_create_base_text_units_metadata_included_in_chunk():
|
||||
|
||||
await update_document_metadata(config.input.metadata, context)
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.storage)
|
||||
# only check the columns from the base workflow - our expected table is the final and will have more
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS
|
||||
from graphrag.index.workflows.create_communities import (
|
||||
@@ -32,7 +31,6 @@ async def test_create_communities():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage("communities", context.storage)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.enums import ModelType
|
||||
from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS
|
||||
@@ -65,11 +64,7 @@ async def test_create_community_reports():
|
||||
"graph_prompt": "",
|
||||
}
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("community_reports", context.storage)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS
|
||||
from graphrag.index.workflows.create_final_documents import (
|
||||
@@ -27,11 +26,7 @@ async def test_create_final_documents():
|
||||
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("documents", context.storage)
|
||||
|
||||
@@ -54,11 +49,7 @@ async def test_create_final_documents_with_metadata_column():
|
||||
|
||||
expected = await load_table_from_storage("documents", context.storage)
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("documents", context.storage)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.data_model.schemas import TEXT_UNITS_FINAL_COLUMNS
|
||||
from graphrag.index.workflows.create_final_text_units import (
|
||||
@@ -32,11 +31,7 @@ async def test_create_final_text_units():
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
config.extract_claims.enabled = True
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.storage)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
from pandas.testing import assert_series_equal
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.enums import ModelType
|
||||
from graphrag.data_model.schemas import COVARIATES_FINAL_COLUMNS
|
||||
@@ -44,11 +43,7 @@ async def test_extract_covariates():
|
||||
"claim_description": "description",
|
||||
}
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("covariates", context.storage)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.enums import ModelType
|
||||
from graphrag.index.workflows.extract_graph import (
|
||||
@@ -60,11 +59,7 @@ async def test_extract_graph():
|
||||
"llm": summarize_llm_settings,
|
||||
}
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
nodes_actual = await load_table_from_storage("entities", context.storage)
|
||||
edges_actual = await load_table_from_storage("relationships", context.storage)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.workflows.extract_graph_nlp import (
|
||||
run_workflow,
|
||||
@@ -21,11 +20,7 @@ async def test_extract_graph_nlp():
|
||||
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
nodes_actual = await load_table_from_storage("entities", context.storage)
|
||||
edges_actual = await load_table_from_storage("relationships", context.storage)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.data_model.schemas import (
|
||||
ENTITIES_FINAL_COLUMNS,
|
||||
@@ -24,11 +23,7 @@ async def test_finalize_graph():
|
||||
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
nodes_actual = await load_table_from_storage("entities", context.storage)
|
||||
edges_actual = await load_table_from_storage("relationships", context.storage)
|
||||
@@ -54,11 +49,7 @@ async def test_finalize_graph_umap():
|
||||
config.embed_graph.enabled = True
|
||||
config.umap.enabled = True
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
nodes_actual = await load_table_from_storage("entities", context.storage)
|
||||
edges_actual = await load_table_from_storage("relationships", context.storage)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.embeddings import (
|
||||
all_embeddings,
|
||||
@@ -43,11 +42,7 @@ async def test_generate_text_embeddings():
|
||||
config.embed_text.target = TextEmbeddingTarget.all
|
||||
config.snapshots.embeddings = True
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
parquet_files = context.storage.keys()
|
||||
|
||||
|
||||
54
tests/verbs/test_pipeline_state.py
Normal file
54
tests/verbs/test_pipeline_state.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Tests for pipeline state passthrough."""
|
||||
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.index.workflows.factory import PipelineFactory
|
||||
from tests.verbs.util import DEFAULT_MODEL_CONFIG
|
||||
|
||||
|
||||
async def run_workflow_1( # noqa: RUF029
|
||||
_config: GraphRagConfig, context: PipelineRunContext
|
||||
):
|
||||
context.state["count"] = 1
|
||||
return WorkflowFunctionOutput(result=None)
|
||||
|
||||
|
||||
async def run_workflow_2( # noqa: RUF029
|
||||
_config: GraphRagConfig, context: PipelineRunContext
|
||||
):
|
||||
context.state["count"] += 1
|
||||
return WorkflowFunctionOutput(result=None)
|
||||
|
||||
|
||||
async def test_pipeline_state():
|
||||
# checks that we can update the arbitrary state block within the pipeline run context
|
||||
PipelineFactory.register("workflow_1", run_workflow_1)
|
||||
PipelineFactory.register("workflow_2", run_workflow_2)
|
||||
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
config.workflows = ["workflow_1", "workflow_2"]
|
||||
context = create_run_context()
|
||||
|
||||
for _, fn in PipelineFactory.create_pipeline(config).run():
|
||||
await fn(config, context)
|
||||
|
||||
assert context.state["count"] == 2
|
||||
|
||||
|
||||
async def test_pipeline_existing_state():
|
||||
PipelineFactory.register("workflow_2", run_workflow_2)
|
||||
|
||||
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||||
config.workflows = ["workflow_2"]
|
||||
context = create_run_context(state={"count": 4})
|
||||
|
||||
for _, fn in PipelineFactory.create_pipeline(config).run():
|
||||
await fn(config, context)
|
||||
|
||||
assert context.state["count"] == 5
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.models.prune_graph_config import PruneGraphConfig
|
||||
from graphrag.index.workflows.prune_graph import (
|
||||
@@ -25,11 +24,7 @@ async def test_prune_graph():
|
||||
min_node_freq=4, min_node_degree=0, min_edge_weight_pct=0
|
||||
)
|
||||
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
await run_workflow(config, context)
|
||||
|
||||
nodes_actual = await load_table_from_storage("entities", context.storage)
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ import pandas as pd
|
||||
from pandas.testing import assert_series_equal
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
pd.set_option("display.max_columns", None)
|
||||
@@ -33,7 +33,7 @@ DEFAULT_MODEL_CONFIG = {
|
||||
|
||||
async def create_test_context(storage: list[str] | None = None) -> PipelineRunContext:
|
||||
"""Create a test context with tables loaded into storage storage."""
|
||||
context = create_run_context(None, None, None)
|
||||
context = create_run_context()
|
||||
|
||||
# always set the input docs, but since our stored table is final, drop what wouldn't be in the original source input
|
||||
input = load_test_table("documents")
|
||||
|
||||
Reference in New Issue
Block a user