mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge pull request #383 from ianpaul10/feat/manual-disc-follow-up
Support changing manual configuration while running
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import os
|
||||
import asyncio
|
||||
from exo.networking.discovery import Discovery
|
||||
from typing import Dict, List, Callable
|
||||
from typing import Dict, List, Callable, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from exo.networking.discovery import Discovery
|
||||
from exo.topology.device_capabilities import DeviceCapabilities
|
||||
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
|
||||
from exo.helpers import DEBUG_DISCOVERY
|
||||
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
|
||||
self,
|
||||
network_config_path: str,
|
||||
node_id: str,
|
||||
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
|
||||
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
|
||||
):
|
||||
self.topology = NetworkTopology.from_path(network_config_path)
|
||||
self.network_config_path = network_config_path
|
||||
self.node_id = node_id
|
||||
self.create_peer_handle = create_peer_handle
|
||||
|
||||
if node_id not in self.topology.peers:
|
||||
raise ValueError(
|
||||
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
|
||||
)
|
||||
|
||||
self.listen_task = None
|
||||
|
||||
self.known_peers: Dict[str, PeerHandle] = {}
|
||||
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
|
||||
self.peers_in_network.pop(node_id)
|
||||
|
||||
self._cached_peers: Dict[str, PeerConfig] = {}
|
||||
self._last_modified_time: Optional[float] = None
|
||||
self._file_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
async def start(self) -> None:
|
||||
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self.listen_task:
|
||||
self.listen_task.cancel()
|
||||
if self.listen_task: self.listen_task.cancel()
|
||||
self._file_executor.shutdown(wait=True)
|
||||
|
||||
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
|
||||
if wait_for_peers > 0:
|
||||
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
|
||||
async def task_find_peers_from_config(self):
|
||||
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
|
||||
while True:
|
||||
for peer_id, peer_config in self.peers_in_network.items():
|
||||
peers_from_config = await self._get_peers()
|
||||
new_known_peers = {}
|
||||
for peer_id, peer_config in peers_from_config.items():
|
||||
try:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
|
||||
peer = self.known_peers.get(peer_id)
|
||||
@@ -57,15 +58,44 @@ class ManualDiscovery(Discovery):
|
||||
is_healthy = await peer.health_check()
|
||||
if is_healthy:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
|
||||
self.known_peers[peer_id] = peer
|
||||
else:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
|
||||
try:
|
||||
del self.known_peers[peer_id]
|
||||
except KeyError:
|
||||
pass
|
||||
new_known_peers[peer_id] = peer
|
||||
elif DEBUG_DISCOVERY >= 2:
|
||||
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
|
||||
self.known_peers = new_known_peers
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
|
||||
|
||||
async def _get_peers(self):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
|
||||
|
||||
if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
|
||||
return self._cached_peers
|
||||
|
||||
topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
|
||||
|
||||
if self.node_id not in topology.peers:
|
||||
raise ValueError(
|
||||
f"Node ID {self.node_id} not found in network config file "
|
||||
f"{self.network_config_path}. Please run with `node_id` set to "
|
||||
f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
|
||||
)
|
||||
|
||||
peers_in_network = topology.peers
|
||||
peers_in_network.pop(self.node_id)
|
||||
|
||||
self._cached_peers = peers_in_network
|
||||
self._last_modified_time = current_mtime
|
||||
|
||||
return peers_in_network
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >= 2:
|
||||
print(f"Error when loading network config file from {self.network_config_path}. "
|
||||
f"Please update the config file in order to successfully discover peers. "
|
||||
f"Exception: {e}")
|
||||
return self._cached_peers
|
||||
|
||||
@@ -29,4 +29,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest import mock
|
||||
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.peer1 = mock.AsyncMock()
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
|
||||
_ = self.discovery1.start()
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
|
||||
)
|
||||
await self.discovery1.start()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
self.peer2 = mock.AsyncMock()
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.peer2.connect = mock.AsyncMock()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
|
||||
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
|
||||
)
|
||||
self.discovery2 = ManualDiscovery(
|
||||
root_path,
|
||||
"node2",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
|
||||
)
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
|
||||
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
|
||||
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
|
||||
await self.server1.start()
|
||||
await self.server2.start()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
|
||||
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
)
|
||||
self.discovery2 = ManualDiscovery(
|
||||
root_path,
|
||||
"node2",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
)
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
|
||||
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertFalse(await peers1[0].is_connected())
|
||||
self.assertFalse(await peers2[0].is_connected())
|
||||
|
||||
async def test_dynamic_config_update(self):
|
||||
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(initial_peers), 1)
|
||||
|
||||
# Save original config for cleanup
|
||||
with open(root_path, "r") as f:
|
||||
original_config = json.load(f)
|
||||
|
||||
try:
|
||||
updated_config = {
|
||||
"peers": {
|
||||
**original_config["peers"],
|
||||
"node3": {
|
||||
"address": "localhost",
|
||||
"port": 50053,
|
||||
"device_capabilities": {
|
||||
"model": "Unknown Model",
|
||||
"chip": "Unknown Chip",
|
||||
"memory": 0,
|
||||
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
with open(root_path, "w") as f:
|
||||
json.dump(updated_config, f, indent=2)
|
||||
|
||||
node3 = mock.AsyncMock(spec=Node)
|
||||
server3 = GRPCServer(node3, "localhost", 50053)
|
||||
await server3.start()
|
||||
|
||||
try:
|
||||
# Wait for the config to be reloaded
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
|
||||
self.assertEqual(len(updated_peers), 2)
|
||||
|
||||
for peer in updated_peers:
|
||||
await peer.connect()
|
||||
self.assertTrue(await peer.is_connected())
|
||||
|
||||
finally:
|
||||
await server3.stop()
|
||||
|
||||
finally:
|
||||
# Restore the original config file
|
||||
with open(root_path, "w") as f:
|
||||
json.dump(original_config, f, indent=2)
|
||||
|
||||
# Wait for the config to be reloaded again
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(updated_peers), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(unittest.main())
|
||||
|
||||
Reference in New Issue
Block a user