mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
mlx sharded implementation with example of distributed inference
This commit is contained in:
75
example_user.py
Normal file
75
example_user.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# In this example, a user is running a home cluster with 3 shards.
|
||||
# They are prompting the cluster to generate a response to a question.
|
||||
# The cluster is given the question, and the user is given the response.
|
||||
|
||||
from inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
||||
from inference.shard import Shard
|
||||
from networking.peer_handle import PeerHandle
|
||||
from networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from typing import List
|
||||
import asyncio
|
||||
import argparse
|
||||
|
||||
path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
tokenizer_config = {}
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
|
||||
peers: List[PeerHandle] = [
|
||||
GRPCPeerHandle(
|
||||
"node1",
|
||||
"localhost:8080",
|
||||
),
|
||||
GRPCPeerHandle(
|
||||
"node2",
|
||||
"localhost:8081",
|
||||
)
|
||||
]
|
||||
shards: List[Shard] = [
|
||||
# Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=15, n_layers=32),
|
||||
# Shard(model_id=path_or_hf_repo, start_layer=16, end_layer=31, n_layers=32),
|
||||
Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=30, n_layers=32),
|
||||
Shard(model_id=path_or_hf_repo, start_layer=31, end_layer=31, n_layers=32),
|
||||
]
|
||||
|
||||
async def run_prompt(prompt: str):
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
if (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
for peer, shard in zip(peers, shards):
|
||||
await peer.connect()
|
||||
await peer.reset_shard(shard)
|
||||
|
||||
tokens = []
|
||||
last_output = prompt
|
||||
|
||||
for _ in range(20):
|
||||
for peer, shard in zip(peers, shards):
|
||||
if isinstance(last_output, str):
|
||||
last_output = await peer.send_prompt(shard, last_output)
|
||||
print("prompt output:", last_output)
|
||||
else:
|
||||
last_output = await peer.send_tensor(shard, last_output)
|
||||
print("tensor output:", last_output)
|
||||
|
||||
if not last_output:
|
||||
break
|
||||
|
||||
tokens.append(last_output.item())
|
||||
|
||||
print(tokenizer.decode(tokens))
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run prompt")
|
||||
parser.add_argument("--prompt", type=str, help="The prompt to run")
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(run_prompt(args.prompt))
|
||||
@@ -9,23 +9,9 @@ class InferenceEngine(ABC):
|
||||
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
|
||||
async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reset_shard(self, shard: Shard):
|
||||
pass
|
||||
|
||||
class MLXFixedShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self, model: nn.Module, shard: Shard):
|
||||
self.model = model
|
||||
self.shard = shard
|
||||
|
||||
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
if shard != self.shard:
|
||||
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
||||
|
||||
output_data = self.model.process(input_data)
|
||||
print("Processed data through model shard")
|
||||
return output_data
|
||||
|
||||
async def reset_shard(self, shard: Shard):
|
||||
# TODO
|
||||
print(f"Resetting shard: {shard}")
|
||||
244
inference/mlx/models/sharded_llama.py
Normal file
244
inference/mlx/models/sharded_llama.py
Normal file
@@ -0,0 +1,244 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_lm.models.base import BaseModelArgs, create_additive_causal_mask
|
||||
from ...shard import Shard
|
||||
|
||||
|
||||
@dataclass
|
||||
class NormalModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
@dataclass
|
||||
class ModelArgs(NormalModelArgs):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__() # Ensure parent initializations are respected
|
||||
|
||||
if isinstance(self.shard, Shard):
|
||||
return
|
||||
if not isinstance(self.shard, dict):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.shard.n_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
if self.args.shard.is_first_layer():
|
||||
h = self.embed_tokens(inputs)
|
||||
else:
|
||||
h = inputs
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = create_additive_causal_mask(
|
||||
h.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
return self.norm(h)
|
||||
else:
|
||||
return h
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = LlamaModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
||||
37
inference/mlx/sharded_inference_engine.py
Normal file
37
inference/mlx/sharded_inference_engine.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
from ..inference_engine import InferenceEngine
|
||||
from .sharded_model import StatefulShardedModel
|
||||
from .sharded_utils import load_shard
|
||||
from ..shard import Shard
|
||||
|
||||
class MLXFixedShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self, model_path: str, shard: Shard):
|
||||
print("initializing fixed shard inference", shard)
|
||||
self.shard = shard
|
||||
model_shard, self.tokenizer = load_shard(model_path, shard)
|
||||
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
|
||||
|
||||
async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
if shard != self.shard:
|
||||
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
||||
|
||||
output_data = self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt)))
|
||||
return np.array(output_data)
|
||||
|
||||
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
if shard != self.shard:
|
||||
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
||||
|
||||
print("infer_shard", shard, input_data)
|
||||
|
||||
output_data = self.stateful_sharded_model.step(mx.array(input_data))
|
||||
return np.array(output_data)
|
||||
|
||||
async def reset_shard(self, shard: Shard):
|
||||
if shard != self.shard:
|
||||
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
||||
|
||||
print(f"Resetting shard: {shard}")
|
||||
self.stateful_sharded_model.reset()
|
||||
56
inference/mlx/sharded_model.py
Normal file
56
inference/mlx/sharded_model.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Dict, Generator, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.base import KVCache
|
||||
from mlx_lm.sample_utils import top_p_sampling
|
||||
|
||||
from ..shard import Shard
|
||||
|
||||
class StatefulShardedModel:
|
||||
def __init__(self, shard: Shard, model: nn.Module):
|
||||
self.shard = shard
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def step(
|
||||
self,
|
||||
x,
|
||||
temp: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
logits[:, indices] += values
|
||||
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
token = top_p_sampling(logits, top_p, temp)
|
||||
else:
|
||||
token = mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
return token
|
||||
|
||||
y = x
|
||||
|
||||
output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
logits = output[:, -1, :]
|
||||
y = sample(logits)
|
||||
return y
|
||||
else:
|
||||
return output
|
||||
|
||||
def reset(self):
|
||||
kv_heads = (
|
||||
[self.model.n_kv_heads] * len(self.model.layers)
|
||||
if isinstance(self.model.n_kv_heads, int)
|
||||
else self.model.n_kv_heads
|
||||
)
|
||||
self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
||||
230
inference/mlx/sharded_utils.py
Normal file
230
inference/mlx/sharded_utils.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
|
||||
from mlx_lm.tuner.utils import apply_lora_layers
|
||||
|
||||
from ..shard import Shard
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
MODEL_REMAPPING = {
|
||||
"mistral": "llama", # mistral is compatible with llama
|
||||
"phi-msft": "phixtral",
|
||||
}
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
Retrieve the model and model args classes based on the configuration.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
|
||||
Returns:
|
||||
A tuple containing the Model class and the ModelArgs class.
|
||||
"""
|
||||
model_type = config["model_type"]
|
||||
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
||||
try:
|
||||
arch = importlib.import_module(f"inference.mlx.models.{model_type}")
|
||||
except ImportError:
|
||||
msg = f"Model type {model_type} not supported."
|
||||
logging.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
try:
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
except FileNotFoundError:
|
||||
logging.error(f"Config file not found in {model_path}")
|
||||
raise
|
||||
return config
|
||||
|
||||
def load_model_shard(
|
||||
model_path: Path,
|
||||
shard: Shard,
|
||||
lazy: bool = False,
|
||||
model_config: dict = {},
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Load and initialize the model from a given path.
|
||||
|
||||
Args:
|
||||
model_path (Path): The path to load the model from.
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
model_config(dict, optional): Configuration parameters for the model.
|
||||
Defaults to an empty dictionary.
|
||||
|
||||
Returns:
|
||||
nn.Module: The loaded and initialized model.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the weight files (.safetensors) are not found.
|
||||
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||
"""
|
||||
|
||||
config = load_config(model_path)
|
||||
config.update(model_config)
|
||||
|
||||
# TODO hack
|
||||
config["model_type"] = f"sharded_{config['model_type']}"
|
||||
config["shard"] = {
|
||||
"model_id": model_path.name,
|
||||
"start_layer": shard.start_layer,
|
||||
"end_layer": shard.end_layer,
|
||||
"n_layers": shard.n_layers
|
||||
}
|
||||
|
||||
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
||||
|
||||
if not weight_files:
|
||||
# Try weight for back-compat
|
||||
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
||||
|
||||
if not weight_files:
|
||||
logging.error(f"No safetensors found in {model_path}")
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
model_class, model_args_class = _get_classes(config=config)
|
||||
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = model_class(model_args)
|
||||
|
||||
if hasattr(model, "sanitize"):
|
||||
weights = model.sanitize(weights)
|
||||
|
||||
if (quantization := config.get("quantization", None)) is not None:
|
||||
# Handle legacy models which may not have everything quantized
|
||||
def class_predicate(p, m):
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
return f"{p}.scales" in weights
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
**quantization,
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
||||
filtered_weights = {}
|
||||
for k, v in weights.items():
|
||||
if k.startswith("model.layers."):
|
||||
layer_num = int(k.split('.')[2])
|
||||
if shard.start_layer <= layer_num <= shard.end_layer:
|
||||
new_key = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
|
||||
filtered_weights[new_key] = v
|
||||
else:
|
||||
filtered_weights[k] = v
|
||||
weights = filtered_weights
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
|
||||
if not lazy:
|
||||
mx.eval(model.parameters())
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Ensures the model is available locally. If the path does not exist locally,
|
||||
it is downloaded from the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
||||
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
||||
|
||||
Returns:
|
||||
Path: The path to the model.
|
||||
"""
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
try:
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=[
|
||||
"*.json",
|
||||
"*.safetensors",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
],
|
||||
)
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
raise ModelNotFoundError(
|
||||
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
||||
"Please make sure you specified the local path or Hugging Face"
|
||||
" repo id correctly.\nIf you are trying to access a private or"
|
||||
" gated Hugging Face repo, make sure you are authenticated:\n"
|
||||
"https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
|
||||
) from None
|
||||
return model_path
|
||||
|
||||
|
||||
def load_shard(
|
||||
path_or_hf_repo: str,
|
||||
shard: Shard,
|
||||
tokenizer_config={},
|
||||
model_config={},
|
||||
adapter_path: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||
"""
|
||||
Load the model and tokenizer from a given path or a huggingface repository.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
||||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||
Defaults to an empty dictionary.
|
||||
model_config(dict, optional): Configuration parameters specifically for the model.
|
||||
Defaults to an empty dictionary.
|
||||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
||||
to the model. Default: ``None``.
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
Returns:
|
||||
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file or safetensors are not found.
|
||||
ValueError: If model class or args class are not found.
|
||||
"""
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
model = load_model_shard(model_path, shard, lazy, model_config)
|
||||
if adapter_path is not None:
|
||||
model = apply_lora_layers(model, adapter_path)
|
||||
model.eval()
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
|
||||
return model, tokenizer
|
||||
@@ -3,6 +3,12 @@ from dataclasses import dataclass
|
||||
@dataclass
|
||||
class Shard:
|
||||
model_id: str
|
||||
n_layers: int
|
||||
start_layer: int
|
||||
end_layer: int
|
||||
n_layers: int
|
||||
|
||||
def is_first_layer(self) -> bool:
|
||||
return self.start_layer == 0
|
||||
|
||||
def is_last_layer(self) -> bool:
|
||||
return self.end_layer == self.n_layers - 1
|
||||
|
||||
23
main.py
23
main.py
@@ -5,19 +5,10 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from orchestration.standard_node import StandardNode
|
||||
from networking.grpc.grpc_server import GRPCServer
|
||||
from inference.inference_engine import MLXFixedShardInferenceEngine
|
||||
from inference.mlx.sharded_inference_engine import MLXFixedShardInferenceEngine
|
||||
from inference.shard import Shard
|
||||
from networking.grpc.grpc_discovery import GRPCDiscovery
|
||||
|
||||
class SimpleMLXModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleMLXModel, self).__init__()
|
||||
self.linear = nn.Linear(10, 5) # Example dimensions
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
parser.add_argument("--node-id", type=str, default="node1", help="Node ID")
|
||||
@@ -25,15 +16,19 @@ parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host"
|
||||
parser.add_argument("--node-port", type=int, default=8080, help="Node port")
|
||||
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
||||
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
|
||||
parser.add_argument("--model-id", type=str, default="mlx-community/Meta-Llama-3-8B-Instruct-4bit", help="Path to the model")
|
||||
parser.add_argument("--n-layers", type=int, default=32, help="Number of layers in the model")
|
||||
parser.add_argument("--start-layer", type=int, default=0, help="Start layer index")
|
||||
parser.add_argument("--end-layer", type=int, default=31, help="End layer index")
|
||||
args = parser.parse_args()
|
||||
|
||||
mlx_model = SimpleMLXModel()
|
||||
inference_engine = MLXFixedShardInferenceEngine(mlx_model, shard=Shard(model_id="test", n_layers=32, start_layer=0, end_layer=31))
|
||||
inference_engine = MLXFixedShardInferenceEngine(args.model_id, shard=Shard(model_id=args.model_id, n_layers=args.n_layers, start_layer=args.start_layer, end_layer=args.end_layer))
|
||||
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
|
||||
node = StandardNode(args.node_id, None, inference_engine, discovery)
|
||||
server = GRPCServer(node, args.node_host, args.node_port)
|
||||
node.server = server
|
||||
|
||||
|
||||
async def shutdown(signal, loop):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
print(f"Received exit signal {signal.name}...")
|
||||
@@ -56,10 +51,6 @@ async def main():
|
||||
|
||||
await node.start()
|
||||
|
||||
await asyncio.sleep(5)
|
||||
print("Sending reset shard request")
|
||||
await node.peers[0].reset_shard(f"regards from {node.id}")
|
||||
|
||||
await asyncio.Event().wait()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -7,6 +7,7 @@ from . import node_service_pb2
|
||||
from . import node_service_pb2_grpc
|
||||
|
||||
from ..peer_handle import PeerHandle
|
||||
from inference.shard import Shard
|
||||
|
||||
class GRPCPeerHandle(PeerHandle):
|
||||
def __init__(self, id: str, address: str):
|
||||
@@ -23,25 +24,38 @@ class GRPCPeerHandle(PeerHandle):
|
||||
async def disconnect(self):
|
||||
await self.channel.close()
|
||||
|
||||
async def send_prompt(self, prompt: str) -> None:
|
||||
request = node_service_pb2.PromptRequest(prompt=prompt)
|
||||
await self.stub.SendPrompt(request)
|
||||
async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
|
||||
request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
|
||||
response = await self.stub.SendPrompt(request)
|
||||
print(f"Sent prompt to {self.address}: {prompt}")
|
||||
|
||||
async def send_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
|
||||
if not response.tensor_data or not response.shape or not response.dtype:
|
||||
return None
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
async def send_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.TensorRequest(
|
||||
tensor_data=tensor.tobytes(),
|
||||
shape=tensor.shape,
|
||||
dtype=str(tensor.dtype),
|
||||
shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
|
||||
tensor = node_service_pb2.Tensor(
|
||||
tensor_data=tensor.tobytes(),
|
||||
shape=tensor.shape,
|
||||
dtype=str(tensor.dtype)
|
||||
),
|
||||
target=target
|
||||
)
|
||||
await self.stub.SendTensor(request)
|
||||
response = await self.stub.SendTensor(request)
|
||||
if target:
|
||||
print(f"Sent tensor to {self.address} with target {target}: shape {tensor.shape}")
|
||||
else:
|
||||
print(f"Sent tensor to {self.address}: shape {tensor.shape}")
|
||||
|
||||
async def reset_shard(self, shard_id: str) -> None:
|
||||
request = node_service_pb2.ResetShardRequest(shard_id=shard_id)
|
||||
if not response.tensor_data or not response.shape or not response.dtype:
|
||||
return None
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
async def reset_shard(self, shard: Shard) -> None:
|
||||
request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
|
||||
await self.stub.ResetShard(request)
|
||||
print(f"Reset shard {shard_id} on {self.address}")
|
||||
print(f"Reset shard {shard} on {self.address}")
|
||||
|
||||
@@ -4,6 +4,7 @@ import numpy as np
|
||||
|
||||
from . import node_service_pb2
|
||||
from . import node_service_pb2_grpc
|
||||
from inference.shard import Shard
|
||||
|
||||
from orchestration import Node
|
||||
|
||||
@@ -28,30 +29,24 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
print("Server stopped")
|
||||
|
||||
async def SendPrompt(self, request, context):
|
||||
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
|
||||
prompt = request.prompt
|
||||
target = request.target if request.HasField('target') else None
|
||||
if target and target != self.node.node_id:
|
||||
await self.node.process_prompt(prompt, target)
|
||||
else:
|
||||
# Process the prompt locally
|
||||
# You'd need to implement this method in the Node class
|
||||
await self.node.process_prompt(prompt)
|
||||
return node_service_pb2.Empty()
|
||||
result = await self.node.process_prompt(shard, prompt, target)
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
|
||||
|
||||
async def SendTensor(self, request, context):
|
||||
tensor = np.frombuffer(request.tensor_data, dtype=np.dtype(request.dtype)).reshape(request.shape)
|
||||
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
|
||||
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
|
||||
target = request.target if request.HasField('target') else None
|
||||
if target and target != self.node.node_id:
|
||||
await self.node.process_tensor(tensor, target)
|
||||
else:
|
||||
# Process the tensor locally
|
||||
await self.node.inference_strategy.process_inference(tensor)
|
||||
return node_service_pb2.Empty()
|
||||
result = await self.node.process_tensor(shard, tensor, target)
|
||||
print("SendTensor tensor result", result)
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
|
||||
|
||||
async def ResetShard(self, request, context):
|
||||
print(f"Received ResetShard request: {request}")
|
||||
# TODO
|
||||
# shard_id = request.shard_id
|
||||
# You'd need to implement this method in the Node class
|
||||
# await self.node.reset_shard(shard_id)
|
||||
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
|
||||
print(f"Received ResetShard request: {shard}")
|
||||
await self.node.reset_shard(shard)
|
||||
return node_service_pb2.Empty()
|
||||
|
||||
@@ -3,25 +3,38 @@ syntax = "proto3";
|
||||
package node_service;
|
||||
|
||||
service NodeService {
|
||||
rpc SendPrompt (PromptRequest) returns (Empty) {}
|
||||
rpc SendTensor (TensorRequest) returns (Empty) {}
|
||||
rpc SendPrompt (PromptRequest) returns (Tensor) {}
|
||||
rpc SendTensor (TensorRequest) returns (Tensor) {}
|
||||
rpc ResetShard (ResetShardRequest) returns (Empty) {}
|
||||
}
|
||||
|
||||
message Shard {
|
||||
string model_id = 1;
|
||||
int32 start_layer = 2;
|
||||
int32 end_layer = 3;
|
||||
int32 n_layers = 4;
|
||||
}
|
||||
|
||||
message PromptRequest {
|
||||
string prompt = 1;
|
||||
optional string target = 2;
|
||||
Shard shard = 1;
|
||||
string prompt = 2;
|
||||
optional string target = 3;
|
||||
}
|
||||
|
||||
message TensorRequest {
|
||||
Shard shard = 1;
|
||||
Tensor tensor = 2;
|
||||
optional string target = 3;
|
||||
}
|
||||
|
||||
message Tensor {
|
||||
bytes tensor_data = 1;
|
||||
repeated int32 shape = 2;
|
||||
string dtype = 3;
|
||||
optional string target = 4;
|
||||
}
|
||||
|
||||
message ResetShardRequest {
|
||||
string shard_id = 1;
|
||||
Shard shard = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
@@ -14,21 +14,25 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"?\n\rPromptRequest\x12\x0e\n\x06prompt\x18\x01 \x01(\t\x12\x13\n\x06target\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"b\n\rTensorRequest\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\x12\x13\n\x06target\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"%\n\x11ResetShardRequest\x12\x10\n\x08shard_id\x18\x01 \x01(\t\"\x07\n\x05\x45mpty2\xd7\x01\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"c\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"\x07\n\x05\x45mpty2\xd9\x01\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_PROMPTREQUEST']._serialized_start=36
|
||||
_globals['_PROMPTREQUEST']._serialized_end=99
|
||||
_globals['_TENSORREQUEST']._serialized_start=101
|
||||
_globals['_TENSORREQUEST']._serialized_end=199
|
||||
_globals['_RESETSHARDREQUEST']._serialized_start=201
|
||||
_globals['_RESETSHARDREQUEST']._serialized_end=238
|
||||
_globals['_EMPTY']._serialized_start=240
|
||||
_globals['_EMPTY']._serialized_end=247
|
||||
_globals['_NODESERVICE']._serialized_start=250
|
||||
_globals['_NODESERVICE']._serialized_end=465
|
||||
_globals['_SHARD']._serialized_start=36
|
||||
_globals['_SHARD']._serialized_end=119
|
||||
_globals['_PROMPTREQUEST']._serialized_start=121
|
||||
_globals['_PROMPTREQUEST']._serialized_end=220
|
||||
_globals['_TENSORREQUEST']._serialized_start=222
|
||||
_globals['_TENSORREQUEST']._serialized_end=343
|
||||
_globals['_TENSOR']._serialized_start=345
|
||||
_globals['_TENSOR']._serialized_end=404
|
||||
_globals['_RESETSHARDREQUEST']._serialized_start=406
|
||||
_globals['_RESETSHARDREQUEST']._serialized_end=461
|
||||
_globals['_EMPTY']._serialized_start=463
|
||||
_globals['_EMPTY']._serialized_end=470
|
||||
_globals['_NODESERVICE']._serialized_start=473
|
||||
_globals['_NODESERVICE']._serialized_end=690
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -42,12 +42,12 @@ class NodeServiceStub(object):
|
||||
self.SendPrompt = channel.unary_unary(
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTensor = channel.unary_unary(
|
||||
'/node_service.NodeService/SendTensor',
|
||||
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True)
|
||||
self.ResetShard = channel.unary_unary(
|
||||
'/node_service.NodeService/ResetShard',
|
||||
@@ -83,12 +83,12 @@ def add_NodeServiceServicer_to_server(servicer, server):
|
||||
'SendPrompt': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPrompt,
|
||||
request_deserializer=node__service__pb2.PromptRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'SendTensor': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendTensor,
|
||||
request_deserializer=node__service__pb2.TensorRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'ResetShard': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ResetShard,
|
||||
@@ -122,7 +122,7 @@ class NodeService(object):
|
||||
target,
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
node__service__pb2.PromptRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
@@ -149,7 +149,7 @@ class NodeService(object):
|
||||
target,
|
||||
'/node_service.NodeService/SendTensor',
|
||||
node__service__pb2.TensorRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
from inference.shard import Shard
|
||||
|
||||
class PeerHandle(ABC):
|
||||
def id(self) -> str:
|
||||
@@ -14,13 +16,13 @@ class PeerHandle(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_prompt(self, prompt: str) -> None:
|
||||
async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_tensor(self, tensor: Any) -> None:
|
||||
async def send_tensor(self, shard: Shard, tensor: np.array) -> Optional[np.array]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reset_shard(self, shard_id: str) -> None:
|
||||
async def reset_shard(self, shard: Shard) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from inference.shard import Shard
|
||||
|
||||
class Node(ABC):
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
def start(self, wait_for_peers: int = 0) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -12,13 +13,13 @@ class Node(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
|
||||
def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process_prompt(self, prompt: str, target: Optional[str] = None) -> None:
|
||||
def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_shard(self, shard_id: str) -> None:
|
||||
def reset_shard(self, shard: Shard) -> None:
|
||||
pass
|
||||
|
||||
@@ -13,10 +13,10 @@ class StandardNode(Node):
|
||||
self.peers: List[PeerHandle] = {}
|
||||
self.ring_order: List[str] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
async def start(self, wait_for_peers: int = 0) -> None:
|
||||
await self.server.start()
|
||||
await self.discovery.start()
|
||||
self.peers = await self.discovery.discover_peers()
|
||||
self.peers = await self.discovery.discover_peers(wait_for_peers)
|
||||
print(f"Starting with the following peers: {self.peers}")
|
||||
print("Connecting to peers...")
|
||||
for peer in self.peers:
|
||||
@@ -27,19 +27,35 @@ class StandardNode(Node):
|
||||
await self.discovery.stop()
|
||||
await self.server.stop()
|
||||
|
||||
async def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
|
||||
result = await self.inference_engine.process_shard(tensor)
|
||||
|
||||
async def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> Optional[np.array]:
|
||||
print("Process prompt", shard, prompt, target)
|
||||
result = await self.inference_engine.infer_prompt(shard, prompt)
|
||||
# Implement prompt processing logic
|
||||
print(f"Got result from prompt: {prompt}. Result: {result}")
|
||||
# You might want to initiate inference here
|
||||
if target:
|
||||
if not filter(lambda p: p.id() == target, self.peers):
|
||||
target_peer = next((p for p in self.peers if p.id() == target), None)
|
||||
if not target_peer:
|
||||
raise ValueError(f"Peer {target} not found")
|
||||
|
||||
await self.peers[target].send_tensor(result)
|
||||
await target_peer.send_tensor(result)
|
||||
|
||||
async def process_prompt(self, prompt: str) -> None:
|
||||
return result
|
||||
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
|
||||
print("Process tensor", shard, tensor)
|
||||
result = await self.inference_engine.infer_shard(shard, tensor)
|
||||
# Implement prompt processing logic
|
||||
print(f"Processing prompt: {prompt}")
|
||||
# You might want to initiate inference here
|
||||
print(f"Got result from prompt: {len(tensor)}. Result: {result}")
|
||||
|
||||
if target:
|
||||
target_peer = next((p for p in self.peers if p.id() == target), None)
|
||||
if not target_peer:
|
||||
raise ValueError(f"Peer {target} not found")
|
||||
|
||||
await target_peer.send_tensor(result)
|
||||
|
||||
return result
|
||||
|
||||
async def reset_shard(self, shard: Shard) -> None:
|
||||
# Implement shard reset logic
|
||||
|
||||
Reference in New Issue
Block a user