mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge pull request #511 from roryclear/load_shard_only
only load layers in shard in tinygrad
This commit is contained in:
@@ -7,6 +7,7 @@ from exo.inference.shard import Shard
|
|||||||
from exo.helpers import DEBUG
|
from exo.helpers import DEBUG
|
||||||
from exo.download.hf.hf_helpers import get_allow_patterns
|
from exo.download.hf.hf_helpers import get_allow_patterns
|
||||||
from fnmatch import fnmatch
|
from fnmatch import fnmatch
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
# **** helper functions ****
|
# **** helper functions ****
|
||||||
@@ -42,6 +43,10 @@ def load(fn: str, shard: Shard):
|
|||||||
if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
|
if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
|
||||||
return {k: parts[n][k] for k, n in filtered_weight_map.items()}
|
return {k: parts[n][k] for k, n in filtered_weight_map.items()}
|
||||||
elif fn.endswith(".safetensors"):
|
elif fn.endswith(".safetensors"):
|
||||||
return safe_load(fn)
|
weight_map = safe_load(fn)
|
||||||
|
for k in list(weight_map):
|
||||||
|
if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
|
||||||
|
del weight_map[k]
|
||||||
|
return weight_map
|
||||||
else:
|
else:
|
||||||
return torch_load(fn)
|
return torch_load(fn)
|
||||||
|
|||||||
Reference in New Issue
Block a user