Files
graphrag-microsoft/graphrag/query/structured_search/global_search/search.py
Josh Bradley b8b949f3bb Cleanup query api - remove code duplication (#1690)
* consolidate query api functions and remove code duplication

* refactor and remove more code duplication

* Add semversioner file

* fix basic search

* fix drift search and update base class function names

* update example notebooks
2025-02-13 16:31:08 -05:00

479 lines
18 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The GlobalSearch Implementation."""
import asyncio
import json
import logging
import time
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
import pandas as pd
import tiktoken
from graphrag.callbacks.global_search_callbacks import GlobalSearchLLMCallback
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
GENERAL_KNOWLEDGE_INSTRUCTION,
)
from graphrag.prompts.query.global_search_map_system_prompt import (
MAP_SYSTEM_PROMPT,
)
from graphrag.prompts.query.global_search_reduce_system_prompt import (
NO_DATA_ANSWER,
REDUCE_SYSTEM_PROMPT,
)
from graphrag.query.context_builder.builders import GlobalContextBuilder
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object
from graphrag.query.structured_search.base import BaseSearch, SearchResult
DEFAULT_MAP_LLM_PARAMS = {
"max_tokens": 1000,
"temperature": 0.0,
}
DEFAULT_REDUCE_LLM_PARAMS = {
"max_tokens": 2000,
"temperature": 0.0,
}
log = logging.getLogger(__name__)
@dataclass(kw_only=True)
class GlobalSearchResult(SearchResult):
"""A GlobalSearch result."""
map_responses: list[SearchResult]
reduce_context_data: str | list[pd.DataFrame] | dict[str, pd.DataFrame]
reduce_context_text: str | list[str] | dict[str, str]
class GlobalSearch(BaseSearch[GlobalContextBuilder]):
"""Search orchestration for global search mode."""
def __init__(
self,
llm: BaseLLM,
context_builder: GlobalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
map_system_prompt: str | None = None,
reduce_system_prompt: str | None = None,
response_type: str = "multiple paragraphs",
allow_general_knowledge: bool = False,
general_knowledge_inclusion_prompt: str | None = None,
json_mode: bool = True,
callbacks: list[GlobalSearchLLMCallback] | None = None,
max_data_tokens: int = 8000,
map_llm_params: dict[str, Any] = DEFAULT_MAP_LLM_PARAMS,
reduce_llm_params: dict[str, Any] = DEFAULT_REDUCE_LLM_PARAMS,
context_builder_params: dict[str, Any] | None = None,
concurrent_coroutines: int = 32,
):
super().__init__(
llm=llm,
context_builder=context_builder,
token_encoder=token_encoder,
context_builder_params=context_builder_params,
)
self.map_system_prompt = map_system_prompt or MAP_SYSTEM_PROMPT
self.reduce_system_prompt = reduce_system_prompt or REDUCE_SYSTEM_PROMPT
self.response_type = response_type
self.allow_general_knowledge = allow_general_knowledge
self.general_knowledge_inclusion_prompt = (
general_knowledge_inclusion_prompt or GENERAL_KNOWLEDGE_INSTRUCTION
)
self.callbacks = callbacks
self.max_data_tokens = max_data_tokens
self.map_llm_params = map_llm_params
self.reduce_llm_params = reduce_llm_params
if json_mode:
self.map_llm_params["response_format"] = {"type": "json_object"}
else:
# remove response_format key if json_mode is False
self.map_llm_params.pop("response_format", None)
self.semaphore = asyncio.Semaphore(concurrent_coroutines)
async def stream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator:
"""Stream the global search response."""
context_result = await self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**self.context_builder_params,
)
if self.callbacks:
for callback in self.callbacks:
callback.on_map_response_start(context_result.context_chunks) # type: ignore
map_responses = await asyncio.gather(*[
self._map_response_single_batch(
context_data=data, query=query, **self.map_llm_params
)
for data in context_result.context_chunks
])
if self.callbacks:
for callback in self.callbacks:
callback.on_map_response_end(map_responses) # type: ignore
# send context records first before sending the reduce response
yield context_result.context_records
async for response in self._stream_reduce_response(
map_responses=map_responses, # type: ignore
query=query,
**self.reduce_llm_params,
):
yield response
async def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs: Any,
) -> GlobalSearchResult:
"""
Perform a global search.
Global search mode includes two steps:
- Step 1: Run parallel LLM calls on communities' short summaries to generate answer for each batch
- Step 2: Combine the answers from step 2 to generate the final answer
"""
# Step 1: Generate answers for each batch of community short summaries
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
start_time = time.time()
context_result = await self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**self.context_builder_params,
)
llm_calls["build_context"] = context_result.llm_calls
prompt_tokens["build_context"] = context_result.prompt_tokens
output_tokens["build_context"] = context_result.output_tokens
if self.callbacks:
for callback in self.callbacks:
callback.on_map_response_start(context_result.context_chunks) # type: ignore
map_responses = await asyncio.gather(*[
self._map_response_single_batch(
context_data=data, query=query, **self.map_llm_params
)
for data in context_result.context_chunks
])
if self.callbacks:
for callback in self.callbacks:
callback.on_map_response_end(map_responses)
llm_calls["map"] = sum(response.llm_calls for response in map_responses)
prompt_tokens["map"] = sum(response.prompt_tokens for response in map_responses)
output_tokens["map"] = sum(response.output_tokens for response in map_responses)
# Step 2: Combine the intermediate answers from step 2 to generate the final answer
reduce_response = await self._reduce_response(
map_responses=map_responses,
query=query,
**self.reduce_llm_params,
)
llm_calls["reduce"] = reduce_response.llm_calls
prompt_tokens["reduce"] = reduce_response.prompt_tokens
output_tokens["reduce"] = reduce_response.output_tokens
return GlobalSearchResult(
response=reduce_response.response,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
map_responses=map_responses,
reduce_context_data=reduce_response.context_data,
reduce_context_text=reduce_response.context_text,
completion_time=time.time() - start_time,
llm_calls=sum(llm_calls.values()),
prompt_tokens=sum(prompt_tokens.values()),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
async def _map_response_single_batch(
self,
context_data: str,
query: str,
**llm_kwargs,
) -> SearchResult:
"""Generate answer for a single chunk of community reports."""
start_time = time.time()
search_prompt = ""
try:
search_prompt = self.map_system_prompt.format(context_data=context_data)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
async with self.semaphore:
search_response = await self.llm.agenerate(
messages=search_messages, streaming=False, **llm_kwargs
)
log.info("Map response: %s", search_response)
try:
# parse search response json
processed_response = self._parse_search_response(search_response)
except ValueError:
log.warning(
"Warning: Error parsing search response json - skipping this batch"
)
processed_response = []
return SearchResult(
response=processed_response,
context_data=context_data,
context_text=context_data,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=num_tokens(search_response, self.token_encoder),
)
except Exception:
log.exception("Exception in _map_response_single_batch")
return SearchResult(
response=[{"answer": "", "score": 0}],
context_data=context_data,
context_text=context_data,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)
def _parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
"""Parse the search response json and return a list of key points.
Parameters
----------
search_response: str
The search response json string
Returns
-------
list[dict[str, Any]]
A list of key points, each key point is a dictionary with "answer" and "score" keys
"""
search_response, j = try_parse_json_object(search_response)
if j == {}:
return [{"answer": "", "score": 0}]
parsed_elements = json.loads(search_response).get("points")
if not parsed_elements or not isinstance(parsed_elements, list):
return [{"answer": "", "score": 0}]
return [
{
"answer": element["description"],
"score": int(element["score"]),
}
for element in parsed_elements
if "description" in element and "score" in element
]
async def _reduce_response(
self,
map_responses: list[SearchResult],
query: str,
**llm_kwargs,
) -> SearchResult:
"""Combine all intermediate responses from single batches into a final answer to the user query."""
text_data = ""
search_prompt = ""
start_time = time.time()
try:
# collect all key points into a single list to prepare for sorting
key_points = []
for index, response in enumerate(map_responses):
if not isinstance(response.response, list):
continue
for element in response.response:
if not isinstance(element, dict):
continue
if "answer" not in element or "score" not in element:
continue
key_points.append({
"analyst": index,
"answer": element["answer"],
"score": element["score"],
})
# filter response with score = 0 and rank responses by descending order of score
filtered_key_points = [
point
for point in key_points
if point["score"] > 0 # type: ignore
]
if len(filtered_key_points) == 0 and not self.allow_general_knowledge:
# return no data answer if no key points are found
log.warning(
"Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations."
)
return SearchResult(
response=NO_DATA_ANSWER,
context_data="",
context_text="",
completion_time=time.time() - start_time,
llm_calls=0,
prompt_tokens=0,
output_tokens=0,
)
filtered_key_points = sorted(
filtered_key_points,
key=lambda x: x["score"], # type: ignore
reverse=True, # type: ignore
)
data = []
total_tokens = 0
for point in filtered_key_points:
formatted_response_data = []
formatted_response_data.append(
f"----Analyst {point['analyst'] + 1}----"
)
formatted_response_data.append(
f"Importance Score: {point['score']}" # type: ignore
)
formatted_response_data.append(point["answer"]) # type: ignore
formatted_response_text = "\n".join(formatted_response_data)
if (
total_tokens
+ num_tokens(formatted_response_text, self.token_encoder)
> self.max_data_tokens
):
break
data.append(formatted_response_text)
total_tokens += num_tokens(formatted_response_text, self.token_encoder)
text_data = "\n\n".join(data)
search_prompt = self.reduce_system_prompt.format(
report_data=text_data, response_type=self.response_type
)
if self.allow_general_knowledge:
search_prompt += "\n" + self.general_knowledge_inclusion_prompt
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
search_response = await self.llm.agenerate(
search_messages,
streaming=True,
callbacks=self.callbacks, # type: ignore
**llm_kwargs, # type: ignore
)
return SearchResult(
response=search_response,
context_data=text_data,
context_text=text_data,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=num_tokens(search_response, self.token_encoder),
)
except Exception:
log.exception("Exception in reduce_response")
return SearchResult(
response="",
context_data=text_data,
context_text=text_data,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)
async def _stream_reduce_response(
self,
map_responses: list[SearchResult],
query: str,
**llm_kwargs,
) -> AsyncGenerator[str, None]:
# collect all key points into a single list to prepare for sorting
key_points = []
for index, response in enumerate(map_responses):
if not isinstance(response.response, list):
continue
for element in response.response:
if not isinstance(element, dict):
continue
if "answer" not in element or "score" not in element:
continue
key_points.append({
"analyst": index,
"answer": element["answer"],
"score": element["score"],
})
# filter response with score = 0 and rank responses by descending order of score
filtered_key_points = [
point
for point in key_points
if point["score"] > 0 # type: ignore
]
if len(filtered_key_points) == 0 and not self.allow_general_knowledge:
# return no data answer if no key points are found
log.warning(
"Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations."
)
yield NO_DATA_ANSWER
return
filtered_key_points = sorted(
filtered_key_points,
key=lambda x: x["score"], # type: ignore
reverse=True, # type: ignore
)
data = []
total_tokens = 0
for point in filtered_key_points:
formatted_response_data = [
f"----Analyst {point['analyst'] + 1}----",
f"Importance Score: {point['score']}",
point["answer"],
]
formatted_response_text = "\n".join(formatted_response_data)
if (
total_tokens + num_tokens(formatted_response_text, self.token_encoder)
> self.max_data_tokens
):
break
data.append(formatted_response_text)
total_tokens += num_tokens(formatted_response_text, self.token_encoder)
text_data = "\n\n".join(data)
search_prompt = self.reduce_system_prompt.format(
report_data=text_data, response_type=self.response_type
)
if self.allow_general_knowledge:
search_prompt += "\n" + self.general_knowledge_inclusion_prompt
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
async for resp in self.llm.astream_generate( # type: ignore
search_messages,
callbacks=self.callbacks, # type: ignore
**llm_kwargs, # type: ignore
):
yield resp