mlx sharded implementation with example of distributed inference

This commit is contained in:
Alex Cheema
2024-06-24 19:35:57 +01:00
parent a21f59ff45
commit 563dcb56b0
16 changed files with 775 additions and 105 deletions

75
example_user.py Normal file
View File

@@ -0,0 +1,75 @@
# In this example, a user is running a home cluster with 3 shards.
# They are prompting the cluster to generate a response to a question.
# The cluster is given the question, and the user is given the response.
from inference.mlx.sharded_utils import get_model_path, load_tokenizer
from inference.shard import Shard
from networking.peer_handle import PeerHandle
from networking.grpc.grpc_peer_handle import GRPCPeerHandle
from typing import List
import asyncio
import argparse
path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
model_path = get_model_path(path_or_hf_repo)
tokenizer_config = {}
tokenizer = load_tokenizer(model_path, tokenizer_config)
peers: List[PeerHandle] = [
GRPCPeerHandle(
"node1",
"localhost:8080",
),
GRPCPeerHandle(
"node2",
"localhost:8081",
)
]
shards: List[Shard] = [
# Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=15, n_layers=32),
# Shard(model_id=path_or_hf_repo, start_layer=16, end_layer=31, n_layers=32),
Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=30, n_layers=32),
Shard(model_id=path_or_hf_repo, start_layer=31, end_layer=31, n_layers=32),
]
async def run_prompt(prompt: str):
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for peer, shard in zip(peers, shards):
await peer.connect()
await peer.reset_shard(shard)
tokens = []
last_output = prompt
for _ in range(20):
for peer, shard in zip(peers, shards):
if isinstance(last_output, str):
last_output = await peer.send_prompt(shard, last_output)
print("prompt output:", last_output)
else:
last_output = await peer.send_tensor(shard, last_output)
print("tensor output:", last_output)
if not last_output:
break
tokens.append(last_output.item())
print(tokenizer.decode(tokens))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run prompt")
parser.add_argument("--prompt", type=str, help="The prompt to run")
args = parser.parse_args()
asyncio.run(run_prompt(args.prompt))

View File

@@ -9,23 +9,9 @@ class InferenceEngine(ABC):
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
pass
async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
pass
@abstractmethod
async def reset_shard(self, shard: Shard):
pass
class MLXFixedShardInferenceEngine(InferenceEngine):
def __init__(self, model: nn.Module, shard: Shard):
self.model = model
self.shard = shard
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
if shard != self.shard:
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
output_data = self.model.process(input_data)
print("Processed data through model shard")
return output_data
async def reset_shard(self, shard: Shard):
# TODO
print(f"Resetting shard: {shard}")

View File

@@ -0,0 +1,244 @@
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import BaseModelArgs, create_additive_causal_mask
from ...shard import Shard
@dataclass
class NormalModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
attention_bias: bool = False
mlp_bias: bool = False
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
@dataclass
class ModelArgs(NormalModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
super().__post_init__() # Ensure parent initializations are respected
if isinstance(self.shard, Shard):
return
if not isinstance(self.shard, dict):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
self.shard = Shard(**self.shard)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.shard.n_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
if self.args.shard.is_first_layer():
h = self.embed_tokens(inputs)
else:
h = inputs
mask = None
if h.shape[1] > 1:
mask = create_additive_causal_mask(
h.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
if self.args.shard.is_last_layer():
return self.norm(h)
else:
return h
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = LlamaModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.shard.is_last_layer():
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -0,0 +1,37 @@
import mlx.nn as nn
import numpy as np
import mlx.core as mx
from ..inference_engine import InferenceEngine
from .sharded_model import StatefulShardedModel
from .sharded_utils import load_shard
from ..shard import Shard
class MLXFixedShardInferenceEngine(InferenceEngine):
def __init__(self, model_path: str, shard: Shard):
print("initializing fixed shard inference", shard)
self.shard = shard
model_shard, self.tokenizer = load_shard(model_path, shard)
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
if shard != self.shard:
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
output_data = self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt)))
return np.array(output_data)
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
if shard != self.shard:
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
print("infer_shard", shard, input_data)
output_data = self.stateful_sharded_model.step(mx.array(input_data))
return np.array(output_data)
async def reset_shard(self, shard: Shard):
if shard != self.shard:
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
print(f"Resetting shard: {shard}")
self.stateful_sharded_model.reset()

View File

@@ -0,0 +1,56 @@
from typing import Dict, Generator, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import KVCache
from mlx_lm.sample_utils import top_p_sampling
from ..shard import Shard
class StatefulShardedModel:
def __init__(self, shard: Shard, model: nn.Module):
self.shard = shard
self.model = model
self.reset()
def step(
self,
x,
temp: float = 0.0,
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
def sample(logits: mx.array) -> Tuple[mx.array, float]:
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
else:
token = mx.random.categorical(logits * (1 / temp))
return token
y = x
output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
if self.shard.is_last_layer():
logits = output[:, -1, :]
y = sample(logits)
return y
else:
return output
def reset(self):
kv_heads = (
[self.model.n_kv_heads] * len(self.model.layers)
if isinstance(self.model.n_kv_heads, int)
else self.model.n_kv_heads
)
self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]

View File

@@ -0,0 +1,230 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
import glob
import importlib
import json
import logging
from pathlib import Path
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
from mlx_lm.tuner.utils import apply_lora_layers
from ..shard import Shard
class ModelNotFoundError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
MODEL_REMAPPING = {
"mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral",
}
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
model_type = MODEL_REMAPPING.get(model_type, model_type)
try:
arch = importlib.import_module(f"inference.mlx.models.{model_type}")
except ImportError:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
return arch.Model, arch.ModelArgs
def load_config(model_path: Path) -> dict:
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
return config
def load_model_shard(
model_path: Path,
shard: Shard,
lazy: bool = False,
model_config: dict = {},
) -> nn.Module:
"""
Load and initialize the model from a given path.
Args:
model_path (Path): The path to load the model from.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
model_config(dict, optional): Configuration parameters for the model.
Defaults to an empty dictionary.
Returns:
nn.Module: The loaded and initialized model.
Raises:
FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated.
"""
config = load_config(model_path)
config.update(model_config)
# TODO hack
config["model_type"] = f"sharded_{config['model_type']}"
config["shard"] = {
"model_id": model_path.name,
"start_layer": shard.start_layer,
"end_layer": shard.end_layer,
"n_layers": shard.n_layers
}
weight_files = glob.glob(str(model_path / "model*.safetensors"))
if not weight_files:
# Try weight for back-compat
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
def class_predicate(p, m):
if not hasattr(m, "to_quantized"):
return False
return f"{p}.scales" in weights
nn.quantize(
model,
**quantization,
class_predicate=class_predicate,
)
filtered_weights = {}
for k, v in weights.items():
if k.startswith("model.layers."):
layer_num = int(k.split('.')[2])
if shard.start_layer <= layer_num <= shard.end_layer:
new_key = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
filtered_weights[new_key] = v
else:
filtered_weights[k] = v
weights = filtered_weights
model.load_weights(list(weights.items()), strict=False)
if not lazy:
mx.eval(model.parameters())
model.eval()
return model
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Args:
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
Returns:
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
if not model_path.exists():
try:
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
)
)
except RepositoryNotFoundError:
raise ModelNotFoundError(
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
"Please make sure you specified the local path or Hugging Face"
" repo id correctly.\nIf you are trying to access a private or"
" gated Hugging Face repo, make sure you are authenticated:\n"
"https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
) from None
return model_path
def load_shard(
path_or_hf_repo: str,
shard: Shard,
tokenizer_config={},
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
Args:
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
Raises:
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
model_path = get_model_path(path_or_hf_repo)
model = load_model_shard(model_path, shard, lazy, model_config)
if adapter_path is not None:
model = apply_lora_layers(model, adapter_path)
model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config)
return model, tokenizer

View File

@@ -3,6 +3,12 @@ from dataclasses import dataclass
@dataclass
class Shard:
model_id: str
n_layers: int
start_layer: int
end_layer: int
n_layers: int
def is_first_layer(self) -> bool:
return self.start_layer == 0
def is_last_layer(self) -> bool:
return self.end_layer == self.n_layers - 1

23
main.py
View File

@@ -5,19 +5,10 @@ import mlx.core as mx
import mlx.nn as nn
from orchestration.standard_node import StandardNode
from networking.grpc.grpc_server import GRPCServer
from inference.inference_engine import MLXFixedShardInferenceEngine
from inference.mlx.sharded_inference_engine import MLXFixedShardInferenceEngine
from inference.shard import Shard
from networking.grpc.grpc_discovery import GRPCDiscovery
class SimpleMLXModel(nn.Module):
def __init__(self):
super(SimpleMLXModel, self).__init__()
self.linear = nn.Linear(10, 5) # Example dimensions
def forward(self, x):
return self.linear(x)
# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
parser.add_argument("--node-id", type=str, default="node1", help="Node ID")
@@ -25,15 +16,19 @@ parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host"
parser.add_argument("--node-port", type=int, default=8080, help="Node port")
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--model-id", type=str, default="mlx-community/Meta-Llama-3-8B-Instruct-4bit", help="Path to the model")
parser.add_argument("--n-layers", type=int, default=32, help="Number of layers in the model")
parser.add_argument("--start-layer", type=int, default=0, help="Start layer index")
parser.add_argument("--end-layer", type=int, default=31, help="End layer index")
args = parser.parse_args()
mlx_model = SimpleMLXModel()
inference_engine = MLXFixedShardInferenceEngine(mlx_model, shard=Shard(model_id="test", n_layers=32, start_layer=0, end_layer=31))
inference_engine = MLXFixedShardInferenceEngine(args.model_id, shard=Shard(model_id=args.model_id, n_layers=args.n_layers, start_layer=args.start_layer, end_layer=args.end_layer))
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
node = StandardNode(args.node_id, None, inference_engine, discovery)
server = GRPCServer(node, args.node_host, args.node_port)
node.server = server
async def shutdown(signal, loop):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
@@ -56,10 +51,6 @@ async def main():
await node.start()
await asyncio.sleep(5)
print("Sending reset shard request")
await node.peers[0].reset_shard(f"regards from {node.id}")
await asyncio.Event().wait()
if __name__ == "__main__":

View File

@@ -7,6 +7,7 @@ from . import node_service_pb2
from . import node_service_pb2_grpc
from ..peer_handle import PeerHandle
from inference.shard import Shard
class GRPCPeerHandle(PeerHandle):
def __init__(self, id: str, address: str):
@@ -23,25 +24,38 @@ class GRPCPeerHandle(PeerHandle):
async def disconnect(self):
await self.channel.close()
async def send_prompt(self, prompt: str) -> None:
request = node_service_pb2.PromptRequest(prompt=prompt)
await self.stub.SendPrompt(request)
async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
response = await self.stub.SendPrompt(request)
print(f"Sent prompt to {self.address}: {prompt}")
async def send_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
if not response.tensor_data or not response.shape or not response.dtype:
return None
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
tensor_data=tensor.tobytes(),
shape=tensor.shape,
dtype=str(tensor.dtype),
shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
tensor = node_service_pb2.Tensor(
tensor_data=tensor.tobytes(),
shape=tensor.shape,
dtype=str(tensor.dtype)
),
target=target
)
await self.stub.SendTensor(request)
response = await self.stub.SendTensor(request)
if target:
print(f"Sent tensor to {self.address} with target {target}: shape {tensor.shape}")
else:
print(f"Sent tensor to {self.address}: shape {tensor.shape}")
async def reset_shard(self, shard_id: str) -> None:
request = node_service_pb2.ResetShardRequest(shard_id=shard_id)
if not response.tensor_data or not response.shape or not response.dtype:
return None
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def reset_shard(self, shard: Shard) -> None:
request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
await self.stub.ResetShard(request)
print(f"Reset shard {shard_id} on {self.address}")
print(f"Reset shard {shard} on {self.address}")

View File

@@ -4,6 +4,7 @@ import numpy as np
from . import node_service_pb2
from . import node_service_pb2_grpc
from inference.shard import Shard
from orchestration import Node
@@ -28,30 +29,24 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
print("Server stopped")
async def SendPrompt(self, request, context):
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
prompt = request.prompt
target = request.target if request.HasField('target') else None
if target and target != self.node.node_id:
await self.node.process_prompt(prompt, target)
else:
# Process the prompt locally
# You'd need to implement this method in the Node class
await self.node.process_prompt(prompt)
return node_service_pb2.Empty()
result = await self.node.process_prompt(shard, prompt, target)
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
async def SendTensor(self, request, context):
tensor = np.frombuffer(request.tensor_data, dtype=np.dtype(request.dtype)).reshape(request.shape)
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
target = request.target if request.HasField('target') else None
if target and target != self.node.node_id:
await self.node.process_tensor(tensor, target)
else:
# Process the tensor locally
await self.node.inference_strategy.process_inference(tensor)
return node_service_pb2.Empty()
result = await self.node.process_tensor(shard, tensor, target)
print("SendTensor tensor result", result)
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
async def ResetShard(self, request, context):
print(f"Received ResetShard request: {request}")
# TODO
# shard_id = request.shard_id
# You'd need to implement this method in the Node class
# await self.node.reset_shard(shard_id)
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
print(f"Received ResetShard request: {shard}")
await self.node.reset_shard(shard)
return node_service_pb2.Empty()

View File

@@ -3,25 +3,38 @@ syntax = "proto3";
package node_service;
service NodeService {
rpc SendPrompt (PromptRequest) returns (Empty) {}
rpc SendTensor (TensorRequest) returns (Empty) {}
rpc SendPrompt (PromptRequest) returns (Tensor) {}
rpc SendTensor (TensorRequest) returns (Tensor) {}
rpc ResetShard (ResetShardRequest) returns (Empty) {}
}
message Shard {
string model_id = 1;
int32 start_layer = 2;
int32 end_layer = 3;
int32 n_layers = 4;
}
message PromptRequest {
string prompt = 1;
optional string target = 2;
Shard shard = 1;
string prompt = 2;
optional string target = 3;
}
message TensorRequest {
Shard shard = 1;
Tensor tensor = 2;
optional string target = 3;
}
message Tensor {
bytes tensor_data = 1;
repeated int32 shape = 2;
string dtype = 3;
optional string target = 4;
}
message ResetShardRequest {
string shard_id = 1;
Shard shard = 1;
}
message Empty {}

View File

@@ -14,21 +14,25 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"?\n\rPromptRequest\x12\x0e\n\x06prompt\x18\x01 \x01(\t\x12\x13\n\x06target\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"b\n\rTensorRequest\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\x12\x13\n\x06target\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"%\n\x11ResetShardRequest\x12\x10\n\x08shard_id\x18\x01 \x01(\t\"\x07\n\x05\x45mpty2\xd7\x01\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"c\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"\x07\n\x05\x45mpty2\xd9\x01\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_PROMPTREQUEST']._serialized_start=36
_globals['_PROMPTREQUEST']._serialized_end=99
_globals['_TENSORREQUEST']._serialized_start=101
_globals['_TENSORREQUEST']._serialized_end=199
_globals['_RESETSHARDREQUEST']._serialized_start=201
_globals['_RESETSHARDREQUEST']._serialized_end=238
_globals['_EMPTY']._serialized_start=240
_globals['_EMPTY']._serialized_end=247
_globals['_NODESERVICE']._serialized_start=250
_globals['_NODESERVICE']._serialized_end=465
_globals['_SHARD']._serialized_start=36
_globals['_SHARD']._serialized_end=119
_globals['_PROMPTREQUEST']._serialized_start=121
_globals['_PROMPTREQUEST']._serialized_end=220
_globals['_TENSORREQUEST']._serialized_start=222
_globals['_TENSORREQUEST']._serialized_end=343
_globals['_TENSOR']._serialized_start=345
_globals['_TENSOR']._serialized_end=404
_globals['_RESETSHARDREQUEST']._serialized_start=406
_globals['_RESETSHARDREQUEST']._serialized_end=461
_globals['_EMPTY']._serialized_start=463
_globals['_EMPTY']._serialized_end=470
_globals['_NODESERVICE']._serialized_start=473
_globals['_NODESERVICE']._serialized_end=690
# @@protoc_insertion_point(module_scope)

View File

@@ -42,12 +42,12 @@ class NodeServiceStub(object):
self.SendPrompt = channel.unary_unary(
'/node_service.NodeService/SendPrompt',
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendTensor = channel.unary_unary(
'/node_service.NodeService/SendTensor',
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.ResetShard = channel.unary_unary(
'/node_service.NodeService/ResetShard',
@@ -83,12 +83,12 @@ def add_NodeServiceServicer_to_server(servicer, server):
'SendPrompt': grpc.unary_unary_rpc_method_handler(
servicer.SendPrompt,
request_deserializer=node__service__pb2.PromptRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'SendTensor': grpc.unary_unary_rpc_method_handler(
servicer.SendTensor,
request_deserializer=node__service__pb2.TensorRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'ResetShard': grpc.unary_unary_rpc_method_handler(
servicer.ResetShard,
@@ -122,7 +122,7 @@ class NodeService(object):
target,
'/node_service.NodeService/SendPrompt',
node__service__pb2.PromptRequest.SerializeToString,
node__service__pb2.Empty.FromString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
@@ -149,7 +149,7 @@ class NodeService(object):
target,
'/node_service.NodeService/SendTensor',
node__service__pb2.TensorRequest.SerializeToString,
node__service__pb2.Empty.FromString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,

View File

@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Optional
import numpy as np
from inference.shard import Shard
class PeerHandle(ABC):
def id(self) -> str:
@@ -14,13 +16,13 @@ class PeerHandle(ABC):
pass
@abstractmethod
async def send_prompt(self, prompt: str) -> None:
async def send_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
pass
@abstractmethod
async def send_tensor(self, tensor: Any) -> None:
async def send_tensor(self, shard: Shard, tensor: np.array) -> Optional[np.array]:
pass
@abstractmethod
async def reset_shard(self, shard_id: str) -> None:
async def reset_shard(self, shard: Shard) -> None:
pass

View File

@@ -1,10 +1,11 @@
from typing import Optional
import numpy as np
from abc import ABC, abstractmethod
from inference.shard import Shard
class Node(ABC):
@abstractmethod
def start(self) -> None:
def start(self, wait_for_peers: int = 0) -> None:
pass
@abstractmethod
@@ -12,13 +13,13 @@ class Node(ABC):
pass
@abstractmethod
def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
pass
@abstractmethod
def process_prompt(self, prompt: str, target: Optional[str] = None) -> None:
def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> None:
pass
@abstractmethod
def reset_shard(self, shard_id: str) -> None:
def reset_shard(self, shard: Shard) -> None:
pass

View File

@@ -13,10 +13,10 @@ class StandardNode(Node):
self.peers: List[PeerHandle] = {}
self.ring_order: List[str] = []
async def start(self) -> None:
async def start(self, wait_for_peers: int = 0) -> None:
await self.server.start()
await self.discovery.start()
self.peers = await self.discovery.discover_peers()
self.peers = await self.discovery.discover_peers(wait_for_peers)
print(f"Starting with the following peers: {self.peers}")
print("Connecting to peers...")
for peer in self.peers:
@@ -27,19 +27,35 @@ class StandardNode(Node):
await self.discovery.stop()
await self.server.stop()
async def process_tensor(self, tensor: np.ndarray, target: Optional[str] = None) -> None:
result = await self.inference_engine.process_shard(tensor)
async def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> Optional[np.array]:
print("Process prompt", shard, prompt, target)
result = await self.inference_engine.infer_prompt(shard, prompt)
# Implement prompt processing logic
print(f"Got result from prompt: {prompt}. Result: {result}")
# You might want to initiate inference here
if target:
if not filter(lambda p: p.id() == target, self.peers):
target_peer = next((p for p in self.peers if p.id() == target), None)
if not target_peer:
raise ValueError(f"Peer {target} not found")
await self.peers[target].send_tensor(result)
await target_peer.send_tensor(result)
async def process_prompt(self, prompt: str) -> None:
return result
async def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
print("Process tensor", shard, tensor)
result = await self.inference_engine.infer_shard(shard, tensor)
# Implement prompt processing logic
print(f"Processing prompt: {prompt}")
# You might want to initiate inference here
print(f"Got result from prompt: {len(tensor)}. Result: {result}")
if target:
target_peer = next((p for p in self.peers if p.id() == target), None)
if not target_peer:
raise ValueError(f"Peer {target} not found")
await target_peer.send_tensor(result)
return result
async def reset_shard(self, shard: Shard) -> None:
# Implement shard reset logic