mirror of
https://github.com/MadcowD/ell.git
synced 2024-09-22 16:14:36 +03:00
ensure tz on deserialize timestamp
this is needed to ensure we get a utc datetime when reading from sqlite or engines that don't support storing timestamps with a timezone
This commit is contained in:
88
tests/test_sql_store.py
Normal file
88
tests/test_sql_store.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from sqlmodel import Session, select
|
||||
from ell.stores.sql import SQLStore, SerializedLMP
|
||||
from sqlalchemy import Engine, create_engine
|
||||
|
||||
from ell.types import utc_now
|
||||
|
||||
@pytest.fixture
|
||||
def in_memory_db():
|
||||
return create_engine("sqlite:///:memory:")
|
||||
|
||||
@pytest.fixture
|
||||
def sql_store(in_memory_db: Engine) -> SQLStore:
|
||||
store = SQLStore("sqlite:///:memory:")
|
||||
store.engine = in_memory_db
|
||||
SerializedLMP.metadata.create_all(in_memory_db)
|
||||
return store
|
||||
|
||||
def test_write_lmp(sql_store: SQLStore):
|
||||
# Arrange
|
||||
lmp_id = "test_lmp_1"
|
||||
name = "Test LMP"
|
||||
source = "def test_function(): pass"
|
||||
dependencies = str(["dep1", "dep2"])
|
||||
is_lmp = True
|
||||
lm_kwargs = '{"param1": "value1"}'
|
||||
version_number = 1
|
||||
uses = {"used_lmp_1": {}, "used_lmp_2": {}}
|
||||
global_vars = {"global_var1": "value1"}
|
||||
free_vars = {"free_var1": "value2"}
|
||||
commit_message = "Initial commit"
|
||||
created_at = utc_now()
|
||||
assert created_at.tzinfo is not None
|
||||
|
||||
# Act
|
||||
sql_store.write_lmp(
|
||||
lmp_id=lmp_id,
|
||||
name=name,
|
||||
source=source,
|
||||
dependencies=dependencies,
|
||||
is_lmp=is_lmp,
|
||||
lm_kwargs=lm_kwargs,
|
||||
version_number=version_number,
|
||||
uses=uses,
|
||||
global_vars=global_vars,
|
||||
free_vars=free_vars,
|
||||
commit_message=commit_message,
|
||||
created_at=created_at
|
||||
)
|
||||
|
||||
# Assert
|
||||
with Session(sql_store.engine) as session:
|
||||
result = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first()
|
||||
|
||||
assert result is not None
|
||||
assert result.lmp_id == lmp_id
|
||||
assert result.name == name
|
||||
assert result.source == source
|
||||
assert result.dependencies == str(dependencies)
|
||||
assert result.is_lm == is_lmp
|
||||
assert result.lm_kwargs == lm_kwargs
|
||||
assert result.version_number == version_number
|
||||
assert result.initial_global_vars == global_vars
|
||||
assert result.initial_free_vars == free_vars
|
||||
assert result.commit_message == commit_message
|
||||
# we want to assert created_at has timezone information
|
||||
assert result.created_at.tzinfo is not None
|
||||
|
||||
# Test that writing the same LMP again doesn't create a duplicate
|
||||
sql_store.write_lmp(
|
||||
lmp_id=lmp_id,
|
||||
name=name,
|
||||
source=source,
|
||||
dependencies=dependencies,
|
||||
is_lmp=is_lmp,
|
||||
lm_kwargs=lm_kwargs,
|
||||
version_number=version_number,
|
||||
uses=uses,
|
||||
global_vars=global_vars,
|
||||
free_vars=free_vars,
|
||||
commit_message=commit_message,
|
||||
created_at=created_at
|
||||
)
|
||||
|
||||
with Session(sql_store.engine) as session:
|
||||
count = session.query(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id).count()
|
||||
assert count == 1
|
||||
Reference in New Issue
Block a user