Export NLP community reports prompt (#1697)

* Properly export the NLP community reports prompt

* Semver

* Fix verb tests
This commit is contained in:
Nathan Evans
2025-02-12 10:41:39 -08:00
committed by GitHub
parent b94290ec2b
commit fe461417b5
9 changed files with 32 additions and 13 deletions

View File

@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Export NLP community reports prompt."
}

View File

@@ -10,6 +10,9 @@ from graphrag.logger.factory import LoggerFactory, LoggerType
from graphrag.prompts.index.community_report import (
COMMUNITY_REPORT_PROMPT,
)
from graphrag.prompts.index.community_report_text_units import (
COMMUNITY_REPORT_TEXT_PROMPT,
)
from graphrag.prompts.index.extract_claims import EXTRACT_CLAIMS_PROMPT
from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
@@ -72,7 +75,8 @@ def initialize_project_at(path: Path, force: bool) -> None:
"extract_graph": GRAPH_EXTRACTION_PROMPT,
"summarize_descriptions": SUMMARIZE_PROMPT,
"extract_claims": EXTRACT_CLAIMS_PROMPT,
"community_report": COMMUNITY_REPORT_PROMPT,
"community_report_graph": COMMUNITY_REPORT_PROMPT,
"community_report_text": COMMUNITY_REPORT_TEXT_PROMPT,
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
"drift_reduce_prompt": DRIFT_REDUCE_PROMPT,
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,

View File

@@ -114,7 +114,8 @@ extract_claims:
community_reports:
model_id: {defs.COMMUNITY_REPORT_MODEL_ID}
prompt: "prompts/community_report.txt"
graph_prompt: "prompts/community_report_graph.txt"
text_prompt: "prompts/community_report_text.txt"
max_length: {defs.COMMUNITY_REPORT_MAX_LENGTH}
max_input_length: {defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH}

View File

@@ -14,8 +14,13 @@ from graphrag.config.models.language_model_config import LanguageModelConfig
class CommunityReportsConfig(BaseModel):
"""Configuration section for community reports."""
prompt: str | None = Field(
description="The community report extraction prompt to use.", default=None
graph_prompt: str | None = Field(
description="The community report extraction prompt to use for graph-based summarization.",
default=None,
)
text_prompt: str | None = Field(
description="The community report extraction prompt to use for text-based summarization.",
default=None,
)
max_length: int = Field(
description="The community report maximum length in tokens.",
@@ -46,10 +51,15 @@ class CommunityReportsConfig(BaseModel):
"llm": model_config.model_dump(),
"stagger": model_config.parallelization_stagger,
"num_threads": model_config.parallelization_num_threads,
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
encoding="utf-8"
)
if self.prompt
if self.graph_prompt
else None,
"text_prompt": (Path(root_dir) / self.text_prompt).read_text(
encoding="utf-8"
)
if self.text_prompt
else None,
"max_report_length": self.max_length,
"max_input_length": self.max_input_length,

View File

@@ -46,6 +46,8 @@ async def create_community_reports(
if claims_input is not None:
claims = _prep_claims(claims_input)
summarization_strategy["extraction_prompt"] = summarization_strategy["graph_prompt"]
max_input_length = summarization_strategy.get(
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
)

View File

@@ -24,9 +24,6 @@ from graphrag.index.operations.summarize_communities.text_unit_context.context_b
build_level_context,
build_local_context,
)
from graphrag.prompts.index.community_report_text_units import (
COMMUNITY_REPORT_PROMPT,
)
log = logging.getLogger(__name__)
@@ -44,8 +41,7 @@ async def create_community_reports_text(
"""All the steps to transform community reports."""
nodes = explode_communities(communities, entities)
# TEMP: forcing override of the prompt until we can put it into config
summarization_strategy["extraction_prompt"] = COMMUNITY_REPORT_PROMPT
summarization_strategy["extraction_prompt"] = summarization_strategy["text_prompt"]
max_input_length = summarization_strategy.get(
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH

View File

@@ -3,7 +3,7 @@
"""A file containing prompts definition."""
COMMUNITY_REPORT_PROMPT = """
COMMUNITY_REPORT_TEXT_PROMPT = """
You are an AI assistant that helps a human analyst to perform general information discovery.
Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network.

View File

@@ -418,7 +418,8 @@ def assert_summarize_descriptions_configs(
def assert_community_reports_configs(
actual: CommunityReportsConfig, expected: CommunityReportsConfig
) -> None:
assert actual.prompt == expected.prompt
assert actual.graph_prompt == expected.graph_prompt
assert actual.text_prompt == expected.text_prompt
assert actual.max_length == expected.max_length
assert actual.max_input_length == expected.max_input_length
assert actual.strategy == expected.strategy

View File

@@ -61,6 +61,7 @@ async def test_create_community_reports():
config.community_reports.strategy = {
"type": "graph_intelligence",
"llm": llm_settings,
"graph_prompt": "",
}
await run_workflow(