Context property bag ("state") (#1774)

* Add pipeline state property bag to run context

* Move state creation out of context util

* Move callbacks into PipelineRunContext

* Semver

* Rename state.json to context.json to avoid confusion with stats.json

* Expand smoke test row count

* Add util to create storage and cache
This commit is contained in:
Nathan Evans
2025-02-28 09:31:48 -08:00
committed by GitHub
parent a15942629b
commit bd06d8b4f0
58 changed files with 286 additions and 272 deletions

View File

@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Adds general-purpose pipeline run state object."
}

View File

@@ -14,10 +14,10 @@ from graphrag.callbacks.reporting import create_pipeline_reporter
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import IndexingMethod
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.run.pipeline_run_result import PipelineRunResult
from graphrag.index.run.run_pipeline import run_pipeline
from graphrag.index.run.utils import create_callback_chain
from graphrag.index.typing import WorkflowFunction
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.index.typing.workflow import WorkflowFunction
from graphrag.index.workflows.factory import PipelineFactory
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger

View File

@@ -4,7 +4,7 @@
"""A no-op implementation of WorkflowCallbacks."""
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.run.pipeline_run_result import PipelineRunResult
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.logger.progress import Progress

View File

@@ -5,7 +5,7 @@
from typing import Protocol
from graphrag.index.run.pipeline_run_result import PipelineRunResult
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.logger.progress import Progress

View File

@@ -4,7 +4,7 @@
"""A module containing the WorkflowCallbacks registry."""
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.run.pipeline_run_result import PipelineRunResult
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.logger.progress import Progress

View File

@@ -13,7 +13,7 @@ from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks
from graphrag.config.load_config import load_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.storage.factory import StorageFactory
from graphrag.utils.api import create_storage_from_config
from graphrag.utils.storage import load_table_from_storage, storage_has_table
if TYPE_CHECKING:
@@ -497,10 +497,7 @@ def _resolve_output_files(
dataframe_dict["num_indexes"] = len(config.outputs)
dataframe_dict["index_names"] = config.outputs.keys()
for output in config.outputs.values():
output_config = output.model_dump()
storage_obj = StorageFactory().create_storage(
storage_type=output_config["type"], kwargs=output_config
)
storage_obj = create_storage_from_config(output)
for name in output_list:
if name not in dataframe_dict:
dataframe_dict[name] = []
@@ -527,10 +524,7 @@ def _resolve_output_files(
return dataframe_dict
# Loading output files for single-index search
dataframe_dict["multi-index"] = False
output_config = config.output.model_dump()
storage_obj = StorageFactory().create_storage(
storage_type=output_config["type"], kwargs=output_config
)
storage_obj = create_storage_from_config(config.output)
for name in output_list:
df_value = asyncio.run(load_table_from_storage(name=name, storage=storage_obj))
dataframe_dict[name] = df_value

View File

@@ -1,39 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
# isort: skip_file
"""A module containing the 'PipelineRunStats' and 'PipelineRunContext' models."""
from dataclasses import dataclass as dc_dataclass
from dataclasses import field
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.storage.pipeline_storage import PipelineStorage
@dc_dataclass
class PipelineRunStats:
"""Pipeline running stats."""
total_runtime: float = field(default=0)
"""Float representing the total runtime."""
num_documents: int = field(default=0)
"""Number of documents."""
input_load_time: float = field(default=0)
"""Float representing the input load time."""
workflows: dict[str, dict[str, float]] = field(default_factory=dict)
"""A dictionary of workflows."""
@dc_dataclass
class PipelineRunContext:
"""Provides the context for the current pipeline run."""
stats: PipelineRunStats
storage: PipelineStorage
"Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
cache: PipelineCache
"Cache instance for reading previous LLM responses."

View File

@@ -13,7 +13,7 @@ from graphrag.config.enums import AsyncType
from graphrag.index.operations.build_noun_graph.np_extractors.base import (
BaseNounPhraseExtractor,
)
from graphrag.index.run.derive_from_rows import derive_from_rows
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.index.utils.hashing import gen_sha512_hash

View File

@@ -11,7 +11,7 @@ from typing import Any
import tiktoken
from graphrag.config.defaults import ENCODING_MODEL, graphrag_config_defaults
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.typing.error_handler import ErrorHandlerFn
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.index.extract_claims import (
CONTINUE_PROMPT,

View File

@@ -20,7 +20,7 @@ from graphrag.index.operations.extract_covariates.typing import (
Covariate,
CovariateExtractionResult,
)
from graphrag.index.run.derive_from_rows import derive_from_rows
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.language_model.manager import ModelManager
log = logging.getLogger(__name__)

View File

@@ -16,7 +16,7 @@ from graphrag.index.operations.extract_graph.typing import (
EntityExtractStrategy,
ExtractEntityStrategyType,
)
from graphrag.index.run.derive_from_rows import derive_from_rows
from graphrag.index.utils.derive_from_rows import derive_from_rows
log = logging.getLogger(__name__)

View File

@@ -14,7 +14,7 @@ import networkx as nx
import tiktoken
from graphrag.config.defaults import ENCODING_MODEL, graphrag_config_defaults
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.typing.error_handler import ErrorHandlerFn
from graphrag.index.utils.string import clean_str
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.index.extract_graph import (

View File

@@ -14,7 +14,7 @@ from graphrag.index.operations.layout_graph.typing import (
GraphLayout,
NodePosition,
)
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.typing.error_handler import ErrorHandlerFn
# TODO: This could be handled more elegantly, like what columns to use
# for "size" or "cluster"

View File

@@ -12,7 +12,7 @@ from graphrag.index.operations.layout_graph.typing import (
GraphLayout,
NodePosition,
)
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.typing.error_handler import ErrorHandlerFn
# TODO: This could be handled more elegantly, like what columns to use
# for "size" or "cluster"

View File

@@ -10,7 +10,7 @@ from typing import Any
from pydantic import BaseModel, Field
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.typing.error_handler import ErrorHandlerFn
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT

View File

@@ -21,7 +21,7 @@ from graphrag.index.operations.summarize_communities.typing import (
from graphrag.index.operations.summarize_communities.utils import (
get_levels,
)
from graphrag.index.run.derive_from_rows import derive_from_rows
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.logger.progress import progress_ticker
log = logging.getLogger(__name__)

View File

@@ -6,7 +6,7 @@
import json
from dataclasses import dataclass
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.typing.error_handler import ErrorHandlerFn
from graphrag.index.utils.tokens import num_tokens_from_string
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT

View File

@@ -13,24 +13,23 @@ from dataclasses import asdict
import pandas as pd
from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunStats
from graphrag.index.input.factory import create_input
from graphrag.index.run.pipeline import Pipeline
from graphrag.index.run.pipeline_run_result import PipelineRunResult
from graphrag.index.run.utils import create_run_context
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.pipeline import Pipeline
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.index.update.incremental_index import (
get_delta_docs,
update_dataframe_outputs,
)
from graphrag.logger.base import ProgressLogger
from graphrag.logger.progress import Progress
from graphrag.storage.factory import StorageFactory
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.api import create_cache_from_config, create_storage_from_config
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
log = logging.getLogger(__name__)
@@ -46,17 +45,8 @@ async def run_pipeline(
"""Run all workflows using a simplified pipeline."""
root_dir = config.root_dir
storage_config = config.output.model_dump()
storage = StorageFactory().create_storage(
storage_type=storage_config["type"],
kwargs=storage_config,
)
cache_config = config.cache.model_dump()
cache = CacheFactory().create_cache(
cache_type=cache_config["type"],
root_dir=root_dir,
kwargs=cache_config,
)
storage = create_storage_from_config(config.output)
cache = create_cache_from_config(config.cache, root_dir)
dataset = await create_input(config.input, logger, root_dir)
@@ -70,11 +60,7 @@ async def run_pipeline(
warning_msg = "Incremental indexing found no new documents, exiting."
logger.warning(warning_msg)
else:
update_storage_config = config.update_index_output.model_dump()
update_storage = StorageFactory().create_storage(
storage_type=update_storage_config["type"],
kwargs=update_storage_config,
)
update_storage = create_storage_from_config(config.update_index_output)
# we use this to store the new subset index, and will merge its content with the previous index
timestamped_storage = update_storage.child(time.strftime("%Y%m%d-%H%M%S"))
delta_storage = timestamped_storage.child("delta")
@@ -133,15 +119,20 @@ async def _run_pipeline(
) -> AsyncIterable[PipelineRunResult]:
start_time = time.time()
context = create_run_context(storage=storage, cache=cache, stats=None)
# load existing state in case any workflows are stateful
state_json = await storage.get("context.json")
state = json.loads(state_json) if state_json else {}
context = create_run_context(
storage=storage, cache=cache, callbacks=callbacks, state=state
)
log.info("Final # of rows loaded: %s", len(dataset))
context.stats.num_documents = len(dataset)
last_workflow = "starting documents"
conf = config.model_copy()
try:
await _dump_stats(context.stats, context.storage)
await _dump_json(context)
await write_table_to_storage(dataset, "documents", context.storage)
for name, workflow_function in pipeline.run():
@@ -149,32 +140,33 @@ async def _run_pipeline(
progress = logger.child(name, transient=False)
callbacks.workflow_start(name, None)
work_time = time.time()
result = await workflow_function(
conf,
context,
callbacks,
)
result = await workflow_function(config, context)
progress(Progress(percent=1))
callbacks.workflow_end(name, result)
if result.config:
conf = result.config
yield PipelineRunResult(name, result.result, conf, None)
yield PipelineRunResult(
workflow=name, result=result.result, state=context.state, errors=None
)
context.stats.workflows[name] = {"overall": time.time() - work_time}
context.stats.total_runtime = time.time() - start_time
await _dump_stats(context.stats, context.storage)
await _dump_json(context)
except Exception as e:
log.exception("error running workflow %s", last_workflow)
callbacks.error("Error running pipeline!", e, traceback.format_exc())
yield PipelineRunResult(last_workflow, None, conf, [e])
yield PipelineRunResult(
workflow=last_workflow, result=None, state=context.state, errors=[e]
)
async def _dump_stats(stats: PipelineRunStats, storage: PipelineStorage) -> None:
"""Dump the stats to the storage."""
await storage.set(
"stats.json", json.dumps(asdict(stats), indent=4, ensure_ascii=False)
async def _dump_json(context: PipelineRunContext) -> None:
"""Dump the stats and context state to the storage."""
await context.storage.set(
"stats.json", json.dumps(asdict(context.stats), indent=4, ensure_ascii=False)
)
await context.storage.set(
"context.json", json.dumps(context.state, indent=4, ensure_ascii=False)
)

View File

@@ -5,25 +5,32 @@
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.progress_workflow_callbacks import ProgressWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.callbacks.workflow_callbacks_manager import WorkflowCallbacksManager
from graphrag.index.context import PipelineRunContext, PipelineRunStats
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.state import PipelineState
from graphrag.index.typing.stats import PipelineRunStats
from graphrag.logger.base import ProgressLogger
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.storage.pipeline_storage import PipelineStorage
def create_run_context(
storage: PipelineStorage | None,
cache: PipelineCache | None,
stats: PipelineRunStats | None,
storage: PipelineStorage | None = None,
cache: PipelineCache | None = None,
callbacks: WorkflowCallbacks | None = None,
stats: PipelineRunStats | None = None,
state: PipelineState | None = None,
) -> PipelineRunContext:
"""Create the run context for the pipeline."""
return PipelineRunContext(
stats=stats or PipelineRunStats(),
cache=cache or InMemoryCache(),
storage=storage or MemoryPipelineStorage(),
callbacks=callbacks or NoopWorkflowCallbacks(),
state=state or {},
)

View File

@@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Root typings for GraphRAG."""

View File

@@ -0,0 +1,28 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
# isort: skip_file
"""A module containing the 'PipelineRunContext' models."""
from dataclasses import dataclass
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.typing.state import PipelineState
from graphrag.index.typing.stats import PipelineRunStats
from graphrag.storage.pipeline_storage import PipelineStorage
@dataclass
class PipelineRunContext:
"""Provides the context for the current pipeline run."""
stats: PipelineRunStats
storage: PipelineStorage
"Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
cache: PipelineCache
"Cache instance for reading previous LLM responses."
callbacks: WorkflowCallbacks
"Callbacks to be called during the pipeline run."
state: PipelineState
"Arbitrary property bag for runtime state, persistent pre-computes, or experimental features."

View File

@@ -0,0 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Shared error handler types."""
from collections.abc import Callable
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]

View File

@@ -5,7 +5,7 @@
from collections.abc import Generator
from graphrag.index.typing import Workflow
from graphrag.index.typing.workflow import Workflow
class Pipeline:

View File

@@ -6,7 +6,7 @@
from dataclasses import dataclass
from typing import Any
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.typing.state import PipelineState
@dataclass
@@ -17,6 +17,6 @@ class PipelineRunResult:
"""The name of the workflow that was executed."""
result: Any | None
"""The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage."""
config: GraphRagConfig | None
"""Final config after running the workflow, which may have been mutated."""
state: PipelineState
"""Ongoing pipeline context state object."""
errors: list[BaseException] | None

View File

@@ -0,0 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Pipeline state types."""
from typing import Any
PipelineState = dict[Any, Any]

View File

@@ -0,0 +1,23 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Pipeline stats types."""
from dataclasses import dataclass, field
@dataclass
class PipelineRunStats:
"""Pipeline running stats."""
total_runtime: float = field(default=0)
"""Float representing the total runtime."""
num_documents: int = field(default=0)
"""Number of documents."""
input_load_time: float = field(default=0)
"""Float representing the input load time."""
workflows: dict[str, dict[str, float]] = field(default_factory=dict)
"""A dictionary of workflows."""

View File

@@ -1,17 +1,14 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing the 'PipelineRunResult' model."""
"""Pipeline workflow types."""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
from graphrag.index.typing.context import PipelineRunContext
@dataclass
@@ -20,12 +17,10 @@ class WorkflowFunctionOutput:
result: Any | None
"""The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage."""
config: GraphRagConfig | None
"""If the config is mutated, return the mutated config here. This allows users to design workflows that tune config for downstream workflow use."""
WorkflowFunction = Callable[
[GraphRagConfig, PipelineRunContext, WorkflowCallbacks],
[GraphRagConfig, PipelineRunContext],
Awaitable[WorkflowFunctionOutput],
]
Workflow = tuple[str, WorkflowFunction]

View File

@@ -11,10 +11,10 @@ import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.chunking_config import ChunkStrategyType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
from graphrag.index.operations.chunk_text.strategies import get_encoding_fn
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logger.progress import Progress
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
@@ -23,7 +23,6 @@ from graphrag.utils.storage import load_table_from_storage, write_table_to_stora
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to transform base text_units."""
documents = await load_table_from_storage("documents", context.storage)
@@ -32,7 +31,7 @@ async def run_workflow(
output = create_base_text_units(
documents,
callbacks,
context.callbacks,
chunks.group_by_columns,
chunks.size,
chunks.overlap,
@@ -44,7 +43,7 @@ async def run_workflow(
await write_table_to_storage(output, "text_units", context.storage)
return WorkflowFunctionOutput(result=output, config=None)
return WorkflowFunctionOutput(result=output)
def create_base_text_units(

View File

@@ -10,20 +10,18 @@ from uuid import uuid4
import numpy as np
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
_callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to transform final communities."""
entities = await load_table_from_storage("entities", context.storage)
@@ -43,7 +41,7 @@ async def run_workflow(
await write_table_to_storage(output, "communities", context.storage)
return WorkflowFunctionOutput(result=output, config=None)
return WorkflowFunctionOutput(result=output)
def create_communities(

View File

@@ -11,7 +11,6 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.finalize_community_reports import (
finalize_community_reports,
)
@@ -25,7 +24,8 @@ from graphrag.index.operations.summarize_communities.graph_context.context_build
from graphrag.index.operations.summarize_communities.summarize_communities import (
summarize_communities,
)
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import (
load_table_from_storage,
storage_has_table,
@@ -36,7 +36,6 @@ from graphrag.utils.storage import (
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
edges = await load_table_from_storage("relationships", context.storage)
@@ -62,7 +61,7 @@ async def run_workflow(
entities=entities,
communities=communities,
claims_input=claims,
callbacks=callbacks,
callbacks=context.callbacks,
cache=context.cache,
summarization_strategy=summarization_strategy,
async_mode=async_mode,
@@ -71,7 +70,7 @@ async def run_workflow(
await write_table_to_storage(output, "community_reports", context.storage)
return WorkflowFunctionOutput(result=output, config=None)
return WorkflowFunctionOutput(result=output)
async def create_community_reports(

View File

@@ -12,7 +12,6 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.finalize_community_reports import (
finalize_community_reports,
)
@@ -26,7 +25,8 @@ from graphrag.index.operations.summarize_communities.text_unit_context.context_b
build_level_context,
build_local_context,
)
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
log = logging.getLogger(__name__)
@@ -35,7 +35,6 @@ log = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
entities = await load_table_from_storage("entities", context.storage)
@@ -56,7 +55,7 @@ async def run_workflow(
entities,
communities,
text_units,
callbacks,
context.callbacks,
context.cache,
summarization_strategy,
async_mode=async_mode,
@@ -65,7 +64,7 @@ async def run_workflow(
await write_table_to_storage(output, "community_reports", context.storage)
return WorkflowFunctionOutput(result=output, config=None)
return WorkflowFunctionOutput(result=output)
async def create_community_reports_text(

View File

@@ -5,18 +5,16 @@
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS
from graphrag.index.context import PipelineRunContext
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
async def run_workflow(
_config: GraphRagConfig,
context: PipelineRunContext,
_callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to transform final documents."""
documents = await load_table_from_storage("documents", context.storage)
@@ -26,7 +24,7 @@ async def run_workflow(
await write_table_to_storage(output, "documents", context.storage)
return WorkflowFunctionOutput(result=output, config=None)
return WorkflowFunctionOutput(result=output)
def create_final_documents(

View File

@@ -5,11 +5,10 @@
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.schemas import TEXT_UNITS_FINAL_COLUMNS
from graphrag.index.context import PipelineRunContext
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import (
load_table_from_storage,
storage_has_table,
@@ -20,7 +19,6 @@ from graphrag.utils.storage import (
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
_callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to transform the text units."""
text_units = await load_table_from_storage("text_units", context.storage)
@@ -43,7 +41,7 @@ async def run_workflow(
await write_table_to_storage(output, "text_units", context.storage)
return WorkflowFunctionOutput(result=output, config=None)
return WorkflowFunctionOutput(result=output)
def create_final_text_units(

View File

@@ -13,18 +13,17 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.schemas import COVARIATES_FINAL_COLUMNS
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.extract_covariates.extract_covariates import (
extract_covariates as extractor,
)
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to extract and format covariates."""
text_units = await load_table_from_storage("text_units", context.storage)
@@ -41,7 +40,7 @@ async def run_workflow(
output = await extract_covariates(
text_units,
callbacks,
context.callbacks,
context.cache,
"claim",
extraction_strategy,
@@ -52,7 +51,7 @@ async def run_workflow(
await write_table_to_storage(output, "covariates", context.storage)
return WorkflowFunctionOutput(result=output, config=None)
return WorkflowFunctionOutput(result=output)
async def extract_covariates(

View File

@@ -11,21 +11,20 @@ from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.extract_graph.extract_graph import (
extract_graph as extractor,
)
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
text_units = await load_table_from_storage("text_units", context.storage)
@@ -46,7 +45,7 @@ async def run_workflow(
entities, relationships = await extract_graph(
text_units=text_units,
callbacks=callbacks,
callbacks=context.callbacks,
cache=context.cache,
extraction_strategy=extraction_strategy,
extraction_num_threads=extract_graph_llm_settings.concurrent_requests,
@@ -63,8 +62,7 @@ async def run_workflow(
result={
"entities": entities,
"relationships": relationships,
},
config=None,
}
)

View File

@@ -6,22 +6,20 @@
import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.build_noun_graph.build_noun_graph import build_noun_graph
from graphrag.index.operations.build_noun_graph.np_extractors.factory import (
create_noun_phrase_extractor,
)
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
_callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
text_units = await load_table_from_storage("text_units", context.storage)
@@ -39,8 +37,7 @@ async def run_workflow(
result={
"entities": entities,
"relationships": relationships,
},
config=None,
}
)

View File

@@ -7,8 +7,8 @@ from typing import ClassVar
from graphrag.config.enums import IndexingMethod
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.run.pipeline import Pipeline
from graphrag.index.typing import WorkflowFunction
from graphrag.index.typing.pipeline import Pipeline
from graphrag.index.typing.workflow import WorkflowFunction
class PipelineFactory:

View File

@@ -8,19 +8,18 @@ import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.finalize_entities import finalize_entities
from graphrag.index.operations.finalize_relationships import finalize_relationships
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
entities = await load_table_from_storage("entities", context.storage)
@@ -29,7 +28,7 @@ async def run_workflow(
final_entities, final_relationships = finalize_graph(
entities,
relationships,
callbacks,
callbacks=context.callbacks,
embed_config=config.embed_graph,
layout_enabled=config.umap.enabled,
)
@@ -50,8 +49,7 @@ async def run_workflow(
result={
"entities": entities,
"relationships": relationships,
},
config=None,
}
)

View File

@@ -22,9 +22,9 @@ from graphrag.config.embeddings import (
text_unit_text_embedding,
)
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
log = logging.getLogger(__name__)
@@ -33,7 +33,6 @@ log = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
documents = await load_table_from_storage("documents", context.storage)
@@ -47,27 +46,27 @@ async def run_workflow(
embedded_fields = get_embedded_fields(config)
text_embed = get_embedding_settings(config)
result = await generate_text_embeddings(
output = await generate_text_embeddings(
documents=documents,
relationships=relationships,
text_units=text_units,
entities=entities,
community_reports=community_reports,
callbacks=callbacks,
callbacks=context.callbacks,
cache=context.cache,
text_embed_config=text_embed,
embedded_fields=embedded_fields,
)
if config.snapshots.embeddings:
for name, table in result.items():
for name, table in output.items():
await write_table_to_storage(
table,
f"embeddings.{name}",
context.storage,
)
return WorkflowFunctionOutput(result=result, config=None)
return WorkflowFunctionOutput(result=output)
async def generate_text_embeddings(

View File

@@ -5,21 +5,19 @@
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.prune_graph_config import PruneGraphConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.graph_to_dataframes import graph_to_dataframes
from graphrag.index.operations.prune_graph import prune_graph as prune_graph_operation
from graphrag.index.typing import WorkflowFunctionOutput
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
_callbacks: WorkflowCallbacks,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
entities = await load_table_from_storage("entities", context.storage)
@@ -38,8 +36,7 @@ async def run_workflow(
result={
"entities": pruned_entities,
"relationships": pruned_relationships,
},
config=None,
}
)

View File

@@ -7,7 +7,7 @@ from typing import Any
from fnllm.events import LLMEvents
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.typing.error_handler import ErrorHandlerFn
class FNLLMEvents(LLMEvents):

View File

@@ -18,7 +18,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.language_model_config import (
LanguageModelConfig,
)
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.typing.error_handler import ErrorHandlerFn
from graphrag.language_model.providers.fnllm.cache import FNLLMCacheProvider

View File

@@ -6,8 +6,14 @@
from pathlib import Path
from typing import Any
from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.embeddings import create_collection_name
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.output_config import OutputConfig
from graphrag.data_model.types import TextEmbedder
from graphrag.storage.factory import StorageFactory
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.vector_stores.base import (
BaseVectorStore,
VectorStoreDocument,
@@ -230,3 +236,22 @@ def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None:
if prompt_file.exists():
return prompt_file.read_bytes().decode(encoding="utf-8")
return None
def create_storage_from_config(output: OutputConfig) -> PipelineStorage:
"""Create a storage object from the config."""
storage_config = output.model_dump()
return StorageFactory().create_storage(
storage_type=storage_config["type"],
kwargs=storage_config,
)
def create_cache_from_config(cache: CacheConfig, root_dir: str) -> PipelineCache:
"""Create a cache object from the config."""
cache_config = cache.model_dump()
return CacheFactory().create_cache(
cache_type=cache_config["type"],
root_dir=root_dir,
kwargs=cache_config,
)

View File

@@ -11,7 +11,7 @@
"finalize_graph": {
"row_range": [
1,
100
500
],
"nan_allowed_columns": [
"x",

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.workflows.create_base_text_units import run_workflow
from graphrag.utils.storage import load_table_from_storage
@@ -22,11 +21,7 @@ async def test_create_base_text_units():
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
actual = await load_table_from_storage("text_units", context.storage)
@@ -46,11 +41,7 @@ async def test_create_base_text_units_metadata():
await update_document_metadata(config.input.metadata, context)
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
actual = await load_table_from_storage("text_units", context.storage)
compare_outputs(actual, expected)
@@ -70,11 +61,7 @@ async def test_create_base_text_units_metadata_included_in_chunk():
await update_document_metadata(config.input.metadata, context)
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
actual = await load_table_from_storage("text_units", context.storage)
# only check the columns from the base workflow - our expected table is the final and will have more

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS
from graphrag.index.workflows.create_communities import (
@@ -32,7 +31,6 @@ async def test_create_communities():
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage("communities", context.storage)

View File

@@ -2,7 +2,6 @@
# Licensed under the MIT License
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import ModelType
from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS
@@ -65,11 +64,7 @@ async def test_create_community_reports():
"graph_prompt": "",
}
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
actual = await load_table_from_storage("community_reports", context.storage)

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS
from graphrag.index.workflows.create_final_documents import (
@@ -27,11 +26,7 @@ async def test_create_final_documents():
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
actual = await load_table_from_storage("documents", context.storage)
@@ -54,11 +49,7 @@ async def test_create_final_documents_with_metadata_column():
expected = await load_table_from_storage("documents", context.storage)
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
actual = await load_table_from_storage("documents", context.storage)

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.data_model.schemas import TEXT_UNITS_FINAL_COLUMNS
from graphrag.index.workflows.create_final_text_units import (
@@ -32,11 +31,7 @@ async def test_create_final_text_units():
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
config.extract_claims.enabled = True
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
actual = await load_table_from_storage("text_units", context.storage)

View File

@@ -3,7 +3,6 @@
from pandas.testing import assert_series_equal
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import ModelType
from graphrag.data_model.schemas import COVARIATES_FINAL_COLUMNS
@@ -44,11 +43,7 @@ async def test_extract_covariates():
"claim_description": "description",
}
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
actual = await load_table_from_storage("covariates", context.storage)

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import ModelType
from graphrag.index.workflows.extract_graph import (
@@ -60,11 +59,7 @@ async def test_extract_graph():
"llm": summarize_llm_settings,
}
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
nodes_actual = await load_table_from_storage("entities", context.storage)
edges_actual = await load_table_from_storage("relationships", context.storage)

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.workflows.extract_graph_nlp import (
run_workflow,
@@ -21,11 +20,7 @@ async def test_extract_graph_nlp():
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
nodes_actual = await load_table_from_storage("entities", context.storage)
edges_actual = await load_table_from_storage("relationships", context.storage)

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.data_model.schemas import (
ENTITIES_FINAL_COLUMNS,
@@ -24,11 +23,7 @@ async def test_finalize_graph():
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
nodes_actual = await load_table_from_storage("entities", context.storage)
edges_actual = await load_table_from_storage("relationships", context.storage)
@@ -54,11 +49,7 @@ async def test_finalize_graph_umap():
config.embed_graph.enabled = True
config.umap.enabled = True
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
nodes_actual = await load_table_from_storage("entities", context.storage)
edges_actual = await load_table_from_storage("relationships", context.storage)

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.embeddings import (
all_embeddings,
@@ -43,11 +42,7 @@ async def test_generate_text_embeddings():
config.embed_text.target = TextEmbeddingTarget.all
config.snapshots.embeddings = True
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
parquet_files = context.storage.keys()

View File

@@ -0,0 +1,54 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Tests for pipeline state passthrough."""
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.run.utils import create_run_context
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.workflows.factory import PipelineFactory
from tests.verbs.util import DEFAULT_MODEL_CONFIG
async def run_workflow_1( # noqa: RUF029
_config: GraphRagConfig, context: PipelineRunContext
):
context.state["count"] = 1
return WorkflowFunctionOutput(result=None)
async def run_workflow_2( # noqa: RUF029
_config: GraphRagConfig, context: PipelineRunContext
):
context.state["count"] += 1
return WorkflowFunctionOutput(result=None)
async def test_pipeline_state():
# checks that we can update the arbitrary state block within the pipeline run context
PipelineFactory.register("workflow_1", run_workflow_1)
PipelineFactory.register("workflow_2", run_workflow_2)
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
config.workflows = ["workflow_1", "workflow_2"]
context = create_run_context()
for _, fn in PipelineFactory.create_pipeline(config).run():
await fn(config, context)
assert context.state["count"] == 2
async def test_pipeline_existing_state():
PipelineFactory.register("workflow_2", run_workflow_2)
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
config.workflows = ["workflow_2"]
context = create_run_context(state={"count": 4})
for _, fn in PipelineFactory.create_pipeline(config).run():
await fn(config, context)
assert context.state["count"] == 5

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.models.prune_graph_config import PruneGraphConfig
from graphrag.index.workflows.prune_graph import (
@@ -25,11 +24,7 @@ async def test_prune_graph():
min_node_freq=4, min_node_degree=0, min_edge_weight_pct=0
)
await run_workflow(
config,
context,
NoopWorkflowCallbacks(),
)
await run_workflow(config, context)
nodes_actual = await load_table_from_storage("entities", context.storage)

View File

@@ -5,8 +5,8 @@ import pandas as pd
from pandas.testing import assert_series_equal
import graphrag.config.defaults as defs
from graphrag.index.context import PipelineRunContext
from graphrag.index.run.utils import create_run_context
from graphrag.index.typing.context import PipelineRunContext
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
pd.set_option("display.max_columns", None)
@@ -33,7 +33,7 @@ DEFAULT_MODEL_CONFIG = {
async def create_test_context(storage: list[str] | None = None) -> PipelineRunContext:
"""Create a test context with tables loaded into storage storage."""
context = create_run_context(None, None, None)
context = create_run_context()
# always set the input docs, but since our stored table is final, drop what wouldn't be in the original source input
input = load_test_table("documents")