readability refactor of closure

This commit is contained in:
William Guss
2024-08-02 20:50:21 -07:00
parent 6483ba7ad7
commit 36d4354043
2 changed files with 199 additions and 371 deletions

129
README.md
View File

@@ -128,132 +128,3 @@ You can then visualize your promtps by visiting the frontend on `http://localhos
- Convert all of our todos into issues and milestones
- Multimodality
- Output parsing.
## Todos
### Metrics
- [ ] Design the metrics functionality. (maybe link WandB lol)
### Bugs
- [ ] Fix weird rehashing issue of the main prompt whenever subprompt changes? Or just make commits more of a background deal.
- [x] Trace not writing on first invoc.
- [ ] Serialize lkstrs in the jkson dumps in pyhton the same way as the db serializers them for the frontend (\_\_lstr vs SerialziedLstr) <- these are pydantic models and so we can reuse them
- [ ] handle failure to serialize.
- [ ] Unify cattrs deserialziation and serialization its fucked right now.
- [ ] Fix URL not changing on invocation click
## Tests
- [ ] Add comprehensive unit tests for all core functionalities.
- [ ] Implement integration tests for end-to-end workflows.
- [ ] Optimize backend performance and run benchmarks.
## Trace Functionality
- [x] Visualize trace in graph
- [x] Implement Langsmith-style invocations and traces
- [x] Improve UX on traces
- [o] Complete full trace implementation on invocation page
- [x] Enhance UX around traces in dependency graphs
- [o] Implement argument pass-through functionality
- [ ] Working trace graph thing
## Ell Studio
- [ ] Optimize front-end data fetching & api structure
- [ ] Add websockets for live updates.
- [ ] Add a section for the actual LLM call
## Version History
- [x] Auto-document commit changes
- [x] Implement version history diff view (possibly with automatic commit messages using GPT-4)
- [ ] Add interactive diff view
- [ ] Highlight changes in source when switching versions
- [ ] Automatically format python code on serialization (version doesnt change just because formatting does.)
## Caching and versioning
- [x] Get caching to work for bound global mutable types
- [x] Decide if version not change just because I invoke an LMP using a global that was mutated some where else in the program. (In example where we had test being global with a list of objects of type Test) that shouldnt change the version number. My vote is that it shouldnt change.
- [x] Bind global mutable at invocation, and update our Invocation database types
- [x] Come up with a nice way to expand and collapse stored globals in dependency scope.disc
- [x] Serialize globals free variables at time of definition
- [x] Come up with a nice way to display bound globals of an invocation
- [x] Custom syntax highlighting for these bound mutable objects
- [x] When we click invocation, the global if its overridden at the time execution needs to be updated.
- [x] Decide where the line between program state and program version changes. If we had a global variable which was a list of [“asd”]*10000 do we include the contents of this shit in the version.
- [ ] Better definition disclosure (and hiding) for bound global variables
- [ ] Design or steal vs-codes debug instance widget
- [ ] Bind all arguments to their names in the function signature (pre v1)
- [ ] Compute type lexical closures as needed (post v1)
## LM Functionality
- [x] Freezing and unfreezing LMPs
- [ ] Support multimodal inputs
- [ ] Implement function calling
- [ ] Add persistent chatting capability
- [ ] Integrate with various LLM providers
## Use Cases
- [ ] Develop RAG (Retrieval-Augmented Generation) example
- [ ] Implement embeddings functionality
- [ ] Create examples for tool use and agents
- [ ] Demonstrate Chain of Thought (CoT) reasoning
- [ ] Showcase optimization techniques
## Store
- [ ] Improve developer experience around logging mechanisms
- [ ] Drastically improve query runtime and performance test.
## DX (Developer Experience)
- [x] Enhance UX for the LMP details page
- [ ] Improve layout for the differnet parts
- [ ] Add Dependency Graph on LMP page
- [ ] Implement VSCode-style file explorer
- [ ] Ensure and test Jupyter compatibility
- [ ] Continue UI/UX improvements for the visualization component
- [x] Update LMP Details to be function-based for easier result viewing
- [ ] Implement easy navigation similar to VSCode (cmd+shift+p or spotlight)
- [x] Optimize display of dependencies on prompt pages
- [ ] Automatic API key management if not specified as a nice experience. Should ask for the api key and store it in the ~/.ell/api_keys directory for the user's convenience.
- [ ] Need a cli for managing api keys etc
## Packaging
- [ ] Write comprehensive documentation
- [x] Prepare package for distribution
- [ ] Refine and organize examples
- [x] Create production build for ell studio
- [ ] Draft contribution guidelines
- [ ] Document the code
## Misc
- [ ] Commit cosine similarity graph for browsing.
- [ ] Profile closures code.
- [ ] Add max height to the dependencies.
- [ ] Source code graph refractor (source & dependencies should be json tree of source code)
- [ ] Implement metric tracking system
- [ ] Add built-ins for classifiers (e.g., logit debiasing)
- [ ] Develop evaluator framework
- [ ] Create timeline visualization
- [ ] Implement comment system
- [ ] Add easy-to-use human evaluation tools
- [ ] Implement keyboard shortcuts for navigating invocations
- [ ] Ensure all components are linkable
- [ ] Add comparison mode for language models and double-blind setup for evaluations
- [ ] Integrate AI-assisted evaluations and metrics
- [ ] Consider developing as a VSCode plugin
- [ ] Implement organization system for multiple prompts (e.g., by module)
- [ ] Add live updates and new content indicators
- [ ] Force stores ot use the pydantic data types and dont use model dumping by default.
- [ ] Update the serializer so that we prefer stringliterals when serialziing globals
- [x] Update stores to use schema type hints and serialize to model dump in Flask (or consider switching to FastAPI)

View File

@@ -31,27 +31,215 @@ def xD():
import collections
import ast
import hashlib
import json
import os
from typing import Any, Dict, Set, Tuple
from typing import Any, Dict, Set, Tuple, Callable
import dill
import inspect
import types
from dill.source import getsource
import importlib.util
import ast
from collections import deque
import inspect
import dill.source
import re
from collections import deque
DELIM = "$$$$$$$$$$$$$$$$$$$$$$$$$"
SEPERATOR = "#------------------------"
FORBIDDEN_NAMES = ["ell", "lstr"]
def lexical_closure(
func: Any,
already_closed: Set[int] = None,
initial_call: bool = False,
recursion_stack: list = None
) -> Tuple[str, Tuple[str, str], Set[str]]:
"""
Generate a lexical closure for a given function or callable.
Args:
func: The function or callable to process.
already_closed: Set of already processed function hashes.
initial_call: Whether this is the initial call to the function.
recursion_stack: Stack to keep track of the recursion path.
Returns:
A tuple containing:
- The full source code of the closure
- A tuple of (function source, dependencies source)
- A set of function hashes that this closure uses
"""
already_closed = already_closed or set()
uses = set()
recursion_stack = recursion_stack or []
if hash(func) in already_closed:
return "", ("", ""), set()
recursion_stack.append(getattr(func, '__qualname__', str(func)))
outer_ell_func = func
while hasattr(func, "__ell_func__"):
func = func.__ell_func__
source = getsource(func, lstrip=True)
already_closed.add(hash(func))
globals_and_frees = _get_globals_and_frees(func)
dependencies, imports, modules = _process_dependencies(func, globals_and_frees, already_closed, recursion_stack)
cur_src = _build_initial_source(imports, dependencies, source)
module_src = _process_modules(modules, cur_src, already_closed, recursion_stack, uses)
dirty_src = _build_final_source(imports, module_src, dependencies, source)
dirty_src_without_func = _build_final_source(imports, module_src, dependencies, "")
CLOSURE_SOURCE[hash(func)] = dirty_src
dsrc = _clean_src(dirty_src_without_func)
fn_hash = _generate_function_hash(source, dsrc, func.__qualname__)
_update_ell_func(outer_ell_func, source, dsrc, globals_and_frees['globals'], globals_and_frees['frees'], fn_hash, uses)
return (dirty_src, (source, dsrc), ({fn_hash} if not initial_call and hasattr(outer_ell_func, "__ell_func__") else uses))
def _get_globals_and_frees(func: Callable) -> Dict[str, Dict]:
"""Get global and free variables for a function."""
globals_dict = collections.OrderedDict(dill.detect.globalvars(func))
frees_dict = collections.OrderedDict(dill.detect.freevars(func))
if isinstance(func, type):
for name, method in collections.OrderedDict(func.__dict__).items():
if isinstance(method, (types.FunctionType, types.MethodType)):
globals_dict.update(collections.OrderedDict(dill.detect.globalvars(method)))
frees_dict.update(collections.OrderedDict(dill.detect.freevars(method)))
return {'globals': globals_dict, 'frees': frees_dict}
def _process_dependencies(func, globals_and_frees, already_closed, recursion_stack):
"""Process function dependencies."""
dependencies = []
modules = deque()
imports = []
if isinstance(func, (types.FunctionType, types.MethodType)):
_process_default_kwargs(func, dependencies, already_closed, recursion_stack)
for var_name, var_value in {**globals_and_frees['globals'], **globals_and_frees['frees']}.items():
_process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack)
return dependencies, imports, modules
def _process_default_kwargs(func, dependencies, already_closed, recursion_stack):
"""Process default keyword arguments of a function."""
ps = inspect.signature(func).parameters
default_kwargs = collections.OrderedDict({k: v.default for k, v in ps.items() if v.default is not inspect.Parameter.empty})
for name, val in default_kwargs.items():
if name not in FORBIDDEN_NAMES:
try:
dep, _, _ = lexical_closure(type(val), already_closed=already_closed, recursion_stack=recursion_stack.copy())
dependencies.append(dep)
except Exception as e:
_raise_error(f"Failed to capture the lexical closure of default parameter {name}", e, recursion_stack)
def _process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack):
"""Process a single variable."""
if isinstance(var_value, (types.FunctionType, type, types.MethodType)):
_process_callable(var_name, var_value, dependencies, already_closed, recursion_stack)
elif isinstance(var_value, types.ModuleType):
_process_module(var_name, var_value, modules, imports)
elif isinstance(var_value, types.BuiltinFunctionType):
imports.append(dill.source.getimport(var_value, alias=var_name))
else:
_process_other_variable(var_name, var_value, dependencies)
def _process_callable(var_name, var_value, dependencies, already_closed, recursion_stack):
"""Process a callable (function, method, or class)."""
if var_name not in FORBIDDEN_NAMES:
try:
dep, _, _ = lexical_closure(var_value, already_closed=already_closed, recursion_stack=recursion_stack.copy())
dependencies.append(dep)
except Exception as e:
_raise_error(f"Failed to capture the lexical closure of global or free variable {var_name}", e, recursion_stack)
def _process_module(var_name, var_value, modules, imports):
"""Process a module."""
if should_import(var_value):
imports.append(dill.source.getimport(var_value, alias=var_name))
else:
modules.append((var_name, var_value))
def _process_other_variable(var_name, var_value, dependencies):
"""Process variables that are not callables or modules."""
if isinstance(var_value, str) and '\n' in var_value:
dependencies.append(f"{var_name} = '''{var_value}'''")
elif is_immutable_variable(var_value):
dependencies.append(f"#<BV>\n{var_name} = {repr(var_value)}\n#</BV>")
else:
dependencies.append(f"#<BmV>\n{var_name} = <{type(var_value).__name__} object>\n#</BmV>")
def _build_initial_source(imports, dependencies, source):
"""Build the initial source code."""
return f"{DELIM}\n" + f"\n{DELIM}\n".join(imports + dependencies + [source]) + f"\n{DELIM}\n"
def _process_modules(modules, cur_src, already_closed, recursion_stack, uses):
"""Process module dependencies."""
reverse_module_src = deque()
while modules:
mname, mval = modules.popleft()
mdeps = []
attrs_to_extract = get_referenced_names(cur_src.replace(DELIM, ""), mname)
for attr in attrs_to_extract:
_process_module_attribute(mname, mval, attr, mdeps, modules, already_closed, recursion_stack, uses)
mdeps.insert(0, f"# Extracted from module: {mname}")
reverse_module_src.appendleft("\n".join(mdeps))
cur_src = _dereference_module_names(cur_src, mname, attrs_to_extract)
return list(reverse_module_src)
def _process_module_attribute(mname, mval, attr, mdeps, modules, already_closed, recursion_stack, uses):
"""Process a single attribute of a module."""
val = getattr(mval, attr)
if isinstance(val, (types.FunctionType, type, types.MethodType)):
try:
dep, _, dep_uses = lexical_closure(val, already_closed=already_closed, recursion_stack=recursion_stack.copy())
mdeps.append(dep)
uses.update(dep_uses)
except Exception as e:
_raise_error(f"Failed to capture the lexical closure of {mname}.{attr}", e, recursion_stack)
elif isinstance(val, types.ModuleType):
modules.append((attr, val))
else:
mdeps.append(f"{attr} = {repr(val)}")
def _dereference_module_names(cur_src, mname, attrs_to_extract):
"""Dereference module names in the source code."""
for attr in attrs_to_extract:
cur_src = cur_src.replace(f"{mname}.{attr}", attr)
return cur_src
def _build_final_source(imports, module_src, dependencies, source):
"""Build the final source code."""
seperated_dependencies = sorted(imports) + sorted(module_src) + sorted(dependencies) + ([source] if source else [])
seperated_dependencies = list(dict.fromkeys(seperated_dependencies))
return DELIM + "\n" + f"\n{DELIM}\n".join(seperated_dependencies) + "\n" + DELIM + "\n"
def _generate_function_hash(source, dsrc, qualname):
"""Generate a hash for the function."""
return "lmp-" + hashlib.md5("\n".join((source, dsrc, qualname)).encode()).hexdigest()
def _update_ell_func(outer_ell_func, source, dsrc, globals_dict, frees_dict, fn_hash, uses):
"""Update the ell function attributes."""
if hasattr(outer_ell_func, "__ell_func__"):
outer_ell_func.__ell_closure__ = (source, dsrc, globals_dict, frees_dict)
outer_ell_func.__ell_hash__ = fn_hash
outer_ell_func.__ell_uses__ = uses
def _raise_error(message, exception, recursion_stack):
"""Raise an error with detailed information."""
error_msg = f"{message}. Error: {str(exception)}\n"
error_msg += f"Recursion stack: {' -> '.join(recursion_stack)}"
raise Exception(error_msg)
def is_immutable_variable(value):
"""
Check if a value is immutable.
@@ -89,7 +277,6 @@ def is_immutable_variable(value):
return False
def should_import(module: types.ModuleType):
"""
This function checks if a module should be imported based on its origin.
@@ -118,10 +305,6 @@ def should_import(module: types.ModuleType):
# Otherwise, return True
return True
import ast
def get_referenced_names(code: str, module_name: str):
"""
This function takes a block of code and a module name as input. It parses the code into an Abstract Syntax Tree (AST)
@@ -144,232 +327,8 @@ def get_referenced_names(code: str, module_name: str):
return referenced_names
CLOSURE_SOURCE: Dict[str, str] = {}
def lexical_closure(func: Any, already_closed=None, initial_call=False, recursion_stack=None) -> Tuple[str, Tuple[str, str], Set[str]]:
"""
This function takes a function or any callable as input and returns a string representation of its lexical closure.
The lexical closure includes the source code of the function itself, as well as the source code of any global variables,
free variables, and dependencies (other functions or classes) that it references.
If the input function is a method of a class, this function will also find and include the source code of other methods
in the same class that are referenced by the input function.
The resulting string can be used to recreate the function in a different context, with all of its dependencies.
Parameters:
func (Callable): The function or callable whose lexical closure is to be found.
already_closed (set): Set of already processed functions to avoid infinite recursion.
initial_call (bool): Whether this is the initial call to the function.
recursion_stack (list): Stack to keep track of the recursion path.
Returns:
str: A string representation of the lexical closure of the input function.
"""
already_closed = already_closed or set()
uses = set()
recursion_stack = recursion_stack or []
if hash(func) in already_closed:
return "", ("", ""), {}
recursion_stack.append(func.__qualname__ if hasattr(func, '__qualname__') else str(func))
outer_ell_func = func
while hasattr(func, "__ell_func__"):
func = func.__ell_func__
source = getsource(func, lstrip=True)
already_closed.add(hash(func))
# if func is nested func
# Parse the source code into an AST
# tree = ast.parse(source)
# Find all the global variables and free variables in the function
# These are not global variables these are globals, and other shit is actualy in cluded here
_globals = collections.OrderedDict(dill.detect.globalvars(func))
print(_globals)
_frees = collections.OrderedDict(dill.detect.freevars(func))
# If func is a class we actually should check all the methods of the class for globalvars. Malekdiction (MSM) was here.
# Add the default aprameter tpes to depndencies if they are not builtins
if isinstance(func, type):
# Now we need to get all the global vars in the class
for name, method in collections.OrderedDict(func.__dict__).items():
if isinstance(method, types.FunctionType) or isinstance(
method, types.MethodType
):
_globals.update(
collections.OrderedDict(dill.detect.globalvars(method))
)
_frees.update(collections.OrderedDict(dill.detect.freevars(method)))
# Initialize a list to store the source code of the dependencies
dependencies = []
modules = deque()
imports = []
if isinstance(func, (types.FunctionType, types.MethodType)):
# Get all the the default kwargs
ps = inspect.signature(func).parameters
default_kwargs = collections.OrderedDict(
{
k: v.default
for k, v in ps.items()
if v.default is not inspect.Parameter.empty
}
)
for name, val in default_kwargs.items():
try:
if name not in FORBIDDEN_NAMES:
dep, _, dep_uses = lexical_closure(
type(val), already_closed=already_closed,
recursion_stack=recursion_stack.copy()
)
dependencies.append(dep)
uses.update(dep_uses)
except Exception as e:
error_msg = f"Failed to capture the lexical closure of default parameter {name}. Error: {str(e)}\n"
error_msg += f"Recursion stack: {' -> '.join(recursion_stack)}"
raise Exception(error_msg)
# Iterate over the global variables
for var_name, var_value in {**_globals, **_frees}.items():
is_free = var_name in _frees
# If the variable is a function, get its source code
if isinstance(var_value, (types.FunctionType, type, types.MethodType)):
if var_name not in FORBIDDEN_NAMES:
try:
ret = lexical_closure(
var_value, already_closed=already_closed,
recursion_stack=recursion_stack.copy()
)
dep, _, dep_uses = ret
dependencies.append(dep)
# See if the function was called at all in the source code of the func
# This is wrong because if its a referred call it won't track the dependency; so we actually need to trace all dependencies that are not ell funcs to see if they call it as well.
if is_function_called(var_name, source):
uses.update(dep_uses)
except Exception as e:
error_msg = f"Failed to capture the lexical closure of global or free variabl evariable {var_name}. Error: {str(e)}\n"
error_msg += f"Recursion stack: {' -> '.join(recursion_stack)}"
raise Exception(error_msg)
elif isinstance(var_value, types.ModuleType):
if should_import(var_value):
imports += [dill.source.getimport(var_value, alias=var_name)]
else:
# Now we need to find all the variables in this module that were referenced
modules.append((var_name, var_value))
elif isinstance(var_value, types.BuiltinFunctionType):
# we need to get an import for it
imports += [dill.source.getimport(var_value, alias=var_name)]
else:
json_default = lambda x: f"<Object of type {type(x).__name__}>"
if isinstance(var_value, str) and '\n' in var_value:
dependencies.append(f"{var_name} = '''{var_value}'''")
else:
# if is immutable
if is_immutable_variable(var_value) and not is_free:
dependencies.append(f"#<BV>\n{var_name} = {repr(var_value)}\n#</BV>")
else:
dependencies.append(f"#<BmV>\n{var_name} = <{type(var_value).__name__} object>\n#</BmV>")
# We probably need to resovle things with topological sort & turn stuff into a dag but for now we can just do this
cur_src = (
DELIM
+ "\n"
+ f"\n{DELIM}\n".join(imports + dependencies)
+ "\n"
+ DELIM
+ "\n"
+ source
+ "\n"
+ DELIM
+ "\n"
)
reverse_module_src = deque()
while len(modules) > 0:
mname, mval = modules.popleft()
mdeps = []
attrs_to_extract = get_referenced_names(cur_src.replace(DELIM, ""), mname)
for attr in attrs_to_extract:
val = getattr(mval, attr)
if isinstance(val, (types.FunctionType, type, types.MethodType)):
try:
dep, _, dep_uses = lexical_closure(
val, already_closed=already_closed,
recursion_stack=recursion_stack.copy()
)
mdeps.append(dep)
uses.update(dep_uses)
except Exception as e:
error_msg = f"Failed to capture the lexical closure of {mname}.{attr}. Error: {str(e)}\n"
error_msg += f"Recursion stack: {' -> '.join(recursion_stack)}"
raise Exception(error_msg)
elif isinstance(val, types.ModuleType):
modules.append((attr, val))
else:
# If its another module we need to add it to the list of modules
mdeps.append(f"{attr} = {repr(val)}")
mdeps.insert(0, f"# Extracted from module: {mname}")
# Now let's dereference all the module names in our cur_src
for attr in attrs_to_extract:
# Go throught hte dependencies and replace all the module names with the attr
source = source.replace(f"{mname}.{attr}", attr)
dependencies = [
dep.replace(f"{mname}.{attr}", attr) for dep in dependencies
]
# Now add all the module dependencies to the top of the list
reverse_module_src.appendleft("\n".join(mdeps))
# Now we need to add the module dependencies to the top of the source
# Sort the dependencies
dependencies = sorted(dependencies)
imports = sorted(imports)
reverse_module_src = sorted(reverse_module_src)
seperated_dependencies = (
imports
+ list(reverse_module_src)
+ dependencies
+ [source]
)
# Remove duplicates and preserve order
seperated_dependencies = list(dict.fromkeys(seperated_dependencies))
dirty_src = DELIM + "\n" + f"\n{DELIM}\n".join(seperated_dependencies) + "\n" + DELIM + "\n"
dirty_src_without_func = DELIM + "\n" + f"\n{DELIM}\n".join(seperated_dependencies[:-1]) + "\n" + DELIM + "\n"
CLOSURE_SOURCE[hash(func)] = dirty_src
dsrc = _clean_src(dirty_src_without_func)
fn_hash = "lmp-" + hashlib.md5(
"\n".join((source, dsrc, func.__qualname__)).encode()
).hexdigest()
if hasattr(outer_ell_func, "__ell_func__"):
outer_ell_func.__ell_closure__ = (source, dsrc, _globals, _frees)
outer_ell_func.__ell_hash__ = fn_hash
outer_ell_func.__ell_uses__ = uses
return (dirty_src, (source, dsrc), ({fn_hash} if not initial_call and hasattr(outer_ell_func, "__ell_func__") else uses))
def lexically_closured_source(func):
_, fnclosure, uses = lexical_closure(func, initial_call=True, recursion_stack=[])
return fnclosure, uses
@@ -377,7 +336,6 @@ def lexically_closured_source(func):
import ast
def _clean_src(dirty_src):
# Now remove all duplicates and preserve order
split_by_setion = filter(lambda x: len(x.strip()) > 0, dirty_src.split(DELIM))
@@ -401,7 +359,6 @@ def _clean_src(dirty_src):
return final_src
def is_function_called(func_name, source_code):
"""
Check if a function is called in the given source code.