Files
graphrag-microsoft/graphrag/query/structured_search/drift_search/search.py
Josh Bradley b8b949f3bb Cleanup query api - remove code duplication (#1690)
* consolidate query api functions and remove code duplication

* refactor and remove more code duplication

* Add semversioner file

* fix basic search

* fix drift search and update base class function names

* update example notebooks
2025-02-13 16:31:08 -05:00

414 lines
15 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""DRIFT Search implementation."""
import logging
import time
from collections.abc import AsyncGenerator
from typing import Any
import tiktoken
from tqdm.asyncio import tqdm_asyncio
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.query.structured_search.drift_search.action import DriftAction
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.primer import DRIFTPrimer
from graphrag.query.structured_search.drift_search.state import QueryState
from graphrag.query.structured_search.local_search.search import LocalSearch
log = logging.getLogger(__name__)
class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
"""Class representing a DRIFT Search."""
def __init__(
self,
llm: ChatOpenAI,
context_builder: DRIFTSearchContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
query_state: QueryState | None = None,
):
"""
Initialize the DRIFTSearch class.
Args:
llm (ChatOpenAI): The language model used for searching.
context_builder (DRIFTSearchContextBuilder): Builder for search context.
config (DRIFTSearchConfig, optional): Configuration settings for DRIFTSearch.
token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens.
query_state (QueryState, optional): State of the current search query.
"""
super().__init__(llm, context_builder, token_encoder)
self.context_builder = context_builder
self.token_encoder = token_encoder
self.query_state = query_state or QueryState()
self.primer = DRIFTPrimer(
config=self.context_builder.config,
chat_llm=llm,
token_encoder=token_encoder,
)
self.local_search = self.init_local_search()
def init_local_search(self) -> LocalSearch:
"""
Initialize the LocalSearch object with parameters based on the DRIFT search configuration.
Returns
-------
LocalSearch: An instance of the LocalSearch class with the configured parameters.
"""
local_context_params = {
"text_unit_prop": self.context_builder.config.local_search_text_unit_prop,
"community_prop": self.context_builder.config.local_search_community_prop,
"top_k_mapped_entities": self.context_builder.config.local_search_top_k_mapped_entities,
"top_k_relationships": self.context_builder.config.local_search_top_k_relationships,
"include_entity_rank": True,
"include_relationship_weight": True,
"include_community_rank": False,
"return_candidate_context": False,
"embedding_vectorstore_key": EntityVectorStoreKey.ID,
"max_tokens": self.context_builder.config.local_search_max_data_tokens,
}
llm_params = {
"max_tokens": self.context_builder.config.local_search_llm_max_gen_tokens,
"temperature": self.context_builder.config.local_search_temperature,
"response_format": {"type": "json_object"},
}
return LocalSearch(
llm=self.llm,
system_prompt=self.context_builder.local_system_prompt,
context_builder=self.context_builder.local_mixed_context,
token_encoder=self.token_encoder,
llm_params=llm_params,
context_builder_params=local_context_params,
response_type="multiple paragraphs",
)
def _process_primer_results(
self, query: str, search_results: SearchResult
) -> DriftAction:
"""
Process the results from the primer search to extract intermediate answers and follow-up queries.
Args:
query (str): The original search query.
search_results (SearchResult): The results from the primer search.
Returns
-------
DriftAction: Action generated from the primer response.
Raises
------
RuntimeError: If no intermediate answers or follow-up queries are found in the primer response.
"""
response = search_results.response
if isinstance(response, list) and isinstance(response[0], dict):
intermediate_answers = [
i["intermediate_answer"] for i in response if "intermediate_answer" in i
]
if not intermediate_answers:
error_msg = "No intermediate answers found in primer response. Ensure that the primer response includes intermediate answers."
raise RuntimeError(error_msg)
intermediate_answer = "\n\n".join([
i["intermediate_answer"] for i in response if "intermediate_answer" in i
])
follow_ups = [fu for i in response for fu in i.get("follow_up_queries", [])]
if not follow_ups:
error_msg = "No follow-up queries found in primer response. Ensure that the primer response includes follow-up queries."
raise RuntimeError(error_msg)
score = sum(i.get("score", float("-inf")) for i in response) / len(response)
response_data = {
"intermediate_answer": intermediate_answer,
"follow_up_queries": follow_ups,
"score": score,
}
return DriftAction.from_primer_response(query, response_data)
error_msg = "Response must be a list of dictionaries."
raise ValueError(error_msg)
async def _search_step(
self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction]
) -> list[DriftAction]:
"""
Perform an asynchronous search step by executing each DriftAction asynchronously.
Args:
global_query (str): The global query for the search.
search_engine (LocalSearch): The local search engine instance.
actions (list[DriftAction]): A list of actions to perform.
Returns
-------
list[DriftAction]: The results from executing the search actions asynchronously.
"""
tasks = [
action.search(search_engine=search_engine, global_query=global_query)
for action in actions
]
return await tqdm_asyncio.gather(*tasks, leave=False)
async def search(
self,
query: str,
conversation_history: Any = None,
reduce: bool = True,
**kwargs,
) -> SearchResult:
"""
Perform an asynchronous DRIFT search.
Args:
query (str): The query to search for.
conversation_history (Any, optional): The conversation history, if any.
reduce (bool, optional): Whether to reduce the response to a single comprehensive response.
Returns
-------
SearchResult: The search result containing the response and context data.
Raises
------
ValueError: If the query is empty.
"""
if query == "":
error_msg = "DRIFT Search query cannot be empty."
raise ValueError(error_msg)
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
start_time = time.perf_counter()
# Check if query state is empty
if not self.query_state.graph:
# Prime the search with the primer
primer_context, token_ct = self.context_builder.build_context(query)
llm_calls["build_context"] = token_ct["llm_calls"]
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
output_tokens["build_context"] = token_ct["prompt_tokens"]
primer_response = await self.primer.search(
query=query, top_k_reports=primer_context
)
llm_calls["primer"] = primer_response.llm_calls
prompt_tokens["primer"] = primer_response.prompt_tokens
output_tokens["primer"] = primer_response.output_tokens
# Package response into DriftAction
init_action = self._process_primer_results(query, primer_response)
self.query_state.add_action(init_action)
self.query_state.add_all_follow_ups(init_action, init_action.follow_ups)
# Main loop
epochs = 0
llm_call_offset = 0
while epochs < self.context_builder.config.n_depth:
actions = self.query_state.rank_incomplete_actions()
if len(actions) == 0:
log.info("No more actions to take. Exiting DRIFT loop.")
break
actions = actions[: self.context_builder.config.drift_k_followups]
llm_call_offset += (
len(actions) - self.context_builder.config.drift_k_followups
)
# Process actions
results = await self._search_step(
global_query=query, search_engine=self.local_search, actions=actions
)
# Update query state
for action in results:
self.query_state.add_action(action)
self.query_state.add_all_follow_ups(action, action.follow_ups)
epochs += 1
t_elapsed = time.perf_counter() - start_time
# Calculate token usage
token_ct = self.query_state.action_token_ct()
llm_calls["action"] = token_ct["llm_calls"]
prompt_tokens["action"] = token_ct["prompt_tokens"]
output_tokens["action"] = token_ct["output_tokens"]
# Package up context data
response_state, context_data, context_text = self.query_state.serialize(
include_context=True
)
reduced_response = response_state
if reduce:
# Reduce response_state to a single comprehensive response
reduced_response = await self._reduce_response(
responses=response_state,
query=query,
llm_calls=llm_calls,
prompt_tokens=prompt_tokens,
output_tokens=output_tokens,
max_tokens=self.context_builder.config.reduce_max_tokens,
temperature=self.context_builder.config.reduce_temperature,
)
return SearchResult(
response=reduced_response,
context_data=context_data,
context_text=context_text,
completion_time=t_elapsed,
llm_calls=sum(llm_calls.values()),
prompt_tokens=sum(prompt_tokens.values()),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
async def stream_search(
self, query: str, conversation_history: ConversationHistory | None = None
) -> AsyncGenerator[str, None]:
"""
Perform a streaming DRIFT search asynchronously.
Args:
query (str): The query to search for.
conversation_history (ConversationHistory, optional): The conversation history.
"""
result = await self.search(
query=query, conversation_history=conversation_history, reduce=False
)
if isinstance(result.response, list):
result.response = result.response[0]
async for resp in self._reduce_response_streaming(
responses=result.response,
query=query,
max_tokens=self.context_builder.config.reduce_max_tokens,
temperature=self.context_builder.config.reduce_temperature,
):
yield resp
async def _reduce_response(
self,
responses: str | dict[str, Any],
query: str,
llm_calls: dict[str, int],
prompt_tokens: dict[str, int],
output_tokens: dict[str, int],
**llm_kwargs,
) -> str:
"""Reduce the response to a single comprehensive response.
Parameters
----------
responses : str|dict[str, Any]
The responses to reduce.
query : str
The original query.
llm_kwargs : dict[str, Any]
Additional keyword arguments to pass to the LLM.
Returns
-------
str
The reduced response.
"""
reduce_responses = []
if isinstance(responses, str):
reduce_responses = [responses]
else:
reduce_responses = [
response["answer"]
for response in responses.get("nodes", [])
if response.get("answer")
]
search_prompt = self.context_builder.reduce_system_prompt.format(
context_data=reduce_responses,
response_type=self.context_builder.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
reduced_response = self.llm.generate(
messages=search_messages,
streaming=False,
callbacks=None,
**llm_kwargs,
)
llm_calls["reduce"] = 1
prompt_tokens["reduce"] = num_tokens(
search_prompt, self.token_encoder
) + num_tokens(query, self.token_encoder)
output_tokens["reduce"] = num_tokens(reduced_response, self.token_encoder)
return reduced_response
async def _reduce_response_streaming(
self,
responses: str | dict[str, Any],
query: str,
**llm_kwargs,
) -> AsyncGenerator[str, None]:
"""Reduce the response to a single comprehensive response.
Parameters
----------
responses : str|dict[str, Any]
The responses to reduce.
query : str
The original query.
llm_kwargs : dict[str, Any]
Additional keyword arguments to pass to the LLM.
Returns
-------
str
The reduced response.
"""
reduce_responses = []
if isinstance(responses, str):
reduce_responses = [responses]
else:
reduce_responses = [
response["answer"]
for response in responses.get("nodes", [])
if response.get("answer")
]
search_prompt = self.context_builder.reduce_system_prompt.format(
context_data=reduce_responses,
response_type=self.context_builder.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
async for resp in self.llm.astream_generate(
search_messages,
callbacks=None,
**llm_kwargs,
):
yield resp