mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
dynamically assign shards to nodes deterministically weighted by memory
This commit is contained in:
@@ -16,18 +16,12 @@ model_path = get_model_path(path_or_hf_repo)
|
||||
tokenizer_config = {}
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
|
||||
peers: List[PeerHandle] = [
|
||||
GRPCPeerHandle(
|
||||
"node1",
|
||||
"localhost:8080",
|
||||
DeviceCapabilities(model="test1", chip="test1", memory=10000)
|
||||
),
|
||||
]
|
||||
shards: List[Shard] = [
|
||||
Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=15, n_layers=32),
|
||||
# Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=30, n_layers=32),
|
||||
# Shard(model_id=path_or_hf_repo, start_layer=31, end_layer=31, n_layers=32),
|
||||
]
|
||||
peer = GRPCPeerHandle(
|
||||
"node1",
|
||||
"localhost:8080",
|
||||
DeviceCapabilities(model="test1", chip="test1", memory=10000)
|
||||
)
|
||||
shard = Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=0, n_layers=32)
|
||||
|
||||
async def run_prompt(prompt: str):
|
||||
if tokenizer.chat_template is None:
|
||||
@@ -41,28 +35,11 @@ async def run_prompt(prompt: str):
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
for peer, shard in zip(peers, shards):
|
||||
await peer.connect()
|
||||
await peer.reset_shard(shard)
|
||||
await peer.connect()
|
||||
await peer.reset_shard(shard)
|
||||
|
||||
tokens = []
|
||||
last_output = prompt
|
||||
|
||||
for _ in range(20):
|
||||
for peer, shard in zip(peers, shards):
|
||||
if isinstance(last_output, str):
|
||||
last_output = await peer.send_prompt(shard, last_output)
|
||||
print("prompt output:", last_output)
|
||||
else:
|
||||
last_output = await peer.send_tensor(shard, last_output)
|
||||
print("tensor output:", last_output)
|
||||
|
||||
if not last_output:
|
||||
break
|
||||
|
||||
tokens.append(last_output.item())
|
||||
|
||||
print(tokenizer.decode(tokens))
|
||||
result = await peer.send_prompt(shard, prompt)
|
||||
print(tokenizer.decode(result))
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run prompt")
|
||||
|
||||
@@ -6,11 +6,11 @@ from .shard import Shard
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
@abstractmethod
|
||||
async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -12,21 +12,20 @@ class MLXFixedShardInferenceEngine(InferenceEngine):
|
||||
model_shard, self.tokenizer = load_shard(model_path, shard)
|
||||
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
|
||||
|
||||
async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
|
||||
if shard != self.shard:
|
||||
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
||||
|
||||
output_data = self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt)))
|
||||
return np.array(output_data)
|
||||
output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
|
||||
print(f"output_data size: {output_data.size}, output_data: {output_data}")
|
||||
return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
|
||||
|
||||
async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
|
||||
if shard != self.shard:
|
||||
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
||||
|
||||
print("infer_tensor", shard, input_data)
|
||||
|
||||
output_data = self.stateful_sharded_model.step(mx.array(input_data))
|
||||
return np.array(output_data)
|
||||
output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
|
||||
return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
|
||||
|
||||
async def reset_shard(self, shard: Shard):
|
||||
if shard != self.shard:
|
||||
@@ -34,3 +33,31 @@ class MLXFixedShardInferenceEngine(InferenceEngine):
|
||||
|
||||
print(f"Resetting shard: {shard}")
|
||||
self.stateful_sharded_model.reset()
|
||||
|
||||
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self):
|
||||
self.shard = None
|
||||
|
||||
async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
|
||||
await self.ensure_shard(shard)
|
||||
output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
|
||||
return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
|
||||
|
||||
async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
|
||||
await self.ensure_shard(shard)
|
||||
output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
|
||||
return output_data, output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
|
||||
|
||||
async def reset_shard(self, shard: Shard):
|
||||
await self.ensure_shard(shard)
|
||||
|
||||
print(f"Resetting shard: {shard}")
|
||||
self.stateful_sharded_model.reset()
|
||||
|
||||
async def ensure_shard(self, shard: Shard):
|
||||
if self.shard == shard:
|
||||
return
|
||||
|
||||
model_shard, self.tokenizer = load_shard(shard.model_id, shard)
|
||||
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
|
||||
self.shard = shard
|
||||
66
main_dynamic.py
Normal file
66
main_dynamic.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import signal
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from typing import List
|
||||
from orchestration.standard_node import StandardNode
|
||||
from networking.grpc.grpc_server import GRPCServer
|
||||
from inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from inference.shard import Shard
|
||||
from networking.grpc.grpc_discovery import GRPCDiscovery
|
||||
from topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
parser.add_argument("--node-id", type=str, default="node1", help="Node ID")
|
||||
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
||||
parser.add_argument("--node-port", type=int, default=8080, help="Node port")
|
||||
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
||||
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
|
||||
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
inference_engine = MLXDynamicShardInferenceEngine()
|
||||
def on_token(tokens: List[int]):
|
||||
if inference_engine.tokenizer:
|
||||
print(inference_engine.tokenizer.decode(tokens))
|
||||
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
|
||||
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), on_token=on_token)
|
||||
server = GRPCServer(node, args.node_host, args.node_port)
|
||||
node.server = server
|
||||
|
||||
|
||||
async def shutdown(signal, loop):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
print(f"Received exit signal {signal.name}...")
|
||||
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
[task.cancel() for task in server_tasks]
|
||||
print(f"Cancelling {len(server_tasks)} outstanding tasks")
|
||||
await asyncio.gather(*server_tasks, return_exceptions=True)
|
||||
await server.shutdown()
|
||||
loop.stop()
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Use a more direct approach to handle signals
|
||||
def handle_exit():
|
||||
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
|
||||
|
||||
for s in [signal.SIGINT, signal.SIGTERM]:
|
||||
loop.add_signal_handler(s, handle_exit)
|
||||
|
||||
await node.start(wait_for_peers=args.wait_for_peers)
|
||||
|
||||
await asyncio.Event().wait()
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
@@ -93,6 +93,7 @@ class GRPCDiscovery(Discovery):
|
||||
try:
|
||||
data, addr = await asyncio.get_event_loop().sock_recvfrom(sock, 1024)
|
||||
message = json.loads(data.decode('utf-8'))
|
||||
print(f"received from peer {addr}: {message}")
|
||||
if message['type'] == 'discovery' and message['node_id'] != self.node_id:
|
||||
peer_id = message['node_id']
|
||||
peer_host = addr[0]
|
||||
@@ -107,7 +108,7 @@ class GRPCDiscovery(Discovery):
|
||||
async def _cleanup_peers(self):
|
||||
while True:
|
||||
current_time = time.time()
|
||||
timeout = 5 * self.broadcast_interval
|
||||
timeout = 15 * self.broadcast_interval
|
||||
peers_to_remove = [peer_id for peer_id, last_seen in self.peer_last_seen.items() if current_time - last_seen > timeout]
|
||||
for peer_id in peers_to_remove:
|
||||
del self.known_peers[peer_id]
|
||||
|
||||
@@ -16,7 +16,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
self.server = None
|
||||
|
||||
async def start(self) -> None:
|
||||
self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
|
||||
('grpc.max_metadata_size', 128*1024)
|
||||
])
|
||||
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
|
||||
listen_addr = f'{self.host}:{self.port}'
|
||||
self.server.add_insecure_port(listen_addr)
|
||||
|
||||
@@ -14,11 +14,11 @@ class Node(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> None:
|
||||
async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def process_prompt(self, shard: Shard, prompt: str) -> None:
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Callable
|
||||
import numpy as np
|
||||
from networking import Discovery, PeerHandle, Server
|
||||
from inference.inference_engine import InferenceEngine, Shard
|
||||
@@ -9,7 +9,7 @@ from topology.partitioning_strategy import PartitioningStrategy
|
||||
from topology.partitioning_strategy import Partition
|
||||
|
||||
class StandardNode(Node):
|
||||
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None):
|
||||
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
|
||||
self.id = id
|
||||
self.inference_engine = inference_engine
|
||||
self.server = server
|
||||
@@ -18,6 +18,9 @@ class StandardNode(Node):
|
||||
self.peers: List[PeerHandle] = {}
|
||||
self.topology: Topology = Topology()
|
||||
self.device_capabilities = device_capabilities()
|
||||
self.buffered_token_output: List[int] = []
|
||||
self.on_token = on_token
|
||||
self.max_generate_tokens = max_generate_tokens
|
||||
|
||||
async def start(self, wait_for_peers: int = 0) -> None:
|
||||
await self.server.start()
|
||||
@@ -35,23 +38,32 @@ class StandardNode(Node):
|
||||
await self.discovery.stop()
|
||||
await self.server.stop()
|
||||
|
||||
async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
|
||||
print("Process prompt", shard, prompt)
|
||||
result = await self.inference_engine.infer_prompt(shard, prompt)
|
||||
print(f"Got result from prompt: {prompt}. Result: {result}")
|
||||
async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]:
|
||||
print("process prompt", shard, prompt)
|
||||
result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
|
||||
|
||||
await self.forward_tensor_to_next_shard(shard, result)
|
||||
print(f"result size: {result.size}, is finished: {is_finished}")
|
||||
if result.size == 1:
|
||||
self.buffered_token_output.append(result.item())
|
||||
self.on_token(self.buffered_token_output)
|
||||
|
||||
return result
|
||||
if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
|
||||
await self.forward_tensor_to_next_shard(shard, result)
|
||||
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> None:
|
||||
print("Process tensor", shard, tensor)
|
||||
result = await self.inference_engine.infer_tensor(shard, tensor)
|
||||
print(f"Got result from tensor: {len(tensor)}. Result: {result}")
|
||||
return np.array(self.buffered_token_output) if self.buffered_token_output else None
|
||||
|
||||
await self.forward_tensor_to_next_shard(shard, result)
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]:
|
||||
result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
|
||||
|
||||
return result
|
||||
print(f"result size: {result.size}, is finished: {is_finished}")
|
||||
if result.size == 1:
|
||||
self.buffered_token_output.append(result.item())
|
||||
self.on_token(self.buffered_token_output)
|
||||
|
||||
if not is_finished and len(self.buffered_token_output) < self.max_generate_tokens:
|
||||
await self.forward_tensor_to_next_shard(shard, result)
|
||||
|
||||
return np.array(self.buffered_token_output) if self.buffered_token_output else None
|
||||
|
||||
async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray) -> None:
|
||||
if not self.partitioning_strategy:
|
||||
@@ -67,6 +79,10 @@ class StandardNode(Node):
|
||||
print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
|
||||
|
||||
if next_partition:
|
||||
if next_partition.node_id == self.id:
|
||||
await self.process_tensor(shard, tensor)
|
||||
return
|
||||
|
||||
target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
||||
if not target_peer:
|
||||
raise ValueError(f"Peer for {next_partition} not found")
|
||||
@@ -79,10 +95,23 @@ class StandardNode(Node):
|
||||
|
||||
await target_peer.send_tensor(next_shard, tensor)
|
||||
|
||||
def get_current_shard(self, shard: Shard) -> Shard:
|
||||
partitions = self.partitioning_strategy.partition(self.topology)
|
||||
current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
|
||||
if current_partition_index is None:
|
||||
raise ValueError(f"No current partition found for node: {self.id}")
|
||||
|
||||
current_partition = partitions[current_partition_index]
|
||||
start_layer = int(current_partition.start * shard.n_layers)
|
||||
end_layer = int(current_partition.end * shard.n_layers) - 1
|
||||
return Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
|
||||
|
||||
|
||||
async def reset_shard(self, shard: Shard) -> None:
|
||||
# Implement shard reset logic
|
||||
print(f"Resetting shard: {shard}")
|
||||
await self.inference_engine.reset_shard(shard)
|
||||
self.buffered_token_output = []
|
||||
await self.inference_engine.reset_shard(self.get_current_shard(shard))
|
||||
|
||||
async def collect_topology(self, max_depth: int = 4) -> Topology:
|
||||
self.topology.update_node(self.id, self.device_capabilities)
|
||||
|
||||
Reference in New Issue
Block a user