Pull request #1: Feature/init
Merge in SE471652/llm_factory from feature/init to develop * commit '1dfedfb03bbe9cbf59aba493f2b899b00af6ec99': basic initialization test added base_query_engine added errors added llms added query engine added requirements file created readme file created toml file created
This commit is contained in:
75
pyproject.toml
Normal file
75
pyproject.toml
Normal file
@@ -0,0 +1,75 @@
|
||||
[project]
|
||||
name = "llm-factory"
|
||||
version = "0.0.1"
|
||||
authors = [
|
||||
{ name="Mithat Sinan Ergen", email="mithat.ergen@turkcell.com.tr" },
|
||||
]
|
||||
description = "LLM Factory package"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0",
|
||||
"build==0.9.0",
|
||||
"pep517==0.13.0",
|
||||
"pyparsing==3.0.9",
|
||||
"tomli==2.0.1",
|
||||
"colorama==0.4.6",
|
||||
"greenlet==1.1.3.post0",
|
||||
"aiohappyeyeballs==2.4.0",
|
||||
"aiohttp==3.10.5",
|
||||
"aiosignal==1.3.1",
|
||||
"annotated-types==0.7.0",
|
||||
"anyio==4.4.0",
|
||||
"async-timeout==4.0.3",
|
||||
"attrs==24.2.0",
|
||||
"certifi==2024.8.30",
|
||||
"charset-normalizer==3.3.2",
|
||||
"dataclasses-json==0.6.7",
|
||||
"distro==1.9.0",
|
||||
"exceptiongroup==1.2.2",
|
||||
"frozenlist==1.4.1",
|
||||
"h11==0.14.0",
|
||||
"httpcore==1.0.5",
|
||||
"httpx==0.27.2",
|
||||
"idna==3.8",
|
||||
"jiter==0.5.0",
|
||||
"jsonpatch==1.33",
|
||||
"jsonpointer==3.0.0",
|
||||
"langchain==0.2.6",
|
||||
"langchain-community==0.2.6",
|
||||
"langchain-core==0.2.11",
|
||||
"langchain-openai==0.1.14",
|
||||
"langchain-text-splitters==0.2.2",
|
||||
"langsmith==0.1.116",
|
||||
"loguru==0.7.2",
|
||||
"marshmallow==3.22.0",
|
||||
"multidict==6.0.5",
|
||||
"mypy-extensions==1.0.0",
|
||||
"numpy==1.26.4",
|
||||
"openai==1.44.0",
|
||||
"orjson==3.10.7",
|
||||
"packaging==24.1",
|
||||
"pydantic==2.9.0",
|
||||
"pydantic_core==2.23.2",
|
||||
"PyYAML==6.0.2",
|
||||
"regex==2024.7.24",
|
||||
"requests==2.32.3",
|
||||
"sniffio==1.3.1",
|
||||
"SQLAlchemy==2.0.34",
|
||||
"tenacity==8.5.0",
|
||||
"tiktoken==0.7.0",
|
||||
"tqdm==4.66.5",
|
||||
"typing-inspect==0.9.0",
|
||||
"typing_extensions==4.12.2",
|
||||
"tzdata==2024.1",
|
||||
"urllib3==2.2.2",
|
||||
"yarl==1.11.0",
|
||||
"win32-setctime==1.1.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
|
||||
49
requirements.txt
Normal file
49
requirements.txt
Normal file
@@ -0,0 +1,49 @@
|
||||
aiohappyeyeballs==2.4.0
|
||||
aiohttp==3.10.5
|
||||
aiosignal==1.3.1
|
||||
annotated-types==0.7.0
|
||||
anyio==4.4.0
|
||||
async-timeout==4.0.3
|
||||
attrs==24.2.0
|
||||
certifi==2024.8.30
|
||||
charset-normalizer==3.3.2
|
||||
dataclasses-json==0.6.7
|
||||
distro==1.9.0
|
||||
exceptiongroup==1.2.2
|
||||
frozenlist==1.4.1
|
||||
h11==0.14.0
|
||||
httpcore==1.0.5
|
||||
httpx==0.27.2
|
||||
idna==3.8
|
||||
jiter==0.5.0
|
||||
jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
langchain==0.2.6
|
||||
langchain-community==0.2.6
|
||||
langchain-core==0.2.11
|
||||
langchain-openai==0.1.14
|
||||
langchain-text-splitters==0.2.2
|
||||
langsmith==0.1.116
|
||||
loguru==0.7.2
|
||||
marshmallow==3.22.0
|
||||
multidict==6.0.5
|
||||
mypy-extensions==1.0.0
|
||||
numpy==1.26.4
|
||||
openai==1.44.0
|
||||
orjson==3.10.7
|
||||
packaging==24.1
|
||||
pydantic==2.9.0
|
||||
pydantic_core==2.23.2
|
||||
PyYAML==6.0.2
|
||||
regex==2024.7.24
|
||||
requests==2.32.3
|
||||
sniffio==1.3.1
|
||||
SQLAlchemy==2.0.34
|
||||
tenacity==8.5.0
|
||||
tiktoken==0.7.0
|
||||
tqdm==4.66.5
|
||||
typing-inspect==0.9.0
|
||||
typing_extensions==4.12.2
|
||||
tzdata==2024.1
|
||||
urllib3==2.2.2
|
||||
yarl==1.11.0
|
||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
45
src/base_query_engine.py
Normal file
45
src/base_query_engine.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class BaseQueryEngine(ABC):
|
||||
|
||||
def __init__(self, model, provider):
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
|
||||
@abstractmethod
|
||||
def get_sample_message_content(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def construct_message_content(self, inputs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def calculate_consumption(self, cb):
|
||||
pass
|
||||
|
||||
def get_provider_name(self):
|
||||
return self.provider
|
||||
|
||||
def ask(self, inputs):
|
||||
with get_openai_callback() as cb:
|
||||
message_content = self.construct_message_content(inputs)
|
||||
response = self.model.invoke(message_content)
|
||||
|
||||
logger.debug(f"Total Tokens: {cb.total_tokens}")
|
||||
logger.debug(f"Prompt Tokens: {cb.prompt_tokens}")
|
||||
logger.debug(f"Completion Tokens: {cb.completion_tokens}")
|
||||
logger.debug(f"Total Cost (USD): ${round(cb.total_cost, 4)}")
|
||||
|
||||
consumption = self.calculate_consumption(cb)
|
||||
|
||||
answer = self.parse_response(response, inputs)
|
||||
return answer, consumption
|
||||
|
||||
def parse_response(self, response, inputs):
|
||||
if "parser" in inputs:
|
||||
return inputs["parser"].invoke(input=response.content)
|
||||
return response.content
|
||||
25
src/errors.py
Normal file
25
src/errors.py
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
|
||||
class InsertionError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class DeletionError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class LanguageError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class LLMProviderError(ConnectionError):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class LLMFactoryError(ConnectionError):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
84
src/llms.py
Normal file
84
src/llms.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
import query_engine as qe
|
||||
|
||||
from typing import List
|
||||
|
||||
from errors import LLMFactoryError
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
|
||||
def create_model(self, provider: str, conversational: bool, model_name: str, multimodal: List[str] = None):
|
||||
if provider == "azure" and conversational:
|
||||
self.model = AzureChatOpenAI(
|
||||
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
||||
openai_api_version=os.environ["AZURE_API_VERSION"],
|
||||
azure_endpoint=os.environ["AZURE_OPENAI_API_ENDPOINT"],
|
||||
azure_deployment=model_name,
|
||||
)
|
||||
return qe.QueryEngineOpenAIConversationalText(self.model)
|
||||
elif provider == "openai":
|
||||
if conversational:
|
||||
if multimodal:
|
||||
if multimodal == ["text", "image"]:
|
||||
self.model = ChatOpenAI(
|
||||
temperature=os.environ["OPENAI_TEMPERATURE"],
|
||||
model=model_name,
|
||||
timeout=10,
|
||||
max_retries=3
|
||||
)
|
||||
return qe.QueryEngineOpenAIConversationalMultiModal(self.model)
|
||||
else:
|
||||
raise LLMFactoryError(f"given multimodality is not implemented: {multimodal}")
|
||||
else:
|
||||
raise LLMFactoryError(f"non-multimodal is not implemented")
|
||||
else:
|
||||
raise LLMFactoryError(f"non-conversation is not implemented")
|
||||
elif provider == "ollama":
|
||||
if conversational:
|
||||
self.model = ChatOllama(
|
||||
base_url=os.environ["OLLAMA_BASE_URL"],
|
||||
temperature=0,
|
||||
model=model_name,
|
||||
keep_alive=int(os.environ["OLLAMA_KEEP_ALIVE"]),
|
||||
num_ctx=int(os.environ["OLLAMA_NUM_CTX"]),
|
||||
timeout=300,
|
||||
max_retries=3
|
||||
)
|
||||
if multimodal:
|
||||
if multimodal == ["text", "image"]:
|
||||
return qe.QueryEngineOllamaConversationalMultiModal(self.model, provider)
|
||||
else:
|
||||
raise LLMFactoryError(f"given multimodality is not implemented: {multimodal}")
|
||||
else:
|
||||
return qe.QueryEngineOllamaConversationalText(self.model, provider)
|
||||
else:
|
||||
raise LLMFactoryError(f"non-conversation is not implemented")
|
||||
elif provider == "llamacpp":
|
||||
if conversational:
|
||||
self.model = ChatOpenAI(
|
||||
model=model_name, #"models/Meta-Llama-3-8B-Instruct.Q8_0.gguf",
|
||||
temperature=0,
|
||||
max_tokens=2048,
|
||||
stop=["<|eot_id|>"],
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key="sk-xxx",
|
||||
base_url=os.environ["LLAMACPP_BASE_URL"],
|
||||
)
|
||||
if multimodal:
|
||||
if multimodal == ["text", "image"]:
|
||||
raise LLMFactoryError(f"given {provider}, multimodality is not implemented: {multimodal}")
|
||||
else:
|
||||
raise LLMFactoryError(f"given {provider}, multimodality is not implemented: {multimodal}")
|
||||
else:
|
||||
return qe.QueryEngineLlamacppConversationalText(self.model)
|
||||
else:
|
||||
raise LLMFactoryError(f"non-conversation is not implemented")
|
||||
|
||||
|
||||
92
src/query_engine.py
Normal file
92
src/query_engine.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from loguru import logger
|
||||
|
||||
from base_query_engine import BaseQueryEngine
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
|
||||
class QueryEngineOpenAIConversationalMultiModal(BaseQueryEngine):
|
||||
|
||||
def construct_message_content(self, inputs):
|
||||
return [
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": inputs["prompt"]},
|
||||
{"type": "text", "text": inputs["parser"].get_format_instructions()},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{inputs['image']}"}}
|
||||
])
|
||||
]
|
||||
|
||||
def calculate_consumption(self, cb):
|
||||
return {"total_cost": round(cb.total_cost, 4), "total_tokens": cb.total_tokens}
|
||||
|
||||
|
||||
class QueryEngineOpenAIConversationalText(BaseQueryEngine):
|
||||
|
||||
def construct_message_content(self, inputs):
|
||||
return [
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": inputs["prompt"]}
|
||||
])
|
||||
]
|
||||
|
||||
def calculate_consumption(self, cb):
|
||||
return {"total_cost": round(cb.total_cost, 4), "total_tokens": cb.total_tokens}
|
||||
|
||||
|
||||
class QueryEngineOllamaConversationalMultiModal(BaseQueryEngine):
|
||||
|
||||
def construct_message_content(self, inputs):
|
||||
return [
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": inputs["prompt"]},
|
||||
{"type": "text", "text": inputs["parser"].get_format_instructions()},
|
||||
{"type": "image_url", "image_url": f"data:image/jpeg;base64,{inputs['image']}"}
|
||||
])
|
||||
]
|
||||
|
||||
def calculate_consumption(self, cb):
|
||||
return {"total_cost": 0, "total_tokens": 0}
|
||||
|
||||
|
||||
class QueryEngineOllamaConversationalText(BaseQueryEngine):
|
||||
|
||||
def get_sample_message_content(self):
|
||||
return """inputs = {"prompt": system_prompt + user_prompt}"""
|
||||
|
||||
def construct_message_content(self, inputs):
|
||||
return [
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": inputs["prompt"]}
|
||||
])
|
||||
]
|
||||
|
||||
def calculate_consumption(self, cb):
|
||||
return {"total_cost": 0, "total_tokens": 0}
|
||||
|
||||
|
||||
class QueryEngineLlamacppConversationalText(BaseQueryEngine):
|
||||
|
||||
def get_sample_message_content(self):
|
||||
return """inputs = {"system": system_prompt, "prompt": user_prompt}"""
|
||||
|
||||
def construct_message_content(self, inputs):
|
||||
return [
|
||||
{"role": "system", "content": inputs["system"]},
|
||||
{"role": "user", "content": inputs["prompt"]}
|
||||
]
|
||||
|
||||
def calculate_consumption(self, cb):
|
||||
return {"total_cost": 0, "total_tokens": cb.total_tokens}
|
||||
|
||||
def ask(self, inputs):
|
||||
with get_openai_callback() as cb:
|
||||
message_content = self.construct_message_content(inputs)
|
||||
response = self.model.invoke(message_content)
|
||||
logger.debug(f"Total Tokens: {cb.total_tokens}")
|
||||
logger.debug(f"Prompt Tokens: {cb.prompt_tokens}")
|
||||
logger.debug(f"Completion Tokens: {cb.completion_tokens}")
|
||||
logger.debug(f"Total Cost (USD): ${round(cb.total_cost, 4)}")
|
||||
consumption = self.calculate_consumption(cb)
|
||||
|
||||
answer = self.parse_response(response, inputs)
|
||||
return answer, consumption
|
||||
10
src/test.py
Normal file
10
src/test.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import llms
|
||||
|
||||
llm_factory = llms.LLMFactory()
|
||||
model_ollama = llm_factory.create_model(
|
||||
provider="ollama",
|
||||
conversational=True,
|
||||
model_name="llama3.1:8b-instruct-fp16",
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user