refactor: simplify LLM tests and remove duplication
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
import pdb
|
||||
from dataclasses import dataclass
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -9,154 +12,115 @@ import sys
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
provider: str
|
||||
model_name: str
|
||||
temperature: float = 0.8
|
||||
base_url: str = None
|
||||
api_key: str = None
|
||||
|
||||
def create_message_content(text, image_path=None):
|
||||
content = [{"type": "text", "text": text}]
|
||||
|
||||
if image_path:
|
||||
from src.utils import utils
|
||||
image_data = utils.encode_image(image_path)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"}
|
||||
})
|
||||
|
||||
return content
|
||||
|
||||
def get_env_value(key, provider):
|
||||
env_mappings = {
|
||||
"openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"},
|
||||
"azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"},
|
||||
"gemini": {"api_key": "GOOGLE_API_KEY"},
|
||||
"deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"}
|
||||
}
|
||||
|
||||
if provider in env_mappings and key in env_mappings[provider]:
|
||||
return os.getenv(env_mappings[provider][key], "")
|
||||
return ""
|
||||
|
||||
def test_llm(config, query, image_path=None, system_message=None):
|
||||
from src.utils import utils
|
||||
|
||||
# Special handling for Ollama-based models
|
||||
if config.provider == "ollama":
|
||||
if "deepseek-r1" in config.model_name:
|
||||
from src.utils.llm import DeepSeekR1ChatOllama
|
||||
llm = DeepSeekR1ChatOllama(model=config.model_name)
|
||||
else:
|
||||
llm = ChatOllama(model=config.model_name)
|
||||
|
||||
ai_msg = llm.invoke(query)
|
||||
print(ai_msg.content)
|
||||
if "deepseek-r1" in config.model_name:
|
||||
pdb.set_trace()
|
||||
return
|
||||
|
||||
# For other providers, use the standard configuration
|
||||
llm = utils.get_llm_model(
|
||||
provider=config.provider,
|
||||
model_name=config.model_name,
|
||||
temperature=config.temperature,
|
||||
base_url=config.base_url or get_env_value("base_url", config.provider),
|
||||
api_key=config.api_key or get_env_value("api_key", config.provider)
|
||||
)
|
||||
|
||||
# Prepare messages for non-Ollama models
|
||||
messages = []
|
||||
if system_message:
|
||||
messages.append(SystemMessage(content=create_message_content(system_message)))
|
||||
messages.append(HumanMessage(content=create_message_content(query, image_path)))
|
||||
ai_msg = llm.invoke(messages)
|
||||
|
||||
# Handle different response types
|
||||
if hasattr(ai_msg, "reasoning_content"):
|
||||
print(ai_msg.reasoning_content)
|
||||
print(ai_msg.content)
|
||||
|
||||
if config.provider == "deepseek" and "deepseek-reasoner" in config.model_name:
|
||||
print(llm.model_name)
|
||||
pdb.set_trace()
|
||||
|
||||
def test_openai_model():
|
||||
from langchain_core.messages import HumanMessage
|
||||
from src.utils import utils
|
||||
|
||||
llm = utils.get_llm_model(
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
temperature=0.8,
|
||||
base_url=os.getenv("OPENAI_ENDPOINT", ""),
|
||||
api_key=os.getenv("OPENAI_API_KEY", "")
|
||||
)
|
||||
image_path = "assets/examples/test.png"
|
||||
image_data = utils.encode_image(image_path)
|
||||
message = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "describe this image"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
|
||||
},
|
||||
]
|
||||
)
|
||||
ai_msg = llm.invoke([message])
|
||||
print(ai_msg.content)
|
||||
|
||||
config = LLMConfig(provider="openai", model_name="gpt-4o")
|
||||
test_llm(config, "Describe this image", "assets/examples/test.png")
|
||||
|
||||
def test_gemini_model():
|
||||
# you need to enable your api key first: https://ai.google.dev/palm_docs/oauth_quickstart
|
||||
from langchain_core.messages import HumanMessage
|
||||
from src.utils import utils
|
||||
|
||||
llm = utils.get_llm_model(
|
||||
provider="gemini",
|
||||
model_name="gemini-2.0-flash-exp",
|
||||
temperature=0.8,
|
||||
api_key=os.getenv("GOOGLE_API_KEY", "")
|
||||
)
|
||||
|
||||
image_path = "assets/examples/test.png"
|
||||
image_data = utils.encode_image(image_path)
|
||||
message = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "describe this image"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
|
||||
},
|
||||
]
|
||||
)
|
||||
ai_msg = llm.invoke([message])
|
||||
print(ai_msg.content)
|
||||
|
||||
# Enable your API key first if you haven't: https://ai.google.dev/palm_docs/oauth_quickstart
|
||||
config = LLMConfig(provider="gemini", model_name="gemini-2.0-flash-exp")
|
||||
test_llm(config, "Describe this image", "assets/examples/test.png")
|
||||
|
||||
def test_azure_openai_model():
|
||||
from langchain_core.messages import HumanMessage
|
||||
from src.utils import utils
|
||||
|
||||
llm = utils.get_llm_model(
|
||||
provider="azure_openai",
|
||||
model_name="gpt-4o",
|
||||
temperature=0.8,
|
||||
base_url=os.getenv("AZURE_OPENAI_ENDPOINT", ""),
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY", "")
|
||||
)
|
||||
image_path = "assets/examples/test.png"
|
||||
image_data = utils.encode_image(image_path)
|
||||
message = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "describe this image"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
|
||||
},
|
||||
]
|
||||
)
|
||||
ai_msg = llm.invoke([message])
|
||||
print(ai_msg.content)
|
||||
|
||||
config = LLMConfig(provider="azure_openai", model_name="gpt-4o")
|
||||
test_llm(config, "Describe this image", "assets/examples/test.png")
|
||||
|
||||
def test_deepseek_model():
|
||||
from langchain_core.messages import HumanMessage
|
||||
from src.utils import utils
|
||||
|
||||
llm = utils.get_llm_model(
|
||||
provider="deepseek",
|
||||
model_name="deepseek-chat",
|
||||
temperature=0.8,
|
||||
base_url=os.getenv("DEEPSEEK_ENDPOINT", ""),
|
||||
api_key=os.getenv("DEEPSEEK_API_KEY", "")
|
||||
)
|
||||
message = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "who are you?"}
|
||||
]
|
||||
)
|
||||
ai_msg = llm.invoke([message])
|
||||
print(ai_msg.content)
|
||||
config = LLMConfig(provider="deepseek", model_name="deepseek-chat")
|
||||
test_llm(config, "Who are you?")
|
||||
|
||||
def test_deepseek_r1_model():
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
from src.utils import utils
|
||||
|
||||
llm = utils.get_llm_model(
|
||||
provider="deepseek",
|
||||
model_name="deepseek-reasoner",
|
||||
temperature=0.8,
|
||||
base_url=os.getenv("DEEPSEEK_ENDPOINT", ""),
|
||||
api_key=os.getenv("DEEPSEEK_API_KEY", "")
|
||||
)
|
||||
messages = []
|
||||
sys_message = SystemMessage(
|
||||
content=[{"type": "text", "text": "you are a helpful AI assistant"}]
|
||||
)
|
||||
messages.append(sys_message)
|
||||
user_message = HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "9.11 and 9.8, which is greater?"}
|
||||
]
|
||||
)
|
||||
messages.append(user_message)
|
||||
ai_msg = llm.invoke(messages)
|
||||
print(ai_msg.reasoning_content)
|
||||
print(ai_msg.content)
|
||||
print(llm.model_name)
|
||||
pdb.set_trace()
|
||||
config = LLMConfig(provider="deepseek", model_name="deepseek-reasoner")
|
||||
test_llm(config, "Which is greater, 9.11 or 9.8?", system_message="You are a helpful AI assistant.")
|
||||
|
||||
def test_ollama_model():
|
||||
from langchain_ollama import ChatOllama
|
||||
config = LLMConfig(provider="ollama", model_name="qwen2.5:7b")
|
||||
test_llm(config, "Sing a ballad of LangChain.")
|
||||
|
||||
llm = ChatOllama(model="qwen2.5:7b")
|
||||
ai_msg = llm.invoke("Sing a ballad of LangChain.")
|
||||
print(ai_msg.content)
|
||||
|
||||
def test_deepseek_r1_ollama_model():
|
||||
from src.utils.llm import DeepSeekR1ChatOllama
|
||||
config = LLMConfig(provider="ollama", model_name="deepseek-r1:14b")
|
||||
test_llm(config, "How many 'r's are in the word 'strawberry'?")
|
||||
|
||||
llm = DeepSeekR1ChatOllama(model="deepseek-r1:14b")
|
||||
ai_msg = llm.invoke("how many r in strawberry?")
|
||||
print(ai_msg.content)
|
||||
pdb.set_trace()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# test_openai_model()
|
||||
# test_gemini_model()
|
||||
# test_azure_openai_model()
|
||||
test_deepseek_model()
|
||||
# test_ollama_model()
|
||||
# test_deepseek_r1_model()
|
||||
# test_deepseek_r1_ollama_model()
|
||||
# test_deepseek_r1_ollama_model()
|
||||
|
||||
Reference in New Issue
Block a user