mirror of
https://github.com/MadcowD/ell.git
synced 2024-09-22 16:14:36 +03:00
i give up
This commit is contained in:
@@ -52,15 +52,16 @@ def lm(model: str, client: Optional[openai.Client] = None, exempt_from_tracking=
|
||||
return tracked_str, api_params, metadata
|
||||
|
||||
# TODO: # we'll deal with type safety here later
|
||||
model_call.__ell_lm_kwargs__ = lm_kwargs
|
||||
# XXX: Do we need intermediate params?
|
||||
model_call.__ell_func__ = prompt
|
||||
model_call.__ell_type__ = LMPType.LM
|
||||
model_call.__ell_exempt_from_tracking = exempt_from_tracking
|
||||
|
||||
if exempt_from_tracking:
|
||||
return model_call
|
||||
else:
|
||||
return track(model_call, forced_dependencies=dict(tools=tools))
|
||||
return track(model_call, forced_dependencies=dict(tools=tools), lmp_type=LMPType.LM, lm_kwargs=lm_kwargs)
|
||||
|
||||
|
||||
return parameterized_lm_decorator
|
||||
|
||||
|
||||
|
||||
return parameterized_lm_decorator
|
||||
@@ -44,19 +44,15 @@ def exclude_var(v):
|
||||
# is module or is immutable
|
||||
return inspect.ismodule(v)
|
||||
|
||||
def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, Any]] = None) -> Callable:
|
||||
|
||||
lmp_type = getattr(func_to_track, "__ell_type__", LMPType.OTHER)
|
||||
_has_serialized_lmp = {}
|
||||
_lmp_hash = {}
|
||||
|
||||
def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, Any]] = None, lm_kwargs: Optional[Dict[str, Any]] = None, lmp_type: Optional[LMPType] = LMPType.OTHER) -> Callable:
|
||||
|
||||
# see if it exists
|
||||
if not hasattr(func_to_track, "_has_serialized_lmp"):
|
||||
func_to_track._has_serialized_lmp = False
|
||||
|
||||
if not hasattr(func_to_track, "__ell_hash__") and not config.lazy_versioning:
|
||||
|
||||
if not ell.util.closure.has_closured_function(func_to_track) and not config.lazy_versioning:
|
||||
ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies)
|
||||
|
||||
|
||||
@wraps(func_to_track)
|
||||
def tracked_func(*fn_args, **fn_kwargs) -> str:
|
||||
# XXX: Cache keys and global variable binding is not thread safe.
|
||||
@@ -76,14 +72,15 @@ def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, An
|
||||
|
||||
if try_use_cache:
|
||||
# Todo: add nice logging if verbose for when using a cahced invocaiton. IN a different color with thar args..
|
||||
if not hasattr(func_to_track, "__ell_hash__") and config.lazy_versioning:
|
||||
if not ell.util.closure.has_closured_function(func_to_track) and config.lazy_versioning:
|
||||
fn_closure, _ = ell.util.closure.lexically_closured_source(func_to_track)
|
||||
|
||||
# compute the state cachekey
|
||||
state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__)
|
||||
lexical_closure = ell.util.closure.get_lexical_closure(func_to_track)
|
||||
state_cache_key = compute_state_cache_key(ipstr, lexical_closure.closure)
|
||||
|
||||
cache_store = func_to_track.__wrapper__.__ell_use_cache__
|
||||
cached_invocations = cache_store.get_cached_invocations(func_to_track.__ell_hash__, state_cache_key)
|
||||
cached_invocations = cache_store.get_cached_invocations(lexical_closure.hash, state_cache_key)
|
||||
|
||||
|
||||
if len(cached_invocations) > 0:
|
||||
@@ -115,12 +112,13 @@ def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, An
|
||||
prompt_tokens=usage.get("prompt_tokens", 0)
|
||||
completion_tokens=usage.get("completion_tokens", 0)
|
||||
|
||||
if not hasattr(func_to_track, "__ell_hash__") and config.lazy_versioning:
|
||||
if not ell.util.closure.has_closured_function(func_to_track) and config.lazy_versioning:
|
||||
ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies)
|
||||
_serialize_lmp(func_to_track)
|
||||
|
||||
lexical_closure = ell.util.closure.get_lexical_closure(func_to_track)
|
||||
if not state_cache_key:
|
||||
state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__)
|
||||
state_cache_key = compute_state_cache_key(ipstr, lexical_closure.closure)
|
||||
|
||||
_write_invocation(func_to_track, invocation_id, latency_ms, prompt_tokens, completion_tokens,
|
||||
state_cache_key, invocation_kwargs, cleaned_invocation_params, consumes, result, parent_invocation_id)
|
||||
@@ -131,8 +129,7 @@ def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, An
|
||||
|
||||
|
||||
func_to_track.__wrapper__ = tracked_func
|
||||
if hasattr(func_to_track, "__ell_lm_kwargs__"):
|
||||
tracked_func.__ell_lm_kwargs__ = func_to_track.__ell_lm_kwargs__
|
||||
# XXX: Move away from __ private declarations this should be object oriented.
|
||||
if hasattr(func_to_track, "__ell_params_model__"):
|
||||
tracked_func.__ell_params_model__ = func_to_track.__ell_params_model__
|
||||
tracked_func.__ell_func__ = func_to_track
|
||||
@@ -142,12 +139,13 @@ def track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, An
|
||||
|
||||
def _serialize_lmp(func):
|
||||
# Serialize deptjh first all fo the used lmps.
|
||||
for f in func.__ell_uses__:
|
||||
lexical_closure = ell.util.closure.get_lexical_closure(func)
|
||||
for f in lexical_closure.uses:
|
||||
_serialize_lmp(f)
|
||||
|
||||
if getattr(func, "_has_serialized_lmp", False):
|
||||
if getattr(func, _has_serialized_lmp[func], False):
|
||||
return
|
||||
func._has_serialized_lmp = False
|
||||
_has_serialized_lmp[func] = True
|
||||
fn_closure = func.__ell_closure__
|
||||
lmp_type = func.__ell_type__
|
||||
name = func.__qualname__
|
||||
|
||||
@@ -30,6 +30,7 @@ def xD():
|
||||
"""
|
||||
import collections
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
import itertools
|
||||
import os
|
||||
@@ -279,14 +280,7 @@ def _generate_function_hash(source, dsrc, qualname):
|
||||
"""Generate a hash for the function."""
|
||||
return "lmp-" + hashlib.md5("\n".join((source, dsrc, qualname)).encode()).hexdigest()
|
||||
|
||||
def _update_ell_func(outer_ell_func, source, dsrc, globals_dict, frees_dict, fn_hash, uses):
|
||||
"""Update the ell function attributes."""
|
||||
formatted_source = _format_source(source)
|
||||
formatted_dsrc = _format_source(dsrc)
|
||||
if hasattr(outer_ell_func, "__ell_func__"):
|
||||
outer_ell_func.__ell_closure__ = (formatted_source, formatted_dsrc, globals_dict, frees_dict)
|
||||
outer_ell_func.__ell_hash__ = fn_hash
|
||||
outer_ell_func.__ell_uses__ = uses
|
||||
|
||||
|
||||
def _raise_error(message, exception, recursion_stack):
|
||||
"""Raise an error with detailed information."""
|
||||
@@ -505,3 +499,25 @@ def globalvars(func, recurse=True, builtin=False):
|
||||
#NOTE: if name not in __globals__, then we skip it...
|
||||
return dict((name,globs[name]) for name in func if name in globs)
|
||||
|
||||
|
||||
def _update_ell_func(outer_ell_func, source, dsrc, globals_dict, frees_dict, fn_hash, uses):
|
||||
"""Update the ell function attributes."""
|
||||
formatted_source = _format_source(source)
|
||||
formatted_dsrc = _format_source(dsrc)
|
||||
if hasattr(outer_ell_func, "__ell_func__"):
|
||||
function_closures[outer_ell_func] = LexicalClosure(hash=fn_hash, closure=(formatted_source, formatted_dsrc, globals_dict, frees_dict), uses=uses)
|
||||
|
||||
@dataclass
|
||||
class LexicalClosure:
|
||||
hash : str
|
||||
closure : Tuple[str, str, Dict[str, Any], Dict[str, Any]]
|
||||
uses : Set[str]
|
||||
|
||||
# cache of all the closured funciton closures.
|
||||
function_closures : Dict[Callable, LexicalClosure] = {}
|
||||
|
||||
def has_closured_function(func : Callable) -> bool:
|
||||
return func in function_closures
|
||||
|
||||
def get_lexical_closure(func : Callable) -> LexicalClosure | None:
|
||||
return function_closures.get(func)
|
||||
Reference in New Issue
Block a user