Files
graphrag-microsoft/graphrag/query/structured_search/local_search/search.py
Josh Bradley b8b949f3bb Cleanup query api - remove code duplication (#1690)
* consolidate query api functions and remove code duplication

* refactor and remove more code duplication

* Add semversioner file

* fix basic search

* fix drift search and update base class function names

* update example notebooks
2025-02-13 16:31:08 -05:00

161 lines
5.9 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""LocalSearch implementation."""
import logging
import time
from collections.abc import AsyncGenerator
from typing import Any
import tiktoken
from graphrag.prompts.query.local_search_system_prompt import (
LOCAL_SEARCH_SYSTEM_PROMPT,
)
from graphrag.query.context_builder.builders import LocalContextBuilder
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
DEFAULT_LLM_PARAMS = {
"max_tokens": 1500,
"temperature": 0.0,
}
log = logging.getLogger(__name__)
class LocalSearch(BaseSearch[LocalContextBuilder]):
"""Search orchestration for local search mode."""
def __init__(
self,
llm: BaseLLM,
context_builder: LocalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
system_prompt: str | None = None,
response_type: str = "multiple paragraphs",
callbacks: list[BaseLLMCallback] | None = None,
llm_params: dict[str, Any] = DEFAULT_LLM_PARAMS,
context_builder_params: dict | None = None,
):
super().__init__(
llm=llm,
context_builder=context_builder,
token_encoder=token_encoder,
llm_params=llm_params,
context_builder_params=context_builder_params or {},
)
self.system_prompt = system_prompt or LOCAL_SEARCH_SYSTEM_PROMPT
self.callbacks = callbacks
self.response_type = response_type
async def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> SearchResult:
"""Build local search context that fits a single context window and generate answer for the user query."""
start_time = time.time()
search_prompt = ""
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
context_result = self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**kwargs,
**self.context_builder_params,
)
llm_calls["build_context"] = context_result.llm_calls
prompt_tokens["build_context"] = context_result.prompt_tokens
output_tokens["build_context"] = context_result.output_tokens
log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
try:
if "drift_query" in kwargs:
drift_query = kwargs["drift_query"]
search_prompt = self.system_prompt.format(
context_data=context_result.context_chunks,
response_type=self.response_type,
global_query=drift_query,
)
else:
search_prompt = self.system_prompt.format(
context_data=context_result.context_chunks,
response_type=self.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
response = await self.llm.agenerate(
messages=search_messages,
streaming=True,
callbacks=self.callbacks,
**self.llm_params,
)
llm_calls["response"] = 1
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
output_tokens["response"] = num_tokens(response, self.token_encoder)
return SearchResult(
response=response,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=sum(llm_calls.values()),
prompt_tokens=sum(prompt_tokens.values()),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
except Exception:
log.exception("Exception in _asearch")
return SearchResult(
response="",
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)
async def stream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator:
"""Build local search context that fits a single context window and generate answer for the user query."""
start_time = time.time()
context_result = self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**self.context_builder_params,
)
log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
search_prompt = self.system_prompt.format(
context_data=context_result.context_chunks, response_type=self.response_type
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
# send context records first before sending the reduce response
yield context_result.context_records
async for response in self.llm.astream_generate( # type: ignore
messages=search_messages,
callbacks=self.callbacks,
**self.llm_params,
):
yield response