mirror of
https://github.com/MadcowD/ell.git
synced 2024-09-22 16:14:36 +03:00
readability refactor of closure
This commit is contained in:
129
README.md
129
README.md
@@ -128,132 +128,3 @@ You can then visualize your promtps by visiting the frontend on `http://localhos
|
||||
- Convert all of our todos into issues and milestones
|
||||
- Multimodality
|
||||
- Output parsing.
|
||||
|
||||
## Todos
|
||||
|
||||
### Metrics
|
||||
|
||||
- [ ] Design the metrics functionality. (maybe link WandB lol)
|
||||
|
||||
### Bugs
|
||||
|
||||
- [ ] Fix weird rehashing issue of the main prompt whenever subprompt changes? Or just make commits more of a background deal.
|
||||
- [x] Trace not writing on first invoc.
|
||||
- [ ] Serialize lkstrs in the jkson dumps in pyhton the same way as the db serializers them for the frontend (\_\_lstr vs SerialziedLstr) <- these are pydantic models and so we can reuse them
|
||||
- [ ] handle failure to serialize.
|
||||
- [ ] Unify cattrs deserialziation and serialization its fucked right now.
|
||||
- [ ] Fix URL not changing on invocation click
|
||||
|
||||
## Tests
|
||||
|
||||
- [ ] Add comprehensive unit tests for all core functionalities.
|
||||
- [ ] Implement integration tests for end-to-end workflows.
|
||||
- [ ] Optimize backend performance and run benchmarks.
|
||||
|
||||
## Trace Functionality
|
||||
|
||||
- [x] Visualize trace in graph
|
||||
- [x] Implement Langsmith-style invocations and traces
|
||||
- [x] Improve UX on traces
|
||||
- [o] Complete full trace implementation on invocation page
|
||||
- [x] Enhance UX around traces in dependency graphs
|
||||
- [o] Implement argument pass-through functionality
|
||||
- [ ] Working trace graph thing
|
||||
|
||||
## Ell Studio
|
||||
- [ ] Optimize front-end data fetching & api structure
|
||||
- [ ] Add websockets for live updates.
|
||||
- [ ] Add a section for the actual LLM call
|
||||
|
||||
## Version History
|
||||
|
||||
- [x] Auto-document commit changes
|
||||
- [x] Implement version history diff view (possibly with automatic commit messages using GPT-4)
|
||||
- [ ] Add interactive diff view
|
||||
- [ ] Highlight changes in source when switching versions
|
||||
- [ ] Automatically format python code on serialization (version doesnt change just because formatting does.)
|
||||
|
||||
## Caching and versioning
|
||||
|
||||
- [x] Get caching to work for bound global mutable types
|
||||
- [x] Decide if version not change just because I invoke an LMP using a global that was mutated some where else in the program. (In example where we had test being global with a list of objects of type Test) that shouldn’t change the version number. My vote is that it shouldn’t change.
|
||||
- [x] Bind global mutable at invocation, and update our Invocation database types
|
||||
- [x] Come up with a nice way to expand and collapse stored globals in dependency scope.disc
|
||||
- [x] Serialize globals free variables at time of definition
|
||||
- [x] Come up with a nice way to display bound globals of an invocation
|
||||
- [x] Custom syntax highlighting for these bound mutable objects
|
||||
- [x] When we click invocation, the global if its overridden at the time execution needs to be updated.
|
||||
- [x] Decide where the line between program state and program version changes. If we had a global variable which was a list of [“asd”]*10000 do we include the contents of this shit in the version.
|
||||
|
||||
- [ ] Better definition disclosure (and hiding) for bound global variables
|
||||
- [ ] Design or steal vs-code’s debug instance widget
|
||||
- [ ] Bind all arguments to their names in the function signature (pre v1)
|
||||
- [ ] Compute type lexical closures as needed (post v1)
|
||||
|
||||
## LM Functionality
|
||||
|
||||
- [x] Freezing and unfreezing LMPs
|
||||
- [ ] Support multimodal inputs
|
||||
- [ ] Implement function calling
|
||||
- [ ] Add persistent chatting capability
|
||||
- [ ] Integrate with various LLM providers
|
||||
|
||||
## Use Cases
|
||||
|
||||
- [ ] Develop RAG (Retrieval-Augmented Generation) example
|
||||
- [ ] Implement embeddings functionality
|
||||
- [ ] Create examples for tool use and agents
|
||||
- [ ] Demonstrate Chain of Thought (CoT) reasoning
|
||||
- [ ] Showcase optimization techniques
|
||||
|
||||
## Store
|
||||
|
||||
- [ ] Improve developer experience around logging mechanisms
|
||||
- [ ] Drastically improve query runtime and performance test.
|
||||
|
||||
## DX (Developer Experience)
|
||||
|
||||
- [x] Enhance UX for the LMP details page
|
||||
- [ ] Improve layout for the differnet parts
|
||||
- [ ] Add Dependency Graph on LMP page
|
||||
- [ ] Implement VSCode-style file explorer
|
||||
- [ ] Ensure and test Jupyter compatibility
|
||||
- [ ] Continue UI/UX improvements for the visualization component
|
||||
- [x] Update LMP Details to be function-based for easier result viewing
|
||||
- [ ] Implement easy navigation similar to VSCode (cmd+shift+p or spotlight)
|
||||
- [x] Optimize display of dependencies on prompt pages
|
||||
|
||||
- [ ] Automatic API key management if not specified as a nice experience. Should ask for the api key and store it in the ~/.ell/api_keys directory for the user's convenience.
|
||||
- [ ] Need a cli for managing api keys etc
|
||||
|
||||
## Packaging
|
||||
|
||||
- [ ] Write comprehensive documentation
|
||||
- [x] Prepare package for distribution
|
||||
- [ ] Refine and organize examples
|
||||
- [x] Create production build for ell studio
|
||||
- [ ] Draft contribution guidelines
|
||||
- [ ] Document the code
|
||||
|
||||
|
||||
## Misc
|
||||
- [ ] Commit cosine similarity graph for browsing.
|
||||
- [ ] Profile closures code.
|
||||
- [ ] Add max height to the dependencies.
|
||||
- [ ] Source code graph refractor (source & dependencies should be json tree of source code)
|
||||
- [ ] Implement metric tracking system
|
||||
- [ ] Add built-ins for classifiers (e.g., logit debiasing)
|
||||
- [ ] Develop evaluator framework
|
||||
- [ ] Create timeline visualization
|
||||
- [ ] Implement comment system
|
||||
- [ ] Add easy-to-use human evaluation tools
|
||||
- [ ] Implement keyboard shortcuts for navigating invocations
|
||||
- [ ] Ensure all components are linkable
|
||||
- [ ] Add comparison mode for language models and double-blind setup for evaluations
|
||||
- [ ] Integrate AI-assisted evaluations and metrics
|
||||
- [ ] Consider developing as a VSCode plugin
|
||||
- [ ] Implement organization system for multiple prompts (e.g., by module)
|
||||
- [ ] Add live updates and new content indicators
|
||||
- [ ] Force stores ot use the pydantic data types and dont use model dumping by default.
|
||||
- [ ] Update the serializer so that we prefer stringliterals when serialziing globals
|
||||
- [x] Update stores to use schema type hints and serialize to model dump in Flask (or consider switching to FastAPI)
|
||||
|
||||
@@ -31,27 +31,215 @@ def xD():
|
||||
import collections
|
||||
import ast
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Set, Tuple
|
||||
from typing import Any, Dict, Set, Tuple, Callable
|
||||
import dill
|
||||
import inspect
|
||||
import types
|
||||
from dill.source import getsource
|
||||
|
||||
import importlib.util
|
||||
|
||||
import ast
|
||||
|
||||
from collections import deque
|
||||
import inspect
|
||||
|
||||
import dill.source
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
DELIM = "$$$$$$$$$$$$$$$$$$$$$$$$$"
|
||||
SEPERATOR = "#------------------------"
|
||||
FORBIDDEN_NAMES = ["ell", "lstr"]
|
||||
|
||||
def lexical_closure(
|
||||
func: Any,
|
||||
already_closed: Set[int] = None,
|
||||
initial_call: bool = False,
|
||||
recursion_stack: list = None
|
||||
) -> Tuple[str, Tuple[str, str], Set[str]]:
|
||||
"""
|
||||
Generate a lexical closure for a given function or callable.
|
||||
|
||||
Args:
|
||||
func: The function or callable to process.
|
||||
already_closed: Set of already processed function hashes.
|
||||
initial_call: Whether this is the initial call to the function.
|
||||
recursion_stack: Stack to keep track of the recursion path.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The full source code of the closure
|
||||
- A tuple of (function source, dependencies source)
|
||||
- A set of function hashes that this closure uses
|
||||
"""
|
||||
already_closed = already_closed or set()
|
||||
uses = set()
|
||||
recursion_stack = recursion_stack or []
|
||||
|
||||
if hash(func) in already_closed:
|
||||
return "", ("", ""), set()
|
||||
|
||||
recursion_stack.append(getattr(func, '__qualname__', str(func)))
|
||||
|
||||
outer_ell_func = func
|
||||
while hasattr(func, "__ell_func__"):
|
||||
func = func.__ell_func__
|
||||
|
||||
source = getsource(func, lstrip=True)
|
||||
already_closed.add(hash(func))
|
||||
|
||||
globals_and_frees = _get_globals_and_frees(func)
|
||||
dependencies, imports, modules = _process_dependencies(func, globals_and_frees, already_closed, recursion_stack)
|
||||
|
||||
cur_src = _build_initial_source(imports, dependencies, source)
|
||||
|
||||
module_src = _process_modules(modules, cur_src, already_closed, recursion_stack, uses)
|
||||
|
||||
dirty_src = _build_final_source(imports, module_src, dependencies, source)
|
||||
dirty_src_without_func = _build_final_source(imports, module_src, dependencies, "")
|
||||
|
||||
CLOSURE_SOURCE[hash(func)] = dirty_src
|
||||
|
||||
dsrc = _clean_src(dirty_src_without_func)
|
||||
fn_hash = _generate_function_hash(source, dsrc, func.__qualname__)
|
||||
|
||||
_update_ell_func(outer_ell_func, source, dsrc, globals_and_frees['globals'], globals_and_frees['frees'], fn_hash, uses)
|
||||
|
||||
return (dirty_src, (source, dsrc), ({fn_hash} if not initial_call and hasattr(outer_ell_func, "__ell_func__") else uses))
|
||||
|
||||
def _get_globals_and_frees(func: Callable) -> Dict[str, Dict]:
|
||||
"""Get global and free variables for a function."""
|
||||
globals_dict = collections.OrderedDict(dill.detect.globalvars(func))
|
||||
frees_dict = collections.OrderedDict(dill.detect.freevars(func))
|
||||
|
||||
if isinstance(func, type):
|
||||
for name, method in collections.OrderedDict(func.__dict__).items():
|
||||
if isinstance(method, (types.FunctionType, types.MethodType)):
|
||||
globals_dict.update(collections.OrderedDict(dill.detect.globalvars(method)))
|
||||
frees_dict.update(collections.OrderedDict(dill.detect.freevars(method)))
|
||||
|
||||
return {'globals': globals_dict, 'frees': frees_dict}
|
||||
|
||||
def _process_dependencies(func, globals_and_frees, already_closed, recursion_stack):
|
||||
"""Process function dependencies."""
|
||||
dependencies = []
|
||||
modules = deque()
|
||||
imports = []
|
||||
|
||||
if isinstance(func, (types.FunctionType, types.MethodType)):
|
||||
_process_default_kwargs(func, dependencies, already_closed, recursion_stack)
|
||||
|
||||
for var_name, var_value in {**globals_and_frees['globals'], **globals_and_frees['frees']}.items():
|
||||
_process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack)
|
||||
|
||||
return dependencies, imports, modules
|
||||
|
||||
def _process_default_kwargs(func, dependencies, already_closed, recursion_stack):
|
||||
"""Process default keyword arguments of a function."""
|
||||
ps = inspect.signature(func).parameters
|
||||
default_kwargs = collections.OrderedDict({k: v.default for k, v in ps.items() if v.default is not inspect.Parameter.empty})
|
||||
for name, val in default_kwargs.items():
|
||||
if name not in FORBIDDEN_NAMES:
|
||||
try:
|
||||
dep, _, _ = lexical_closure(type(val), already_closed=already_closed, recursion_stack=recursion_stack.copy())
|
||||
dependencies.append(dep)
|
||||
except Exception as e:
|
||||
_raise_error(f"Failed to capture the lexical closure of default parameter {name}", e, recursion_stack)
|
||||
|
||||
def _process_variable(var_name, var_value, dependencies, modules, imports, already_closed, recursion_stack):
|
||||
"""Process a single variable."""
|
||||
if isinstance(var_value, (types.FunctionType, type, types.MethodType)):
|
||||
_process_callable(var_name, var_value, dependencies, already_closed, recursion_stack)
|
||||
elif isinstance(var_value, types.ModuleType):
|
||||
_process_module(var_name, var_value, modules, imports)
|
||||
elif isinstance(var_value, types.BuiltinFunctionType):
|
||||
imports.append(dill.source.getimport(var_value, alias=var_name))
|
||||
else:
|
||||
_process_other_variable(var_name, var_value, dependencies)
|
||||
|
||||
def _process_callable(var_name, var_value, dependencies, already_closed, recursion_stack):
|
||||
"""Process a callable (function, method, or class)."""
|
||||
if var_name not in FORBIDDEN_NAMES:
|
||||
try:
|
||||
dep, _, _ = lexical_closure(var_value, already_closed=already_closed, recursion_stack=recursion_stack.copy())
|
||||
dependencies.append(dep)
|
||||
except Exception as e:
|
||||
_raise_error(f"Failed to capture the lexical closure of global or free variable {var_name}", e, recursion_stack)
|
||||
|
||||
def _process_module(var_name, var_value, modules, imports):
|
||||
"""Process a module."""
|
||||
if should_import(var_value):
|
||||
imports.append(dill.source.getimport(var_value, alias=var_name))
|
||||
else:
|
||||
modules.append((var_name, var_value))
|
||||
|
||||
def _process_other_variable(var_name, var_value, dependencies):
|
||||
"""Process variables that are not callables or modules."""
|
||||
if isinstance(var_value, str) and '\n' in var_value:
|
||||
dependencies.append(f"{var_name} = '''{var_value}'''")
|
||||
elif is_immutable_variable(var_value):
|
||||
dependencies.append(f"#<BV>\n{var_name} = {repr(var_value)}\n#</BV>")
|
||||
else:
|
||||
dependencies.append(f"#<BmV>\n{var_name} = <{type(var_value).__name__} object>\n#</BmV>")
|
||||
|
||||
def _build_initial_source(imports, dependencies, source):
|
||||
"""Build the initial source code."""
|
||||
return f"{DELIM}\n" + f"\n{DELIM}\n".join(imports + dependencies + [source]) + f"\n{DELIM}\n"
|
||||
|
||||
def _process_modules(modules, cur_src, already_closed, recursion_stack, uses):
|
||||
"""Process module dependencies."""
|
||||
reverse_module_src = deque()
|
||||
while modules:
|
||||
mname, mval = modules.popleft()
|
||||
mdeps = []
|
||||
attrs_to_extract = get_referenced_names(cur_src.replace(DELIM, ""), mname)
|
||||
for attr in attrs_to_extract:
|
||||
_process_module_attribute(mname, mval, attr, mdeps, modules, already_closed, recursion_stack, uses)
|
||||
|
||||
mdeps.insert(0, f"# Extracted from module: {mname}")
|
||||
reverse_module_src.appendleft("\n".join(mdeps))
|
||||
|
||||
cur_src = _dereference_module_names(cur_src, mname, attrs_to_extract)
|
||||
|
||||
return list(reverse_module_src)
|
||||
|
||||
def _process_module_attribute(mname, mval, attr, mdeps, modules, already_closed, recursion_stack, uses):
|
||||
"""Process a single attribute of a module."""
|
||||
val = getattr(mval, attr)
|
||||
if isinstance(val, (types.FunctionType, type, types.MethodType)):
|
||||
try:
|
||||
dep, _, dep_uses = lexical_closure(val, already_closed=already_closed, recursion_stack=recursion_stack.copy())
|
||||
mdeps.append(dep)
|
||||
uses.update(dep_uses)
|
||||
except Exception as e:
|
||||
_raise_error(f"Failed to capture the lexical closure of {mname}.{attr}", e, recursion_stack)
|
||||
elif isinstance(val, types.ModuleType):
|
||||
modules.append((attr, val))
|
||||
else:
|
||||
mdeps.append(f"{attr} = {repr(val)}")
|
||||
|
||||
def _dereference_module_names(cur_src, mname, attrs_to_extract):
|
||||
"""Dereference module names in the source code."""
|
||||
for attr in attrs_to_extract:
|
||||
cur_src = cur_src.replace(f"{mname}.{attr}", attr)
|
||||
return cur_src
|
||||
|
||||
def _build_final_source(imports, module_src, dependencies, source):
|
||||
"""Build the final source code."""
|
||||
seperated_dependencies = sorted(imports) + sorted(module_src) + sorted(dependencies) + ([source] if source else [])
|
||||
seperated_dependencies = list(dict.fromkeys(seperated_dependencies))
|
||||
return DELIM + "\n" + f"\n{DELIM}\n".join(seperated_dependencies) + "\n" + DELIM + "\n"
|
||||
|
||||
def _generate_function_hash(source, dsrc, qualname):
|
||||
"""Generate a hash for the function."""
|
||||
return "lmp-" + hashlib.md5("\n".join((source, dsrc, qualname)).encode()).hexdigest()
|
||||
|
||||
def _update_ell_func(outer_ell_func, source, dsrc, globals_dict, frees_dict, fn_hash, uses):
|
||||
"""Update the ell function attributes."""
|
||||
if hasattr(outer_ell_func, "__ell_func__"):
|
||||
outer_ell_func.__ell_closure__ = (source, dsrc, globals_dict, frees_dict)
|
||||
outer_ell_func.__ell_hash__ = fn_hash
|
||||
outer_ell_func.__ell_uses__ = uses
|
||||
|
||||
def _raise_error(message, exception, recursion_stack):
|
||||
"""Raise an error with detailed information."""
|
||||
error_msg = f"{message}. Error: {str(exception)}\n"
|
||||
error_msg += f"Recursion stack: {' -> '.join(recursion_stack)}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
def is_immutable_variable(value):
|
||||
"""
|
||||
Check if a value is immutable.
|
||||
@@ -89,7 +277,6 @@ def is_immutable_variable(value):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def should_import(module: types.ModuleType):
|
||||
"""
|
||||
This function checks if a module should be imported based on its origin.
|
||||
@@ -118,10 +305,6 @@ def should_import(module: types.ModuleType):
|
||||
# Otherwise, return True
|
||||
return True
|
||||
|
||||
|
||||
import ast
|
||||
|
||||
|
||||
def get_referenced_names(code: str, module_name: str):
|
||||
"""
|
||||
This function takes a block of code and a module name as input. It parses the code into an Abstract Syntax Tree (AST)
|
||||
@@ -144,232 +327,8 @@ def get_referenced_names(code: str, module_name: str):
|
||||
|
||||
return referenced_names
|
||||
|
||||
|
||||
CLOSURE_SOURCE: Dict[str, str] = {}
|
||||
|
||||
|
||||
def lexical_closure(func: Any, already_closed=None, initial_call=False, recursion_stack=None) -> Tuple[str, Tuple[str, str], Set[str]]:
|
||||
"""
|
||||
This function takes a function or any callable as input and returns a string representation of its lexical closure.
|
||||
The lexical closure includes the source code of the function itself, as well as the source code of any global variables,
|
||||
free variables, and dependencies (other functions or classes) that it references.
|
||||
|
||||
If the input function is a method of a class, this function will also find and include the source code of other methods
|
||||
in the same class that are referenced by the input function.
|
||||
|
||||
The resulting string can be used to recreate the function in a different context, with all of its dependencies.
|
||||
|
||||
Parameters:
|
||||
func (Callable): The function or callable whose lexical closure is to be found.
|
||||
already_closed (set): Set of already processed functions to avoid infinite recursion.
|
||||
initial_call (bool): Whether this is the initial call to the function.
|
||||
recursion_stack (list): Stack to keep track of the recursion path.
|
||||
|
||||
Returns:
|
||||
str: A string representation of the lexical closure of the input function.
|
||||
"""
|
||||
|
||||
already_closed = already_closed or set()
|
||||
uses = set()
|
||||
recursion_stack = recursion_stack or []
|
||||
|
||||
if hash(func) in already_closed:
|
||||
return "", ("", ""), {}
|
||||
|
||||
recursion_stack.append(func.__qualname__ if hasattr(func, '__qualname__') else str(func))
|
||||
|
||||
outer_ell_func = func
|
||||
while hasattr(func, "__ell_func__"):
|
||||
func = func.__ell_func__
|
||||
|
||||
source = getsource(func, lstrip=True)
|
||||
already_closed.add(hash(func))
|
||||
# if func is nested func
|
||||
# Parse the source code into an AST
|
||||
# tree = ast.parse(source)
|
||||
# Find all the global variables and free variables in the function
|
||||
|
||||
# These are not global variables these are globals, and other shit is actualy in cluded here
|
||||
_globals = collections.OrderedDict(dill.detect.globalvars(func))
|
||||
print(_globals)
|
||||
_frees = collections.OrderedDict(dill.detect.freevars(func))
|
||||
|
||||
# If func is a class we actually should check all the methods of the class for globalvars. Malekdiction (MSM) was here.
|
||||
# Add the default aprameter tpes to depndencies if they are not builtins
|
||||
|
||||
if isinstance(func, type):
|
||||
# Now we need to get all the global vars in the class
|
||||
for name, method in collections.OrderedDict(func.__dict__).items():
|
||||
if isinstance(method, types.FunctionType) or isinstance(
|
||||
method, types.MethodType
|
||||
):
|
||||
_globals.update(
|
||||
collections.OrderedDict(dill.detect.globalvars(method))
|
||||
)
|
||||
_frees.update(collections.OrderedDict(dill.detect.freevars(method)))
|
||||
|
||||
# Initialize a list to store the source code of the dependencies
|
||||
dependencies = []
|
||||
modules = deque()
|
||||
imports = []
|
||||
|
||||
if isinstance(func, (types.FunctionType, types.MethodType)):
|
||||
# Get all the the default kwargs
|
||||
ps = inspect.signature(func).parameters
|
||||
default_kwargs = collections.OrderedDict(
|
||||
{
|
||||
k: v.default
|
||||
for k, v in ps.items()
|
||||
if v.default is not inspect.Parameter.empty
|
||||
}
|
||||
)
|
||||
for name, val in default_kwargs.items():
|
||||
try:
|
||||
if name not in FORBIDDEN_NAMES:
|
||||
dep, _, dep_uses = lexical_closure(
|
||||
type(val), already_closed=already_closed,
|
||||
recursion_stack=recursion_stack.copy()
|
||||
)
|
||||
dependencies.append(dep)
|
||||
uses.update(dep_uses)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to capture the lexical closure of default parameter {name}. Error: {str(e)}\n"
|
||||
error_msg += f"Recursion stack: {' -> '.join(recursion_stack)}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
# Iterate over the global variables
|
||||
for var_name, var_value in {**_globals, **_frees}.items():
|
||||
is_free = var_name in _frees
|
||||
# If the variable is a function, get its source code
|
||||
if isinstance(var_value, (types.FunctionType, type, types.MethodType)):
|
||||
if var_name not in FORBIDDEN_NAMES:
|
||||
try:
|
||||
ret = lexical_closure(
|
||||
var_value, already_closed=already_closed,
|
||||
recursion_stack=recursion_stack.copy()
|
||||
)
|
||||
dep, _, dep_uses = ret
|
||||
dependencies.append(dep)
|
||||
# See if the function was called at all in the source code of the func
|
||||
# This is wrong because if its a referred call it won't track the dependency; so we actually need to trace all dependencies that are not ell funcs to see if they call it as well.
|
||||
if is_function_called(var_name, source):
|
||||
uses.update(dep_uses)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to capture the lexical closure of global or free variabl evariable {var_name}. Error: {str(e)}\n"
|
||||
error_msg += f"Recursion stack: {' -> '.join(recursion_stack)}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
elif isinstance(var_value, types.ModuleType):
|
||||
if should_import(var_value):
|
||||
imports += [dill.source.getimport(var_value, alias=var_name)]
|
||||
|
||||
else:
|
||||
# Now we need to find all the variables in this module that were referenced
|
||||
modules.append((var_name, var_value))
|
||||
elif isinstance(var_value, types.BuiltinFunctionType):
|
||||
# we need to get an import for it
|
||||
|
||||
imports += [dill.source.getimport(var_value, alias=var_name)]
|
||||
|
||||
else:
|
||||
json_default = lambda x: f"<Object of type {type(x).__name__}>"
|
||||
if isinstance(var_value, str) and '\n' in var_value:
|
||||
dependencies.append(f"{var_name} = '''{var_value}'''")
|
||||
else:
|
||||
# if is immutable
|
||||
if is_immutable_variable(var_value) and not is_free:
|
||||
dependencies.append(f"#<BV>\n{var_name} = {repr(var_value)}\n#</BV>")
|
||||
else:
|
||||
|
||||
dependencies.append(f"#<BmV>\n{var_name} = <{type(var_value).__name__} object>\n#</BmV>")
|
||||
|
||||
# We probably need to resovle things with topological sort & turn stuff into a dag but for now we can just do this
|
||||
|
||||
cur_src = (
|
||||
DELIM
|
||||
+ "\n"
|
||||
+ f"\n{DELIM}\n".join(imports + dependencies)
|
||||
+ "\n"
|
||||
+ DELIM
|
||||
+ "\n"
|
||||
+ source
|
||||
+ "\n"
|
||||
+ DELIM
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
reverse_module_src = deque()
|
||||
while len(modules) > 0:
|
||||
mname, mval = modules.popleft()
|
||||
mdeps = []
|
||||
attrs_to_extract = get_referenced_names(cur_src.replace(DELIM, ""), mname)
|
||||
for attr in attrs_to_extract:
|
||||
val = getattr(mval, attr)
|
||||
if isinstance(val, (types.FunctionType, type, types.MethodType)):
|
||||
try:
|
||||
dep, _, dep_uses = lexical_closure(
|
||||
val, already_closed=already_closed,
|
||||
recursion_stack=recursion_stack.copy()
|
||||
)
|
||||
mdeps.append(dep)
|
||||
uses.update(dep_uses)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to capture the lexical closure of {mname}.{attr}. Error: {str(e)}\n"
|
||||
error_msg += f"Recursion stack: {' -> '.join(recursion_stack)}"
|
||||
raise Exception(error_msg)
|
||||
elif isinstance(val, types.ModuleType):
|
||||
modules.append((attr, val))
|
||||
else:
|
||||
# If its another module we need to add it to the list of modules
|
||||
mdeps.append(f"{attr} = {repr(val)}")
|
||||
|
||||
mdeps.insert(0, f"# Extracted from module: {mname}")
|
||||
|
||||
# Now let's dereference all the module names in our cur_src
|
||||
for attr in attrs_to_extract:
|
||||
# Go throught hte dependencies and replace all the module names with the attr
|
||||
source = source.replace(f"{mname}.{attr}", attr)
|
||||
dependencies = [
|
||||
dep.replace(f"{mname}.{attr}", attr) for dep in dependencies
|
||||
]
|
||||
# Now add all the module dependencies to the top of the list
|
||||
reverse_module_src.appendleft("\n".join(mdeps))
|
||||
|
||||
# Now we need to add the module dependencies to the top of the source
|
||||
# Sort the dependencies
|
||||
dependencies = sorted(dependencies)
|
||||
imports = sorted(imports)
|
||||
reverse_module_src = sorted(reverse_module_src)
|
||||
seperated_dependencies = (
|
||||
imports
|
||||
+ list(reverse_module_src)
|
||||
+ dependencies
|
||||
+ [source]
|
||||
)
|
||||
# Remove duplicates and preserve order
|
||||
seperated_dependencies = list(dict.fromkeys(seperated_dependencies))
|
||||
|
||||
dirty_src = DELIM + "\n" + f"\n{DELIM}\n".join(seperated_dependencies) + "\n" + DELIM + "\n"
|
||||
dirty_src_without_func = DELIM + "\n" + f"\n{DELIM}\n".join(seperated_dependencies[:-1]) + "\n" + DELIM + "\n"
|
||||
|
||||
CLOSURE_SOURCE[hash(func)] = dirty_src
|
||||
|
||||
dsrc = _clean_src(dirty_src_without_func)
|
||||
fn_hash = "lmp-" + hashlib.md5(
|
||||
"\n".join((source, dsrc, func.__qualname__)).encode()
|
||||
).hexdigest()
|
||||
|
||||
if hasattr(outer_ell_func, "__ell_func__"):
|
||||
outer_ell_func.__ell_closure__ = (source, dsrc, _globals, _frees)
|
||||
outer_ell_func.__ell_hash__ = fn_hash
|
||||
outer_ell_func.__ell_uses__ = uses
|
||||
|
||||
|
||||
return (dirty_src, (source, dsrc), ({fn_hash} if not initial_call and hasattr(outer_ell_func, "__ell_func__") else uses))
|
||||
|
||||
|
||||
|
||||
def lexically_closured_source(func):
|
||||
_, fnclosure, uses = lexical_closure(func, initial_call=True, recursion_stack=[])
|
||||
return fnclosure, uses
|
||||
@@ -377,7 +336,6 @@ def lexically_closured_source(func):
|
||||
import ast
|
||||
|
||||
def _clean_src(dirty_src):
|
||||
|
||||
# Now remove all duplicates and preserve order
|
||||
split_by_setion = filter(lambda x: len(x.strip()) > 0, dirty_src.split(DELIM))
|
||||
|
||||
@@ -401,7 +359,6 @@ def _clean_src(dirty_src):
|
||||
|
||||
return final_src
|
||||
|
||||
|
||||
def is_function_called(func_name, source_code):
|
||||
"""
|
||||
Check if a function is called in the given source code.
|
||||
|
||||
Reference in New Issue
Block a user