mirror of
https://github.com/microsoft/graphrag.git
synced 2025-03-11 01:26:14 +03:00
LLm Factory + Provider base
This commit is contained in:
@@ -112,6 +112,9 @@ class LLMType(str, Enum):
|
||||
# Debug
|
||||
StaticResponse = "static_response"
|
||||
|
||||
# Mock
|
||||
Mock = "mock"
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
|
||||
2
graphrag/llm/__init__.py
Normal file
2
graphrag/llm/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
0
graphrag/llm/clients/__init__.py
Normal file
0
graphrag/llm/clients/__init__.py
Normal file
0
graphrag/llm/clients/azure_openai.py
Normal file
0
graphrag/llm/clients/azure_openai.py
Normal file
17
graphrag/llm/clients/base_llm.py
Normal file
17
graphrag/llm/clients/base_llm.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing the base LLM class."""
|
||||
|
||||
def BaseLLM(Protocol):
|
||||
"""A base class for LLMs."""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_response(self, input: str) -> str:
|
||||
"""Get a response from the LLM."""
|
||||
pass
|
||||
|
||||
def get_embedding(self, input: str) -> list[float]:
|
||||
"""Get an embedding from the LLM."""
|
||||
pass
|
||||
0
graphrag/llm/clients/mock.py
Normal file
0
graphrag/llm/clients/mock.py
Normal file
0
graphrag/llm/clients/openai.py
Normal file
0
graphrag/llm/clients/openai.py
Normal file
0
graphrag/llm/clients/static.py
Normal file
0
graphrag/llm/clients/static.py
Normal file
47
graphrag/llm/factory.py
Normal file
47
graphrag/llm/factory.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A package containing a factory for supported llm types."""
|
||||
|
||||
from typing import Any, Callable
|
||||
from config.enums import LLMType
|
||||
from graphrag.llm.protocols.chat import ChatLLM
|
||||
from graphrag.llm.protocols.embedding import EmbeddingLLM
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
"""A factory for creating LLM instances."""
|
||||
|
||||
_chat_registry: dict[str, Callable[..., ChatLLM]] = {}
|
||||
_embedding_registry: dict[str, Callable[..., EmbeddingLLM]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_chat(cls, key: str, creator: Callable[..., ChatLLM]) -> None:
|
||||
cls._chat_registry[key] = creator
|
||||
|
||||
@classmethod
|
||||
def register_embedding(cls, key: str, creator: Callable[..., EmbeddingLLM]) -> None:
|
||||
cls._embedding_registry[key] = creator
|
||||
|
||||
@classmethod
|
||||
def create_chat_llm(cls, key: str, **kwargs: Any) -> ChatLLM:
|
||||
if key not in cls._chat_registry:
|
||||
msg = f"ChatLLM implementation '{key}' is not registered."
|
||||
raise ValueError(msg)
|
||||
return cls._chat_registry[key](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_embedding_llm(cls, key: str, **kwargs: Any) -> EmbeddingLLM:
|
||||
if key not in cls._embedding_registry:
|
||||
msg = f"EmbeddingLLM implementation '{key}' is not registered."
|
||||
raise ValueError(msg)
|
||||
return cls._embedding_registry[key](**kwargs)
|
||||
|
||||
# --- Register default implementations ---
|
||||
LLMFactory.register_chat(LLMType.AzureOpenAIChat, lambda **kwargs: AzureOpenAIChat(**kwargs))
|
||||
LLMFactory.register_chat(LLMType.OpenAIChat, lambda **kwargs: OpenAIChat(**kwargs))
|
||||
LLMFactory.register_chat(LLMType.Mock, lambda **kwargs: MockChatLLM())
|
||||
|
||||
LLMFactory.register_embedding(LLMType.AzureOpenAIEmbedding, lambda **kwargs: AzureOpenAIEmbedding(**kwargs))
|
||||
LLMFactory.register_embedding(LLMType.OpenAIEmbedding, lambda **kwargs: OpenAIEmbedding(**kwargs))
|
||||
LLMFactory.register_embedding(LLMType.Mock, lambda **kwargs: MockEmbeddingLLM())
|
||||
97
graphrag/llm/manager.py
Normal file
97
graphrag/llm/manager.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Singleton LLM Manager for ChatLLM and EmbeddingsLLM instances.
|
||||
|
||||
This manager lets you register chat and embeddings LLMs independently.
|
||||
It leverages the LLMFactory for instantiation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from llm.protocols import ChatLLM, EmbeddingLLM
|
||||
from llm.factory import LLMFactory
|
||||
|
||||
|
||||
class LLMManager:
|
||||
"""Singleton manager for LLM instances."""
|
||||
|
||||
_instance: LLMManager | None = None
|
||||
|
||||
def __new__(cls) -> LLMManager:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Avoid reinitialization in the singleton.
|
||||
if not hasattr(self, "_initialized"):
|
||||
self.chat_llms: dict[str, ChatLLM] = {}
|
||||
self.embedding_llms: dict[str, EmbeddingLLM] = {}
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> LLMManager:
|
||||
"""Returns the singleton instance of LLMManager."""
|
||||
return cls.__new__(cls)
|
||||
|
||||
def register_chat(self, name: str, chat_key: str, **chat_kwargs: Any) -> None:
|
||||
"""
|
||||
Registers a ChatLLM instance under a unique name.
|
||||
|
||||
Args:
|
||||
name: Unique identifier for the ChatLLM instance.
|
||||
chat_key: Key for the ChatLLM implementation in LLMFactory.
|
||||
**chat_kwargs: Additional parameters for instantiation.
|
||||
"""
|
||||
self.chat_llms[name] = LLMFactory.create_chat_llm(chat_key, **chat_kwargs)
|
||||
|
||||
def register_embedding(self, name: str, embedding_key: str, **embedding_kwargs: Any) -> None:
|
||||
"""
|
||||
Registers an EmbeddingsLLM instance under a unique name.
|
||||
|
||||
Args:
|
||||
name: Unique identifier for the EmbeddingsLLM instance.
|
||||
embedding_key: Key for the EmbeddingsLLM implementation in LLMFactory.
|
||||
**embedding_kwargs: Additional parameters for instantiation.
|
||||
"""
|
||||
self.embedding_llms[name] = LLMFactory.create_embedding_llm(embedding_key, **embedding_kwargs)
|
||||
|
||||
def get_chat_llm(self, name: str) -> ChatLLM:
|
||||
"""
|
||||
Retrieves the ChatLLM instance registered under the given name.
|
||||
|
||||
Raises:
|
||||
ValueError: If no ChatLLM is registered under the name.
|
||||
"""
|
||||
if name not in self.chat_llms:
|
||||
raise ValueError(f"No ChatLLM registered under name '{name}'.")
|
||||
return self.chat_llms[name]
|
||||
|
||||
def get_embedding_llm(self, name: str) -> EmbeddingLLM:
|
||||
"""
|
||||
Retrieves the EmbeddingsLLM instance registered under the given name.
|
||||
|
||||
Raises:
|
||||
ValueError: If no EmbeddingsLLM is registered under the name.
|
||||
"""
|
||||
if name not in self.embedding_llms:
|
||||
raise ValueError(f"No EmbeddingsLLM registered under name '{name}'.")
|
||||
return self.embedding_llms[name]
|
||||
|
||||
def remove_chat(self, name: str) -> None:
|
||||
"""Removes the ChatLLM instance registered under the given name."""
|
||||
self.chat_llms.pop(name, None)
|
||||
|
||||
def remove_embedding(self, name: str) -> None:
|
||||
"""Removes the EmbeddingsLLM instance registered under the given name."""
|
||||
self.embedding_llms.pop(name, None)
|
||||
|
||||
def list_chat_llms(self) -> dict[str, ChatLLM]:
|
||||
"""Returns a copy of all registered ChatLLM instances."""
|
||||
return {k: v for k, v in self.chat_llms.items()}
|
||||
|
||||
def list_embedding_llms(self) -> dict[str, EmbeddingLLM]:
|
||||
"""Returns a copy of all registered EmbeddingsLLM instances."""
|
||||
return {k: v for k, v in self.embedding_llms.items()}
|
||||
7
graphrag/llm/protocols/__init__.py
Normal file
7
graphrag/llm/protocols/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from .chat import ChatLLM
|
||||
from .embedding import EmbeddingLLM
|
||||
|
||||
__all__ = ["ChatLLM", "EmbeddingLLM"]
|
||||
19
graphrag/llm/protocols/chat.py
Normal file
19
graphrag/llm/protocols/chat.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Protocol, List, Any
|
||||
|
||||
class ChatLLM(Protocol):
|
||||
def chat(self, prompt: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
Generate a chat response based on the provided prompt.
|
||||
|
||||
Args:
|
||||
prompt: The text prompt to generate a response for.
|
||||
**kwargs: Additional keyword arguments (e.g., temperature, max_tokens).
|
||||
|
||||
Returns:
|
||||
A string response generated by the LLM.
|
||||
"""
|
||||
...
|
||||
19
graphrag/llm/protocols/embedding.py
Normal file
19
graphrag/llm/protocols/embedding.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Protocol, List, Any
|
||||
|
||||
class EmbeddingLLM(Protocol):
|
||||
def embed(self, text: str, **kwargs: Any) -> List[float]:
|
||||
"""
|
||||
Generate an embedding vector for the given text.
|
||||
|
||||
Args:
|
||||
text: The text to generate an embedding for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns:
|
||||
A list of floats representing the embedding vector.
|
||||
"""
|
||||
...
|
||||
Reference in New Issue
Block a user