mirror of
https://github.com/MadcowD/ell.git
synced 2024-09-22 16:14:36 +03:00
beginning of lstr pydantization
This commit is contained in:
@@ -2,7 +2,7 @@ import dataclasses
|
||||
import random
|
||||
from typing import Callable, List, Tuple
|
||||
import ell
|
||||
from ell.lstr import lstr
|
||||
from ell._lstr import _lstr
|
||||
ell.config.verbose = True
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ ell.config.verbose = True
|
||||
# Option 0 (pythonic.)
|
||||
# This going to be NON Native to ell
|
||||
|
||||
def parse_outputs(result : lstr) -> str:
|
||||
def parse_outputs(result : _lstr) -> str:
|
||||
name = result.split(":")[0]
|
||||
backstory = result.split(":")[1]
|
||||
return Personality(name, backstory)
|
||||
@@ -31,7 +31,7 @@ def get_personality():
|
||||
|
||||
# Option 1.
|
||||
|
||||
def parse_outputs(result : lstr) -> str:
|
||||
def parse_outputs(result : _lstr) -> str:
|
||||
name = result.split(":")[0]
|
||||
backstory = result.split(":")[1]
|
||||
return Personality(name, backstory)
|
||||
@@ -56,7 +56,7 @@ class Personality:
|
||||
backstory : str
|
||||
|
||||
@staticmethod
|
||||
def parse_outputs(result : lstr) -> str:
|
||||
def parse_outputs(result : _lstr) -> str:
|
||||
name = result.split(":")[0]
|
||||
backstory = result.split(":")[1]
|
||||
return Personality(name, backstory)
|
||||
@@ -73,7 +73,7 @@ Backstory: <3 sentence backstory>'""" # System prompt
|
||||
|
||||
# Option 3. Another decorator
|
||||
|
||||
def parse_outputs(result : lstr) -> str:
|
||||
def parse_outputs(result : _lstr) -> str:
|
||||
name = result.split(":")[0]
|
||||
backstory = result.split(":")[1]
|
||||
return Personality(name, backstory)
|
||||
@@ -103,7 +103,7 @@ class PersonalitySchema(ell.Schema):
|
||||
name : str
|
||||
backstory : str
|
||||
|
||||
def parse_outputs(result : lstr) -> str:
|
||||
def parse_outputs(result : _lstr) -> str:
|
||||
name = result.split(":")[0]
|
||||
backstory = result.split(":")[1]
|
||||
return Personality(name, backstory)
|
||||
@@ -126,7 +126,7 @@ def create_random_personality():
|
||||
|
||||
|
||||
# or
|
||||
def parser(result : lstr) -> OutputFormat:
|
||||
def parser(result : _lstr) -> OutputFormat:
|
||||
name = result.split(":")[0]
|
||||
backstory = result.split(":")[1]
|
||||
return OutputFormat(name, backstory)
|
||||
@@ -323,7 +323,7 @@ def internal_make_a_rpg_character(name : str):
|
||||
|
||||
@ell.track
|
||||
@retry(tries=3)
|
||||
def parse(result : lstr):
|
||||
def parse(result : _lstr):
|
||||
return json.parse(result)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import ell
|
||||
import requests
|
||||
|
||||
import ell.lmp.tool
|
||||
from ell.lstr import lstr
|
||||
from ell._lstr import _lstr
|
||||
from ell.stores.sql import SQLiteStore
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ def get_html_content(
|
||||
url : str = Field(description="The URL to get the HTML content of. Never incldue the protocol (like http:// or https://)"),
|
||||
):
|
||||
"""Get the HTML content of a URL."""
|
||||
return lstr(requests.get(url))
|
||||
return _lstr(requests.get(url))
|
||||
|
||||
|
||||
@ell.text(model="gpt-4o", tools=[get_html_content], eager=True)
|
||||
|
||||
@@ -7,7 +7,7 @@ from ell.stores.sql import SQLiteStore
|
||||
def get_random_length():
|
||||
return int(np.random.beta(2, 6) * 1500)
|
||||
|
||||
@ell.text(model="gpt-4o-mini")
|
||||
@ell.multimodal(model="gpt-4o-mini")
|
||||
def hello(world : str):
|
||||
"""Your goal is to be really meant to the other guy whiel say hello"""
|
||||
name = world.capitalize()
|
||||
@@ -21,7 +21,7 @@ if __name__ == "__main__":
|
||||
ell.set_store(SQLiteStore('sqlite_example'), autocommit=True)
|
||||
|
||||
greeting = hello("sam altman") # > "hello sama! ... "
|
||||
print(greeting[0:5].__origin_trace__)
|
||||
# print(greeting[0:5].__origin_trace__)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
LM string that supports logits and keeps track of it's _origin_trace even after mutation.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import (
|
||||
Optional,
|
||||
@@ -16,12 +15,12 @@ from typing import (
|
||||
Callable,
|
||||
)
|
||||
from typing_extensions import override
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema
|
||||
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
class lstr(str):
|
||||
class _lstr(str):
|
||||
"""
|
||||
A string class that supports logits and keeps track of its _origin_trace even after mutation.
|
||||
This class is designed to be used in prompt engineering libraries where it is essential to associate
|
||||
@@ -97,7 +96,7 @@ class lstr(str):
|
||||
logits (np.ndarray, optional): The logits associated with this string. Defaults to None.
|
||||
_origin_trace (Union[str, FrozenSet[str]], optional): The _origin_trace(s) of this string. Defaults to None.
|
||||
"""
|
||||
instance = super(lstr, cls).__new__(cls, content)
|
||||
instance = super(_lstr, cls).__new__(cls, content)
|
||||
# instance._logits = logits
|
||||
if isinstance(_origin_trace, str):
|
||||
instance.__origin_trace__ = frozenset({_origin_trace})
|
||||
@@ -110,23 +109,48 @@ class lstr(str):
|
||||
# _logits: Optional[np.ndarray]
|
||||
__origin_trace__: FrozenSet[str]
|
||||
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
# todo: make this work with originators & all
|
||||
return core_schema.no_info_after_validator_function(cls, handler(str))
|
||||
# @property
|
||||
# def logits(self) -> Optional[np.ndarray]:
|
||||
# """
|
||||
# Get the logits associated with this lstr instance.
|
||||
def validate_lstr(value):
|
||||
print("valiudating", value)
|
||||
import traceback
|
||||
print("Call stack:")
|
||||
for line in traceback.format_stack():
|
||||
print(line.strip())
|
||||
if isinstance(value, dict) and value.get('__lstr', False):
|
||||
content = value['content']
|
||||
_origin_trace = value['_originator'].split(',')
|
||||
return cls(content, _origin_trace=_origin_trace)
|
||||
elif isinstance(value, str):
|
||||
print("returning lstr")
|
||||
return cls(value)
|
||||
elif isinstance(value, cls):
|
||||
return value
|
||||
else:
|
||||
raise ValueError(f"Invalid value for lstr: {value}")
|
||||
|
||||
# Returns:
|
||||
# Optional[np.ndarray]: The logits associated with this lstr instance, or None if no logits are available.
|
||||
# """
|
||||
# return self._logits
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.typed_dict_schema({
|
||||
'content': core_schema.typed_dict_field(core_schema.str_schema()),
|
||||
'_originator': core_schema.typed_dict_field(core_schema.str_schema()),
|
||||
'__lstr': core_schema.typed_dict_field(core_schema.bool_schema()),
|
||||
}),
|
||||
python_schema=core_schema.union_schema([
|
||||
core_schema.is_instance_schema(cls),
|
||||
core_schema.no_info_plain_validator_function(validate_lstr),
|
||||
]),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda instance: {
|
||||
"content": str(instance),
|
||||
"_originator": ",".join(instance._origin_trace),
|
||||
"__lstr": True
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def _origin_trace(self) -> FrozenSet[str]:
|
||||
"""
|
||||
@@ -149,7 +173,7 @@ class lstr(str):
|
||||
"""
|
||||
return super().__repr__()
|
||||
|
||||
def __add__(self, other: Union[str, "lstr"]) -> "lstr":
|
||||
def __add__(self, other: Union[str, "_lstr"]) -> "_lstr":
|
||||
"""
|
||||
Concatenate this lstr instance with another string or lstr instance.
|
||||
|
||||
@@ -159,21 +183,21 @@ class lstr(str):
|
||||
Returns:
|
||||
lstr: A new lstr instance containing the concatenated content, with the _origin_trace(s) updated accordingly.
|
||||
"""
|
||||
new_content = super(lstr, self).__add__(other)
|
||||
new_content = super(_lstr, self).__add__(other)
|
||||
self_origin = self.__origin_trace__
|
||||
|
||||
if isinstance(other, lstr):
|
||||
if isinstance(other, _lstr):
|
||||
new_origin = set(self_origin)
|
||||
new_origin.update(other.__origin_trace__)
|
||||
new_origin = frozenset(new_origin)
|
||||
else:
|
||||
new_origin = self_origin
|
||||
|
||||
return lstr(new_content, None, new_origin)
|
||||
return _lstr(new_content, None, new_origin)
|
||||
|
||||
def __mod__(
|
||||
self, other: Union[str, "lstr", Tuple[Union[str, "lstr"], ...]]
|
||||
) -> "lstr":
|
||||
self, other: Union[str, "_lstr", Tuple[Union[str, "_lstr"], ...]]
|
||||
) -> "_lstr":
|
||||
"""
|
||||
Perform a modulo operation between this lstr instance and another string, lstr, or a tuple of strings and lstrs,
|
||||
tracing the operation by logging the operands and the result.
|
||||
@@ -186,22 +210,22 @@ class lstr(str):
|
||||
"""
|
||||
# If 'other' is a tuple, we need to handle each element
|
||||
if isinstance(other, tuple):
|
||||
result_content = super(lstr, self).__mod__(tuple(str(o) for o in other))
|
||||
result_content = super(_lstr, self).__mod__(tuple(str(o) for o in other))
|
||||
new__origin_trace__s = set(self.__origin_trace__)
|
||||
for item in other:
|
||||
if isinstance(item, lstr):
|
||||
if isinstance(item, _lstr):
|
||||
new__origin_trace__s.update(item.__origin_trace__)
|
||||
new__origin_trace__ = frozenset(new__origin_trace__s)
|
||||
else:
|
||||
result_content = super(lstr, self).__mod__(other)
|
||||
if isinstance(other, lstr):
|
||||
result_content = super(_lstr, self).__mod__(other)
|
||||
if isinstance(other, _lstr):
|
||||
new__origin_trace__ = self.__origin_trace__.union(other.__origin_trace__)
|
||||
else:
|
||||
new__origin_trace__ = self.__origin_trace__
|
||||
|
||||
return lstr(result_content, None, new__origin_trace__)
|
||||
return _lstr(result_content, None, new__origin_trace__)
|
||||
|
||||
def __mul__(self, other: SupportsIndex) -> "lstr":
|
||||
def __mul__(self, other: SupportsIndex) -> "_lstr":
|
||||
"""
|
||||
Perform a multiplication operation between this lstr instance and an integer or another lstr,
|
||||
tracing the operation by logging the operands and the result.
|
||||
@@ -213,14 +237,14 @@ class lstr(str):
|
||||
lstr: A new lstr instance containing the result of the multiplication operation, with the _origin_trace(s) updated accordingly.
|
||||
"""
|
||||
if isinstance(other, SupportsIndex):
|
||||
result_content = super(lstr, self).__mul__(other)
|
||||
result_content = super(_lstr, self).__mul__(other)
|
||||
new__origin_trace__ = self.__origin_trace__
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
return lstr(result_content, None, new__origin_trace__)
|
||||
return _lstr(result_content, None, new__origin_trace__)
|
||||
|
||||
def __rmul__(self, other: SupportsIndex) -> "lstr":
|
||||
def __rmul__(self, other: SupportsIndex) -> "_lstr":
|
||||
"""
|
||||
Perform a right multiplication operation between an integer or another lstr and this lstr instance,
|
||||
tracing the operation by logging the operands and the result.
|
||||
@@ -233,7 +257,7 @@ class lstr(str):
|
||||
"""
|
||||
return self.__mul__(other) # Multiplication is commutative in this context
|
||||
|
||||
def __getitem__(self, key: Union[SupportsIndex, slice]) -> "lstr":
|
||||
def __getitem__(self, key: Union[SupportsIndex, slice]) -> "_lstr":
|
||||
"""
|
||||
Get a slice or index of this lstr instance.
|
||||
|
||||
@@ -243,14 +267,14 @@ class lstr(str):
|
||||
Returns:
|
||||
lstr: A new lstr instance containing the sliced or indexed content, with the _origin_trace(s) preserved.
|
||||
"""
|
||||
result = super(lstr, self).__getitem__(key)
|
||||
result = super(_lstr, self).__getitem__(key)
|
||||
# This is a matter of opinon. I believe that when you Index into a language model output, you or divorcing the lodges of the indexed result from their contacts which produce them. Therefore, it is only reasonable to directly index into the lodges without changing the original context, and so any mutation on the string should invalidate the logits.
|
||||
# try:
|
||||
# logit_subset = self._logits[key] if self._logits else None
|
||||
# except:
|
||||
# logit_subset = None
|
||||
logit_subset = None
|
||||
return lstr(result, logit_subset, self.__origin_trace__)
|
||||
return _lstr(result, logit_subset, self.__origin_trace__)
|
||||
|
||||
def __getattribute__(self, name: str) -> Union[Callable, Any]:
|
||||
"""
|
||||
@@ -273,7 +297,7 @@ class lstr(str):
|
||||
if name == "__class__":
|
||||
return type(self)
|
||||
|
||||
if callable(attr) and name not in lstr.__dict__:
|
||||
if callable(attr) and name not in _lstr.__dict__:
|
||||
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
result = attr(*args, **kwargs)
|
||||
@@ -281,12 +305,12 @@ class lstr(str):
|
||||
if isinstance(result, str):
|
||||
_origin_traces = self.__origin_trace__
|
||||
for arg in args:
|
||||
if isinstance(arg, lstr):
|
||||
if isinstance(arg, _lstr):
|
||||
_origin_traces = _origin_traces.union(arg.__origin_trace__)
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, lstr):
|
||||
if isinstance(value, _lstr):
|
||||
_origin_traces = _origin_traces.union(value.__origin_trace__)
|
||||
return lstr(result, None, _origin_traces)
|
||||
return _lstr(result, None, _origin_traces)
|
||||
|
||||
return result
|
||||
|
||||
@@ -295,7 +319,7 @@ class lstr(str):
|
||||
return attr
|
||||
|
||||
@override
|
||||
def join(self, iterable: Iterable[Union[str, "lstr"]]) -> "lstr":
|
||||
def join(self, iterable: Iterable[Union[str, "_lstr"]]) -> "_lstr":
|
||||
"""
|
||||
Join a sequence of strings or lstr instances into a single lstr instance.
|
||||
|
||||
@@ -306,17 +330,17 @@ class lstr(str):
|
||||
lstr: A new lstr instance containing the joined content, with the _origin_trace(s) updated accordingly.
|
||||
"""
|
||||
parts = [str(item) for item in iterable]
|
||||
new_content = super(lstr, self).join(parts)
|
||||
new_content = super(_lstr, self).join(parts)
|
||||
new__origin_trace__ = self.__origin_trace__
|
||||
for item in iterable:
|
||||
if isinstance(item, lstr):
|
||||
if isinstance(item, _lstr):
|
||||
new__origin_trace__ = new__origin_trace__.union(item.__origin_trace__)
|
||||
return lstr(new_content, None, new__origin_trace__)
|
||||
return _lstr(new_content, None, new__origin_trace__)
|
||||
|
||||
@override
|
||||
def split(
|
||||
self, sep: Optional[Union[str, "lstr"]] = None, maxsplit: SupportsIndex = -1
|
||||
) -> List["lstr"]:
|
||||
self, sep: Optional[Union[str, "_lstr"]] = None, maxsplit: SupportsIndex = -1
|
||||
) -> List["_lstr"]:
|
||||
"""
|
||||
Split this lstr instance into a list of lstr instances based on a separator.
|
||||
|
||||
@@ -327,12 +351,12 @@ class lstr(str):
|
||||
Returns:
|
||||
List["lstr"]: A list of lstr instances containing the split content, with the _origin_trace(s) preserved.
|
||||
"""
|
||||
return self._split_helper(super(lstr, self).split, sep, maxsplit)
|
||||
return self._split_helper(super(_lstr, self).split, sep, maxsplit)
|
||||
|
||||
@override
|
||||
def rsplit(
|
||||
self, sep: Optional[Union[str, "lstr"]] = None, maxsplit: SupportsIndex = -1
|
||||
) -> List["lstr"]:
|
||||
self, sep: Optional[Union[str, "_lstr"]] = None, maxsplit: SupportsIndex = -1
|
||||
) -> List["_lstr"]:
|
||||
"""
|
||||
Split this lstr instance into a list of lstr instances based on a separator, starting from the right.
|
||||
|
||||
@@ -343,10 +367,10 @@ class lstr(str):
|
||||
Returns:
|
||||
List["lstr"]: A list of lstr instances containing the split content, with the _origin_trace(s) preserved.
|
||||
"""
|
||||
return self._split_helper(super(lstr, self).rsplit, sep, maxsplit)
|
||||
return self._split_helper(super(_lstr, self).rsplit, sep, maxsplit)
|
||||
|
||||
@override
|
||||
def splitlines(self, keepends: bool = False) -> List["lstr"]:
|
||||
def splitlines(self, keepends: bool = False) -> List["_lstr"]:
|
||||
"""
|
||||
Split this lstr instance into a list of lstr instances based on line breaks.
|
||||
|
||||
@@ -357,12 +381,12 @@ class lstr(str):
|
||||
List["lstr"]: A list of lstr instances containing the split content, with the _origin_trace(s) preserved.
|
||||
"""
|
||||
return [
|
||||
lstr(p, None, self.__origin_trace__)
|
||||
for p in super(lstr, self).splitlines(keepends=keepends)
|
||||
_lstr(p, None, self.__origin_trace__)
|
||||
for p in super(_lstr, self).splitlines(keepends=keepends)
|
||||
]
|
||||
|
||||
@override
|
||||
def partition(self, sep: Union[str, "lstr"]) -> Tuple["lstr", "lstr", "lstr"]:
|
||||
def partition(self, sep: Union[str, "_lstr"]) -> Tuple["_lstr", "_lstr", "_lstr"]:
|
||||
"""
|
||||
Partition this lstr instance into three lstr instances based on a separator.
|
||||
|
||||
@@ -372,10 +396,10 @@ class lstr(str):
|
||||
Returns:
|
||||
Tuple["lstr", "lstr", "lstr"]: A tuple of three lstr instances containing the content before the separator, the separator itself, and the content after the separator, with the _origin_trace(s) updated accordingly.
|
||||
"""
|
||||
return self._partition_helper(super(lstr, self).partition, sep)
|
||||
return self._partition_helper(super(_lstr, self).partition, sep)
|
||||
|
||||
@override
|
||||
def rpartition(self, sep: Union[str, "lstr"]) -> Tuple["lstr", "lstr", "lstr"]:
|
||||
def rpartition(self, sep: Union[str, "_lstr"]) -> Tuple["_lstr", "_lstr", "_lstr"]:
|
||||
"""
|
||||
Partition this lstr instance into three lstr instances based on a separator, starting from the right.
|
||||
|
||||
@@ -385,11 +409,11 @@ class lstr(str):
|
||||
Returns:
|
||||
Tuple["lstr", "lstr", "lstr"]: A tuple of three lstr instances containing the content before the separator, the separator itself, and the content after the separator, with the _origin_trace(s) updated accordingly.
|
||||
"""
|
||||
return self._partition_helper(super(lstr, self).rpartition, sep)
|
||||
return self._partition_helper(super(_lstr, self).rpartition, sep)
|
||||
|
||||
def _partition_helper(
|
||||
self, method, sep: Union[str, "lstr"]
|
||||
) -> Tuple["lstr", "lstr", "lstr"]:
|
||||
self, method, sep: Union[str, "_lstr"]
|
||||
) -> Tuple["_lstr", "_lstr", "_lstr"]:
|
||||
"""
|
||||
Helper method for partitioning this lstr instance based on a separator.
|
||||
|
||||
@@ -403,21 +427,21 @@ class lstr(str):
|
||||
part1, part2, part3 = method(sep)
|
||||
new__origin_trace__ = (
|
||||
self.__origin_trace__ | sep.__origin_trace__
|
||||
if isinstance(sep, lstr)
|
||||
if isinstance(sep, _lstr)
|
||||
else self.__origin_trace__
|
||||
)
|
||||
return (
|
||||
lstr(part1, None, new__origin_trace__),
|
||||
lstr(part2, None, new__origin_trace__),
|
||||
lstr(part3, None, new__origin_trace__),
|
||||
_lstr(part1, None, new__origin_trace__),
|
||||
_lstr(part2, None, new__origin_trace__),
|
||||
_lstr(part3, None, new__origin_trace__),
|
||||
)
|
||||
|
||||
def _split_helper(
|
||||
self,
|
||||
method,
|
||||
sep: Optional[Union[str, "lstr"]] = None,
|
||||
sep: Optional[Union[str, "_lstr"]] = None,
|
||||
maxsplit: SupportsIndex = -1,
|
||||
) -> List["lstr"]:
|
||||
) -> List["_lstr"]:
|
||||
"""
|
||||
Helper method for splitting this lstr instance based on a separator.
|
||||
|
||||
@@ -431,11 +455,11 @@ class lstr(str):
|
||||
"""
|
||||
_origin_traces = (
|
||||
self.__origin_trace__ | sep.__origin_trace__
|
||||
if isinstance(sep, lstr)
|
||||
if isinstance(sep, _lstr)
|
||||
else self.__origin_trace__
|
||||
)
|
||||
parts = method(sep, maxsplit)
|
||||
return [lstr(part, None, _origin_traces) for part in parts]
|
||||
return [_lstr(part, None, _origin_traces) for part in parts]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -450,14 +474,14 @@ if __name__ == "__main__":
|
||||
s1 = generate_random_string(1000)
|
||||
s2 = generate_random_string(1000)
|
||||
|
||||
lstr_time = timeit.timeit(lambda: lstr(s1) + lstr(s2), number=10000)
|
||||
lstr_time = timeit.timeit(lambda: _lstr(s1) + _lstr(s2), number=10000)
|
||||
str_time = timeit.timeit(lambda: s1 + s2, number=10000)
|
||||
|
||||
print(f"Concatenation: lstr: {lstr_time:.6f}s, str: {str_time:.6f}s")
|
||||
|
||||
def test_slicing():
|
||||
s = generate_random_string(10000)
|
||||
ls = lstr(s)
|
||||
ls = _lstr(s)
|
||||
|
||||
lstr_time = timeit.timeit(lambda: ls[1000:2000], number=10000)
|
||||
str_time = timeit.timeit(lambda: s[1000:2000], number=10000)
|
||||
@@ -466,7 +490,7 @@ if __name__ == "__main__":
|
||||
|
||||
def test_splitting():
|
||||
s = generate_random_string(10000)
|
||||
ls = lstr(s)
|
||||
ls = _lstr(s)
|
||||
|
||||
lstr_time = timeit.timeit(lambda: ls.split(), number=1000)
|
||||
str_time = timeit.timeit(lambda: s.split(), number=1000)
|
||||
@@ -475,9 +499,9 @@ if __name__ == "__main__":
|
||||
|
||||
def test_joining():
|
||||
words = [generate_random_string(10) for _ in range(1000)]
|
||||
lwords = [lstr(word) for word in words]
|
||||
lwords = [_lstr(word) for word in words]
|
||||
|
||||
lstr_time = timeit.timeit(lambda: lstr(' ').join(lwords), number=1000)
|
||||
lstr_time = timeit.timeit(lambda: _lstr(' ').join(lwords), number=1000)
|
||||
str_time = timeit.timeit(lambda: ' '.join(words), number=1000)
|
||||
|
||||
print(f"Joining: lstr: {lstr_time:.6f}s, str: {str_time:.6f}s")
|
||||
@@ -495,8 +519,8 @@ if __name__ == "__main__":
|
||||
def test_add():
|
||||
s1 = generate_random_string(1000)
|
||||
s2 = generate_random_string(1000)
|
||||
ls1 = lstr(s1, None, "origin1")
|
||||
ls2 = lstr(s2, None, "origin2")
|
||||
ls1 = _lstr(s1, None, "origin1")
|
||||
ls2 = _lstr(s2, None, "origin2")
|
||||
|
||||
for _ in range(100000):
|
||||
result = ls1 + ls2
|
||||
@@ -1,6 +1,6 @@
|
||||
from ell.configurator import config
|
||||
from ell.lmp.track import track
|
||||
from ell.lstr import lstr
|
||||
from ell._lstr import _lstr
|
||||
from ell.types import LMP, InvocableLM, LMPParams, LMPType, Message, MessageContentBlock, MessageOrDict, _lstr_generic
|
||||
from ell.util._warnings import _warnings
|
||||
from ell.util.api import call
|
||||
@@ -70,7 +70,7 @@ def _get_messages(prompt_ret: Union[str, list[MessageOrDict]], prompt: LMP) -> l
|
||||
"""
|
||||
if isinstance(prompt_ret, str):
|
||||
return [
|
||||
Message(role="system", content=[MessageContentBlock(text=lstr(prompt.__doc__) or config.default_system_prompt)]),
|
||||
Message(role="system", content=[MessageContentBlock(text=_lstr(prompt.__doc__) or config.default_system_prompt)]),
|
||||
Message(role="user", content=[MessageContentBlock(text=prompt_ret)]),
|
||||
]
|
||||
else:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
import threading
|
||||
from ell.types import LMPType, SerializedLStr, utc_now, SerializedLMP, Invocation, InvocationTrace
|
||||
from ell.types import LMPType, utc_now, SerializedLMP, Invocation, InvocationTrace
|
||||
import ell.util.closure
|
||||
from ell.configurator import config
|
||||
from ell.lstr import lstr
|
||||
from ell._lstr import _lstr
|
||||
|
||||
import inspect
|
||||
|
||||
@@ -202,25 +202,12 @@ def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion
|
||||
invocation_kwargs=invocation_kwargs,
|
||||
args=cleaned_invocation_params.get('args', []),
|
||||
kwargs=cleaned_invocation_params.get('kwargs', {}),
|
||||
used_by_id=parent_invocation_id
|
||||
used_by_id=parent_invocation_id,
|
||||
results=result
|
||||
)
|
||||
|
||||
results = []
|
||||
if isinstance(result, lstr):
|
||||
results = [result]
|
||||
elif isinstance(result, list):
|
||||
results = result
|
||||
else:
|
||||
raise TypeError("Result must be either lstr or List[lstr]")
|
||||
|
||||
serialized_results = [
|
||||
SerializedLStr(
|
||||
content=str(res),
|
||||
# logits=res.logits
|
||||
) for res in results
|
||||
]
|
||||
|
||||
config._store.write_invocation(invocation, serialized_results, consumes)
|
||||
|
||||
config._store.write_invocation(invocation, consumes)
|
||||
|
||||
def compute_state_cache_key(ipstr, fn_closure):
|
||||
_global_free_vars_str = f"{json.dumps(get_immutable_vars(fn_closure[2]), sort_keys=True, default=repr)}"
|
||||
@@ -269,7 +256,7 @@ def prepare_invocation_params(fn_args, fn_kwargs):
|
||||
lambda arr: arr.tolist()
|
||||
)
|
||||
invocation_converter.register_unstructure_hook(
|
||||
lstr,
|
||||
_lstr,
|
||||
process_lstr
|
||||
)
|
||||
invocation_converter.register_unstructure_hook(
|
||||
|
||||
@@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Dict, List, Set, Union
|
||||
from ell.lstr import lstr
|
||||
from ell.types import InvocableLM, SerializedLMP, Invocation, SerializedLStr
|
||||
from ell._lstr import _lstr
|
||||
from ell.types import InvocableLM, SerializedLMP, Invocation
|
||||
|
||||
|
||||
class Store(ABC):
|
||||
@@ -23,7 +23,7 @@ class Store(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write_invocation(self, invocation: Invocation, results: List[SerializedLStr], consumes: Set[str]) -> Optional[Any]:
|
||||
def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optional[Any]:
|
||||
"""
|
||||
Write an invocation of an LMP to the storage.
|
||||
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
"""
|
||||
This is an old stub from a previous jsonl store implementation that was
|
||||
used for testing. You can feel free to PR a full implementation if you'd like. It just needs an index.
|
||||
"""
|
||||
|
||||
# import os
|
||||
# import json
|
||||
# from typing import Any, Optional, Dict, List
|
||||
# from ell.lstr import lstr
|
||||
# import ell.store
|
||||
# import numpy as np
|
||||
# import glob
|
||||
# from operator import itemgetter
|
||||
# import warnings
|
||||
# import cattrs
|
||||
|
||||
# class JsonlStore(ell.store.Store):
|
||||
# def __init__(self, storage_dir: str, max_file_size: int = 1024 * 1024, check_empty: bool = False): # 1MB default
|
||||
# self.storage_dir = storage_dir
|
||||
# self.max_file_size = max_file_size
|
||||
# os.makedirs(storage_dir, exist_ok=True)
|
||||
# self.open_files = {}
|
||||
|
||||
# if check_empty and not os.path.exists(os.path.join(storage_dir, 'invocations')) and \
|
||||
# not os.path.exists(os.path.join(storage_dir, 'programs')):
|
||||
# warnings.warn(f"The ELL storage directory '{storage_dir}' is empty. No invocations or programs found.")
|
||||
|
||||
# self.converter = cattrs.Converter()
|
||||
# self._setup_cattrs()
|
||||
|
||||
# def lst_converter(self, obj: Any) -> Any:
|
||||
# # print(obj)
|
||||
# # return obj
|
||||
# return self.converter.unstructure(dict(content=str(obj), **obj.__dict__, __is_lstr=True))
|
||||
|
||||
# return obj
|
||||
# def _setup_cattrs(self):
|
||||
# self.converter.register_unstructure_hook(
|
||||
# np.ndarray,
|
||||
# lambda arr: arr.tolist()
|
||||
# )
|
||||
# self.converter.register_unstructure_hook(
|
||||
# lstr,
|
||||
# self.lst_converter
|
||||
# )
|
||||
# self.converter.register_unstructure_hook(
|
||||
# set,
|
||||
# lambda s: list(s)
|
||||
# )
|
||||
# self.converter.register_unstructure_hook(
|
||||
# frozenset,
|
||||
# lambda s: list(s)
|
||||
# )
|
||||
|
||||
# def _serialize(self, obj: Any) -> Any:
|
||||
# return self.converter.unstructure(obj)
|
||||
|
||||
# def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str],
|
||||
# created_at: float, is_lmp: bool, lm_kwargs: Optional[str],
|
||||
# uses: Dict[str, Any]) -> Optional[Any]:
|
||||
# """
|
||||
# Write the LMP (Language Model Program) to a JSON file.
|
||||
# """
|
||||
# file_path = os.path.join(self.storage_dir, 'programs', f"{name}_{lmp_id}.json")
|
||||
# os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
# lmp_data = {
|
||||
# 'lmp_id': lmp_id,
|
||||
# 'name': name,
|
||||
# 'source': source,
|
||||
# 'dependencies': dependencies,
|
||||
# 'created_at': created_at,
|
||||
# 'is_lmp': is_lmp,
|
||||
# 'lm_kwargs': lm_kwargs,
|
||||
# 'uses': uses
|
||||
# }
|
||||
|
||||
# with open(file_path, 'w') as f:
|
||||
# json.dump(self._serialize(lmp_data), f)
|
||||
|
||||
# return None
|
||||
|
||||
# def write_invocation(self, lmp_id: str, args: str, kwargs: str, result: str,
|
||||
# created_at: float, invocation_kwargs: Dict[str, Any]) -> Optional[Any]:
|
||||
# """
|
||||
# Write an LMP invocation to a JSONL file in a nested folder structure.
|
||||
# """
|
||||
# dir_path = os.path.join(self.storage_dir, 'invocations', lmp_id[:2], lmp_id[2:4], lmp_id[4:])
|
||||
# os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
# if lmp_id not in self.open_files:
|
||||
# index = 0
|
||||
# while True:
|
||||
# file_path = os.path.join(dir_path, f"invocations.{index}.jsonl")
|
||||
# if not os.path.exists(file_path) or os.path.getsize(file_path) < self.max_file_size:
|
||||
# self.open_files[lmp_id] = {'file': open(file_path, 'a'), 'path': file_path}
|
||||
# break
|
||||
# index += 1
|
||||
|
||||
# invocation_data = {
|
||||
# 'lmp_id': lmp_id,
|
||||
# 'args': args,
|
||||
# 'kwargs': kwargs,
|
||||
# 'result': result,
|
||||
# 'invocation_kwargs': invocation_kwargs,
|
||||
# 'created_at': created_at
|
||||
# }
|
||||
|
||||
# file_info = self.open_files[lmp_id]
|
||||
# file_info['file'].write(json.dumps(self._serialize(invocation_data)) + '\n')
|
||||
# file_info['file'].flush()
|
||||
|
||||
# if os.path.getsize(file_info['path']) >= self.max_file_size:
|
||||
# file_info['file'].close()
|
||||
# del self.open_files[lmp_id]
|
||||
|
||||
# return invocation_data
|
||||
|
||||
# def get_lmps(self, **filters: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# lmps = []
|
||||
# for file_path in glob.glob(os.path.join(self.storage_dir, 'programs', '*.json')):
|
||||
# with open(file_path, 'r') as f:
|
||||
# lmp = json.load(f)
|
||||
# if filters:
|
||||
# if all(lmp.get(k) == v for k, v in filters.items()):
|
||||
# lmps.append(lmp)
|
||||
# else:
|
||||
# lmps.append(lmp)
|
||||
# return lmps
|
||||
|
||||
# def get_invocations(self, lmp_id: str, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
# invocations = []
|
||||
# dir_path = os.path.join(self.storage_dir, 'invocations', lmp_id[:2], lmp_id[2:4], lmp_id[4:])
|
||||
# for file_path in glob.glob(os.path.join(dir_path, 'invocations.*.jsonl')):
|
||||
# with open(file_path, 'r') as f:
|
||||
# for line in f:
|
||||
# invocation = json.loads(line)
|
||||
# if filters:
|
||||
# if all(invocation.get(k) == v for k, v in filters.items()):
|
||||
# invocations.append(invocation)
|
||||
# else:
|
||||
# invocations.append(invocation)
|
||||
# return invocations
|
||||
|
||||
# def get_lmp(self, lmp_id: str) -> Optional[Dict[str, Any]]:
|
||||
# """
|
||||
# Get a specific LMP by its ID.
|
||||
# """
|
||||
# for file_path in glob.glob(os.path.join(self.storage_dir, 'programs', '*.json')):
|
||||
# with open(file_path, 'r') as f:
|
||||
# lmp = json.load(f)
|
||||
# if lmp.get('lmp_id') == lmp_id:
|
||||
# return lmp
|
||||
# return None
|
||||
|
||||
# def search_lmps(self, query: str) -> List[Dict[str, Any]]:
|
||||
# lmps = []
|
||||
# for file_path in glob.glob(os.path.join(self.storage_dir, 'programs', '*.json')):
|
||||
# with open(file_path, 'r') as f:
|
||||
# lmp = json.load(f)
|
||||
# if query.lower() in json.dumps(lmp).lower():
|
||||
# lmps.append(lmp)
|
||||
# return lmps
|
||||
|
||||
# def search_invocations(self, query: str) -> List[Dict[str, Any]]:
|
||||
# invocations = []
|
||||
# for dir_path in glob.glob(os.path.join(self.storage_dir, 'invocations', '*', '*', '*')):
|
||||
# for file_path in glob.glob(os.path.join(dir_path, 'invocations.*.jsonl')):
|
||||
# with open(file_path, 'r') as f:
|
||||
# for line in f:
|
||||
# invocation = json.loads(line)
|
||||
# if query.lower() in json.dumps(invocation).lower():
|
||||
# invocations.append(invocation)
|
||||
# return invocations
|
||||
|
||||
# def get_lmp_versions(self, lmp_id: str) -> List[Dict[str, Any]]:
|
||||
# """
|
||||
# Get all versions of an LMP with the given lmp_id.
|
||||
# """
|
||||
# target_lmp = self.get_lmp(lmp_id)
|
||||
# if not target_lmp:
|
||||
# return []
|
||||
|
||||
# versions = []
|
||||
# for file_path in glob.glob(os.path.join(self.storage_dir, 'programs', f"{target_lmp['name']}_*.json")):
|
||||
# with open(file_path, 'r') as f:
|
||||
# lmp = json.load(f)
|
||||
# versions.append(lmp)
|
||||
|
||||
# # Sort versions by created_at timestamp, newest first
|
||||
# return sorted(versions, key=lambda x: x['created_at'], reverse=True)
|
||||
|
||||
# def get_latest_lmps(self) -> List[Dict[str, Any]]:
|
||||
# """
|
||||
# Get the latest version of each unique LMP.
|
||||
# """
|
||||
# lmps_by_name = {}
|
||||
# for file_path in glob.glob(os.path.join(self.storage_dir, 'programs', '*.json')):
|
||||
# with open(file_path, 'r') as f:
|
||||
# lmp = json.load(f)
|
||||
# name = lmp['name']
|
||||
# if name not in lmps_by_name or lmp['created_at'] > lmps_by_name[name]['created_at']:
|
||||
# lmps_by_name[name] = lmp
|
||||
|
||||
# # Return the list of latest LMPs, sorted by name
|
||||
# return sorted(lmps_by_name.values(), key=itemgetter('name'))
|
||||
|
||||
# def __del__(self):
|
||||
# """
|
||||
# Close all open files when the serializer is destroyed.
|
||||
# """
|
||||
# for file_info in self.open_files.values():
|
||||
# file_info['file'].close()
|
||||
|
||||
# def install(self):
|
||||
# """
|
||||
# Install the serializer into all invocations of the ell wrapper.
|
||||
# """
|
||||
# ell.config.register_serializer(self)
|
||||
@@ -7,16 +7,14 @@ import ell.store
|
||||
import cattrs
|
||||
import numpy as np
|
||||
from sqlalchemy.sql import text
|
||||
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr, utc_now
|
||||
from ell.lstr import lstr
|
||||
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, utc_now
|
||||
from ell._lstr import _lstr
|
||||
from sqlalchemy import or_, func, and_, extract, case
|
||||
|
||||
class SQLStore(ell.store.Store):
|
||||
def __init__(self, db_uri: str):
|
||||
self.engine = create_engine(db_uri)
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
|
||||
|
||||
self.open_files: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
@@ -39,7 +37,7 @@ class SQLStore(ell.store.Store):
|
||||
session.commit()
|
||||
return None
|
||||
|
||||
def write_invocation(self, invocation: Invocation, results: List[SerializedLStr], consumes: Set[str]) -> Optional[Any]:
|
||||
def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optional[Any]:
|
||||
with Session(self.engine) as session:
|
||||
lmp = session.query(SerializedLMP).filter(SerializedLMP.lmp_id == invocation.lmp_id).first()
|
||||
assert lmp is not None, f"LMP with id {invocation.lmp_id} not found. Writing invocation erroneously"
|
||||
@@ -52,10 +50,6 @@ class SQLStore(ell.store.Store):
|
||||
|
||||
session.add(invocation)
|
||||
|
||||
for result in results:
|
||||
result.producer_invocation = invocation
|
||||
session.add(result)
|
||||
|
||||
# Now create traces.
|
||||
for consumed_id in consumes:
|
||||
session.add(InvocationTrace(
|
||||
@@ -116,9 +110,8 @@ class SQLStore(ell.store.Store):
|
||||
def get_invocations(self, session: Session, lmp_filters: Dict[str, Any], skip: int = 0, limit: int = 10, filters: Optional[Dict[str, Any]] = None, hierarchical: bool = False) -> List[Dict[str, Any]]:
|
||||
def fetch_invocation(inv_id):
|
||||
query = (
|
||||
select(Invocation, SerializedLStr, SerializedLMP)
|
||||
select(Invocation, SerializedLMP)
|
||||
.join(SerializedLMP)
|
||||
.outerjoin(SerializedLStr)
|
||||
.where(Invocation.id == inv_id)
|
||||
)
|
||||
results = session.exec(query).all()
|
||||
@@ -126,10 +119,10 @@ class SQLStore(ell.store.Store):
|
||||
if not results:
|
||||
return None
|
||||
|
||||
inv, lstr, lmp = results[0]
|
||||
inv, lmp = results[0]
|
||||
inv_dict = inv.model_dump()
|
||||
inv_dict['lmp'] = lmp.model_dump()
|
||||
inv_dict['results'] = [dict(**l.model_dump(), __lstr=True) for l in [r[1] for r in results if r[1]]]
|
||||
inv_dict['results'] = inv_dict['lmp']['results']
|
||||
|
||||
# Fetch consumes and consumed_by invocation IDs
|
||||
consumes_query = select(InvocationTrace.invocation_consuming_id).where(InvocationTrace.invocation_consumer_id == inv_id)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlmodel import SQLModel
|
||||
from ell.types import SerializedLMPBase, InvocationBase, SerializedLStrBase
|
||||
from ell.types import SerializedLMPBase, InvocationBase
|
||||
|
||||
class SerializedLMPPublic(SerializedLMPBase):
|
||||
pass
|
||||
@@ -25,7 +25,6 @@ class SerializedLMPUpdate(SQLModel):
|
||||
|
||||
class InvocationPublic(InvocationBase):
|
||||
lmp: SerializedLMPPublic
|
||||
results: List[SerializedLStrBase]
|
||||
consumes: List[str]
|
||||
consumed_by: List[str]
|
||||
uses: List[str]
|
||||
@@ -44,11 +43,6 @@ class InvocationUpdate(SQLModel):
|
||||
state_cache_key: Optional[str] = None
|
||||
invocation_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
class SerializedLStrPublic(SerializedLStrBase):
|
||||
pass
|
||||
|
||||
class SerializedLStrCreate(SerializedLStrBase):
|
||||
pass
|
||||
|
||||
class SerializedLStrUpdate(SQLModel):
|
||||
content: Optional[str] = None
|
||||
|
||||
@@ -4,7 +4,9 @@ from typing import Callable, Dict, List, Type, Union, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator, validator
|
||||
|
||||
from ell.lstr import lstr
|
||||
import enum
|
||||
|
||||
from ell._lstr import _lstr
|
||||
from ell.util.dict_sync_meta import DictSyncMeta
|
||||
|
||||
from datetime import datetime, timezone
|
||||
@@ -13,7 +15,7 @@ from sqlmodel import Field, SQLModel, Relationship, JSON, Column
|
||||
from sqlalchemy import Index, func
|
||||
import sqlalchemy.types as types
|
||||
|
||||
_lstr_generic = Union[lstr, str]
|
||||
_lstr_generic = Union[_lstr, str]
|
||||
|
||||
OneTurn = Callable[..., _lstr_generic]
|
||||
|
||||
@@ -24,6 +26,8 @@ LMPParams = Dict[str, Any]
|
||||
|
||||
InvocableTool = Callable[..., _lstr_generic]
|
||||
|
||||
|
||||
|
||||
# todo: implement tracing for structured outs. this a v2 feature.
|
||||
class ToolCall(BaseModel):
|
||||
tool : InvocableTool
|
||||
@@ -36,7 +40,7 @@ class MessageContentBlock(BaseModel):
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
||||
structured: Optional[Type[BaseModel]] = Field(default=None)
|
||||
|
||||
# XXX: Todo implement this.
|
||||
# XXX: Should validate.
|
||||
# @field_validator('*')
|
||||
# @classmethod
|
||||
# def check_at_least_one_field_set(cls, v, info):
|
||||
@@ -47,14 +51,15 @@ class MessageContentBlock(BaseModel):
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: List[MessageContentBlock] = Field(min_length=1)
|
||||
|
||||
def __json__(self):
|
||||
print("dumping", self.model_dump())
|
||||
return self.model_dump()
|
||||
def to_openai_message(self):
|
||||
assert len(self.content) == 1
|
||||
return {
|
||||
"role": self.role,
|
||||
"content": self.content[0].text
|
||||
}
|
||||
|
||||
|
||||
|
||||
# Well this is disappointing, I wanted to effectively type hint by doign that data sync meta, but eh, at elast we can still reference role or content this way. Probably wil lcan the dict sync meta.
|
||||
@@ -103,8 +108,6 @@ def UTCTimestampField(index:bool=False, **kwargs:Any):
|
||||
sa_column=Column(UTCTimestamp(timezone=True), index=index, **kwargs))
|
||||
|
||||
|
||||
import enum
|
||||
|
||||
class LMPType(str, enum.Enum):
|
||||
LM = "LM"
|
||||
TOOL = "TOOL"
|
||||
@@ -112,6 +115,7 @@ class LMPType(str, enum.Enum):
|
||||
OTHER = "OTHER"
|
||||
|
||||
|
||||
|
||||
class SerializedLMPBase(SQLModel):
|
||||
lmp_id: Optional[str] = Field(default=None, primary_key=True)
|
||||
name: str = Field(index=True)
|
||||
@@ -156,15 +160,6 @@ class InvocationTrace(SQLModel, table=True):
|
||||
invocation_consumer_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True)
|
||||
invocation_consuming_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True)
|
||||
|
||||
class SerializedLStrBase(SQLModel):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
content: str
|
||||
logits: List[float] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
producer_invocation_id: Optional[str] = Field(default=None, foreign_key="invocation.id", index=True)
|
||||
|
||||
class SerializedLStr(SerializedLStrBase, table=True):
|
||||
producer_invocation: Optional["Invocation"] = Relationship(back_populates="results")
|
||||
|
||||
# Should be subtyped for differnet kidns of LMPS.
|
||||
class InvocationBase(SQLModel):
|
||||
id: Optional[str] = Field(default=None, primary_key=True)
|
||||
@@ -180,10 +175,10 @@ class InvocationBase(SQLModel):
|
||||
created_at: datetime = UTCTimestampField(default=func.now(), nullable=False)
|
||||
invocation_kwargs: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON))
|
||||
used_by_id: Optional[str] = Field(default=None, foreign_key="invocation.id", index=True)
|
||||
results : Union[List[Message], Any] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
|
||||
class Invocation(InvocationBase, table=True):
|
||||
lmp: SerializedLMP = Relationship(back_populates="invocations")
|
||||
results: List[SerializedLStr] = Relationship(back_populates="producer_invocation")
|
||||
consumed_by: List["Invocation"] = Relationship(
|
||||
back_populates="consumes",
|
||||
link_model=InvocationTrace,
|
||||
|
||||
@@ -4,7 +4,7 @@ from functools import partial
|
||||
from ell.configurator import config
|
||||
import openai
|
||||
from collections import defaultdict
|
||||
from ell.lstr import lstr
|
||||
from ell._lstr import _lstr
|
||||
from ell.types import LMP, LMPParams, Message, MessageContentBlock, MessageOrDict
|
||||
|
||||
|
||||
@@ -22,9 +22,11 @@ def process_messages_for_client(messages: list[Message], client: Any):
|
||||
return [
|
||||
message.to_openai_message()
|
||||
for message in messages]
|
||||
|
||||
# elif isinstance(client, anthropic.Anthropic):
|
||||
# return messages
|
||||
# XXX: or some such.
|
||||
|
||||
|
||||
def call(
|
||||
*,
|
||||
model: str,
|
||||
@@ -36,7 +38,7 @@ def call(
|
||||
_exempt_from_tracking: bool,
|
||||
_logging_color=None,
|
||||
_name: str = None,
|
||||
) -> Tuple[Union[lstr, Iterable[lstr]], Optional[Dict[str, Any]]]:
|
||||
) -> Tuple[Union[_lstr, Iterable[_lstr]], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Helper function to run the language model with the provided messages and parameters.
|
||||
"""
|
||||
@@ -114,13 +116,14 @@ def call(
|
||||
tracked_results = [
|
||||
# TODO: Remove hardcoding
|
||||
# TODO: Unversal message format
|
||||
Message(role=choice.message.role if not streaming else choice_deltas[0].delta.role, content=[MessageContentBlock(text=lstr(
|
||||
Message(role=choice.message.role if not streaming else choice_deltas[0].delta.role, content=[MessageContentBlock(text=_lstr(
|
||||
content="".join((choice.delta.content or "" for choice in choice_deltas)) if streaming else choice.message.content,
|
||||
_origin_trace=_invocation_origin,
|
||||
))])
|
||||
for _, choice_deltas in sorted(choices_progress.items(), key= lambda x: x[0],)
|
||||
]
|
||||
|
||||
print(tracked_results)
|
||||
api_params= dict(model=model, messages=client_safe_messages_messages, lm_kwargs=lm_kwargs)
|
||||
|
||||
return tracked_results[0] if n_choices == 1 else tracked_results, api_params, metadata
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from ell.lstr import lstr
|
||||
from ell._lstr import _lstr
|
||||
|
||||
|
||||
class TestLstr:
|
||||
def test_init(self):
|
||||
# Test initialization with string content only
|
||||
s = lstr("hello")
|
||||
s = _lstr("hello")
|
||||
assert str(s) == "hello"
|
||||
assert s.logits is None
|
||||
assert s._origin_trace == frozenset()
|
||||
@@ -14,14 +14,14 @@ class TestLstr:
|
||||
# Test initialization with logits and _origin_trace
|
||||
logits = np.array([0.1, 0.2])
|
||||
_origin_trace = "model1"
|
||||
s = lstr("world", logits=logits, _origin_trace=_origin_trace)
|
||||
s = _lstr("world", logits=logits, _origin_trace=_origin_trace)
|
||||
assert str(s) == "world"
|
||||
assert np.array_equal(s.logits, logits)
|
||||
assert s._origin_trace == frozenset({_origin_trace})
|
||||
|
||||
def test_add(self):
|
||||
s1 = lstr("hello")
|
||||
s2 = lstr("world", _origin_trace="model2")
|
||||
s1 = _lstr("hello")
|
||||
s2 = _lstr("world", _origin_trace="model2")
|
||||
assert isinstance(s1 + s2, str)
|
||||
result = s1 + s2
|
||||
assert str(result) == "helloworld"
|
||||
@@ -29,21 +29,21 @@ class TestLstr:
|
||||
assert result._origin_trace == frozenset({"model2"})
|
||||
|
||||
def test_mod(self):
|
||||
s = lstr("hello %s")
|
||||
s = _lstr("hello %s")
|
||||
result = s % "world"
|
||||
assert str(result) == "hello world"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset()
|
||||
|
||||
def test_mul(self):
|
||||
s = lstr("ha", _origin_trace="model3")
|
||||
s = _lstr("ha", _origin_trace="model3")
|
||||
result = s * 3
|
||||
assert str(result) == "hahaha"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset({"model3"})
|
||||
|
||||
def test_getitem(self):
|
||||
s = lstr(
|
||||
s = _lstr(
|
||||
"hello", logits=np.array([0.1, 0.2, 0.3, 0.4, 0.5]), _origin_trace="model4"
|
||||
)
|
||||
result = s[1:4]
|
||||
@@ -53,50 +53,50 @@ class TestLstr:
|
||||
|
||||
def test_upper(self):
|
||||
# Test upper method without _origin_trace and logits
|
||||
s = lstr("hello")
|
||||
s = _lstr("hello")
|
||||
result = s.upper()
|
||||
assert str(result) == "HELLO"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset()
|
||||
|
||||
# Test upper method with _origin_trace
|
||||
s = lstr("world", _origin_trace="model11")
|
||||
s = _lstr("world", _origin_trace="model11")
|
||||
result = s.upper()
|
||||
assert str(result) == "WORLD"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset({"model11"})
|
||||
|
||||
def test_join(self):
|
||||
s = lstr(", ", _origin_trace="model5")
|
||||
parts = [lstr("hello"), lstr("world", _origin_trace="model6")]
|
||||
s = _lstr(", ", _origin_trace="model5")
|
||||
parts = [_lstr("hello"), _lstr("world", _origin_trace="model6")]
|
||||
result = s.join(parts)
|
||||
assert str(result) == "hello, world"
|
||||
assert result.logits is None
|
||||
assert result._origin_trace == frozenset({"model5", "model6"})
|
||||
|
||||
def test_split(self):
|
||||
s = lstr("hello world", _origin_trace="model7")
|
||||
s = _lstr("hello world", _origin_trace="model7")
|
||||
parts = s.split()
|
||||
assert [str(p) for p in parts] == ["hello", "world"]
|
||||
assert all(p.logits is None for p in parts)
|
||||
assert all(p._origin_trace == frozenset({"model7"}) for p in parts)
|
||||
|
||||
def test_partition(self):
|
||||
s = lstr("hello, world", _origin_trace="model8")
|
||||
s = _lstr("hello, world", _origin_trace="model8")
|
||||
part1, sep, part2 = s.partition(", ")
|
||||
assert (str(part1), str(sep), str(part2)) == ("hello", ", ", "world")
|
||||
assert all(p.logits is None for p in (part1, sep, part2))
|
||||
assert all(p._origin_trace == frozenset({"model8"}) for p in (part1, sep, part2))
|
||||
|
||||
def test_formatting(self):
|
||||
s = lstr("Hello {}!")
|
||||
filled = s.format(lstr("world", _origin_trace="model9"))
|
||||
s = _lstr("Hello {}!")
|
||||
filled = s.format(_lstr("world", _origin_trace="model9"))
|
||||
assert str(filled) == "Hello world!"
|
||||
assert filled.logits is None
|
||||
assert filled._origin_trace == frozenset({"model9"})
|
||||
|
||||
def test_repr(self):
|
||||
s = lstr("test", logits=np.array([1.0]), _origin_trace="model10")
|
||||
s = _lstr("test", logits=np.array([1.0]), _origin_trace="model10")
|
||||
assert "test" in repr(s)
|
||||
assert "model10" in repr(s._origin_trace)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user