mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Merge pull request #368 from ianpaul10/feat/manual-disc-0
Manual networking with configuration files
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -170,3 +170,4 @@ cython_debug/
|
||||
#.idea/
|
||||
|
||||
**/*.xcodeproj/*
|
||||
.aider*
|
||||
|
||||
10
exo/main.py
10
exo/main.py
@@ -6,7 +6,8 @@ import logging
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import sys
|
||||
from exo.networking.manual.manual_discovery import ManualDiscovery
|
||||
from exo.networking.manual.network_topology_config import NetworkTopology
|
||||
from exo.orchestration.standard_node import StandardNode
|
||||
from exo.networking.grpc.grpc_server import GRPCServer
|
||||
from exo.networking.udp.udp_discovery import UDPDiscovery
|
||||
@@ -36,8 +37,9 @@ parser.add_argument("--download-quick-check", action="store_true", help="Quick c
|
||||
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
|
||||
parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
|
||||
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
|
||||
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale"], default="udp", help="Discovery module to use")
|
||||
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
|
||||
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
|
||||
parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
|
||||
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
|
||||
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
|
||||
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
|
||||
@@ -80,6 +82,10 @@ if args.discovery_module == "udp":
|
||||
discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
|
||||
elif args.discovery_module == "tailscale":
|
||||
discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
|
||||
elif args.discovery_module == "manual":
|
||||
if not args.discovery_config_path:
|
||||
raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
|
||||
discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
|
||||
topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
|
||||
node = StandardNode(
|
||||
args.node_id,
|
||||
|
||||
@@ -56,7 +56,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return response.is_healthy
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
except:
|
||||
except Exception:
|
||||
if DEBUG >= 4:
|
||||
print(f"Health check failed for {self._id}@{self.address}.")
|
||||
import traceback
|
||||
|
||||
0
exo/networking/manual/__init__.py
Normal file
0
exo/networking/manual/__init__.py
Normal file
81
exo/networking/manual/manual_discovery.py
Normal file
81
exo/networking/manual/manual_discovery.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import asyncio
|
||||
from exo.networking.discovery import Discovery
|
||||
from typing import Dict, List, Callable
|
||||
|
||||
from exo.topology.device_capabilities import DeviceCapabilities
|
||||
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
|
||||
from exo.helpers import DEBUG_DISCOVERY
|
||||
from exo.networking.peer_handle import PeerHandle
|
||||
|
||||
|
||||
class ManualDiscovery(Discovery):
|
||||
def __init__(
|
||||
self,
|
||||
network_config_path: str,
|
||||
node_id: str,
|
||||
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
|
||||
discovery_timeout: int = 30,
|
||||
):
|
||||
self.topology = NetworkTopology.from_path(network_config_path)
|
||||
self.node_id = node_id
|
||||
self.create_peer_handle = create_peer_handle
|
||||
self.discovery_timeout = discovery_timeout
|
||||
|
||||
try:
|
||||
self.node = self.topology.peers[node_id]
|
||||
except KeyError as e:
|
||||
print(f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}")
|
||||
raise e
|
||||
|
||||
self.node_port = self.node.port
|
||||
|
||||
self.listen_task = None
|
||||
self.cleanup_task = None
|
||||
|
||||
self.known_peers: Dict[str, PeerHandle] = {}
|
||||
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
|
||||
self.node_config = self.peers_in_network.pop(node_id)
|
||||
|
||||
async def start(self) -> None:
|
||||
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self.listen_task:
|
||||
self.listen_task.cancel()
|
||||
|
||||
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
|
||||
if DEBUG_DISCOVERY >= 2: print("Starting discovery...")
|
||||
if wait_for_peers > 0:
|
||||
while len(self.known_peers) < wait_for_peers:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
|
||||
await asyncio.sleep(0.1)
|
||||
return list(self.known_peers.values())
|
||||
|
||||
|
||||
async def task_find_peers_from_config(self):
|
||||
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
|
||||
while True:
|
||||
for peer_id, peer_config in self.peers_in_network.items():
|
||||
try:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
|
||||
peer = self.known_peers.get(peer_id)
|
||||
if not peer:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
|
||||
peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
|
||||
is_healthy = await peer.health_check()
|
||||
if is_healthy:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
|
||||
self.known_peers[peer_id] = peer
|
||||
else:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
|
||||
try:
|
||||
del self.known_peers[peer_id]
|
||||
except KeyError:
|
||||
pass # peer was never added, so nothing to delete
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >=2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
|
||||
finally:
|
||||
await asyncio.sleep(self.discovery_timeout)
|
||||
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
|
||||
|
||||
32
exo/networking/manual/network_topology_config.py
Normal file
32
exo/networking/manual/network_topology_config.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from exo.topology.device_capabilities import DeviceCapabilities
|
||||
|
||||
|
||||
class PeerConfig(BaseModel):
|
||||
address: str
|
||||
port: int
|
||||
device_capabilities: DeviceCapabilities
|
||||
|
||||
|
||||
class NetworkTopology(BaseModel):
|
||||
"""Configuration of the network. A collection outlining all nodes in the network, including the node this is running from."""
|
||||
|
||||
peers: Dict[str, PeerConfig]
|
||||
"""
|
||||
node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, path: str) -> "NetworkTopology":
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
config_data = f.read()
|
||||
except FileNotFoundError as e:
|
||||
raise FileNotFoundError(f"Config file not found at {path}") from e
|
||||
|
||||
try:
|
||||
return cls.model_validate_json(config_data)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Error validating network topology config from {path}: {e}") from e
|
||||
17
exo/networking/manual/test_data/invalid_config.json
Normal file
17
exo/networking/manual/test_data/invalid_config.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"peers": {
|
||||
"node1": {
|
||||
"address": "localhost",
|
||||
"device_capabilities": {
|
||||
"model": "Unknown Model",
|
||||
"chip": "Unknown Chip",
|
||||
"memory": 0,
|
||||
"flops": {
|
||||
"fp32": 0,
|
||||
"fp16": 0,
|
||||
"int8": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
0
exo/networking/manual/test_data/invalid_json.json
Normal file
0
exo/networking/manual/test_data/invalid_json.json
Normal file
32
exo/networking/manual/test_data/test_config.json
Normal file
32
exo/networking/manual/test_data/test_config.json
Normal file
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"peers": {
|
||||
"node1": {
|
||||
"address": "localhost",
|
||||
"port": 50051,
|
||||
"device_capabilities": {
|
||||
"model": "Unknown Model",
|
||||
"chip": "Unknown Chip",
|
||||
"memory": 0,
|
||||
"flops": {
|
||||
"fp32": 0,
|
||||
"fp16": 0,
|
||||
"int8": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"node2": {
|
||||
"address": "localhost",
|
||||
"port": 50052,
|
||||
"device_capabilities": {
|
||||
"model": "Unknown Model",
|
||||
"chip": "Unknown Chip",
|
||||
"memory": 0,
|
||||
"flops": {
|
||||
"fp32": 0,
|
||||
"fp16": 0,
|
||||
"int8": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
18
exo/networking/manual/test_data/test_config_single_node.json
Normal file
18
exo/networking/manual/test_data/test_config_single_node.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"peers": {
|
||||
"node1": {
|
||||
"address": "localhost",
|
||||
"port": 50051,
|
||||
"device_capabilities": {
|
||||
"model": "Unknown Model",
|
||||
"chip": "Unknown Chip",
|
||||
"memory": 0,
|
||||
"flops": {
|
||||
"fp32": 0,
|
||||
"fp16": 0,
|
||||
"int8": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
103
exo/networking/manual/test_manual_discovery.py
Normal file
103
exo/networking/manual/test_manual_discovery.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from exo.networking.manual.manual_discovery import ManualDiscovery
|
||||
from exo.networking.manual.network_topology_config import NetworkTopology
|
||||
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from exo.networking.grpc.grpc_server import GRPCServer
|
||||
from exo.orchestration.node import Node
|
||||
|
||||
root_path = "./exo/networking/manual/test_data/test_config.json"
|
||||
|
||||
|
||||
class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.peer1 = mock.AsyncMock()
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
|
||||
_ = self.discovery1.start()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
|
||||
async def test_discovery(self):
|
||||
peers1 = await self.discovery1.discover_peers(wait_for_peers=0)
|
||||
assert len(peers1) == 0
|
||||
|
||||
self.peer1.connect.assert_not_called()
|
||||
|
||||
|
||||
class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.peer1 = mock.AsyncMock()
|
||||
self.peer2 = mock.AsyncMock()
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.peer2.connect = mock.AsyncMock()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
|
||||
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
await self.discovery2.stop()
|
||||
|
||||
async def test_discovery(self):
|
||||
peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
assert len(peers1) == 1
|
||||
peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
|
||||
assert len(peers2) == 1
|
||||
|
||||
# connect has to be explicitly called after discovery
|
||||
self.peer1.connect.assert_not_called()
|
||||
self.peer2.connect.assert_not_called()
|
||||
|
||||
|
||||
class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
config = NetworkTopology.from_path(root_path)
|
||||
|
||||
self.node1 = mock.AsyncMock(spec=Node)
|
||||
self.node2 = mock.AsyncMock(spec=Node)
|
||||
self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port)
|
||||
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
|
||||
await self.server1.start()
|
||||
await self.server2.start()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
|
||||
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
await self.discovery2.stop()
|
||||
await self.server1.stop()
|
||||
await self.server2.stop()
|
||||
|
||||
async def test_grpc_discovery(self):
|
||||
peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
assert len(peers1) == 1
|
||||
peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
|
||||
assert len(peers2) == 1
|
||||
|
||||
# Connect
|
||||
await peers1[0].connect()
|
||||
await peers2[0].connect()
|
||||
self.assertTrue(await peers1[0].is_connected())
|
||||
self.assertTrue(await peers2[0].is_connected())
|
||||
|
||||
# Kill server1
|
||||
await self.server1.stop()
|
||||
|
||||
self.assertTrue(await peers1[0].is_connected())
|
||||
self.assertFalse(await peers2[0].is_connected())
|
||||
|
||||
# Kill server2
|
||||
await self.server2.stop()
|
||||
|
||||
self.assertFalse(await peers1[0].is_connected())
|
||||
self.assertFalse(await peers2[0].is_connected())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(unittest.main())
|
||||
49
exo/networking/manual/test_network_topology_config.py
Normal file
49
exo/networking/manual/test_network_topology_config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import unittest
|
||||
|
||||
from exo.networking.manual.network_topology_config import NetworkTopology
|
||||
|
||||
root_path = "./exo/networking/manual/test_data/"
|
||||
|
||||
|
||||
class TestNetworkTopologyConfig(unittest.TestCase):
|
||||
def test_from_path_invalid_path(self):
|
||||
with self.assertRaises(FileNotFoundError) as e:
|
||||
NetworkTopology.from_path("invalid_path")
|
||||
self.assertEqual(str(e.exception), "Config file not found at invalid_path")
|
||||
|
||||
def test_from_path_invalid_json(self):
|
||||
with self.assertRaises(ValueError) as e:
|
||||
NetworkTopology.from_path(root_path + "invalid_json.json")
|
||||
self.assertIn("Error validating network topology config from", str(e.exception))
|
||||
self.assertIn("1 validation error for NetworkTopology\n Invalid JSON: EOF while parsing a value at line 1 column 0", str(e.exception))
|
||||
|
||||
def test_from_path_invalid_config(self):
|
||||
with self.assertRaises(ValueError) as e:
|
||||
NetworkTopology.from_path(root_path + "invalid_config.json")
|
||||
self.assertIn("Error validating network topology config from", str(e.exception))
|
||||
self.assertIn("port\n Field required", str(e.exception))
|
||||
|
||||
def test_from_path_valid(self):
|
||||
config = NetworkTopology.from_path(root_path + "test_config.json")
|
||||
|
||||
self.assertEqual(config.peers["node1"].port, 50051)
|
||||
self.assertEqual(config.peers["node1"].device_capabilities.model, "Unknown Model")
|
||||
self.assertEqual(config.peers["node1"].address, "localhost")
|
||||
self.assertEqual(config.peers["node1"].device_capabilities.chip, "Unknown Chip")
|
||||
self.assertEqual(config.peers["node1"].device_capabilities.memory, 0)
|
||||
self.assertEqual(config.peers["node1"].device_capabilities.flops.fp32, 0)
|
||||
self.assertEqual(config.peers["node1"].device_capabilities.flops.fp16, 0)
|
||||
self.assertEqual(config.peers["node1"].device_capabilities.flops.int8, 0)
|
||||
|
||||
self.assertEqual(config.peers["node2"].port, 50052)
|
||||
self.assertEqual(config.peers["node2"].device_capabilities.model, "Unknown Model")
|
||||
self.assertEqual(config.peers["node2"].address, "localhost")
|
||||
self.assertEqual(config.peers["node2"].device_capabilities.chip, "Unknown Chip")
|
||||
self.assertEqual(config.peers["node2"].device_capabilities.memory, 0)
|
||||
self.assertEqual(config.peers["node2"].device_capabilities.flops.fp32, 0)
|
||||
self.assertEqual(config.peers["node2"].device_capabilities.flops.fp16, 0)
|
||||
self.assertEqual(config.peers["node2"].device_capabilities.flops.int8, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -6,6 +6,7 @@ from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from exo.networking.grpc.grpc_server import GRPCServer
|
||||
from exo.orchestration.node import Node
|
||||
|
||||
|
||||
class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.peer1 = mock.AsyncMock()
|
||||
|
||||
@@ -205,4 +205,4 @@ class UDPDiscovery(Discovery):
|
||||
(current_time - last_seen > self.discovery_timeout) or
|
||||
(not health_ok)
|
||||
)
|
||||
return should_remove
|
||||
return should_remove
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
from exo import DEBUG
|
||||
from dataclasses import dataclass, asdict
|
||||
import subprocess
|
||||
import psutil
|
||||
|
||||
TFLOPS = 1.00
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceFlops:
|
||||
class DeviceFlops(BaseModel):
|
||||
# units of TFLOPS
|
||||
fp32: float
|
||||
fp16: float
|
||||
@@ -17,11 +17,10 @@ class DeviceFlops:
|
||||
return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceCapabilities:
|
||||
class DeviceCapabilities(BaseModel):
|
||||
model: str
|
||||
chip: str
|
||||
memory: int
|
||||
@@ -30,7 +29,7 @@ class DeviceCapabilities:
|
||||
def __str__(self):
|
||||
return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
|
||||
|
||||
def __post_init__(self):
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if isinstance(self.flops, dict):
|
||||
self.flops = DeviceFlops(**self.flops)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user