allow update to manual discovery file

re-load manual discovery file for each runthrough of the peer network, allowing incremental updates to the peer file even when exo is running
This commit is contained in:
Ian Paul
2024-10-24 09:10:10 +07:00
parent 496a3b49f5
commit 98118babae
3 changed files with 148 additions and 57 deletions

View File

@@ -9,63 +9,107 @@ from exo.networking.peer_handle import PeerHandle
class ManualDiscovery(Discovery):
def __init__(
self,
network_config_path: str,
node_id: str,
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
):
self.topology = NetworkTopology.from_path(network_config_path)
self.create_peer_handle = create_peer_handle
def __init__(
self,
network_config_path: str,
node_id: str,
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
):
self.network_config_path = network_config_path
self.node_id = node_id
self.create_peer_handle = create_peer_handle
if node_id not in self.topology.peers:
raise ValueError(
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
)
if node_id not in self.topology.peers:
raise ValueError(
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
)
self.listen_task = None
self.listen_task = None
self.known_peers: Dict[str, PeerHandle] = {}
self.known_peers: Dict[str, PeerHandle] = {}
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
self.peers_in_network.pop(node_id)
async def start(self) -> None:
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
async def start(self) -> None:
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
async def stop(self) -> None:
if self.listen_task:
self.listen_task.cancel()
async def stop(self) -> None:
if self.listen_task:
self.listen_task.cancel()
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if wait_for_peers > 0:
while len(self.known_peers) < wait_for_peers:
if DEBUG_DISCOVERY >= 2:
print(
f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers..."
)
await asyncio.sleep(0.1)
if DEBUG_DISCOVERY >= 2:
print(
f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}"
)
return list(self.known_peers.values())
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if wait_for_peers > 0:
while len(self.known_peers) < wait_for_peers:
if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
await asyncio.sleep(0.1)
if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
return list(self.known_peers.values())
async def task_find_peers_from_config(self):
if DEBUG_DISCOVERY >= 2:
print("Starting task to find peers from config...")
while True:
peers = self._get_peers().items()
for peer_id, peer_config in peers:
try:
if DEBUG_DISCOVERY >= 2:
print(
f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}"
)
peer = self.known_peers.get(peer_id)
if not peer:
if DEBUG_DISCOVERY >= 2:
print(f"{peer_id=} not found in known peers. Adding.")
peer = self.create_peer_handle(
peer_id,
f"{peer_config.address}:{peer_config.port}",
peer_config.device_capabilities,
)
peer = self.create_peer_handle(
peer_id,
f"{peer_config.address}:{peer_config.port}",
peer_config.device_capabilities,
)
is_healthy = await peer.health_check()
if is_healthy:
if DEBUG_DISCOVERY >= 2:
print(
f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy."
)
self.known_peers[peer_id] = peer
else:
if DEBUG_DISCOVERY >= 2:
print(
f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy."
)
try:
del self.known_peers[peer_id]
except KeyError:
pass
except Exception as e:
if DEBUG_DISCOVERY >= 2:
print(
f"Exception occured when attempting to add {peer_id=}: {e}"
)
await asyncio.sleep(1.0)
async def task_find_peers_from_config(self):
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
while True:
for peer_id, peer_config in self.peers_in_network.items():
try:
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
peer = self.known_peers.get(peer_id)
if not peer:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities)
is_healthy = await peer.health_check()
if is_healthy:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
self.known_peers[peer_id] = peer
else:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
try:
del self.known_peers[peer_id]
except KeyError:
pass
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
await asyncio.sleep(1.0)
if DEBUG_DISCOVERY >= 2:
print(
f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}"
)
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
def _get_peers(self):
topology = NetworkTopology.from_path(self.network_config_path)
if self.node_id not in topology.peers:
raise ValueError(
f"Node ID {self.node_id} not found in network config file {self.network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in topology.peers]}"
)
peers_in_network: Dict[str, PeerConfig] = topology.peers
peers_in_network.pop(self.node_id)
return peers_in_network

View File

@@ -29,4 +29,4 @@
}
}
}
}
}

View File

@@ -1,3 +1,4 @@
import json
import asyncio
import unittest
from unittest import mock
@@ -44,9 +45,9 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
async def test_discovery(self):
peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
assert len(peers1) == 1
self.assertEqual(len(peers1), 1)
peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
assert len(peers2) == 1
self.assertEqual(len(peers2), 1)
# connect has to be explicitly called after discovery
self.peer1.connect.assert_not_called()
@@ -76,9 +77,9 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
async def test_grpc_discovery(self):
peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
assert len(peers1) == 1
self.assertEqual(len(peers1), 1)
peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
assert len(peers2) == 1
self.assertEqual(len(peers2), 1)
# Connect
await peers1[0].connect()
@@ -98,6 +99,52 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.assertFalse(await peers1[0].is_connected())
self.assertFalse(await peers2[0].is_connected())
async def test_dynamic_config_update(self):
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(initial_peers), 1)
# Save original config for cleanup
with open(root_path, "r") as f:
original_config = json.load(f)
try:
updated_config = {
"peers": {
**original_config["peers"],
"node3": {
"address": "localhost",
"port": 50053,
"device_capabilities": {"model": "Unknown Model", "chip": "Unknown Chip", "memory": 0, "flops": {"fp32": 0, "fp16": 0, "int8": 0}},
},
}
}
with open(root_path, "w") as f:
json.dump(updated_config, f, indent=2)
node3 = mock.AsyncMock(spec=Node)
server3 = GRPCServer(node3, "localhost", 50053)
await server3.start()
try:
# Wait for the config to be reloaded
await asyncio.sleep(1.5)
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
self.assertEqual(len(updated_peers), 2)
for peer in updated_peers:
await peer.connect()
self.assertTrue(await peer.is_connected())
finally:
await server3.stop()
finally:
# Restore the original config file
with open(root_path, "w") as f:
json.dump(original_config, f, indent=2)
if __name__ == "__main__":
asyncio.run(unittest.main())