beginning of lstr pydantization

This commit is contained in:
William Guss
2024-08-22 17:17:46 -07:00
parent 4b2465d70c
commit feb98386b1
13 changed files with 162 additions and 385 deletions

View File

@@ -2,7 +2,7 @@ import dataclasses
import random
from typing import Callable, List, Tuple
import ell
from ell.lstr import lstr
from ell._lstr import _lstr
ell.config.verbose = True
@@ -10,7 +10,7 @@ ell.config.verbose = True
# Option 0 (pythonic.)
# This going to be NON Native to ell
def parse_outputs(result : lstr) -> str:
def parse_outputs(result : _lstr) -> str:
name = result.split(":")[0]
backstory = result.split(":")[1]
return Personality(name, backstory)
@@ -31,7 +31,7 @@ def get_personality():
# Option 1.
def parse_outputs(result : lstr) -> str:
def parse_outputs(result : _lstr) -> str:
name = result.split(":")[0]
backstory = result.split(":")[1]
return Personality(name, backstory)
@@ -56,7 +56,7 @@ class Personality:
backstory : str
@staticmethod
def parse_outputs(result : lstr) -> str:
def parse_outputs(result : _lstr) -> str:
name = result.split(":")[0]
backstory = result.split(":")[1]
return Personality(name, backstory)
@@ -73,7 +73,7 @@ Backstory: <3 sentence backstory>'""" # System prompt
# Option 3. Another decorator
def parse_outputs(result : lstr) -> str:
def parse_outputs(result : _lstr) -> str:
name = result.split(":")[0]
backstory = result.split(":")[1]
return Personality(name, backstory)
@@ -103,7 +103,7 @@ class PersonalitySchema(ell.Schema):
name : str
backstory : str
def parse_outputs(result : lstr) -> str:
def parse_outputs(result : _lstr) -> str:
name = result.split(":")[0]
backstory = result.split(":")[1]
return Personality(name, backstory)
@@ -126,7 +126,7 @@ def create_random_personality():
# or
def parser(result : lstr) -> OutputFormat:
def parser(result : _lstr) -> OutputFormat:
name = result.split(":")[0]
backstory = result.split(":")[1]
return OutputFormat(name, backstory)
@@ -323,7 +323,7 @@ def internal_make_a_rpg_character(name : str):
@ell.track
@retry(tries=3)
def parse(result : lstr):
def parse(result : _lstr):
return json.parse(result)

View File

@@ -4,7 +4,7 @@ import ell
import requests
import ell.lmp.tool
from ell.lstr import lstr
from ell._lstr import _lstr
from ell.stores.sql import SQLiteStore
@@ -13,7 +13,7 @@ def get_html_content(
url : str = Field(description="The URL to get the HTML content of. Never incldue the protocol (like http:// or https://)"),
):
"""Get the HTML content of a URL."""
return lstr(requests.get(url))
return _lstr(requests.get(url))
@ell.text(model="gpt-4o", tools=[get_html_content], eager=True)

View File

@@ -7,7 +7,7 @@ from ell.stores.sql import SQLiteStore
def get_random_length():
return int(np.random.beta(2, 6) * 1500)
@ell.text(model="gpt-4o-mini")
@ell.multimodal(model="gpt-4o-mini")
def hello(world : str):
"""Your goal is to be really meant to the other guy whiel say hello"""
name = world.capitalize()
@@ -21,7 +21,7 @@ if __name__ == "__main__":
ell.set_store(SQLiteStore('sqlite_example'), autocommit=True)
greeting = hello("sam altman") # > "hello sama! ... "
print(greeting[0:5].__origin_trace__)
# print(greeting[0:5].__origin_trace__)

View File

@@ -1,7 +1,6 @@
"""
LM string that supports logits and keeps track of it's _origin_trace even after mutation.
"""
import numpy as np
from typing import (
Optional,
@@ -16,12 +15,12 @@ from typing import (
Callable,
)
from typing_extensions import override
from pydantic import GetCoreSchemaHandler
from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import CoreSchema
from pydantic_core import CoreSchema, core_schema
class lstr(str):
class _lstr(str):
"""
A string class that supports logits and keeps track of its _origin_trace even after mutation.
This class is designed to be used in prompt engineering libraries where it is essential to associate
@@ -97,7 +96,7 @@ class lstr(str):
logits (np.ndarray, optional): The logits associated with this string. Defaults to None.
_origin_trace (Union[str, FrozenSet[str]], optional): The _origin_trace(s) of this string. Defaults to None.
"""
instance = super(lstr, cls).__new__(cls, content)
instance = super(_lstr, cls).__new__(cls, content)
# instance._logits = logits
if isinstance(_origin_trace, str):
instance.__origin_trace__ = frozenset({_origin_trace})
@@ -110,23 +109,48 @@ class lstr(str):
# _logits: Optional[np.ndarray]
__origin_trace__: FrozenSet[str]
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
# todo: make this work with originators & all
return core_schema.no_info_after_validator_function(cls, handler(str))
# @property
# def logits(self) -> Optional[np.ndarray]:
# """
# Get the logits associated with this lstr instance.
def validate_lstr(value):
print("valiudating", value)
import traceback
print("Call stack:")
for line in traceback.format_stack():
print(line.strip())
if isinstance(value, dict) and value.get('__lstr', False):
content = value['content']
_origin_trace = value['_originator'].split(',')
return cls(content, _origin_trace=_origin_trace)
elif isinstance(value, str):
print("returning lstr")
return cls(value)
elif isinstance(value, cls):
return value
else:
raise ValueError(f"Invalid value for lstr: {value}")
# Returns:
# Optional[np.ndarray]: The logits associated with this lstr instance, or None if no logits are available.
# """
# return self._logits
return core_schema.json_or_python_schema(
json_schema=core_schema.typed_dict_schema({
'content': core_schema.typed_dict_field(core_schema.str_schema()),
'_originator': core_schema.typed_dict_field(core_schema.str_schema()),
'__lstr': core_schema.typed_dict_field(core_schema.bool_schema()),
}),
python_schema=core_schema.union_schema([
core_schema.is_instance_schema(cls),
core_schema.no_info_plain_validator_function(validate_lstr),
]),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: {
"content": str(instance),
"_originator": ",".join(instance._origin_trace),
"__lstr": True
}
)
)
@property
def _origin_trace(self) -> FrozenSet[str]:
"""
@@ -149,7 +173,7 @@ class lstr(str):
"""
return super().__repr__()
def __add__(self, other: Union[str, "lstr"]) -> "lstr":
def __add__(self, other: Union[str, "_lstr"]) -> "_lstr":
"""
Concatenate this lstr instance with another string or lstr instance.
@@ -159,21 +183,21 @@ class lstr(str):
Returns:
lstr: A new lstr instance containing the concatenated content, with the _origin_trace(s) updated accordingly.
"""
new_content = super(lstr, self).__add__(other)
new_content = super(_lstr, self).__add__(other)
self_origin = self.__origin_trace__
if isinstance(other, lstr):
if isinstance(other, _lstr):
new_origin = set(self_origin)
new_origin.update(other.__origin_trace__)
new_origin = frozenset(new_origin)
else:
new_origin = self_origin
return lstr(new_content, None, new_origin)
return _lstr(new_content, None, new_origin)
def __mod__(
self, other: Union[str, "lstr", Tuple[Union[str, "lstr"], ...]]
) -> "lstr":
self, other: Union[str, "_lstr", Tuple[Union[str, "_lstr"], ...]]
) -> "_lstr":
"""
Perform a modulo operation between this lstr instance and another string, lstr, or a tuple of strings and lstrs,
tracing the operation by logging the operands and the result.
@@ -186,22 +210,22 @@ class lstr(str):
"""
# If 'other' is a tuple, we need to handle each element
if isinstance(other, tuple):
result_content = super(lstr, self).__mod__(tuple(str(o) for o in other))
result_content = super(_lstr, self).__mod__(tuple(str(o) for o in other))
new__origin_trace__s = set(self.__origin_trace__)
for item in other:
if isinstance(item, lstr):
if isinstance(item, _lstr):
new__origin_trace__s.update(item.__origin_trace__)
new__origin_trace__ = frozenset(new__origin_trace__s)
else:
result_content = super(lstr, self).__mod__(other)
if isinstance(other, lstr):
result_content = super(_lstr, self).__mod__(other)
if isinstance(other, _lstr):
new__origin_trace__ = self.__origin_trace__.union(other.__origin_trace__)
else:
new__origin_trace__ = self.__origin_trace__
return lstr(result_content, None, new__origin_trace__)
return _lstr(result_content, None, new__origin_trace__)
def __mul__(self, other: SupportsIndex) -> "lstr":
def __mul__(self, other: SupportsIndex) -> "_lstr":
"""
Perform a multiplication operation between this lstr instance and an integer or another lstr,
tracing the operation by logging the operands and the result.
@@ -213,14 +237,14 @@ class lstr(str):
lstr: A new lstr instance containing the result of the multiplication operation, with the _origin_trace(s) updated accordingly.
"""
if isinstance(other, SupportsIndex):
result_content = super(lstr, self).__mul__(other)
result_content = super(_lstr, self).__mul__(other)
new__origin_trace__ = self.__origin_trace__
else:
return NotImplemented
return lstr(result_content, None, new__origin_trace__)
return _lstr(result_content, None, new__origin_trace__)
def __rmul__(self, other: SupportsIndex) -> "lstr":
def __rmul__(self, other: SupportsIndex) -> "_lstr":
"""
Perform a right multiplication operation between an integer or another lstr and this lstr instance,
tracing the operation by logging the operands and the result.
@@ -233,7 +257,7 @@ class lstr(str):
"""
return self.__mul__(other) # Multiplication is commutative in this context
def __getitem__(self, key: Union[SupportsIndex, slice]) -> "lstr":
def __getitem__(self, key: Union[SupportsIndex, slice]) -> "_lstr":
"""
Get a slice or index of this lstr instance.
@@ -243,14 +267,14 @@ class lstr(str):
Returns:
lstr: A new lstr instance containing the sliced or indexed content, with the _origin_trace(s) preserved.
"""
result = super(lstr, self).__getitem__(key)
result = super(_lstr, self).__getitem__(key)
# This is a matter of opinon. I believe that when you Index into a language model output, you or divorcing the lodges of the indexed result from their contacts which produce them. Therefore, it is only reasonable to directly index into the lodges without changing the original context, and so any mutation on the string should invalidate the logits.
# try:
# logit_subset = self._logits[key] if self._logits else None
# except:
# logit_subset = None
logit_subset = None
return lstr(result, logit_subset, self.__origin_trace__)
return _lstr(result, logit_subset, self.__origin_trace__)
def __getattribute__(self, name: str) -> Union[Callable, Any]:
"""
@@ -273,7 +297,7 @@ class lstr(str):
if name == "__class__":
return type(self)
if callable(attr) and name not in lstr.__dict__:
if callable(attr) and name not in _lstr.__dict__:
def wrapped(*args: Any, **kwargs: Any) -> Any:
result = attr(*args, **kwargs)
@@ -281,12 +305,12 @@ class lstr(str):
if isinstance(result, str):
_origin_traces = self.__origin_trace__
for arg in args:
if isinstance(arg, lstr):
if isinstance(arg, _lstr):
_origin_traces = _origin_traces.union(arg.__origin_trace__)
for key, value in kwargs.items():
if isinstance(value, lstr):
if isinstance(value, _lstr):
_origin_traces = _origin_traces.union(value.__origin_trace__)
return lstr(result, None, _origin_traces)
return _lstr(result, None, _origin_traces)
return result
@@ -295,7 +319,7 @@ class lstr(str):
return attr
@override
def join(self, iterable: Iterable[Union[str, "lstr"]]) -> "lstr":
def join(self, iterable: Iterable[Union[str, "_lstr"]]) -> "_lstr":
"""
Join a sequence of strings or lstr instances into a single lstr instance.
@@ -306,17 +330,17 @@ class lstr(str):
lstr: A new lstr instance containing the joined content, with the _origin_trace(s) updated accordingly.
"""
parts = [str(item) for item in iterable]
new_content = super(lstr, self).join(parts)
new_content = super(_lstr, self).join(parts)
new__origin_trace__ = self.__origin_trace__
for item in iterable:
if isinstance(item, lstr):
if isinstance(item, _lstr):
new__origin_trace__ = new__origin_trace__.union(item.__origin_trace__)
return lstr(new_content, None, new__origin_trace__)
return _lstr(new_content, None, new__origin_trace__)
@override
def split(
self, sep: Optional[Union[str, "lstr"]] = None, maxsplit: SupportsIndex = -1
) -> List["lstr"]:
self, sep: Optional[Union[str, "_lstr"]] = None, maxsplit: SupportsIndex = -1
) -> List["_lstr"]:
"""
Split this lstr instance into a list of lstr instances based on a separator.
@@ -327,12 +351,12 @@ class lstr(str):
Returns:
List["lstr"]: A list of lstr instances containing the split content, with the _origin_trace(s) preserved.
"""
return self._split_helper(super(lstr, self).split, sep, maxsplit)
return self._split_helper(super(_lstr, self).split, sep, maxsplit)
@override
def rsplit(
self, sep: Optional[Union[str, "lstr"]] = None, maxsplit: SupportsIndex = -1
) -> List["lstr"]:
self, sep: Optional[Union[str, "_lstr"]] = None, maxsplit: SupportsIndex = -1
) -> List["_lstr"]:
"""
Split this lstr instance into a list of lstr instances based on a separator, starting from the right.
@@ -343,10 +367,10 @@ class lstr(str):
Returns:
List["lstr"]: A list of lstr instances containing the split content, with the _origin_trace(s) preserved.
"""
return self._split_helper(super(lstr, self).rsplit, sep, maxsplit)
return self._split_helper(super(_lstr, self).rsplit, sep, maxsplit)
@override
def splitlines(self, keepends: bool = False) -> List["lstr"]:
def splitlines(self, keepends: bool = False) -> List["_lstr"]:
"""
Split this lstr instance into a list of lstr instances based on line breaks.
@@ -357,12 +381,12 @@ class lstr(str):
List["lstr"]: A list of lstr instances containing the split content, with the _origin_trace(s) preserved.
"""
return [
lstr(p, None, self.__origin_trace__)
for p in super(lstr, self).splitlines(keepends=keepends)
_lstr(p, None, self.__origin_trace__)
for p in super(_lstr, self).splitlines(keepends=keepends)
]
@override
def partition(self, sep: Union[str, "lstr"]) -> Tuple["lstr", "lstr", "lstr"]:
def partition(self, sep: Union[str, "_lstr"]) -> Tuple["_lstr", "_lstr", "_lstr"]:
"""
Partition this lstr instance into three lstr instances based on a separator.
@@ -372,10 +396,10 @@ class lstr(str):
Returns:
Tuple["lstr", "lstr", "lstr"]: A tuple of three lstr instances containing the content before the separator, the separator itself, and the content after the separator, with the _origin_trace(s) updated accordingly.
"""
return self._partition_helper(super(lstr, self).partition, sep)
return self._partition_helper(super(_lstr, self).partition, sep)
@override
def rpartition(self, sep: Union[str, "lstr"]) -> Tuple["lstr", "lstr", "lstr"]:
def rpartition(self, sep: Union[str, "_lstr"]) -> Tuple["_lstr", "_lstr", "_lstr"]:
"""
Partition this lstr instance into three lstr instances based on a separator, starting from the right.
@@ -385,11 +409,11 @@ class lstr(str):
Returns:
Tuple["lstr", "lstr", "lstr"]: A tuple of three lstr instances containing the content before the separator, the separator itself, and the content after the separator, with the _origin_trace(s) updated accordingly.
"""
return self._partition_helper(super(lstr, self).rpartition, sep)
return self._partition_helper(super(_lstr, self).rpartition, sep)
def _partition_helper(
self, method, sep: Union[str, "lstr"]
) -> Tuple["lstr", "lstr", "lstr"]:
self, method, sep: Union[str, "_lstr"]
) -> Tuple["_lstr", "_lstr", "_lstr"]:
"""
Helper method for partitioning this lstr instance based on a separator.
@@ -403,21 +427,21 @@ class lstr(str):
part1, part2, part3 = method(sep)
new__origin_trace__ = (
self.__origin_trace__ | sep.__origin_trace__
if isinstance(sep, lstr)
if isinstance(sep, _lstr)
else self.__origin_trace__
)
return (
lstr(part1, None, new__origin_trace__),
lstr(part2, None, new__origin_trace__),
lstr(part3, None, new__origin_trace__),
_lstr(part1, None, new__origin_trace__),
_lstr(part2, None, new__origin_trace__),
_lstr(part3, None, new__origin_trace__),
)
def _split_helper(
self,
method,
sep: Optional[Union[str, "lstr"]] = None,
sep: Optional[Union[str, "_lstr"]] = None,
maxsplit: SupportsIndex = -1,
) -> List["lstr"]:
) -> List["_lstr"]:
"""
Helper method for splitting this lstr instance based on a separator.
@@ -431,11 +455,11 @@ class lstr(str):
"""
_origin_traces = (
self.__origin_trace__ | sep.__origin_trace__
if isinstance(sep, lstr)
if isinstance(sep, _lstr)
else self.__origin_trace__
)
parts = method(sep, maxsplit)
return [lstr(part, None, _origin_traces) for part in parts]
return [_lstr(part, None, _origin_traces) for part in parts]
if __name__ == "__main__":
@@ -450,14 +474,14 @@ if __name__ == "__main__":
s1 = generate_random_string(1000)
s2 = generate_random_string(1000)
lstr_time = timeit.timeit(lambda: lstr(s1) + lstr(s2), number=10000)
lstr_time = timeit.timeit(lambda: _lstr(s1) + _lstr(s2), number=10000)
str_time = timeit.timeit(lambda: s1 + s2, number=10000)
print(f"Concatenation: lstr: {lstr_time:.6f}s, str: {str_time:.6f}s")
def test_slicing():
s = generate_random_string(10000)
ls = lstr(s)
ls = _lstr(s)
lstr_time = timeit.timeit(lambda: ls[1000:2000], number=10000)
str_time = timeit.timeit(lambda: s[1000:2000], number=10000)
@@ -466,7 +490,7 @@ if __name__ == "__main__":
def test_splitting():
s = generate_random_string(10000)
ls = lstr(s)
ls = _lstr(s)
lstr_time = timeit.timeit(lambda: ls.split(), number=1000)
str_time = timeit.timeit(lambda: s.split(), number=1000)
@@ -475,9 +499,9 @@ if __name__ == "__main__":
def test_joining():
words = [generate_random_string(10) for _ in range(1000)]
lwords = [lstr(word) for word in words]
lwords = [_lstr(word) for word in words]
lstr_time = timeit.timeit(lambda: lstr(' ').join(lwords), number=1000)
lstr_time = timeit.timeit(lambda: _lstr(' ').join(lwords), number=1000)
str_time = timeit.timeit(lambda: ' '.join(words), number=1000)
print(f"Joining: lstr: {lstr_time:.6f}s, str: {str_time:.6f}s")
@@ -495,8 +519,8 @@ if __name__ == "__main__":
def test_add():
s1 = generate_random_string(1000)
s2 = generate_random_string(1000)
ls1 = lstr(s1, None, "origin1")
ls2 = lstr(s2, None, "origin2")
ls1 = _lstr(s1, None, "origin1")
ls2 = _lstr(s2, None, "origin2")
for _ in range(100000):
result = ls1 + ls2

View File

@@ -1,6 +1,6 @@
from ell.configurator import config
from ell.lmp.track import track
from ell.lstr import lstr
from ell._lstr import _lstr
from ell.types import LMP, InvocableLM, LMPParams, LMPType, Message, MessageContentBlock, MessageOrDict, _lstr_generic
from ell.util._warnings import _warnings
from ell.util.api import call
@@ -70,7 +70,7 @@ def _get_messages(prompt_ret: Union[str, list[MessageOrDict]], prompt: LMP) -> l
"""
if isinstance(prompt_ret, str):
return [
Message(role="system", content=[MessageContentBlock(text=lstr(prompt.__doc__) or config.default_system_prompt)]),
Message(role="system", content=[MessageContentBlock(text=_lstr(prompt.__doc__) or config.default_system_prompt)]),
Message(role="user", content=[MessageContentBlock(text=prompt_ret)]),
]
else:

View File

@@ -1,9 +1,9 @@
import logging
import threading
from ell.types import LMPType, SerializedLStr, utc_now, SerializedLMP, Invocation, InvocationTrace
from ell.types import LMPType, utc_now, SerializedLMP, Invocation, InvocationTrace
import ell.util.closure
from ell.configurator import config
from ell.lstr import lstr
from ell._lstr import _lstr
import inspect
@@ -202,25 +202,12 @@ def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion
invocation_kwargs=invocation_kwargs,
args=cleaned_invocation_params.get('args', []),
kwargs=cleaned_invocation_params.get('kwargs', {}),
used_by_id=parent_invocation_id
used_by_id=parent_invocation_id,
results=result
)
results = []
if isinstance(result, lstr):
results = [result]
elif isinstance(result, list):
results = result
else:
raise TypeError("Result must be either lstr or List[lstr]")
serialized_results = [
SerializedLStr(
content=str(res),
# logits=res.logits
) for res in results
]
config._store.write_invocation(invocation, serialized_results, consumes)
config._store.write_invocation(invocation, consumes)
def compute_state_cache_key(ipstr, fn_closure):
_global_free_vars_str = f"{json.dumps(get_immutable_vars(fn_closure[2]), sort_keys=True, default=repr)}"
@@ -269,7 +256,7 @@ def prepare_invocation_params(fn_args, fn_kwargs):
lambda arr: arr.tolist()
)
invocation_converter.register_unstructure_hook(
lstr,
_lstr,
process_lstr
)
invocation_converter.register_unstructure_hook(

View File

@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
from contextlib import contextmanager
from datetime import datetime
from typing import Any, Optional, Dict, List, Set, Union
from ell.lstr import lstr
from ell.types import InvocableLM, SerializedLMP, Invocation, SerializedLStr
from ell._lstr import _lstr
from ell.types import InvocableLM, SerializedLMP, Invocation
class Store(ABC):
@@ -23,7 +23,7 @@ class Store(ABC):
pass
@abstractmethod
def write_invocation(self, invocation: Invocation, results: List[SerializedLStr], consumes: Set[str]) -> Optional[Any]:
def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optional[Any]:
"""
Write an invocation of an LMP to the storage.

View File

@@ -1,219 +0,0 @@
"""
This is an old stub from a previous jsonl store implementation that was
used for testing. You can feel free to PR a full implementation if you'd like. It just needs an index.
"""
# import os
# import json
# from typing import Any, Optional, Dict, List
# from ell.lstr import lstr
# import ell.store
# import numpy as np
# import glob
# from operator import itemgetter
# import warnings
# import cattrs
# class JsonlStore(ell.store.Store):
# def __init__(self, storage_dir: str, max_file_size: int = 1024 * 1024, check_empty: bool = False): # 1MB default
# self.storage_dir = storage_dir
# self.max_file_size = max_file_size
# os.makedirs(storage_dir, exist_ok=True)
# self.open_files = {}
# if check_empty and not os.path.exists(os.path.join(storage_dir, 'invocations')) and \
# not os.path.exists(os.path.join(storage_dir, 'programs')):
# warnings.warn(f"The ELL storage directory '{storage_dir}' is empty. No invocations or programs found.")
# self.converter = cattrs.Converter()
# self._setup_cattrs()
# def lst_converter(self, obj: Any) -> Any:
# # print(obj)
# # return obj
# return self.converter.unstructure(dict(content=str(obj), **obj.__dict__, __is_lstr=True))
# return obj
# def _setup_cattrs(self):
# self.converter.register_unstructure_hook(
# np.ndarray,
# lambda arr: arr.tolist()
# )
# self.converter.register_unstructure_hook(
# lstr,
# self.lst_converter
# )
# self.converter.register_unstructure_hook(
# set,
# lambda s: list(s)
# )
# self.converter.register_unstructure_hook(
# frozenset,
# lambda s: list(s)
# )
# def _serialize(self, obj: Any) -> Any:
# return self.converter.unstructure(obj)
# def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str],
# created_at: float, is_lmp: bool, lm_kwargs: Optional[str],
# uses: Dict[str, Any]) -> Optional[Any]:
# """
# Write the LMP (Language Model Program) to a JSON file.
# """
# file_path = os.path.join(self.storage_dir, 'programs', f"{name}_{lmp_id}.json")
# os.makedirs(os.path.dirname(file_path), exist_ok=True)
# lmp_data = {
# 'lmp_id': lmp_id,
# 'name': name,
# 'source': source,
# 'dependencies': dependencies,
# 'created_at': created_at,
# 'is_lmp': is_lmp,
# 'lm_kwargs': lm_kwargs,
# 'uses': uses
# }
# with open(file_path, 'w') as f:
# json.dump(self._serialize(lmp_data), f)
# return None
# def write_invocation(self, lmp_id: str, args: str, kwargs: str, result: str,
# created_at: float, invocation_kwargs: Dict[str, Any]) -> Optional[Any]:
# """
# Write an LMP invocation to a JSONL file in a nested folder structure.
# """
# dir_path = os.path.join(self.storage_dir, 'invocations', lmp_id[:2], lmp_id[2:4], lmp_id[4:])
# os.makedirs(dir_path, exist_ok=True)
# if lmp_id not in self.open_files:
# index = 0
# while True:
# file_path = os.path.join(dir_path, f"invocations.{index}.jsonl")
# if not os.path.exists(file_path) or os.path.getsize(file_path) < self.max_file_size:
# self.open_files[lmp_id] = {'file': open(file_path, 'a'), 'path': file_path}
# break
# index += 1
# invocation_data = {
# 'lmp_id': lmp_id,
# 'args': args,
# 'kwargs': kwargs,
# 'result': result,
# 'invocation_kwargs': invocation_kwargs,
# 'created_at': created_at
# }
# file_info = self.open_files[lmp_id]
# file_info['file'].write(json.dumps(self._serialize(invocation_data)) + '\n')
# file_info['file'].flush()
# if os.path.getsize(file_info['path']) >= self.max_file_size:
# file_info['file'].close()
# del self.open_files[lmp_id]
# return invocation_data
# def get_lmps(self, **filters: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]:
# lmps = []
# for file_path in glob.glob(os.path.join(self.storage_dir, 'programs', '*.json')):
# with open(file_path, 'r') as f:
# lmp = json.load(f)
# if filters:
# if all(lmp.get(k) == v for k, v in filters.items()):
# lmps.append(lmp)
# else:
# lmps.append(lmp)
# return lmps
# def get_invocations(self, lmp_id: str, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
# invocations = []
# dir_path = os.path.join(self.storage_dir, 'invocations', lmp_id[:2], lmp_id[2:4], lmp_id[4:])
# for file_path in glob.glob(os.path.join(dir_path, 'invocations.*.jsonl')):
# with open(file_path, 'r') as f:
# for line in f:
# invocation = json.loads(line)
# if filters:
# if all(invocation.get(k) == v for k, v in filters.items()):
# invocations.append(invocation)
# else:
# invocations.append(invocation)
# return invocations
# def get_lmp(self, lmp_id: str) -> Optional[Dict[str, Any]]:
# """
# Get a specific LMP by its ID.
# """
# for file_path in glob.glob(os.path.join(self.storage_dir, 'programs', '*.json')):
# with open(file_path, 'r') as f:
# lmp = json.load(f)
# if lmp.get('lmp_id') == lmp_id:
# return lmp
# return None
# def search_lmps(self, query: str) -> List[Dict[str, Any]]:
# lmps = []
# for file_path in glob.glob(os.path.join(self.storage_dir, 'programs', '*.json')):
# with open(file_path, 'r') as f:
# lmp = json.load(f)
# if query.lower() in json.dumps(lmp).lower():
# lmps.append(lmp)
# return lmps
# def search_invocations(self, query: str) -> List[Dict[str, Any]]:
# invocations = []
# for dir_path in glob.glob(os.path.join(self.storage_dir, 'invocations', '*', '*', '*')):
# for file_path in glob.glob(os.path.join(dir_path, 'invocations.*.jsonl')):
# with open(file_path, 'r') as f:
# for line in f:
# invocation = json.loads(line)
# if query.lower() in json.dumps(invocation).lower():
# invocations.append(invocation)
# return invocations
# def get_lmp_versions(self, lmp_id: str) -> List[Dict[str, Any]]:
# """
# Get all versions of an LMP with the given lmp_id.
# """
# target_lmp = self.get_lmp(lmp_id)
# if not target_lmp:
# return []
# versions = []
# for file_path in glob.glob(os.path.join(self.storage_dir, 'programs', f"{target_lmp['name']}_*.json")):
# with open(file_path, 'r') as f:
# lmp = json.load(f)
# versions.append(lmp)
# # Sort versions by created_at timestamp, newest first
# return sorted(versions, key=lambda x: x['created_at'], reverse=True)
# def get_latest_lmps(self) -> List[Dict[str, Any]]:
# """
# Get the latest version of each unique LMP.
# """
# lmps_by_name = {}
# for file_path in glob.glob(os.path.join(self.storage_dir, 'programs', '*.json')):
# with open(file_path, 'r') as f:
# lmp = json.load(f)
# name = lmp['name']
# if name not in lmps_by_name or lmp['created_at'] > lmps_by_name[name]['created_at']:
# lmps_by_name[name] = lmp
# # Return the list of latest LMPs, sorted by name
# return sorted(lmps_by_name.values(), key=itemgetter('name'))
# def __del__(self):
# """
# Close all open files when the serializer is destroyed.
# """
# for file_info in self.open_files.values():
# file_info['file'].close()
# def install(self):
# """
# Install the serializer into all invocations of the ell wrapper.
# """
# ell.config.register_serializer(self)

View File

@@ -7,16 +7,14 @@ import ell.store
import cattrs
import numpy as np
from sqlalchemy.sql import text
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr, utc_now
from ell.lstr import lstr
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, utc_now
from ell._lstr import _lstr
from sqlalchemy import or_, func, and_, extract, case
class SQLStore(ell.store.Store):
def __init__(self, db_uri: str):
self.engine = create_engine(db_uri)
SQLModel.metadata.create_all(self.engine)
self.open_files: Dict[str, Dict[str, Any]] = {}
@@ -39,7 +37,7 @@ class SQLStore(ell.store.Store):
session.commit()
return None
def write_invocation(self, invocation: Invocation, results: List[SerializedLStr], consumes: Set[str]) -> Optional[Any]:
def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optional[Any]:
with Session(self.engine) as session:
lmp = session.query(SerializedLMP).filter(SerializedLMP.lmp_id == invocation.lmp_id).first()
assert lmp is not None, f"LMP with id {invocation.lmp_id} not found. Writing invocation erroneously"
@@ -52,10 +50,6 @@ class SQLStore(ell.store.Store):
session.add(invocation)
for result in results:
result.producer_invocation = invocation
session.add(result)
# Now create traces.
for consumed_id in consumes:
session.add(InvocationTrace(
@@ -116,9 +110,8 @@ class SQLStore(ell.store.Store):
def get_invocations(self, session: Session, lmp_filters: Dict[str, Any], skip: int = 0, limit: int = 10, filters: Optional[Dict[str, Any]] = None, hierarchical: bool = False) -> List[Dict[str, Any]]:
def fetch_invocation(inv_id):
query = (
select(Invocation, SerializedLStr, SerializedLMP)
select(Invocation, SerializedLMP)
.join(SerializedLMP)
.outerjoin(SerializedLStr)
.where(Invocation.id == inv_id)
)
results = session.exec(query).all()
@@ -126,10 +119,10 @@ class SQLStore(ell.store.Store):
if not results:
return None
inv, lstr, lmp = results[0]
inv, lmp = results[0]
inv_dict = inv.model_dump()
inv_dict['lmp'] = lmp.model_dump()
inv_dict['results'] = [dict(**l.model_dump(), __lstr=True) for l in [r[1] for r in results if r[1]]]
inv_dict['results'] = inv_dict['lmp']['results']
# Fetch consumes and consumed_by invocation IDs
consumes_query = select(InvocationTrace.invocation_consuming_id).where(InvocationTrace.invocation_consumer_id == inv_id)

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from typing import List, Optional, Dict, Any
from sqlmodel import SQLModel
from ell.types import SerializedLMPBase, InvocationBase, SerializedLStrBase
from ell.types import SerializedLMPBase, InvocationBase
class SerializedLMPPublic(SerializedLMPBase):
pass
@@ -25,7 +25,6 @@ class SerializedLMPUpdate(SQLModel):
class InvocationPublic(InvocationBase):
lmp: SerializedLMPPublic
results: List[SerializedLStrBase]
consumes: List[str]
consumed_by: List[str]
uses: List[str]
@@ -44,11 +43,6 @@ class InvocationUpdate(SQLModel):
state_cache_key: Optional[str] = None
invocation_kwargs: Optional[Dict[str, Any]] = None
class SerializedLStrPublic(SerializedLStrBase):
pass
class SerializedLStrCreate(SerializedLStrBase):
pass
class SerializedLStrUpdate(SQLModel):
content: Optional[str] = None

View File

@@ -4,7 +4,9 @@ from typing import Callable, Dict, List, Type, Union, Any, Optional
from pydantic import BaseModel, field_validator, validator
from ell.lstr import lstr
import enum
from ell._lstr import _lstr
from ell.util.dict_sync_meta import DictSyncMeta
from datetime import datetime, timezone
@@ -13,7 +15,7 @@ from sqlmodel import Field, SQLModel, Relationship, JSON, Column
from sqlalchemy import Index, func
import sqlalchemy.types as types
_lstr_generic = Union[lstr, str]
_lstr_generic = Union[_lstr, str]
OneTurn = Callable[..., _lstr_generic]
@@ -24,6 +26,8 @@ LMPParams = Dict[str, Any]
InvocableTool = Callable[..., _lstr_generic]
# todo: implement tracing for structured outs. this a v2 feature.
class ToolCall(BaseModel):
tool : InvocableTool
@@ -36,7 +40,7 @@ class MessageContentBlock(BaseModel):
tool_calls: Optional[List[ToolCall]] = Field(default=None)
structured: Optional[Type[BaseModel]] = Field(default=None)
# XXX: Todo implement this.
# XXX: Should validate.
# @field_validator('*')
# @classmethod
# def check_at_least_one_field_set(cls, v, info):
@@ -47,14 +51,15 @@ class MessageContentBlock(BaseModel):
class Message(BaseModel):
role: str
content: List[MessageContentBlock] = Field(min_length=1)
def __json__(self):
print("dumping", self.model_dump())
return self.model_dump()
def to_openai_message(self):
assert len(self.content) == 1
return {
"role": self.role,
"content": self.content[0].text
}
# Well this is disappointing, I wanted to effectively type hint by doign that data sync meta, but eh, at elast we can still reference role or content this way. Probably wil lcan the dict sync meta.
@@ -103,8 +108,6 @@ def UTCTimestampField(index:bool=False, **kwargs:Any):
sa_column=Column(UTCTimestamp(timezone=True), index=index, **kwargs))
import enum
class LMPType(str, enum.Enum):
LM = "LM"
TOOL = "TOOL"
@@ -112,6 +115,7 @@ class LMPType(str, enum.Enum):
OTHER = "OTHER"
class SerializedLMPBase(SQLModel):
lmp_id: Optional[str] = Field(default=None, primary_key=True)
name: str = Field(index=True)
@@ -156,15 +160,6 @@ class InvocationTrace(SQLModel, table=True):
invocation_consumer_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True)
invocation_consuming_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True)
class SerializedLStrBase(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)
content: str
logits: List[float] = Field(default_factory=list, sa_column=Column(JSON))
producer_invocation_id: Optional[str] = Field(default=None, foreign_key="invocation.id", index=True)
class SerializedLStr(SerializedLStrBase, table=True):
producer_invocation: Optional["Invocation"] = Relationship(back_populates="results")
# Should be subtyped for differnet kidns of LMPS.
class InvocationBase(SQLModel):
id: Optional[str] = Field(default=None, primary_key=True)
@@ -180,10 +175,10 @@ class InvocationBase(SQLModel):
created_at: datetime = UTCTimestampField(default=func.now(), nullable=False)
invocation_kwargs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON))
used_by_id: Optional[str] = Field(default=None, foreign_key="invocation.id", index=True)
results : Union[List[Message], Any] = Field(default_factory=list, sa_column=Column(JSON))
class Invocation(InvocationBase, table=True):
lmp: SerializedLMP = Relationship(back_populates="invocations")
results: List[SerializedLStr] = Relationship(back_populates="producer_invocation")
consumed_by: List["Invocation"] = Relationship(
back_populates="consumes",
link_model=InvocationTrace,

View File

@@ -4,7 +4,7 @@ from functools import partial
from ell.configurator import config
import openai
from collections import defaultdict
from ell.lstr import lstr
from ell._lstr import _lstr
from ell.types import LMP, LMPParams, Message, MessageContentBlock, MessageOrDict
@@ -22,9 +22,11 @@ def process_messages_for_client(messages: list[Message], client: Any):
return [
message.to_openai_message()
for message in messages]
# elif isinstance(client, anthropic.Anthropic):
# return messages
# XXX: or some such.
def call(
*,
model: str,
@@ -36,7 +38,7 @@ def call(
_exempt_from_tracking: bool,
_logging_color=None,
_name: str = None,
) -> Tuple[Union[lstr, Iterable[lstr]], Optional[Dict[str, Any]]]:
) -> Tuple[Union[_lstr, Iterable[_lstr]], Optional[Dict[str, Any]]]:
"""
Helper function to run the language model with the provided messages and parameters.
"""
@@ -114,13 +116,14 @@ def call(
tracked_results = [
# TODO: Remove hardcoding
# TODO: Unversal message format
Message(role=choice.message.role if not streaming else choice_deltas[0].delta.role, content=[MessageContentBlock(text=lstr(
Message(role=choice.message.role if not streaming else choice_deltas[0].delta.role, content=[MessageContentBlock(text=_lstr(
content="".join((choice.delta.content or "" for choice in choice_deltas)) if streaming else choice.message.content,
_origin_trace=_invocation_origin,
))])
for _, choice_deltas in sorted(choices_progress.items(), key= lambda x: x[0],)
]
print(tracked_results)
api_params= dict(model=model, messages=client_safe_messages_messages, lm_kwargs=lm_kwargs)
return tracked_results[0] if n_choices == 1 else tracked_results, api_params, metadata

View File

@@ -1,12 +1,12 @@
import numpy as np
import pytest
from ell.lstr import lstr
from ell._lstr import _lstr
class TestLstr:
def test_init(self):
# Test initialization with string content only
s = lstr("hello")
s = _lstr("hello")
assert str(s) == "hello"
assert s.logits is None
assert s._origin_trace == frozenset()
@@ -14,14 +14,14 @@ class TestLstr:
# Test initialization with logits and _origin_trace
logits = np.array([0.1, 0.2])
_origin_trace = "model1"
s = lstr("world", logits=logits, _origin_trace=_origin_trace)
s = _lstr("world", logits=logits, _origin_trace=_origin_trace)
assert str(s) == "world"
assert np.array_equal(s.logits, logits)
assert s._origin_trace == frozenset({_origin_trace})
def test_add(self):
s1 = lstr("hello")
s2 = lstr("world", _origin_trace="model2")
s1 = _lstr("hello")
s2 = _lstr("world", _origin_trace="model2")
assert isinstance(s1 + s2, str)
result = s1 + s2
assert str(result) == "helloworld"
@@ -29,21 +29,21 @@ class TestLstr:
assert result._origin_trace == frozenset({"model2"})
def test_mod(self):
s = lstr("hello %s")
s = _lstr("hello %s")
result = s % "world"
assert str(result) == "hello world"
assert result.logits is None
assert result._origin_trace == frozenset()
def test_mul(self):
s = lstr("ha", _origin_trace="model3")
s = _lstr("ha", _origin_trace="model3")
result = s * 3
assert str(result) == "hahaha"
assert result.logits is None
assert result._origin_trace == frozenset({"model3"})
def test_getitem(self):
s = lstr(
s = _lstr(
"hello", logits=np.array([0.1, 0.2, 0.3, 0.4, 0.5]), _origin_trace="model4"
)
result = s[1:4]
@@ -53,50 +53,50 @@ class TestLstr:
def test_upper(self):
# Test upper method without _origin_trace and logits
s = lstr("hello")
s = _lstr("hello")
result = s.upper()
assert str(result) == "HELLO"
assert result.logits is None
assert result._origin_trace == frozenset()
# Test upper method with _origin_trace
s = lstr("world", _origin_trace="model11")
s = _lstr("world", _origin_trace="model11")
result = s.upper()
assert str(result) == "WORLD"
assert result.logits is None
assert result._origin_trace == frozenset({"model11"})
def test_join(self):
s = lstr(", ", _origin_trace="model5")
parts = [lstr("hello"), lstr("world", _origin_trace="model6")]
s = _lstr(", ", _origin_trace="model5")
parts = [_lstr("hello"), _lstr("world", _origin_trace="model6")]
result = s.join(parts)
assert str(result) == "hello, world"
assert result.logits is None
assert result._origin_trace == frozenset({"model5", "model6"})
def test_split(self):
s = lstr("hello world", _origin_trace="model7")
s = _lstr("hello world", _origin_trace="model7")
parts = s.split()
assert [str(p) for p in parts] == ["hello", "world"]
assert all(p.logits is None for p in parts)
assert all(p._origin_trace == frozenset({"model7"}) for p in parts)
def test_partition(self):
s = lstr("hello, world", _origin_trace="model8")
s = _lstr("hello, world", _origin_trace="model8")
part1, sep, part2 = s.partition(", ")
assert (str(part1), str(sep), str(part2)) == ("hello", ", ", "world")
assert all(p.logits is None for p in (part1, sep, part2))
assert all(p._origin_trace == frozenset({"model8"}) for p in (part1, sep, part2))
def test_formatting(self):
s = lstr("Hello {}!")
filled = s.format(lstr("world", _origin_trace="model9"))
s = _lstr("Hello {}!")
filled = s.format(_lstr("world", _origin_trace="model9"))
assert str(filled) == "Hello world!"
assert filled.logits is None
assert filled._origin_trace == frozenset({"model9"})
def test_repr(self):
s = lstr("test", logits=np.array([1.0]), _origin_trace="model10")
s = _lstr("test", logits=np.array([1.0]), _origin_trace="model10")
assert "test" in repr(s)
assert "model10" in repr(s._origin_trace)