mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
collect global topology with local peer visibility, ring memory weighted partitioning strategy
This commit is contained in:
@@ -6,6 +6,7 @@ from inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
||||
from inference.shard import Shard
|
||||
from networking.peer_handle import PeerHandle
|
||||
from networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from topology.device_capabilities import DeviceCapabilities
|
||||
from typing import List
|
||||
import asyncio
|
||||
import argparse
|
||||
@@ -19,17 +20,19 @@ peers: List[PeerHandle] = [
|
||||
GRPCPeerHandle(
|
||||
"node1",
|
||||
"localhost:8080",
|
||||
DeviceCapabilities(model="test1", chip="test1", memory=10000)
|
||||
),
|
||||
GRPCPeerHandle(
|
||||
"node2",
|
||||
"localhost:8081",
|
||||
DeviceCapabilities(model="test2", chip="test2", memory=20000)
|
||||
)
|
||||
]
|
||||
shards: List[Shard] = [
|
||||
# Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=15, n_layers=32),
|
||||
# Shard(model_id=path_or_hf_repo, start_layer=16, end_layer=31, n_layers=32),
|
||||
Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=30, n_layers=32),
|
||||
Shard(model_id=path_or_hf_repo, start_layer=31, end_layer=31, n_layers=32),
|
||||
Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=15, n_layers=32),
|
||||
Shard(model_id=path_or_hf_repo, start_layer=16, end_layer=31, n_layers=32),
|
||||
# Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=30, n_layers=32),
|
||||
# Shard(model_id=path_or_hf_repo, start_layer=31, end_layer=31, n_layers=32),
|
||||
]
|
||||
|
||||
async def run_prompt(prompt: str):
|
||||
|
||||
72
example_user_2.py
Normal file
72
example_user_2.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# In this example, a user is running a home cluster with 3 shards.
|
||||
# They are prompting the cluster to generate a response to a question.
|
||||
# The cluster is given the question, and the user is given the response.
|
||||
|
||||
from inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
||||
from inference.shard import Shard
|
||||
from networking.peer_handle import PeerHandle
|
||||
from networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from topology.device_capabilities import DeviceCapabilities
|
||||
from typing import List
|
||||
import asyncio
|
||||
import argparse
|
||||
|
||||
path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
tokenizer_config = {}
|
||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||
|
||||
peers: List[PeerHandle] = [
|
||||
GRPCPeerHandle(
|
||||
"node1",
|
||||
"localhost:8080",
|
||||
DeviceCapabilities(model="test1", chip="test1", memory=10000)
|
||||
),
|
||||
]
|
||||
shards: List[Shard] = [
|
||||
Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=15, n_layers=32),
|
||||
# Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=30, n_layers=32),
|
||||
# Shard(model_id=path_or_hf_repo, start_layer=31, end_layer=31, n_layers=32),
|
||||
]
|
||||
|
||||
async def run_prompt(prompt: str):
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
if (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
for peer, shard in zip(peers, shards):
|
||||
await peer.connect()
|
||||
await peer.reset_shard(shard)
|
||||
|
||||
tokens = []
|
||||
last_output = prompt
|
||||
|
||||
for _ in range(20):
|
||||
for peer, shard in zip(peers, shards):
|
||||
if isinstance(last_output, str):
|
||||
last_output = await peer.send_prompt(shard, last_output)
|
||||
print("prompt output:", last_output)
|
||||
else:
|
||||
last_output = await peer.send_tensor(shard, last_output)
|
||||
print("tensor output:", last_output)
|
||||
|
||||
if not last_output:
|
||||
break
|
||||
|
||||
tokens.append(last_output.item())
|
||||
|
||||
print(tokenizer.decode(tokens))
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run prompt")
|
||||
parser.add_argument("--prompt", type=str, help="The prompt to run")
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(run_prompt(args.prompt))
|
||||
@@ -6,7 +6,7 @@ from .shard import Shard
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
@abstractmethod
|
||||
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -19,11 +19,11 @@ class MLXFixedShardInferenceEngine(InferenceEngine):
|
||||
output_data = self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt)))
|
||||
return np.array(output_data)
|
||||
|
||||
async def infer_shard(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
if shard != self.shard:
|
||||
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")
|
||||
|
||||
print("infer_shard", shard, input_data)
|
||||
print("infer_tensor", shard, input_data)
|
||||
|
||||
output_data = self.stateful_sharded_model.step(mx.array(input_data))
|
||||
return np.array(output_data)
|
||||
|
||||
6
main.py
6
main.py
@@ -8,6 +8,7 @@ from networking.grpc.grpc_server import GRPCServer
|
||||
from inference.mlx.sharded_inference_engine import MLXFixedShardInferenceEngine
|
||||
from inference.shard import Shard
|
||||
from networking.grpc.grpc_discovery import GRPCDiscovery
|
||||
from topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
@@ -20,11 +21,12 @@ parser.add_argument("--model-id", type=str, default="mlx-community/Meta-Llama-3-
|
||||
parser.add_argument("--n-layers", type=int, default=32, help="Number of layers in the model")
|
||||
parser.add_argument("--start-layer", type=int, default=0, help="Start layer index")
|
||||
parser.add_argument("--end-layer", type=int, default=31, help="End layer index")
|
||||
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
|
||||
args = parser.parse_args()
|
||||
|
||||
inference_engine = MLXFixedShardInferenceEngine(args.model_id, shard=Shard(model_id=args.model_id, n_layers=args.n_layers, start_layer=args.start_layer, end_layer=args.end_layer))
|
||||
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
|
||||
node = StandardNode(args.node_id, None, inference_engine, discovery)
|
||||
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy())
|
||||
server = GRPCServer(node, args.node_host, args.node_port)
|
||||
node.server = server
|
||||
|
||||
@@ -49,7 +51,7 @@ async def main():
|
||||
for s in [signal.SIGINT, signal.SIGTERM]:
|
||||
loop.add_signal_handler(s, handle_exit)
|
||||
|
||||
await node.start()
|
||||
await node.start(wait_for_peers=args.wait_for_peers)
|
||||
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
@@ -6,12 +6,13 @@ from typing import List, Dict
|
||||
from ..discovery import Discovery
|
||||
from ..peer_handle import PeerHandle
|
||||
from .grpc_peer_handle import GRPCPeerHandle
|
||||
from topology.device_capabilities import DeviceCapabilities, mac_device_capabilities
|
||||
from topology.device_capabilities import DeviceCapabilities, device_capabilities
|
||||
|
||||
class GRPCDiscovery(Discovery):
|
||||
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1):
|
||||
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None):
|
||||
self.node_id = node_id
|
||||
self.node_port = node_port
|
||||
self.device_capabilities = device_capabilities
|
||||
self.listen_port = listen_port
|
||||
self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
|
||||
self.broadcast_interval = broadcast_interval
|
||||
@@ -62,7 +63,9 @@ class GRPCDiscovery(Discovery):
|
||||
return list(self.known_peers.values())
|
||||
|
||||
async def _broadcast_presence(self):
|
||||
self.device_capabilities: DeviceCapabilities = mac_device_capabilities()
|
||||
if not self.device_capabilities:
|
||||
self.device_capabilities = device_capabilities()
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
sock.settimeout(0.5)
|
||||
@@ -70,7 +73,11 @@ class GRPCDiscovery(Discovery):
|
||||
"type": "discovery",
|
||||
"node_id": self.node_id,
|
||||
"grpc_port": self.node_port,
|
||||
"device_capabilities": self.device_capabilities.to_dict()
|
||||
"device_capabilities": {
|
||||
"model": self.device_capabilities.model,
|
||||
"chip": self.device_capabilities.chip,
|
||||
"memory": self.device_capabilities.memory
|
||||
}
|
||||
}).encode('utf-8')
|
||||
|
||||
while True:
|
||||
@@ -90,7 +97,8 @@ class GRPCDiscovery(Discovery):
|
||||
peer_id = message['node_id']
|
||||
peer_host = addr[0]
|
||||
peer_port = message['grpc_port']
|
||||
self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}")
|
||||
device_capabilities = DeviceCapabilities(**message['device_capabilities'])
|
||||
self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
|
||||
self.peer_last_seen[peer_id] = time.time()
|
||||
except Exception as e:
|
||||
print(f"Error in peer discovery: {e}")
|
||||
|
||||
@@ -8,15 +8,21 @@ from . import node_service_pb2_grpc
|
||||
|
||||
from ..peer_handle import PeerHandle
|
||||
from inference.shard import Shard
|
||||
from topology.topology import Topology
|
||||
from topology.device_capabilities import DeviceCapabilities
|
||||
|
||||
class GRPCPeerHandle(PeerHandle):
|
||||
def __init__(self, id: str, address: str):
|
||||
def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):
|
||||
self._id = id
|
||||
self.address = address
|
||||
self._device_capabilities = device_capabilities
|
||||
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
def device_capabilities(self) -> DeviceCapabilities:
|
||||
return self._device_capabilities
|
||||
|
||||
async def connect(self):
|
||||
self.channel = grpc.aio.insecure_channel(self.address)
|
||||
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
||||
@@ -54,3 +60,15 @@ class GRPCPeerHandle(PeerHandle):
|
||||
request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
|
||||
await self.stub.ResetShard(request)
|
||||
print(f"Reset shard {shard} on {self.address}")
|
||||
|
||||
async def collect_topology(self, max_depth: int) -> Topology:
|
||||
request = node_service_pb2.CollectTopologyRequest(max_depth=max_depth)
|
||||
response = await self.stub.CollectTopology(request)
|
||||
topology = Topology()
|
||||
for node_id, capabilities in response.nodes.items():
|
||||
device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory)
|
||||
topology.update_node(node_id, device_capabilities)
|
||||
for node_id, peers in response.peer_graph.items():
|
||||
for peer_id in peers.peer_ids:
|
||||
topology.add_edge(node_id, peer_id)
|
||||
return topology
|
||||
|
||||
@@ -48,3 +48,10 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
print(f"Received ResetShard request: {shard}")
|
||||
await self.node.reset_shard(shard)
|
||||
return node_service_pb2.Empty()
|
||||
|
||||
async def CollectTopology(self, request, context):
|
||||
max_depth = request.max_depth
|
||||
topology = await self.node.collect_topology(max_depth)
|
||||
nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory) for node_id, cap in topology.nodes.items()}
|
||||
peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
|
||||
return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
|
||||
|
||||
@@ -6,6 +6,7 @@ service NodeService {
|
||||
rpc SendPrompt (PromptRequest) returns (Tensor) {}
|
||||
rpc SendTensor (TensorRequest) returns (Tensor) {}
|
||||
rpc ResetShard (ResetShardRequest) returns (Empty) {}
|
||||
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
|
||||
}
|
||||
|
||||
message Shard {
|
||||
@@ -35,4 +36,23 @@ message ResetShardRequest {
|
||||
Shard shard = 1;
|
||||
}
|
||||
|
||||
message CollectTopologyRequest {
|
||||
int32 max_depth = 1;
|
||||
}
|
||||
|
||||
message Topology {
|
||||
map<string, DeviceCapabilities> nodes = 1;
|
||||
map<string, Peers> peer_graph = 2;
|
||||
}
|
||||
|
||||
message Peers {
|
||||
repeated string peer_ids = 1;
|
||||
}
|
||||
|
||||
message DeviceCapabilities {
|
||||
string model = 1;
|
||||
string chip = 2;
|
||||
int32 memory = 3;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
@@ -14,13 +14,17 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"C\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\"Y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"\x07\n\x05\x45mpty2\xd9\x01\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"C\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\"Y\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"+\n\x16\x43ollectTopologyRequest\x12\x11\n\tmax_depth\x18\x01 \x01(\x05\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"A\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\"\x07\n\x05\x45mpty2\xac\x02\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TOPOLOGY_NODESENTRY']._loaded_options = None
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
|
||||
_globals['_SHARD']._serialized_start=36
|
||||
_globals['_SHARD']._serialized_end=119
|
||||
_globals['_PROMPTREQUEST']._serialized_start=121
|
||||
@@ -31,8 +35,20 @@ if not _descriptor._USE_C_DESCRIPTORS:
|
||||
_globals['_TENSOR']._serialized_end=340
|
||||
_globals['_RESETSHARDREQUEST']._serialized_start=342
|
||||
_globals['_RESETSHARDREQUEST']._serialized_end=397
|
||||
_globals['_EMPTY']._serialized_start=399
|
||||
_globals['_EMPTY']._serialized_end=406
|
||||
_globals['_NODESERVICE']._serialized_start=409
|
||||
_globals['_NODESERVICE']._serialized_end=626
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=399
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=442
|
||||
_globals['_PEERS']._serialized_start=444
|
||||
_globals['_PEERS']._serialized_end=469
|
||||
_globals['_TOPOLOGY']._serialized_start=472
|
||||
_globals['_TOPOLOGY']._serialized_end=742
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=593
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=671
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=673
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=742
|
||||
_globals['_DEVICECAPABILITIES']._serialized_start=744
|
||||
_globals['_DEVICECAPABILITIES']._serialized_end=809
|
||||
_globals['_EMPTY']._serialized_start=811
|
||||
_globals['_EMPTY']._serialized_end=818
|
||||
_globals['_NODESERVICE']._serialized_start=821
|
||||
_globals['_NODESERVICE']._serialized_end=1121
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import node_service_pb2 as node__service__pb2
|
||||
from . import node_service_pb2 as node__service__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.64.1'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
@@ -54,6 +54,11 @@ class NodeServiceStub(object):
|
||||
request_serializer=node__service__pb2.ResetShardRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.CollectTopology = channel.unary_unary(
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Topology.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class NodeServiceServicer(object):
|
||||
@@ -77,6 +82,12 @@ class NodeServiceServicer(object):
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def CollectTopology(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_NodeServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
@@ -95,6 +106,11 @@ def add_NodeServiceServicer_to_server(servicer, server):
|
||||
request_deserializer=node__service__pb2.ResetShardRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'CollectTopology': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CollectTopology,
|
||||
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
|
||||
response_serializer=node__service__pb2.Topology.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'node_service.NodeService', rpc_method_handlers)
|
||||
@@ -186,3 +202,30 @@ class NodeService(object):
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def CollectTopology(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
node__service__pb2.Topology.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
from inference.shard import Shard
|
||||
from topology.device_capabilities import DeviceCapabilities
|
||||
from topology.topology import Topology
|
||||
|
||||
class PeerHandle(ABC):
|
||||
@abstractmethod
|
||||
@@ -32,3 +33,6 @@ class PeerHandle(ABC):
|
||||
@abstractmethod
|
||||
async def reset_shard(self, shard: Shard) -> None:
|
||||
pass
|
||||
|
||||
async def collect_topology(self, max_depth: int) -> Topology:
|
||||
pass
|
||||
|
||||
@@ -2,24 +2,28 @@ from typing import Optional
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from inference.shard import Shard
|
||||
from topology.topology import Topology
|
||||
|
||||
class Node(ABC):
|
||||
@abstractmethod
|
||||
def start(self, wait_for_peers: int = 0) -> None:
|
||||
async def start(self, wait_for_peers: int = 0) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process_tensor(self, shard: Shard, tensor: np.ndarray) -> None:
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process_prompt(self, shard: Shard, prompt: str) -> None:
|
||||
async def process_prompt(self, shard: Shard, prompt: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_shard(self, shard: Shard) -> None:
|
||||
async def reset_shard(self, shard: Shard) -> None:
|
||||
pass
|
||||
|
||||
async def collect_topology(self, max_depth: int = 2) -> Topology:
|
||||
pass
|
||||
|
||||
@@ -4,16 +4,20 @@ from networking import Discovery, PeerHandle, Server
|
||||
from inference.inference_engine import InferenceEngine, Shard
|
||||
from .node import Node
|
||||
from topology.topology import Topology
|
||||
from topology.device_capabilities import device_capabilities
|
||||
from topology.partitioning_strategy import PartitioningStrategy
|
||||
from topology.partitioning_strategy import Partition
|
||||
|
||||
class StandardNode(Node):
|
||||
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery):
|
||||
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None):
|
||||
self.id = id
|
||||
self.inference_engine = inference_engine
|
||||
self.server = server
|
||||
self.discovery = discovery
|
||||
self.partitioning_strategy = partitioning_strategy
|
||||
self.peers: List[PeerHandle] = {}
|
||||
self.topology: Topology = Topology()
|
||||
self.successor: Optional[PeerHandle] = None
|
||||
self.device_capabilities = device_capabilities()
|
||||
|
||||
async def start(self, wait_for_peers: int = 0) -> None:
|
||||
await self.server.start()
|
||||
@@ -24,6 +28,8 @@ class StandardNode(Node):
|
||||
for peer in self.peers:
|
||||
await peer.connect()
|
||||
print(f"Connected to {peer.id()}")
|
||||
await self.collect_topology()
|
||||
print(f"Collected topology: {self.topology}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self.discovery.stop()
|
||||
@@ -32,30 +38,62 @@ class StandardNode(Node):
|
||||
async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.array]:
|
||||
print("Process prompt", shard, prompt)
|
||||
result = await self.inference_engine.infer_prompt(shard, prompt)
|
||||
# Implement prompt processing logic
|
||||
print(f"Got result from prompt: {prompt}. Result: {result}")
|
||||
# You might want to initiate inference here
|
||||
if self.successor:
|
||||
await self.succesor.send_tensor()
|
||||
|
||||
await self.forward_tensor_to_next_shard(shard, result)
|
||||
|
||||
return result
|
||||
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray, target: Optional[str] = None) -> None:
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> None:
|
||||
print("Process tensor", shard, tensor)
|
||||
result = await self.inference_engine.infer_shard(shard, tensor)
|
||||
# Implement prompt processing logic
|
||||
print(f"Got result from prompt: {len(tensor)}. Result: {result}")
|
||||
result = await self.inference_engine.infer_tensor(shard, tensor)
|
||||
print(f"Got result from tensor: {len(tensor)}. Result: {result}")
|
||||
|
||||
if target:
|
||||
target_peer = next((p for p in self.peers if p.id() == target), None)
|
||||
if not target_peer:
|
||||
raise ValueError(f"Peer {target} not found")
|
||||
|
||||
await target_peer.send_tensor(result)
|
||||
await self.forward_tensor_to_next_shard(shard, result)
|
||||
|
||||
return result
|
||||
|
||||
async def forward_tensor_to_next_shard(self, shard: Shard, tensor: np.ndarray) -> None:
|
||||
if not self.partitioning_strategy:
|
||||
print("No partitioning strategy found. Skipping forward.")
|
||||
return
|
||||
|
||||
partitions = self.partitioning_strategy.partition(self.topology)
|
||||
current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
|
||||
print(f"Current partition index: {current_partition_index}")
|
||||
if current_partition_index is not None:
|
||||
next_partition_index = (current_partition_index + 1) % len(partitions)
|
||||
next_partition: Partition = partitions[next_partition_index]
|
||||
print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
|
||||
|
||||
if next_partition:
|
||||
target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
||||
if not target_peer:
|
||||
raise ValueError(f"Peer for {next_partition} not found")
|
||||
|
||||
start_layer = int(next_partition.start * shard.n_layers)
|
||||
end_layer = int(next_partition.end * shard.n_layers) - 1
|
||||
next_shard = Shard(shard.model_id, start_layer, end_layer, shard.n_layers)
|
||||
|
||||
print(f"Sending tensor to {target_peer.id()} for shard: {next_shard}")
|
||||
|
||||
await target_peer.send_tensor(next_shard, tensor)
|
||||
|
||||
async def reset_shard(self, shard: Shard) -> None:
|
||||
# Implement shard reset logic
|
||||
print(f"Resetting shard: {shard}")
|
||||
await self.inference_engine.reset_shard(shard)
|
||||
|
||||
async def collect_topology(self, max_depth: int = 4) -> Topology:
|
||||
self.topology.update_node(self.id, self.device_capabilities)
|
||||
|
||||
for peer in self.peers:
|
||||
self.topology.update_node(peer.id(), peer.device_capabilities())
|
||||
self.topology.add_edge(self.id, peer.id())
|
||||
|
||||
if max_depth > 0:
|
||||
other_topology = await peer.collect_topology(max_depth = max_depth - 1)
|
||||
print(f"Collected topology from: {peer.id()}: {other_topology}")
|
||||
self.topology.merge(other_topology)
|
||||
|
||||
return self.topology
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
import subprocess
|
||||
import platform
|
||||
|
||||
@dataclass
|
||||
class DeviceCapabilities:
|
||||
@@ -7,6 +8,17 @@ class DeviceCapabilities:
|
||||
chip: str
|
||||
memory: int
|
||||
|
||||
def device_capabilities() -> DeviceCapabilities:
|
||||
system = platform.system()
|
||||
if system == 'Darwin':
|
||||
return mac_device_capabilities()
|
||||
# elif system == 'Linux':
|
||||
# return linux_device_capabilities()
|
||||
# elif system == 'Windows':
|
||||
# return windows_device_capabilities()
|
||||
else:
|
||||
return DeviceCapabilities(model="Unknown Model", chip="Unknown Chip", memory=0)
|
||||
|
||||
def mac_device_capabilities() -> DeviceCapabilities:
|
||||
# Fetch the model of the Mac using system_profiler
|
||||
model = subprocess.check_output(['system_profiler', 'SPHardwareDataType']).decode('utf-8')
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
from inference.shard import Shard
|
||||
from networking.peer_handle import PeerHandle
|
||||
from .topology import Topology
|
||||
|
||||
# Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
|
||||
@dataclass
|
||||
class Partition:
|
||||
node_id: str
|
||||
start: float
|
||||
end: float
|
||||
|
||||
class PartitioningStrategy(ABC):
|
||||
def node_id(self) -> str:
|
||||
pass
|
||||
|
||||
class PartitioningStrategy(ABC):
|
||||
@abstractmethod
|
||||
def next_shard(self, current_shard: Shard, topology: Topology, node_stats: dict) -> Shard:
|
||||
def partition(self, topology: Topology) -> List[Partition]:
|
||||
pass
|
||||
|
||||
@@ -1,27 +1,18 @@
|
||||
from typing import List
|
||||
from .partitioning_strategy import PartitioningStrategy
|
||||
from inference.shard import Shard
|
||||
from .topology import Topology
|
||||
from .partitioning_strategy import Partition
|
||||
|
||||
class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
|
||||
def next_shard(self, current_shard: Shard, topology: Topology, node_stats: dict) -> Shard:
|
||||
# Get all nodes from the topology and include the current node
|
||||
def partition(self, topology: Topology) -> List[Partition]:
|
||||
nodes = list(topology.all_nodes())
|
||||
nodes.append((self.id, None, node_stats))
|
||||
|
||||
# Sort nodes by their IDs
|
||||
nodes.sort(key=lambda x: x[0])
|
||||
|
||||
# Calculate the total memory of all nodes
|
||||
total_memory = sum(node[2]['memory'] for node in nodes)
|
||||
|
||||
# Calculate the number of layers to assign to each node proportional to its memory
|
||||
layers_per_node = {node[0]: (node[2]['memory'] / total_memory) * current_shard.n_layers for node in nodes}
|
||||
|
||||
# Find the successor node
|
||||
node_ids = [node[0] for node in nodes]
|
||||
current_index = node_ids.index(self.id)
|
||||
successor_index = (current_index + 1) % len(node_ids)
|
||||
successor_id = node_ids[successor_index]
|
||||
|
||||
# Return the Shard calculated for the successor
|
||||
return Shard(successor_id, layers_per_node[successor_id])
|
||||
total_memory = sum(node[1].memory for node in nodes)
|
||||
partitions = []
|
||||
start = 0
|
||||
for node in nodes:
|
||||
end = start + (node[1].memory / total_memory)
|
||||
partitions.append(Partition(node[0], start, end))
|
||||
start = end
|
||||
return partitions
|
||||
|
||||
31
topology/test_ring_memory_weighted_partitioning_strategy.py
Normal file
31
topology/test_ring_memory_weighted_partitioning_strategy.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from .ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
from .topology import Topology, DeviceCapabilities
|
||||
from .partitioning_strategy import Partition
|
||||
|
||||
class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
|
||||
def test_partition(self):
|
||||
# triangle
|
||||
# node1 -> node2 -> node3 -> node1
|
||||
topology = Topology()
|
||||
topology.update_node('node1', DeviceCapabilities(model="test1", chip="test1", memory=100))
|
||||
topology.update_node('node2', DeviceCapabilities(model="test2", chip="test2", memory=300))
|
||||
topology.update_node('node3', DeviceCapabilities(model="test3", chip="test3", memory=600))
|
||||
topology.add_edge('node1', 'node2')
|
||||
topology.add_edge('node2', 'node3')
|
||||
topology.add_edge('node3', 'node1')
|
||||
topology.add_edge('node1', 'node3')
|
||||
|
||||
strategy = RingMemoryWeightedPartitioningStrategy()
|
||||
partitions = strategy.partition(topology)
|
||||
|
||||
self.assertEqual(len(partitions), 3)
|
||||
self.assertEqual(partitions, [
|
||||
Partition('node1', 0.0, 0.1),
|
||||
Partition('node2', 0.1, 0.4),
|
||||
Partition('node3', 0.4, 1.0)
|
||||
])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,12 +1,47 @@
|
||||
from .device_capabilities import DeviceCapabilities
|
||||
from typing import Dict, Set
|
||||
|
||||
class Topology:
|
||||
def __init__(self):
|
||||
self.nodes = {} # Maps node IDs to a tuple of (host, port, stats)
|
||||
self.nodes: Dict[str, DeviceCapabilities] = {} # Maps node IDs to DeviceCapabilities
|
||||
self.peer_graph: Dict[str, Set[str]] = {} # Adjacency list representing the graph
|
||||
|
||||
def update_node(self, node_id, stats):
|
||||
self.nodes[node_id] = stats
|
||||
def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
|
||||
self.nodes[node_id] = device_capabilities
|
||||
|
||||
def get_node(self, node_id):
|
||||
def get_node(self, node_id: str) -> DeviceCapabilities:
|
||||
return self.nodes.get(node_id)
|
||||
|
||||
def all_nodes(self):
|
||||
return self.nodes.items()
|
||||
|
||||
def add_edge(self, node1_id: str, node2_id: str):
|
||||
if node1_id not in self.peer_graph:
|
||||
self.peer_graph[node1_id] = set()
|
||||
if node2_id not in self.peer_graph:
|
||||
self.peer_graph[node2_id] = set()
|
||||
self.peer_graph[node1_id].add(node2_id)
|
||||
self.peer_graph[node2_id].add(node1_id)
|
||||
|
||||
def get_neighbors(self, node_id: str) -> Set[str]:
|
||||
return self.peer_graph.get(node_id, set())
|
||||
|
||||
def all_edges(self):
|
||||
edges = []
|
||||
for node, neighbors in self.peer_graph.items():
|
||||
for neighbor in neighbors:
|
||||
if (neighbor, node) not in edges: # Avoid duplicate edges
|
||||
edges.append((node, neighbor))
|
||||
return edges
|
||||
|
||||
def merge(self, other: 'Topology'):
|
||||
for node_id, capabilities in other.nodes.items():
|
||||
self.update_node(node_id, capabilities)
|
||||
for node_id, neighbors in other.peer_graph.items():
|
||||
for neighbor in neighbors:
|
||||
self.add_edge(node_id, neighbor)
|
||||
|
||||
def __str__(self):
|
||||
nodes_str = ', '.join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())
|
||||
edges_str = ', '.join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items())
|
||||
return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"
|
||||
|
||||
Reference in New Issue
Block a user