From 2fdda5174dee5ba7fab930f62b0905389d156255 Mon Sep 17 00:00:00 2001 From: Rory Clear Date: Wed, 27 Nov 2024 15:48:31 +0000 Subject: [PATCH 1/3] remove unused layers --- exo/inference/tinygrad/inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/exo/inference/tinygrad/inference.py b/exo/inference/tinygrad/inference.py index afa9c5df..ad5deabe 100644 --- a/exo/inference/tinygrad/inference.py +++ b/exo/inference/tinygrad/inference.py @@ -13,6 +13,7 @@ from exo.download.shard_download import ShardDownloader from concurrent.futures import ThreadPoolExecutor from .stateful_model import StatefulModel import asyncio +import re Tensor.no_grad = True # default settings @@ -51,7 +52,9 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No weights = load(str(model_path), shard) weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"]) weights = fix_bf16(weights) - + for k in list(weights): + if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer): + del weights[k] with Context(BEAM=0): # replace weights in model load_state_dict(model, weights, strict=False, consume=False) # consume=True From 1d1fa8c608e81e3cbffdbb756a8b3f99e70db4b7 Mon Sep 17 00:00:00 2001 From: Rory Clear Date: Wed, 27 Nov 2024 16:41:12 +0000 Subject: [PATCH 2/3] move to tinygrad_helpers --- exo/inference/tinygrad/inference.py | 4 ---- exo/inference/tinygrad/tinygrad_helpers.py | 7 ++++++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/exo/inference/tinygrad/inference.py b/exo/inference/tinygrad/inference.py index ad5deabe..24cddc92 100644 --- a/exo/inference/tinygrad/inference.py +++ b/exo/inference/tinygrad/inference.py @@ -13,7 +13,6 @@ from exo.download.shard_download import ShardDownloader from concurrent.futures import ThreadPoolExecutor from .stateful_model import StatefulModel import asyncio -import re Tensor.no_grad = True # default settings @@ -52,9 +51,6 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No weights = load(str(model_path), shard) weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"]) weights = fix_bf16(weights) - for k in list(weights): - if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer): - del weights[k] with Context(BEAM=0): # replace weights in model load_state_dict(model, weights, strict=False, consume=False) # consume=True diff --git a/exo/inference/tinygrad/tinygrad_helpers.py b/exo/inference/tinygrad/tinygrad_helpers.py index d3aa234e..a33b01ae 100644 --- a/exo/inference/tinygrad/tinygrad_helpers.py +++ b/exo/inference/tinygrad/tinygrad_helpers.py @@ -7,6 +7,7 @@ from exo.inference.shard import Shard from exo.helpers import DEBUG from exo.download.hf.hf_helpers import get_allow_patterns from fnmatch import fnmatch +import re # **** helper functions **** @@ -42,6 +43,10 @@ def load(fn: str, shard: Shard): if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}") return {k: parts[n][k] for k, n in filtered_weight_map.items()} elif fn.endswith(".safetensors"): - return safe_load(fn) + weight_map = safe_load(fn) + for k in list(weight_map): + if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer): + del weight_map[k] + return weight_map else: return torch_load(fn) From 3c81845ab74fb25052b32babaa249420c84d60e9 Mon Sep 17 00:00:00 2001 From: Rory Clear Date: Wed, 27 Nov 2024 16:42:19 +0000 Subject: [PATCH 3/3] undo diff --- exo/inference/tinygrad/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/tinygrad/inference.py b/exo/inference/tinygrad/inference.py index 24cddc92..afa9c5df 100644 --- a/exo/inference/tinygrad/inference.py +++ b/exo/inference/tinygrad/inference.py @@ -51,6 +51,7 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No weights = load(str(model_path), shard) weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"]) weights = fix_bf16(weights) + with Context(BEAM=0): # replace weights in model load_state_dict(model, weights, strict=False, consume=False) # consume=True