mirror of
https://github.com/microsoft/graphrag.git
synced 2025-03-11 01:26:14 +03:00
* consolidate query api functions and remove code duplication * refactor and remove more code duplication * Add semversioner file * fix basic search * fix drift search and update base class function names * update example notebooks
1210 lines
46 KiB
Python
1210 lines
46 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
"""
|
|
Query Engine API.
|
|
|
|
This API provides access to the query engine of graphrag, allowing external applications
|
|
to hook into graphrag and run queries over a knowledge graph generated by graphrag.
|
|
|
|
Contains the following functions:
|
|
- global_search: Perform a global search.
|
|
- global_search_streaming: Perform a global search and stream results back.
|
|
- local_search: Perform a local search.
|
|
- local_search_streaming: Perform a local search and stream results back.
|
|
|
|
WARNING: This API is under development and may undergo changes in future releases.
|
|
Backwards compatibility is not guaranteed at this time.
|
|
"""
|
|
|
|
from collections.abc import AsyncGenerator
|
|
from typing import Any
|
|
|
|
import pandas as pd
|
|
from pydantic import validate_call
|
|
|
|
from graphrag.config.embeddings import (
|
|
community_full_content_embedding,
|
|
entity_description_embedding,
|
|
text_unit_text_embedding,
|
|
)
|
|
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
|
from graphrag.logger.print_progress import PrintProgressLogger
|
|
from graphrag.query.factory import (
|
|
get_basic_search_engine,
|
|
get_drift_search_engine,
|
|
get_global_search_engine,
|
|
get_local_search_engine,
|
|
)
|
|
from graphrag.query.indexer_adapters import (
|
|
read_indexer_communities,
|
|
read_indexer_covariates,
|
|
read_indexer_entities,
|
|
read_indexer_relationships,
|
|
read_indexer_report_embeddings,
|
|
read_indexer_reports,
|
|
read_indexer_text_units,
|
|
)
|
|
from graphrag.utils.api import (
|
|
get_embedding_store,
|
|
load_search_prompt,
|
|
reformat_context_data,
|
|
update_context_data,
|
|
)
|
|
from graphrag.utils.cli import redact
|
|
|
|
logger = PrintProgressLogger("")
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def global_search(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
community_level: int | None,
|
|
dynamic_community_selection: bool,
|
|
response_type: str,
|
|
query: str,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a global search and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- communities (pd.DataFrame): A DataFrame containing the final communities (from communities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
|
|
- response_type (str): The type of response to return.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
full_response = ""
|
|
context_data = {}
|
|
get_context_data = True
|
|
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
|
# All subsequent chunks are the query response.
|
|
async for chunk in global_search_streaming(
|
|
config=config,
|
|
entities=entities,
|
|
communities=communities,
|
|
community_reports=community_reports,
|
|
community_level=community_level,
|
|
dynamic_community_selection=dynamic_community_selection,
|
|
response_type=response_type,
|
|
query=query,
|
|
):
|
|
if get_context_data:
|
|
context_data = chunk
|
|
get_context_data = False
|
|
else:
|
|
full_response += chunk
|
|
return full_response, context_data
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def global_search_streaming(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
community_level: int | None,
|
|
dynamic_community_selection: bool,
|
|
response_type: str,
|
|
query: str,
|
|
) -> AsyncGenerator:
|
|
"""Perform a global search and return the context data and response via a generator.
|
|
|
|
Context data is returned as a dictionary of lists, with one list entry for each record.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- communities (pd.DataFrame): A DataFrame containing the final communities (from communities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
|
|
- response_type (str): The type of response to return.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
communities_ = read_indexer_communities(communities, community_reports)
|
|
reports = read_indexer_reports(
|
|
community_reports,
|
|
communities,
|
|
community_level=community_level,
|
|
dynamic_community_selection=dynamic_community_selection,
|
|
)
|
|
entities_ = read_indexer_entities(
|
|
entities, communities, community_level=community_level
|
|
)
|
|
map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt)
|
|
reduce_prompt = load_search_prompt(
|
|
config.root_dir, config.global_search.reduce_prompt
|
|
)
|
|
knowledge_prompt = load_search_prompt(
|
|
config.root_dir, config.global_search.knowledge_prompt
|
|
)
|
|
|
|
search_engine = get_global_search_engine(
|
|
config,
|
|
reports=reports,
|
|
entities=entities_,
|
|
communities=communities_,
|
|
response_type=response_type,
|
|
dynamic_community_selection=dynamic_community_selection,
|
|
map_system_prompt=map_prompt,
|
|
reduce_system_prompt=reduce_prompt,
|
|
general_knowledge_inclusion_prompt=knowledge_prompt,
|
|
)
|
|
search_result = search_engine.stream_search(query=query)
|
|
|
|
# NOTE: when streaming results, a context data object is returned as the first result
|
|
# and the query response in subsequent tokens
|
|
context_data = {}
|
|
get_context_data = True
|
|
async for stream_chunk in search_result:
|
|
if get_context_data:
|
|
context_data = reformat_context_data(stream_chunk) # type: ignore
|
|
yield context_data
|
|
get_context_data = False
|
|
else:
|
|
yield stream_chunk
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def multi_index_global_search(
|
|
config: GraphRagConfig,
|
|
entities_list: list[pd.DataFrame],
|
|
communities_list: list[pd.DataFrame],
|
|
community_reports_list: list[pd.DataFrame],
|
|
index_names: list[str],
|
|
community_level: int | None,
|
|
dynamic_community_selection: bool,
|
|
response_type: str,
|
|
streaming: bool,
|
|
query: str,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a global search across multiple indexes and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
|
|
- communities_list (list[pd.DataFrame]): A list of DataFrames containing the final communities (from communities.parquet)
|
|
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
|
|
- index_names (list[str]): A list of index names.
|
|
- community_level (int): The community level to search at.
|
|
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
|
|
- response_type (str): The type of response to return.
|
|
- streaming (bool): Whether to stream the results or not.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
# Streaming not supported yet
|
|
if streaming:
|
|
message = "Streaming not yet implemented for multi_global_search"
|
|
raise NotImplementedError(message)
|
|
|
|
links = {
|
|
"communities": {},
|
|
"community_reports": {},
|
|
"entities": {},
|
|
}
|
|
max_vals = {
|
|
"communities": -1,
|
|
"community_reports": -1,
|
|
"entities": -1,
|
|
}
|
|
|
|
communities_dfs = []
|
|
community_reports_dfs = []
|
|
entities_dfs = []
|
|
|
|
for idx, index_name in enumerate(index_names):
|
|
# Prepare each index's community reports dataframe for merging
|
|
community_reports_df = community_reports_list[idx]
|
|
community_reports_df["community"] = community_reports_df["community"].astype(
|
|
int
|
|
)
|
|
for i in community_reports_df["community"]:
|
|
links["community_reports"][i + max_vals["community_reports"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
community_reports_df["community"] += max_vals["community_reports"] + 1
|
|
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
|
|
max_vals["community_reports"] = int(community_reports_df["community"].max())
|
|
community_reports_dfs.append(community_reports_df)
|
|
|
|
# Prepare each index's communities dataframe for merging
|
|
communities_df = communities_list[idx]
|
|
communities_df["community"] = communities_df["community"].astype(int)
|
|
communities_df["parent"] = communities_df["parent"].astype(int)
|
|
for i in communities_df["community"]:
|
|
links["communities"][i + max_vals["communities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
communities_df["community"] += max_vals["communities"] + 1
|
|
communities_df["parent"] = communities_df["parent"].apply(
|
|
lambda x: x if x == -1 else x + max_vals["communities"] + 1
|
|
)
|
|
communities_df["human_readable_id"] += max_vals["communities"] + 1
|
|
# concat the index name to the entity_ids, since this is used for joining later
|
|
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["communities"] = int(communities_df["community"].max())
|
|
communities_dfs.append(communities_df)
|
|
|
|
# Prepare each index's entities dataframe for merging
|
|
entities_df = entities_list[idx]
|
|
for i in entities_df["human_readable_id"]:
|
|
links["entities"][i + max_vals["entities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
entities_df["human_readable_id"] += max_vals["entities"] + 1
|
|
entities_df["title"] = entities_df["title"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["entities"] = int(entities_df["human_readable_id"].max())
|
|
entities_dfs.append(entities_df)
|
|
|
|
# Merge the dataframes
|
|
community_reports_combined = pd.concat(
|
|
community_reports_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
|
|
communities_combined = pd.concat(
|
|
communities_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
|
|
result = await global_search(
|
|
config,
|
|
entities=entities_combined,
|
|
communities=communities_combined,
|
|
community_reports=community_reports_combined,
|
|
community_level=community_level,
|
|
dynamic_community_selection=dynamic_community_selection,
|
|
response_type=response_type,
|
|
query=query,
|
|
)
|
|
|
|
# Update the context data by linking index names and community ids
|
|
context = update_context_data(result[1], links)
|
|
|
|
return (result[0], context)
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def local_search(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
text_units: pd.DataFrame,
|
|
relationships: pd.DataFrame,
|
|
covariates: pd.DataFrame | None,
|
|
community_level: int,
|
|
response_type: str,
|
|
query: str,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a local search and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
|
|
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from covariates.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- response_type (str): The response type to return.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
full_response = ""
|
|
context_data = {}
|
|
get_context_data = True
|
|
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
|
# All subsequent chunks are the query response.
|
|
async for chunk in local_search_streaming(
|
|
config=config,
|
|
entities=entities,
|
|
communities=communities,
|
|
community_reports=community_reports,
|
|
text_units=text_units,
|
|
relationships=relationships,
|
|
covariates=covariates,
|
|
community_level=community_level,
|
|
response_type=response_type,
|
|
query=query,
|
|
):
|
|
if get_context_data:
|
|
context_data = chunk
|
|
get_context_data = False
|
|
else:
|
|
full_response += chunk
|
|
return full_response, context_data
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def local_search_streaming(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
text_units: pd.DataFrame,
|
|
relationships: pd.DataFrame,
|
|
covariates: pd.DataFrame | None,
|
|
community_level: int,
|
|
response_type: str,
|
|
query: str,
|
|
) -> AsyncGenerator:
|
|
"""Perform a local search and return the context data and response via a generator.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
|
|
- covariates (pd.DataFrame): A DataFrame containing the final covariates (from covariates.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- response_type (str): The response type to return.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
vector_store_args = {}
|
|
for index, store in config.vector_store.items():
|
|
vector_store_args[index] = store.model_dump()
|
|
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
|
|
|
description_embedding_store = get_embedding_store(
|
|
config_args=vector_store_args, # type: ignore
|
|
embedding_name=entity_description_embedding,
|
|
)
|
|
|
|
entities_ = read_indexer_entities(entities, communities, community_level)
|
|
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
|
|
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
|
|
|
|
search_engine = get_local_search_engine(
|
|
config=config,
|
|
reports=read_indexer_reports(community_reports, communities, community_level),
|
|
text_units=read_indexer_text_units(text_units),
|
|
entities=entities_,
|
|
relationships=read_indexer_relationships(relationships),
|
|
covariates={"claims": covariates_},
|
|
description_embedding_store=description_embedding_store, # type: ignore
|
|
response_type=response_type,
|
|
system_prompt=prompt,
|
|
)
|
|
search_result = search_engine.stream_search(query=query)
|
|
|
|
# NOTE: when streaming results, a context data object is returned as the first result
|
|
# and the query response in subsequent tokens
|
|
context_data = {}
|
|
get_context_data = True
|
|
async for stream_chunk in search_result:
|
|
if get_context_data:
|
|
context_data = reformat_context_data(stream_chunk) # type: ignore
|
|
yield context_data
|
|
get_context_data = False
|
|
else:
|
|
yield stream_chunk
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def multi_index_local_search(
|
|
config: GraphRagConfig,
|
|
entities_list: list[pd.DataFrame],
|
|
communities_list: list[pd.DataFrame],
|
|
community_reports_list: list[pd.DataFrame],
|
|
text_units_list: list[pd.DataFrame],
|
|
relationships_list: list[pd.DataFrame],
|
|
covariates_list: list[pd.DataFrame] | None,
|
|
index_names: list[str],
|
|
community_level: int,
|
|
response_type: str,
|
|
streaming: bool,
|
|
query: str,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a local search across multiple indexes and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
|
|
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
|
|
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
|
|
- relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet)
|
|
- covariates_list (list[pd.DataFrame]): [Optional] A list of DataFrames containing the final covariates (from covariates.parquet)
|
|
- index_names (list[str]): A list of index names.
|
|
- community_level (int): The community level to search at.
|
|
- response_type (str): The response type to return.
|
|
- streaming (bool): Whether to stream the results or not.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
# Streaming not supported yet
|
|
if streaming:
|
|
message = "Streaming not yet implemented for multi_index_local_search"
|
|
raise NotImplementedError(message)
|
|
|
|
links = {
|
|
"community_reports": {},
|
|
"communities": {},
|
|
"entities": {},
|
|
"text_units": {},
|
|
"relationships": {},
|
|
"covariates": {},
|
|
}
|
|
max_vals = {
|
|
"community_reports": -1,
|
|
"communities": -1,
|
|
"entities": -1,
|
|
"text_units": 0,
|
|
"relationships": -1,
|
|
"covariates": 0,
|
|
}
|
|
community_reports_dfs = []
|
|
communities_dfs = []
|
|
entities_dfs = []
|
|
relationships_dfs = []
|
|
text_units_dfs = []
|
|
covariates_dfs = []
|
|
|
|
for idx, index_name in enumerate(index_names):
|
|
# Prepare each index's communities dataframe for merging
|
|
communities_df = communities_list[idx]
|
|
communities_df["community"] = communities_df["community"].astype(int)
|
|
for i in communities_df["community"]:
|
|
links["communities"][i + max_vals["communities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
communities_df["community"] += max_vals["communities"] + 1
|
|
communities_df["human_readable_id"] += max_vals["communities"] + 1
|
|
# concat the index name to the entity_ids, since this is used for joining later
|
|
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["communities"] = int(communities_df["community"].max())
|
|
communities_dfs.append(communities_df)
|
|
|
|
# Prepare each index's community reports dataframe for merging
|
|
community_reports_df = community_reports_list[idx]
|
|
community_reports_df["community"] = community_reports_df["community"].astype(
|
|
int
|
|
)
|
|
for i in community_reports_df["community"]:
|
|
links["community_reports"][i + max_vals["community_reports"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
community_reports_df["community"] += max_vals["community_reports"] + 1
|
|
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
|
|
max_vals["community_reports"] = int(community_reports_df["community"].max())
|
|
community_reports_dfs.append(community_reports_df)
|
|
|
|
# Prepare each index's entities dataframe for merging
|
|
entities_df = entities_list[idx]
|
|
for i in entities_df["human_readable_id"]:
|
|
links["entities"][i + max_vals["entities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
entities_df["human_readable_id"] += max_vals["entities"] + 1
|
|
entities_df["title"] = entities_df["title"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
entities_df["id"] = entities_df["id"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["entities"] = int(entities_df["human_readable_id"].max())
|
|
entities_dfs.append(entities_df)
|
|
|
|
# Prepare each index's relationships dataframe for merging
|
|
relationships_df = relationships_list[idx]
|
|
for i in relationships_df["human_readable_id"].astype(int):
|
|
links["relationships"][i + max_vals["relationships"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
if max_vals["relationships"] != -1:
|
|
col = (
|
|
relationships_df["human_readable_id"].astype(int)
|
|
+ max_vals["relationships"]
|
|
+ 1
|
|
)
|
|
relationships_df["human_readable_id"] = col.astype(str)
|
|
relationships_df["source"] = relationships_df["source"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
relationships_df["target"] = relationships_df["target"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["relationships"] = int(relationships_df["human_readable_id"].max())
|
|
relationships_dfs.append(relationships_df)
|
|
|
|
# Prepare each index's text units dataframe for merging
|
|
text_units_df = text_units_list[idx]
|
|
for i in range(text_units_df.shape[0]):
|
|
links["text_units"][i + max_vals["text_units"]] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
text_units_df["id"] = text_units_df["id"].apply(
|
|
lambda x, index_name=index_name: f"{x}-{index_name}"
|
|
)
|
|
text_units_df["human_readable_id"] = (
|
|
text_units_df["human_readable_id"] + max_vals["text_units"]
|
|
)
|
|
max_vals["text_units"] += text_units_df.shape[0]
|
|
text_units_dfs.append(text_units_df)
|
|
|
|
# If presents, prepare each index's covariates dataframe for merging
|
|
if covariates_list is not None:
|
|
covariates_df = covariates_list[idx]
|
|
for i in covariates_df["human_readable_id"].astype(int):
|
|
links["covariates"][i + max_vals["covariates"]] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
covariates_df["id"] = covariates_df["id"].apply(
|
|
lambda x, index_name=index_name: f"{x}-{index_name}"
|
|
)
|
|
covariates_df["human_readable_id"] = (
|
|
covariates_df["human_readable_id"] + max_vals["covariates"]
|
|
)
|
|
covariates_df["text_unit_id"] = covariates_df["text_unit_id"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
covariates_df["subject_id"] = covariates_df["subject_id"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
max_vals["covariates"] += covariates_df.shape[0]
|
|
covariates_dfs.append(covariates_df)
|
|
|
|
# Merge the dataframes
|
|
communities_combined = pd.concat(
|
|
communities_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
community_reports_combined = pd.concat(
|
|
community_reports_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
|
|
relationships_combined = pd.concat(
|
|
relationships_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
text_units_combined = pd.concat(
|
|
text_units_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
covariates_combined = None
|
|
if len(covariates_dfs) > 0:
|
|
covariates_combined = pd.concat(
|
|
covariates_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
result = await local_search(
|
|
config,
|
|
entities=entities_combined,
|
|
communities=communities_combined,
|
|
community_reports=community_reports_combined,
|
|
text_units=text_units_combined,
|
|
relationships=relationships_combined,
|
|
covariates=covariates_combined,
|
|
community_level=community_level,
|
|
response_type=response_type,
|
|
query=query,
|
|
)
|
|
|
|
# Update the context data by linking index names and community ids
|
|
context = update_context_data(result[1], links)
|
|
|
|
return (result[0], context)
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def drift_search(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
text_units: pd.DataFrame,
|
|
relationships: pd.DataFrame,
|
|
community_level: int,
|
|
response_type: str,
|
|
query: str,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a DRIFT search and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
full_response = ""
|
|
context_data = {}
|
|
get_context_data = True
|
|
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
|
# All subsequent chunks are the query response.
|
|
async for chunk in drift_search_streaming(
|
|
config=config,
|
|
entities=entities,
|
|
communities=communities,
|
|
community_reports=community_reports,
|
|
text_units=text_units,
|
|
relationships=relationships,
|
|
community_level=community_level,
|
|
response_type=response_type,
|
|
query=query,
|
|
):
|
|
if get_context_data:
|
|
context_data = chunk
|
|
get_context_data = False
|
|
else:
|
|
full_response += chunk
|
|
return full_response, context_data
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def drift_search_streaming(
|
|
config: GraphRagConfig,
|
|
entities: pd.DataFrame,
|
|
communities: pd.DataFrame,
|
|
community_reports: pd.DataFrame,
|
|
text_units: pd.DataFrame,
|
|
relationships: pd.DataFrame,
|
|
community_level: int,
|
|
response_type: str,
|
|
query: str,
|
|
) -> AsyncGenerator:
|
|
"""Perform a DRIFT search and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities (pd.DataFrame): A DataFrame containing the final entities (from entities.parquet)
|
|
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from community_reports.parquet)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from relationships.parquet)
|
|
- community_level (int): The community level to search at.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
vector_store_args = {}
|
|
for index, store in config.vector_store.items():
|
|
vector_store_args[index] = store.model_dump()
|
|
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
|
|
|
description_embedding_store = get_embedding_store(
|
|
config_args=vector_store_args, # type: ignore
|
|
embedding_name=entity_description_embedding,
|
|
)
|
|
|
|
full_content_embedding_store = get_embedding_store(
|
|
config_args=vector_store_args, # type: ignore
|
|
embedding_name=community_full_content_embedding,
|
|
)
|
|
|
|
entities_ = read_indexer_entities(entities, communities, community_level)
|
|
reports = read_indexer_reports(community_reports, communities, community_level)
|
|
read_indexer_report_embeddings(reports, full_content_embedding_store)
|
|
prompt = load_search_prompt(config.root_dir, config.drift_search.prompt)
|
|
reduce_prompt = load_search_prompt(
|
|
config.root_dir, config.drift_search.reduce_prompt
|
|
)
|
|
|
|
search_engine = get_drift_search_engine(
|
|
config=config,
|
|
reports=reports,
|
|
text_units=read_indexer_text_units(text_units),
|
|
entities=entities_,
|
|
relationships=read_indexer_relationships(relationships),
|
|
description_embedding_store=description_embedding_store, # type: ignore
|
|
local_system_prompt=prompt,
|
|
reduce_system_prompt=reduce_prompt,
|
|
response_type=response_type,
|
|
)
|
|
search_result = search_engine.stream_search(query=query)
|
|
|
|
# NOTE: when streaming results, a context data object is returned as the first result
|
|
# and the query response in subsequent tokens
|
|
context_data = {}
|
|
get_context_data = True
|
|
async for stream_chunk in search_result:
|
|
if get_context_data:
|
|
context_data = reformat_context_data(stream_chunk) # type: ignore
|
|
yield context_data
|
|
get_context_data = False
|
|
else:
|
|
yield stream_chunk
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def multi_index_drift_search(
|
|
config: GraphRagConfig,
|
|
entities_list: list[pd.DataFrame],
|
|
communities_list: list[pd.DataFrame],
|
|
community_reports_list: list[pd.DataFrame],
|
|
text_units_list: list[pd.DataFrame],
|
|
relationships_list: list[pd.DataFrame],
|
|
index_names: list[str],
|
|
community_level: int,
|
|
response_type: str,
|
|
streaming: bool,
|
|
query: str,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a DRIFT search across multiple indexes and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- entities_list (list[pd.DataFrame]): A list of DataFrames containing the final entities (from entities.parquet)
|
|
- community_reports_list (list[pd.DataFrame]): A list of DataFrames containing the final community reports (from community_reports.parquet)
|
|
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
|
|
- relationships_list (list[pd.DataFrame]): A list of DataFrames containing the final relationships (from relationships.parquet)
|
|
- index_names (list[str]): A list of index names.
|
|
- community_level (int): The community level to search at.
|
|
- response_type (str): The response type to return.
|
|
- streaming (bool): Whether to stream the results or not.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
# Streaming not supported yet
|
|
if streaming:
|
|
message = "Streaming not yet implemented for multi_drift_search"
|
|
raise NotImplementedError(message)
|
|
|
|
links = {
|
|
"community_reports": {},
|
|
"communities": {},
|
|
"entities": {},
|
|
"text_units": {},
|
|
"relationships": {},
|
|
}
|
|
max_vals = {
|
|
"community_reports": -1,
|
|
"communities": -1,
|
|
"entities": -1,
|
|
"text_units": 0,
|
|
"relationships": -1,
|
|
}
|
|
|
|
communities_dfs = []
|
|
community_reports_dfs = []
|
|
entities_dfs = []
|
|
relationships_dfs = []
|
|
text_units_dfs = []
|
|
|
|
for idx, index_name in enumerate(index_names):
|
|
# Prepare each index's communities dataframe for merging
|
|
communities_df = communities_list[idx]
|
|
communities_df["community"] = communities_df["community"].astype(int)
|
|
for i in communities_df["community"]:
|
|
links["communities"][i + max_vals["communities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
communities_df["community"] += max_vals["communities"] + 1
|
|
communities_df["human_readable_id"] += max_vals["communities"] + 1
|
|
# concat the index name to the entity_ids, since this is used for joining later
|
|
communities_df["entity_ids"] = communities_df["entity_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["communities"] = int(communities_df["community"].max())
|
|
communities_dfs.append(communities_df)
|
|
|
|
# Prepare each index's community reports dataframe for merging
|
|
community_reports_df = community_reports_list[idx]
|
|
community_reports_df["community"] = community_reports_df["community"].astype(
|
|
int
|
|
)
|
|
for i in community_reports_df["community"]:
|
|
links["community_reports"][i + max_vals["community_reports"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": str(i),
|
|
}
|
|
community_reports_df["community"] += max_vals["community_reports"] + 1
|
|
community_reports_df["human_readable_id"] += max_vals["community_reports"] + 1
|
|
community_reports_df["id"] = community_reports_df["id"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
max_vals["community_reports"] = int(community_reports_df["community"].max())
|
|
community_reports_dfs.append(community_reports_df)
|
|
|
|
# Prepare each index's entities dataframe for merging
|
|
entities_df = entities_list[idx]
|
|
for i in entities_df["human_readable_id"]:
|
|
links["entities"][i + max_vals["entities"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
entities_df["human_readable_id"] += max_vals["entities"] + 1
|
|
entities_df["title"] = entities_df["title"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
entities_df["id"] = entities_df["id"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
entities_df["text_unit_ids"] = entities_df["text_unit_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["entities"] = int(entities_df["human_readable_id"].max())
|
|
entities_dfs.append(entities_df)
|
|
|
|
# Prepare each index's relationships dataframe for merging
|
|
relationships_df = relationships_list[idx]
|
|
for i in relationships_df["human_readable_id"].astype(int):
|
|
links["relationships"][i + max_vals["relationships"] + 1] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
if max_vals["relationships"] != -1:
|
|
col = (
|
|
relationships_df["human_readable_id"].astype(int)
|
|
+ max_vals["relationships"]
|
|
+ 1
|
|
)
|
|
relationships_df["human_readable_id"] = col.astype(str)
|
|
relationships_df["source"] = relationships_df["source"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
relationships_df["target"] = relationships_df["target"].apply(
|
|
lambda x, index_name=index_name: x + f"-{index_name}"
|
|
)
|
|
relationships_df["text_unit_ids"] = relationships_df["text_unit_ids"].apply(
|
|
lambda x, index_name=index_name: [i + f"-{index_name}" for i in x]
|
|
)
|
|
max_vals["relationships"] = int(
|
|
relationships_df["human_readable_id"].astype(int).max()
|
|
)
|
|
|
|
relationships_dfs.append(relationships_df)
|
|
|
|
# Prepare each index's text units dataframe for merging
|
|
text_units_df = text_units_list[idx]
|
|
for i in range(text_units_df.shape[0]):
|
|
links["text_units"][i + max_vals["text_units"]] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
text_units_df["id"] = text_units_df["id"].apply(
|
|
lambda x, index_name=index_name: f"{x}-{index_name}"
|
|
)
|
|
text_units_df["human_readable_id"] = (
|
|
text_units_df["human_readable_id"] + max_vals["text_units"]
|
|
)
|
|
max_vals["text_units"] += text_units_df.shape[0]
|
|
text_units_dfs.append(text_units_df)
|
|
|
|
# Merge the dataframes
|
|
communities_combined = pd.concat(
|
|
communities_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
community_reports_combined = pd.concat(
|
|
community_reports_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
entities_combined = pd.concat(entities_dfs, axis=0, ignore_index=True, sort=False)
|
|
relationships_combined = pd.concat(
|
|
relationships_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
text_units_combined = pd.concat(
|
|
text_units_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
|
|
result = await drift_search(
|
|
config,
|
|
entities=entities_combined,
|
|
communities=communities_combined,
|
|
community_reports=community_reports_combined,
|
|
text_units=text_units_combined,
|
|
relationships=relationships_combined,
|
|
community_level=community_level,
|
|
response_type=response_type,
|
|
query=query,
|
|
)
|
|
|
|
# Update the context data by linking index names and community ids
|
|
context = {}
|
|
if type(result[1]) is dict:
|
|
for key in result[1]:
|
|
context[key] = update_context_data(result[1][key], links)
|
|
else:
|
|
context = result[1]
|
|
return (result[0], context)
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def basic_search(
|
|
config: GraphRagConfig,
|
|
text_units: pd.DataFrame,
|
|
query: str,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a basic search and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
full_response = ""
|
|
context_data = {}
|
|
get_context_data = True
|
|
# NOTE: when streaming, the first chunk of returned data is the complete context data.
|
|
# All subsequent chunks are the query response.
|
|
async for chunk in basic_search_streaming(
|
|
config=config,
|
|
text_units=text_units,
|
|
query=query,
|
|
):
|
|
if get_context_data:
|
|
context_data = chunk
|
|
get_context_data = False
|
|
else:
|
|
full_response += chunk
|
|
return full_response, context_data
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def basic_search_streaming(
|
|
config: GraphRagConfig,
|
|
text_units: pd.DataFrame,
|
|
query: str,
|
|
) -> AsyncGenerator:
|
|
"""Perform a local search and return the context data and response via a generator.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- text_units (pd.DataFrame): A DataFrame containing the final text units (from text_units.parquet)
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
vector_store_args = {}
|
|
for index, store in config.vector_store.items():
|
|
vector_store_args[index] = store.model_dump()
|
|
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
|
|
|
description_embedding_store = get_embedding_store(
|
|
config_args=vector_store_args, # type: ignore
|
|
embedding_name=text_unit_text_embedding,
|
|
)
|
|
|
|
prompt = load_search_prompt(config.root_dir, config.basic_search.prompt)
|
|
|
|
search_engine = get_basic_search_engine(
|
|
config=config,
|
|
text_units=read_indexer_text_units(text_units),
|
|
text_unit_embeddings=description_embedding_store,
|
|
system_prompt=prompt,
|
|
)
|
|
search_result = search_engine.stream_search(query=query)
|
|
|
|
# NOTE: when streaming results, a context data object is returned as the first result
|
|
# and the query response in subsequent tokens
|
|
context_data = {}
|
|
get_context_data = True
|
|
async for stream_chunk in search_result:
|
|
if get_context_data:
|
|
context_data = reformat_context_data(stream_chunk) # type: ignore
|
|
yield context_data
|
|
get_context_data = False
|
|
else:
|
|
yield stream_chunk
|
|
|
|
|
|
@validate_call(config={"arbitrary_types_allowed": True})
|
|
async def multi_index_basic_search(
|
|
config: GraphRagConfig,
|
|
text_units_list: list[pd.DataFrame],
|
|
index_names: list[str],
|
|
streaming: bool,
|
|
query: str,
|
|
) -> tuple[
|
|
str | dict[str, Any] | list[dict[str, Any]],
|
|
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
|
]:
|
|
"""Perform a basic search across multiple indexes and return the context data and response.
|
|
|
|
Parameters
|
|
----------
|
|
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
|
- text_units_list (list[pd.DataFrame]): A list of DataFrames containing the final text units (from text_units.parquet)
|
|
- index_names (list[str]): A list of index names.
|
|
- streaming (bool): Whether to stream the results or not.
|
|
- query (str): The user query to search for.
|
|
|
|
Returns
|
|
-------
|
|
TODO: Document the search response type and format.
|
|
|
|
Raises
|
|
------
|
|
TODO: Document any exceptions to expect.
|
|
"""
|
|
# Streaming not supported yet
|
|
if streaming:
|
|
message = "Streaming not yet implemented for multi_basic_search"
|
|
raise NotImplementedError(message)
|
|
|
|
links = {
|
|
"text_units": {},
|
|
}
|
|
max_vals = {
|
|
"text_units": 0,
|
|
}
|
|
|
|
text_units_dfs = []
|
|
|
|
for idx, index_name in enumerate(index_names):
|
|
# Prepare each index's text units dataframe for merging
|
|
text_units_df = text_units_list[idx]
|
|
for i in range(text_units_df.shape[0]):
|
|
links["text_units"][i + max_vals["text_units"]] = {
|
|
"index_name": index_name,
|
|
"id": i,
|
|
}
|
|
text_units_df["id"] = text_units_df["id"].apply(
|
|
lambda x, index_name=index_name: f"{x}-{index_name}"
|
|
)
|
|
text_units_df["human_readable_id"] = (
|
|
text_units_df["human_readable_id"] + max_vals["text_units"]
|
|
)
|
|
max_vals["text_units"] += text_units_df.shape[0]
|
|
text_units_dfs.append(text_units_df)
|
|
|
|
# Merge the dataframes
|
|
text_units_combined = pd.concat(
|
|
text_units_dfs, axis=0, ignore_index=True, sort=False
|
|
)
|
|
|
|
return await basic_search(
|
|
config,
|
|
text_units=text_units_combined,
|
|
query=query,
|
|
)
|