ensure tz on deserialize timestamp

this is needed to ensure we get a utc datetime when reading from sqlite or engines that don't support storing timestamps with a timezone
This commit is contained in:
Alex Dixon
2024-08-05 22:07:19 -07:00
parent f11149bda5
commit 4f65459c8b
2 changed files with 113 additions and 4 deletions

88
tests/test_sql_store.py Normal file
View File

@@ -0,0 +1,88 @@
import pytest
from datetime import datetime, timezone
from sqlmodel import Session, select
from ell.stores.sql import SQLStore, SerializedLMP
from sqlalchemy import Engine, create_engine
from ell.types import utc_now
@pytest.fixture
def in_memory_db():
return create_engine("sqlite:///:memory:")
@pytest.fixture
def sql_store(in_memory_db: Engine) -> SQLStore:
store = SQLStore("sqlite:///:memory:")
store.engine = in_memory_db
SerializedLMP.metadata.create_all(in_memory_db)
return store
def test_write_lmp(sql_store: SQLStore):
# Arrange
lmp_id = "test_lmp_1"
name = "Test LMP"
source = "def test_function(): pass"
dependencies = str(["dep1", "dep2"])
is_lmp = True
lm_kwargs = '{"param1": "value1"}'
version_number = 1
uses = {"used_lmp_1": {}, "used_lmp_2": {}}
global_vars = {"global_var1": "value1"}
free_vars = {"free_var1": "value2"}
commit_message = "Initial commit"
created_at = utc_now()
assert created_at.tzinfo is not None
# Act
sql_store.write_lmp(
lmp_id=lmp_id,
name=name,
source=source,
dependencies=dependencies,
is_lmp=is_lmp,
lm_kwargs=lm_kwargs,
version_number=version_number,
uses=uses,
global_vars=global_vars,
free_vars=free_vars,
commit_message=commit_message,
created_at=created_at
)
# Assert
with Session(sql_store.engine) as session:
result = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first()
assert result is not None
assert result.lmp_id == lmp_id
assert result.name == name
assert result.source == source
assert result.dependencies == str(dependencies)
assert result.is_lm == is_lmp
assert result.lm_kwargs == lm_kwargs
assert result.version_number == version_number
assert result.initial_global_vars == global_vars
assert result.initial_free_vars == free_vars
assert result.commit_message == commit_message
# we want to assert created_at has timezone information
assert result.created_at.tzinfo is not None
# Test that writing the same LMP again doesn't create a duplicate
sql_store.write_lmp(
lmp_id=lmp_id,
name=name,
source=source,
dependencies=dependencies,
is_lmp=is_lmp,
lm_kwargs=lm_kwargs,
version_number=version_number,
uses=uses,
global_vars=global_vars,
free_vars=free_vars,
commit_message=commit_message,
created_at=created_at
)
with Session(sql_store.engine) as session:
count = session.query(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id).count()
assert count == 1