make StatefulShardedModel callable, add some tests for mlx sharded inference

This commit is contained in:
Alex Cheema
2024-07-13 15:41:15 -07:00
parent 6ee0547eff
commit 850b72d3ea
8 changed files with 103 additions and 2 deletions

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
__pycache__/
.venv
test_weights.npz

0
inference/__init__.py Normal file
View File

View File

View File

View File

@@ -60,4 +60,4 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
model_shard, self.tokenizer = load_shard(shard.model_id, shard)
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
self.shard = shard
self.shard = shard

View File

@@ -1,4 +1,4 @@
from typing import Dict, Generator, Optional, Tuple
from typing import Any, Dict, Generator, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -47,6 +47,15 @@ class StatefulShardedModel:
else:
return output
def __call__(
self,
x,
temp: float = 0.0,
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
return self.step(x, temp, top_p, logit_bias)
def reset(self):
kv_heads = (
[self.model.n_kv_heads] * len(self.model.layers)

View File

@@ -0,0 +1,40 @@
import mlx.core as mx
from inference.mlx.sharded_model import StatefulShardedModel
from inference.mlx.sharded_utils import load_shard
from inference.shard import Shard
shard_full = Shard("llama", 0, 31, 32)
shard1 = Shard("llama", 0, 12, 32)
shard2 = Shard("llama", 13, 31, 32)
full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard_full)
model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
full = StatefulShardedModel(shard_full, full_model_shard)
m1 = StatefulShardedModel(shard1, model_shard1)
m2 = StatefulShardedModel(shard2, model_shard2)
prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
prompt_tokens = mx.array(tokenizer1.encode(prompt))
max_tokens = 50
resp = prompt_tokens
full_generated_tokens = []
for _ in range(max_tokens):
resp = full.step(resp)
full_generated_tokens.append(resp.item())
print("full response: ", tokenizer1.decode(full_generated_tokens))
sharded_generated_tokens = []
sharded_resp = prompt_tokens
for _ in range(max_tokens):
resp1 = m1.step(sharded_resp)
sharded_resp = m2.step(resp1)
sharded_generated_tokens.append(sharded_resp.item())
print("sharded response: ", tokenizer1.decode(sharded_generated_tokens))
assert tokenizer1.decode(full_generated_tokens) == tokenizer1.decode(sharded_generated_tokens)

View File

@@ -0,0 +1,51 @@
from inference.shard import Shard
from inference.mlx.sharded_model import StatefulShardedModel
import mlx.core as mx
import mlx.nn as nn
from typing import Optional
import numpy as np
class DummyModel(nn.Module):
def __init__(self, shard: Optional[Shard] = None):
self.shard = shard
self.layers = [
nn.Linear(8, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 8),
]
self.n_kv_heads = 4
self.head_dim = 4
def __call__(self, x, cache=None):
if self.shard:
for layer in self.layers[self.shard.start_layer:self.shard.end_layer+1]:
x = layer(x)
if self.shard.is_last_layer():
x = x.reshape((1, 2, 4))
else:
for layer in self.layers:
x = layer(x)
x = x.reshape((1, 2, 4))
return x
model = DummyModel()
model.save_weights("./test_weights.npz")
n_layers = 5
shard1 = Shard("test", 0, n_layers // 2, n_layers)
sharded_model1 = DummyModel(shard1)
shard2 = Shard("test", n_layers // 2 + 1, n_layers - 1, n_layers)
sharded_model2 = DummyModel(shard2)
model.load_weights("./test_weights.npz")
sharded_model1.load_weights("./test_weights.npz")
sharded_model2.load_weights("./test_weights.npz")
fullresp = model(mx.array([1,2,3,4,5,6,7,8]))
resp1 = sharded_model1(mx.array([1,2,3,4,5,6,7,8]))
resp2 = sharded_model2(resp1)
assert np.all(np.array(fullresp) == np.array(resp2))