mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
formatting and fixing tests after rebasing
This commit is contained in:
@@ -15,7 +15,7 @@ class ManualDiscovery(Discovery):
|
||||
self,
|
||||
network_config_path: str,
|
||||
node_id: str,
|
||||
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
|
||||
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
|
||||
):
|
||||
self.network_config_path = network_config_path
|
||||
self.node_id = node_id
|
||||
@@ -54,7 +54,7 @@ class ManualDiscovery(Discovery):
|
||||
peer = self.known_peers.get(peer_id)
|
||||
if not peer:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
|
||||
peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
|
||||
peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities)
|
||||
is_healthy = await peer.health_check()
|
||||
if is_healthy:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
|
||||
|
||||
@@ -12,187 +12,170 @@ root_path = "./exo/networking/manual/test_data/test_config.json"
|
||||
|
||||
|
||||
class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.peer1 = mock.AsyncMock(spec=Node)
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.server1 = GRPCServer(self.peer1, "localhost", 8000)
|
||||
await self.server1.start()
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(
|
||||
peer_id, address, device_capabilities
|
||||
),
|
||||
)
|
||||
await self.discovery1.start()
|
||||
async def asyncSetUp(self):
|
||||
self.peer1 = mock.AsyncMock()
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
|
||||
)
|
||||
await self.discovery1.start()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
await self.server1.stop()
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
|
||||
async def test_discovery(self):
|
||||
peers1 = await self.discovery1.discover_peers(wait_for_peers=0)
|
||||
self.assertEqual(len(peers1), 0)
|
||||
async def test_discovery(self):
|
||||
peers1 = await self.discovery1.discover_peers(wait_for_peers=0)
|
||||
assert len(peers1) == 0
|
||||
|
||||
self.peer1.connect.assert_not_called()
|
||||
self.peer1.connect.assert_not_called()
|
||||
|
||||
|
||||
class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.peer1 = mock.AsyncMock()
|
||||
self.peer2 = mock.AsyncMock()
|
||||
async def asyncSetUp(self):
|
||||
self.peer1 = mock.AsyncMock()
|
||||
self.peer2 = mock.AsyncMock()
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.peer2.connect = mock.AsyncMock()
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
|
||||
)
|
||||
self.discovery2 = ManualDiscovery(
|
||||
root_path,
|
||||
"node2",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
|
||||
)
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
|
||||
self.peer1.id = mock.MagicMock(return_value="node2")
|
||||
self.peer2.id = mock.MagicMock(return_value="node1")
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
await self.discovery2.stop()
|
||||
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.peer2.connect = mock.AsyncMock()
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
|
||||
)
|
||||
self.discovery2 = ManualDiscovery(
|
||||
root_path,
|
||||
"node2",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
|
||||
)
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
async def test_discovery(self):
|
||||
peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
assert len(peers1) == 1
|
||||
peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
|
||||
assert len(peers2) == 1
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
await self.discovery2.stop()
|
||||
|
||||
async def test_discovery(self):
|
||||
peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(peers1), 1)
|
||||
peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(peers2), 1)
|
||||
|
||||
# connect has to be explicitly called after discovery
|
||||
self.peer1.connect.assert_not_called()
|
||||
self.peer2.connect.assert_not_called()
|
||||
# connect has to be explicitly called after discovery
|
||||
self.peer1.connect.assert_not_called()
|
||||
self.peer2.connect.assert_not_called()
|
||||
|
||||
|
||||
class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
config = NetworkTopology.from_path(root_path)
|
||||
async def asyncSetUp(self):
|
||||
config = NetworkTopology.from_path(root_path)
|
||||
|
||||
self.node1 = mock.AsyncMock(spec=Node)
|
||||
self.node2 = mock.AsyncMock(spec=Node)
|
||||
self.server1 = GRPCServer(
|
||||
self.node1, config.peers["node1"].address, config.peers["node1"].port
|
||||
)
|
||||
self.server2 = GRPCServer(
|
||||
self.node2, config.peers["node2"].address, config.peers["node2"].port
|
||||
)
|
||||
await self.server1.start()
|
||||
await self.server2.start()
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(
|
||||
peer_id, address, description, device_capabilities
|
||||
),
|
||||
)
|
||||
self.discovery2 = ManualDiscovery(
|
||||
root_path,
|
||||
"node2",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(
|
||||
peer_id, address, description, device_capabilities
|
||||
),
|
||||
)
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
self.node1 = mock.AsyncMock(spec=Node)
|
||||
self.node2 = mock.AsyncMock(spec=Node)
|
||||
self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port)
|
||||
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
|
||||
await self.server1.start()
|
||||
await self.server2.start()
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
)
|
||||
self.discovery2 = ManualDiscovery(
|
||||
root_path,
|
||||
"node2",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
)
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
await self.discovery2.stop()
|
||||
await self.server1.stop()
|
||||
await self.server2.stop()
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
await self.discovery2.stop()
|
||||
await self.server1.stop()
|
||||
await self.server2.stop()
|
||||
|
||||
async def test_grpc_discovery(self):
|
||||
peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(peers1), 1)
|
||||
peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(peers2), 1)
|
||||
async def test_grpc_discovery(self):
|
||||
peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
assert len(peers1) == 1
|
||||
peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
|
||||
assert len(peers2) == 1
|
||||
|
||||
# Connect
|
||||
await peers1[0].connect()
|
||||
await peers2[0].connect()
|
||||
self.assertTrue(await peers1[0].is_connected())
|
||||
self.assertTrue(await peers2[0].is_connected())
|
||||
# Connect
|
||||
await peers1[0].connect()
|
||||
await peers2[0].connect()
|
||||
self.assertTrue(await peers1[0].is_connected())
|
||||
self.assertTrue(await peers2[0].is_connected())
|
||||
|
||||
# Kill server1
|
||||
await self.server1.stop()
|
||||
# Kill server1
|
||||
await self.server1.stop()
|
||||
|
||||
self.assertTrue(await peers1[0].is_connected())
|
||||
self.assertFalse(await peers2[0].is_connected())
|
||||
self.assertTrue(await peers1[0].is_connected())
|
||||
self.assertFalse(await peers2[0].is_connected())
|
||||
|
||||
# Kill server2
|
||||
await self.server2.stop()
|
||||
# Kill server2
|
||||
await self.server2.stop()
|
||||
|
||||
self.assertFalse(await peers1[0].is_connected())
|
||||
self.assertFalse(await peers2[0].is_connected())
|
||||
self.assertFalse(await peers1[0].is_connected())
|
||||
self.assertFalse(await peers2[0].is_connected())
|
||||
|
||||
async def test_dynamic_config_update(self):
|
||||
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(initial_peers), 1)
|
||||
async def test_dynamic_config_update(self):
|
||||
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(initial_peers), 1)
|
||||
|
||||
# Save original config for cleanup
|
||||
with open(root_path, "r") as f:
|
||||
original_config = json.load(f)
|
||||
# Save original config for cleanup
|
||||
with open(root_path, "r") as f:
|
||||
original_config = json.load(f)
|
||||
|
||||
try:
|
||||
updated_config = {
|
||||
"peers": {
|
||||
**original_config["peers"],
|
||||
"node3": {
|
||||
"address": "localhost",
|
||||
"port": 50053,
|
||||
"device_capabilities": {
|
||||
"model": "Unknown Model",
|
||||
"chip": "Unknown Chip",
|
||||
"memory": 0,
|
||||
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
try:
|
||||
updated_config = {
|
||||
"peers": {
|
||||
**original_config["peers"],
|
||||
"node3": {
|
||||
"address": "localhost",
|
||||
"port": 50053,
|
||||
"device_capabilities": {
|
||||
"model": "Unknown Model",
|
||||
"chip": "Unknown Chip",
|
||||
"memory": 0,
|
||||
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
with open(root_path, "w") as f:
|
||||
json.dump(updated_config, f, indent=2)
|
||||
with open(root_path, "w") as f:
|
||||
json.dump(updated_config, f, indent=2)
|
||||
|
||||
node3 = mock.AsyncMock(spec=Node)
|
||||
server3 = GRPCServer(node3, "localhost", 50053)
|
||||
await server3.start()
|
||||
node3 = mock.AsyncMock(spec=Node)
|
||||
server3 = GRPCServer(node3, "localhost", 50053)
|
||||
await server3.start()
|
||||
|
||||
try:
|
||||
# Wait for the config to be reloaded
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
|
||||
self.assertEqual(len(updated_peers), 2)
|
||||
|
||||
for peer in updated_peers:
|
||||
await peer.connect()
|
||||
self.assertTrue(await peer.is_connected())
|
||||
|
||||
finally:
|
||||
await server3.stop()
|
||||
|
||||
finally:
|
||||
# Restore the original config file
|
||||
with open(root_path, "w") as f:
|
||||
json.dump(original_config, f, indent=2)
|
||||
|
||||
# Wait for the config to be reloaded again
|
||||
try:
|
||||
# Wait for the config to be reloaded
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(updated_peers), 1)
|
||||
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
|
||||
self.assertEqual(len(updated_peers), 2)
|
||||
|
||||
for peer in updated_peers:
|
||||
await peer.connect()
|
||||
self.assertTrue(await peer.is_connected())
|
||||
|
||||
finally:
|
||||
await server3.stop()
|
||||
|
||||
finally:
|
||||
# Restore the original config file
|
||||
with open(root_path, "w") as f:
|
||||
json.dump(original_config, f, indent=2)
|
||||
|
||||
# Wait for the config to be reloaded again
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(updated_peers), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(unittest.main())
|
||||
asyncio.run(unittest.main())
|
||||
|
||||
Reference in New Issue
Block a user