Refactor config (#1593)

* Refactor config

- Add new ModelConfig to represent LLM settings
    - Combines LLMParameters, ParallelizationParameters, encoding_model, and async_mode
- Add top level models config that is a list of available LLM ModelConfigs
- Remove LLMConfig inheritance and delete LLMConfig
    - Replace the inheritance with a model_id reference to the ModelConfig listed in the top level models config
- Remove all fallbacks and hydration logic from create_graphrag_config
    - This removes the automatic env variable overrides
- Support env variables within config files using Templating
    - This requires "$" to be escaped with extra "$" so ".*\\.txt$" becomes ".*\\.txt$$"
- Update init content to initialize new config file with the ModelConfig structure

* Use dict of ModelConfig instead of list

* Add model validations and unit tests

* Fix ruff checks

* Add semversioner change

* Fix unit tests

* validate root_dir in pydantic model

* Rename ModelConfig to LanguageModelConfig

* Rename ModelConfigMissingError to LanguageModelConfigMissingError

* Add validationg for unexpected API keys

* Allow skipping pydantic validation for testing/mocking purposes.

* Add default lm configs to verb tests

* smoke test

* remove config from flows to fix llm arg mapping

* Fix embedding llm arg mapping

* Remove timestamp from smoke test outputs

* Remove unused "subworkflows" smoke test properties

* Add models to smoke test configs

* Update smoke test output path

* Send logs to logs folder

* Fix output path

* Fix csv test file pattern

* Update placeholder

* Format

* Instantiate default model configs

* Fix unit tests for config defaults

* Fix migration notebook

* Remove create_pipeline_config

* Remove several unused config models

* Remove indexing embedding and input configs

* Move embeddings function to config

* Remove skip_workflows

* Remove skip embeddings in favor of explicit naming

* fix unit test spelling mistake

* self.models[model_id] is already a language model. Remove redundant casting.

* update validation errors to instruct users to rerun graphrag init

* instantiate LanguageModelConfigs with validation

* skip validation in unit tests

* update verb tests to use default model settings instead of skipping validation

* test using llm settings

* cleanup verb tests

* remove unsafe default model config

* remove the ability to skip pydantic validation

* remove None union types when default values are set

* move vector_store from embeddings to top level of config and delete resolve_paths

* update vector store settings

* fix vector store and smoke tests

* fix serializing vector_store settings

* fix vector_store usage

* fix vector_store type

* support cli overrides for loading graphrag config

* rename storage to output

* Add --force flag to init

* Remove run_id and resume, fix Drift config assignment

* Ruff

---------

Co-authored-by: Nathan Evans <github@talkswithnumbers.com>
Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Derek Worthen
2025-01-21 15:52:06 -08:00
committed by GitHub
parent 47adfe16f0
commit c644338bae
104 changed files with 2251 additions and 3608 deletions

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
@@ -25,7 +25,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
@@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -45,19 +45,20 @@
"\n",
"from graphrag.config.load_config import load_config\n",
"from graphrag.config.resolve_path import resolve_paths\n",
"from graphrag.index.create_pipeline_config import create_pipeline_config\n",
"from graphrag.storage.factory import create_storage\n",
"from graphrag.storage.factory import StorageFactory\n",
"\n",
"# This first block does some config loading, path resolution, and translation that is normally done by the CLI/API when running a full workflow\n",
"config = load_config(Path(PROJECT_DIRECTORY))\n",
"resolve_paths(config)\n",
"pipeline_config = create_pipeline_config(config)\n",
"storage = create_storage(pipeline_config.storage)"
"storage_config = config.storage.model_dump() # type: ignore\n",
"storage = StorageFactory().create_storage(\n",
" storage_type=storage_config[\"type\"], kwargs=storage_config\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
@@ -68,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
@@ -97,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
@@ -108,22 +109,16 @@
"# First we'll go through any parquet files that had model changes and update them\n",
"# The new data model may have removed excess columns as well, but we will only make the minimal changes required for compatibility\n",
"\n",
"final_documents = await load_table_from_storage(\n",
" \"create_final_documents.parquet\", storage\n",
")\n",
"final_text_units = await load_table_from_storage(\n",
" \"create_final_text_units.parquet\", storage\n",
")\n",
"final_entities = await load_table_from_storage(\"create_final_entities.parquet\", storage)\n",
"final_nodes = await load_table_from_storage(\"create_final_nodes.parquet\", storage)\n",
"final_documents = await load_table_from_storage(\"create_final_documents\", storage)\n",
"final_text_units = await load_table_from_storage(\"create_final_text_units\", storage)\n",
"final_entities = await load_table_from_storage(\"create_final_entities\", storage)\n",
"final_nodes = await load_table_from_storage(\"create_final_nodes\", storage)\n",
"final_relationships = await load_table_from_storage(\n",
" \"create_final_relationships.parquet\", storage\n",
")\n",
"final_communities = await load_table_from_storage(\n",
" \"create_final_communities.parquet\", storage\n",
" \"create_final_relationships\", storage\n",
")\n",
"final_communities = await load_table_from_storage(\"create_final_communities\", storage)\n",
"final_community_reports = await load_table_from_storage(\n",
" \"create_final_community_reports.parquet\", storage\n",
" \"create_final_community_reports\", storage\n",
")\n",
"\n",
"\n",
@@ -183,44 +178,41 @@
" parent_df, on=\"community\", how=\"left\"\n",
" )\n",
"\n",
"await write_table_to_storage(final_documents, \"create_final_documents.parquet\", storage)\n",
"await write_table_to_storage(final_documents, \"create_final_documents\", storage)\n",
"await write_table_to_storage(final_text_units, \"create_final_text_units\", storage)\n",
"await write_table_to_storage(final_entities, \"create_final_entities\", storage)\n",
"await write_table_to_storage(final_nodes, \"create_final_nodes\", storage)\n",
"await write_table_to_storage(final_relationships, \"create_final_relationships\", storage)\n",
"await write_table_to_storage(final_communities, \"create_final_communities\", storage)\n",
"await write_table_to_storage(\n",
" final_text_units, \"create_final_text_units.parquet\", storage\n",
")\n",
"await write_table_to_storage(final_entities, \"create_final_entities.parquet\", storage)\n",
"await write_table_to_storage(final_nodes, \"create_final_nodes.parquet\", storage)\n",
"await write_table_to_storage(\n",
" final_relationships, \"create_final_relationships.parquet\", storage\n",
")\n",
"await write_table_to_storage(\n",
" final_communities, \"create_final_communities.parquet\", storage\n",
")\n",
"await write_table_to_storage(\n",
" final_community_reports, \"create_final_community_reports.parquet\", storage\n",
" final_community_reports, \"create_final_community_reports\", storage\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from graphrag.cache.factory import create_cache\n",
"from graphrag.cache.factory import CacheFactory\n",
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
"from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings\n",
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
"\n",
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
"# We'll construct the context and run this function flow directly to avoid everything else\n",
"\n",
"workflow = next(\n",
" (x for x in pipeline_config.workflows if x.name == \"generate_text_embeddings\"), None\n",
")\n",
"config = workflow.config\n",
"text_embed = config.get(\"text_embed\", {})\n",
"embedded_fields = config.get(\"embedded_fields\", {})\n",
"\n",
"embedded_fields = get_embedded_fields(config)\n",
"text_embed = get_embedding_settings(config)\n",
"callbacks = NoopWorkflowCallbacks()\n",
"cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n",
"cache_config = config.cache.model_dump() # type: ignore\n",
"cache = CacheFactory().create_cache(\n",
" cache_type=cache_config[\"type\"], # type: ignore\n",
" root_dir=PROJECT_DIRECTORY,\n",
" kwargs=cache_config,\n",
")\n",
"\n",
"await generate_text_embeddings(\n",
" final_documents=None,\n",