Merge pull request #383 from ianpaul10/feat/manual-disc-follow-up

Support changing manual configuration while running
This commit is contained in:
Alex Cheema
2024-12-28 11:57:25 +00:00
committed by GitHub
3 changed files with 137 additions and 29 deletions

View File

@@ -1,7 +1,9 @@
import os
import asyncio
from exo.networking.discovery import Discovery
from typing import Dict, List, Callable
from typing import Dict, List, Callable, Optional
from concurrent.futures import ThreadPoolExecutor
from exo.networking.discovery import Discovery
from exo.topology.device_capabilities import DeviceCapabilities
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
from exo.helpers import DEBUG_DISCOVERY
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
self,
network_config_path: str,
node_id: str,
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
):
self.topology = NetworkTopology.from_path(network_config_path)
self.network_config_path = network_config_path
self.node_id = node_id
self.create_peer_handle = create_peer_handle
if node_id not in self.topology.peers:
raise ValueError(
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
)
self.listen_task = None
self.known_peers: Dict[str, PeerHandle] = {}
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
self.peers_in_network.pop(node_id)
self._cached_peers: Dict[str, PeerConfig] = {}
self._last_modified_time: Optional[float] = None
self._file_executor = ThreadPoolExecutor(max_workers=1)
async def start(self) -> None:
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
async def stop(self) -> None:
if self.listen_task:
self.listen_task.cancel()
if self.listen_task: self.listen_task.cancel()
self._file_executor.shutdown(wait=True)
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if wait_for_peers > 0:
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
async def task_find_peers_from_config(self):
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
while True:
for peer_id, peer_config in self.peers_in_network.items():
peers_from_config = await self._get_peers()
new_known_peers = {}
for peer_id, peer_config in peers_from_config.items():
try:
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
peer = self.known_peers.get(peer_id)
@@ -57,15 +58,44 @@ class ManualDiscovery(Discovery):
is_healthy = await peer.health_check()
if is_healthy:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
self.known_peers[peer_id] = peer
else:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
try:
del self.known_peers[peer_id]
except KeyError:
pass
new_known_peers[peer_id] = peer
elif DEBUG_DISCOVERY >= 2:
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
self.known_peers = new_known_peers
await asyncio.sleep(1.0)
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
async def _get_peers(self):
try:
loop = asyncio.get_running_loop()
current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
return self._cached_peers
topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
if self.node_id not in topology.peers:
raise ValueError(
f"Node ID {self.node_id} not found in network config file "
f"{self.network_config_path}. Please run with `node_id` set to "
f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
)
peers_in_network = topology.peers
peers_in_network.pop(self.node_id)
self._cached_peers = peers_in_network
self._last_modified_time = current_mtime
return peers_in_network
except Exception as e:
if DEBUG_DISCOVERY >= 2:
print(f"Error when loading network config file from {self.network_config_path}. "
f"Please update the config file in order to successfully discover peers. "
f"Exception: {e}")
return self._cached_peers

View File

@@ -29,4 +29,4 @@
}
}
}
}
}

View File

@@ -1,3 +1,4 @@
import json
import asyncio
import unittest
from unittest import mock
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.peer1 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
_ = self.discovery1.start()
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
await self.discovery1.start()
async def asyncTearDown(self):
await self.discovery1.stop()
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
self.peer2 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.peer2.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
)
await self.discovery1.start()
await self.discovery2.start()
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
await self.server1.start()
await self.server2.start()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
await self.discovery1.start()
await self.discovery2.start()
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.assertFalse(await peers1[0].is_connected())
self.assertFalse(await peers2[0].is_connected())
async def test_dynamic_config_update(self):
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(initial_peers), 1)
# Save original config for cleanup
with open(root_path, "r") as f:
original_config = json.load(f)
try:
updated_config = {
"peers": {
**original_config["peers"],
"node3": {
"address": "localhost",
"port": 50053,
"device_capabilities": {
"model": "Unknown Model",
"chip": "Unknown Chip",
"memory": 0,
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
},
},
}
}
with open(root_path, "w") as f:
json.dump(updated_config, f, indent=2)
node3 = mock.AsyncMock(spec=Node)
server3 = GRPCServer(node3, "localhost", 50053)
await server3.start()
try:
# Wait for the config to be reloaded
await asyncio.sleep(1.5)
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
self.assertEqual(len(updated_peers), 2)
for peer in updated_peers:
await peer.connect()
self.assertTrue(await peer.is_connected())
finally:
await server3.stop()
finally:
# Restore the original config file
with open(root_path, "w") as f:
json.dump(original_config, f, indent=2)
# Wait for the config to be reloaded again
await asyncio.sleep(1.5)
updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(updated_peers), 1)
if __name__ == "__main__":
asyncio.run(unittest.main())