Compare commits

...

5 Commits

Author SHA1 Message Date
b5735a39df Pull request #1: Feature/init
Merge in SE471652/llm_factory from feature/init to develop

* commit '1dfedfb03bbe9cbf59aba493f2b899b00af6ec99':
  basic initialization test added
  base_query_engine added
  errors added
  llms added
  query engine added
  requirements file created
  readme file created
  toml file created
2024-09-09 10:31:03 +03:00
1dfedfb03b basic initialization test added 2024-09-09 10:24:49 +03:00
03908e4d2d base_query_engine added 2024-09-09 10:24:37 +03:00
8daddf989d errors added 2024-09-09 10:24:31 +03:00
fb952f967e llms added 2024-09-09 10:24:27 +03:00
4 changed files with 164 additions and 0 deletions

45
src/base_query_engine.py Normal file
View File

@@ -0,0 +1,45 @@
from abc import ABC, abstractmethod
from langchain_community.callbacks import get_openai_callback
from loguru import logger
class BaseQueryEngine(ABC):
def __init__(self, model, provider):
self.model = model
self.provider = provider
@abstractmethod
def get_sample_message_content(self):
pass
@abstractmethod
def construct_message_content(self, inputs):
pass
@abstractmethod
def calculate_consumption(self, cb):
pass
def get_provider_name(self):
return self.provider
def ask(self, inputs):
with get_openai_callback() as cb:
message_content = self.construct_message_content(inputs)
response = self.model.invoke(message_content)
logger.debug(f"Total Tokens: {cb.total_tokens}")
logger.debug(f"Prompt Tokens: {cb.prompt_tokens}")
logger.debug(f"Completion Tokens: {cb.completion_tokens}")
logger.debug(f"Total Cost (USD): ${round(cb.total_cost, 4)}")
consumption = self.calculate_consumption(cb)
answer = self.parse_response(response, inputs)
return answer, consumption
def parse_response(self, response, inputs):
if "parser" in inputs:
return inputs["parser"].invoke(input=response.content)
return response.content

25
src/errors.py Normal file
View File

@@ -0,0 +1,25 @@
class InsertionError(Exception):
def __init__(self, message):
super().__init__(message)
class DeletionError(Exception):
def __init__(self, message):
super().__init__(message)
class LanguageError(Exception):
def __init__(self, message):
super().__init__(message)
class LLMProviderError(ConnectionError):
def __init__(self, message):
super().__init__(message)
class LLMFactoryError(ConnectionError):
def __init__(self, message):
super().__init__(message)

84
src/llms.py Normal file
View File

@@ -0,0 +1,84 @@
import os
import query_engine as qe
from typing import List
from errors import LLMFactoryError
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_community.chat_models import ChatOllama
class LLMFactory:
def __init__(self):
self.model = None
def create_model(self, provider: str, conversational: bool, model_name: str, multimodal: List[str] = None):
if provider == "azure" and conversational:
self.model = AzureChatOpenAI(
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
openai_api_version=os.environ["AZURE_API_VERSION"],
azure_endpoint=os.environ["AZURE_OPENAI_API_ENDPOINT"],
azure_deployment=model_name,
)
return qe.QueryEngineOpenAIConversationalText(self.model)
elif provider == "openai":
if conversational:
if multimodal:
if multimodal == ["text", "image"]:
self.model = ChatOpenAI(
temperature=os.environ["OPENAI_TEMPERATURE"],
model=model_name,
timeout=10,
max_retries=3
)
return qe.QueryEngineOpenAIConversationalMultiModal(self.model)
else:
raise LLMFactoryError(f"given multimodality is not implemented: {multimodal}")
else:
raise LLMFactoryError(f"non-multimodal is not implemented")
else:
raise LLMFactoryError(f"non-conversation is not implemented")
elif provider == "ollama":
if conversational:
self.model = ChatOllama(
base_url=os.environ["OLLAMA_BASE_URL"],
temperature=0,
model=model_name,
keep_alive=int(os.environ["OLLAMA_KEEP_ALIVE"]),
num_ctx=int(os.environ["OLLAMA_NUM_CTX"]),
timeout=300,
max_retries=3
)
if multimodal:
if multimodal == ["text", "image"]:
return qe.QueryEngineOllamaConversationalMultiModal(self.model, provider)
else:
raise LLMFactoryError(f"given multimodality is not implemented: {multimodal}")
else:
return qe.QueryEngineOllamaConversationalText(self.model, provider)
else:
raise LLMFactoryError(f"non-conversation is not implemented")
elif provider == "llamacpp":
if conversational:
self.model = ChatOpenAI(
model=model_name, #"models/Meta-Llama-3-8B-Instruct.Q8_0.gguf",
temperature=0,
max_tokens=2048,
stop=["<|eot_id|>"],
timeout=None,
max_retries=2,
api_key="sk-xxx",
base_url=os.environ["LLAMACPP_BASE_URL"],
)
if multimodal:
if multimodal == ["text", "image"]:
raise LLMFactoryError(f"given {provider}, multimodality is not implemented: {multimodal}")
else:
raise LLMFactoryError(f"given {provider}, multimodality is not implemented: {multimodal}")
else:
return qe.QueryEngineLlamacppConversationalText(self.model)
else:
raise LLMFactoryError(f"non-conversation is not implemented")

10
src/test.py Normal file
View File

@@ -0,0 +1,10 @@
import llms
llm_factory = llms.LLMFactory()
model_ollama = llm_factory.create_model(
provider="ollama",
conversational=True,
model_name="llama3.1:8b-instruct-fp16",
)