LLm Factory + Provider base

This commit is contained in:
Alonso Guevara
2025-02-13 15:00:16 -06:00
parent fe461417b5
commit 174c712a46
13 changed files with 211 additions and 0 deletions

View File

@@ -112,6 +112,9 @@ class LLMType(str, Enum):
# Debug
StaticResponse = "static_response"
# Mock
Mock = "mock"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'

2
graphrag/llm/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License

View File

View File

View File

@@ -0,0 +1,17 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing the base LLM class."""
def BaseLLM(Protocol):
"""A base class for LLMs."""
def __init__(self):
pass
def get_response(self, input: str) -> str:
"""Get a response from the LLM."""
pass
def get_embedding(self, input: str) -> list[float]:
"""Get an embedding from the LLM."""
pass

View File

View File

View File

47
graphrag/llm/factory.py Normal file
View File

@@ -0,0 +1,47 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""A package containing a factory for supported llm types."""
from typing import Any, Callable
from config.enums import LLMType
from graphrag.llm.protocols.chat import ChatLLM
from graphrag.llm.protocols.embedding import EmbeddingLLM
class LLMFactory:
"""A factory for creating LLM instances."""
_chat_registry: dict[str, Callable[..., ChatLLM]] = {}
_embedding_registry: dict[str, Callable[..., EmbeddingLLM]] = {}
@classmethod
def register_chat(cls, key: str, creator: Callable[..., ChatLLM]) -> None:
cls._chat_registry[key] = creator
@classmethod
def register_embedding(cls, key: str, creator: Callable[..., EmbeddingLLM]) -> None:
cls._embedding_registry[key] = creator
@classmethod
def create_chat_llm(cls, key: str, **kwargs: Any) -> ChatLLM:
if key not in cls._chat_registry:
msg = f"ChatLLM implementation '{key}' is not registered."
raise ValueError(msg)
return cls._chat_registry[key](**kwargs)
@classmethod
def create_embedding_llm(cls, key: str, **kwargs: Any) -> EmbeddingLLM:
if key not in cls._embedding_registry:
msg = f"EmbeddingLLM implementation '{key}' is not registered."
raise ValueError(msg)
return cls._embedding_registry[key](**kwargs)
# --- Register default implementations ---
LLMFactory.register_chat(LLMType.AzureOpenAIChat, lambda **kwargs: AzureOpenAIChat(**kwargs))
LLMFactory.register_chat(LLMType.OpenAIChat, lambda **kwargs: OpenAIChat(**kwargs))
LLMFactory.register_chat(LLMType.Mock, lambda **kwargs: MockChatLLM())
LLMFactory.register_embedding(LLMType.AzureOpenAIEmbedding, lambda **kwargs: AzureOpenAIEmbedding(**kwargs))
LLMFactory.register_embedding(LLMType.OpenAIEmbedding, lambda **kwargs: OpenAIEmbedding(**kwargs))
LLMFactory.register_embedding(LLMType.Mock, lambda **kwargs: MockEmbeddingLLM())

97
graphrag/llm/manager.py Normal file
View File

@@ -0,0 +1,97 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""Singleton LLM Manager for ChatLLM and EmbeddingsLLM instances.
This manager lets you register chat and embeddings LLMs independently.
It leverages the LLMFactory for instantiation.
"""
from __future__ import annotations
from typing import Any
from llm.protocols import ChatLLM, EmbeddingLLM
from llm.factory import LLMFactory
class LLMManager:
"""Singleton manager for LLM instances."""
_instance: LLMManager | None = None
def __new__(cls) -> LLMManager:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
# Avoid reinitialization in the singleton.
if not hasattr(self, "_initialized"):
self.chat_llms: dict[str, ChatLLM] = {}
self.embedding_llms: dict[str, EmbeddingLLM] = {}
self._initialized = True
@classmethod
def get_instance(cls) -> LLMManager:
"""Returns the singleton instance of LLMManager."""
return cls.__new__(cls)
def register_chat(self, name: str, chat_key: str, **chat_kwargs: Any) -> None:
"""
Registers a ChatLLM instance under a unique name.
Args:
name: Unique identifier for the ChatLLM instance.
chat_key: Key for the ChatLLM implementation in LLMFactory.
**chat_kwargs: Additional parameters for instantiation.
"""
self.chat_llms[name] = LLMFactory.create_chat_llm(chat_key, **chat_kwargs)
def register_embedding(self, name: str, embedding_key: str, **embedding_kwargs: Any) -> None:
"""
Registers an EmbeddingsLLM instance under a unique name.
Args:
name: Unique identifier for the EmbeddingsLLM instance.
embedding_key: Key for the EmbeddingsLLM implementation in LLMFactory.
**embedding_kwargs: Additional parameters for instantiation.
"""
self.embedding_llms[name] = LLMFactory.create_embedding_llm(embedding_key, **embedding_kwargs)
def get_chat_llm(self, name: str) -> ChatLLM:
"""
Retrieves the ChatLLM instance registered under the given name.
Raises:
ValueError: If no ChatLLM is registered under the name.
"""
if name not in self.chat_llms:
raise ValueError(f"No ChatLLM registered under name '{name}'.")
return self.chat_llms[name]
def get_embedding_llm(self, name: str) -> EmbeddingLLM:
"""
Retrieves the EmbeddingsLLM instance registered under the given name.
Raises:
ValueError: If no EmbeddingsLLM is registered under the name.
"""
if name not in self.embedding_llms:
raise ValueError(f"No EmbeddingsLLM registered under name '{name}'.")
return self.embedding_llms[name]
def remove_chat(self, name: str) -> None:
"""Removes the ChatLLM instance registered under the given name."""
self.chat_llms.pop(name, None)
def remove_embedding(self, name: str) -> None:
"""Removes the EmbeddingsLLM instance registered under the given name."""
self.embedding_llms.pop(name, None)
def list_chat_llms(self) -> dict[str, ChatLLM]:
"""Returns a copy of all registered ChatLLM instances."""
return {k: v for k, v in self.chat_llms.items()}
def list_embedding_llms(self) -> dict[str, EmbeddingLLM]:
"""Returns a copy of all registered EmbeddingsLLM instances."""
return {k: v for k, v in self.embedding_llms.items()}

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
from .chat import ChatLLM
from .embedding import EmbeddingLLM
__all__ = ["ChatLLM", "EmbeddingLLM"]

View File

@@ -0,0 +1,19 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
from __future__ import annotations
from typing import Protocol, List, Any
class ChatLLM(Protocol):
def chat(self, prompt: str, **kwargs: Any) -> str:
"""
Generate a chat response based on the provided prompt.
Args:
prompt: The text prompt to generate a response for.
**kwargs: Additional keyword arguments (e.g., temperature, max_tokens).
Returns:
A string response generated by the LLM.
"""
...

View File

@@ -0,0 +1,19 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
from __future__ import annotations
from typing import Protocol, List, Any
class EmbeddingLLM(Protocol):
def embed(self, text: str, **kwargs: Any) -> List[float]:
"""
Generate an embedding vector for the given text.
Args:
text: The text to generate an embedding for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns:
A list of floats representing the embedding vector.
"""
...