i give up

This commit is contained in:
William Guss
2024-08-21 19:33:28 -07:00
parent 5f9741b7e0
commit 02d431f78d
3 changed files with 47 additions and 32 deletions

View File

@@ -52,15 +52,16 @@ def lm(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=
return tracked_str, api_params, metadata
# TODO: # we'll deal with type safety here later
model_call.__ell_lm_kwargs__ = lm_kwargs
# XXX: Do we need intermediate params?
model_call.__ell_func__ = prompt
model_call.__ell_type__ = LMPType.LM
model_call.__ell_exempt_from_tracking = exempt_from_tracking
if exempt_from_tracking:
return model_call
else:
return track(model_call, forced_dependencies=dict(tools=tools))
return track(model_call, forced_dependencies=dict(tools=tools), lmp_type=LMPType.LM, lm_kwargs=lm_kwargs)
return parameterized_lm_decorator
return parameterized_lm_decorator

View File

@@ -44,19 +44,15 @@ def exclude_var(v):
# is module or is immutable
return inspect.ismodule(v)
def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, Any]] = None) -> Callable:
lmp_type = getattr(func_to_track, "__ell_type__", LMPType.OTHER)
_has_serialized_lmp = {}
_lmp_hash = {}
def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, Any]] = None, lm_kwargs: Optional[Dict[str, Any]] = None, lmp_type: Optional[LMPType] = LMPType.OTHER) -> Callable:
# see if it exists
if not hasattr(func_to_track, "_has_serialized_lmp"):
func_to_track._has_serialized_lmp = False
if not hasattr(func_to_track, "__ell_hash__") and not config.lazy_versioning:
if not ell.util.closure.has_closured_function(func_to_track) and not config.lazy_versioning:
ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies)
@wraps(func_to_track)
def tracked_func(*fn_args, **fn_kwargs) -> str:
# XXX: Cache keys and global variable binding is not thread safe.
@@ -76,14 +72,15 @@ def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, An
if try_use_cache:
# Todo: add nice logging if verbose for when using a cahced invocaiton. IN a different color with thar args..
if not hasattr(func_to_track, "__ell_hash__") and config.lazy_versioning:
if not ell.util.closure.has_closured_function(func_to_track) and config.lazy_versioning:
fn_closure, _ = ell.util.closure.lexically_closured_source(func_to_track)
# compute the state cachekey
state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__)
lexical_closure = ell.util.closure.get_lexical_closure(func_to_track)
state_cache_key = compute_state_cache_key(ipstr, lexical_closure.closure)
cache_store = func_to_track.__wrapper__.__ell_use_cache__
cached_invocations = cache_store.get_cached_invocations(func_to_track.__ell_hash__, state_cache_key)
cached_invocations = cache_store.get_cached_invocations(lexical_closure.hash, state_cache_key)
if len(cached_invocations) > 0:
@@ -115,12 +112,13 @@ def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, An
prompt_tokens=usage.get("prompt_tokens", 0)
completion_tokens=usage.get("completion_tokens", 0)
if not hasattr(func_to_track, "__ell_hash__") and config.lazy_versioning:
if not ell.util.closure.has_closured_function(func_to_track) and config.lazy_versioning:
ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies)
_serialize_lmp(func_to_track)
lexical_closure = ell.util.closure.get_lexical_closure(func_to_track)
if not state_cache_key:
state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__)
state_cache_key = compute_state_cache_key(ipstr, lexical_closure.closure)
_write_invocation(func_to_track, invocation_id, latency_ms, prompt_tokens, completion_tokens,
state_cache_key, invocation_kwargs, cleaned_invocation_params, consumes, result, parent_invocation_id)
@@ -131,8 +129,7 @@ def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, An
func_to_track.__wrapper__ = tracked_func
if hasattr(func_to_track, "__ell_lm_kwargs__"):
tracked_func.__ell_lm_kwargs__ = func_to_track.__ell_lm_kwargs__
# XXX: Move away from __ private declarations this should be object oriented.
if hasattr(func_to_track, "__ell_params_model__"):
tracked_func.__ell_params_model__ = func_to_track.__ell_params_model__
tracked_func.__ell_func__ = func_to_track
@@ -142,12 +139,13 @@ def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, An
def _serialize_lmp(func):
# Serialize deptjh first all fo the used lmps.
for f in func.__ell_uses__:
lexical_closure = ell.util.closure.get_lexical_closure(func)
for f in lexical_closure.uses:
_serialize_lmp(f)
if getattr(func, "_has_serialized_lmp", False):
if getattr(func, _has_serialized_lmp[func], False):
return
func._has_serialized_lmp = False
_has_serialized_lmp[func] = True
fn_closure = func.__ell_closure__
lmp_type = func.__ell_type__
name = func.__qualname__

View File

@@ -30,6 +30,7 @@ def xD():
"""
import collections
import ast
from dataclasses import dataclass
import hashlib
import itertools
import os
@@ -279,14 +280,7 @@ def _generate_function_hash(source, dsrc, qualname):
"""Generate a hash for the function."""
return "lmp-" + hashlib.md5("\n".join((source, dsrc, qualname)).encode()).hexdigest()
def _update_ell_func(outer_ell_func, source, dsrc, globals_dict, frees_dict, fn_hash, uses):
"""Update the ell function attributes."""
formatted_source = _format_source(source)
formatted_dsrc = _format_source(dsrc)
if hasattr(outer_ell_func, "__ell_func__"):
outer_ell_func.__ell_closure__ = (formatted_source, formatted_dsrc, globals_dict, frees_dict)
outer_ell_func.__ell_hash__ = fn_hash
outer_ell_func.__ell_uses__ = uses
def _raise_error(message, exception, recursion_stack):
"""Raise an error with detailed information."""
@@ -505,3 +499,25 @@ def globalvars(func, recurse=True, builtin=False):
#NOTE: if name not in __globals__, then we skip it...
return dict((name,globs[name]) for name in func if name in globs)
def _update_ell_func(outer_ell_func, source, dsrc, globals_dict, frees_dict, fn_hash, uses):
"""Update the ell function attributes."""
formatted_source = _format_source(source)
formatted_dsrc = _format_source(dsrc)
if hasattr(outer_ell_func, "__ell_func__"):
function_closures[outer_ell_func] = LexicalClosure(hash=fn_hash, closure=(formatted_source, formatted_dsrc, globals_dict, frees_dict), uses=uses)
@dataclass
class LexicalClosure:
hash : str
closure : Tuple[str, str, Dict[str, Any], Dict[str, Any]]
uses : Set[str]
# cache of all the closured funciton closures.
function_closures : Dict[Callable, LexicalClosure] = {}
def has_closured_function(func : Callable) -> bool:
return func in function_closures
def get_lexical_closure(func : Callable) -> LexicalClosure | None:
return function_closures.get(func)