topology with partitioning strategy

This commit is contained in:
Alex Cheema
2024-06-24 20:56:50 +01:00
parent 563dcb56b0
commit 6c8c9ee7b1
17 changed files with 161 additions and 38 deletions

View File

@@ -9,6 +9,7 @@ class InferenceEngine(ABC):
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
pass
@abstractmethod
async def infer_prompt(self, shard: Shard, prompt: str) -> np.ndarray:
pass

View File

@@ -1,4 +1,3 @@
import mlx.nn as nn
import numpy as np
import mlx.core as mx
from ..inference_engine import InferenceEngine

View File

@@ -6,6 +6,7 @@ from typing import List, Dict
from ..discovery import Discovery
from ..peer_handle import PeerHandle
from .grpc_peer_handle import GRPCPeerHandle
from topology.device_capabilities import DeviceCapabilities, mac_device_capabilities
class GRPCDiscovery(Discovery):
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1):
@@ -61,13 +62,15 @@ class GRPCDiscovery(Discovery):
return list(self.known_peers.values())
async def _broadcast_presence(self):
self.device_capabilities: DeviceCapabilities = mac_device_capabilities()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.settimeout(0.5)
message = json.dumps({
"type": "discovery",
"node_id": self.node_id,
"grpc_port": self.node_port
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict()
}).encode('utf-8')
while True:

View File

@@ -34,7 +34,7 @@ class GRPCPeerHandle(PeerHandle):
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> Optional[np.array]:
async def send_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
tensor = node_service_pb2.Tensor(
@@ -42,13 +42,8 @@ class GRPCPeerHandle(PeerHandle):
shape=tensor.shape,
dtype=str(tensor.dtype)
),
target=target
)
response = await self.stub.SendTensor(request)
if target:
print(f"Sent tensor to {self.address} with target {target}: shape {tensor.shape}")
else:
print(f"Sent tensor to {self.address}: shape {tensor.shape}")
if not response.tensor_data or not response.shape or not response.dtype:
return None

View File

@@ -31,16 +31,14 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
async def SendPrompt(self, request, context):
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
prompt = request.prompt
target = request.target if request.HasField('target') else None
result = await self.node.process_prompt(shard, prompt, target)
result = await self.node.process_prompt(shard, prompt)
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
async def SendTensor(self, request, context):
shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
target = request.target if request.HasField('target') else None
result = await self.node.process_tensor(shard, tensor, target)
result = await self.node.process_tensor(shard, tensor)
print("SendTensor tensor result", result)
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))

View File

@@ -18,13 +18,11 @@ message Shard {
message PromptRequest {
Shard shard = 1;
string prompt = 2;
optional string target = 3;
}
message TensorRequest {
Shard shard = 1;
Tensor tensor = 2;
optional string target = 3;
}
message Tensor {

View File

@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"c\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\"y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x13\n\x06target\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\t\n\x07_target\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"\x07\n\x05\x45mpty2\xd9\x01\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"C\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\"Y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"\x07\n\x05\x45mpty2\xd9\x01\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -24,15 +24,15 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_SHARD']._serialized_start=36
_globals['_SHARD']._serialized_end=119
_globals['_PROMPTREQUEST']._serialized_start=121
_globals['_PROMPTREQUEST']._serialized_end=220
_globals['_TENSORREQUEST']._serialized_start=222
_globals['_TENSORREQUEST']._serialized_end=343
_globals['_TENSOR']._serialized_start=345
_globals['_TENSOR']._serialized_end=404
_globals['_RESETSHARDREQUEST']._serialized_start=406
_globals['_RESETSHARDREQUEST']._serialized_end=461
_globals['_EMPTY']._serialized_start=463
_globals['_EMPTY']._serialized_end=470
_globals['_NODESERVICE']._serialized_start=473
_globals['_NODESERVICE']._serialized_end=690
_globals['_PROMPTREQUEST']._serialized_end=188
_globals['_TENSORREQUEST']._serialized_start=190
_globals['_TENSORREQUEST']._serialized_end=279
_globals['_TENSOR']._serialized_start=281
_globals['_TENSOR']._serialized_end=340
_globals['_RESETSHARDREQUEST']._serialized_start=342
_globals['_RESETSHARDREQUEST']._serialized_end=397
_globals['_EMPTY']._serialized_start=399
_globals['_EMPTY']._serialized_end=406
_globals['_NODESERVICE']._serialized_start=409
_globals['_NODESERVICE']._serialized_end=626
# @@protoc_insertion_point(module_scope)

View File

@@ -3,7 +3,7 @@
import grpc
import warnings
from . import node_service_pb2 as node__service__pb2
import node_service_pb2 as node__service__pb2
GRPC_GENERATED_VERSION = '1.64.1'
GRPC_VERSION = grpc.__version__

View File

@@ -2,11 +2,17 @@ from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
from inference.shard import Shard
from topology.device_capabilities import DeviceCapabilities
class PeerHandle(ABC):
@abstractmethod
def id(self) -> str:
pass
@abstractmethod
def device_capabilities(self) -> DeviceCapabilities:
pass
@abstractmethod
async def connect(self) -> None:
pass

View File

@@ -13,11 +13,11 @@ class Node(ABC):
pass
@abstractmethod
def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
def process_tensor(self, shard: Shard, tensor: np.ndarray) -> None:
pass
@abstractmethod
def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> None:
def process_prompt(self, shard: Shard, prompt: str) -> None:
pass
@abstractmethod

View File

@@ -3,6 +3,7 @@ import numpy as np
from networking import Discovery, PeerHandle, Server
from inference.inference_engine import InferenceEngine, Shard
from .node import Node
from topology.topology import Topology
class StandardNode(Node):
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery):
@@ -11,7 +12,8 @@ class StandardNode(Node):
self.server = server
self.discovery = discovery
self.peers: List[PeerHandle] = {}
self.ring_order: List[str] = []
self.topology: Topology = Topology()
self.successor: Optional[PeerHandle] = None
async def start(self, wait_for_peers: int = 0) -> None:
await self.server.start()
@@ -27,18 +29,14 @@ class StandardNode(Node):
await self.discovery.stop()
await self.server.stop()
async def process_prompt(self, shard: Shard, prompt: str, target: Optional[str] = None) -> Optional[np.array]:
print("Process prompt", shard, prompt, target)
async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
print("Process prompt", shard, prompt)
result = await self.inference_engine.infer_prompt(shard, prompt)
# Implement prompt processing logic
print(f"Got result from prompt: {prompt}. Result: {result}")
# You might want to initiate inference here
if target:
target_peer = next((p for p in self.peers if p.id() == target), None)
if not target_peer:
raise ValueError(f"Peer {target} not found")
await target_peer.send_tensor(result)
if self.successor:
await self.succesor.send_tensor()
return result

0
topology/__init__.py Normal file
View File

View File

@@ -0,0 +1,27 @@
from dataclasses import dataclass
import subprocess
@dataclass
class DeviceCapabilities:
model: str
chip: str
memory: int
def mac_device_capabilities() -> DeviceCapabilities:
# Fetch the model of the Mac using system_profiler
model = subprocess.check_output(['system_profiler', 'SPHardwareDataType']).decode('utf-8')
model_line = next((line for line in model.split('\n') if "Model Name" in line), None)
model_id = model_line.split(': ')[1] if model_line else "Unknown Model"
chip_line = next((line for line in model.split('\n') if "Chip" in line), None)
chip_id = chip_line.split(': ')[1] if chip_line else "Unknown Chip"
memory_line = next((line for line in model.split('\n') if "Memory" in line), None)
memory_str = memory_line.split(': ')[1] if memory_line else "Unknown Memory"
memory_units = memory_str.split()
memory_value = int(memory_units[0])
if memory_units[1] == "GB":
memory = memory_value * 1024
else:
memory = memory_value
# Assuming static values for other attributes for demonstration
return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory)

View File

@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
from typing import List
from inference.shard import Shard
from networking.peer_handle import PeerHandle
from .topology import Topology
class PartitioningStrategy(ABC):
@abstractmethod
def next_shard(self, current_shard: Shard, topology: Topology, node_stats: dict) -> Shard:
pass

View File

@@ -0,0 +1,27 @@
from .partitioning_strategy import PartitioningStrategy
from inference.shard import Shard
from .topology import Topology
class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
def next_shard(self, current_shard: Shard, topology: Topology, node_stats: dict) -> Shard:
# Get all nodes from the topology and include the current node
nodes = list(topology.all_nodes())
nodes.append((self.id, None, node_stats))
# Sort nodes by their IDs
nodes.sort(key=lambda x: x[0])
# Calculate the total memory of all nodes
total_memory = sum(node[2]['memory'] for node in nodes)
# Calculate the number of layers to assign to each node proportional to its memory
layers_per_node = {node[0]: (node[2]['memory'] / total_memory) * current_shard.n_layers for node in nodes}
# Find the successor node
node_ids = [node[0] for node in nodes]
current_index = node_ids.index(self.id)
successor_index = (current_index + 1) % len(node_ids)
successor_id = node_ids[successor_index]
# Return the Shard calculated for the successor
return Shard(successor_id, layers_per_node[successor_id])

View File

@@ -0,0 +1,49 @@
import unittest
from unittest.mock import patch
from topology.device_capabilities import mac_device_capabilities, DeviceCapabilities
class TestMacDeviceCapabilities(unittest.TestCase):
@patch('subprocess.check_output')
def test_mac_device_capabilities(self, mock_check_output):
# Mock the subprocess output
mock_check_output.return_value = b"""
Hardware:
Hardware Overview:
Model Name: MacBook Pro
Model Identifier: Mac15,9
Model Number: Z1CM000EFB/A
Chip: Apple M3 Max
Total Number of Cores: 16 (12 performance and 4 efficiency)
Memory: 128 GB
System Firmware Version: 10000.000.0
OS Loader Version: 10000.000.0
Serial Number (system): XXXXXXXXXX
Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX
Activation Lock Status: Enabled
"""
# Call the function
result = mac_device_capabilities()
# Check the results
self.assertIsInstance(result, DeviceCapabilities)
self.assertEqual(result.model, "MacBook Pro")
self.assertEqual(result.chip, "Apple M3 Max")
self.assertEqual(result.memory, 131072) # 16 GB in MB
@unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
def test_mac_device_capabilities_real(self):
# Call the function without mocking
result = mac_device_capabilities()
# Check the results
self.assertIsInstance(result, DeviceCapabilities)
self.assertEqual(result.model, "MacBook Pro")
self.assertEqual(result.chip, "Apple M3 Max")
self.assertEqual(result.memory, 131072) # 128 GB in MB
if __name__ == '__main__':
unittest.main()

12
topology/topology.py Normal file
View File

@@ -0,0 +1,12 @@
class Topology:
def __init__(self):
self.nodes = {} # Maps node IDs to a tuple of (host, port, stats)
def update_node(self, node_id, stats):
self.nodes[node_id] = stats
def get_node(self, node_id):
return self.nodes.get(node_id)
def all_nodes(self):
return self.nodes.items()