mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
topology with partitioning strategy
This commit is contained in:
@@ -9,6 +9,7 @@ class InferenceEngine(ABC):
|
||||
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
from ..inference_engine import InferenceEngine
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List, Dict
|
||||
from ..discovery import Discovery
|
||||
from ..peer_handle import PeerHandle
|
||||
from .grpc_peer_handle import GRPCPeerHandle
|
||||
from topology.device_capabilities import DeviceCapabilities, mac_device_capabilities
|
||||
|
||||
class GRPCDiscovery(Discovery):
|
||||
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1):
|
||||
@@ -61,13 +62,15 @@ class GRPCDiscovery(Discovery):
|
||||
return list(self.known_peers.values())
|
||||
|
||||
async def _broadcast_presence(self):
|
||||
self.device_capabilities: DeviceCapabilities = mac_device_capabilities()
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
sock.settimeout(0.5)
|
||||
message = json.dumps({
|
||||
"type": "discovery",
|
||||
"node_id": self.node_id,
|
||||
"grpc_port": self.node_port
|
||||
"grpc_port": self.node_port,
|
||||
"device_capabilities": self.device_capabilities.to_dict()
|
||||
}).encode('utf-8')
|
||||
|
||||
while True:
|
||||
|
||||
@@ -34,7 +34,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
async def send_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> Optional[np.array]:
|
||||
async def send_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.array]:
|
||||
request = node_service_pb2.TensorRequest(
|
||||
shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
|
||||
tensor = node_service_pb2.Tensor(
|
||||
@@ -42,13 +42,8 @@ class GRPCPeerHandle(PeerHandle):
|
||||
shape=tensor.shape,
|
||||
dtype=str(tensor.dtype)
|
||||
),
|
||||
target=target
|
||||
)
|
||||
response = await self.stub.SendTensor(request)
|
||||
if target:
|
||||
print(f"Sent tensor to {self.address} with target {target}: shape {tensor.shape}")
|
||||
else:
|
||||
print(f"Sent tensor to {self.address}: shape {tensor.shape}")
|
||||
|
||||
if not response.tensor_data or not response.shape or not response.dtype:
|
||||
return None
|
||||
|
||||
@@ -31,16 +31,14 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
async def SendPrompt(self, request, context):
|
||||
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
|
||||
prompt = request.prompt
|
||||
target = request.target if request.HasField('target') else None
|
||||
result = await self.node.process_prompt(shard, prompt, target)
|
||||
result = await self.node.process_prompt(shard, prompt)
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
|
||||
|
||||
async def SendTensor(self, request, context):
|
||||
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
|
||||
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
|
||||
target = request.target if request.HasField('target') else None
|
||||
result = await self.node.process_tensor(shard, tensor, target)
|
||||
result = await self.node.process_tensor(shard, tensor)
|
||||
print("SendTensor tensor result", result)
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
|
||||
|
||||
@@ -18,13 +18,11 @@ message Shard {
|
||||
message PromptRequest {
|
||||
Shard shard = 1;
|
||||
string prompt = 2;
|
||||
optional string target = 3;
|
||||
}
|
||||
|
||||
message TensorRequest {
|
||||
Shard shard = 1;
|
||||
Tensor tensor = 2;
|
||||
optional string target = 3;
|
||||
}
|
||||
|
||||
message Tensor {
|
||||
|
||||
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"c\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"\x07\n\x05\x45mpty2\xd9\x01\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"C\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\"Y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"\x07\n\x05\x45mpty2\xd9\x01\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
@@ -24,15 +24,15 @@ if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['_SHARD']._serialized_start=36
|
||||
_globals['_SHARD']._serialized_end=119
|
||||
_globals['_PROMPTREQUEST']._serialized_start=121
|
||||
_globals['_PROMPTREQUEST']._serialized_end=220
|
||||
_globals['_TENSORREQUEST']._serialized_start=222
|
||||
_globals['_TENSORREQUEST']._serialized_end=343
|
||||
_globals['_TENSOR']._serialized_start=345
|
||||
_globals['_TENSOR']._serialized_end=404
|
||||
_globals['_RESETSHARDREQUEST']._serialized_start=406
|
||||
_globals['_RESETSHARDREQUEST']._serialized_end=461
|
||||
_globals['_EMPTY']._serialized_start=463
|
||||
_globals['_EMPTY']._serialized_end=470
|
||||
_globals['_NODESERVICE']._serialized_start=473
|
||||
_globals['_NODESERVICE']._serialized_end=690
|
||||
_globals['_PROMPTREQUEST']._serialized_end=188
|
||||
_globals['_TENSORREQUEST']._serialized_start=190
|
||||
_globals['_TENSORREQUEST']._serialized_end=279
|
||||
_globals['_TENSOR']._serialized_start=281
|
||||
_globals['_TENSOR']._serialized_end=340
|
||||
_globals['_RESETSHARDREQUEST']._serialized_start=342
|
||||
_globals['_RESETSHARDREQUEST']._serialized_end=397
|
||||
_globals['_EMPTY']._serialized_start=399
|
||||
_globals['_EMPTY']._serialized_end=406
|
||||
_globals['_NODESERVICE']._serialized_start=409
|
||||
_globals['_NODESERVICE']._serialized_end=626
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from . import node_service_pb2 as node__service__pb2
|
||||
import node_service_pb2 as node__service__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.64.1'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
|
||||
@@ -2,11 +2,17 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
from inference.shard import Shard
|
||||
from topology.device_capabilities import DeviceCapabilities
|
||||
|
||||
class PeerHandle(ABC):
|
||||
@abstractmethod
|
||||
def id(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def device_capabilities(self) -> DeviceCapabilities:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -13,11 +13,11 @@ class Node(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
|
||||
def process_tensor(self, shard: Shard, tensor: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> None:
|
||||
def process_prompt(self, shard: Shard, prompt: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -3,6 +3,7 @@ import numpy as np
|
||||
from networking import Discovery, PeerHandle, Server
|
||||
from inference.inference_engine import InferenceEngine, Shard
|
||||
from .node import Node
|
||||
from topology.topology import Topology
|
||||
|
||||
class StandardNode(Node):
|
||||
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery):
|
||||
@@ -11,7 +12,8 @@ class StandardNode(Node):
|
||||
self.server = server
|
||||
self.discovery = discovery
|
||||
self.peers: List[PeerHandle] = {}
|
||||
self.ring_order: List[str] = []
|
||||
self.topology: Topology = Topology()
|
||||
self.successor: Optional[PeerHandle] = None
|
||||
|
||||
async def start(self, wait_for_peers: int = 0) -> None:
|
||||
await self.server.start()
|
||||
@@ -27,18 +29,14 @@ class StandardNode(Node):
|
||||
await self.discovery.stop()
|
||||
await self.server.stop()
|
||||
|
||||
async def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> Optional[np.array]:
|
||||
print("Process prompt", shard, prompt, target)
|
||||
async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
|
||||
print("Process prompt", shard, prompt)
|
||||
result = await self.inference_engine.infer_prompt(shard, prompt)
|
||||
# Implement prompt processing logic
|
||||
print(f"Got result from prompt: {prompt}. Result: {result}")
|
||||
# You might want to initiate inference here
|
||||
if target:
|
||||
target_peer = next((p for p in self.peers if p.id() == target), None)
|
||||
if not target_peer:
|
||||
raise ValueError(f"Peer {target} not found")
|
||||
|
||||
await target_peer.send_tensor(result)
|
||||
if self.successor:
|
||||
await self.succesor.send_tensor()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
0
topology/__init__.py
Normal file
0
topology/__init__.py
Normal file
27
topology/device_capabilities.py
Normal file
27
topology/device_capabilities.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from dataclasses import dataclass
|
||||
import subprocess
|
||||
|
||||
@dataclass
|
||||
class DeviceCapabilities:
|
||||
model: str
|
||||
chip: str
|
||||
memory: int
|
||||
|
||||
def mac_device_capabilities() -> DeviceCapabilities:
|
||||
# Fetch the model of the Mac using system_profiler
|
||||
model = subprocess.check_output(['system_profiler', 'SPHardwareDataType']).decode('utf-8')
|
||||
model_line = next((line for line in model.split('\n') if "Model Name" in line), None)
|
||||
model_id = model_line.split(': ')[1] if model_line else "Unknown Model"
|
||||
chip_line = next((line for line in model.split('\n') if "Chip" in line), None)
|
||||
chip_id = chip_line.split(': ')[1] if chip_line else "Unknown Chip"
|
||||
memory_line = next((line for line in model.split('\n') if "Memory" in line), None)
|
||||
memory_str = memory_line.split(': ')[1] if memory_line else "Unknown Memory"
|
||||
memory_units = memory_str.split()
|
||||
memory_value = int(memory_units[0])
|
||||
if memory_units[1] == "GB":
|
||||
memory = memory_value * 1024
|
||||
else:
|
||||
memory = memory_value
|
||||
|
||||
# Assuming static values for other attributes for demonstration
|
||||
return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory)
|
||||
10
topology/partitioning_strategy.py
Normal file
10
topology/partitioning_strategy.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from inference.shard import Shard
|
||||
from networking.peer_handle import PeerHandle
|
||||
from .topology import Topology
|
||||
|
||||
class PartitioningStrategy(ABC):
|
||||
@abstractmethod
|
||||
def next_shard(self, current_shard: Shard, topology: Topology, node_stats: dict) -> Shard:
|
||||
pass
|
||||
27
topology/ring_memory_weighted_partitioning_strategy.py
Normal file
27
topology/ring_memory_weighted_partitioning_strategy.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from .partitioning_strategy import PartitioningStrategy
|
||||
from inference.shard import Shard
|
||||
from .topology import Topology
|
||||
|
||||
class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
|
||||
def next_shard(self, current_shard: Shard, topology: Topology, node_stats: dict) -> Shard:
|
||||
# Get all nodes from the topology and include the current node
|
||||
nodes = list(topology.all_nodes())
|
||||
nodes.append((self.id, None, node_stats))
|
||||
|
||||
# Sort nodes by their IDs
|
||||
nodes.sort(key=lambda x: x[0])
|
||||
|
||||
# Calculate the total memory of all nodes
|
||||
total_memory = sum(node[2]['memory'] for node in nodes)
|
||||
|
||||
# Calculate the number of layers to assign to each node proportional to its memory
|
||||
layers_per_node = {node[0]: (node[2]['memory'] / total_memory) * current_shard.n_layers for node in nodes}
|
||||
|
||||
# Find the successor node
|
||||
node_ids = [node[0] for node in nodes]
|
||||
current_index = node_ids.index(self.id)
|
||||
successor_index = (current_index + 1) % len(node_ids)
|
||||
successor_id = node_ids[successor_index]
|
||||
|
||||
# Return the Shard calculated for the successor
|
||||
return Shard(successor_id, layers_per_node[successor_id])
|
||||
49
topology/test_device_capabilities.py
Normal file
49
topology/test_device_capabilities.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from topology.device_capabilities import mac_device_capabilities, DeviceCapabilities
|
||||
|
||||
class TestMacDeviceCapabilities(unittest.TestCase):
|
||||
@patch('subprocess.check_output')
|
||||
def test_mac_device_capabilities(self, mock_check_output):
|
||||
# Mock the subprocess output
|
||||
mock_check_output.return_value = b"""
|
||||
Hardware:
|
||||
|
||||
Hardware Overview:
|
||||
|
||||
Model Name: MacBook Pro
|
||||
Model Identifier: Mac15,9
|
||||
Model Number: Z1CM000EFB/A
|
||||
Chip: Apple M3 Max
|
||||
Total Number of Cores: 16 (12 performance and 4 efficiency)
|
||||
Memory: 128 GB
|
||||
System Firmware Version: 10000.000.0
|
||||
OS Loader Version: 10000.000.0
|
||||
Serial Number (system): XXXXXXXXXX
|
||||
Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
|
||||
Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX
|
||||
Activation Lock Status: Enabled
|
||||
"""
|
||||
|
||||
# Call the function
|
||||
result = mac_device_capabilities()
|
||||
|
||||
# Check the results
|
||||
self.assertIsInstance(result, DeviceCapabilities)
|
||||
self.assertEqual(result.model, "MacBook Pro")
|
||||
self.assertEqual(result.chip, "Apple M3 Max")
|
||||
self.assertEqual(result.memory, 131072) # 16 GB in MB
|
||||
|
||||
@unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
|
||||
def test_mac_device_capabilities_real(self):
|
||||
# Call the function without mocking
|
||||
result = mac_device_capabilities()
|
||||
|
||||
# Check the results
|
||||
self.assertIsInstance(result, DeviceCapabilities)
|
||||
self.assertEqual(result.model, "MacBook Pro")
|
||||
self.assertEqual(result.chip, "Apple M3 Max")
|
||||
self.assertEqual(result.memory, 131072) # 128 GB in MB
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
12
topology/topology.py
Normal file
12
topology/topology.py
Normal file
@@ -0,0 +1,12 @@
|
||||
class Topology:
|
||||
def __init__(self):
|
||||
self.nodes = {} # Maps node IDs to a tuple of (host, port, stats)
|
||||
|
||||
def update_node(self, node_id, stats):
|
||||
self.nodes[node_id] = stats
|
||||
|
||||
def get_node(self, node_id):
|
||||
return self.nodes.get(node_id)
|
||||
|
||||
def all_nodes(self):
|
||||
return self.nodes.items()
|
||||
Reference in New Issue
Block a user