mirror of
https://github.com/MadcowD/ell.git
synced 2024-09-22 16:14:36 +03:00
ensure tz on deserialize timestamp
this is needed to ensure we get a utc datetime when reading from sqlite or engines that don't support storing timestamps with a timezone
This commit is contained in:
@@ -8,7 +8,9 @@ from ell.util.dict_sync_meta import DictSyncMeta
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List, Optional
|
||||
from sqlmodel import Field, SQLModel, Relationship, JSON, ARRAY, Column, Float
|
||||
from sqlmodel import Field, SQLModel, Relationship, JSON, Column
|
||||
from sqlalchemy import TIMESTAMP, func
|
||||
import sqlalchemy.types as types
|
||||
|
||||
_lstr_generic = Union[lstr, str]
|
||||
|
||||
@@ -42,6 +44,10 @@ ChatLMP = Callable[[Chat, T], Chat]
|
||||
LMP = Union[OneTurn, MultiTurnLMP, ChatLMP]
|
||||
InvocableLM = Callable[..., _lstr_generic]
|
||||
|
||||
from datetime import timezone
|
||||
from sqlmodel import Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""
|
||||
@@ -62,6 +68,16 @@ class SerializedLMPUses(SQLModel, table=True):
|
||||
lmp_using_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is using the other LMP
|
||||
|
||||
|
||||
class UTCTimestamp(types.TypeDecorator[datetime]):
|
||||
impl = types.TIMESTAMP
|
||||
def process_result_value(self, value: datetime, dialect:Any):
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
|
||||
def UTCTimestampField(index:bool=False, **kwargs:Any):
|
||||
return Field(
|
||||
sa_column= Column(UTCTimestamp(timezone=True),index=index, **kwargs))
|
||||
|
||||
|
||||
|
||||
class SerializedLMP(SQLModel, table=True):
|
||||
"""
|
||||
@@ -73,7 +89,12 @@ class SerializedLMP(SQLModel, table=True):
|
||||
name: str = Field(index=True) # Name of the LMP
|
||||
source: str # Source code or reference for the LMP
|
||||
dependencies: str # List of dependencies for the LMP, stored as a string
|
||||
created_at: datetime = Field(default_factory=utc_now, index=True) # Timestamp of when the LMP was created
|
||||
# Timestamp of when the LMP was created
|
||||
created_at: datetime = UTCTimestampField(
|
||||
index=True,
|
||||
default=func.now(),
|
||||
nullable=False
|
||||
)
|
||||
is_lm: bool # Boolean indicating if it is an LM (Language Model) or an LMP
|
||||
lm_kwargs: dict = Field(sa_column=Column(JSON)) # Additional keyword arguments for the LMP
|
||||
|
||||
@@ -139,8 +160,8 @@ class Invocation(SQLModel, table=True):
|
||||
completion_tokens: Optional[int] = Field(default=None)
|
||||
state_cache_key: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
created_at: datetime = Field(default_factory=utc_now) # Timestamp of when the invocation was created
|
||||
# Timestamp of when the invocation was created
|
||||
created_at: datetime = UTCTimestampField(default=func.now(), nullable=False)
|
||||
invocation_kwargs: dict = Field(default_factory=dict, sa_column=Column(JSON)) # Additional keyword arguments for the invocation
|
||||
|
||||
# Relationships
|
||||
|
||||
88
tests/test_sql_store.py
Normal file
88
tests/test_sql_store.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from sqlmodel import Session, select
|
||||
from ell.stores.sql import SQLStore, SerializedLMP
|
||||
from sqlalchemy import Engine, create_engine
|
||||
|
||||
from ell.types import utc_now
|
||||
|
||||
@pytest.fixture
|
||||
def in_memory_db():
|
||||
return create_engine("sqlite:///:memory:")
|
||||
|
||||
@pytest.fixture
|
||||
def sql_store(in_memory_db: Engine) -> SQLStore:
|
||||
store = SQLStore("sqlite:///:memory:")
|
||||
store.engine = in_memory_db
|
||||
SerializedLMP.metadata.create_all(in_memory_db)
|
||||
return store
|
||||
|
||||
def test_write_lmp(sql_store: SQLStore):
|
||||
# Arrange
|
||||
lmp_id = "test_lmp_1"
|
||||
name = "Test LMP"
|
||||
source = "def test_function(): pass"
|
||||
dependencies = str(["dep1", "dep2"])
|
||||
is_lmp = True
|
||||
lm_kwargs = '{"param1": "value1"}'
|
||||
version_number = 1
|
||||
uses = {"used_lmp_1": {}, "used_lmp_2": {}}
|
||||
global_vars = {"global_var1": "value1"}
|
||||
free_vars = {"free_var1": "value2"}
|
||||
commit_message = "Initial commit"
|
||||
created_at = utc_now()
|
||||
assert created_at.tzinfo is not None
|
||||
|
||||
# Act
|
||||
sql_store.write_lmp(
|
||||
lmp_id=lmp_id,
|
||||
name=name,
|
||||
source=source,
|
||||
dependencies=dependencies,
|
||||
is_lmp=is_lmp,
|
||||
lm_kwargs=lm_kwargs,
|
||||
version_number=version_number,
|
||||
uses=uses,
|
||||
global_vars=global_vars,
|
||||
free_vars=free_vars,
|
||||
commit_message=commit_message,
|
||||
created_at=created_at
|
||||
)
|
||||
|
||||
# Assert
|
||||
with Session(sql_store.engine) as session:
|
||||
result = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first()
|
||||
|
||||
assert result is not None
|
||||
assert result.lmp_id == lmp_id
|
||||
assert result.name == name
|
||||
assert result.source == source
|
||||
assert result.dependencies == str(dependencies)
|
||||
assert result.is_lm == is_lmp
|
||||
assert result.lm_kwargs == lm_kwargs
|
||||
assert result.version_number == version_number
|
||||
assert result.initial_global_vars == global_vars
|
||||
assert result.initial_free_vars == free_vars
|
||||
assert result.commit_message == commit_message
|
||||
# we want to assert created_at has timezone information
|
||||
assert result.created_at.tzinfo is not None
|
||||
|
||||
# Test that writing the same LMP again doesn't create a duplicate
|
||||
sql_store.write_lmp(
|
||||
lmp_id=lmp_id,
|
||||
name=name,
|
||||
source=source,
|
||||
dependencies=dependencies,
|
||||
is_lmp=is_lmp,
|
||||
lm_kwargs=lm_kwargs,
|
||||
version_number=version_number,
|
||||
uses=uses,
|
||||
global_vars=global_vars,
|
||||
free_vars=free_vars,
|
||||
commit_message=commit_message,
|
||||
created_at=created_at
|
||||
)
|
||||
|
||||
with Session(sql_store.engine) as session:
|
||||
count = session.query(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id).count()
|
||||
assert count == 1
|
||||
Reference in New Issue
Block a user