Files
graphrag-microsoft/graphrag/query/structured_search/base.py
Josh Bradley b8b949f3bb Cleanup query api - remove code duplication (#1690)
* consolidate query api functions and remove code duplication

* refactor and remove more code duplication

* Add semversioner file

* fix basic search

* fix drift search and update base class function names

* update example notebooks
2025-02-13 16:31:08 -05:00

91 lines
2.6 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Base classes for search algos."""
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
import pandas as pd
import tiktoken
from graphrag.query.context_builder.builders import (
BasicContextBuilder,
DRIFTContextBuilder,
GlobalContextBuilder,
LocalContextBuilder,
)
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.llm.base import BaseLLM
@dataclass
class SearchResult:
"""A Structured Search Result."""
response: str | dict[str, Any] | list[dict[str, Any]]
context_data: str | list[pd.DataFrame] | dict[str, pd.DataFrame]
# actual text strings that are in the context window, built from context_data
context_text: str | list[str] | dict[str, str]
completion_time: float
# total LLM calls and token usage
llm_calls: int
prompt_tokens: int
output_tokens: int
# breakdown of LLM calls and token usage
llm_calls_categories: dict[str, int] | None = None
prompt_tokens_categories: dict[str, int] | None = None
output_tokens_categories: dict[str, int] | None = None
T = TypeVar(
"T",
GlobalContextBuilder,
LocalContextBuilder,
DRIFTContextBuilder,
BasicContextBuilder,
)
class BaseSearch(ABC, Generic[T]):
"""The Base Search implementation."""
def __init__(
self,
llm: BaseLLM,
context_builder: T,
token_encoder: tiktoken.Encoding | None = None,
llm_params: dict[str, Any] | None = None,
context_builder_params: dict[str, Any] | None = None,
):
self.llm = llm
self.context_builder = context_builder
self.token_encoder = token_encoder
self.llm_params = llm_params or {}
self.context_builder_params = context_builder_params or {}
@abstractmethod
async def search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> SearchResult:
"""Search for the given query asynchronously."""
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)
@abstractmethod
def stream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator[Any, None]:
"""Stream search for the given query."""
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)