passing tests

This commit is contained in:
William Guss
2024-08-29 10:43:57 -07:00
parent 0a510d15c6
commit fb3c469f92
5 changed files with 44 additions and 14 deletions

View File

@@ -37,7 +37,7 @@ class SQLStore(ell.store.Store):
def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[Any]:
with Session(self.engine) as session:
# Bind the serialized_lmp to the session
lmp = session.query(SerializedLMP).filter(SerializedLMP.lmp_id == serialized_lmp.lmp_id).first()
lmp = session.exec(select(SerializedLMP).filter(SerializedLMP.lmp_id == serialized_lmp.lmp_id)).first()
if lmp:
# Already added to the DB.
@@ -55,7 +55,7 @@ class SQLStore(ell.store.Store):
def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optional[Any]:
with Session(self.engine) as session:
lmp = session.query(SerializedLMP).filter(SerializedLMP.lmp_id == invocation.lmp_id).first()
lmp = session.exec(select(SerializedLMP).filter(SerializedLMP.lmp_id == invocation.lmp_id)).first()
assert lmp is not None, f"LMP with id {invocation.lmp_id} not found. Writing invocation erroneously"
# Increment num_invocations

View File

@@ -175,7 +175,3 @@ class Invocation(InvocationBase, table=True):
Index('ix_invocation_created_at_latency_ms', 'created_at', 'latency_ms'),
Index('ix_invocation_created_at_tokens', 'created_at', 'prompt_tokens', 'completion_tokens'),
)
# Update forward references
Invocation.update_forward_refs()
InvocationContents.update_forward_refs()

View File

@@ -46,7 +46,7 @@ class ContentBlock(BaseModel):
text: Optional[_lstr_generic] = Field(default=None)
image: Optional[Union[PILImage.Image, str, np.ndarray]] = Field(default=None)
audio: Optional[_lstr_generic] = Field(default=None)
audio: Optional[Union[np.ndarray, List[float]]] = Field(default=None)
tool_call: Optional[ToolCall] = Field(default=None)
parsed: Optional[Union[Type[BaseModel], BaseModel]] = Field(default=None)
tool_result: Optional[ToolResult] = Field(default=None)

View File

@@ -2,6 +2,8 @@ import pytest
from pydantic import BaseModel
import ell
from src.ell.types.message import ContentBlock, ToolCall, ToolResult, Message
import numpy as np
from PIL import Image
class DummyParams(BaseModel):
param1: str
@@ -41,7 +43,7 @@ def test_content_block_coerce_base_model():
result = ContentBlock.coerce(formatted_response)
assert isinstance(result, ContentBlock)
assert result.parsed == formatted_response
assert result.type == "formatted_response"
assert result.type == "parsed"
def test_content_block_coerce_content_block():
original_block = ContentBlock(text="Original content")
@@ -72,13 +74,20 @@ def test_message_coercion():
def test_content_block_single_non_null():
# Valid cases
ContentBlock.model_validate(ContentBlock(text="Hello"))
ContentBlock.model_validate(ContentBlock(image="image.jpg"))
ContentBlock.model_validate(ContentBlock(audio="audio.mp3"))
# ContentBlock.model_validate(ContentBlock(image="image.jpg"))
# ContentBlock.model_validate(ContentBlock(audio="audio.mp3"))
ContentBlock.model_validate(ContentBlock(tool_call=ToolCall(tool=dummy_tool,
params=DummyParams(param1="test", param2=42))))
ContentBlock.model_validate(ContentBlock(parsed=DummyFormattedResponse(field1="test", field2=42)))
ContentBlock.model_validate(ContentBlock(tool_result=ToolResult(tool_call_id="123", result=[ContentBlock(text="Tool result")])))
# New valid cases for image and audio
dummy_image = Image.new('RGB', (100, 100))
dummy_audio = np.array([0.1, 0.2, 0.3])
ContentBlock.model_validate(ContentBlock(image=dummy_image))
ContentBlock.model_validate(ContentBlock(audio=dummy_audio))
# Invalid cases
with pytest.raises(ValueError):
ContentBlock.model_validate(ContentBlock(text="Hello", image="image.jpg"))
@@ -88,4 +97,30 @@ def test_content_block_single_non_null():
params=DummyParams(param1="test", param2=42))))
with pytest.raises(ValueError):
ContentBlock.model_validate(ContentBlock(image="image.jpg", audio="audio.mp3", parsed=DummyFormattedResponse(field1="test", field2=42)))
ContentBlock.model_validate(ContentBlock(image="image.jpg", audio="audio.mp3", parsed=DummyFormattedResponse(field1="test", field2=42)))
# New invalid cases for image and audio
with pytest.raises(ValueError):
ContentBlock.model_validate(ContentBlock(image="image.jpg"))
with pytest.raises(ValueError):
ContentBlock.model_validate(ContentBlock(audio="audio.mp3"))
# Add new tests for image and audio validation
def test_content_block_image_validation():
valid_image = Image.new('RGB', (100, 100))
invalid_image = "image.jpg"
ContentBlock.model_validate(ContentBlock(image=valid_image))
with pytest.raises(ValueError):
ContentBlock.model_validate(ContentBlock(image=invalid_image))
def test_content_block_audio_validation():
valid_audio = np.array([0.1, 0.2, 0.3])
invalid_audio = "audio.mp3"
ContentBlock.model_validate(ContentBlock(audio=valid_audio))
with pytest.raises(ValueError):
ContentBlock.model_validate(ContentBlock(audio=invalid_audio))

View File

@@ -2,7 +2,7 @@ import pytest
from datetime import datetime, timezone
from sqlmodel import Session, select
from ell.stores.sql import SQLStore, SerializedLMP
from sqlalchemy import Engine, create_engine
from sqlalchemy import Engine, create_engine, func
from ell.types.lmp import LMPType
from ell.types.lmp import utc_now
@@ -74,7 +74,6 @@ def test_write_lmp(sql_store: SQLStore):
# Test that writing the same LMP again doesn't create a duplicate
sql_store.write_lmp(SerializedLMP(lmp_id=lmp_id, name=name, source=source, dependencies=dependencies, lmp_type=LMPType.LM, api_params=api_params, version_number=version_number, initial_global_vars=global_vars, initial_free_vars=free_vars, commit_message=commit_message, created_at=created_at), uses)
with Session(sql_store.engine) as session:
count = session.query(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id).count()
count = session.exec(select(func.count()).where(SerializedLMP.lmp_id == lmp_id)).one()
assert count == 1