Files
graphrag-microsoft/graphrag/api/query.py
Josh Bradley b8b949f3bb Cleanup query api - remove code duplication (#1690)
* consolidate query api functions and remove code duplication

* refactor and remove more code duplication

* Add semversioner file

* fix basic search

* fix drift search and update base class function names

* update example notebooks
2025-02-13 16:31:08 -05:00

1210 lines
46 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""
Query Engine API.
This API provides access to the query engine of graphrag, allowing external applications
to hook into graphrag and run queries over a knowledge graph generated by graphrag.
Contains the following functions:
- global_search: Perform a global search.
- global_search_streaming: Perform a global search and stream results back.
- local_search: Perform a local search.
- local_search_streaming: Perform a local search and stream results back.
WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""
from collections.abc import AsyncGenerator
from typing import Any
import pandas as pd
from pydantic import validate_call
from graphrag.config.embeddings import (
community_full_content_embedding,
entity_description_embedding,
text_unit_text_embedding,
)
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.logger.print_progress import PrintProgressLogger
from graphrag.query.factory import (
get_basic_search_engine,
get_drift_search_engine,
get_global_search_engine,
get_local_search_engine,
)
from graphrag.query.indexer_adapters import (
read_indexer_communities,
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_report_embeddings,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.utils.api import (
get_embedding_store,
load_search_prompt,
reformat_context_data,
update_context_data,
)
from graphrag.utils.cli import redact
logger = PrintProgressLogger("")
@validate_call(config={"arbitrary_types_allowed": True})
async def global_search(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
community_level: int | None,
dynamic_community_selection: bool,
response_type: str,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a global search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- communities (pd.DataFrame): A DataFrame containing the final communities (from communities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- community_level (int): The community level to search at.
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
- response_type (str): The type of response to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
full_response = ""
context_data = {}
get_context_data = True
# NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
async for chunk in global_search_streaming(
config=config,
entities=entities,
communities=communities,
community_reports=community_reports,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = chunk
get_context_data = False
else:
full_response += chunk
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
async def global_search_streaming(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
community_level: int | None,
dynamic_community_selection: bool,
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a global search and return the context data and response via a generator.
Context data is returned as a dictionary of lists, with one list entry for each record.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- communities (pd.DataFrame): A DataFrame containing the final communities (from communities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- community_level (int): The community level to search at.
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
- response_type (str): The type of response to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
communities_ = read_indexer_communities(communities, community_reports)
reports = read_indexer_reports(
community_reports,
communities,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
)
entities_ = read_indexer_entities(
entities, communities, community_level=community_level
)
map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt)
reduce_prompt = load_search_prompt(
config.root_dir, config.global_search.reduce_prompt
)
knowledge_prompt = load_search_prompt(
config.root_dir, config.global_search.knowledge_prompt
)
search_engine = get_global_search_engine(
config,
reports=reports,
entities=entities_,
communities=communities_,
response_type=response_type,
dynamic_community_selection=dynamic_community_selection,
map_system_prompt=map_prompt,
reduce_system_prompt=reduce_prompt,
general_knowledge_inclusion_prompt=knowledge_prompt,
)
search_result = search_engine.stream_search(query=query)
# NOTE: when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = {}
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = reformat_context_data(stream_chunk) # type: ignore
yield context_data
get_context_data = False
else:
yield stream_chunk
@validate_call(config={"arbitrary_types_allowed": True})
async def multi_index_global_search(
config: GraphRagConfig,
entities_list: list[pd.DataFrame],
communities_list: list[pd.DataFrame],
community_reports_list: list[pd.DataFrame],
index_names: list[str],
community_level: int | None,
dynamic_community_selection: bool,
response_type: str,
streaming: bool,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a global search across multiple indexes and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
- communities_list (list[pd.DataFrame]): A list of DataFrames containing the final communities (from communities.parquet)
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
- index_names (list[str]): A list of index names.
- community_level (int): The community level to search at.
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
- response_type (str): The type of response to return.
- streaming (bool): Whether to stream the results or not.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_global_search"
raise NotImplementedError(message)
links = {
"communities": {},
"community_reports": {},
"entities": {},
}
max_vals = {
"communities": -1,
"community_reports": -1,
"entities": -1,
}
communities_dfs = []
community_reports_dfs = []
entities_dfs = []
for idx, index_name in enumerate(index_names):
# Prepare each index's community reports dataframe for merging
community_reports_df = community_reports_list[idx]
community_reports_df["community"] = community_reports_df["community"].astype(
int
)
for i in community_reports_df["community"]:
links["community_reports"][i + max_vals["community_reports"] + 1] = {
"index_name": index_name,
"id": str(i),
}
community_reports_df["community"] += max_vals["community_reports"] + 1
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
max_vals["community_reports"] = int(community_reports_df["community"].max())
community_reports_dfs.append(community_reports_df)
# Prepare each index's communities dataframe for merging
communities_df = communities_list[idx]
communities_df["community"] = communities_df["community"].astype(int)
communities_df["parent"] = communities_df["parent"].astype(int)
for i in communities_df["community"]:
links["communities"][i + max_vals["communities"] + 1] = {
"index_name": index_name,
"id": str(i),
}
communities_df["community"] += max_vals["communities"] + 1
communities_df["parent"] = communities_df["parent"].apply(
lambda x: x if x == -1 else x + max_vals["communities"] + 1
)
communities_df["human_readable_id"] += max_vals["communities"] + 1
# concat the index name to the entity_ids, since this is used for joining later
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["communities"] = int(communities_df["community"].max())
communities_dfs.append(communities_df)
# Prepare each index's entities dataframe for merging
entities_df = entities_list[idx]
for i in entities_df["human_readable_id"]:
links["entities"][i + max_vals["entities"] + 1] = {
"index_name": index_name,
"id": i,
}
entities_df["human_readable_id"] += max_vals["entities"] + 1
entities_df["title"] = entities_df["title"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["entities"] = int(entities_df["human_readable_id"].max())
entities_dfs.append(entities_df)
# Merge the dataframes
community_reports_combined = pd.concat(
community_reports_dfs, axis=0, ignore_index=True, sort=False
)
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
communities_combined = pd.concat(
communities_dfs, axis=0, ignore_index=True, sort=False
)
result = await global_search(
config,
entities=entities_combined,
communities=communities_combined,
community_reports=community_reports_combined,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
response_type=response_type,
query=query,
)
# Update the context data by linking index names and community ids
context = update_context_data(result[1], links)
return (result[0], context)
@validate_call(config={"arbitrary_types_allowed": True})
async def local_search(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
covariates: pd.DataFrame | None,
community_level: int,
response_type: str,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a local search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from covariates.parquet)
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
full_response = ""
context_data = {}
get_context_data = True
# NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
async for chunk in local_search_streaming(
config=config,
entities=entities,
communities=communities,
community_reports=community_reports,
text_units=text_units,
relationships=relationships,
covariates=covariates,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = chunk
get_context_data = False
else:
full_response += chunk
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
async def local_search_streaming(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
covariates: pd.DataFrame | None,
community_level: int,
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a local search and return the context data and response via a generator.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from covariates.parquet)
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=entity_description_embedding,
)
entities_ = read_indexer_entities(entities, communities, community_level)
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
search_engine = get_local_search_engine(
config=config,
reports=read_indexer_reports(community_reports, communities, community_level),
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
covariates={"claims": covariates_},
description_embedding_store=description_embedding_store, # type: ignore
response_type=response_type,
system_prompt=prompt,
)
search_result = search_engine.stream_search(query=query)
# NOTE: when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = {}
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = reformat_context_data(stream_chunk) # type: ignore
yield context_data
get_context_data = False
else:
yield stream_chunk
@validate_call(config={"arbitrary_types_allowed": True})
async def multi_index_local_search(
config: GraphRagConfig,
entities_list: list[pd.DataFrame],
communities_list: list[pd.DataFrame],
community_reports_list: list[pd.DataFrame],
text_units_list: list[pd.DataFrame],
relationships_list: list[pd.DataFrame],
covariates_list: list[pd.DataFrame] | None,
index_names: list[str],
community_level: int,
response_type: str,
streaming: bool,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a local search across multiple indexes and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
- relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet)
- covariates_list (list[pd.DataFrame]): [Optional] A list of DataFrames containing the final covariates (from covariates.parquet)
- index_names (list[str]): A list of index names.
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- streaming (bool): Whether to stream the results or not.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_index_local_search"
raise NotImplementedError(message)
links = {
"community_reports": {},
"communities": {},
"entities": {},
"text_units": {},
"relationships": {},
"covariates": {},
}
max_vals = {
"community_reports": -1,
"communities": -1,
"entities": -1,
"text_units": 0,
"relationships": -1,
"covariates": 0,
}
community_reports_dfs = []
communities_dfs = []
entities_dfs = []
relationships_dfs = []
text_units_dfs = []
covariates_dfs = []
for idx, index_name in enumerate(index_names):
# Prepare each index's communities dataframe for merging
communities_df = communities_list[idx]
communities_df["community"] = communities_df["community"].astype(int)
for i in communities_df["community"]:
links["communities"][i + max_vals["communities"] + 1] = {
"index_name": index_name,
"id": str(i),
}
communities_df["community"] += max_vals["communities"] + 1
communities_df["human_readable_id"] += max_vals["communities"] + 1
# concat the index name to the entity_ids, since this is used for joining later
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["communities"] = int(communities_df["community"].max())
communities_dfs.append(communities_df)
# Prepare each index's community reports dataframe for merging
community_reports_df = community_reports_list[idx]
community_reports_df["community"] = community_reports_df["community"].astype(
int
)
for i in community_reports_df["community"]:
links["community_reports"][i + max_vals["community_reports"] + 1] = {
"index_name": index_name,
"id": str(i),
}
community_reports_df["community"] += max_vals["community_reports"] + 1
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
max_vals["community_reports"] = int(community_reports_df["community"].max())
community_reports_dfs.append(community_reports_df)
# Prepare each index's entities dataframe for merging
entities_df = entities_list[idx]
for i in entities_df["human_readable_id"]:
links["entities"][i + max_vals["entities"] + 1] = {
"index_name": index_name,
"id": i,
}
entities_df["human_readable_id"] += max_vals["entities"] + 1
entities_df["title"] = entities_df["title"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
entities_df["id"] = entities_df["id"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["entities"] = int(entities_df["human_readable_id"].max())
entities_dfs.append(entities_df)
# Prepare each index's relationships dataframe for merging
relationships_df = relationships_list[idx]
for i in relationships_df["human_readable_id"].astype(int):
links["relationships"][i + max_vals["relationships"] + 1] = {
"index_name": index_name,
"id": i,
}
if max_vals["relationships"] != -1:
col = (
relationships_df["human_readable_id"].astype(int)
+ max_vals["relationships"]
+ 1
)
relationships_df["human_readable_id"] = col.astype(str)
relationships_df["source"] = relationships_df["source"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
relationships_df["target"] = relationships_df["target"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["relationships"] = int(relationships_df["human_readable_id"].max())
relationships_dfs.append(relationships_df)
# Prepare each index's text units dataframe for merging
text_units_df = text_units_list[idx]
for i in range(text_units_df.shape[0]):
links["text_units"][i + max_vals["text_units"]] = {
"index_name": index_name,
"id": i,
}
text_units_df["id"] = text_units_df["id"].apply(
lambda x, index_name=index_name: f"{x}-{index_name}"
)
text_units_df["human_readable_id"] = (
text_units_df["human_readable_id"] + max_vals["text_units"]
)
max_vals["text_units"] += text_units_df.shape[0]
text_units_dfs.append(text_units_df)
# If presents, prepare each index's covariates dataframe for merging
if covariates_list is not None:
covariates_df = covariates_list[idx]
for i in covariates_df["human_readable_id"].astype(int):
links["covariates"][i + max_vals["covariates"]] = {
"index_name": index_name,
"id": i,
}
covariates_df["id"] = covariates_df["id"].apply(
lambda x, index_name=index_name: f"{x}-{index_name}"
)
covariates_df["human_readable_id"] = (
covariates_df["human_readable_id"] + max_vals["covariates"]
)
covariates_df["text_unit_id"] = covariates_df["text_unit_id"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
covariates_df["subject_id"] = covariates_df["subject_id"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
max_vals["covariates"] += covariates_df.shape[0]
covariates_dfs.append(covariates_df)
# Merge the dataframes
communities_combined = pd.concat(
communities_dfs, axis=0, ignore_index=True, sort=False
)
community_reports_combined = pd.concat(
community_reports_dfs, axis=0, ignore_index=True, sort=False
)
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
relationships_combined = pd.concat(
relationships_dfs, axis=0, ignore_index=True, sort=False
)
text_units_combined = pd.concat(
text_units_dfs, axis=0, ignore_index=True, sort=False
)
covariates_combined = None
if len(covariates_dfs) > 0:
covariates_combined = pd.concat(
covariates_dfs, axis=0, ignore_index=True, sort=False
)
result = await local_search(
config,
entities=entities_combined,
communities=communities_combined,
community_reports=community_reports_combined,
text_units=text_units_combined,
relationships=relationships_combined,
covariates=covariates_combined,
community_level=community_level,
response_type=response_type,
query=query,
)
# Update the context data by linking index names and community ids
context = update_context_data(result[1], links)
return (result[0], context)
@validate_call(config={"arbitrary_types_allowed": True})
async def drift_search(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a DRIFT search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
- community_level (int): The community level to search at.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
full_response = ""
context_data = {}
get_context_data = True
# NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
async for chunk in drift_search_streaming(
config=config,
entities=entities,
communities=communities,
community_reports=community_reports,
text_units=text_units,
relationships=relationships,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = chunk
get_context_data = False
else:
full_response += chunk
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
async def drift_search_streaming(
config: GraphRagConfig,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
text_units: pd.DataFrame,
relationships: pd.DataFrame,
community_level: int,
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a DRIFT search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
- community_level (int): The community level to search at.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=entity_description_embedding,
)
full_content_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=community_full_content_embedding,
)
entities_ = read_indexer_entities(entities, communities, community_level)
reports = read_indexer_reports(community_reports, communities, community_level)
read_indexer_report_embeddings(reports, full_content_embedding_store)
prompt = load_search_prompt(config.root_dir, config.drift_search.prompt)
reduce_prompt = load_search_prompt(
config.root_dir, config.drift_search.reduce_prompt
)
search_engine = get_drift_search_engine(
config=config,
reports=reports,
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
)
search_result = search_engine.stream_search(query=query)
# NOTE: when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = {}
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = reformat_context_data(stream_chunk) # type: ignore
yield context_data
get_context_data = False
else:
yield stream_chunk
@validate_call(config={"arbitrary_types_allowed": True})
async def multi_index_drift_search(
config: GraphRagConfig,
entities_list: list[pd.DataFrame],
communities_list: list[pd.DataFrame],
community_reports_list: list[pd.DataFrame],
text_units_list: list[pd.DataFrame],
relationships_list: list[pd.DataFrame],
index_names: list[str],
community_level: int,
response_type: str,
streaming: bool,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a DRIFT search across multiple indexes and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
- relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet)
- index_names (list[str]): A list of index names.
- community_level (int): The community level to search at.
- response_type (str): The response type to return.
- streaming (bool): Whether to stream the results or not.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_drift_search"
raise NotImplementedError(message)
links = {
"community_reports": {},
"communities": {},
"entities": {},
"text_units": {},
"relationships": {},
}
max_vals = {
"community_reports": -1,
"communities": -1,
"entities": -1,
"text_units": 0,
"relationships": -1,
}
communities_dfs = []
community_reports_dfs = []
entities_dfs = []
relationships_dfs = []
text_units_dfs = []
for idx, index_name in enumerate(index_names):
# Prepare each index's communities dataframe for merging
communities_df = communities_list[idx]
communities_df["community"] = communities_df["community"].astype(int)
for i in communities_df["community"]:
links["communities"][i + max_vals["communities"] + 1] = {
"index_name": index_name,
"id": str(i),
}
communities_df["community"] += max_vals["communities"] + 1
communities_df["human_readable_id"] += max_vals["communities"] + 1
# concat the index name to the entity_ids, since this is used for joining later
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["communities"] = int(communities_df["community"].max())
communities_dfs.append(communities_df)
# Prepare each index's community reports dataframe for merging
community_reports_df = community_reports_list[idx]
community_reports_df["community"] = community_reports_df["community"].astype(
int
)
for i in community_reports_df["community"]:
links["community_reports"][i + max_vals["community_reports"] + 1] = {
"index_name": index_name,
"id": str(i),
}
community_reports_df["community"] += max_vals["community_reports"] + 1
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
community_reports_df["id"] = community_reports_df["id"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
max_vals["community_reports"] = int(community_reports_df["community"].max())
community_reports_dfs.append(community_reports_df)
# Prepare each index's entities dataframe for merging
entities_df = entities_list[idx]
for i in entities_df["human_readable_id"]:
links["entities"][i + max_vals["entities"] + 1] = {
"index_name": index_name,
"id": i,
}
entities_df["human_readable_id"] += max_vals["entities"] + 1
entities_df["title"] = entities_df["title"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
entities_df["id"] = entities_df["id"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["entities"] = int(entities_df["human_readable_id"].max())
entities_dfs.append(entities_df)
# Prepare each index's relationships dataframe for merging
relationships_df = relationships_list[idx]
for i in relationships_df["human_readable_id"].astype(int):
links["relationships"][i + max_vals["relationships"] + 1] = {
"index_name": index_name,
"id": i,
}
if max_vals["relationships"] != -1:
col = (
relationships_df["human_readable_id"].astype(int)
+ max_vals["relationships"]
+ 1
)
relationships_df["human_readable_id"] = col.astype(str)
relationships_df["source"] = relationships_df["source"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
relationships_df["target"] = relationships_df["target"].apply(
lambda x, index_name=index_name: x + f"-{index_name}"
)
relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply(
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
)
max_vals["relationships"] = int(
relationships_df["human_readable_id"].astype(int).max()
)
relationships_dfs.append(relationships_df)
# Prepare each index's text units dataframe for merging
text_units_df = text_units_list[idx]
for i in range(text_units_df.shape[0]):
links["text_units"][i + max_vals["text_units"]] = {
"index_name": index_name,
"id": i,
}
text_units_df["id"] = text_units_df["id"].apply(
lambda x, index_name=index_name: f"{x}-{index_name}"
)
text_units_df["human_readable_id"] = (
text_units_df["human_readable_id"] + max_vals["text_units"]
)
max_vals["text_units"] += text_units_df.shape[0]
text_units_dfs.append(text_units_df)
# Merge the dataframes
communities_combined = pd.concat(
communities_dfs, axis=0, ignore_index=True, sort=False
)
community_reports_combined = pd.concat(
community_reports_dfs, axis=0, ignore_index=True, sort=False
)
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
relationships_combined = pd.concat(
relationships_dfs, axis=0, ignore_index=True, sort=False
)
text_units_combined = pd.concat(
text_units_dfs, axis=0, ignore_index=True, sort=False
)
result = await drift_search(
config,
entities=entities_combined,
communities=communities_combined,
community_reports=community_reports_combined,
text_units=text_units_combined,
relationships=relationships_combined,
community_level=community_level,
response_type=response_type,
query=query,
)
# Update the context data by linking index names and community ids
context = {}
if type(result[1]) is dict:
for key in result[1]:
context[key] = update_context_data(result[1][key], links)
else:
context = result[1]
return (result[0], context)
@validate_call(config={"arbitrary_types_allowed": True})
async def basic_search(
config: GraphRagConfig,
text_units: pd.DataFrame,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a basic search and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
full_response = ""
context_data = {}
get_context_data = True
# NOTE: when streaming, the first chunk of returned data is the complete context data.
# All subsequent chunks are the query response.
async for chunk in basic_search_streaming(
config=config,
text_units=text_units,
query=query,
):
if get_context_data:
context_data = chunk
get_context_data = False
else:
full_response += chunk
return full_response, context_data
@validate_call(config={"arbitrary_types_allowed": True})
async def basic_search_streaming(
config: GraphRagConfig,
text_units: pd.DataFrame,
query: str,
) -> AsyncGenerator:
"""Perform a local search and return the context data and response via a generator.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
description_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
embedding_name=text_unit_text_embedding,
)
prompt = load_search_prompt(config.root_dir, config.basic_search.prompt)
search_engine = get_basic_search_engine(
config=config,
text_units=read_indexer_text_units(text_units),
text_unit_embeddings=description_embedding_store,
system_prompt=prompt,
)
search_result = search_engine.stream_search(query=query)
# NOTE: when streaming results, a context data object is returned as the first result
# and the query response in subsequent tokens
context_data = {}
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = reformat_context_data(stream_chunk) # type: ignore
yield context_data
get_context_data = False
else:
yield stream_chunk
@validate_call(config={"arbitrary_types_allowed": True})
async def multi_index_basic_search(
config: GraphRagConfig,
text_units_list: list[pd.DataFrame],
index_names: list[str],
streaming: bool,
query: str,
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a basic search across multiple indexes and return the context data and response.
Parameters
----------
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
- index_names (list[str]): A list of index names.
- streaming (bool): Whether to stream the results or not.
- query (str): The user query to search for.
Returns
-------
TODO: Document the search response type and format.
Raises
------
TODO: Document any exceptions to expect.
"""
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_basic_search"
raise NotImplementedError(message)
links = {
"text_units": {},
}
max_vals = {
"text_units": 0,
}
text_units_dfs = []
for idx, index_name in enumerate(index_names):
# Prepare each index's text units dataframe for merging
text_units_df = text_units_list[idx]
for i in range(text_units_df.shape[0]):
links["text_units"][i + max_vals["text_units"]] = {
"index_name": index_name,
"id": i,
}
text_units_df["id"] = text_units_df["id"].apply(
lambda x, index_name=index_name: f"{x}-{index_name}"
)
text_units_df["human_readable_id"] = (
text_units_df["human_readable_id"] + max_vals["text_units"]
)
max_vals["text_units"] += text_units_df.shape[0]
text_units_dfs.append(text_units_df)
# Merge the dataframes
text_units_combined = pd.concat(
text_units_dfs, axis=0, ignore_index=True, sort=False
)
return await basic_search(
config,
text_units=text_units_combined,
query=query,
)