mirror of
https://github.com/microsoft/graphrag.git
synced 2025-03-11 01:26:14 +03:00
Export NLP community reports prompt (#1697)
* Properly export the NLP community reports prompt * Semver * Fix verb tests
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Export NLP community reports prompt."
|
||||
}
|
||||
@@ -10,6 +10,9 @@ from graphrag.logger.factory import LoggerFactory, LoggerType
|
||||
from graphrag.prompts.index.community_report import (
|
||||
COMMUNITY_REPORT_PROMPT,
|
||||
)
|
||||
from graphrag.prompts.index.community_report_text_units import (
|
||||
COMMUNITY_REPORT_TEXT_PROMPT,
|
||||
)
|
||||
from graphrag.prompts.index.extract_claims import EXTRACT_CLAIMS_PROMPT
|
||||
from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT
|
||||
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
|
||||
@@ -72,7 +75,8 @@ def initialize_project_at(path: Path, force: bool) -> None:
|
||||
"extract_graph": GRAPH_EXTRACTION_PROMPT,
|
||||
"summarize_descriptions": SUMMARIZE_PROMPT,
|
||||
"extract_claims": EXTRACT_CLAIMS_PROMPT,
|
||||
"community_report": COMMUNITY_REPORT_PROMPT,
|
||||
"community_report_graph": COMMUNITY_REPORT_PROMPT,
|
||||
"community_report_text": COMMUNITY_REPORT_TEXT_PROMPT,
|
||||
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
|
||||
"drift_reduce_prompt": DRIFT_REDUCE_PROMPT,
|
||||
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,
|
||||
|
||||
@@ -114,7 +114,8 @@ extract_claims:
|
||||
|
||||
community_reports:
|
||||
model_id: {defs.COMMUNITY_REPORT_MODEL_ID}
|
||||
prompt: "prompts/community_report.txt"
|
||||
graph_prompt: "prompts/community_report_graph.txt"
|
||||
text_prompt: "prompts/community_report_text.txt"
|
||||
max_length: {defs.COMMUNITY_REPORT_MAX_LENGTH}
|
||||
max_input_length: {defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH}
|
||||
|
||||
|
||||
@@ -14,8 +14,13 @@ from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
class CommunityReportsConfig(BaseModel):
|
||||
"""Configuration section for community reports."""
|
||||
|
||||
prompt: str | None = Field(
|
||||
description="The community report extraction prompt to use.", default=None
|
||||
graph_prompt: str | None = Field(
|
||||
description="The community report extraction prompt to use for graph-based summarization.",
|
||||
default=None,
|
||||
)
|
||||
text_prompt: str | None = Field(
|
||||
description="The community report extraction prompt to use for text-based summarization.",
|
||||
default=None,
|
||||
)
|
||||
max_length: int = Field(
|
||||
description="The community report maximum length in tokens.",
|
||||
@@ -46,10 +51,15 @@ class CommunityReportsConfig(BaseModel):
|
||||
"llm": model_config.model_dump(),
|
||||
"stagger": model_config.parallelization_stagger,
|
||||
"num_threads": model_config.parallelization_num_threads,
|
||||
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
|
||||
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
if self.prompt
|
||||
if self.graph_prompt
|
||||
else None,
|
||||
"text_prompt": (Path(root_dir) / self.text_prompt).read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
if self.text_prompt
|
||||
else None,
|
||||
"max_report_length": self.max_length,
|
||||
"max_input_length": self.max_input_length,
|
||||
|
||||
@@ -46,6 +46,8 @@ async def create_community_reports(
|
||||
if claims_input is not None:
|
||||
claims = _prep_claims(claims_input)
|
||||
|
||||
summarization_strategy["extraction_prompt"] = summarization_strategy["graph_prompt"]
|
||||
|
||||
max_input_length = summarization_strategy.get(
|
||||
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
|
||||
)
|
||||
|
||||
@@ -24,9 +24,6 @@ from graphrag.index.operations.summarize_communities.text_unit_context.context_b
|
||||
build_level_context,
|
||||
build_local_context,
|
||||
)
|
||||
from graphrag.prompts.index.community_report_text_units import (
|
||||
COMMUNITY_REPORT_PROMPT,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,8 +41,7 @@ async def create_community_reports_text(
|
||||
"""All the steps to transform community reports."""
|
||||
nodes = explode_communities(communities, entities)
|
||||
|
||||
# TEMP: forcing override of the prompt until we can put it into config
|
||||
summarization_strategy["extraction_prompt"] = COMMUNITY_REPORT_PROMPT
|
||||
summarization_strategy["extraction_prompt"] = summarization_strategy["text_prompt"]
|
||||
|
||||
max_input_length = summarization_strategy.get(
|
||||
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
"""A file containing prompts definition."""
|
||||
|
||||
COMMUNITY_REPORT_PROMPT = """
|
||||
COMMUNITY_REPORT_TEXT_PROMPT = """
|
||||
You are an AI assistant that helps a human analyst to perform general information discovery.
|
||||
Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network.
|
||||
|
||||
|
||||
@@ -418,7 +418,8 @@ def assert_summarize_descriptions_configs(
|
||||
def assert_community_reports_configs(
|
||||
actual: CommunityReportsConfig, expected: CommunityReportsConfig
|
||||
) -> None:
|
||||
assert actual.prompt == expected.prompt
|
||||
assert actual.graph_prompt == expected.graph_prompt
|
||||
assert actual.text_prompt == expected.text_prompt
|
||||
assert actual.max_length == expected.max_length
|
||||
assert actual.max_input_length == expected.max_input_length
|
||||
assert actual.strategy == expected.strategy
|
||||
|
||||
@@ -61,6 +61,7 @@ async def test_create_community_reports():
|
||||
config.community_reports.strategy = {
|
||||
"type": "graph_intelligence",
|
||||
"llm": llm_settings,
|
||||
"graph_prompt": "",
|
||||
}
|
||||
|
||||
await run_workflow(
|
||||
|
||||
Reference in New Issue
Block a user