tensor_list_data = 2;
+ string other_data_json = 3;
+}
+
message CollectTopologyRequest {
repeated string visited = 1;
int32 max_depth = 2;
@@ -85,8 +97,9 @@ message DeviceCapabilities {
message SendNewTokenRequest {
string request_id = 1;
- int32 token = 2;
- bool is_finished = 3;
+ repeated int32 result = 2;
+ optional Tensor tensor = 3;
+ bool is_finished = 4;
}
message SendOpaqueStatusRequest {
diff --git a/exo/networking/grpc/node_service_pb2.py b/exo/networking/grpc/node_service_pb2.py
index 7379eb69..6ff71086 100644
--- a/exo/networking/grpc/node_service_pb2.py
+++ b/exo/networking/grpc/node_service_pb2.py
@@ -24,55 +24,67 @@ _sym_db = _symbol_database.Default()
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"k\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\x81\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"M\n\x13SendNewTokenRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\r\n\x05token\x18\x02 \x01(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x99\x04\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12H\n\x0cSendNewToken\x12!.node_service.SendNewTokenRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xbb\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x84\x01\n\x13SendNewTokenRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x99\x04\n\x0bNodeService\x12@\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x13.node_service.Empty\"\x00\x12@\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x13.node_service.Empty\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12H\n\x0cSendNewToken\x12!.node_service.SendNewTokenRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
+ _globals['_INFERENCESTATE_TENSORDATAENTRY']._loaded_options = None
+ _globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_options = b'8\001'
+ _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._loaded_options = None
+ _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_options = b'8\001'
_globals['_TOPOLOGY_NODESENTRY']._loaded_options = None
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
_globals['_SHARD']._serialized_start=36
_globals['_SHARD']._serialized_end=119
- _globals['_PROMPTREQUEST']._serialized_start=121
- _globals['_PROMPTREQUEST']._serialized_end=228
- _globals['_TENSORREQUEST']._serialized_start=231
- _globals['_TENSORREQUEST']._serialized_end=360
- _globals['_EXAMPLEREQUEST']._serialized_start=363
- _globals['_EXAMPLEREQUEST']._serialized_end=585
- _globals['_LOSS']._serialized_start=587
- _globals['_LOSS']._serialized_end=659
- _globals['_TENSOR']._serialized_start=661
- _globals['_TENSOR']._serialized_end=720
- _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=722
- _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=782
- _globals['_TOPOLOGY']._serialized_start=785
- _globals['_TOPOLOGY']._serialized_end=1065
- _globals['_TOPOLOGY_NODESENTRY']._serialized_start=906
- _globals['_TOPOLOGY_NODESENTRY']._serialized_end=984
- _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=986
- _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1065
- _globals['_PEERCONNECTION']._serialized_start=1067
- _globals['_PEERCONNECTION']._serialized_end=1140
- _globals['_PEERCONNECTIONS']._serialized_start=1142
- _globals['_PEERCONNECTIONS']._serialized_end=1210
- _globals['_DEVICEFLOPS']._serialized_start=1212
- _globals['_DEVICEFLOPS']._serialized_end=1267
- _globals['_DEVICECAPABILITIES']._serialized_start=1269
- _globals['_DEVICECAPABILITIES']._serialized_end=1376
- _globals['_SENDNEWTOKENREQUEST']._serialized_start=1378
- _globals['_SENDNEWTOKENREQUEST']._serialized_end=1455
- _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1457
- _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1518
- _globals['_HEALTHCHECKREQUEST']._serialized_start=1520
- _globals['_HEALTHCHECKREQUEST']._serialized_end=1540
- _globals['_HEALTHCHECKRESPONSE']._serialized_start=1542
- _globals['_HEALTHCHECKRESPONSE']._serialized_end=1583
- _globals['_EMPTY']._serialized_start=1585
- _globals['_EMPTY']._serialized_end=1592
- _globals['_NODESERVICE']._serialized_start=1595
- _globals['_NODESERVICE']._serialized_end=2132
+ _globals['_PROMPTREQUEST']._serialized_start=122
+ _globals['_PROMPTREQUEST']._serialized_end=309
+ _globals['_TENSORREQUEST']._serialized_start=312
+ _globals['_TENSORREQUEST']._serialized_end=521
+ _globals['_EXAMPLEREQUEST']._serialized_start=524
+ _globals['_EXAMPLEREQUEST']._serialized_end=746
+ _globals['_LOSS']._serialized_start=748
+ _globals['_LOSS']._serialized_end=820
+ _globals['_TENSOR']._serialized_start=822
+ _globals['_TENSOR']._serialized_end=881
+ _globals['_TENSORLIST']._serialized_start=883
+ _globals['_TENSORLIST']._serialized_end=934
+ _globals['_INFERENCESTATE']._serialized_start=937
+ _globals['_INFERENCESTATE']._serialized_end=1275
+ _globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_start=1123
+ _globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_end=1194
+ _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_start=1196
+ _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_end=1275
+ _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=1277
+ _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=1337
+ _globals['_TOPOLOGY']._serialized_start=1340
+ _globals['_TOPOLOGY']._serialized_end=1620
+ _globals['_TOPOLOGY_NODESENTRY']._serialized_start=1461
+ _globals['_TOPOLOGY_NODESENTRY']._serialized_end=1539
+ _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1541
+ _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1620
+ _globals['_PEERCONNECTION']._serialized_start=1622
+ _globals['_PEERCONNECTION']._serialized_end=1695
+ _globals['_PEERCONNECTIONS']._serialized_start=1697
+ _globals['_PEERCONNECTIONS']._serialized_end=1765
+ _globals['_DEVICEFLOPS']._serialized_start=1767
+ _globals['_DEVICEFLOPS']._serialized_end=1822
+ _globals['_DEVICECAPABILITIES']._serialized_start=1824
+ _globals['_DEVICECAPABILITIES']._serialized_end=1931
+ _globals['_SENDNEWTOKENREQUEST']._serialized_start=1934
+ _globals['_SENDNEWTOKENREQUEST']._serialized_end=2066
+ _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=2068
+ _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=2129
+ _globals['_HEALTHCHECKREQUEST']._serialized_start=2131
+ _globals['_HEALTHCHECKREQUEST']._serialized_end=2151
+ _globals['_HEALTHCHECKRESPONSE']._serialized_start=2153
+ _globals['_HEALTHCHECKRESPONSE']._serialized_end=2194
+ _globals['_EMPTY']._serialized_start=2196
+ _globals['_EMPTY']._serialized_end=2203
+ _globals['_NODESERVICE']._serialized_start=2206
+ _globals['_NODESERVICE']._serialized_end=2743
# @@protoc_insertion_point(module_scope)
diff --git a/exo/networking/manual/manual_discovery.py b/exo/networking/manual/manual_discovery.py
index 8287605e..35a8fabe 100644
--- a/exo/networking/manual/manual_discovery.py
+++ b/exo/networking/manual/manual_discovery.py
@@ -1,7 +1,9 @@
+import os
import asyncio
-from exo.networking.discovery import Discovery
-from typing import Dict, List, Callable
+from typing import Dict, List, Callable, Optional
+from concurrent.futures import ThreadPoolExecutor
+from exo.networking.discovery import Discovery
from exo.topology.device_capabilities import DeviceCapabilities
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
from exo.helpers import DEBUG_DISCOVERY
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
self,
network_config_path: str,
node_id: str,
- create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+ create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
):
- self.topology = NetworkTopology.from_path(network_config_path)
+ self.network_config_path = network_config_path
+ self.node_id = node_id
self.create_peer_handle = create_peer_handle
- if node_id not in self.topology.peers:
- raise ValueError(
- f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
- )
-
self.listen_task = None
-
self.known_peers: Dict[str, PeerHandle] = {}
- self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
- self.peers_in_network.pop(node_id)
+
+ self._cached_peers: Dict[str, PeerConfig] = {}
+ self._last_modified_time: Optional[float] = None
+ self._file_executor = ThreadPoolExecutor(max_workers=1)
async def start(self) -> None:
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
async def stop(self) -> None:
- if self.listen_task:
- self.listen_task.cancel()
+ if self.listen_task: self.listen_task.cancel()
+ self._file_executor.shutdown(wait=True)
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if wait_for_peers > 0:
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
async def task_find_peers_from_config(self):
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
while True:
- for peer_id, peer_config in self.peers_in_network.items():
+ peers_from_config = await self._get_peers()
+ new_known_peers = {}
+ for peer_id, peer_config in peers_from_config.items():
try:
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
peer = self.known_peers.get(peer_id)
@@ -57,15 +58,43 @@ class ManualDiscovery(Discovery):
is_healthy = await peer.health_check()
if is_healthy:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
- self.known_peers[peer_id] = peer
- else:
- if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
- try:
- del self.known_peers[peer_id]
- except KeyError:
- pass
+ new_known_peers[peer_id] = peer
+ elif DEBUG_DISCOVERY >= 2:
+ print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
await asyncio.sleep(5.0)
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
+
+ async def _get_peers(self):
+ try:
+ loop = asyncio.get_running_loop()
+ current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
+
+ if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
+ return self._cached_peers
+
+ topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
+
+ if self.node_id not in topology.peers:
+ raise ValueError(
+ f"Node ID {self.node_id} not found in network config file "
+ f"{self.network_config_path}. Please run with `node_id` set to "
+ f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
+ )
+
+ peers_in_network = topology.peers
+ peers_in_network.pop(self.node_id)
+
+ self._cached_peers = peers_in_network
+ self._last_modified_time = current_mtime
+
+ return peers_in_network
+
+ except Exception as e:
+ if DEBUG_DISCOVERY >= 2:
+ print(f"Error when loading network config file from {self.network_config_path}. "
+ f"Please update the config file in order to successfully discover peers. "
+ f"Exception: {e}")
+ return self._cached_peers
diff --git a/exo/networking/manual/test_data/test_config.json b/exo/networking/manual/test_data/test_config.json
index b50ef635..54eced72 100644
--- a/exo/networking/manual/test_data/test_config.json
+++ b/exo/networking/manual/test_data/test_config.json
@@ -29,4 +29,4 @@
}
}
}
-}
+}
\ No newline at end of file
diff --git a/exo/networking/manual/test_manual_discovery.py b/exo/networking/manual/test_manual_discovery.py
index 69f45fa1..317fba9d 100644
--- a/exo/networking/manual/test_manual_discovery.py
+++ b/exo/networking/manual/test_manual_discovery.py
@@ -1,3 +1,4 @@
+import json
import asyncio
import unittest
from unittest import mock
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.peer1 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
- self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
- _ = self.discovery1.start()
+ self.discovery1 = ManualDiscovery(
+ root_path,
+ "node1",
+ create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
+ )
+ await self.discovery1.start()
async def asyncTearDown(self):
await self.discovery1.stop()
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
self.peer2 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.peer2.connect = mock.AsyncMock()
- self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
- self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
+ self.discovery1 = ManualDiscovery(
+ root_path,
+ "node1",
+ create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
+ )
+ self.discovery2 = ManualDiscovery(
+ root_path,
+ "node2",
+ create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
+ )
await self.discovery1.start()
await self.discovery2.start()
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
await self.server1.start()
await self.server2.start()
- self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
- self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+ self.discovery1 = ManualDiscovery(
+ root_path,
+ "node1",
+ create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
+ )
+ self.discovery2 = ManualDiscovery(
+ root_path,
+ "node2",
+ create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
+ )
await self.discovery1.start()
await self.discovery2.start()
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.assertFalse(await peers1[0].is_connected())
self.assertFalse(await peers2[0].is_connected())
+ async def test_dynamic_config_update(self):
+ initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
+ self.assertEqual(len(initial_peers), 1)
+
+ # Save original config for cleanup
+ with open(root_path, "r") as f:
+ original_config = json.load(f)
+
+ try:
+ updated_config = {
+ "peers": {
+ **original_config["peers"],
+ "node3": {
+ "address": "localhost",
+ "port": 50053,
+ "device_capabilities": {
+ "model": "Unknown Model",
+ "chip": "Unknown Chip",
+ "memory": 0,
+ "flops": {"fp32": 0, "fp16": 0, "int8": 0},
+ },
+ },
+ }
+ }
+
+ with open(root_path, "w") as f:
+ json.dump(updated_config, f, indent=2)
+
+ node3 = mock.AsyncMock(spec=Node)
+ server3 = GRPCServer(node3, "localhost", 50053)
+ await server3.start()
+
+ try:
+ # Wait for the config to be reloaded
+ await asyncio.sleep(1.5)
+
+ updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
+ self.assertEqual(len(updated_peers), 2)
+
+ for peer in updated_peers:
+ await peer.connect()
+ self.assertTrue(await peer.is_connected())
+
+ finally:
+ await server3.stop()
+
+ finally:
+ # Restore the original config file
+ with open(root_path, "w") as f:
+ json.dump(original_config, f, indent=2)
+
+ # Wait for the config to be reloaded again
+ await asyncio.sleep(1.5)
+
+ updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
+ self.assertEqual(len(updated_peers), 1)
+
if __name__ == "__main__":
asyncio.run(unittest.main())
diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py
index ebf9b673..00453deb 100644
--- a/exo/orchestration/node.py
+++ b/exo/orchestration/node.py
@@ -118,44 +118,50 @@ class Node:
shard,
result: np.ndarray,
request_id: Optional[str] = None,
+ inference_state: Optional[dict] = None,
):
- if request_id not in self.buffered_token_output:
- self.buffered_token_output[request_id] = ([], False)
- is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-
- if shard.is_last_layer() and not is_finished:
- self.token_count += 1
- if self.token_count == 1:
- self.first_token_time = time.perf_counter_ns()
- if self.token_count % 20 == 0:
- print(f"[{request_id}] TPS: {self.token_count / ((time.perf_counter_ns() - self.first_token_time) / 1e9)}")
-
- token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
- await self.inference_engine.ensure_shard(shard)
- self.buffered_token_output[request_id][0].append(token.item())
- is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
- if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
- forward = token.reshape(1, -1)
- self.trigger_on_token_callbacks(request_id, token.item(), is_finished)
- asyncio.create_task(self.broadcast_new_token(request_id, token.item(), is_finished))
+ if shard.model_id != 'stable-diffusion-2-1-base':
+ if request_id not in self.buffered_token_output:
+ self.buffered_token_output[request_id] = ([], False)
+ is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+ if shard.is_last_layer() and not is_finished:
+ token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
+ await self.inference_engine.ensure_shard(shard)
+ self.buffered_token_output[request_id][0].append(token.item())
+ is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+ if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+ asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
+ forward = token.reshape(1, -1)
+ intermediate_result = self.buffered_token_output[request_id][0]
+ else:
+ forward = result
else:
+ await self.inference_engine.ensure_shard(shard)
+ is_finished = inference_state.get("is_finished", False)
+ intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
forward = result
+ if shard.is_last_layer():
+ self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
+ asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
if is_finished:
- self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+ if shard.model_id != 'stable-diffusion-2-1-base':
+ self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
self.outstanding_requests.pop(request_id)
else:
self.outstanding_requests[request_id] = "waiting"
- asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
+ asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
+
+ return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
- return np.array(self.buffered_token_output[request_id][0])
async def process_prompt(
self,
base_shard: Shard,
prompt: str,
request_id: Optional[str] = None,
- ) -> None:
+ inference_state: Optional[dict] = {},
+ ) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
start_time = time.perf_counter_ns()
asyncio.create_task(
@@ -172,7 +178,8 @@ class Node:
}),
)
)
- await self._process_prompt(base_shard, prompt, request_id)
+ start_time = time.perf_counter_ns()
+ resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
asyncio.create_task(
@@ -192,7 +199,7 @@ class Node:
)
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
- async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+ async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
if request_id is None:
request_id = str(uuid.uuid4())
shard = self.get_current_shard(base_shard)
@@ -201,12 +208,13 @@ class Node:
if not shard.is_first_layer():
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
self.outstanding_requests[request_id] = "waiting"
- await self.forward_prompt(shard, prompt, request_id, 0)
+ resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
return None
-
- self.outstanding_requests[request_id] = "processing"
- result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
- await self.process_inference_result(shard, result, request_id)
+ else:
+ self.outstanding_requests[request_id] = "processing"
+ result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
+ ret = await self.process_inference_result(shard, result, request_id, inference_state)
+ return result
async def enqueue_example(
self,
@@ -350,10 +358,11 @@ class Node:
base_shard: Shard,
tensor: np.ndarray,
request_id: Optional[str] = None,
- ) -> None:
+ inference_state: Optional[dict] = None,
+ ) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
start_time = time.perf_counter_ns()
- await self._process_tensor(shard, tensor, request_id)
+ resp = await self._process_tensor(shard, tensor, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
@@ -363,15 +372,17 @@ class Node:
base_shard: Shard,
tensor: np.ndarray,
request_id: Optional[str] = None,
- ) -> None:
+ inference_state: Optional[dict] = None,
+ ) -> Optional[np.ndarray]:
if request_id is None:
request_id = str(uuid.uuid4())
shard = self.get_current_shard(base_shard)
try:
self.outstanding_requests[request_id] = "processing"
- result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
- await self.process_inference_result(shard, result, request_id)
+ result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
+ ret = await self.process_inference_result(shard, result, request_id, inference_state)
+ return ret
except Exception as e:
self.outstanding_requests.pop(request_id)
print(f"Error processing tensor for shard {shard}: {e}")
@@ -404,19 +415,20 @@ class Node:
prompt: str,
request_id: str,
target_index: int,
+ inference_state: Optional[dict] = None,
) -> None:
if DEBUG >= 1: print(f"target partition index: {target_index}")
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
next_shard = self.get_current_shard(base_shard, target_index)
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
if target_id == self.id:
- await self.process_prompt(next_shard, prompt, request_id)
+ await self.process_prompt(next_shard, prompt, request_id, inference_state)
else:
target_peer = next((p for p in self.peers if p.id() == target_id), None)
if not target_peer:
raise ValueError(f"Peer for {target_index} not found")
if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
- await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+ await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
async def forward_tensor(
self,
@@ -424,19 +436,20 @@ class Node:
tensor: np.ndarray,
request_id: str,
target_index: int,
+ inference_state: Optional[dict] = None,
) -> None:
if DEBUG >= 1: print(f"target partition index: {target_index}")
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
next_shard = self.get_current_shard(base_shard, target_index)
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
if target_id == self.id:
- await self.process_tensor(next_shard, tensor, request_id)
+ await self.process_tensor(next_shard, tensor, request_id, inference_state)
else:
target_peer = next((p for p in self.peers if p.id() == target_id), None)
if not target_peer:
raise ValueError(f"Peer for {target_index} not found")
if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
- await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
+ await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
def get_partition_index(self, offset: int = 0):
if not self.partitioning_strategy:
@@ -604,3 +617,12 @@ class Node:
@property
def current_topology(self) -> Topology:
return self.topology
+
+ def handle_stable_diffusion(self, inference_state, result):
+ if inference_state['is_step_finished']:
+ inference_state['step']+=1
+ progress = [inference_state['step'],inference_state['total_steps']]
+ intermediate_result = result
+ if progress[0] == progress[1]:
+ intermediate_result = result
+ return intermediate_result, inference_state
diff --git a/exo/orchestration/tracing.py b/exo/orchestration/tracing.py
new file mode 100644
index 00000000..4466fc7d
--- /dev/null
+++ b/exo/orchestration/tracing.py
@@ -0,0 +1,166 @@
+from dataclasses import dataclass
+from typing import Dict, Optional, Any
+from opentelemetry import trace, context
+from opentelemetry.trace import Status, StatusCode, SpanContext
+from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
+from contextlib import contextmanager
+import time
+from threading import Lock
+
+@dataclass
+class TraceContext:
+ request_id: str
+ sequence_number: int
+ current_span: Optional[trace.Span] = None
+ trace_parent: Optional[str] = None
+ token_group_span: Optional[trace.Span] = None
+ token_count: int = 0
+ token_group_size: int = 10 # Default group size
+ request_span: Optional[trace.Span] = None # Track the main request span
+
+class Tracer:
+ def __init__(self):
+ self.tracer = trace.get_tracer("exo")
+ self.contexts: Dict[str, TraceContext] = {}
+ self._lock = Lock()
+ self.propagator = TraceContextTextMapPropagator()
+
+ def get_context(self, request_id: str) -> Optional[TraceContext]:
+ with self._lock:
+ return self.contexts.get(request_id)
+
+ def set_context(self, request_id: str, context: TraceContext):
+ with self._lock:
+ self.contexts[request_id] = context
+
+ def inject_context(self, span: trace.Span) -> str:
+ """Inject current span context into carrier for propagation"""
+ carrier = {}
+ ctx = trace.set_span_in_context(span)
+ self.propagator.inject(carrier, context=ctx)
+ return carrier.get("traceparent", "")
+
+ def extract_context(self, trace_parent: str) -> Optional[context.Context]:
+ """Extract span context from carrier"""
+ if not trace_parent:
+ return None
+ carrier = {"traceparent": trace_parent}
+ return self.propagator.extract(carrier)
+
+ def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext:
+ """Create a new context with the given trace parent"""
+ parent_ctx = self.extract_context(trace_parent)
+ if parent_ctx:
+ # Create a new request span that links to the parent context
+ request_span = self.tracer.start_span(
+ "request",
+ context=parent_ctx,
+ attributes={
+ "request_id": request_id,
+ "sequence_number": sequence_number
+ }
+ )
+ return TraceContext(
+ request_id=request_id,
+ sequence_number=sequence_number,
+ request_span=request_span,
+ current_span=request_span,
+ trace_parent=trace_parent
+ )
+ return TraceContext(request_id=request_id, sequence_number=sequence_number)
+
+ def handle_token(self, context: TraceContext, token: int, is_finished: bool = False):
+ """Handle token generation and manage token group spans"""
+ context.token_count += 1
+
+ # Start a new token group span if needed
+ if not context.token_group_span and context.request_span:
+ group_number = (context.token_count - 1) // context.token_group_size + 1
+
+ # Create token group span as child of request span
+ parent_ctx = trace.set_span_in_context(context.request_span)
+ context.token_group_span = self.tracer.start_span(
+ f"token_group_{group_number}",
+ context=parent_ctx,
+ attributes={
+ "request_id": context.request_id,
+ "group.number": group_number,
+ "group.start_token": context.token_count,
+ "group.max_tokens": context.token_group_size
+ }
+ )
+
+ # Add token to current group span
+ if context.token_group_span:
+ relative_pos = ((context.token_count - 1) % context.token_group_size) + 1
+ context.token_group_span.set_attribute(f"token.{relative_pos}", token)
+ context.token_group_span.set_attribute("token.count", relative_pos)
+
+ # End current group span if we've reached the group size or if generation is finished
+ if context.token_count % context.token_group_size == 0 or is_finished:
+ context.token_group_span.set_attribute("token.final_count", relative_pos)
+ context.token_group_span.end()
+ context.token_group_span = None
+
+ @contextmanager
+ def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None):
+ """Start a new span with proper parent context"""
+ attributes = {
+ "request_id": context.request_id,
+ "sequence_number": context.sequence_number
+ }
+ if extra_attributes:
+ attributes.update(extra_attributes)
+
+ # Use request span as parent if available
+ parent_ctx = None
+ if context.request_span:
+ parent_ctx = trace.set_span_in_context(context.request_span)
+ elif context.trace_parent:
+ parent_ctx = self.extract_context(context.trace_parent)
+ if parent_ctx and not context.request_span:
+ # Create a new request span that links to the parent context
+ context.request_span = self.tracer.start_span(
+ "request",
+ context=parent_ctx,
+ attributes={
+ "request_id": context.request_id,
+ "sequence_number": context.sequence_number
+ }
+ )
+ parent_ctx = trace.set_span_in_context(context.request_span)
+ elif context.current_span:
+ parent_ctx = trace.set_span_in_context(context.current_span)
+
+ # Create span with parent context if it exists
+ if parent_ctx:
+ span = self.tracer.start_span(
+ name,
+ context=parent_ctx,
+ attributes=attributes
+ )
+ else:
+ span = self.tracer.start_span(
+ name,
+ attributes=attributes
+ )
+
+ # Update context with current span
+ prev_span = context.current_span
+ context.current_span = span
+
+ try:
+ start_time = time.perf_counter()
+ yield span
+ duration = time.perf_counter() - start_time
+ span.set_attribute("duration_s", duration)
+ span.set_status(Status(StatusCode.OK))
+ except Exception as e:
+ span.set_status(Status(StatusCode.ERROR, str(e)))
+ raise
+ finally:
+ span.end()
+ context.current_span = prev_span
+
+# Global tracer instance
+tracer = Tracer()
\ No newline at end of file
diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html
index 4e0617e4..013d0d63 100644
--- a/exo/tinychat/index.html
+++ b/exo/tinychat/index.html
@@ -197,7 +197,25 @@
const div = document.createElement('div');
div.className = `message message-role-${role}`;
try {
- div.innerHTML = DOMPurify.sanitize(marked.parse(content));
+ if (content.includes('![Generated Image]')) {
+ const imageUrl = content.match(/\((.*?)\)/)[1];
+ const img = document.createElement('img');
+ img.src = imageUrl;
+ img.alt = 'Generated Image';
+ img.onclick = async () => {
+ try {
+ const response = await fetch(img.src);
+ const blob = await response.blob();
+ const file = new File([blob], 'image.png', { type: 'image/png' });
+ handleImageUpload({ target: { files: [file] } });
+ } catch (error) {
+ console.error('Error fetching image:', error);
+ }
+ };
+ div.appendChild(img);
+ } else {
+ div.innerHTML = DOMPurify.sanitize(marked.parse(content));
+ }
} catch (e) {
console.log(content);
console.error(e);
@@ -281,7 +299,7 @@
-