diff --git a/exo/inference/tinygrad/tinygrad_helpers.py b/exo/inference/tinygrad/tinygrad_helpers.py index d3aa234e..a33b01ae 100644 --- a/exo/inference/tinygrad/tinygrad_helpers.py +++ b/exo/inference/tinygrad/tinygrad_helpers.py @@ -7,6 +7,7 @@ from exo.inference.shard import Shard from exo.helpers import DEBUG from exo.download.hf.hf_helpers import get_allow_patterns from fnmatch import fnmatch +import re # **** helper functions **** @@ -42,6 +43,10 @@ def load(fn: str, shard: Shard): if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}") return {k: parts[n][k] for k, n in filtered_weight_map.items()} elif fn.endswith(".safetensors"): - return safe_load(fn) + weight_map = safe_load(fn) + for k in list(weight_map): + if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer): + del weight_map[k] + return weight_map else: return torch_load(fn)