76 lines
2.8 KiB
Python
76 lines
2.8 KiB
Python
import os
|
|
|
|
from langchain_community.callbacks import get_openai_callback
|
|
from langchain_community.chat_models import ChatOllama
|
|
from langchain_openai import AzureChatOpenAI
|
|
from loguru import logger
|
|
|
|
|
|
class LLMProviderError(ConnectionError):
|
|
pass
|
|
|
|
|
|
class LLMService:
|
|
def __init__(self, provider, model=None):
|
|
self.provider = provider
|
|
if self.provider == "onprem":
|
|
self.model = os.environ["OLLAMA_MODEL_NAME"]
|
|
elif self.provider == "azure":
|
|
if model == None:
|
|
self.model = os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"]
|
|
else:
|
|
self.model = model
|
|
|
|
#self._initialize_client_query_engine()
|
|
|
|
def initialize_client_query_engine(self):
|
|
logger.debug(f"llm client on {self.provider} and its query engine initializing")
|
|
if self.provider == "onprem":
|
|
self.llm = ChatOllama(
|
|
model=self.model,
|
|
base_url=os.environ["OLLAMA_BASE_URL"],
|
|
temperature=0,
|
|
keep_alive=os.environ["OLLAMA_KEEP_ALIVE"],
|
|
verbose=True
|
|
)
|
|
self._test_provider(llm=self.llm)
|
|
query_engine = self.qeng_ollama
|
|
elif self.provider == "azure":
|
|
self.llm = AzureChatOpenAI(
|
|
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
|
openai_api_version=os.environ["AZURE_API_VERSION"],
|
|
azure_endpoint=os.environ["AZURE_OPENAI_API_ENDPOINT"],
|
|
azure_deployment=self.model,
|
|
)
|
|
self._test_provider(llm=self.llm)
|
|
query_engine = self.qeng_azure
|
|
return query_engine
|
|
|
|
def qeng_azure(self, prompt):
|
|
logger.debug("prompting to azure")
|
|
with get_openai_callback() as cb:
|
|
response = self.llm.invoke(prompt)
|
|
logger.debug(f"Total Tokens: {cb.total_tokens}")
|
|
logger.debug(f"Prompt Tokens: {cb.prompt_tokens}")
|
|
logger.debug(f"Completion Tokens: {cb.completion_tokens}")
|
|
logger.debug(f"Total Cost (USD): ${cb.total_cost}")
|
|
consumption = {"total_cost": round(cb.total_cost,4), "total_tokens": cb.total_tokens}
|
|
return response, consumption
|
|
|
|
def qeng_ollama(self, prompt):
|
|
logger.debug("prompting to ollama")
|
|
response = self.llm.invoke(prompt)
|
|
consumption = {"total_cost": 0.0, "total_tokens": 0.0}
|
|
return response, consumption
|
|
|
|
def _test_provider(self, llm):
|
|
try:
|
|
response = llm.invoke("answer me only in single word like OK")
|
|
if response:
|
|
logger.success(f"test OK provided llm is responsive")
|
|
return True
|
|
except Exception as e:
|
|
raise LLMProviderError(f"Error while testing provider: {e}")
|
|
return False
|
|
|