Feat/llm provider query (#1735)

* Add ModelProvider to Query package.

* Spellcheck + others

* Semver

* Fix tests

* Format

* Fix Pyright

* Fix tests

* Fix for smoke tests
This commit is contained in:
Alonso Guevara
2025-02-24 18:35:51 -06:00
committed by GitHub
parent faa05b691f
commit e0d233fe10
57 changed files with 919 additions and 1371 deletions

View File

@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Use ModelProvider for query module"
}

View File

@@ -120,10 +120,7 @@ unhot
groupby
retryer
agenerate
aembed
dedupe
dropna
dtypes
notna
# LLM Terms
@@ -131,6 +128,8 @@ AOAI
embedder
llm
llms
achat
aembed
# Galaxy-Brain Terms
Unipartite

View File

@@ -28,3 +28,6 @@ class QueryCallbacks(BaseLLMCallback):
def on_reduce_response_end(self, reduce_response_output: str) -> None:
"""Handle the end of reduce operation."""
def on_llm_new_token(self, token) -> None:
"""Handle when a new token is generated."""

View File

@@ -48,6 +48,8 @@ class BasicSearchDefaults:
n: int = 1
max_tokens: int = 12_000
llm_max_tokens: int = 2000
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
@dataclass
@@ -122,7 +124,9 @@ class DriftSearchDefaults:
local_search_temperature: float = 0
local_search_top_p: float = 1
local_search_n: int = 1
local_search_llm_max_gen_tokens: int = 12_000
local_search_llm_max_gen_tokens: int = 4_096
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
@dataclass
@@ -239,6 +243,8 @@ class GlobalSearchDefaults:
dynamic_search_use_summary: bool = False
dynamic_search_concurrent_coroutines: int = 16
dynamic_search_max_level: int = 2
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
@dataclass
@@ -305,6 +311,8 @@ class LocalSearchDefaults:
n: int = 1
max_tokens: int = 12_000
llm_max_tokens: int = 2000
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
@dataclass

View File

@@ -145,18 +145,26 @@ snapshots:
## See the config docs: https://microsoft.github.io/graphrag/config/yaml/#query
local_search:
chat_model_id: {graphrag_config_defaults.local_search.chat_model_id}
embedding_model_id: {graphrag_config_defaults.local_search.embedding_model_id}
prompt: "prompts/local_search_system_prompt.txt"
global_search:
chat_model_id: {graphrag_config_defaults.global_search.chat_model_id}
embedding_model_id: {graphrag_config_defaults.global_search.embedding_model_id}
map_prompt: "prompts/global_search_map_system_prompt.txt"
reduce_prompt: "prompts/global_search_reduce_system_prompt.txt"
knowledge_prompt: "prompts/global_search_knowledge_system_prompt.txt"
drift_search:
chat_model_id: {graphrag_config_defaults.drift_search.chat_model_id}
embedding_model_id: {graphrag_config_defaults.drift_search.embedding_model_id}
prompt: "prompts/drift_search_system_prompt.txt"
reduce_prompt: "prompts/drift_search_reduce_prompt.txt"
basic_search:
chat_model_id: {graphrag_config_defaults.basic_search.chat_model_id}
embedding_model_id: {graphrag_config_defaults.basic_search.embedding_model_id}
prompt: "prompts/basic_search_system_prompt.txt"
"""

View File

@@ -15,6 +15,14 @@ class BasicSearchConfig(BaseModel):
description="The basic search prompt to use.",
default=graphrag_config_defaults.basic_search.prompt,
)
chat_model_id: str = Field(
description="The model ID to use for basic search.",
default=graphrag_config_defaults.basic_search.chat_model_id,
)
embedding_model_id: str = Field(
description="The model ID to use for text embeddings.",
default=graphrag_config_defaults.basic_search.embedding_model_id,
)
text_unit_prop: float = Field(
description="The text unit proportion.",
default=graphrag_config_defaults.basic_search.text_unit_prop,

View File

@@ -19,6 +19,14 @@ class DRIFTSearchConfig(BaseModel):
description="The drift search reduce prompt to use.",
default=graphrag_config_defaults.drift_search.reduce_prompt,
)
chat_model_id: str = Field(
description="The model ID to use for drift search.",
default=graphrag_config_defaults.drift_search.chat_model_id,
)
embedding_model_id: str = Field(
description="The model ID to use for drift search.",
default=graphrag_config_defaults.drift_search.embedding_model_id,
)
temperature: float = Field(
description="The temperature to use for token generation.",
default=graphrag_config_defaults.drift_search.temperature,

View File

@@ -19,6 +19,10 @@ class GlobalSearchConfig(BaseModel):
description="The global search reducer to use.",
default=graphrag_config_defaults.global_search.reduce_prompt,
)
chat_model_id: str = Field(
description="The model ID to use for global search.",
default=graphrag_config_defaults.global_search.chat_model_id,
)
knowledge_prompt: str | None = Field(
description="The global search general prompt to use.",
default=graphrag_config_defaults.global_search.knowledge_prompt,

View File

@@ -15,6 +15,14 @@ class LocalSearchConfig(BaseModel):
description="The local search prompt to use.",
default=graphrag_config_defaults.local_search.prompt,
)
chat_model_id: str = Field(
description="The model ID to use for local search.",
default=graphrag_config_defaults.local_search.chat_model_id,
)
embedding_model_id: str = Field(
description="The model ID to use for text embeddings.",
default=graphrag_config_defaults.local_search.embedding_model_id,
)
text_unit_prop: float = Field(
description="The text unit proportion.",
default=graphrag_config_defaults.local_search.text_unit_prop,

View File

@@ -88,7 +88,7 @@ async def _execute(
) -> list[list[float]]:
async def embed(chunk: list[str]):
async with semaphore:
chunk_embeddings = await model.embed(chunk)
chunk_embeddings = await model.aembed_batch(chunk)
result = np.array(chunk_embeddings)
tick(1)
return result

View File

@@ -166,7 +166,7 @@ class ClaimExtractor:
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
)
response = await self._model.chat(
response = await self._model.achat(
self._extraction_prompt.format(**{
self._input_text_key: doc,
**prompt_args,
@@ -177,7 +177,7 @@ class ClaimExtractor:
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
response = await self._model.chat(
response = await self._model.achat(
CONTINUE_PROMPT,
name=f"extract-continuation-{i}",
history=response.history,
@@ -191,7 +191,7 @@ class ClaimExtractor:
if i >= self._max_gleanings - 1:
break
response = await self._model.chat(
response = await self._model.achat(
LOOP_PROMPT,
name=f"extract-loopcheck-{i}",
history=response.history,

View File

@@ -152,7 +152,7 @@ class GraphExtractor:
async def _process_document(
self, text: str, prompt_variables: dict[str, str]
) -> str:
response = await self._model.chat(
response = await self._model.achat(
self._extraction_prompt.format(**{
**prompt_variables,
self._input_text_key: text,
@@ -162,7 +162,7 @@ class GraphExtractor:
# Repeat to ensure we maximize entity count
for i in range(self._max_gleanings):
response = await self._model.chat(
response = await self._model.achat(
CONTINUE_PROMPT,
name=f"extract-continuation-{i}",
history=response.history,
@@ -173,7 +173,7 @@ class GraphExtractor:
if i >= self._max_gleanings - 1:
break
response = await self._model.chat(
response = await self._model.achat(
LOOP_PROMPT,
name=f"extract-loopcheck-{i}",
history=response.history,

View File

@@ -78,7 +78,7 @@ class CommunityReportsExtractor:
prompt = self._extraction_prompt.replace(
"{" + self._input_text_key + "}", input_text
)
response = await self._model.chat(
response = await self._model.achat(
prompt,
json=True, # Leaving this as True to avoid creating new cache entries
name="create_community_report",

View File

@@ -125,7 +125,7 @@ class SummarizeExtractor:
self, id: str | tuple[str, str] | list[str], descriptions: list[str]
):
"""Summarize descriptions using the LLM."""
response = await self._model.chat(
response = await self._model.achat(
self._summarization_prompt.format(**{
self._entity_name_key: json.dumps(id, ensure_ascii=False),
self._input_descriptions_key: json.dumps(

View File

@@ -30,7 +30,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
)
try:
asyncio.run(llm.chat("This is an LLM connectivity test. Say Hello World"))
asyncio.run(llm.achat("This is an LLM connectivity test. Say Hello World"))
logger.success("LLM Config Params Validated")
except Exception as e: # noqa: BLE001
logger.error(f"LLM configuration error detected. Exiting...\n{e}") # noqa
@@ -49,7 +49,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
)
try:
asyncio.run(embed_llm.embed(["This is an LLM Embedding Test String"]))
asyncio.run(embed_llm.aembed_batch(["This is an LLM Embedding Test String"]))
logger.success("Embedding LLM Config Params Validated")
except Exception as e: # noqa: BLE001
logger.error(f"Embedding LLM configuration error detected. Exiting...\n{e}") # noqa

View File

@@ -8,6 +8,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from graphrag.language_model.response.base import ModelResponse
@@ -18,7 +20,51 @@ class EmbeddingModel(Protocol):
This protocol defines the methods required for an embedding-based LM.
"""
async def embed(self, text: str | list[str], **kwargs: Any) -> list[list[float]]:
async def aembed_batch(
self, text_list: list[str], **kwargs: Any
) -> list[list[float]]:
"""
Generate an embedding vector for the given list of strings.
Args:
text: The text to generate an embedding for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
A collections of list of floats representing the embedding vector for each item in the batch.
"""
...
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate an embedding vector for the given text.
Args:
text: The text to generate an embedding for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
A list of floats representing the embedding vector.
"""
...
def embed_batch(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
"""
Generate an embedding vector for the given list of strings.
Args:
text: The text to generate an embedding for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
A collections of list of floats representing the embedding vector for each item in the batch.
"""
...
def embed(self, text: str, **kwargs: Any) -> list[float]:
"""
Generate an embedding vector for the given text.
@@ -41,12 +87,15 @@ class ChatModel(Protocol):
Prompt is always required for the chat method, and any other keyword arguments are forwarded to the Model provider.
"""
async def chat(self, prompt: str, **kwargs: Any) -> ModelResponse:
async def achat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> ModelResponse:
"""
Generate a response for the given text.
Args:
prompt: The text to generate a response for.
history: The conversation history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
@@ -55,3 +104,55 @@ class ChatModel(Protocol):
"""
...
async def achat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> AsyncGenerator[str, None]:
"""
Generate a response for the given text using a streaming interface.
Args:
prompt: The text to generate a response for.
history: The conversation history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
A generator that yields strings representing the response.
"""
...
def chat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> ModelResponse:
"""
Generate a response for the given text.
Args:
prompt: The text to generate a response for.
history: The conversation history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
A string representing the response.
"""
...
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> AsyncGenerator[str, None]:
"""
Generate a response for the given text using a streaming interface.
Args:
prompt: The text to generate a response for.
history: The conversation history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
A generator that yields strings representing the response.
"""
...

View File

@@ -3,13 +3,15 @@
"""A module containing fnllm model provider definitions."""
from collections.abc import AsyncGenerator
from fnllm.openai import (
create_openai_chat_llm,
create_openai_client,
create_openai_embeddings_llm,
)
from fnllm.types import ChatLLM as FNLLMChatLLM
from fnllm.types import EmbeddingsLLM as FNLLMEmbeddingLLM
from fnllm.openai.types.client import OpenAIChatLLM as FNLLMChatLLM
from fnllm.openai.types.client import OpenAIEmbeddingsLLM as FNLLMEmbeddingLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@@ -21,6 +23,7 @@ from graphrag.language_model.providers.fnllm.utils import (
_create_cache,
_create_error_handler,
_create_openai_config,
run_coroutine_sync,
)
from graphrag.language_model.response.base import (
BaseModelOutput,
@@ -39,21 +42,23 @@ class OpenAIChatFNLLM:
*,
name: str,
config: LanguageModelConfig,
callbacks: WorkflowCallbacks,
cache: PipelineCache | None,
callbacks: WorkflowCallbacks | None = None,
cache: PipelineCache | None = None,
) -> None:
model_config = _create_openai_config(config, False)
error_handler = _create_error_handler(callbacks)
model_config = _create_openai_config(config, azure=False)
error_handler = _create_error_handler(callbacks) if callbacks else None
model_cache = _create_cache(cache, name)
client = create_openai_client(model_config)
self.model = create_openai_chat_llm(
model_config,
client=client,
cache=model_cache,
events=FNLLMEvents(error_handler),
events=FNLLMEvents(error_handler) if error_handler else None,
)
async def chat(self, prompt: str, **kwargs) -> ModelResponse:
async def achat(
self, prompt: str, history: list | None = None, **kwargs
) -> ModelResponse:
"""
Chat with the Model using the given prompt.
@@ -65,7 +70,10 @@ class OpenAIChatFNLLM:
-------
The response from the Model.
"""
response = await self.model(prompt, **kwargs)
if history is None:
response = await self.model(prompt, **kwargs)
else:
response = await self.model(prompt, history=history, **kwargs)
return BaseModelResponse(
output=BaseModelOutput(content=response.output.content),
parsed_response=response.parsed_json,
@@ -75,6 +83,59 @@ class OpenAIChatFNLLM:
metrics=response.metrics,
)
async def achat_stream(
self, prompt: str, history: list | None = None, **kwargs
) -> AsyncGenerator[str, None]:
"""
Stream Chat with the Model using the given prompt.
Args:
prompt: The prompt to chat with.
kwargs: Additional arguments to pass to the Model.
Returns
-------
A generator that yields strings representing the response.
"""
if history is None:
response = await self.model(prompt, stream=True, **kwargs)
else:
response = await self.model(prompt, history=history, stream=True, **kwargs)
async for chunk in response.output.content:
if chunk is not None:
yield chunk
def chat(self, prompt: str, history: list | None = None, **kwargs) -> ModelResponse:
"""
Chat with the Model using the given prompt.
Args:
prompt: The prompt to chat with.
kwargs: Additional arguments to pass to the Model.
Returns
-------
The response from the Model.
"""
return run_coroutine_sync(self.achat(prompt, history=history, **kwargs))
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs
) -> AsyncGenerator[str, None]:
"""
Stream Chat with the Model using the given prompt.
Args:
prompt: The prompt to chat with.
kwargs: Additional arguments to pass to the Model.
Returns
-------
A generator that yields strings representing the response.
"""
msg = "chat_stream is not supported for synchronous execution"
raise NotImplementedError(msg)
class OpenAIEmbeddingFNLLM:
"""An OpenAI Embedding Model provider using the fnllm library."""
@@ -86,21 +147,21 @@ class OpenAIEmbeddingFNLLM:
*,
name: str,
config: LanguageModelConfig,
callbacks: WorkflowCallbacks,
cache: PipelineCache | None,
callbacks: WorkflowCallbacks | None = None,
cache: PipelineCache | None = None,
) -> None:
model_config = _create_openai_config(config, False)
error_handler = _create_error_handler(callbacks)
model_config = _create_openai_config(config, azure=False)
error_handler = _create_error_handler(callbacks) if callbacks else None
model_cache = _create_cache(cache, name)
client = create_openai_client(model_config)
self.model = create_openai_embeddings_llm(
model_config,
client=client,
cache=model_cache,
events=FNLLMEvents(error_handler),
events=FNLLMEvents(error_handler) if error_handler else None,
)
async def embed(self, text: str | list[str], **kwargs) -> list[list[float]]:
async def aembed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
"""
Embed the given text using the Model.
@@ -112,13 +173,60 @@ class OpenAIEmbeddingFNLLM:
-------
The embeddings of the text.
"""
response = await self.model(text, **kwargs)
response = await self.model(text_list, **kwargs)
if response.output.embeddings is None:
msg = "No embeddings found in response"
raise ValueError(msg)
embeddings: list[list[float]] = response.output.embeddings
return embeddings
async def aembed(self, text: str, **kwargs) -> list[float]:
"""
Embed the given text using the Model.
Args:
text: The text to embed.
kwargs: Additional arguments to pass to the Model.
Returns
-------
The embeddings of the text.
"""
response = await self.model([text], **kwargs)
if response.output.embeddings is None:
msg = "No embeddings found in response"
raise ValueError(msg)
embeddings: list[float] = response.output.embeddings[0]
return embeddings
def embed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
"""
Embed the given text using the Model.
Args:
text: The text to embed.
kwargs: Additional arguments to pass to the LLM.
Returns
-------
The embeddings of the text.
"""
return run_coroutine_sync(self.aembed_batch(text_list, **kwargs))
def embed(self, text: str, **kwargs) -> list[float]:
"""
Embed the given text using the Model.
Args:
text: The text to embed.
kwargs: Additional arguments to pass to the Model.
Returns
-------
The embeddings of the text.
"""
return run_coroutine_sync(self.aembed(text, **kwargs))
class AzureOpenAIChatFNLLM:
"""An Azure OpenAI Chat LLM provider using the fnllm library."""
@@ -130,21 +238,72 @@ class AzureOpenAIChatFNLLM:
*,
name: str,
config: LanguageModelConfig,
callbacks: WorkflowCallbacks,
cache: PipelineCache | None,
callbacks: WorkflowCallbacks | None = None,
cache: PipelineCache | None = None,
) -> None:
model_config = _create_openai_config(config, True)
error_handler = _create_error_handler(callbacks)
model_config = _create_openai_config(config, azure=True)
error_handler = _create_error_handler(callbacks) if callbacks else None
model_cache = _create_cache(cache, name)
client = create_openai_client(model_config)
self.model = create_openai_chat_llm(
model_config,
client=client,
cache=model_cache,
events=FNLLMEvents(error_handler),
events=FNLLMEvents(error_handler) if error_handler else None,
)
async def chat(self, prompt: str, **kwargs) -> ModelResponse:
async def achat(
self, prompt: str, history: list | None = None, **kwargs
) -> ModelResponse:
"""
Chat with the Model using the given prompt.
Args:
prompt: The prompt to chat with.
history: The conversation history.
kwargs: Additional arguments to pass to the Model.
Returns
-------
The response from the Model.
"""
if history is None:
response = await self.model(prompt, **kwargs)
else:
response = await self.model(prompt, history=history, **kwargs)
return BaseModelResponse(
output=BaseModelOutput(content=response.output.content),
parsed_response=response.parsed_json,
history=response.history,
cache_hit=response.cache_hit,
tool_calls=response.tool_calls,
metrics=response.metrics,
)
async def achat_stream(
self, prompt: str, history: list | None = None, **kwargs
) -> AsyncGenerator[str, None]:
"""
Stream Chat with the Model using the given prompt.
Args:
prompt: The prompt to chat with.
history: The conversation history.
kwargs: Additional arguments to pass to the Model.
Returns
-------
A generator that yields strings representing the response.
"""
if history is None:
response = await self.model(prompt, stream=True, **kwargs)
else:
response = await self.model(prompt, history=history, stream=True, **kwargs)
async for chunk in response.output.content:
if chunk is not None:
yield chunk
def chat(self, prompt: str, history: list | None = None, **kwargs) -> ModelResponse:
"""
Chat with the Model using the given prompt.
@@ -156,15 +315,24 @@ class AzureOpenAIChatFNLLM:
-------
The response from the Model.
"""
response = await self.model(prompt, **kwargs)
return BaseModelResponse(
output=BaseModelOutput(content=response.output.content),
parsed_response=response.parsed_json,
history=response.history,
cache_hit=response.cache_hit,
tool_calls=response.tool_calls,
metrics=response.metrics,
)
return run_coroutine_sync(self.achat(prompt, history=history, **kwargs))
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs
) -> AsyncGenerator[str, None]:
"""
Stream Chat with the Model using the given prompt.
Args:
prompt: The prompt to chat with.
kwargs: Additional arguments to pass to the Model.
Returns
-------
A generator that yields strings representing the response.
"""
msg = "chat_stream is not supported for synchronous execution"
raise NotImplementedError(msg)
class AzureOpenAIEmbeddingFNLLM:
@@ -177,21 +345,21 @@ class AzureOpenAIEmbeddingFNLLM:
*,
name: str,
config: LanguageModelConfig,
callbacks: WorkflowCallbacks,
cache: PipelineCache | None,
callbacks: WorkflowCallbacks | None = None,
cache: PipelineCache | None = None,
) -> None:
model_config = _create_openai_config(config, True)
error_handler = _create_error_handler(callbacks)
model_config = _create_openai_config(config, azure=True)
error_handler = _create_error_handler(callbacks) if callbacks else None
model_cache = _create_cache(cache, name)
client = create_openai_client(model_config)
self.model = create_openai_embeddings_llm(
model_config,
client=client,
cache=model_cache,
events=FNLLMEvents(error_handler),
events=FNLLMEvents(error_handler) if error_handler else None,
)
async def embed(self, text: str | list[str], **kwargs) -> list[list[float]]:
async def aembed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
"""
Embed the given text using the Model.
@@ -203,9 +371,56 @@ class AzureOpenAIEmbeddingFNLLM:
-------
The embeddings of the text.
"""
response = await self.model(text, **kwargs)
response = await self.model(text_list, **kwargs)
if response.output.embeddings is None:
msg = "No embeddings found in response"
raise ValueError(msg)
embeddings: list[list[float]] = response.output.embeddings
return embeddings
async def aembed(self, text: str, **kwargs) -> list[float]:
"""
Embed the given text using the Model.
Args:
text: The text to embed.
kwargs: Additional arguments to pass to the Model.
Returns
-------
The embeddings of the text.
"""
response = await self.model([text], **kwargs)
if response.output.embeddings is None:
msg = "No embeddings found in response"
raise ValueError(msg)
embeddings: list[float] = response.output.embeddings[0]
return embeddings
def embed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
"""
Embed the given text using the Model.
Args:
text: The text to embed.
kwargs: Additional arguments to pass to the Model.
Returns
-------
The embeddings of the text.
"""
return run_coroutine_sync(self.aembed_batch(text_list, **kwargs))
def embed(self, text: str, **kwargs) -> list[float]:
"""
Embed the given text using the Model.
Args:
text: The text to embed.
kwargs: Additional arguments to pass to the Model.
Returns
-------
The embeddings of the text.
"""
return run_coroutine_sync(self.aembed(text, **kwargs))

View File

@@ -3,6 +3,11 @@
"""A module containing utils for fnllm."""
import asyncio
import threading
from collections.abc import Coroutine
from typing import Any, TypeVar
from fnllm.base.config import JsonStrategy, RetryStrategy
from fnllm.openai import AzureOpenAIConfig, OpenAIConfig, PublicOpenAIConfig
from fnllm.openai.types.chat.parameters import OpenAIChatParameters
@@ -25,6 +30,8 @@ def _create_cache(cache: PipelineCache | None, name: str) -> FNLLMCacheProvider
def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
"""Create an error handler from a WorkflowCallbacks."""
def on_error(
error: BaseException | None = None,
stack: str | None = None,
@@ -36,6 +43,7 @@ def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAIConfig:
"""Create an OpenAIConfig from a LanguageModelConfig."""
encoding_model = config.encoding_model
json_strategy = (
JsonStrategy.VALID if config.model_supports_json else JsonStrategy.LOOSE
@@ -92,3 +100,28 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon
chat_parameters=chat_parameters,
sleep_on_rate_limit_recommendation=True,
)
# FNLLM does not support sync operations, so we workaround running in an available loop/thread.
T = TypeVar("T")
_loop = asyncio.new_event_loop()
_thr = threading.Thread(target=_loop.run_forever, name="Async Runner", daemon=True)
def run_coroutine_sync(coroutine: Coroutine[Any, Any, T]) -> T:
"""
Run a coroutine synchronously.
Args:
coroutine: The coroutine to run.
Returns
-------
The result of the coroutine.
"""
if not _thr.is_alive():
_thr.start()
future = asyncio.run_coroutine_threadsafe(coroutine, _loop)
return future.result()

View File

@@ -30,6 +30,6 @@ async def generate_community_report_rating(
domain=domain, persona=persona, input_text=docs_str
)
response = await model.chat(domain_prompt)
response = await model.achat(domain_prompt)
return str(response.output.content).strip()

View File

@@ -30,6 +30,6 @@ async def generate_community_reporter_role(
domain=domain, persona=persona, input_text=docs_str
)
response = await model.chat(domain_prompt)
response = await model.achat(domain_prompt)
return str(response.output.content)

View File

@@ -22,6 +22,6 @@ async def generate_domain(model: ChatModel, docs: str | list[str]) -> str:
docs_str = " ".join(docs) if isinstance(docs, list) else docs
domain_prompt = GENERATE_DOMAIN_PROMPT.format(input_text=docs_str)
response = await model.chat(domain_prompt)
response = await model.achat(domain_prompt)
return str(response.output.content)

View File

@@ -57,7 +57,7 @@ async def generate_entity_relationship_examples(
messages = messages[:MAX_EXAMPLES]
tasks = [
model.chat(message, history=history, json=json_mode) for message in messages
model.achat(message, history=history, json=json_mode) for message in messages
]
responses = await asyncio.gather(*tasks)

View File

@@ -46,11 +46,11 @@ async def generate_entity_types(
history = [{"role": "system", "content": persona}]
if json_mode:
response = await model.chat(
response = await model.achat(
entity_types_prompt, history=history, json_model=EntityTypesResponse
)
parsed_model = response.parsed_response
return parsed_model.entity_types if parsed_model else []
response = await model.chat(entity_types_prompt, history=history, json=json_mode)
response = await model.achat(entity_types_prompt, history=history, json=json_mode)
return str(response.output.content)

View File

@@ -22,6 +22,6 @@ async def detect_language(model: ChatModel, docs: str | list[str]) -> str:
docs_str = " ".join(docs) if isinstance(docs, list) else docs
language_prompt = DETECT_LANGUAGE_PROMPT.format(input_text=docs_str)
response = await model.chat(language_prompt)
response = await model.achat(language_prompt)
return str(response.output.content)

View File

@@ -22,6 +22,6 @@ async def generate_persona(
formatted_task = task.format(domain=domain)
persona_prompt = GENERATE_PERSONA_PROMPT.format(sample_task=formatted_task)
response = await model.chat(persona_prompt)
response = await model.achat(persona_prompt)
return str(response.output.content)

View File

@@ -30,7 +30,9 @@ async def _embed_chunks(
) -> tuple[pd.DataFrame, np.ndarray]:
"""Convert text chunks into dense text embeddings."""
sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks)))
embeddings = await embedding_llm.embed(sampled_text_chunks["chunks"].tolist())
embeddings = await embedding_llm.aembed_batch(
sampled_text_chunks["chunks"].tolist()
)
return text_chunks, np.array(embeddings)

View File

@@ -54,7 +54,7 @@ class DRIFTContextBuilder(ABC):
"""Base class for DRIFT-search context builders."""
@abstractmethod
def build_context(
async def build_context(
self,
query: str,
**kwargs,

View File

@@ -14,9 +14,9 @@ import tiktoken
from graphrag.data_model.community import Community
from graphrag.data_model.community_report import CommunityReport
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
from graphrag.query.context_builder.rate_relevancy import rate_relevancy
from graphrag.query.llm.base import BaseLLM
log = logging.getLogger(__name__)
@@ -33,7 +33,7 @@ class DynamicCommunitySelection:
self,
community_reports: list[CommunityReport],
communities: list[Community],
llm: BaseLLM,
model: ChatModel,
token_encoder: tiktoken.Encoding,
rate_query: str = RATE_QUERY,
use_summary: bool = False,
@@ -44,7 +44,7 @@ class DynamicCommunitySelection:
concurrent_coroutines: int = 8,
llm_kwargs: Any = DEFAULT_RATE_LLM_PARAMS,
):
self.llm = llm
self.model = model
self.token_encoder = token_encoder
self.rate_query = rate_query
self.num_repeats = num_repeats
@@ -98,7 +98,7 @@ class DynamicCommunitySelection:
if self.use_summary
else self.reports[community].full_content
),
llm=self.llm,
model=self.model,
token_encoder=self.token_encoder,
rate_query=self.rate_query,
num_repeats=self.num_repeats,

View File

@@ -7,12 +7,12 @@ from enum import Enum
from graphrag.data_model.entity import Entity
from graphrag.data_model.relationship import Relationship
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.query.input.retrieval.entities import (
get_entity_by_id,
get_entity_by_key,
get_entity_by_name,
)
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.vector_stores.base import BaseVectorStore
@@ -37,7 +37,7 @@ class EntityVectorStoreKey(str, Enum):
def map_query_to_entities(
query: str,
text_embedding_vectorstore: BaseVectorStore,
text_embedder: BaseTextEmbedding,
text_embedder: EmbeddingModel,
all_entities_dict: dict[str, Entity],
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
include_entity_names: list[str] | None = None,

View File

@@ -11,8 +11,8 @@ from typing import Any
import numpy as np
import tiktoken
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object
log = logging.getLogger(__name__)
@@ -21,7 +21,7 @@ log = logging.getLogger(__name__)
async def rate_relevancy(
query: str,
description: str,
llm: BaseLLM,
model: ChatModel,
token_encoder: tiktoken.Encoding,
rate_query: str = RATE_QUERY,
num_repeats: int = 1,
@@ -47,11 +47,13 @@ async def rate_relevancy(
"role": "system",
"content": rate_query.format(description=description, question=query),
},
{"role": "user", "content": query},
]
for _ in range(num_repeats):
async with semaphore if semaphore is not None else nullcontext():
response = await llm.agenerate(messages=messages, **llm_kwargs)
model_response = await model.achat(
prompt=query, history=messages, model_parameters=llm_kwargs, json=True
)
response = model_response.output.content
try:
_, parsed_response = try_parse_json_object(response)
ratings.append(parsed_response["rating"])

View File

@@ -13,8 +13,8 @@ from graphrag.data_model.covariate import Covariate
from graphrag.data_model.entity import Entity
from graphrag.data_model.relationship import Relationship
from graphrag.data_model.text_unit import TextUnit
from graphrag.language_model.manager import ModelManager
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.get_client import get_llm, get_text_embedder
from graphrag.query.structured_search.basic_search.basic_context import (
BasicSearchContext,
)
@@ -47,15 +47,38 @@ def get_local_search_engine(
callbacks: list[QueryCallbacks] | None = None,
) -> LocalSearch:
"""Create a local search engine based on data + configuration."""
default_llm_settings = config.get_language_model_config("default_chat_model")
llm = get_llm(config)
text_embedder = get_text_embedder(config)
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
model_settings = config.get_language_model_config(config.local_search.chat_model_id)
if model_settings.max_retries == -1:
model_settings.max_retries = (
len(reports) + len(entities) + len(relationships) + len(covariates)
)
chat_model = ModelManager().get_or_create_chat_model(
name="local_search_chat",
model_type=model_settings.type,
config=model_settings,
)
embedding_settings = config.get_language_model_config(
config.local_search.embedding_model_id
)
if embedding_settings.max_retries == -1:
embedding_settings.max_retries = (
len(reports) + len(entities) + len(relationships)
)
embedding_model = ModelManager().get_or_create_embedding_model(
name="local_search_embedding",
model_type=embedding_settings.type,
config=embedding_settings,
)
token_encoder = tiktoken.get_encoding(model_settings.encoding_model)
ls_config = config.local_search
return LocalSearch(
llm=llm,
model=chat_model,
system_prompt=system_prompt,
context_builder=LocalSearchMixedContext(
community_reports=reports,
@@ -65,11 +88,11 @@ def get_local_search_engine(
covariates=covariates,
entity_text_embeddings=description_embedding_store,
embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE
text_embedder=text_embedder,
text_embedder=embedding_model,
token_encoder=token_encoder,
),
token_encoder=token_encoder,
llm_params={
model_params={
"max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500)
"temperature": ls_config.temperature,
"top_p": ls_config.top_p,
@@ -108,10 +131,20 @@ def get_global_search_engine(
) -> GlobalSearch:
"""Create a global search engine based on data + configuration."""
# TODO: Global search should select model based on config??
default_llm_settings = config.get_language_model_config("default_chat_model")
model_settings = config.get_language_model_config(
config.global_search.chat_model_id
)
if model_settings.max_retries == -1:
model_settings.max_retries = len(reports) + len(entities)
model = ModelManager().get_or_create_chat_model(
name="global_search",
model_type=model_settings.type,
config=model_settings,
)
# Here we get encoding based on specified encoding name
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
token_encoder = tiktoken.get_encoding(model_settings.encoding_model)
gs_config = config.global_search
dynamic_community_selection_kwargs = {}
@@ -119,9 +152,9 @@ def get_global_search_engine(
# TODO: Allow for another llm definition only for Global Search to leverage -mini models
dynamic_community_selection_kwargs.update({
"llm": get_llm(config),
"model": model,
# And here we get encoding based on model
"token_encoder": tiktoken.encoding_for_model(default_llm_settings.model),
"token_encoder": tiktoken.encoding_for_model(model_settings.model),
"keep_parent": gs_config.dynamic_search_keep_parent,
"num_repeats": gs_config.dynamic_search_num_repeats,
"use_summary": gs_config.dynamic_search_use_summary,
@@ -131,7 +164,7 @@ def get_global_search_engine(
})
return GlobalSearch(
llm=get_llm(config),
model=model,
map_system_prompt=map_system_prompt,
reduce_system_prompt=reduce_system_prompt,
general_knowledge_inclusion_prompt=general_knowledge_inclusion_prompt,
@@ -190,16 +223,43 @@ def get_drift_search_engine(
callbacks: list[QueryCallbacks] | None = None,
) -> DRIFTSearch:
"""Create a local search engine based on data + configuration."""
default_llm_settings = config.get_language_model_config("default_chat_model")
llm = get_llm(config)
text_embedder = get_text_embedder(config)
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
chat_model_settings = config.get_language_model_config(
config.drift_search.chat_model_id
)
if chat_model_settings.max_retries == -1:
chat_model_settings.max_retries = (
config.drift_search.drift_k_followups
* config.drift_search.primer_folds
* config.drift_search.n_depth
)
chat_model = ModelManager().get_or_create_chat_model(
name="drift_search_chat",
model_type=chat_model_settings.type,
config=chat_model_settings,
)
embedding_model_settings = config.get_language_model_config(
config.drift_search.embedding_model_id
)
if embedding_model_settings.max_retries == -1:
embedding_model_settings.max_retries = (
len(reports) + len(entities) + len(relationships)
)
embedding_model = ModelManager().get_or_create_embedding_model(
name="drift_search_embedding",
model_type=embedding_model_settings.type,
config=embedding_model_settings,
)
token_encoder = tiktoken.get_encoding(chat_model_settings.encoding_model)
return DRIFTSearch(
llm=llm,
model=chat_model,
context_builder=DRIFTSearchContextBuilder(
chat_llm=llm,
text_embedder=text_embedder,
model=chat_model,
text_embedder=embedding_model,
entities=entities,
relationships=relationships,
reports=reports,
@@ -223,18 +283,40 @@ def get_basic_search_engine(
callbacks: list[QueryCallbacks] | None = None,
) -> BasicSearch:
"""Create a basic search engine based on data + configuration."""
default_llm_settings = config.get_language_model_config("default_chat_model")
llm = get_llm(config)
text_embedder = get_text_embedder(config)
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
chat_model_settings = config.get_language_model_config(
config.basic_search.chat_model_id
)
if chat_model_settings.max_retries == -1:
chat_model_settings.max_retries = len(text_units)
chat_model = ModelManager().get_or_create_chat_model(
name="basic_search_chat",
model_type=chat_model_settings.type,
config=chat_model_settings,
)
embedding_model_settings = config.get_language_model_config(
config.basic_search.embedding_model_id
)
if embedding_model_settings.max_retries == -1:
embedding_model_settings.max_retries = len(text_units)
embedding_model = ModelManager().get_or_create_embedding_model(
name="basic_search_embedding",
model_type=embedding_model_settings.type,
config=embedding_model_settings,
)
token_encoder = tiktoken.get_encoding(chat_model_settings.encoding_model)
ls_config = config.basic_search
return BasicSearch(
llm=llm,
model=chat_model,
system_prompt=system_prompt,
context_builder=BasicSearchContext(
text_embedder=text_embedder,
text_embedder=embedding_model,
text_unit_embeddings=text_unit_embeddings,
text_units=text_units,
token_encoder=token_encoder,

View File

@@ -18,7 +18,8 @@ from graphrag.data_model.covariate import Covariate
from graphrag.data_model.entity import Entity
from graphrag.data_model.relationship import Relationship
from graphrag.data_model.text_unit import TextUnit
from graphrag.query.factory import get_text_embedder
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.query.input.loaders.dfs import (
read_communities,
read_community_reports,
@@ -27,7 +28,6 @@ from graphrag.query.input.loaders.dfs import (
read_relationships,
read_text_units,
)
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.vector_stores.base import BaseVectorStore
log = logging.getLogger(__name__)
@@ -106,7 +106,15 @@ def read_indexer_reports(
content_embedding_col not in reports_df.columns
or reports_df.loc[:, content_embedding_col].isna().any()
):
embedder = get_text_embedder(config)
# TODO: Find a way to retrieve the right embedding model id.
embedding_model_settings = config.get_language_model_config(
"default_embedding_model"
)
embedder = ModelManager().get_or_create_embedding_model(
name="default_embedding",
model_type=embedding_model_settings.type,
config=embedding_model_settings,
)
reports_df = embed_community_reports(
reports_df, embedder, embedding_col=content_embedding_col
)
@@ -211,7 +219,7 @@ def read_indexer_communities(
def embed_community_reports(
reports_df: pd.DataFrame,
embedder: OpenAIEmbedding,
embedder: EmbeddingModel,
source_col: str = "full_content",
embedding_col: str = "full_content_embedding",
) -> pd.DataFrame:

View File

@@ -1,65 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Base classes for LLM and Embedding models."""
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator
from typing import Any
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
class BaseLLM(ABC):
"""The Base LLM implementation."""
@abstractmethod
def generate(
self,
messages: str | list[Any],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
"""Generate a response."""
@abstractmethod
def stream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> Generator[str, None, None]:
"""Generate a response with streaming."""
@abstractmethod
async def agenerate(
self,
messages: str | list[Any],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
"""Generate a response asynchronously."""
@abstractmethod
async def astream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> AsyncGenerator[str, None]:
"""Generate a response asynchronously with streaming."""
...
class BaseTextEmbedding(ABC):
"""The text embedding interface."""
@abstractmethod
def embed(self, text: str, **kwargs: Any) -> list[float]:
"""Embed a text string."""
@abstractmethod
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
"""Embed a text string asynchronously."""

View File

@@ -1,82 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Initialize LLM and Embedding clients."""
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from graphrag.config.defaults import language_model_defaults
from graphrag.config.enums import AuthType, ModelType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
"""Get the LLM client."""
llm_config = config.get_language_model_config("default_chat_model")
is_azure_client = llm_config.type == ModelType.AzureOpenAIChat
debug_llm_key = llm_config.api_key or ""
llm_debug_info = {
**llm_config.model_dump(),
"api_key": f"REDACTED,len={len(debug_llm_key)}",
}
audience = (
llm_config.audience
if llm_config.audience
else "https://cognitiveservices.azure.com/.default"
)
print(f"creating llm client with {llm_debug_info}") # noqa T201
return ChatOpenAI(
api_key=llm_config.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and llm_config.auth_type == AuthType.AzureManagedIdentity
else None
),
api_base=llm_config.api_base,
organization=llm_config.organization,
model=llm_config.model,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
deployment_name=llm_config.deployment_name,
api_version=llm_config.api_version,
max_retries=llm_config.max_retries
if llm_config.max_retries != -1
else language_model_defaults.max_retries,
request_timeout=llm_config.request_timeout,
)
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
"""Get the LLM client for embeddings."""
embeddings_llm_config = config.get_language_model_config(config.embed_text.model_id)
is_azure_client = embeddings_llm_config.type == ModelType.AzureOpenAIEmbedding
debug_embedding_api_key = embeddings_llm_config.api_key or ""
llm_debug_info = {
**embeddings_llm_config.model_dump(),
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
}
if embeddings_llm_config.audience is None:
audience = "https://cognitiveservices.azure.com/.default"
else:
audience = embeddings_llm_config.audience
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
return OpenAIEmbedding(
api_key=embeddings_llm_config.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client
and embeddings_llm_config.auth_type == AuthType.AzureManagedIdentity
else None
),
api_base=embeddings_llm_config.api_base,
organization=embeddings_llm_config.organization,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
model=embeddings_llm_config.model,
deployment_name=embeddings_llm_config.deployment_name,
api_version=embeddings_llm_config.api_version,
max_retries=embeddings_llm_config.max_retries
if embeddings_llm_config.max_retries != -1
else language_model_defaults.max_retries,
)

View File

@@ -1,4 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""GraphRAG Orchestration OpenAI Wrappers."""

View File

@@ -1,188 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Base classes for LLM and Embedding models."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from graphrag.logger.base import StatusLogger
from graphrag.logger.console import ConsoleReporter
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
class BaseOpenAILLM(ABC):
"""The Base OpenAI LLM implementation."""
_async_client: AsyncOpenAI | AsyncAzureOpenAI
_sync_client: OpenAI | AzureOpenAI
def __init__(self):
self._create_openai_client()
@abstractmethod
def _create_openai_client(self):
"""Create a new synchronous and asynchronous OpenAI client instance."""
def set_clients(
self,
sync_client: OpenAI | AzureOpenAI,
async_client: AsyncOpenAI | AsyncAzureOpenAI,
):
"""
Set the synchronous and asynchronous clients used for making API requests.
Args:
sync_client (OpenAI | AzureOpenAI): The sync client object.
async_client (AsyncOpenAI | AsyncAzureOpenAI): The async client object.
"""
self._sync_client = sync_client
self._async_client = async_client
@property
def async_client(self) -> AsyncOpenAI | AsyncAzureOpenAI | None:
"""
Get the asynchronous client used for making API requests.
Returns
-------
AsyncOpenAI | AsyncAzureOpenAI: The async client object.
"""
return self._async_client
@property
def sync_client(self) -> OpenAI | AzureOpenAI | None:
"""
Get the synchronous client used for making API requests.
Returns
-------
AsyncOpenAI | AsyncAzureOpenAI: The async client object.
"""
return self._sync_client
@async_client.setter
def async_client(self, client: AsyncOpenAI | AsyncAzureOpenAI):
"""
Set the asynchronous client used for making API requests.
Args:
client (AsyncOpenAI | AsyncAzureOpenAI): The async client object.
"""
self._async_client = client
@sync_client.setter
def sync_client(self, client: OpenAI | AzureOpenAI):
"""
Set the synchronous client used for making API requests.
Args:
client (OpenAI | AzureOpenAI): The sync client object.
"""
self._sync_client = client
class OpenAILLMImpl(BaseOpenAILLM):
"""Orchestration OpenAI LLM Implementation."""
_reporter: StatusLogger = ConsoleReporter()
def __init__(
self,
api_key: str | None = None,
azure_ad_token_provider: Callable | None = None,
deployment_name: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_type: OpenaiApiType = OpenaiApiType.OpenAI,
organization: str | None = None,
max_retries: int = 10,
request_timeout: float = 180.0,
logger: StatusLogger | None = None,
):
self.api_key = api_key
self.azure_ad_token_provider = azure_ad_token_provider
self.deployment_name = deployment_name
self.api_base = api_base
self.api_version = api_version
self.api_type = api_type
self.organization = organization
self.max_retries = max_retries
self.request_timeout = request_timeout
self.logger = logger or ConsoleReporter()
try:
# Create OpenAI sync and async clients
super().__init__()
except Exception as e:
self._reporter.error(
message="Failed to create OpenAI client",
details={self.__class__.__name__: str(e)},
)
raise
def _create_openai_client(self):
"""Create a new OpenAI client instance."""
if self.api_type == OpenaiApiType.AzureOpenAI:
if self.api_base is None:
msg = "api_base is required for Azure OpenAI"
raise ValueError(msg)
sync_client = AzureOpenAI(
api_key=self.api_key,
azure_ad_token_provider=self.azure_ad_token_provider,
organization=self.organization,
# Azure-Specifics
api_version=self.api_version,
azure_endpoint=self.api_base,
azure_deployment=self.deployment_name,
# Retry Configuration
timeout=self.request_timeout,
max_retries=self.max_retries,
)
async_client = AsyncAzureOpenAI(
api_key=self.api_key,
azure_ad_token_provider=self.azure_ad_token_provider,
organization=self.organization,
# Azure-Specifics
api_version=self.api_version,
azure_endpoint=self.api_base,
azure_deployment=self.deployment_name,
# Retry Configuration
timeout=self.request_timeout,
max_retries=self.max_retries,
)
self.set_clients(sync_client=sync_client, async_client=async_client)
else:
sync_client = OpenAI(
api_key=self.api_key,
base_url=self.api_base,
organization=self.organization,
# Retry Configuration
timeout=self.request_timeout,
max_retries=self.max_retries,
)
async_client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.api_base,
organization=self.organization,
# Retry Configuration
timeout=self.request_timeout,
max_retries=self.max_retries,
)
self.set_clients(sync_client=sync_client, async_client=async_client)
class OpenAITextEmbeddingImpl(BaseTextEmbedding):
"""Orchestration OpenAI Text Embedding Implementation."""
_reporter: StatusLogger | None = None
def _create_openai_client(self, api_type: OpenaiApiType):
"""Create a new synchronous and asynchronous OpenAI client instance."""

View File

@@ -1,329 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Chat-based OpenAI LLM implementation."""
from collections.abc import AsyncGenerator, Callable, Generator
from typing import Any
from tenacity import (
AsyncRetrying,
RetryError,
Retrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from graphrag.logger.base import StatusLogger
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
from graphrag.query.llm.oai.base import OpenAILLMImpl
from graphrag.query.llm.oai.typing import (
OPENAI_RETRY_ERROR_TYPES,
OpenaiApiType,
)
_MODEL_REQUIRED_MSG = "model is required"
class ChatOpenAI(BaseLLM, OpenAILLMImpl):
"""Wrapper for OpenAI ChatCompletion models."""
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
azure_ad_token_provider: Callable | None = None,
deployment_name: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_type: OpenaiApiType = OpenaiApiType.OpenAI,
organization: str | None = None,
max_retries: int = 10,
request_timeout: float = 180.0,
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
logger: StatusLogger | None = None,
):
OpenAILLMImpl.__init__(
self=self,
api_key=api_key,
azure_ad_token_provider=azure_ad_token_provider,
deployment_name=deployment_name,
api_base=api_base,
api_version=api_version,
api_type=api_type, # type: ignore
organization=organization,
max_retries=max_retries,
request_timeout=request_timeout,
logger=logger,
)
self.model = model
self.retry_error_types = retry_error_types
def generate(
self,
messages: str | list[Any],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
"""Generate text."""
try:
retryer = Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
for attempt in retryer:
with attempt:
return self._generate(
messages=messages,
streaming=streaming,
callbacks=callbacks,
**kwargs,
)
except RetryError as e:
self._reporter.error(
message="Error at generate()", details={self.__class__.__name__: str(e)}
)
return ""
else:
# TODO: why not just throw in this case?
return ""
def stream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> Generator[str, None, None]:
"""Generate text with streaming."""
try:
retryer = Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
for attempt in retryer:
with attempt:
generator = self._stream_generate(
messages=messages,
callbacks=callbacks,
**kwargs,
)
yield from generator
except RetryError as e:
self._reporter.error(
message="Error at stream_generate()",
details={self.__class__.__name__: str(e)},
)
return
else:
return
async def agenerate(
self,
messages: str | list[Any],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
"""Generate text asynchronously."""
try:
retryer = AsyncRetrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types), # type: ignore
)
async for attempt in retryer:
with attempt:
return await self._agenerate(
messages=messages,
streaming=streaming,
callbacks=callbacks,
**kwargs,
)
except RetryError as e:
self._reporter.error(f"Error at agenerate(): {e}")
return ""
else:
# TODO: why not just throw in this case?
return ""
async def astream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> AsyncGenerator[str, None]:
"""Generate text asynchronously with streaming."""
try:
retryer = AsyncRetrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types), # type: ignore
)
async for attempt in retryer:
with attempt:
generator = self._astream_generate(
messages=messages,
callbacks=callbacks,
**kwargs,
)
async for response in generator:
yield response
except RetryError as e:
self._reporter.error(f"Error at astream_generate(): {e}")
return
else:
return
def _generate(
self,
messages: str | list[Any],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
model = self.model
if not model:
raise ValueError(_MODEL_REQUIRED_MSG)
response = self.sync_client.chat.completions.create( # type: ignore
model=model,
messages=messages, # type: ignore
stream=streaming,
**kwargs,
) # type: ignore
if streaming:
full_response = ""
while True:
try:
chunk = response.__next__() # type: ignore
if not chunk or not chunk.choices:
continue
delta = (
chunk.choices[0].delta.content
if chunk.choices[0].delta and chunk.choices[0].delta.content
else ""
) # type: ignore
full_response += delta
if callbacks:
for callback in callbacks:
callback.on_llm_new_token(delta)
if chunk.choices[0].finish_reason == "stop": # type: ignore
break
except StopIteration:
break
return full_response
return response.choices[0].message.content or "" # type: ignore
def _stream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> Generator[str, None, None]:
model = self.model
if not model:
raise ValueError(_MODEL_REQUIRED_MSG)
response = self.sync_client.chat.completions.create( # type: ignore
model=model,
messages=messages, # type: ignore
stream=True,
**kwargs,
)
for chunk in response:
if not chunk or not chunk.choices:
continue
delta = (
chunk.choices[0].delta.content
if chunk.choices[0].delta and chunk.choices[0].delta.content
else ""
)
yield delta
if callbacks:
for callback in callbacks:
callback.on_llm_new_token(delta)
async def _agenerate(
self,
messages: str | list[Any],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
model = self.model
if not model:
raise ValueError(_MODEL_REQUIRED_MSG)
response = await self.async_client.chat.completions.create( # type: ignore
model=model,
messages=messages, # type: ignore
stream=streaming,
**kwargs,
)
if streaming:
full_response = ""
while True:
try:
chunk = await response.__anext__() # type: ignore
if not chunk or not chunk.choices:
continue
delta = (
chunk.choices[0].delta.content
if chunk.choices[0].delta and chunk.choices[0].delta.content
else ""
) # type: ignore
full_response += delta
if callbacks:
for callback in callbacks:
callback.on_llm_new_token(delta)
if chunk.choices[0].finish_reason == "stop": # type: ignore
break
except StopAsyncIteration:
break
return full_response
return response.choices[0].message.content or "" # type: ignore
async def _astream_generate(
self,
messages: str | list[Any],
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> AsyncGenerator[str, None]:
model = self.model
if not model:
raise ValueError(_MODEL_REQUIRED_MSG)
response = await self.async_client.chat.completions.create( # type: ignore
model=model,
messages=messages, # type: ignore
stream=True,
**kwargs,
)
async for chunk in response:
if not chunk or not chunk.choices:
continue
delta = (
chunk.choices[0].delta.content
if chunk.choices[0].delta and chunk.choices[0].delta.content
else ""
) # type: ignore
yield delta
if callbacks:
for callback in callbacks:
callback.on_llm_new_token(delta)

View File

@@ -1,183 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""OpenAI Embedding model implementation."""
import asyncio
from collections.abc import Callable
from typing import Any
import numpy as np
import tiktoken
from tenacity import (
AsyncRetrying,
RetryError,
Retrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
import graphrag.config.defaults as defs
from graphrag.logger.base import StatusLogger
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.base import OpenAILLMImpl
from graphrag.query.llm.oai.typing import (
OPENAI_RETRY_ERROR_TYPES,
OpenaiApiType,
)
from graphrag.query.llm.text_utils import chunk_text
class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl):
"""Wrapper for OpenAI Embedding models."""
def __init__(
self,
api_key: str | None = None,
azure_ad_token_provider: Callable | None = None,
model: str = "text-embedding-3-small",
deployment_name: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_type: OpenaiApiType = OpenaiApiType.OpenAI,
organization: str | None = None,
encoding_name: str = defs.ENCODING_MODEL,
max_tokens: int = 8191,
max_retries: int = 10,
request_timeout: float = 180.0,
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
logger: StatusLogger | None = None,
):
OpenAILLMImpl.__init__(
self=self,
api_key=api_key,
azure_ad_token_provider=azure_ad_token_provider,
deployment_name=deployment_name,
api_base=api_base,
api_version=api_version,
api_type=api_type, # type: ignore
organization=organization,
max_retries=max_retries,
request_timeout=request_timeout,
logger=logger,
)
self.model = model
self.encoding_name = encoding_name
self.max_tokens = max_tokens
self.token_encoder = tiktoken.get_encoding(self.encoding_name)
self.retry_error_types = retry_error_types
def embed(self, text: str, **kwargs: Any) -> list[float]:
"""
Embed text using OpenAI Embedding's sync function.
For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average.
Please refer to: https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
"""
token_chunks = chunk_text(
text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens
)
chunk_embeddings = []
chunk_lens = []
for chunk in token_chunks:
try:
embedding, chunk_len = self._embed_with_retry(chunk, **kwargs)
chunk_embeddings.append(embedding)
chunk_lens.append(chunk_len)
# TODO: catch a more specific exception
except Exception as e: # noqa BLE001
self._reporter.error(
message="Error embedding chunk",
details={self.__class__.__name__: str(e)},
)
continue
chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens)
chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)
return chunk_embeddings.tolist()
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
"""
Embed text using OpenAI Embedding's async function.
For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average.
"""
token_chunks = chunk_text(
text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens
)
chunk_embeddings = []
chunk_lens = []
embedding_results = await asyncio.gather(*[
self._aembed_with_retry(chunk, **kwargs) for chunk in token_chunks
])
embedding_results = [result for result in embedding_results if result[0]]
chunk_embeddings = [result[0] for result in embedding_results]
chunk_lens = [result[1] for result in embedding_results]
chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) # type: ignore
chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)
return chunk_embeddings.tolist()
def _embed_with_retry(
self, text: str | tuple, **kwargs: Any
) -> tuple[list[float], int]:
try:
retryer = Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
for attempt in retryer:
with attempt:
embedding = (
self.sync_client.embeddings.create( # type: ignore
input=text,
model=self.model,
**kwargs, # type: ignore
)
.data[0]
.embedding
or []
)
return (embedding, len(text))
except RetryError as e:
self._reporter.error(
message="Error at embed_with_retry()",
details={self.__class__.__name__: str(e)},
)
return ([], 0)
else:
# TODO: why not just throw in this case?
return ([], 0)
async def _aembed_with_retry(
self, text: str | tuple, **kwargs: Any
) -> tuple[list[float], int]:
try:
retryer = AsyncRetrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
async for attempt in retryer:
with attempt:
embedding = (
await self.async_client.embeddings.create( # type: ignore
input=text,
model=self.model,
**kwargs, # type: ignore
)
).data[0].embedding or []
return (embedding, len(text))
except RetryError as e:
self._reporter.error(
message="Error at embed_with_retry()",
details={self.__class__.__name__: str(e)},
)
return ([], 0)
else:
# TODO: why not just throw in this case?
return ([], 0)

View File

@@ -1,187 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""OpenAI Wrappers for Orchestration."""
import logging
from typing import Any
from tenacity import (
AsyncRetrying,
RetryError,
Retrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from graphrag.query.llm.base import BaseLLMCallback
from graphrag.query.llm.oai.base import OpenAILLMImpl
from graphrag.query.llm.oai.typing import (
OPENAI_RETRY_ERROR_TYPES,
OpenaiApiType,
)
log = logging.getLogger(__name__)
class OpenAI(OpenAILLMImpl):
"""Wrapper for OpenAI Completion models."""
def __init__(
self,
api_key: str,
model: str,
deployment_name: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_type: OpenaiApiType = OpenaiApiType.OpenAI,
organization: str | None = None,
max_retries: int = 10,
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
):
self.api_key = api_key
self.model = model
self.deployment_name = deployment_name
self.api_base = api_base
self.api_version = api_version
self.api_type = api_type
self.organization = organization
self.max_retries = max_retries
self.retry_error_types = retry_error_types
def generate(
self,
messages: str | list[str],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
"""Generate text."""
try:
retryer = Retrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
for attempt in retryer:
with attempt:
return self._generate(
messages=messages,
streaming=streaming,
callbacks=callbacks,
**kwargs,
)
except RetryError:
log.exception("RetryError at generate(): %s")
return ""
else:
# TODO: why not just throw in this case?
return ""
async def agenerate(
self,
messages: str | list[str],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
"""Generate Text Asynchronously."""
try:
retryer = AsyncRetrying(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential_jitter(max=10),
reraise=True,
retry=retry_if_exception_type(self.retry_error_types),
)
async for attempt in retryer:
with attempt:
return await self._agenerate(
messages=messages,
streaming=streaming,
callbacks=callbacks,
**kwargs,
)
except RetryError:
log.exception("Error at agenerate()")
return ""
else:
# TODO: why not just throw in this case?
return ""
def _generate(
self,
messages: str | list[str],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
response = self.sync_client.chat.completions.create( # type: ignore
model=self.model,
messages=messages, # type: ignore
stream=streaming,
**kwargs,
) # type: ignore
if streaming:
full_response = ""
while True:
try:
chunk = response.__next__() # type: ignore
if not chunk or not chunk.choices:
continue
delta = (
chunk.choices[0].delta.content
if chunk.choices[0].delta and chunk.choices[0].delta.content
else ""
) # type: ignore
full_response += delta
if callbacks:
for callback in callbacks:
callback.on_llm_new_token(delta)
if chunk.choices[0].finish_reason == "stop": # type: ignore
break
except StopIteration:
break
return full_response
return response.choices[0].message.content or "" # type: ignore
async def _agenerate(
self,
messages: str | list[str],
streaming: bool = True,
callbacks: list[BaseLLMCallback] | None = None,
**kwargs: Any,
) -> str:
response = await self.async_client.chat.completions.create( # type: ignore
model=self.model,
messages=messages, # type: ignore
stream=streaming,
**kwargs,
)
if streaming:
full_response = ""
while True:
try:
chunk = await response.__anext__() # type: ignore
if not chunk or not chunk.choices:
continue
delta = (
chunk.choices[0].delta.content
if chunk.choices[0].delta and chunk.choices[0].delta.content
else ""
) # type: ignore
full_response += delta
if callbacks:
for callback in callbacks:
callback.on_llm_new_token(delta)
if chunk.choices[0].finish_reason == "stop": # type: ignore
break
except StopIteration:
break
return full_response
return response.choices[0].message.content or "" # type: ignore

View File

@@ -1,27 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""OpenAI wrapper options."""
from enum import Enum
from typing import Any, cast
import httpx
import openai
OPENAI_RETRY_ERROR_TYPES = (
# TODO: update these when we update to OpenAI 1+ library
cast("Any", openai).RateLimitError,
cast("Any", openai).APIConnectionError,
cast("Any", openai).APIError,
cast("Any", httpx).RemoteProtocolError,
cast("Any", httpx).ReadTimeout,
# TODO: replace with comparable OpenAI 1+ error
)
class OpenaiApiType(str, Enum):
"""The OpenAI Flavor."""
OpenAI = "openai"
AzureOpenAI = "azure"

View File

@@ -9,11 +9,11 @@ from typing import Any
import tiktoken
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.builders import (
GlobalContextBuilder,
LocalContextBuilder,
)
from graphrag.query.llm.base import BaseLLM
@dataclass
@@ -32,20 +32,20 @@ class BaseQuestionGen(ABC):
def __init__(
self,
llm: BaseLLM,
model: ChatModel,
context_builder: GlobalContextBuilder | LocalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
llm_params: dict[str, Any] | None = None,
context_builder_params: dict[str, Any] | None = None,
):
self.llm = llm
self.model = model
self.context_builder = context_builder
self.token_encoder = token_encoder
self.llm_params = llm_params or {}
self.context_builder_params = context_builder_params or {}
@abstractmethod
def generate(
async def generate(
self,
question_history: list[str],
context_data: str | None,

View File

@@ -9,6 +9,8 @@ from typing import Any, cast
import tiktoken
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
from graphrag.query.context_builder.builders import (
ContextBuilderResult,
@@ -17,7 +19,6 @@ from graphrag.query.context_builder.builders import (
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult
@@ -29,7 +30,7 @@ class LocalQuestionGen(BaseQuestionGen):
def __init__(
self,
llm: BaseLLM,
model: ChatModel,
context_builder: LocalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
system_prompt: str = QUESTION_SYSTEM_PROMPT,
@@ -38,7 +39,7 @@ class LocalQuestionGen(BaseQuestionGen):
context_builder_params: dict[str, Any] | None = None,
):
super().__init__(
llm=llm,
model=model,
context_builder=context_builder,
token_encoder=token_encoder,
llm_params=llm_params,
@@ -95,15 +96,17 @@ class LocalQuestionGen(BaseQuestionGen):
)
question_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question_text},
]
response = await self.llm.agenerate(
messages=question_messages,
streaming=True,
callbacks=self.callbacks,
**self.llm_params,
)
response = ""
async for chunk in self.model.achat_stream(
prompt=question_text,
history=question_messages,
model_parameters=self.llm_params,
):
response += chunk
for callback in self.callbacks:
callback.on_llm_new_token(chunk)
return QuestionResult(
response=response.split("\n"),
@@ -126,7 +129,7 @@ class LocalQuestionGen(BaseQuestionGen):
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
)
def generate(
async def generate(
self,
question_history: list[str],
context_data: str | None,
@@ -178,12 +181,15 @@ class LocalQuestionGen(BaseQuestionGen):
{"role": "user", "content": question_text},
]
response = self.llm.generate(
messages=question_messages,
streaming=True,
callbacks=self.callbacks,
**self.llm_params,
)
response = ""
async for chunk in self.model.achat_stream(
prompt=question_text,
history=question_messages,
model_parameters=self.llm_params,
):
response += chunk
for callback in self.callbacks:
callback.on_llm_new_token(chunk)
return QuestionResult(
response=response.split("\n"),

View File

@@ -11,6 +11,7 @@ from typing import Any, Generic, TypeVar
import pandas as pd
import tiktoken
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.builders import (
BasicContextBuilder,
DRIFTContextBuilder,
@@ -20,7 +21,6 @@ from graphrag.query.context_builder.builders import (
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.llm.base import BaseLLM
@dataclass
@@ -56,16 +56,16 @@ class BaseSearch(ABC, Generic[T]):
def __init__(
self,
llm: BaseLLM,
model: ChatModel,
context_builder: T,
token_encoder: tiktoken.Encoding | None = None,
llm_params: dict[str, Any] | None = None,
model_params: dict[str, Any] | None = None,
context_builder_params: dict[str, Any] | None = None,
):
self.llm = llm
self.model = model
self.context_builder = context_builder
self.token_encoder = token_encoder
self.llm_params = llm_params or {}
self.model_params = model_params or {}
self.context_builder_params = context_builder_params or {}
@abstractmethod
@@ -80,11 +80,12 @@ class BaseSearch(ABC, Generic[T]):
raise NotImplementedError(msg)
@abstractmethod
def stream_search(
async def stream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator[Any, None]:
) -> AsyncGenerator[str, None]:
"""Stream search for the given query."""
yield "" # This makes it an async generator.
msg = "Subclasses must implement this method"
raise NotImplementedError(msg)

View File

@@ -7,12 +7,12 @@ import pandas as pd
import tiktoken
from graphrag.data_model.text_unit import TextUnit
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.query.context_builder.builders import (
BasicContextBuilder,
ContextBuilderResult,
)
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.vector_stores.base import BaseVectorStore
@@ -21,7 +21,7 @@ class BasicSearchContext(BasicContextBuilder):
def __init__(
self,
text_embedder: BaseTextEmbedding,
text_embedder: EmbeddingModel,
text_unit_embeddings: BaseVectorStore,
text_units: list[TextUnit] | None = None,
token_encoder: tiktoken.Encoding | None = None,

View File

@@ -11,12 +11,12 @@ from typing import Any
import tiktoken
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.basic_search_system_prompt import (
BASIC_SEARCH_SYSTEM_PROMPT,
)
from graphrag.query.context_builder.builders import BasicContextBuilder
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
@@ -36,7 +36,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
def __init__(
self,
llm: BaseLLM,
model: ChatModel,
context_builder: BasicContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
system_prompt: str | None = None,
@@ -46,10 +46,10 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
context_builder_params: dict | None = None,
):
super().__init__(
llm=llm,
model=model,
context_builder=context_builder,
token_encoder=token_encoder,
llm_params=llm_params,
model_params=llm_params,
context_builder_params=context_builder_params or {},
)
self.system_prompt = system_prompt or BASIC_SEARCH_SYSTEM_PROMPT
@@ -86,15 +86,17 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
response = await self.llm.agenerate(
messages=search_messages,
streaming=True,
callbacks=self.callbacks, # type: ignore
**self.llm_params,
)
response = ""
async for chunk in self.model.achat_stream(
prompt=query,
history=search_messages,
model_parameters=self.model_params,
):
for callback in self.callbacks:
callback.on_llm_new_token(chunk)
response += chunk
llm_calls["response"] = 1
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
@@ -125,11 +127,11 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
output_tokens=0,
)
def stream_search(
async def stream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator[Any, None]:
) -> AsyncGenerator[str, None]:
"""Build basic search context that fits a single context window and generate answer for the user query."""
start_time = time.time()
@@ -144,14 +146,16 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
for callback in self.callbacks:
callback.on_context(context_result.context_records)
return self.llm.astream_generate( # type: ignore
messages=search_messages,
callbacks=self.callbacks, # type: ignore
**self.llm_params,
)
async for chunk_response in self.model.achat_stream(
prompt=query,
history=search_messages,
model_parameters=self.model_params,
):
for callback in self.callbacks:
callback.on_llm_new_token(chunk_response)
yield chunk_response

View File

@@ -17,13 +17,12 @@ from graphrag.data_model.covariate import Covariate
from graphrag.data_model.entity import Entity
from graphrag.data_model.relationship import Relationship
from graphrag.data_model.text_unit import TextUnit
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_LOCAL_SYSTEM_PROMPT,
DRIFT_REDUCE_PROMPT,
)
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.structured_search.base import DRIFTContextBuilder
from graphrag.query.structured_search.drift_search.primer import PrimerQueryProcessor
from graphrag.query.structured_search.local_search.mixed_context import (
@@ -39,8 +38,8 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
def __init__(
self,
chat_llm: ChatOpenAI,
text_embedder: BaseTextEmbedding,
model: ChatModel,
text_embedder: EmbeddingModel,
entities: list[Entity],
entity_text_embeddings: BaseVectorStore,
text_units: list[TextUnit] | None = None,
@@ -57,7 +56,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
):
"""Initialize the DRIFT search context builder with necessary components."""
self.config = config or DRIFTSearchConfig()
self.chat_llm = chat_llm
self.model = model
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.local_system_prompt = local_system_prompt or DRIFT_LOCAL_SYSTEM_PROMPT
@@ -163,7 +162,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
and isinstance(query_embedding[0], type(embedding[0]))
)
def build_context(
async def build_context(
self, query: str, **kwargs
) -> tuple[pd.DataFrame, dict[str, int]]:
"""
@@ -191,13 +190,13 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
raise ValueError(missing_reports_error)
query_processor = PrimerQueryProcessor(
chat_llm=self.chat_llm,
chat_model=self.model,
text_embedder=self.text_embedder,
token_encoder=self.token_encoder,
reports=self.reports,
)
query_embedding, token_ct = query_processor(query)
query_embedding, token_ct = await query_processor(query)
report_df = self.convert_reports_to_df(self.reports)

View File

@@ -15,11 +15,10 @@ from tqdm.asyncio import tqdm_asyncio
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.data_model.community_report import CommunityReport
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_PRIMER_PROMPT,
)
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import SearchResult
@@ -31,8 +30,8 @@ class PrimerQueryProcessor:
def __init__(
self,
chat_llm: ChatOpenAI,
text_embedder: BaseTextEmbedding,
chat_model: ChatModel,
text_embedder: EmbeddingModel,
reports: list[CommunityReport],
token_encoder: tiktoken.Encoding | None = None,
):
@@ -45,12 +44,12 @@ class PrimerQueryProcessor:
reports (list[CommunityReport]): List of community reports.
token_encoder (tiktoken.Encoding, optional): Token encoder for token counting.
"""
self.chat_llm = chat_llm
self.chat_model = chat_model
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.reports = reports
def expand_query(self, query: str) -> tuple[str, dict[str, int]]:
async def expand_query(self, query: str) -> tuple[str, dict[str, int]]:
"""
Expand the query using a random community report template.
@@ -68,9 +67,9 @@ class PrimerQueryProcessor:
{template}\n"
Ensure that the hypothetical answer does not reference new named entities that are not present in the original query."""
messages = [{"role": "user", "content": prompt}]
model_response = await self.chat_model.achat(prompt)
text = model_response.output.content
text = self.chat_llm.generate(messages)
prompt_tokens = num_tokens(prompt, self.token_encoder)
output_tokens = num_tokens(text, self.token_encoder)
token_ct = {
@@ -83,7 +82,7 @@ class PrimerQueryProcessor:
return query, token_ct
return text, token_ct
def __call__(self, query: str) -> tuple[list[float], dict[str, int]]:
async def __call__(self, query: str) -> tuple[list[float], dict[str, int]]:
"""
Call method to process the query, expand it, and embed the result.
@@ -94,7 +93,7 @@ class PrimerQueryProcessor:
-------
tuple[list[float], int]: List of embeddings for the expanded query and the token count.
"""
hyde_query, token_ct = self.expand_query(query)
hyde_query, token_ct = await self.expand_query(query)
log.info("Expanded query: %s", hyde_query)
return self.text_embedder.embed(hyde_query), token_ct
@@ -105,7 +104,7 @@ class DRIFTPrimer:
def __init__(
self,
config: DRIFTSearchConfig,
chat_llm: ChatOpenAI,
chat_model: ChatModel,
token_encoder: tiktoken.Encoding | None = None,
):
"""
@@ -116,7 +115,7 @@ class DRIFTPrimer:
chat_llm (ChatOpenAI): The language model used for searching.
token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens.
"""
self.llm = chat_llm
self.chat_model = chat_model
self.config = config
self.token_encoder = token_encoder
@@ -138,11 +137,9 @@ class DRIFTPrimer:
prompt = DRIFT_PRIMER_PROMPT.format(
query=query, community_reports=community_reports
)
messages = [{"role": "user", "content": prompt}]
response = await self.llm.agenerate(
messages, response_format={"type": "json_object"}
)
model_response = await self.chat_model.achat(prompt, json=True)
response = model_response.output.content
parsed_response = json.loads(response)
@@ -173,6 +170,7 @@ class DRIFTPrimer:
start_time = time.perf_counter()
report_folds = self.split_reports(top_k_reports)
tasks = [self.decompose_query(query, fold) for fold in report_folds]
results_with_tokens = await tqdm_asyncio.gather(*tasks, leave=False)
completion_time = time.perf_counter() - start_time

View File

@@ -12,9 +12,9 @@ import tiktoken
from tqdm.asyncio import tqdm_asyncio
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.query.structured_search.drift_search.action import DriftAction
@@ -33,7 +33,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
def __init__(
self,
llm: ChatOpenAI,
model: ChatModel,
context_builder: DRIFTSearchContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
query_state: QueryState | None = None,
@@ -49,14 +49,14 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens.
query_state (QueryState, optional): State of the current search query.
"""
super().__init__(llm, context_builder, token_encoder)
super().__init__(model, context_builder, token_encoder)
self.context_builder = context_builder
self.token_encoder = token_encoder
self.query_state = query_state or QueryState()
self.primer = DRIFTPrimer(
config=self.context_builder.config,
chat_llm=llm,
chat_model=model,
token_encoder=token_encoder,
)
self.callbacks = callbacks or []
@@ -90,11 +90,11 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
}
return LocalSearch(
llm=self.llm,
model=self.model,
system_prompt=self.context_builder.local_system_prompt,
context_builder=self.context_builder.local_mixed_context,
token_encoder=self.token_encoder,
llm_params=llm_params,
model_params=llm_params,
context_builder_params=local_context_params,
response_type="multiple paragraphs",
callbacks=self.callbacks,
@@ -203,7 +203,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
# Check if query state is empty
if not self.query_state.graph:
# Prime the search with the primer
primer_context, token_ct = self.context_builder.build_context(query)
primer_context, token_ct = await self.context_builder.build_context(query)
llm_calls["build_context"] = token_ct["llm_calls"]
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
output_tokens["build_context"] = token_ct["prompt_tokens"]
@@ -362,16 +362,16 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
reduced_response = self.llm.generate(
messages=search_messages,
streaming=False,
callbacks=self.callbacks, # type: ignore
**llm_kwargs,
model_response = await self.model.achat(
prompt=query,
history=search_messages,
model_parameters=llm_kwargs,
)
reduced_response = model_response.output.content
llm_calls["reduce"] = 1
prompt_tokens["reduce"] = num_tokens(
search_prompt, self.token_encoder
@@ -380,7 +380,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
return reduced_response
def _reduce_response_streaming(
async def _reduce_response_streaming(
self,
responses: str | dict[str, Any],
query: str,
@@ -419,11 +419,13 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
return self.llm.astream_generate(
search_messages,
callbacks=self.callbacks, # type: ignore
**llm_kwargs,
)
async for response in self.model.achat_stream(
prompt=query,
history=search_messages,
model_parameters=llm_kwargs,
):
for callback in self.callbacks:
callback.on_llm_new_token(response)
yield response

View File

@@ -46,7 +46,7 @@ class GlobalCommunityContext(GlobalContextBuilder):
self.dynamic_community_selection = DynamicCommunitySelection(
community_reports=community_reports,
communities=communities,
llm=dynamic_community_selection_kwargs.pop("llm"),
model=dynamic_community_selection_kwargs.pop("model"),
token_encoder=dynamic_community_selection_kwargs.pop("token_encoder"),
**dynamic_community_selection_kwargs,
)

View File

@@ -15,6 +15,7 @@ import pandas as pd
import tiktoken
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
GENERAL_KNOWLEDGE_INSTRUCTION,
)
@@ -29,7 +30,6 @@ from graphrag.query.context_builder.builders import GlobalContextBuilder
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object
from graphrag.query.structured_search.base import BaseSearch, SearchResult
@@ -60,7 +60,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
def __init__(
self,
llm: BaseLLM,
model: ChatModel,
context_builder: GlobalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
map_system_prompt: str | None = None,
@@ -77,7 +77,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
concurrent_coroutines: int = 32,
):
super().__init__(
llm=llm,
model=model,
context_builder=context_builder,
token_encoder=token_encoder,
context_builder_params=context_builder_params,
@@ -106,7 +106,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
self,
query: str,
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator:
) -> AsyncGenerator[str, None]:
"""Stream the global search response."""
context_result = await self.context_builder.build_context(
query=query,
@@ -130,7 +130,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
async for response in self._stream_reduce_response(
map_responses=map_responses, # type: ignore
query=query,
**self.reduce_llm_params,
model_parameters=self.reduce_llm_params,
):
yield response
@@ -218,12 +218,15 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
search_prompt = self.map_system_prompt.format(context_data=context_data)
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
async with self.semaphore:
search_response = await self.llm.agenerate(
messages=search_messages, streaming=False, **llm_kwargs
model_response = await self.model.achat(
prompt=query,
history=search_messages,
model_parameters=llm_kwargs,
json=True,
)
search_response = model_response.output.content
log.info("Map response: %s", search_response)
try:
# parse search response json
@@ -373,12 +376,16 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
{"role": "user", "content": query},
]
search_response = await self.llm.agenerate(
search_messages,
streaming=True,
callbacks=self.callbacks, # type: ignore
**llm_kwargs, # type: ignore
)
search_response = ""
async for chunk_response in self.model.achat_stream(
prompt=query,
history=search_messages,
model_parameters=llm_kwargs,
):
search_response += chunk_response
for callback in self.callbacks:
callback.on_llm_new_token(chunk_response)
return SearchResult(
response=search_response,
context_data=text_data,
@@ -468,12 +475,13 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
search_prompt += "\n" + self.general_knowledge_inclusion_prompt
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
async for resp in self.llm.astream_generate( # type: ignore
search_messages,
callbacks=self.callbacks, # type: ignore
**llm_kwargs, # type: ignore
async for chunk_response in self.model.achat_stream(
prompt=query,
history=search_messages,
**llm_kwargs,
):
yield resp
for callback in self.callbacks:
callback.on_llm_new_token(chunk_response)
yield chunk_response

View File

@@ -14,6 +14,7 @@ from graphrag.data_model.covariate import Covariate
from graphrag.data_model.entity import Entity
from graphrag.data_model.relationship import Relationship
from graphrag.data_model.text_unit import TextUnit
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.query.context_builder.builders import ContextBuilderResult
from graphrag.query.context_builder.community_context import (
build_community_context,
@@ -39,7 +40,6 @@ from graphrag.query.input.retrieval.community_reports import (
get_candidate_communities,
)
from graphrag.query.input.retrieval.text_units import get_candidate_text_units
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import LocalContextBuilder
from graphrag.vector_stores.base import BaseVectorStore
@@ -54,7 +54,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
self,
entities: list[Entity],
entity_text_embeddings: BaseVectorStore,
text_embedder: BaseTextEmbedding,
text_embedder: EmbeddingModel,
text_units: list[TextUnit] | None = None,
community_reports: list[CommunityReport] | None = None,
relationships: list[Relationship] | None = None,

View File

@@ -11,6 +11,7 @@ from typing import Any
import tiktoken
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.local_search_system_prompt import (
LOCAL_SEARCH_SYSTEM_PROMPT,
)
@@ -18,7 +19,6 @@ from graphrag.query.context_builder.builders import LocalContextBuilder
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
@@ -35,20 +35,20 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
def __init__(
self,
llm: BaseLLM,
model: ChatModel,
context_builder: LocalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
system_prompt: str | None = None,
response_type: str = "multiple paragraphs",
callbacks: list[QueryCallbacks] | None = None,
llm_params: dict[str, Any] = DEFAULT_LLM_PARAMS,
model_params: dict[str, Any] = DEFAULT_LLM_PARAMS,
context_builder_params: dict | None = None,
):
super().__init__(
llm=llm,
model=model,
context_builder=context_builder,
token_encoder=token_encoder,
llm_params=llm_params,
model_params=model_params,
context_builder_params=context_builder_params or {},
)
self.system_prompt = system_prompt or LOCAL_SEARCH_SYSTEM_PROMPT
@@ -89,26 +89,30 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
context_data=context_result.context_chunks,
response_type=self.response_type,
)
search_messages = [
history_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
response = await self.llm.agenerate(
messages=search_messages,
streaming=True,
callbacks=self.callbacks, # type: ignore
**self.llm_params,
)
full_response = ""
async for response in self.model.achat_stream(
prompt=query,
history=history_messages,
model_parameters=self.model_params,
):
full_response += response
for callback in self.callbacks:
callback.on_llm_new_token(response)
llm_calls["response"] = 1
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
output_tokens["response"] = num_tokens(response, self.token_encoder)
output_tokens["response"] = num_tokens(full_response, self.token_encoder)
for callback in self.callbacks:
callback.on_context(context_result.context_records)
return SearchResult(
response=response,
response=full_response,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
@@ -132,7 +136,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
output_tokens=0,
)
def stream_search(
async def stream_search(
self,
query: str,
conversation_history: ConversationHistory | None = None,
@@ -149,16 +153,18 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
search_prompt = self.system_prompt.format(
context_data=context_result.context_chunks, response_type=self.response_type
)
search_messages = [
history_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
]
for callback in self.callbacks:
callback.on_context(context_result.context_records)
return self.llm.astream_generate( # type: ignore
messages=search_messages,
callbacks=self.callbacks, # type: ignore
**self.llm_params,
)
async for response in self.model.achat_stream(
prompt=query,
history=history_messages,
model_parameters=self.model_params,
):
for callback in self.callbacks:
callback.on_llm_new_token(response)
yield response

View File

@@ -0,0 +1,86 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LLMFactory Tests.
These tests will test the LLMFactory class and the creation of custom and provided LLMs.
"""
from collections.abc import AsyncGenerator
from typing import Any
from graphrag.language_model.factory import ModelFactory
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.response.base import (
BaseModelOutput,
BaseModelResponse,
ModelResponse,
)
async def test_create_custom_chat_model():
class CustomChatModel:
def __init__(self, **kwargs):
pass
async def achat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> ModelResponse:
return BaseModelResponse(output=BaseModelOutput(content="content"))
def chat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> ModelResponse:
return BaseModelResponse(output=BaseModelOutput(content="content"))
async def achat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> AsyncGenerator[str, None]: ...
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> AsyncGenerator[str, None]: ...
ModelFactory.register_chat("custom_chat", CustomChatModel)
model = ModelManager().get_or_create_chat_model("custom", "custom_chat")
assert isinstance(model, CustomChatModel)
response = await model.achat("prompt")
assert response.output.content == "content"
response = model.chat("prompt")
assert response.output.content == "content"
async def test_create_custom_embedding_llm():
class CustomEmbeddingModel:
def __init__(self, **kwargs):
pass
async def aembed(self, text: str, **kwargs) -> list[float]:
return [1.0]
def embed(self, text: str, **kwargs) -> list[float]:
return [1.0]
async def aembed_batch(
self, text_list: list[str], **kwargs
) -> list[list[float]]:
return [[1.0]]
def embed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
return [[1.0]]
ModelFactory.register_embedding("custom_embedding", CustomEmbeddingModel)
llm = ModelManager().get_or_create_embedding_model("custom", "custom_embedding")
assert isinstance(llm, CustomEmbeddingModel)
response = await llm.aembed("text")
assert response == [1.0]
response = llm.embed("text")
assert response == [1.0]
response = await llm.aembed_batch(["text"])
assert response == [[1.0]]
response = llm.embed_batch(["text"])
assert response == [[1.0]]

View File

@@ -1,45 +0,0 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LLMFactory Tests.
These tests will test the LLMFactory class and the creation of custom and provided LLMs.
"""
from graphrag.language_model.factory import ModelFactory
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.response.base import (
BaseModelOutput,
BaseModelResponse,
ModelResponse,
)
async def test_create_custom_chat_llm():
class CustomChatLLM:
def __init__(self, **kwargs):
pass
async def chat(self, prompt: str, **kwargs) -> ModelResponse:
return BaseModelResponse(output=BaseModelOutput(content="content"))
ModelFactory.register_chat("custom_chat", CustomChatLLM)
llm = ModelManager().get_or_create_chat_model("custom", "custom_chat")
assert isinstance(llm, CustomChatLLM)
response = await llm.chat("prompt")
assert response.output.content == "content"
async def test_create_custom_embedding_llm():
class CustomEmbeddingLLM:
def __init__(self, **kwargs):
pass
async def embed(self, text: str | list[str], **kwargs) -> list[list[float]]:
return [[1.0]]
ModelFactory.register_embedding("custom_embedding", CustomEmbeddingLLM)
llm = ModelManager().get_or_create_embedding_model("custom", "custom_embedding")
assert isinstance(llm, CustomEmbeddingLLM)
response = await llm.embed("text")
assert response == [[1.0]]

View File

@@ -3,6 +3,7 @@
"""A module containing mock model provider definitions."""
from collections.abc import AsyncGenerator
from typing import Any
from pydantic import BaseModel
@@ -28,9 +29,38 @@ class MockChatLLM:
self.responses = config.responses if config and config.responses else responses
self.response_index = 0
async def chat(
async def achat(
self,
prompt: str,
history: list | None = None,
**kwargs,
) -> ModelResponse:
"""Return the next response in the list."""
return self.chat(prompt, history, **kwargs)
async def achat_stream(
self,
prompt: str,
history: list | None = None,
**kwargs,
) -> AsyncGenerator[str, None]:
"""Return the next response in the list."""
if not self.responses:
return
for response in self.responses:
response = (
response.model_dump_json()
if isinstance(response, BaseModel)
else response
)
yield response
def chat(
self,
prompt: str,
history: list | None = None,
**kwargs,
) -> ModelResponse:
"""Return the next response in the list."""
@@ -50,6 +80,15 @@ class MockChatLLM:
parsed_response=parsed_json,
)
def chat_stream(
self,
prompt: str,
history: list | None = None,
**kwargs,
) -> AsyncGenerator[str, None]:
"""Return the next response in the list."""
raise NotImplementedError
class MockEmbeddingLLM:
"""A mock embedding LLM provider."""
@@ -57,8 +96,24 @@ class MockEmbeddingLLM:
def __init__(self, **kwargs: Any):
pass
async def embed(self, text: str | list[str], **kwargs: Any) -> list[list[float]]:
def embed_batch(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
"""Generate an embedding for the input text."""
if isinstance(text, str):
if isinstance(text_list, str):
return [[1.0, 1.0, 1.0]]
return [[1.0, 1.0, 1.0] for _ in text]
return [[1.0, 1.0, 1.0] for _ in text_list]
def embed(self, text: str, **kwargs: Any) -> list[float]:
"""Generate an embedding for the input text."""
return [1.0, 1.0, 1.0]
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
"""Generate an embedding for the input text."""
return [1.0, 1.0, 1.0]
async def aembed_batch(
self, text_list: list[str], **kwargs: Any
) -> list[list[float]]:
"""Generate an embedding for the input text."""
if isinstance(text_list, str):
return [[1.0, 1.0, 1.0]]
return [[1.0, 1.0, 1.0] for _ in text_list]

View File

@@ -5,11 +5,11 @@ from typing import Any
from graphrag.data_model.entity import Entity
from graphrag.data_model.types import TextEmbedder
from graphrag.language_model.manager import ModelManager
from graphrag.query.context_builder.entity_extraction import (
EntityVectorStoreKey,
map_query_to_entities,
)
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.vector_stores.base import (
BaseVectorStore,
VectorStoreDocument,
@@ -60,14 +60,6 @@ class MockBaseVectorStore(BaseVectorStore):
return result
class MockBaseTextEmbedding(BaseTextEmbedding):
def embed(self, text: str, **kwargs: Any) -> list[float]:
return [len(text)]
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
return [len(text)]
def test_map_query_to_entities():
entities = [
Entity(
@@ -102,7 +94,9 @@ def test_map_query_to_entities():
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
for entity in entities
]),
text_embedder=MockBaseTextEmbedding(),
text_embedder=ModelManager().get_or_create_embedding_model(
model_type="mock_embedding", name="mock"
),
all_entities_dict={entity.id: entity for entity in entities},
embedding_vectorstore_key=EntityVectorStoreKey.ID,
k=1,
@@ -122,7 +116,9 @@ def test_map_query_to_entities():
VectorStoreDocument(id=entity.title, text=entity.title, vector=None)
for entity in entities
]),
text_embedder=MockBaseTextEmbedding(),
text_embedder=ModelManager().get_or_create_embedding_model(
model_type="mock_embedding", name="mock"
),
all_entities_dict={entity.id: entity for entity in entities},
embedding_vectorstore_key=EntityVectorStoreKey.TITLE,
k=1,
@@ -142,7 +138,9 @@ def test_map_query_to_entities():
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
for entity in entities
]),
text_embedder=MockBaseTextEmbedding(),
text_embedder=ModelManager().get_or_create_embedding_model(
model_type="mock_embedding", name="mock"
),
all_entities_dict={entity.id: entity for entity in entities},
embedding_vectorstore_key=EntityVectorStoreKey.ID,
k=2,
@@ -167,7 +165,9 @@ def test_map_query_to_entities():
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
for entity in entities
]),
text_embedder=MockBaseTextEmbedding(),
text_embedder=ModelManager().get_or_create_embedding_model(
model_type="mock_embedding", name="mock"
),
all_entities_dict={entity.id: entity for entity in entities},
embedding_vectorstore_key=EntityVectorStoreKey.TITLE,
k=2,