Compare commits
5 Commits
cdfada6d3c
...
b5735a39df
| Author | SHA1 | Date | |
|---|---|---|---|
| b5735a39df | |||
| 1dfedfb03b | |||
| 03908e4d2d | |||
| 8daddf989d | |||
| fb952f967e |
45
src/base_query_engine.py
Normal file
45
src/base_query_engine.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class BaseQueryEngine(ABC):
|
||||
|
||||
def __init__(self, model, provider):
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
|
||||
@abstractmethod
|
||||
def get_sample_message_content(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def construct_message_content(self, inputs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def calculate_consumption(self, cb):
|
||||
pass
|
||||
|
||||
def get_provider_name(self):
|
||||
return self.provider
|
||||
|
||||
def ask(self, inputs):
|
||||
with get_openai_callback() as cb:
|
||||
message_content = self.construct_message_content(inputs)
|
||||
response = self.model.invoke(message_content)
|
||||
|
||||
logger.debug(f"Total Tokens: {cb.total_tokens}")
|
||||
logger.debug(f"Prompt Tokens: {cb.prompt_tokens}")
|
||||
logger.debug(f"Completion Tokens: {cb.completion_tokens}")
|
||||
logger.debug(f"Total Cost (USD): ${round(cb.total_cost, 4)}")
|
||||
|
||||
consumption = self.calculate_consumption(cb)
|
||||
|
||||
answer = self.parse_response(response, inputs)
|
||||
return answer, consumption
|
||||
|
||||
def parse_response(self, response, inputs):
|
||||
if "parser" in inputs:
|
||||
return inputs["parser"].invoke(input=response.content)
|
||||
return response.content
|
||||
25
src/errors.py
Normal file
25
src/errors.py
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
|
||||
class InsertionError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class DeletionError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class LanguageError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class LLMProviderError(ConnectionError):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class LLMFactoryError(ConnectionError):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
84
src/llms.py
Normal file
84
src/llms.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
import query_engine as qe
|
||||
|
||||
from typing import List
|
||||
|
||||
from errors import LLMFactoryError
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
|
||||
def create_model(self, provider: str, conversational: bool, model_name: str, multimodal: List[str] = None):
|
||||
if provider == "azure" and conversational:
|
||||
self.model = AzureChatOpenAI(
|
||||
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
||||
openai_api_version=os.environ["AZURE_API_VERSION"],
|
||||
azure_endpoint=os.environ["AZURE_OPENAI_API_ENDPOINT"],
|
||||
azure_deployment=model_name,
|
||||
)
|
||||
return qe.QueryEngineOpenAIConversationalText(self.model)
|
||||
elif provider == "openai":
|
||||
if conversational:
|
||||
if multimodal:
|
||||
if multimodal == ["text", "image"]:
|
||||
self.model = ChatOpenAI(
|
||||
temperature=os.environ["OPENAI_TEMPERATURE"],
|
||||
model=model_name,
|
||||
timeout=10,
|
||||
max_retries=3
|
||||
)
|
||||
return qe.QueryEngineOpenAIConversationalMultiModal(self.model)
|
||||
else:
|
||||
raise LLMFactoryError(f"given multimodality is not implemented: {multimodal}")
|
||||
else:
|
||||
raise LLMFactoryError(f"non-multimodal is not implemented")
|
||||
else:
|
||||
raise LLMFactoryError(f"non-conversation is not implemented")
|
||||
elif provider == "ollama":
|
||||
if conversational:
|
||||
self.model = ChatOllama(
|
||||
base_url=os.environ["OLLAMA_BASE_URL"],
|
||||
temperature=0,
|
||||
model=model_name,
|
||||
keep_alive=int(os.environ["OLLAMA_KEEP_ALIVE"]),
|
||||
num_ctx=int(os.environ["OLLAMA_NUM_CTX"]),
|
||||
timeout=300,
|
||||
max_retries=3
|
||||
)
|
||||
if multimodal:
|
||||
if multimodal == ["text", "image"]:
|
||||
return qe.QueryEngineOllamaConversationalMultiModal(self.model, provider)
|
||||
else:
|
||||
raise LLMFactoryError(f"given multimodality is not implemented: {multimodal}")
|
||||
else:
|
||||
return qe.QueryEngineOllamaConversationalText(self.model, provider)
|
||||
else:
|
||||
raise LLMFactoryError(f"non-conversation is not implemented")
|
||||
elif provider == "llamacpp":
|
||||
if conversational:
|
||||
self.model = ChatOpenAI(
|
||||
model=model_name, #"models/Meta-Llama-3-8B-Instruct.Q8_0.gguf",
|
||||
temperature=0,
|
||||
max_tokens=2048,
|
||||
stop=["<|eot_id|>"],
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key="sk-xxx",
|
||||
base_url=os.environ["LLAMACPP_BASE_URL"],
|
||||
)
|
||||
if multimodal:
|
||||
if multimodal == ["text", "image"]:
|
||||
raise LLMFactoryError(f"given {provider}, multimodality is not implemented: {multimodal}")
|
||||
else:
|
||||
raise LLMFactoryError(f"given {provider}, multimodality is not implemented: {multimodal}")
|
||||
else:
|
||||
return qe.QueryEngineLlamacppConversationalText(self.model)
|
||||
else:
|
||||
raise LLMFactoryError(f"non-conversation is not implemented")
|
||||
|
||||
|
||||
10
src/test.py
Normal file
10
src/test.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import llms
|
||||
|
||||
llm_factory = llms.LLMFactory()
|
||||
model_ollama = llm_factory.create_model(
|
||||
provider="ollama",
|
||||
conversational=True,
|
||||
model_name="llama3.1:8b-instruct-fp16",
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user