mirror of
https://github.com/microsoft/graphrag.git
synced 2025-03-11 01:26:14 +03:00
Feat/llm provider query (#1735)
* Add ModelProvider to Query package. * Spellcheck + others * Semver * Fix tests * Format * Fix Pyright * Fix tests * Fix for smoke tests
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Use ModelProvider for query module"
|
||||
}
|
||||
@@ -120,10 +120,7 @@ unhot
|
||||
groupby
|
||||
retryer
|
||||
agenerate
|
||||
aembed
|
||||
dedupe
|
||||
dropna
|
||||
dtypes
|
||||
notna
|
||||
|
||||
# LLM Terms
|
||||
@@ -131,6 +128,8 @@ AOAI
|
||||
embedder
|
||||
llm
|
||||
llms
|
||||
achat
|
||||
aembed
|
||||
|
||||
# Galaxy-Brain Terms
|
||||
Unipartite
|
||||
|
||||
@@ -28,3 +28,6 @@ class QueryCallbacks(BaseLLMCallback):
|
||||
|
||||
def on_reduce_response_end(self, reduce_response_output: str) -> None:
|
||||
"""Handle the end of reduce operation."""
|
||||
|
||||
def on_llm_new_token(self, token) -> None:
|
||||
"""Handle when a new token is generated."""
|
||||
|
||||
@@ -48,6 +48,8 @@ class BasicSearchDefaults:
|
||||
n: int = 1
|
||||
max_tokens: int = 12_000
|
||||
llm_max_tokens: int = 2000
|
||||
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -122,7 +124,9 @@ class DriftSearchDefaults:
|
||||
local_search_temperature: float = 0
|
||||
local_search_top_p: float = 1
|
||||
local_search_n: int = 1
|
||||
local_search_llm_max_gen_tokens: int = 12_000
|
||||
local_search_llm_max_gen_tokens: int = 4_096
|
||||
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -239,6 +243,8 @@ class GlobalSearchDefaults:
|
||||
dynamic_search_use_summary: bool = False
|
||||
dynamic_search_concurrent_coroutines: int = 16
|
||||
dynamic_search_max_level: int = 2
|
||||
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -305,6 +311,8 @@ class LocalSearchDefaults:
|
||||
n: int = 1
|
||||
max_tokens: int = 12_000
|
||||
llm_max_tokens: int = 2000
|
||||
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -145,18 +145,26 @@ snapshots:
|
||||
## See the config docs: https://microsoft.github.io/graphrag/config/yaml/#query
|
||||
|
||||
local_search:
|
||||
chat_model_id: {graphrag_config_defaults.local_search.chat_model_id}
|
||||
embedding_model_id: {graphrag_config_defaults.local_search.embedding_model_id}
|
||||
prompt: "prompts/local_search_system_prompt.txt"
|
||||
|
||||
global_search:
|
||||
chat_model_id: {graphrag_config_defaults.global_search.chat_model_id}
|
||||
embedding_model_id: {graphrag_config_defaults.global_search.embedding_model_id}
|
||||
map_prompt: "prompts/global_search_map_system_prompt.txt"
|
||||
reduce_prompt: "prompts/global_search_reduce_system_prompt.txt"
|
||||
knowledge_prompt: "prompts/global_search_knowledge_system_prompt.txt"
|
||||
|
||||
drift_search:
|
||||
chat_model_id: {graphrag_config_defaults.drift_search.chat_model_id}
|
||||
embedding_model_id: {graphrag_config_defaults.drift_search.embedding_model_id}
|
||||
prompt: "prompts/drift_search_system_prompt.txt"
|
||||
reduce_prompt: "prompts/drift_search_reduce_prompt.txt"
|
||||
|
||||
basic_search:
|
||||
chat_model_id: {graphrag_config_defaults.basic_search.chat_model_id}
|
||||
embedding_model_id: {graphrag_config_defaults.basic_search.embedding_model_id}
|
||||
prompt: "prompts/basic_search_system_prompt.txt"
|
||||
"""
|
||||
|
||||
|
||||
@@ -15,6 +15,14 @@ class BasicSearchConfig(BaseModel):
|
||||
description="The basic search prompt to use.",
|
||||
default=graphrag_config_defaults.basic_search.prompt,
|
||||
)
|
||||
chat_model_id: str = Field(
|
||||
description="The model ID to use for basic search.",
|
||||
default=graphrag_config_defaults.basic_search.chat_model_id,
|
||||
)
|
||||
embedding_model_id: str = Field(
|
||||
description="The model ID to use for text embeddings.",
|
||||
default=graphrag_config_defaults.basic_search.embedding_model_id,
|
||||
)
|
||||
text_unit_prop: float = Field(
|
||||
description="The text unit proportion.",
|
||||
default=graphrag_config_defaults.basic_search.text_unit_prop,
|
||||
|
||||
@@ -19,6 +19,14 @@ class DRIFTSearchConfig(BaseModel):
|
||||
description="The drift search reduce prompt to use.",
|
||||
default=graphrag_config_defaults.drift_search.reduce_prompt,
|
||||
)
|
||||
chat_model_id: str = Field(
|
||||
description="The model ID to use for drift search.",
|
||||
default=graphrag_config_defaults.drift_search.chat_model_id,
|
||||
)
|
||||
embedding_model_id: str = Field(
|
||||
description="The model ID to use for drift search.",
|
||||
default=graphrag_config_defaults.drift_search.embedding_model_id,
|
||||
)
|
||||
temperature: float = Field(
|
||||
description="The temperature to use for token generation.",
|
||||
default=graphrag_config_defaults.drift_search.temperature,
|
||||
|
||||
@@ -19,6 +19,10 @@ class GlobalSearchConfig(BaseModel):
|
||||
description="The global search reducer to use.",
|
||||
default=graphrag_config_defaults.global_search.reduce_prompt,
|
||||
)
|
||||
chat_model_id: str = Field(
|
||||
description="The model ID to use for global search.",
|
||||
default=graphrag_config_defaults.global_search.chat_model_id,
|
||||
)
|
||||
knowledge_prompt: str | None = Field(
|
||||
description="The global search general prompt to use.",
|
||||
default=graphrag_config_defaults.global_search.knowledge_prompt,
|
||||
|
||||
@@ -15,6 +15,14 @@ class LocalSearchConfig(BaseModel):
|
||||
description="The local search prompt to use.",
|
||||
default=graphrag_config_defaults.local_search.prompt,
|
||||
)
|
||||
chat_model_id: str = Field(
|
||||
description="The model ID to use for local search.",
|
||||
default=graphrag_config_defaults.local_search.chat_model_id,
|
||||
)
|
||||
embedding_model_id: str = Field(
|
||||
description="The model ID to use for text embeddings.",
|
||||
default=graphrag_config_defaults.local_search.embedding_model_id,
|
||||
)
|
||||
text_unit_prop: float = Field(
|
||||
description="The text unit proportion.",
|
||||
default=graphrag_config_defaults.local_search.text_unit_prop,
|
||||
|
||||
@@ -88,7 +88,7 @@ async def _execute(
|
||||
) -> list[list[float]]:
|
||||
async def embed(chunk: list[str]):
|
||||
async with semaphore:
|
||||
chunk_embeddings = await model.embed(chunk)
|
||||
chunk_embeddings = await model.aembed_batch(chunk)
|
||||
result = np.array(chunk_embeddings)
|
||||
tick(1)
|
||||
return result
|
||||
|
||||
@@ -166,7 +166,7 @@ class ClaimExtractor:
|
||||
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
|
||||
)
|
||||
|
||||
response = await self._model.chat(
|
||||
response = await self._model.achat(
|
||||
self._extraction_prompt.format(**{
|
||||
self._input_text_key: doc,
|
||||
**prompt_args,
|
||||
@@ -177,7 +177,7 @@ class ClaimExtractor:
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
response = await self._model.chat(
|
||||
response = await self._model.achat(
|
||||
CONTINUE_PROMPT,
|
||||
name=f"extract-continuation-{i}",
|
||||
history=response.history,
|
||||
@@ -191,7 +191,7 @@ class ClaimExtractor:
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
|
||||
response = await self._model.chat(
|
||||
response = await self._model.achat(
|
||||
LOOP_PROMPT,
|
||||
name=f"extract-loopcheck-{i}",
|
||||
history=response.history,
|
||||
|
||||
@@ -152,7 +152,7 @@ class GraphExtractor:
|
||||
async def _process_document(
|
||||
self, text: str, prompt_variables: dict[str, str]
|
||||
) -> str:
|
||||
response = await self._model.chat(
|
||||
response = await self._model.achat(
|
||||
self._extraction_prompt.format(**{
|
||||
**prompt_variables,
|
||||
self._input_text_key: text,
|
||||
@@ -162,7 +162,7 @@ class GraphExtractor:
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
response = await self._model.chat(
|
||||
response = await self._model.achat(
|
||||
CONTINUE_PROMPT,
|
||||
name=f"extract-continuation-{i}",
|
||||
history=response.history,
|
||||
@@ -173,7 +173,7 @@ class GraphExtractor:
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
|
||||
response = await self._model.chat(
|
||||
response = await self._model.achat(
|
||||
LOOP_PROMPT,
|
||||
name=f"extract-loopcheck-{i}",
|
||||
history=response.history,
|
||||
|
||||
@@ -78,7 +78,7 @@ class CommunityReportsExtractor:
|
||||
prompt = self._extraction_prompt.replace(
|
||||
"{" + self._input_text_key + "}", input_text
|
||||
)
|
||||
response = await self._model.chat(
|
||||
response = await self._model.achat(
|
||||
prompt,
|
||||
json=True, # Leaving this as True to avoid creating new cache entries
|
||||
name="create_community_report",
|
||||
|
||||
@@ -125,7 +125,7 @@ class SummarizeExtractor:
|
||||
self, id: str | tuple[str, str] | list[str], descriptions: list[str]
|
||||
):
|
||||
"""Summarize descriptions using the LLM."""
|
||||
response = await self._model.chat(
|
||||
response = await self._model.achat(
|
||||
self._summarization_prompt.format(**{
|
||||
self._entity_name_key: json.dumps(id, ensure_ascii=False),
|
||||
self._input_descriptions_key: json.dumps(
|
||||
|
||||
@@ -30,7 +30,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(llm.chat("This is an LLM connectivity test. Say Hello World"))
|
||||
asyncio.run(llm.achat("This is an LLM connectivity test. Say Hello World"))
|
||||
logger.success("LLM Config Params Validated")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"LLM configuration error detected. Exiting...\n{e}") # noqa
|
||||
@@ -49,7 +49,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(embed_llm.embed(["This is an LLM Embedding Test String"]))
|
||||
asyncio.run(embed_llm.aembed_batch(["This is an LLM Embedding Test String"]))
|
||||
logger.success("Embedding LLM Config Params Validated")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Embedding LLM configuration error detected. Exiting...\n{e}") # noqa
|
||||
|
||||
@@ -8,6 +8,8 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from graphrag.language_model.response.base import ModelResponse
|
||||
|
||||
|
||||
@@ -18,7 +20,51 @@ class EmbeddingModel(Protocol):
|
||||
This protocol defines the methods required for an embedding-based LM.
|
||||
"""
|
||||
|
||||
async def embed(self, text: str | list[str], **kwargs: Any) -> list[list[float]]:
|
||||
async def aembed_batch(
|
||||
self, text_list: list[str], **kwargs: Any
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Generate an embedding vector for the given list of strings.
|
||||
|
||||
Args:
|
||||
text: The text to generate an embedding for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A collections of list of floats representing the embedding vector for each item in the batch.
|
||||
"""
|
||||
...
|
||||
|
||||
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""
|
||||
Generate an embedding vector for the given text.
|
||||
|
||||
Args:
|
||||
text: The text to generate an embedding for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of floats representing the embedding vector.
|
||||
"""
|
||||
...
|
||||
|
||||
def embed_batch(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
|
||||
"""
|
||||
Generate an embedding vector for the given list of strings.
|
||||
|
||||
Args:
|
||||
text: The text to generate an embedding for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A collections of list of floats representing the embedding vector for each item in the batch.
|
||||
"""
|
||||
...
|
||||
|
||||
def embed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""
|
||||
Generate an embedding vector for the given text.
|
||||
|
||||
@@ -41,12 +87,15 @@ class ChatModel(Protocol):
|
||||
Prompt is always required for the chat method, and any other keyword arguments are forwarded to the Model provider.
|
||||
"""
|
||||
|
||||
async def chat(self, prompt: str, **kwargs: Any) -> ModelResponse:
|
||||
async def achat(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Generate a response for the given text.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate a response for.
|
||||
history: The conversation history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
@@ -55,3 +104,55 @@ class ChatModel(Protocol):
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
async def achat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Generate a response for the given text using a streaming interface.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate a response for.
|
||||
history: The conversation history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator that yields strings representing the response.
|
||||
"""
|
||||
...
|
||||
|
||||
def chat(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Generate a response for the given text.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate a response for.
|
||||
history: The conversation history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A string representing the response.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def chat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Generate a response for the given text using a streaming interface.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate a response for.
|
||||
history: The conversation history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator that yields strings representing the response.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -3,13 +3,15 @@
|
||||
|
||||
"""A module containing fnllm model provider definitions."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fnllm.openai import (
|
||||
create_openai_chat_llm,
|
||||
create_openai_client,
|
||||
create_openai_embeddings_llm,
|
||||
)
|
||||
from fnllm.types import ChatLLM as FNLLMChatLLM
|
||||
from fnllm.types import EmbeddingsLLM as FNLLMEmbeddingLLM
|
||||
from fnllm.openai.types.client import OpenAIChatLLM as FNLLMChatLLM
|
||||
from fnllm.openai.types.client import OpenAIEmbeddingsLLM as FNLLMEmbeddingLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
@@ -21,6 +23,7 @@ from graphrag.language_model.providers.fnllm.utils import (
|
||||
_create_cache,
|
||||
_create_error_handler,
|
||||
_create_openai_config,
|
||||
run_coroutine_sync,
|
||||
)
|
||||
from graphrag.language_model.response.base import (
|
||||
BaseModelOutput,
|
||||
@@ -39,21 +42,23 @@ class OpenAIChatFNLLM:
|
||||
*,
|
||||
name: str,
|
||||
config: LanguageModelConfig,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache | None,
|
||||
callbacks: WorkflowCallbacks | None = None,
|
||||
cache: PipelineCache | None = None,
|
||||
) -> None:
|
||||
model_config = _create_openai_config(config, False)
|
||||
error_handler = _create_error_handler(callbacks)
|
||||
model_config = _create_openai_config(config, azure=False)
|
||||
error_handler = _create_error_handler(callbacks) if callbacks else None
|
||||
model_cache = _create_cache(cache, name)
|
||||
client = create_openai_client(model_config)
|
||||
self.model = create_openai_chat_llm(
|
||||
model_config,
|
||||
client=client,
|
||||
cache=model_cache,
|
||||
events=FNLLMEvents(error_handler),
|
||||
events=FNLLMEvents(error_handler) if error_handler else None,
|
||||
)
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> ModelResponse:
|
||||
async def achat(
|
||||
self, prompt: str, history: list | None = None, **kwargs
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Chat with the Model using the given prompt.
|
||||
|
||||
@@ -65,7 +70,10 @@ class OpenAIChatFNLLM:
|
||||
-------
|
||||
The response from the Model.
|
||||
"""
|
||||
response = await self.model(prompt, **kwargs)
|
||||
if history is None:
|
||||
response = await self.model(prompt, **kwargs)
|
||||
else:
|
||||
response = await self.model(prompt, history=history, **kwargs)
|
||||
return BaseModelResponse(
|
||||
output=BaseModelOutput(content=response.output.content),
|
||||
parsed_response=response.parsed_json,
|
||||
@@ -75,6 +83,59 @@ class OpenAIChatFNLLM:
|
||||
metrics=response.metrics,
|
||||
)
|
||||
|
||||
async def achat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream Chat with the Model using the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to chat with.
|
||||
kwargs: Additional arguments to pass to the Model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator that yields strings representing the response.
|
||||
"""
|
||||
if history is None:
|
||||
response = await self.model(prompt, stream=True, **kwargs)
|
||||
else:
|
||||
response = await self.model(prompt, history=history, stream=True, **kwargs)
|
||||
async for chunk in response.output.content:
|
||||
if chunk is not None:
|
||||
yield chunk
|
||||
|
||||
def chat(self, prompt: str, history: list | None = None, **kwargs) -> ModelResponse:
|
||||
"""
|
||||
Chat with the Model using the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to chat with.
|
||||
kwargs: Additional arguments to pass to the Model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The response from the Model.
|
||||
"""
|
||||
return run_coroutine_sync(self.achat(prompt, history=history, **kwargs))
|
||||
|
||||
def chat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream Chat with the Model using the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to chat with.
|
||||
kwargs: Additional arguments to pass to the Model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator that yields strings representing the response.
|
||||
"""
|
||||
msg = "chat_stream is not supported for synchronous execution"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
class OpenAIEmbeddingFNLLM:
|
||||
"""An OpenAI Embedding Model provider using the fnllm library."""
|
||||
@@ -86,21 +147,21 @@ class OpenAIEmbeddingFNLLM:
|
||||
*,
|
||||
name: str,
|
||||
config: LanguageModelConfig,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache | None,
|
||||
callbacks: WorkflowCallbacks | None = None,
|
||||
cache: PipelineCache | None = None,
|
||||
) -> None:
|
||||
model_config = _create_openai_config(config, False)
|
||||
error_handler = _create_error_handler(callbacks)
|
||||
model_config = _create_openai_config(config, azure=False)
|
||||
error_handler = _create_error_handler(callbacks) if callbacks else None
|
||||
model_cache = _create_cache(cache, name)
|
||||
client = create_openai_client(model_config)
|
||||
self.model = create_openai_embeddings_llm(
|
||||
model_config,
|
||||
client=client,
|
||||
cache=model_cache,
|
||||
events=FNLLMEvents(error_handler),
|
||||
events=FNLLMEvents(error_handler) if error_handler else None,
|
||||
)
|
||||
|
||||
async def embed(self, text: str | list[str], **kwargs) -> list[list[float]]:
|
||||
async def aembed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
|
||||
"""
|
||||
Embed the given text using the Model.
|
||||
|
||||
@@ -112,13 +173,60 @@ class OpenAIEmbeddingFNLLM:
|
||||
-------
|
||||
The embeddings of the text.
|
||||
"""
|
||||
response = await self.model(text, **kwargs)
|
||||
response = await self.model(text_list, **kwargs)
|
||||
if response.output.embeddings is None:
|
||||
msg = "No embeddings found in response"
|
||||
raise ValueError(msg)
|
||||
embeddings: list[list[float]] = response.output.embeddings
|
||||
return embeddings
|
||||
|
||||
async def aembed(self, text: str, **kwargs) -> list[float]:
|
||||
"""
|
||||
Embed the given text using the Model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
kwargs: Additional arguments to pass to the Model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The embeddings of the text.
|
||||
"""
|
||||
response = await self.model([text], **kwargs)
|
||||
if response.output.embeddings is None:
|
||||
msg = "No embeddings found in response"
|
||||
raise ValueError(msg)
|
||||
embeddings: list[float] = response.output.embeddings[0]
|
||||
return embeddings
|
||||
|
||||
def embed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
|
||||
"""
|
||||
Embed the given text using the Model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
kwargs: Additional arguments to pass to the LLM.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The embeddings of the text.
|
||||
"""
|
||||
return run_coroutine_sync(self.aembed_batch(text_list, **kwargs))
|
||||
|
||||
def embed(self, text: str, **kwargs) -> list[float]:
|
||||
"""
|
||||
Embed the given text using the Model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
kwargs: Additional arguments to pass to the Model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The embeddings of the text.
|
||||
"""
|
||||
return run_coroutine_sync(self.aembed(text, **kwargs))
|
||||
|
||||
|
||||
class AzureOpenAIChatFNLLM:
|
||||
"""An Azure OpenAI Chat LLM provider using the fnllm library."""
|
||||
@@ -130,21 +238,72 @@ class AzureOpenAIChatFNLLM:
|
||||
*,
|
||||
name: str,
|
||||
config: LanguageModelConfig,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache | None,
|
||||
callbacks: WorkflowCallbacks | None = None,
|
||||
cache: PipelineCache | None = None,
|
||||
) -> None:
|
||||
model_config = _create_openai_config(config, True)
|
||||
error_handler = _create_error_handler(callbacks)
|
||||
model_config = _create_openai_config(config, azure=True)
|
||||
error_handler = _create_error_handler(callbacks) if callbacks else None
|
||||
model_cache = _create_cache(cache, name)
|
||||
client = create_openai_client(model_config)
|
||||
self.model = create_openai_chat_llm(
|
||||
model_config,
|
||||
client=client,
|
||||
cache=model_cache,
|
||||
events=FNLLMEvents(error_handler),
|
||||
events=FNLLMEvents(error_handler) if error_handler else None,
|
||||
)
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> ModelResponse:
|
||||
async def achat(
|
||||
self, prompt: str, history: list | None = None, **kwargs
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Chat with the Model using the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to chat with.
|
||||
history: The conversation history.
|
||||
kwargs: Additional arguments to pass to the Model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The response from the Model.
|
||||
"""
|
||||
if history is None:
|
||||
response = await self.model(prompt, **kwargs)
|
||||
else:
|
||||
response = await self.model(prompt, history=history, **kwargs)
|
||||
return BaseModelResponse(
|
||||
output=BaseModelOutput(content=response.output.content),
|
||||
parsed_response=response.parsed_json,
|
||||
history=response.history,
|
||||
cache_hit=response.cache_hit,
|
||||
tool_calls=response.tool_calls,
|
||||
metrics=response.metrics,
|
||||
)
|
||||
|
||||
async def achat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream Chat with the Model using the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to chat with.
|
||||
history: The conversation history.
|
||||
kwargs: Additional arguments to pass to the Model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator that yields strings representing the response.
|
||||
"""
|
||||
if history is None:
|
||||
response = await self.model(prompt, stream=True, **kwargs)
|
||||
else:
|
||||
response = await self.model(prompt, history=history, stream=True, **kwargs)
|
||||
async for chunk in response.output.content:
|
||||
if chunk is not None:
|
||||
yield chunk
|
||||
|
||||
def chat(self, prompt: str, history: list | None = None, **kwargs) -> ModelResponse:
|
||||
"""
|
||||
Chat with the Model using the given prompt.
|
||||
|
||||
@@ -156,15 +315,24 @@ class AzureOpenAIChatFNLLM:
|
||||
-------
|
||||
The response from the Model.
|
||||
"""
|
||||
response = await self.model(prompt, **kwargs)
|
||||
return BaseModelResponse(
|
||||
output=BaseModelOutput(content=response.output.content),
|
||||
parsed_response=response.parsed_json,
|
||||
history=response.history,
|
||||
cache_hit=response.cache_hit,
|
||||
tool_calls=response.tool_calls,
|
||||
metrics=response.metrics,
|
||||
)
|
||||
return run_coroutine_sync(self.achat(prompt, history=history, **kwargs))
|
||||
|
||||
def chat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream Chat with the Model using the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to chat with.
|
||||
kwargs: Additional arguments to pass to the Model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator that yields strings representing the response.
|
||||
"""
|
||||
msg = "chat_stream is not supported for synchronous execution"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
class AzureOpenAIEmbeddingFNLLM:
|
||||
@@ -177,21 +345,21 @@ class AzureOpenAIEmbeddingFNLLM:
|
||||
*,
|
||||
name: str,
|
||||
config: LanguageModelConfig,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache | None,
|
||||
callbacks: WorkflowCallbacks | None = None,
|
||||
cache: PipelineCache | None = None,
|
||||
) -> None:
|
||||
model_config = _create_openai_config(config, True)
|
||||
error_handler = _create_error_handler(callbacks)
|
||||
model_config = _create_openai_config(config, azure=True)
|
||||
error_handler = _create_error_handler(callbacks) if callbacks else None
|
||||
model_cache = _create_cache(cache, name)
|
||||
client = create_openai_client(model_config)
|
||||
self.model = create_openai_embeddings_llm(
|
||||
model_config,
|
||||
client=client,
|
||||
cache=model_cache,
|
||||
events=FNLLMEvents(error_handler),
|
||||
events=FNLLMEvents(error_handler) if error_handler else None,
|
||||
)
|
||||
|
||||
async def embed(self, text: str | list[str], **kwargs) -> list[list[float]]:
|
||||
async def aembed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
|
||||
"""
|
||||
Embed the given text using the Model.
|
||||
|
||||
@@ -203,9 +371,56 @@ class AzureOpenAIEmbeddingFNLLM:
|
||||
-------
|
||||
The embeddings of the text.
|
||||
"""
|
||||
response = await self.model(text, **kwargs)
|
||||
response = await self.model(text_list, **kwargs)
|
||||
if response.output.embeddings is None:
|
||||
msg = "No embeddings found in response"
|
||||
raise ValueError(msg)
|
||||
embeddings: list[list[float]] = response.output.embeddings
|
||||
return embeddings
|
||||
|
||||
async def aembed(self, text: str, **kwargs) -> list[float]:
|
||||
"""
|
||||
Embed the given text using the Model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
kwargs: Additional arguments to pass to the Model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The embeddings of the text.
|
||||
"""
|
||||
response = await self.model([text], **kwargs)
|
||||
if response.output.embeddings is None:
|
||||
msg = "No embeddings found in response"
|
||||
raise ValueError(msg)
|
||||
embeddings: list[float] = response.output.embeddings[0]
|
||||
return embeddings
|
||||
|
||||
def embed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
|
||||
"""
|
||||
Embed the given text using the Model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
kwargs: Additional arguments to pass to the Model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The embeddings of the text.
|
||||
"""
|
||||
return run_coroutine_sync(self.aembed_batch(text_list, **kwargs))
|
||||
|
||||
def embed(self, text: str, **kwargs) -> list[float]:
|
||||
"""
|
||||
Embed the given text using the Model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
kwargs: Additional arguments to pass to the Model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The embeddings of the text.
|
||||
"""
|
||||
return run_coroutine_sync(self.aembed(text, **kwargs))
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
|
||||
"""A module containing utils for fnllm."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from fnllm.base.config import JsonStrategy, RetryStrategy
|
||||
from fnllm.openai import AzureOpenAIConfig, OpenAIConfig, PublicOpenAIConfig
|
||||
from fnllm.openai.types.chat.parameters import OpenAIChatParameters
|
||||
@@ -25,6 +30,8 @@ def _create_cache(cache: PipelineCache | None, name: str) -> FNLLMCacheProvider
|
||||
|
||||
|
||||
def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
|
||||
"""Create an error handler from a WorkflowCallbacks."""
|
||||
|
||||
def on_error(
|
||||
error: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
@@ -36,6 +43,7 @@ def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
|
||||
|
||||
|
||||
def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAIConfig:
|
||||
"""Create an OpenAIConfig from a LanguageModelConfig."""
|
||||
encoding_model = config.encoding_model
|
||||
json_strategy = (
|
||||
JsonStrategy.VALID if config.model_supports_json else JsonStrategy.LOOSE
|
||||
@@ -92,3 +100,28 @@ def _create_openai_config(config: LanguageModelConfig, azure: bool) -> OpenAICon
|
||||
chat_parameters=chat_parameters,
|
||||
sleep_on_rate_limit_recommendation=True,
|
||||
)
|
||||
|
||||
|
||||
# FNLLM does not support sync operations, so we workaround running in an available loop/thread.
|
||||
T = TypeVar("T")
|
||||
|
||||
_loop = asyncio.new_event_loop()
|
||||
|
||||
_thr = threading.Thread(target=_loop.run_forever, name="Async Runner", daemon=True)
|
||||
|
||||
|
||||
def run_coroutine_sync(coroutine: Coroutine[Any, Any, T]) -> T:
|
||||
"""
|
||||
Run a coroutine synchronously.
|
||||
|
||||
Args:
|
||||
coroutine: The coroutine to run.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The result of the coroutine.
|
||||
"""
|
||||
if not _thr.is_alive():
|
||||
_thr.start()
|
||||
future = asyncio.run_coroutine_threadsafe(coroutine, _loop)
|
||||
return future.result()
|
||||
|
||||
@@ -30,6 +30,6 @@ async def generate_community_report_rating(
|
||||
domain=domain, persona=persona, input_text=docs_str
|
||||
)
|
||||
|
||||
response = await model.chat(domain_prompt)
|
||||
response = await model.achat(domain_prompt)
|
||||
|
||||
return str(response.output.content).strip()
|
||||
|
||||
@@ -30,6 +30,6 @@ async def generate_community_reporter_role(
|
||||
domain=domain, persona=persona, input_text=docs_str
|
||||
)
|
||||
|
||||
response = await model.chat(domain_prompt)
|
||||
response = await model.achat(domain_prompt)
|
||||
|
||||
return str(response.output.content)
|
||||
|
||||
@@ -22,6 +22,6 @@ async def generate_domain(model: ChatModel, docs: str | list[str]) -> str:
|
||||
docs_str = " ".join(docs) if isinstance(docs, list) else docs
|
||||
domain_prompt = GENERATE_DOMAIN_PROMPT.format(input_text=docs_str)
|
||||
|
||||
response = await model.chat(domain_prompt)
|
||||
response = await model.achat(domain_prompt)
|
||||
|
||||
return str(response.output.content)
|
||||
|
||||
@@ -57,7 +57,7 @@ async def generate_entity_relationship_examples(
|
||||
messages = messages[:MAX_EXAMPLES]
|
||||
|
||||
tasks = [
|
||||
model.chat(message, history=history, json=json_mode) for message in messages
|
||||
model.achat(message, history=history, json=json_mode) for message in messages
|
||||
]
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
@@ -46,11 +46,11 @@ async def generate_entity_types(
|
||||
history = [{"role": "system", "content": persona}]
|
||||
|
||||
if json_mode:
|
||||
response = await model.chat(
|
||||
response = await model.achat(
|
||||
entity_types_prompt, history=history, json_model=EntityTypesResponse
|
||||
)
|
||||
parsed_model = response.parsed_response
|
||||
return parsed_model.entity_types if parsed_model else []
|
||||
|
||||
response = await model.chat(entity_types_prompt, history=history, json=json_mode)
|
||||
response = await model.achat(entity_types_prompt, history=history, json=json_mode)
|
||||
return str(response.output.content)
|
||||
|
||||
@@ -22,6 +22,6 @@ async def detect_language(model: ChatModel, docs: str | list[str]) -> str:
|
||||
docs_str = " ".join(docs) if isinstance(docs, list) else docs
|
||||
language_prompt = DETECT_LANGUAGE_PROMPT.format(input_text=docs_str)
|
||||
|
||||
response = await model.chat(language_prompt)
|
||||
response = await model.achat(language_prompt)
|
||||
|
||||
return str(response.output.content)
|
||||
|
||||
@@ -22,6 +22,6 @@ async def generate_persona(
|
||||
formatted_task = task.format(domain=domain)
|
||||
persona_prompt = GENERATE_PERSONA_PROMPT.format(sample_task=formatted_task)
|
||||
|
||||
response = await model.chat(persona_prompt)
|
||||
response = await model.achat(persona_prompt)
|
||||
|
||||
return str(response.output.content)
|
||||
|
||||
@@ -30,7 +30,9 @@ async def _embed_chunks(
|
||||
) -> tuple[pd.DataFrame, np.ndarray]:
|
||||
"""Convert text chunks into dense text embeddings."""
|
||||
sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks)))
|
||||
embeddings = await embedding_llm.embed(sampled_text_chunks["chunks"].tolist())
|
||||
embeddings = await embedding_llm.aembed_batch(
|
||||
sampled_text_chunks["chunks"].tolist()
|
||||
)
|
||||
return text_chunks, np.array(embeddings)
|
||||
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ class DRIFTContextBuilder(ABC):
|
||||
"""Base class for DRIFT-search context builders."""
|
||||
|
||||
@abstractmethod
|
||||
def build_context(
|
||||
async def build_context(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs,
|
||||
|
||||
@@ -14,9 +14,9 @@ import tiktoken
|
||||
|
||||
from graphrag.data_model.community import Community
|
||||
from graphrag.data_model.community_report import CommunityReport
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
|
||||
from graphrag.query.context_builder.rate_relevancy import rate_relevancy
|
||||
from graphrag.query.llm.base import BaseLLM
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,7 +33,7 @@ class DynamicCommunitySelection:
|
||||
self,
|
||||
community_reports: list[CommunityReport],
|
||||
communities: list[Community],
|
||||
llm: BaseLLM,
|
||||
model: ChatModel,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
rate_query: str = RATE_QUERY,
|
||||
use_summary: bool = False,
|
||||
@@ -44,7 +44,7 @@ class DynamicCommunitySelection:
|
||||
concurrent_coroutines: int = 8,
|
||||
llm_kwargs: Any = DEFAULT_RATE_LLM_PARAMS,
|
||||
):
|
||||
self.llm = llm
|
||||
self.model = model
|
||||
self.token_encoder = token_encoder
|
||||
self.rate_query = rate_query
|
||||
self.num_repeats = num_repeats
|
||||
@@ -98,7 +98,7 @@ class DynamicCommunitySelection:
|
||||
if self.use_summary
|
||||
else self.reports[community].full_content
|
||||
),
|
||||
llm=self.llm,
|
||||
model=self.model,
|
||||
token_encoder=self.token_encoder,
|
||||
rate_query=self.rate_query,
|
||||
num_repeats=self.num_repeats,
|
||||
|
||||
@@ -7,12 +7,12 @@ from enum import Enum
|
||||
|
||||
from graphrag.data_model.entity import Entity
|
||||
from graphrag.data_model.relationship import Relationship
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.query.input.retrieval.entities import (
|
||||
get_entity_by_id,
|
||||
get_entity_by_key,
|
||||
get_entity_by_name,
|
||||
)
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class EntityVectorStoreKey(str, Enum):
|
||||
def map_query_to_entities(
|
||||
query: str,
|
||||
text_embedding_vectorstore: BaseVectorStore,
|
||||
text_embedder: BaseTextEmbedding,
|
||||
text_embedder: EmbeddingModel,
|
||||
all_entities_dict: dict[str, Entity],
|
||||
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
|
||||
include_entity_names: list[str] | None = None,
|
||||
|
||||
@@ -11,8 +11,8 @@ from typing import Any
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
|
||||
from graphrag.query.llm.base import BaseLLM
|
||||
from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -21,7 +21,7 @@ log = logging.getLogger(__name__)
|
||||
async def rate_relevancy(
|
||||
query: str,
|
||||
description: str,
|
||||
llm: BaseLLM,
|
||||
model: ChatModel,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
rate_query: str = RATE_QUERY,
|
||||
num_repeats: int = 1,
|
||||
@@ -47,11 +47,13 @@ async def rate_relevancy(
|
||||
"role": "system",
|
||||
"content": rate_query.format(description=description, question=query),
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
for _ in range(num_repeats):
|
||||
async with semaphore if semaphore is not None else nullcontext():
|
||||
response = await llm.agenerate(messages=messages, **llm_kwargs)
|
||||
model_response = await model.achat(
|
||||
prompt=query, history=messages, model_parameters=llm_kwargs, json=True
|
||||
)
|
||||
response = model_response.output.content
|
||||
try:
|
||||
_, parsed_response = try_parse_json_object(response)
|
||||
ratings.append(parsed_response["rating"])
|
||||
|
||||
@@ -13,8 +13,8 @@ from graphrag.data_model.covariate import Covariate
|
||||
from graphrag.data_model.entity import Entity
|
||||
from graphrag.data_model.relationship import Relationship
|
||||
from graphrag.data_model.text_unit import TextUnit
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.get_client import get_llm, get_text_embedder
|
||||
from graphrag.query.structured_search.basic_search.basic_context import (
|
||||
BasicSearchContext,
|
||||
)
|
||||
@@ -47,15 +47,38 @@ def get_local_search_engine(
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
) -> LocalSearch:
|
||||
"""Create a local search engine based on data + configuration."""
|
||||
default_llm_settings = config.get_language_model_config("default_chat_model")
|
||||
llm = get_llm(config)
|
||||
text_embedder = get_text_embedder(config)
|
||||
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
|
||||
model_settings = config.get_language_model_config(config.local_search.chat_model_id)
|
||||
|
||||
if model_settings.max_retries == -1:
|
||||
model_settings.max_retries = (
|
||||
len(reports) + len(entities) + len(relationships) + len(covariates)
|
||||
)
|
||||
|
||||
chat_model = ModelManager().get_or_create_chat_model(
|
||||
name="local_search_chat",
|
||||
model_type=model_settings.type,
|
||||
config=model_settings,
|
||||
)
|
||||
|
||||
embedding_settings = config.get_language_model_config(
|
||||
config.local_search.embedding_model_id
|
||||
)
|
||||
if embedding_settings.max_retries == -1:
|
||||
embedding_settings.max_retries = (
|
||||
len(reports) + len(entities) + len(relationships)
|
||||
)
|
||||
embedding_model = ModelManager().get_or_create_embedding_model(
|
||||
name="local_search_embedding",
|
||||
model_type=embedding_settings.type,
|
||||
config=embedding_settings,
|
||||
)
|
||||
|
||||
token_encoder = tiktoken.get_encoding(model_settings.encoding_model)
|
||||
|
||||
ls_config = config.local_search
|
||||
|
||||
return LocalSearch(
|
||||
llm=llm,
|
||||
model=chat_model,
|
||||
system_prompt=system_prompt,
|
||||
context_builder=LocalSearchMixedContext(
|
||||
community_reports=reports,
|
||||
@@ -65,11 +88,11 @@ def get_local_search_engine(
|
||||
covariates=covariates,
|
||||
entity_text_embeddings=description_embedding_store,
|
||||
embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE
|
||||
text_embedder=text_embedder,
|
||||
text_embedder=embedding_model,
|
||||
token_encoder=token_encoder,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
llm_params={
|
||||
model_params={
|
||||
"max_tokens": ls_config.llm_max_tokens, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500)
|
||||
"temperature": ls_config.temperature,
|
||||
"top_p": ls_config.top_p,
|
||||
@@ -108,10 +131,20 @@ def get_global_search_engine(
|
||||
) -> GlobalSearch:
|
||||
"""Create a global search engine based on data + configuration."""
|
||||
# TODO: Global search should select model based on config??
|
||||
default_llm_settings = config.get_language_model_config("default_chat_model")
|
||||
model_settings = config.get_language_model_config(
|
||||
config.global_search.chat_model_id
|
||||
)
|
||||
|
||||
if model_settings.max_retries == -1:
|
||||
model_settings.max_retries = len(reports) + len(entities)
|
||||
model = ModelManager().get_or_create_chat_model(
|
||||
name="global_search",
|
||||
model_type=model_settings.type,
|
||||
config=model_settings,
|
||||
)
|
||||
|
||||
# Here we get encoding based on specified encoding name
|
||||
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
|
||||
token_encoder = tiktoken.get_encoding(model_settings.encoding_model)
|
||||
gs_config = config.global_search
|
||||
|
||||
dynamic_community_selection_kwargs = {}
|
||||
@@ -119,9 +152,9 @@ def get_global_search_engine(
|
||||
# TODO: Allow for another llm definition only for Global Search to leverage -mini models
|
||||
|
||||
dynamic_community_selection_kwargs.update({
|
||||
"llm": get_llm(config),
|
||||
"model": model,
|
||||
# And here we get encoding based on model
|
||||
"token_encoder": tiktoken.encoding_for_model(default_llm_settings.model),
|
||||
"token_encoder": tiktoken.encoding_for_model(model_settings.model),
|
||||
"keep_parent": gs_config.dynamic_search_keep_parent,
|
||||
"num_repeats": gs_config.dynamic_search_num_repeats,
|
||||
"use_summary": gs_config.dynamic_search_use_summary,
|
||||
@@ -131,7 +164,7 @@ def get_global_search_engine(
|
||||
})
|
||||
|
||||
return GlobalSearch(
|
||||
llm=get_llm(config),
|
||||
model=model,
|
||||
map_system_prompt=map_system_prompt,
|
||||
reduce_system_prompt=reduce_system_prompt,
|
||||
general_knowledge_inclusion_prompt=general_knowledge_inclusion_prompt,
|
||||
@@ -190,16 +223,43 @@ def get_drift_search_engine(
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
) -> DRIFTSearch:
|
||||
"""Create a local search engine based on data + configuration."""
|
||||
default_llm_settings = config.get_language_model_config("default_chat_model")
|
||||
llm = get_llm(config)
|
||||
text_embedder = get_text_embedder(config)
|
||||
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
|
||||
chat_model_settings = config.get_language_model_config(
|
||||
config.drift_search.chat_model_id
|
||||
)
|
||||
|
||||
if chat_model_settings.max_retries == -1:
|
||||
chat_model_settings.max_retries = (
|
||||
config.drift_search.drift_k_followups
|
||||
* config.drift_search.primer_folds
|
||||
* config.drift_search.n_depth
|
||||
)
|
||||
|
||||
chat_model = ModelManager().get_or_create_chat_model(
|
||||
name="drift_search_chat",
|
||||
model_type=chat_model_settings.type,
|
||||
config=chat_model_settings,
|
||||
)
|
||||
|
||||
embedding_model_settings = config.get_language_model_config(
|
||||
config.drift_search.embedding_model_id
|
||||
)
|
||||
if embedding_model_settings.max_retries == -1:
|
||||
embedding_model_settings.max_retries = (
|
||||
len(reports) + len(entities) + len(relationships)
|
||||
)
|
||||
|
||||
embedding_model = ModelManager().get_or_create_embedding_model(
|
||||
name="drift_search_embedding",
|
||||
model_type=embedding_model_settings.type,
|
||||
config=embedding_model_settings,
|
||||
)
|
||||
token_encoder = tiktoken.get_encoding(chat_model_settings.encoding_model)
|
||||
|
||||
return DRIFTSearch(
|
||||
llm=llm,
|
||||
model=chat_model,
|
||||
context_builder=DRIFTSearchContextBuilder(
|
||||
chat_llm=llm,
|
||||
text_embedder=text_embedder,
|
||||
model=chat_model,
|
||||
text_embedder=embedding_model,
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
reports=reports,
|
||||
@@ -223,18 +283,40 @@ def get_basic_search_engine(
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
) -> BasicSearch:
|
||||
"""Create a basic search engine based on data + configuration."""
|
||||
default_llm_settings = config.get_language_model_config("default_chat_model")
|
||||
llm = get_llm(config)
|
||||
text_embedder = get_text_embedder(config)
|
||||
token_encoder = tiktoken.get_encoding(default_llm_settings.encoding_model)
|
||||
chat_model_settings = config.get_language_model_config(
|
||||
config.basic_search.chat_model_id
|
||||
)
|
||||
|
||||
if chat_model_settings.max_retries == -1:
|
||||
chat_model_settings.max_retries = len(text_units)
|
||||
|
||||
chat_model = ModelManager().get_or_create_chat_model(
|
||||
name="basic_search_chat",
|
||||
model_type=chat_model_settings.type,
|
||||
config=chat_model_settings,
|
||||
)
|
||||
|
||||
embedding_model_settings = config.get_language_model_config(
|
||||
config.basic_search.embedding_model_id
|
||||
)
|
||||
if embedding_model_settings.max_retries == -1:
|
||||
embedding_model_settings.max_retries = len(text_units)
|
||||
|
||||
embedding_model = ModelManager().get_or_create_embedding_model(
|
||||
name="basic_search_embedding",
|
||||
model_type=embedding_model_settings.type,
|
||||
config=embedding_model_settings,
|
||||
)
|
||||
|
||||
token_encoder = tiktoken.get_encoding(chat_model_settings.encoding_model)
|
||||
|
||||
ls_config = config.basic_search
|
||||
|
||||
return BasicSearch(
|
||||
llm=llm,
|
||||
model=chat_model,
|
||||
system_prompt=system_prompt,
|
||||
context_builder=BasicSearchContext(
|
||||
text_embedder=text_embedder,
|
||||
text_embedder=embedding_model,
|
||||
text_unit_embeddings=text_unit_embeddings,
|
||||
text_units=text_units,
|
||||
token_encoder=token_encoder,
|
||||
|
||||
@@ -18,7 +18,8 @@ from graphrag.data_model.covariate import Covariate
|
||||
from graphrag.data_model.entity import Entity
|
||||
from graphrag.data_model.relationship import Relationship
|
||||
from graphrag.data_model.text_unit import TextUnit
|
||||
from graphrag.query.factory import get_text_embedder
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.query.input.loaders.dfs import (
|
||||
read_communities,
|
||||
read_community_reports,
|
||||
@@ -27,7 +28,6 @@ from graphrag.query.input.loaders.dfs import (
|
||||
read_relationships,
|
||||
read_text_units,
|
||||
)
|
||||
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -106,7 +106,15 @@ def read_indexer_reports(
|
||||
content_embedding_col not in reports_df.columns
|
||||
or reports_df.loc[:, content_embedding_col].isna().any()
|
||||
):
|
||||
embedder = get_text_embedder(config)
|
||||
# TODO: Find a way to retrieve the right embedding model id.
|
||||
embedding_model_settings = config.get_language_model_config(
|
||||
"default_embedding_model"
|
||||
)
|
||||
embedder = ModelManager().get_or_create_embedding_model(
|
||||
name="default_embedding",
|
||||
model_type=embedding_model_settings.type,
|
||||
config=embedding_model_settings,
|
||||
)
|
||||
reports_df = embed_community_reports(
|
||||
reports_df, embedder, embedding_col=content_embedding_col
|
||||
)
|
||||
@@ -211,7 +219,7 @@ def read_indexer_communities(
|
||||
|
||||
def embed_community_reports(
|
||||
reports_df: pd.DataFrame,
|
||||
embedder: OpenAIEmbedding,
|
||||
embedder: EmbeddingModel,
|
||||
source_col: str = "full_content",
|
||||
embedding_col: str = "full_content_embedding",
|
||||
) -> pd.DataFrame:
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Base classes for LLM and Embedding models."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from typing import Any
|
||||
|
||||
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""The Base LLM implementation."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
streaming: bool = True,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate a response."""
|
||||
|
||||
@abstractmethod
|
||||
def stream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate a response with streaming."""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
streaming: bool = True,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate a response asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
async def astream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate a response asynchronously with streaming."""
|
||||
...
|
||||
|
||||
|
||||
class BaseTextEmbedding(ABC):
|
||||
"""The text embedding interface."""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""Embed a text string."""
|
||||
|
||||
@abstractmethod
|
||||
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""Embed a text string asynchronously."""
|
||||
@@ -1,82 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Initialize LLM and Embedding clients."""
|
||||
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
|
||||
from graphrag.config.defaults import language_model_defaults
|
||||
from graphrag.config.enums import AuthType, ModelType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
|
||||
from graphrag.query.llm.oai.typing import OpenaiApiType
|
||||
|
||||
|
||||
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
|
||||
"""Get the LLM client."""
|
||||
llm_config = config.get_language_model_config("default_chat_model")
|
||||
is_azure_client = llm_config.type == ModelType.AzureOpenAIChat
|
||||
debug_llm_key = llm_config.api_key or ""
|
||||
llm_debug_info = {
|
||||
**llm_config.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_llm_key)}",
|
||||
}
|
||||
audience = (
|
||||
llm_config.audience
|
||||
if llm_config.audience
|
||||
else "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
print(f"creating llm client with {llm_debug_info}") # noqa T201
|
||||
return ChatOpenAI(
|
||||
api_key=llm_config.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client and llm_config.auth_type == AuthType.AzureManagedIdentity
|
||||
else None
|
||||
),
|
||||
api_base=llm_config.api_base,
|
||||
organization=llm_config.organization,
|
||||
model=llm_config.model,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
deployment_name=llm_config.deployment_name,
|
||||
api_version=llm_config.api_version,
|
||||
max_retries=llm_config.max_retries
|
||||
if llm_config.max_retries != -1
|
||||
else language_model_defaults.max_retries,
|
||||
request_timeout=llm_config.request_timeout,
|
||||
)
|
||||
|
||||
|
||||
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
|
||||
"""Get the LLM client for embeddings."""
|
||||
embeddings_llm_config = config.get_language_model_config(config.embed_text.model_id)
|
||||
is_azure_client = embeddings_llm_config.type == ModelType.AzureOpenAIEmbedding
|
||||
debug_embedding_api_key = embeddings_llm_config.api_key or ""
|
||||
llm_debug_info = {
|
||||
**embeddings_llm_config.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
|
||||
}
|
||||
if embeddings_llm_config.audience is None:
|
||||
audience = "https://cognitiveservices.azure.com/.default"
|
||||
else:
|
||||
audience = embeddings_llm_config.audience
|
||||
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
|
||||
return OpenAIEmbedding(
|
||||
api_key=embeddings_llm_config.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client
|
||||
and embeddings_llm_config.auth_type == AuthType.AzureManagedIdentity
|
||||
else None
|
||||
),
|
||||
api_base=embeddings_llm_config.api_base,
|
||||
organization=embeddings_llm_config.organization,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
model=embeddings_llm_config.model,
|
||||
deployment_name=embeddings_llm_config.deployment_name,
|
||||
api_version=embeddings_llm_config.api_version,
|
||||
max_retries=embeddings_llm_config.max_retries
|
||||
if embeddings_llm_config.max_retries != -1
|
||||
else language_model_defaults.max_retries,
|
||||
)
|
||||
@@ -1,4 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""GraphRAG Orchestration OpenAI Wrappers."""
|
||||
@@ -1,188 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Base classes for LLM and Embedding models."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
from graphrag.logger.base import StatusLogger
|
||||
from graphrag.logger.console import ConsoleReporter
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.query.llm.oai.typing import OpenaiApiType
|
||||
|
||||
|
||||
class BaseOpenAILLM(ABC):
|
||||
"""The Base OpenAI LLM implementation."""
|
||||
|
||||
_async_client: AsyncOpenAI | AsyncAzureOpenAI
|
||||
_sync_client: OpenAI | AzureOpenAI
|
||||
|
||||
def __init__(self):
|
||||
self._create_openai_client()
|
||||
|
||||
@abstractmethod
|
||||
def _create_openai_client(self):
|
||||
"""Create a new synchronous and asynchronous OpenAI client instance."""
|
||||
|
||||
def set_clients(
|
||||
self,
|
||||
sync_client: OpenAI | AzureOpenAI,
|
||||
async_client: AsyncOpenAI | AsyncAzureOpenAI,
|
||||
):
|
||||
"""
|
||||
Set the synchronous and asynchronous clients used for making API requests.
|
||||
|
||||
Args:
|
||||
sync_client (OpenAI | AzureOpenAI): The sync client object.
|
||||
async_client (AsyncOpenAI | AsyncAzureOpenAI): The async client object.
|
||||
"""
|
||||
self._sync_client = sync_client
|
||||
self._async_client = async_client
|
||||
|
||||
@property
|
||||
def async_client(self) -> AsyncOpenAI | AsyncAzureOpenAI | None:
|
||||
"""
|
||||
Get the asynchronous client used for making API requests.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AsyncOpenAI | AsyncAzureOpenAI: The async client object.
|
||||
"""
|
||||
return self._async_client
|
||||
|
||||
@property
|
||||
def sync_client(self) -> OpenAI | AzureOpenAI | None:
|
||||
"""
|
||||
Get the synchronous client used for making API requests.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AsyncOpenAI | AsyncAzureOpenAI: The async client object.
|
||||
"""
|
||||
return self._sync_client
|
||||
|
||||
@async_client.setter
|
||||
def async_client(self, client: AsyncOpenAI | AsyncAzureOpenAI):
|
||||
"""
|
||||
Set the asynchronous client used for making API requests.
|
||||
|
||||
Args:
|
||||
client (AsyncOpenAI | AsyncAzureOpenAI): The async client object.
|
||||
"""
|
||||
self._async_client = client
|
||||
|
||||
@sync_client.setter
|
||||
def sync_client(self, client: OpenAI | AzureOpenAI):
|
||||
"""
|
||||
Set the synchronous client used for making API requests.
|
||||
|
||||
Args:
|
||||
client (OpenAI | AzureOpenAI): The sync client object.
|
||||
"""
|
||||
self._sync_client = client
|
||||
|
||||
|
||||
class OpenAILLMImpl(BaseOpenAILLM):
|
||||
"""Orchestration OpenAI LLM Implementation."""
|
||||
|
||||
_reporter: StatusLogger = ConsoleReporter()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
azure_ad_token_provider: Callable | None = None,
|
||||
deployment_name: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_type: OpenaiApiType = OpenaiApiType.OpenAI,
|
||||
organization: str | None = None,
|
||||
max_retries: int = 10,
|
||||
request_timeout: float = 180.0,
|
||||
logger: StatusLogger | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.azure_ad_token_provider = azure_ad_token_provider
|
||||
self.deployment_name = deployment_name
|
||||
self.api_base = api_base
|
||||
self.api_version = api_version
|
||||
self.api_type = api_type
|
||||
self.organization = organization
|
||||
self.max_retries = max_retries
|
||||
self.request_timeout = request_timeout
|
||||
self.logger = logger or ConsoleReporter()
|
||||
|
||||
try:
|
||||
# Create OpenAI sync and async clients
|
||||
super().__init__()
|
||||
except Exception as e:
|
||||
self._reporter.error(
|
||||
message="Failed to create OpenAI client",
|
||||
details={self.__class__.__name__: str(e)},
|
||||
)
|
||||
raise
|
||||
|
||||
def _create_openai_client(self):
|
||||
"""Create a new OpenAI client instance."""
|
||||
if self.api_type == OpenaiApiType.AzureOpenAI:
|
||||
if self.api_base is None:
|
||||
msg = "api_base is required for Azure OpenAI"
|
||||
raise ValueError(msg)
|
||||
|
||||
sync_client = AzureOpenAI(
|
||||
api_key=self.api_key,
|
||||
azure_ad_token_provider=self.azure_ad_token_provider,
|
||||
organization=self.organization,
|
||||
# Azure-Specifics
|
||||
api_version=self.api_version,
|
||||
azure_endpoint=self.api_base,
|
||||
azure_deployment=self.deployment_name,
|
||||
# Retry Configuration
|
||||
timeout=self.request_timeout,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
|
||||
async_client = AsyncAzureOpenAI(
|
||||
api_key=self.api_key,
|
||||
azure_ad_token_provider=self.azure_ad_token_provider,
|
||||
organization=self.organization,
|
||||
# Azure-Specifics
|
||||
api_version=self.api_version,
|
||||
azure_endpoint=self.api_base,
|
||||
azure_deployment=self.deployment_name,
|
||||
# Retry Configuration
|
||||
timeout=self.request_timeout,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
self.set_clients(sync_client=sync_client, async_client=async_client)
|
||||
|
||||
else:
|
||||
sync_client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
organization=self.organization,
|
||||
# Retry Configuration
|
||||
timeout=self.request_timeout,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
|
||||
async_client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
organization=self.organization,
|
||||
# Retry Configuration
|
||||
timeout=self.request_timeout,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
self.set_clients(sync_client=sync_client, async_client=async_client)
|
||||
|
||||
|
||||
class OpenAITextEmbeddingImpl(BaseTextEmbedding):
|
||||
"""Orchestration OpenAI Text Embedding Implementation."""
|
||||
|
||||
_reporter: StatusLogger | None = None
|
||||
|
||||
def _create_openai_client(self, api_type: OpenaiApiType):
|
||||
"""Create a new synchronous and asynchronous OpenAI client instance."""
|
||||
@@ -1,329 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Chat-based OpenAI LLM implementation."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable, Generator
|
||||
from typing import Any
|
||||
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
RetryError,
|
||||
Retrying,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from graphrag.logger.base import StatusLogger
|
||||
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
|
||||
from graphrag.query.llm.oai.base import OpenAILLMImpl
|
||||
from graphrag.query.llm.oai.typing import (
|
||||
OPENAI_RETRY_ERROR_TYPES,
|
||||
OpenaiApiType,
|
||||
)
|
||||
|
||||
_MODEL_REQUIRED_MSG = "model is required"
|
||||
|
||||
|
||||
class ChatOpenAI(BaseLLM, OpenAILLMImpl):
|
||||
"""Wrapper for OpenAI ChatCompletion models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str | None = None,
|
||||
azure_ad_token_provider: Callable | None = None,
|
||||
deployment_name: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_type: OpenaiApiType = OpenaiApiType.OpenAI,
|
||||
organization: str | None = None,
|
||||
max_retries: int = 10,
|
||||
request_timeout: float = 180.0,
|
||||
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
|
||||
logger: StatusLogger | None = None,
|
||||
):
|
||||
OpenAILLMImpl.__init__(
|
||||
self=self,
|
||||
api_key=api_key,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
deployment_name=deployment_name,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_type=api_type, # type: ignore
|
||||
organization=organization,
|
||||
max_retries=max_retries,
|
||||
request_timeout=request_timeout,
|
||||
logger=logger,
|
||||
)
|
||||
self.model = model
|
||||
self.retry_error_types = retry_error_types
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
streaming: bool = True,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate text."""
|
||||
try:
|
||||
retryer = Retrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential_jitter(max=10),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_error_types),
|
||||
)
|
||||
for attempt in retryer:
|
||||
with attempt:
|
||||
return self._generate(
|
||||
messages=messages,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
except RetryError as e:
|
||||
self._reporter.error(
|
||||
message="Error at generate()", details={self.__class__.__name__: str(e)}
|
||||
)
|
||||
return ""
|
||||
else:
|
||||
# TODO: why not just throw in this case?
|
||||
return ""
|
||||
|
||||
def stream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate text with streaming."""
|
||||
try:
|
||||
retryer = Retrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential_jitter(max=10),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_error_types),
|
||||
)
|
||||
for attempt in retryer:
|
||||
with attempt:
|
||||
generator = self._stream_generate(
|
||||
messages=messages,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
yield from generator
|
||||
|
||||
except RetryError as e:
|
||||
self._reporter.error(
|
||||
message="Error at stream_generate()",
|
||||
details={self.__class__.__name__: str(e)},
|
||||
)
|
||||
return
|
||||
else:
|
||||
return
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
streaming: bool = True,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate text asynchronously."""
|
||||
try:
|
||||
retryer = AsyncRetrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential_jitter(max=10),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_error_types), # type: ignore
|
||||
)
|
||||
async for attempt in retryer:
|
||||
with attempt:
|
||||
return await self._agenerate(
|
||||
messages=messages,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
except RetryError as e:
|
||||
self._reporter.error(f"Error at agenerate(): {e}")
|
||||
return ""
|
||||
else:
|
||||
# TODO: why not just throw in this case?
|
||||
return ""
|
||||
|
||||
async def astream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate text asynchronously with streaming."""
|
||||
try:
|
||||
retryer = AsyncRetrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential_jitter(max=10),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_error_types), # type: ignore
|
||||
)
|
||||
async for attempt in retryer:
|
||||
with attempt:
|
||||
generator = self._astream_generate(
|
||||
messages=messages,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
async for response in generator:
|
||||
yield response
|
||||
except RetryError as e:
|
||||
self._reporter.error(f"Error at astream_generate(): {e}")
|
||||
return
|
||||
else:
|
||||
return
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
streaming: bool = True,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
model = self.model
|
||||
if not model:
|
||||
raise ValueError(_MODEL_REQUIRED_MSG)
|
||||
response = self.sync_client.chat.completions.create( # type: ignore
|
||||
model=model,
|
||||
messages=messages, # type: ignore
|
||||
stream=streaming,
|
||||
**kwargs,
|
||||
) # type: ignore
|
||||
if streaming:
|
||||
full_response = ""
|
||||
while True:
|
||||
try:
|
||||
chunk = response.__next__() # type: ignore
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = (
|
||||
chunk.choices[0].delta.content
|
||||
if chunk.choices[0].delta and chunk.choices[0].delta.content
|
||||
else ""
|
||||
) # type: ignore
|
||||
|
||||
full_response += delta
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_llm_new_token(delta)
|
||||
if chunk.choices[0].finish_reason == "stop": # type: ignore
|
||||
break
|
||||
except StopIteration:
|
||||
break
|
||||
return full_response
|
||||
return response.choices[0].message.content or "" # type: ignore
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Generator[str, None, None]:
|
||||
model = self.model
|
||||
if not model:
|
||||
raise ValueError(_MODEL_REQUIRED_MSG)
|
||||
response = self.sync_client.chat.completions.create( # type: ignore
|
||||
model=model,
|
||||
messages=messages, # type: ignore
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
for chunk in response:
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = (
|
||||
chunk.choices[0].delta.content
|
||||
if chunk.choices[0].delta and chunk.choices[0].delta.content
|
||||
else ""
|
||||
)
|
||||
|
||||
yield delta
|
||||
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_llm_new_token(delta)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
streaming: bool = True,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
model = self.model
|
||||
if not model:
|
||||
raise ValueError(_MODEL_REQUIRED_MSG)
|
||||
response = await self.async_client.chat.completions.create( # type: ignore
|
||||
model=model,
|
||||
messages=messages, # type: ignore
|
||||
stream=streaming,
|
||||
**kwargs,
|
||||
)
|
||||
if streaming:
|
||||
full_response = ""
|
||||
while True:
|
||||
try:
|
||||
chunk = await response.__anext__() # type: ignore
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = (
|
||||
chunk.choices[0].delta.content
|
||||
if chunk.choices[0].delta and chunk.choices[0].delta.content
|
||||
else ""
|
||||
) # type: ignore
|
||||
|
||||
full_response += delta
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_llm_new_token(delta)
|
||||
if chunk.choices[0].finish_reason == "stop": # type: ignore
|
||||
break
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
return full_response
|
||||
|
||||
return response.choices[0].message.content or "" # type: ignore
|
||||
|
||||
async def _astream_generate(
|
||||
self,
|
||||
messages: str | list[Any],
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model = self.model
|
||||
if not model:
|
||||
raise ValueError(_MODEL_REQUIRED_MSG)
|
||||
response = await self.async_client.chat.completions.create( # type: ignore
|
||||
model=model,
|
||||
messages=messages, # type: ignore
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
async for chunk in response:
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = (
|
||||
chunk.choices[0].delta.content
|
||||
if chunk.choices[0].delta and chunk.choices[0].delta.content
|
||||
else ""
|
||||
) # type: ignore
|
||||
|
||||
yield delta
|
||||
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_llm_new_token(delta)
|
||||
@@ -1,183 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""OpenAI Embedding model implementation."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
RetryError,
|
||||
Retrying,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.logger.base import StatusLogger
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.query.llm.oai.base import OpenAILLMImpl
|
||||
from graphrag.query.llm.oai.typing import (
|
||||
OPENAI_RETRY_ERROR_TYPES,
|
||||
OpenaiApiType,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import chunk_text
|
||||
|
||||
|
||||
class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl):
|
||||
"""Wrapper for OpenAI Embedding models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
azure_ad_token_provider: Callable | None = None,
|
||||
model: str = "text-embedding-3-small",
|
||||
deployment_name: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_type: OpenaiApiType = OpenaiApiType.OpenAI,
|
||||
organization: str | None = None,
|
||||
encoding_name: str = defs.ENCODING_MODEL,
|
||||
max_tokens: int = 8191,
|
||||
max_retries: int = 10,
|
||||
request_timeout: float = 180.0,
|
||||
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
|
||||
logger: StatusLogger | None = None,
|
||||
):
|
||||
OpenAILLMImpl.__init__(
|
||||
self=self,
|
||||
api_key=api_key,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
deployment_name=deployment_name,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_type=api_type, # type: ignore
|
||||
organization=organization,
|
||||
max_retries=max_retries,
|
||||
request_timeout=request_timeout,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.encoding_name = encoding_name
|
||||
self.max_tokens = max_tokens
|
||||
self.token_encoder = tiktoken.get_encoding(self.encoding_name)
|
||||
self.retry_error_types = retry_error_types
|
||||
|
||||
def embed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""
|
||||
Embed text using OpenAI Embedding's sync function.
|
||||
|
||||
For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average.
|
||||
Please refer to: https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
|
||||
"""
|
||||
token_chunks = chunk_text(
|
||||
text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens
|
||||
)
|
||||
chunk_embeddings = []
|
||||
chunk_lens = []
|
||||
for chunk in token_chunks:
|
||||
try:
|
||||
embedding, chunk_len = self._embed_with_retry(chunk, **kwargs)
|
||||
chunk_embeddings.append(embedding)
|
||||
chunk_lens.append(chunk_len)
|
||||
# TODO: catch a more specific exception
|
||||
except Exception as e: # noqa BLE001
|
||||
self._reporter.error(
|
||||
message="Error embedding chunk",
|
||||
details={self.__class__.__name__: str(e)},
|
||||
)
|
||||
|
||||
continue
|
||||
chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens)
|
||||
chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)
|
||||
return chunk_embeddings.tolist()
|
||||
|
||||
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""
|
||||
Embed text using OpenAI Embedding's async function.
|
||||
|
||||
For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average.
|
||||
"""
|
||||
token_chunks = chunk_text(
|
||||
text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens
|
||||
)
|
||||
chunk_embeddings = []
|
||||
chunk_lens = []
|
||||
embedding_results = await asyncio.gather(*[
|
||||
self._aembed_with_retry(chunk, **kwargs) for chunk in token_chunks
|
||||
])
|
||||
embedding_results = [result for result in embedding_results if result[0]]
|
||||
chunk_embeddings = [result[0] for result in embedding_results]
|
||||
chunk_lens = [result[1] for result in embedding_results]
|
||||
chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) # type: ignore
|
||||
chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)
|
||||
return chunk_embeddings.tolist()
|
||||
|
||||
def _embed_with_retry(
|
||||
self, text: str | tuple, **kwargs: Any
|
||||
) -> tuple[list[float], int]:
|
||||
try:
|
||||
retryer = Retrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential_jitter(max=10),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_error_types),
|
||||
)
|
||||
for attempt in retryer:
|
||||
with attempt:
|
||||
embedding = (
|
||||
self.sync_client.embeddings.create( # type: ignore
|
||||
input=text,
|
||||
model=self.model,
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
.data[0]
|
||||
.embedding
|
||||
or []
|
||||
)
|
||||
return (embedding, len(text))
|
||||
except RetryError as e:
|
||||
self._reporter.error(
|
||||
message="Error at embed_with_retry()",
|
||||
details={self.__class__.__name__: str(e)},
|
||||
)
|
||||
return ([], 0)
|
||||
else:
|
||||
# TODO: why not just throw in this case?
|
||||
return ([], 0)
|
||||
|
||||
async def _aembed_with_retry(
|
||||
self, text: str | tuple, **kwargs: Any
|
||||
) -> tuple[list[float], int]:
|
||||
try:
|
||||
retryer = AsyncRetrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential_jitter(max=10),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_error_types),
|
||||
)
|
||||
async for attempt in retryer:
|
||||
with attempt:
|
||||
embedding = (
|
||||
await self.async_client.embeddings.create( # type: ignore
|
||||
input=text,
|
||||
model=self.model,
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
).data[0].embedding or []
|
||||
return (embedding, len(text))
|
||||
except RetryError as e:
|
||||
self._reporter.error(
|
||||
message="Error at embed_with_retry()",
|
||||
details={self.__class__.__name__: str(e)},
|
||||
)
|
||||
return ([], 0)
|
||||
else:
|
||||
# TODO: why not just throw in this case?
|
||||
return ([], 0)
|
||||
@@ -1,187 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""OpenAI Wrappers for Orchestration."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
RetryError,
|
||||
Retrying,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from graphrag.query.llm.base import BaseLLMCallback
|
||||
from graphrag.query.llm.oai.base import OpenAILLMImpl
|
||||
from graphrag.query.llm.oai.typing import (
|
||||
OPENAI_RETRY_ERROR_TYPES,
|
||||
OpenaiApiType,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAI(OpenAILLMImpl):
|
||||
"""Wrapper for OpenAI Completion models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str,
|
||||
deployment_name: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_type: OpenaiApiType = OpenaiApiType.OpenAI,
|
||||
organization: str | None = None,
|
||||
max_retries: int = 10,
|
||||
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.deployment_name = deployment_name
|
||||
self.api_base = api_base
|
||||
self.api_version = api_version
|
||||
self.api_type = api_type
|
||||
self.organization = organization
|
||||
self.max_retries = max_retries
|
||||
self.retry_error_types = retry_error_types
|
||||
|
||||
def generate(
|
||||
self,
|
||||
messages: str | list[str],
|
||||
streaming: bool = True,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate text."""
|
||||
try:
|
||||
retryer = Retrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential_jitter(max=10),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_error_types),
|
||||
)
|
||||
for attempt in retryer:
|
||||
with attempt:
|
||||
return self._generate(
|
||||
messages=messages,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
except RetryError:
|
||||
log.exception("RetryError at generate(): %s")
|
||||
return ""
|
||||
else:
|
||||
# TODO: why not just throw in this case?
|
||||
return ""
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
messages: str | list[str],
|
||||
streaming: bool = True,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate Text Asynchronously."""
|
||||
try:
|
||||
retryer = AsyncRetrying(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential_jitter(max=10),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(self.retry_error_types),
|
||||
)
|
||||
async for attempt in retryer:
|
||||
with attempt:
|
||||
return await self._agenerate(
|
||||
messages=messages,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
except RetryError:
|
||||
log.exception("Error at agenerate()")
|
||||
return ""
|
||||
else:
|
||||
# TODO: why not just throw in this case?
|
||||
return ""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: str | list[str],
|
||||
streaming: bool = True,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
response = self.sync_client.chat.completions.create( # type: ignore
|
||||
model=self.model,
|
||||
messages=messages, # type: ignore
|
||||
stream=streaming,
|
||||
**kwargs,
|
||||
) # type: ignore
|
||||
if streaming:
|
||||
full_response = ""
|
||||
while True:
|
||||
try:
|
||||
chunk = response.__next__() # type: ignore
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = (
|
||||
chunk.choices[0].delta.content
|
||||
if chunk.choices[0].delta and chunk.choices[0].delta.content
|
||||
else ""
|
||||
) # type: ignore
|
||||
|
||||
full_response += delta
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_llm_new_token(delta)
|
||||
if chunk.choices[0].finish_reason == "stop": # type: ignore
|
||||
break
|
||||
except StopIteration:
|
||||
break
|
||||
return full_response
|
||||
return response.choices[0].message.content or "" # type: ignore
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: str | list[str],
|
||||
streaming: bool = True,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
response = await self.async_client.chat.completions.create( # type: ignore
|
||||
model=self.model,
|
||||
messages=messages, # type: ignore
|
||||
stream=streaming,
|
||||
**kwargs,
|
||||
)
|
||||
if streaming:
|
||||
full_response = ""
|
||||
while True:
|
||||
try:
|
||||
chunk = await response.__anext__() # type: ignore
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = (
|
||||
chunk.choices[0].delta.content
|
||||
if chunk.choices[0].delta and chunk.choices[0].delta.content
|
||||
else ""
|
||||
) # type: ignore
|
||||
|
||||
full_response += delta
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_llm_new_token(delta)
|
||||
if chunk.choices[0].finish_reason == "stop": # type: ignore
|
||||
break
|
||||
except StopIteration:
|
||||
break
|
||||
return full_response
|
||||
return response.choices[0].message.content or "" # type: ignore
|
||||
@@ -1,27 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""OpenAI wrapper options."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
||||
OPENAI_RETRY_ERROR_TYPES = (
|
||||
# TODO: update these when we update to OpenAI 1+ library
|
||||
cast("Any", openai).RateLimitError,
|
||||
cast("Any", openai).APIConnectionError,
|
||||
cast("Any", openai).APIError,
|
||||
cast("Any", httpx).RemoteProtocolError,
|
||||
cast("Any", httpx).ReadTimeout,
|
||||
# TODO: replace with comparable OpenAI 1+ error
|
||||
)
|
||||
|
||||
|
||||
class OpenaiApiType(str, Enum):
|
||||
"""The OpenAI Flavor."""
|
||||
|
||||
OpenAI = "openai"
|
||||
AzureOpenAI = "azure"
|
||||
@@ -9,11 +9,11 @@ from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.query.context_builder.builders import (
|
||||
GlobalContextBuilder,
|
||||
LocalContextBuilder,
|
||||
)
|
||||
from graphrag.query.llm.base import BaseLLM
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -32,20 +32,20 @@ class BaseQuestionGen(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
model: ChatModel,
|
||||
context_builder: GlobalContextBuilder | LocalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
llm_params: dict[str, Any] | None = None,
|
||||
context_builder_params: dict[str, Any] | None = None,
|
||||
):
|
||||
self.llm = llm
|
||||
self.model = model
|
||||
self.context_builder = context_builder
|
||||
self.token_encoder = token_encoder
|
||||
self.llm_params = llm_params or {}
|
||||
self.context_builder_params = context_builder_params or {}
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
async def generate(
|
||||
self,
|
||||
question_history: list[str],
|
||||
context_data: str | None,
|
||||
|
||||
@@ -9,6 +9,8 @@ from typing import Any, cast
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
|
||||
from graphrag.query.context_builder.builders import (
|
||||
ContextBuilderResult,
|
||||
@@ -17,7 +19,6 @@ from graphrag.query.context_builder.builders import (
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
)
|
||||
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult
|
||||
|
||||
@@ -29,7 +30,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
model: ChatModel,
|
||||
context_builder: LocalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
system_prompt: str = QUESTION_SYSTEM_PROMPT,
|
||||
@@ -38,7 +39,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
context_builder_params: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
model=model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
llm_params=llm_params,
|
||||
@@ -95,15 +96,17 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
)
|
||||
question_messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": question_text},
|
||||
]
|
||||
|
||||
response = await self.llm.agenerate(
|
||||
messages=question_messages,
|
||||
streaming=True,
|
||||
callbacks=self.callbacks,
|
||||
**self.llm_params,
|
||||
)
|
||||
response = ""
|
||||
async for chunk in self.model.achat_stream(
|
||||
prompt=question_text,
|
||||
history=question_messages,
|
||||
model_parameters=self.llm_params,
|
||||
):
|
||||
response += chunk
|
||||
for callback in self.callbacks:
|
||||
callback.on_llm_new_token(chunk)
|
||||
|
||||
return QuestionResult(
|
||||
response=response.split("\n"),
|
||||
@@ -126,7 +129,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
|
||||
)
|
||||
|
||||
def generate(
|
||||
async def generate(
|
||||
self,
|
||||
question_history: list[str],
|
||||
context_data: str | None,
|
||||
@@ -178,12 +181,15 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
{"role": "user", "content": question_text},
|
||||
]
|
||||
|
||||
response = self.llm.generate(
|
||||
messages=question_messages,
|
||||
streaming=True,
|
||||
callbacks=self.callbacks,
|
||||
**self.llm_params,
|
||||
)
|
||||
response = ""
|
||||
async for chunk in self.model.achat_stream(
|
||||
prompt=question_text,
|
||||
history=question_messages,
|
||||
model_parameters=self.llm_params,
|
||||
):
|
||||
response += chunk
|
||||
for callback in self.callbacks:
|
||||
callback.on_llm_new_token(chunk)
|
||||
|
||||
return QuestionResult(
|
||||
response=response.split("\n"),
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Any, Generic, TypeVar
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.query.context_builder.builders import (
|
||||
BasicContextBuilder,
|
||||
DRIFTContextBuilder,
|
||||
@@ -20,7 +21,6 @@ from graphrag.query.context_builder.builders import (
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
)
|
||||
from graphrag.query.llm.base import BaseLLM
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -56,16 +56,16 @@ class BaseSearch(ABC, Generic[T]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
model: ChatModel,
|
||||
context_builder: T,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
llm_params: dict[str, Any] | None = None,
|
||||
model_params: dict[str, Any] | None = None,
|
||||
context_builder_params: dict[str, Any] | None = None,
|
||||
):
|
||||
self.llm = llm
|
||||
self.model = model
|
||||
self.context_builder = context_builder
|
||||
self.token_encoder = token_encoder
|
||||
self.llm_params = llm_params or {}
|
||||
self.model_params = model_params or {}
|
||||
self.context_builder_params = context_builder_params or {}
|
||||
|
||||
@abstractmethod
|
||||
@@ -80,11 +80,12 @@ class BaseSearch(ABC, Generic[T]):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@abstractmethod
|
||||
def stream_search(
|
||||
async def stream_search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream search for the given query."""
|
||||
yield "" # This makes it an async generator.
|
||||
msg = "Subclasses must implement this method"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@@ -7,12 +7,12 @@ import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.data_model.text_unit import TextUnit
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.query.context_builder.builders import (
|
||||
BasicContextBuilder,
|
||||
ContextBuilderResult,
|
||||
)
|
||||
from graphrag.query.context_builder.conversation_history import ConversationHistory
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class BasicSearchContext(BasicContextBuilder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_embedder: BaseTextEmbedding,
|
||||
text_embedder: EmbeddingModel,
|
||||
text_unit_embeddings: BaseVectorStore,
|
||||
text_units: list[TextUnit] | None = None,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
|
||||
@@ -11,12 +11,12 @@ from typing import Any
|
||||
import tiktoken
|
||||
|
||||
from graphrag.callbacks.query_callbacks import QueryCallbacks
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.query.basic_search_system_prompt import (
|
||||
BASIC_SEARCH_SYSTEM_PROMPT,
|
||||
)
|
||||
from graphrag.query.context_builder.builders import BasicContextBuilder
|
||||
from graphrag.query.context_builder.conversation_history import ConversationHistory
|
||||
from graphrag.query.llm.base import BaseLLM
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
|
||||
@@ -36,7 +36,7 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
model: ChatModel,
|
||||
context_builder: BasicContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
system_prompt: str | None = None,
|
||||
@@ -46,10 +46,10 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
context_builder_params: dict | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
model=model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
llm_params=llm_params,
|
||||
model_params=llm_params,
|
||||
context_builder_params=context_builder_params or {},
|
||||
)
|
||||
self.system_prompt = system_prompt or BASIC_SEARCH_SYSTEM_PROMPT
|
||||
@@ -86,15 +86,17 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
response = await self.llm.agenerate(
|
||||
messages=search_messages,
|
||||
streaming=True,
|
||||
callbacks=self.callbacks, # type: ignore
|
||||
**self.llm_params,
|
||||
)
|
||||
response = ""
|
||||
async for chunk in self.model.achat_stream(
|
||||
prompt=query,
|
||||
history=search_messages,
|
||||
model_parameters=self.model_params,
|
||||
):
|
||||
for callback in self.callbacks:
|
||||
callback.on_llm_new_token(chunk)
|
||||
response += chunk
|
||||
|
||||
llm_calls["response"] = 1
|
||||
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
|
||||
@@ -125,11 +127,11 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
def stream_search(
|
||||
async def stream_search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Build basic search context that fits a single context window and generate answer for the user query."""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -144,14 +146,16 @@ class BasicSearch(BaseSearch[BasicContextBuilder]):
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
for callback in self.callbacks:
|
||||
callback.on_context(context_result.context_records)
|
||||
|
||||
return self.llm.astream_generate( # type: ignore
|
||||
messages=search_messages,
|
||||
callbacks=self.callbacks, # type: ignore
|
||||
**self.llm_params,
|
||||
)
|
||||
async for chunk_response in self.model.achat_stream(
|
||||
prompt=query,
|
||||
history=search_messages,
|
||||
model_parameters=self.model_params,
|
||||
):
|
||||
for callback in self.callbacks:
|
||||
callback.on_llm_new_token(chunk_response)
|
||||
yield chunk_response
|
||||
|
||||
@@ -17,13 +17,12 @@ from graphrag.data_model.covariate import Covariate
|
||||
from graphrag.data_model.entity import Entity
|
||||
from graphrag.data_model.relationship import Relationship
|
||||
from graphrag.data_model.text_unit import TextUnit
|
||||
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
|
||||
from graphrag.prompts.query.drift_search_system_prompt import (
|
||||
DRIFT_LOCAL_SYSTEM_PROMPT,
|
||||
DRIFT_REDUCE_PROMPT,
|
||||
)
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.structured_search.base import DRIFTContextBuilder
|
||||
from graphrag.query.structured_search.drift_search.primer import PrimerQueryProcessor
|
||||
from graphrag.query.structured_search.local_search.mixed_context import (
|
||||
@@ -39,8 +38,8 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_llm: ChatOpenAI,
|
||||
text_embedder: BaseTextEmbedding,
|
||||
model: ChatModel,
|
||||
text_embedder: EmbeddingModel,
|
||||
entities: list[Entity],
|
||||
entity_text_embeddings: BaseVectorStore,
|
||||
text_units: list[TextUnit] | None = None,
|
||||
@@ -57,7 +56,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
):
|
||||
"""Initialize the DRIFT search context builder with necessary components."""
|
||||
self.config = config or DRIFTSearchConfig()
|
||||
self.chat_llm = chat_llm
|
||||
self.model = model
|
||||
self.text_embedder = text_embedder
|
||||
self.token_encoder = token_encoder
|
||||
self.local_system_prompt = local_system_prompt or DRIFT_LOCAL_SYSTEM_PROMPT
|
||||
@@ -163,7 +162,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
and isinstance(query_embedding[0], type(embedding[0]))
|
||||
)
|
||||
|
||||
def build_context(
|
||||
async def build_context(
|
||||
self, query: str, **kwargs
|
||||
) -> tuple[pd.DataFrame, dict[str, int]]:
|
||||
"""
|
||||
@@ -191,13 +190,13 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
raise ValueError(missing_reports_error)
|
||||
|
||||
query_processor = PrimerQueryProcessor(
|
||||
chat_llm=self.chat_llm,
|
||||
chat_model=self.model,
|
||||
text_embedder=self.text_embedder,
|
||||
token_encoder=self.token_encoder,
|
||||
reports=self.reports,
|
||||
)
|
||||
|
||||
query_embedding, token_ct = query_processor(query)
|
||||
query_embedding, token_ct = await query_processor(query)
|
||||
|
||||
report_df = self.convert_reports_to_df(self.reports)
|
||||
|
||||
|
||||
@@ -15,11 +15,10 @@ from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
|
||||
from graphrag.data_model.community_report import CommunityReport
|
||||
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
|
||||
from graphrag.prompts.query.drift_search_system_prompt import (
|
||||
DRIFT_PRIMER_PROMPT,
|
||||
)
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import SearchResult
|
||||
|
||||
@@ -31,8 +30,8 @@ class PrimerQueryProcessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_llm: ChatOpenAI,
|
||||
text_embedder: BaseTextEmbedding,
|
||||
chat_model: ChatModel,
|
||||
text_embedder: EmbeddingModel,
|
||||
reports: list[CommunityReport],
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
):
|
||||
@@ -45,12 +44,12 @@ class PrimerQueryProcessor:
|
||||
reports (list[CommunityReport]): List of community reports.
|
||||
token_encoder (tiktoken.Encoding, optional): Token encoder for token counting.
|
||||
"""
|
||||
self.chat_llm = chat_llm
|
||||
self.chat_model = chat_model
|
||||
self.text_embedder = text_embedder
|
||||
self.token_encoder = token_encoder
|
||||
self.reports = reports
|
||||
|
||||
def expand_query(self, query: str) -> tuple[str, dict[str, int]]:
|
||||
async def expand_query(self, query: str) -> tuple[str, dict[str, int]]:
|
||||
"""
|
||||
Expand the query using a random community report template.
|
||||
|
||||
@@ -68,9 +67,9 @@ class PrimerQueryProcessor:
|
||||
{template}\n"
|
||||
Ensure that the hypothetical answer does not reference new named entities that are not present in the original query."""
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
model_response = await self.chat_model.achat(prompt)
|
||||
text = model_response.output.content
|
||||
|
||||
text = self.chat_llm.generate(messages)
|
||||
prompt_tokens = num_tokens(prompt, self.token_encoder)
|
||||
output_tokens = num_tokens(text, self.token_encoder)
|
||||
token_ct = {
|
||||
@@ -83,7 +82,7 @@ class PrimerQueryProcessor:
|
||||
return query, token_ct
|
||||
return text, token_ct
|
||||
|
||||
def __call__(self, query: str) -> tuple[list[float], dict[str, int]]:
|
||||
async def __call__(self, query: str) -> tuple[list[float], dict[str, int]]:
|
||||
"""
|
||||
Call method to process the query, expand it, and embed the result.
|
||||
|
||||
@@ -94,7 +93,7 @@ class PrimerQueryProcessor:
|
||||
-------
|
||||
tuple[list[float], int]: List of embeddings for the expanded query and the token count.
|
||||
"""
|
||||
hyde_query, token_ct = self.expand_query(query)
|
||||
hyde_query, token_ct = await self.expand_query(query)
|
||||
log.info("Expanded query: %s", hyde_query)
|
||||
return self.text_embedder.embed(hyde_query), token_ct
|
||||
|
||||
@@ -105,7 +104,7 @@ class DRIFTPrimer:
|
||||
def __init__(
|
||||
self,
|
||||
config: DRIFTSearchConfig,
|
||||
chat_llm: ChatOpenAI,
|
||||
chat_model: ChatModel,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
):
|
||||
"""
|
||||
@@ -116,7 +115,7 @@ class DRIFTPrimer:
|
||||
chat_llm (ChatOpenAI): The language model used for searching.
|
||||
token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens.
|
||||
"""
|
||||
self.llm = chat_llm
|
||||
self.chat_model = chat_model
|
||||
self.config = config
|
||||
self.token_encoder = token_encoder
|
||||
|
||||
@@ -138,11 +137,9 @@ class DRIFTPrimer:
|
||||
prompt = DRIFT_PRIMER_PROMPT.format(
|
||||
query=query, community_reports=community_reports
|
||||
)
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
response = await self.llm.agenerate(
|
||||
messages, response_format={"type": "json_object"}
|
||||
)
|
||||
model_response = await self.chat_model.achat(prompt, json=True)
|
||||
response = model_response.output.content
|
||||
|
||||
parsed_response = json.loads(response)
|
||||
|
||||
@@ -173,6 +170,7 @@ class DRIFTPrimer:
|
||||
start_time = time.perf_counter()
|
||||
report_folds = self.split_reports(top_k_reports)
|
||||
tasks = [self.decompose_query(query, fold) for fold in report_folds]
|
||||
|
||||
results_with_tokens = await tqdm_asyncio.gather(*tasks, leave=False)
|
||||
|
||||
completion_time = time.perf_counter() - start_time
|
||||
|
||||
@@ -12,9 +12,9 @@ import tiktoken
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from graphrag.callbacks.query_callbacks import QueryCallbacks
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.query.context_builder.conversation_history import ConversationHistory
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
from graphrag.query.structured_search.drift_search.action import DriftAction
|
||||
@@ -33,7 +33,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: ChatOpenAI,
|
||||
model: ChatModel,
|
||||
context_builder: DRIFTSearchContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
query_state: QueryState | None = None,
|
||||
@@ -49,14 +49,14 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
token_encoder (tiktoken.Encoding, optional): Token encoder for managing tokens.
|
||||
query_state (QueryState, optional): State of the current search query.
|
||||
"""
|
||||
super().__init__(llm, context_builder, token_encoder)
|
||||
super().__init__(model, context_builder, token_encoder)
|
||||
|
||||
self.context_builder = context_builder
|
||||
self.token_encoder = token_encoder
|
||||
self.query_state = query_state or QueryState()
|
||||
self.primer = DRIFTPrimer(
|
||||
config=self.context_builder.config,
|
||||
chat_llm=llm,
|
||||
chat_model=model,
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
self.callbacks = callbacks or []
|
||||
@@ -90,11 +90,11 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
}
|
||||
|
||||
return LocalSearch(
|
||||
llm=self.llm,
|
||||
model=self.model,
|
||||
system_prompt=self.context_builder.local_system_prompt,
|
||||
context_builder=self.context_builder.local_mixed_context,
|
||||
token_encoder=self.token_encoder,
|
||||
llm_params=llm_params,
|
||||
model_params=llm_params,
|
||||
context_builder_params=local_context_params,
|
||||
response_type="multiple paragraphs",
|
||||
callbacks=self.callbacks,
|
||||
@@ -203,7 +203,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
# Check if query state is empty
|
||||
if not self.query_state.graph:
|
||||
# Prime the search with the primer
|
||||
primer_context, token_ct = self.context_builder.build_context(query)
|
||||
primer_context, token_ct = await self.context_builder.build_context(query)
|
||||
llm_calls["build_context"] = token_ct["llm_calls"]
|
||||
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
|
||||
output_tokens["build_context"] = token_ct["prompt_tokens"]
|
||||
@@ -362,16 +362,16 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
reduced_response = self.llm.generate(
|
||||
messages=search_messages,
|
||||
streaming=False,
|
||||
callbacks=self.callbacks, # type: ignore
|
||||
**llm_kwargs,
|
||||
model_response = await self.model.achat(
|
||||
prompt=query,
|
||||
history=search_messages,
|
||||
model_parameters=llm_kwargs,
|
||||
)
|
||||
|
||||
reduced_response = model_response.output.content
|
||||
|
||||
llm_calls["reduce"] = 1
|
||||
prompt_tokens["reduce"] = num_tokens(
|
||||
search_prompt, self.token_encoder
|
||||
@@ -380,7 +380,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
|
||||
return reduced_response
|
||||
|
||||
def _reduce_response_streaming(
|
||||
async def _reduce_response_streaming(
|
||||
self,
|
||||
responses: str | dict[str, Any],
|
||||
query: str,
|
||||
@@ -419,11 +419,13 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
return self.llm.astream_generate(
|
||||
search_messages,
|
||||
callbacks=self.callbacks, # type: ignore
|
||||
**llm_kwargs,
|
||||
)
|
||||
async for response in self.model.achat_stream(
|
||||
prompt=query,
|
||||
history=search_messages,
|
||||
model_parameters=llm_kwargs,
|
||||
):
|
||||
for callback in self.callbacks:
|
||||
callback.on_llm_new_token(response)
|
||||
yield response
|
||||
|
||||
@@ -46,7 +46,7 @@ class GlobalCommunityContext(GlobalContextBuilder):
|
||||
self.dynamic_community_selection = DynamicCommunitySelection(
|
||||
community_reports=community_reports,
|
||||
communities=communities,
|
||||
llm=dynamic_community_selection_kwargs.pop("llm"),
|
||||
model=dynamic_community_selection_kwargs.pop("model"),
|
||||
token_encoder=dynamic_community_selection_kwargs.pop("token_encoder"),
|
||||
**dynamic_community_selection_kwargs,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.callbacks.query_callbacks import QueryCallbacks
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
|
||||
GENERAL_KNOWLEDGE_INSTRUCTION,
|
||||
)
|
||||
@@ -29,7 +30,6 @@ from graphrag.query.context_builder.builders import GlobalContextBuilder
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
)
|
||||
from graphrag.query.llm.base import BaseLLM
|
||||
from graphrag.query.llm.text_utils import num_tokens, try_parse_json_object
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
|
||||
@@ -60,7 +60,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
model: ChatModel,
|
||||
context_builder: GlobalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
map_system_prompt: str | None = None,
|
||||
@@ -77,7 +77,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
concurrent_coroutines: int = 32,
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
model=model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
context_builder_params=context_builder_params,
|
||||
@@ -106,7 +106,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream the global search response."""
|
||||
context_result = await self.context_builder.build_context(
|
||||
query=query,
|
||||
@@ -130,7 +130,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
async for response in self._stream_reduce_response(
|
||||
map_responses=map_responses, # type: ignore
|
||||
query=query,
|
||||
**self.reduce_llm_params,
|
||||
model_parameters=self.reduce_llm_params,
|
||||
):
|
||||
yield response
|
||||
|
||||
@@ -218,12 +218,15 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
search_prompt = self.map_system_prompt.format(context_data=context_data)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
async with self.semaphore:
|
||||
search_response = await self.llm.agenerate(
|
||||
messages=search_messages, streaming=False, **llm_kwargs
|
||||
model_response = await self.model.achat(
|
||||
prompt=query,
|
||||
history=search_messages,
|
||||
model_parameters=llm_kwargs,
|
||||
json=True,
|
||||
)
|
||||
search_response = model_response.output.content
|
||||
log.info("Map response: %s", search_response)
|
||||
try:
|
||||
# parse search response json
|
||||
@@ -373,12 +376,16 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
search_response = await self.llm.agenerate(
|
||||
search_messages,
|
||||
streaming=True,
|
||||
callbacks=self.callbacks, # type: ignore
|
||||
**llm_kwargs, # type: ignore
|
||||
)
|
||||
search_response = ""
|
||||
async for chunk_response in self.model.achat_stream(
|
||||
prompt=query,
|
||||
history=search_messages,
|
||||
model_parameters=llm_kwargs,
|
||||
):
|
||||
search_response += chunk_response
|
||||
for callback in self.callbacks:
|
||||
callback.on_llm_new_token(chunk_response)
|
||||
|
||||
return SearchResult(
|
||||
response=search_response,
|
||||
context_data=text_data,
|
||||
@@ -468,12 +475,13 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
search_prompt += "\n" + self.general_knowledge_inclusion_prompt
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
async for resp in self.llm.astream_generate( # type: ignore
|
||||
search_messages,
|
||||
callbacks=self.callbacks, # type: ignore
|
||||
**llm_kwargs, # type: ignore
|
||||
async for chunk_response in self.model.achat_stream(
|
||||
prompt=query,
|
||||
history=search_messages,
|
||||
**llm_kwargs,
|
||||
):
|
||||
yield resp
|
||||
for callback in self.callbacks:
|
||||
callback.on_llm_new_token(chunk_response)
|
||||
yield chunk_response
|
||||
|
||||
@@ -14,6 +14,7 @@ from graphrag.data_model.covariate import Covariate
|
||||
from graphrag.data_model.entity import Entity
|
||||
from graphrag.data_model.relationship import Relationship
|
||||
from graphrag.data_model.text_unit import TextUnit
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.query.context_builder.builders import ContextBuilderResult
|
||||
from graphrag.query.context_builder.community_context import (
|
||||
build_community_context,
|
||||
@@ -39,7 +40,6 @@ from graphrag.query.input.retrieval.community_reports import (
|
||||
get_candidate_communities,
|
||||
)
|
||||
from graphrag.query.input.retrieval.text_units import get_candidate_text_units
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import LocalContextBuilder
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
@@ -54,7 +54,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
self,
|
||||
entities: list[Entity],
|
||||
entity_text_embeddings: BaseVectorStore,
|
||||
text_embedder: BaseTextEmbedding,
|
||||
text_embedder: EmbeddingModel,
|
||||
text_units: list[TextUnit] | None = None,
|
||||
community_reports: list[CommunityReport] | None = None,
|
||||
relationships: list[Relationship] | None = None,
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Any
|
||||
import tiktoken
|
||||
|
||||
from graphrag.callbacks.query_callbacks import QueryCallbacks
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.query.local_search_system_prompt import (
|
||||
LOCAL_SEARCH_SYSTEM_PROMPT,
|
||||
)
|
||||
@@ -18,7 +19,6 @@ from graphrag.query.context_builder.builders import LocalContextBuilder
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
)
|
||||
from graphrag.query.llm.base import BaseLLM
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
|
||||
@@ -35,20 +35,20 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
model: ChatModel,
|
||||
context_builder: LocalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
system_prompt: str | None = None,
|
||||
response_type: str = "multiple paragraphs",
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
llm_params: dict[str, Any] = DEFAULT_LLM_PARAMS,
|
||||
model_params: dict[str, Any] = DEFAULT_LLM_PARAMS,
|
||||
context_builder_params: dict | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
model=model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
llm_params=llm_params,
|
||||
model_params=model_params,
|
||||
context_builder_params=context_builder_params or {},
|
||||
)
|
||||
self.system_prompt = system_prompt or LOCAL_SEARCH_SYSTEM_PROMPT
|
||||
@@ -89,26 +89,30 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
context_data=context_result.context_chunks,
|
||||
response_type=self.response_type,
|
||||
)
|
||||
search_messages = [
|
||||
history_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
response = await self.llm.agenerate(
|
||||
messages=search_messages,
|
||||
streaming=True,
|
||||
callbacks=self.callbacks, # type: ignore
|
||||
**self.llm_params,
|
||||
)
|
||||
full_response = ""
|
||||
|
||||
async for response in self.model.achat_stream(
|
||||
prompt=query,
|
||||
history=history_messages,
|
||||
model_parameters=self.model_params,
|
||||
):
|
||||
full_response += response
|
||||
for callback in self.callbacks:
|
||||
callback.on_llm_new_token(response)
|
||||
|
||||
llm_calls["response"] = 1
|
||||
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
|
||||
output_tokens["response"] = num_tokens(response, self.token_encoder)
|
||||
output_tokens["response"] = num_tokens(full_response, self.token_encoder)
|
||||
|
||||
for callback in self.callbacks:
|
||||
callback.on_context(context_result.context_records)
|
||||
|
||||
return SearchResult(
|
||||
response=response,
|
||||
response=full_response,
|
||||
context_data=context_result.context_records,
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
@@ -132,7 +136,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
def stream_search(
|
||||
async def stream_search(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
@@ -149,16 +153,18 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
search_prompt = self.system_prompt.format(
|
||||
context_data=context_result.context_chunks, response_type=self.response_type
|
||||
)
|
||||
search_messages = [
|
||||
history_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
|
||||
for callback in self.callbacks:
|
||||
callback.on_context(context_result.context_records)
|
||||
|
||||
return self.llm.astream_generate( # type: ignore
|
||||
messages=search_messages,
|
||||
callbacks=self.callbacks, # type: ignore
|
||||
**self.llm_params,
|
||||
)
|
||||
async for response in self.model.achat_stream(
|
||||
prompt=query,
|
||||
history=history_messages,
|
||||
model_parameters=self.model_params,
|
||||
):
|
||||
for callback in self.callbacks:
|
||||
callback.on_llm_new_token(response)
|
||||
yield response
|
||||
|
||||
86
tests/integration/language_model/test_factory.py
Normal file
86
tests/integration/language_model/test_factory.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LLMFactory Tests.
|
||||
|
||||
These tests will test the LLMFactory class and the creation of custom and provided LLMs.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from graphrag.language_model.factory import ModelFactory
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.response.base import (
|
||||
BaseModelOutput,
|
||||
BaseModelResponse,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
|
||||
async def test_create_custom_chat_model():
|
||||
class CustomChatModel:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def achat(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> ModelResponse:
|
||||
return BaseModelResponse(output=BaseModelOutput(content="content"))
|
||||
|
||||
def chat(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> ModelResponse:
|
||||
return BaseModelResponse(output=BaseModelOutput(content="content"))
|
||||
|
||||
async def achat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
|
||||
def chat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
|
||||
ModelFactory.register_chat("custom_chat", CustomChatModel)
|
||||
model = ModelManager().get_or_create_chat_model("custom", "custom_chat")
|
||||
assert isinstance(model, CustomChatModel)
|
||||
response = await model.achat("prompt")
|
||||
assert response.output.content == "content"
|
||||
|
||||
response = model.chat("prompt")
|
||||
assert response.output.content == "content"
|
||||
|
||||
|
||||
async def test_create_custom_embedding_llm():
|
||||
class CustomEmbeddingModel:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def aembed(self, text: str, **kwargs) -> list[float]:
|
||||
return [1.0]
|
||||
|
||||
def embed(self, text: str, **kwargs) -> list[float]:
|
||||
return [1.0]
|
||||
|
||||
async def aembed_batch(
|
||||
self, text_list: list[str], **kwargs
|
||||
) -> list[list[float]]:
|
||||
return [[1.0]]
|
||||
|
||||
def embed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]:
|
||||
return [[1.0]]
|
||||
|
||||
ModelFactory.register_embedding("custom_embedding", CustomEmbeddingModel)
|
||||
llm = ModelManager().get_or_create_embedding_model("custom", "custom_embedding")
|
||||
assert isinstance(llm, CustomEmbeddingModel)
|
||||
response = await llm.aembed("text")
|
||||
assert response == [1.0]
|
||||
|
||||
response = llm.embed("text")
|
||||
assert response == [1.0]
|
||||
|
||||
response = await llm.aembed_batch(["text"])
|
||||
assert response == [[1.0]]
|
||||
|
||||
response = llm.embed_batch(["text"])
|
||||
assert response == [[1.0]]
|
||||
@@ -1,45 +0,0 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LLMFactory Tests.
|
||||
|
||||
These tests will test the LLMFactory class and the creation of custom and provided LLMs.
|
||||
"""
|
||||
|
||||
from graphrag.language_model.factory import ModelFactory
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.response.base import (
|
||||
BaseModelOutput,
|
||||
BaseModelResponse,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
|
||||
async def test_create_custom_chat_llm():
|
||||
class CustomChatLLM:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> ModelResponse:
|
||||
return BaseModelResponse(output=BaseModelOutput(content="content"))
|
||||
|
||||
ModelFactory.register_chat("custom_chat", CustomChatLLM)
|
||||
llm = ModelManager().get_or_create_chat_model("custom", "custom_chat")
|
||||
assert isinstance(llm, CustomChatLLM)
|
||||
response = await llm.chat("prompt")
|
||||
assert response.output.content == "content"
|
||||
|
||||
|
||||
async def test_create_custom_embedding_llm():
|
||||
class CustomEmbeddingLLM:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def embed(self, text: str | list[str], **kwargs) -> list[list[float]]:
|
||||
return [[1.0]]
|
||||
|
||||
ModelFactory.register_embedding("custom_embedding", CustomEmbeddingLLM)
|
||||
llm = ModelManager().get_or_create_embedding_model("custom", "custom_embedding")
|
||||
assert isinstance(llm, CustomEmbeddingLLM)
|
||||
response = await llm.embed("text")
|
||||
assert response == [[1.0]]
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
"""A module containing mock model provider definitions."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -28,9 +29,38 @@ class MockChatLLM:
|
||||
self.responses = config.responses if config and config.responses else responses
|
||||
self.response_index = 0
|
||||
|
||||
async def chat(
|
||||
async def achat(
|
||||
self,
|
||||
prompt: str,
|
||||
history: list | None = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Return the next response in the list."""
|
||||
return self.chat(prompt, history, **kwargs)
|
||||
|
||||
async def achat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
history: list | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Return the next response in the list."""
|
||||
if not self.responses:
|
||||
return
|
||||
|
||||
for response in self.responses:
|
||||
response = (
|
||||
response.model_dump_json()
|
||||
if isinstance(response, BaseModel)
|
||||
else response
|
||||
)
|
||||
|
||||
yield response
|
||||
|
||||
def chat(
|
||||
self,
|
||||
prompt: str,
|
||||
history: list | None = None,
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
"""Return the next response in the list."""
|
||||
@@ -50,6 +80,15 @@ class MockChatLLM:
|
||||
parsed_response=parsed_json,
|
||||
)
|
||||
|
||||
def chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
history: list | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Return the next response in the list."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockEmbeddingLLM:
|
||||
"""A mock embedding LLM provider."""
|
||||
@@ -57,8 +96,24 @@ class MockEmbeddingLLM:
|
||||
def __init__(self, **kwargs: Any):
|
||||
pass
|
||||
|
||||
async def embed(self, text: str | list[str], **kwargs: Any) -> list[list[float]]:
|
||||
def embed_batch(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
|
||||
"""Generate an embedding for the input text."""
|
||||
if isinstance(text, str):
|
||||
if isinstance(text_list, str):
|
||||
return [[1.0, 1.0, 1.0]]
|
||||
return [[1.0, 1.0, 1.0] for _ in text]
|
||||
return [[1.0, 1.0, 1.0] for _ in text_list]
|
||||
|
||||
def embed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""Generate an embedding for the input text."""
|
||||
return [1.0, 1.0, 1.0]
|
||||
|
||||
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""Generate an embedding for the input text."""
|
||||
return [1.0, 1.0, 1.0]
|
||||
|
||||
async def aembed_batch(
|
||||
self, text_list: list[str], **kwargs: Any
|
||||
) -> list[list[float]]:
|
||||
"""Generate an embedding for the input text."""
|
||||
if isinstance(text_list, str):
|
||||
return [[1.0, 1.0, 1.0]]
|
||||
return [[1.0, 1.0, 1.0] for _ in text_list]
|
||||
|
||||
@@ -5,11 +5,11 @@ from typing import Any
|
||||
|
||||
from graphrag.data_model.entity import Entity
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.query.context_builder.entity_extraction import (
|
||||
EntityVectorStoreKey,
|
||||
map_query_to_entities,
|
||||
)
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.vector_stores.base import (
|
||||
BaseVectorStore,
|
||||
VectorStoreDocument,
|
||||
@@ -60,14 +60,6 @@ class MockBaseVectorStore(BaseVectorStore):
|
||||
return result
|
||||
|
||||
|
||||
class MockBaseTextEmbedding(BaseTextEmbedding):
|
||||
def embed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
return [len(text)]
|
||||
|
||||
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
return [len(text)]
|
||||
|
||||
|
||||
def test_map_query_to_entities():
|
||||
entities = [
|
||||
Entity(
|
||||
@@ -102,7 +94,9 @@ def test_map_query_to_entities():
|
||||
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
|
||||
for entity in entities
|
||||
]),
|
||||
text_embedder=MockBaseTextEmbedding(),
|
||||
text_embedder=ModelManager().get_or_create_embedding_model(
|
||||
model_type="mock_embedding", name="mock"
|
||||
),
|
||||
all_entities_dict={entity.id: entity for entity in entities},
|
||||
embedding_vectorstore_key=EntityVectorStoreKey.ID,
|
||||
k=1,
|
||||
@@ -122,7 +116,9 @@ def test_map_query_to_entities():
|
||||
VectorStoreDocument(id=entity.title, text=entity.title, vector=None)
|
||||
for entity in entities
|
||||
]),
|
||||
text_embedder=MockBaseTextEmbedding(),
|
||||
text_embedder=ModelManager().get_or_create_embedding_model(
|
||||
model_type="mock_embedding", name="mock"
|
||||
),
|
||||
all_entities_dict={entity.id: entity for entity in entities},
|
||||
embedding_vectorstore_key=EntityVectorStoreKey.TITLE,
|
||||
k=1,
|
||||
@@ -142,7 +138,9 @@ def test_map_query_to_entities():
|
||||
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
|
||||
for entity in entities
|
||||
]),
|
||||
text_embedder=MockBaseTextEmbedding(),
|
||||
text_embedder=ModelManager().get_or_create_embedding_model(
|
||||
model_type="mock_embedding", name="mock"
|
||||
),
|
||||
all_entities_dict={entity.id: entity for entity in entities},
|
||||
embedding_vectorstore_key=EntityVectorStoreKey.ID,
|
||||
k=2,
|
||||
@@ -167,7 +165,9 @@ def test_map_query_to_entities():
|
||||
VectorStoreDocument(id=entity.id, text=entity.title, vector=None)
|
||||
for entity in entities
|
||||
]),
|
||||
text_embedder=MockBaseTextEmbedding(),
|
||||
text_embedder=ModelManager().get_or_create_embedding_model(
|
||||
model_type="mock_embedding", name="mock"
|
||||
),
|
||||
all_entities_dict={entity.id: entity for entity in entities},
|
||||
embedding_vectorstore_key=EntityVectorStoreKey.TITLE,
|
||||
k=2,
|
||||
|
||||
Reference in New Issue
Block a user