mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
make StatefulShardedModel callable, add some tests for mlx sharded inference
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
__pycache__/
|
||||
.venv
|
||||
test_weights.npz
|
||||
|
||||
0
inference/__init__.py
Normal file
0
inference/__init__.py
Normal file
0
inference/mlx/__init__.py
Normal file
0
inference/mlx/__init__.py
Normal file
0
inference/mlx/models/__init__.py
Normal file
0
inference/mlx/models/__init__.py
Normal file
@@ -60,4 +60,4 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
|
||||
model_shard, self.tokenizer = load_shard(shard.model_id, shard)
|
||||
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
|
||||
self.shard = shard
|
||||
self.shard = shard
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Generator, Optional, Tuple
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -47,6 +47,15 @@ class StatefulShardedModel:
|
||||
else:
|
||||
return output
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
temp: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
return self.step(x, temp, top_p, logit_bias)
|
||||
|
||||
def reset(self):
|
||||
kv_heads = (
|
||||
[self.model.n_kv_heads] * len(self.model.layers)
|
||||
|
||||
40
inference/mlx/test_sharded_llama.py
Normal file
40
inference/mlx/test_sharded_llama.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import mlx.core as mx
|
||||
from inference.mlx.sharded_model import StatefulShardedModel
|
||||
from inference.mlx.sharded_utils import load_shard
|
||||
from inference.shard import Shard
|
||||
|
||||
shard_full = Shard("llama", 0, 31, 32)
|
||||
shard1 = Shard("llama", 0, 12, 32)
|
||||
shard2 = Shard("llama", 13, 31, 32)
|
||||
|
||||
full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard_full)
|
||||
model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
|
||||
model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
|
||||
|
||||
full = StatefulShardedModel(shard_full, full_model_shard)
|
||||
m1 = StatefulShardedModel(shard1, model_shard1)
|
||||
m2 = StatefulShardedModel(shard2, model_shard2)
|
||||
|
||||
prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
|
||||
prompt_tokens = mx.array(tokenizer1.encode(prompt))
|
||||
max_tokens = 50
|
||||
|
||||
resp = prompt_tokens
|
||||
full_generated_tokens = []
|
||||
for _ in range(max_tokens):
|
||||
resp = full.step(resp)
|
||||
full_generated_tokens.append(resp.item())
|
||||
|
||||
print("full response: ", tokenizer1.decode(full_generated_tokens))
|
||||
|
||||
|
||||
sharded_generated_tokens = []
|
||||
sharded_resp = prompt_tokens
|
||||
for _ in range(max_tokens):
|
||||
resp1 = m1.step(sharded_resp)
|
||||
sharded_resp = m2.step(resp1)
|
||||
sharded_generated_tokens.append(sharded_resp.item())
|
||||
|
||||
print("sharded response: ", tokenizer1.decode(sharded_generated_tokens))
|
||||
|
||||
assert tokenizer1.decode(full_generated_tokens) == tokenizer1.decode(sharded_generated_tokens)
|
||||
51
inference/mlx/test_sharded_model.py
Normal file
51
inference/mlx/test_sharded_model.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from inference.shard import Shard
|
||||
from inference.mlx.sharded_model import StatefulShardedModel
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, shard: Optional[Shard] = None):
|
||||
self.shard = shard
|
||||
self.layers = [
|
||||
nn.Linear(8, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 8),
|
||||
]
|
||||
|
||||
self.n_kv_heads = 4
|
||||
self.head_dim = 4
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
if self.shard:
|
||||
for layer in self.layers[self.shard.start_layer:self.shard.end_layer+1]:
|
||||
x = layer(x)
|
||||
if self.shard.is_last_layer():
|
||||
x = x.reshape((1, 2, 4))
|
||||
else:
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x = x.reshape((1, 2, 4))
|
||||
|
||||
return x
|
||||
|
||||
model = DummyModel()
|
||||
model.save_weights("./test_weights.npz")
|
||||
n_layers = 5
|
||||
shard1 = Shard("test", 0, n_layers // 2, n_layers)
|
||||
sharded_model1 = DummyModel(shard1)
|
||||
shard2 = Shard("test", n_layers // 2 + 1, n_layers - 1, n_layers)
|
||||
sharded_model2 = DummyModel(shard2)
|
||||
|
||||
model.load_weights("./test_weights.npz")
|
||||
sharded_model1.load_weights("./test_weights.npz")
|
||||
sharded_model2.load_weights("./test_weights.npz")
|
||||
|
||||
fullresp = model(mx.array([1,2,3,4,5,6,7,8]))
|
||||
resp1 = sharded_model1(mx.array([1,2,3,4,5,6,7,8]))
|
||||
resp2 = sharded_model2(resp1)
|
||||
|
||||
assert np.all(np.array(fullresp) == np.array(resp2))
|
||||
Reference in New Issue
Block a user