Merge pull request #368 from ianpaul10/feat/manual-disc-0

Manual networking with configuration files
This commit is contained in:
Alex Cheema
2024-10-23 16:21:55 -07:00
committed by GitHub
16 changed files with 351 additions and 11 deletions

1
.gitignore vendored
View File

@@ -170,3 +170,4 @@ cython_debug/
#.idea/
**/*.xcodeproj/*
.aider*

View File

@@ -6,7 +6,8 @@ import logging
import time
import traceback
import uuid
import sys
from exo.networking.manual.manual_discovery import ManualDiscovery
from exo.networking.manual.network_topology_config import NetworkTopology
from exo.orchestration.standard_node import StandardNode
from exo.networking.grpc.grpc_server import GRPCServer
from exo.networking.udp.udp_discovery import UDPDiscovery
@@ -36,8 +37,9 @@ parser.add_argument("--download-quick-check", action="store_true", help="Quick c
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale"], default="udp", help="Discovery module to use")
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
@@ -80,6 +82,10 @@ if args.discovery_module == "udp":
discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
elif args.discovery_module == "tailscale":
discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
elif args.discovery_module == "manual":
if not args.discovery_config_path:
raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
node = StandardNode(
args.node_id,

View File

@@ -56,7 +56,7 @@ class GRPCPeerHandle(PeerHandle):
return response.is_healthy
except asyncio.TimeoutError:
return False
except:
except Exception:
if DEBUG >= 4:
print(f"Health check failed for {self._id}@{self.address}.")
import traceback

View File

View File

@@ -0,0 +1,81 @@
import asyncio
from exo.networking.discovery import Discovery
from typing import Dict, List, Callable
from exo.topology.device_capabilities import DeviceCapabilities
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
from exo.helpers import DEBUG_DISCOVERY
from exo.networking.peer_handle import PeerHandle
class ManualDiscovery(Discovery):
def __init__(
self,
network_config_path: str,
node_id: str,
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
discovery_timeout: int = 30,
):
self.topology = NetworkTopology.from_path(network_config_path)
self.node_id = node_id
self.create_peer_handle = create_peer_handle
self.discovery_timeout = discovery_timeout
try:
self.node = self.topology.peers[node_id]
except KeyError as e:
print(f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}")
raise e
self.node_port = self.node.port
self.listen_task = None
self.cleanup_task = None
self.known_peers: Dict[str, PeerHandle] = {}
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
self.node_config = self.peers_in_network.pop(node_id)
async def start(self) -> None:
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
async def stop(self) -> None:
if self.listen_task:
self.listen_task.cancel()
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if DEBUG_DISCOVERY >= 2: print("Starting discovery...")
if wait_for_peers > 0:
while len(self.known_peers) < wait_for_peers:
if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
await asyncio.sleep(0.1)
return list(self.known_peers.values())
async def task_find_peers_from_config(self):
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
while True:
for peer_id, peer_config in self.peers_in_network.items():
try:
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
peer = self.known_peers.get(peer_id)
if not peer:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
is_healthy = await peer.health_check()
if is_healthy:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
self.known_peers[peer_id] = peer
else:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
try:
del self.known_peers[peer_id]
except KeyError:
pass # peer was never added, so nothing to delete
except Exception as e:
if DEBUG_DISCOVERY >=2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
finally:
await asyncio.sleep(self.discovery_timeout)
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")

View File

@@ -0,0 +1,32 @@
from typing import Dict
from pydantic import BaseModel, ValidationError
from exo.topology.device_capabilities import DeviceCapabilities
class PeerConfig(BaseModel):
address: str
port: int
device_capabilities: DeviceCapabilities
class NetworkTopology(BaseModel):
"""Configuration of the network. A collection outlining all nodes in the network, including the node this is running from."""
peers: Dict[str, PeerConfig]
"""
node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
"""
@classmethod
def from_path(cls, path: str) -> "NetworkTopology":
try:
with open(path, "r") as f:
config_data = f.read()
except FileNotFoundError as e:
raise FileNotFoundError(f"Config file not found at {path}") from e
try:
return cls.model_validate_json(config_data)
except ValidationError as e:
raise ValueError(f"Error validating network topology config from {path}: {e}") from e

View File

@@ -0,0 +1,17 @@
{
"peers": {
"node1": {
"address": "localhost",
"device_capabilities": {
"model": "Unknown Model",
"chip": "Unknown Chip",
"memory": 0,
"flops": {
"fp32": 0,
"fp16": 0,
"int8": 0
}
}
}
}
}

View File

@@ -0,0 +1,32 @@
{
"peers": {
"node1": {
"address": "localhost",
"port": 50051,
"device_capabilities": {
"model": "Unknown Model",
"chip": "Unknown Chip",
"memory": 0,
"flops": {
"fp32": 0,
"fp16": 0,
"int8": 0
}
}
},
"node2": {
"address": "localhost",
"port": 50052,
"device_capabilities": {
"model": "Unknown Model",
"chip": "Unknown Chip",
"memory": 0,
"flops": {
"fp32": 0,
"fp16": 0,
"int8": 0
}
}
}
}
}

View File

@@ -0,0 +1,18 @@
{
"peers": {
"node1": {
"address": "localhost",
"port": 50051,
"device_capabilities": {
"model": "Unknown Model",
"chip": "Unknown Chip",
"memory": 0,
"flops": {
"fp32": 0,
"fp16": 0,
"int8": 0
}
}
}
}
}

View File

@@ -0,0 +1,103 @@
import asyncio
import unittest
from unittest import mock
from exo.networking.manual.manual_discovery import ManualDiscovery
from exo.networking.manual.network_topology_config import NetworkTopology
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.networking.grpc.grpc_server import GRPCServer
from exo.orchestration.node import Node
root_path = "./exo/networking/manual/test_data/test_config.json"
class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.peer1 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
_ = self.discovery1.start()
async def asyncTearDown(self):
await self.discovery1.stop()
async def test_discovery(self):
peers1 = await self.discovery1.discover_peers(wait_for_peers=0)
assert len(peers1) == 0
self.peer1.connect.assert_not_called()
class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.peer1 = mock.AsyncMock()
self.peer2 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.peer2.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
await self.discovery1.start()
await self.discovery2.start()
async def asyncTearDown(self):
await self.discovery1.stop()
await self.discovery2.stop()
async def test_discovery(self):
peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
assert len(peers1) == 1
peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
assert len(peers2) == 1
# connect has to be explicitly called after discovery
self.peer1.connect.assert_not_called()
self.peer2.connect.assert_not_called()
class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
config = NetworkTopology.from_path(root_path)
self.node1 = mock.AsyncMock(spec=Node)
self.node2 = mock.AsyncMock(spec=Node)
self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port)
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
await self.server1.start()
await self.server2.start()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
await self.discovery1.start()
await self.discovery2.start()
async def asyncTearDown(self):
await self.discovery1.stop()
await self.discovery2.stop()
await self.server1.stop()
await self.server2.stop()
async def test_grpc_discovery(self):
peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
assert len(peers1) == 1
peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
assert len(peers2) == 1
# Connect
await peers1[0].connect()
await peers2[0].connect()
self.assertTrue(await peers1[0].is_connected())
self.assertTrue(await peers2[0].is_connected())
# Kill server1
await self.server1.stop()
self.assertTrue(await peers1[0].is_connected())
self.assertFalse(await peers2[0].is_connected())
# Kill server2
await self.server2.stop()
self.assertFalse(await peers1[0].is_connected())
self.assertFalse(await peers2[0].is_connected())
if __name__ == "__main__":
asyncio.run(unittest.main())

View File

@@ -0,0 +1,49 @@
import unittest
from exo.networking.manual.network_topology_config import NetworkTopology
root_path = "./exo/networking/manual/test_data/"
class TestNetworkTopologyConfig(unittest.TestCase):
def test_from_path_invalid_path(self):
with self.assertRaises(FileNotFoundError) as e:
NetworkTopology.from_path("invalid_path")
self.assertEqual(str(e.exception), "Config file not found at invalid_path")
def test_from_path_invalid_json(self):
with self.assertRaises(ValueError) as e:
NetworkTopology.from_path(root_path + "invalid_json.json")
self.assertIn("Error validating network topology config from", str(e.exception))
self.assertIn("1 validation error for NetworkTopology\n Invalid JSON: EOF while parsing a value at line 1 column 0", str(e.exception))
def test_from_path_invalid_config(self):
with self.assertRaises(ValueError) as e:
NetworkTopology.from_path(root_path + "invalid_config.json")
self.assertIn("Error validating network topology config from", str(e.exception))
self.assertIn("port\n Field required", str(e.exception))
def test_from_path_valid(self):
config = NetworkTopology.from_path(root_path + "test_config.json")
self.assertEqual(config.peers["node1"].port, 50051)
self.assertEqual(config.peers["node1"].device_capabilities.model, "Unknown Model")
self.assertEqual(config.peers["node1"].address, "localhost")
self.assertEqual(config.peers["node1"].device_capabilities.chip, "Unknown Chip")
self.assertEqual(config.peers["node1"].device_capabilities.memory, 0)
self.assertEqual(config.peers["node1"].device_capabilities.flops.fp32, 0)
self.assertEqual(config.peers["node1"].device_capabilities.flops.fp16, 0)
self.assertEqual(config.peers["node1"].device_capabilities.flops.int8, 0)
self.assertEqual(config.peers["node2"].port, 50052)
self.assertEqual(config.peers["node2"].device_capabilities.model, "Unknown Model")
self.assertEqual(config.peers["node2"].address, "localhost")
self.assertEqual(config.peers["node2"].device_capabilities.chip, "Unknown Chip")
self.assertEqual(config.peers["node2"].device_capabilities.memory, 0)
self.assertEqual(config.peers["node2"].device_capabilities.flops.fp32, 0)
self.assertEqual(config.peers["node2"].device_capabilities.flops.fp16, 0)
self.assertEqual(config.peers["node2"].device_capabilities.flops.int8, 0)
if __name__ == "__main__":
unittest.main()

View File

@@ -6,6 +6,7 @@ from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.networking.grpc.grpc_server import GRPCServer
from exo.orchestration.node import Node
class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.peer1 = mock.AsyncMock()

View File

@@ -205,4 +205,4 @@ class UDPDiscovery(Discovery):
(current_time - last_seen > self.discovery_timeout) or
(not health_ok)
)
return should_remove
return should_remove

View File

@@ -1,13 +1,13 @@
from typing import Any
from pydantic import BaseModel
from exo import DEBUG
from dataclasses import dataclass, asdict
import subprocess
import psutil
TFLOPS = 1.00
@dataclass
class DeviceFlops:
class DeviceFlops(BaseModel):
# units of TFLOPS
fp32: float
fp16: float
@@ -17,11 +17,10 @@ class DeviceFlops:
return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
def to_dict(self):
return asdict(self)
return self.model_dump()
@dataclass
class DeviceCapabilities:
class DeviceCapabilities(BaseModel):
model: str
chip: str
memory: int
@@ -30,7 +29,7 @@ class DeviceCapabilities:
def __str__(self):
return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
def __post_init__(self):
def model_post_init(self, __context: Any) -> None:
if isinstance(self.flops, dict):
self.flops = DeviceFlops(**self.flops)

View File

@@ -18,6 +18,7 @@ install_requires = [
"prometheus-client==0.20.0",
"protobuf==5.27.1",
"psutil==6.0.0",
"pydantic==2.9.2",
"requests==2.32.3",
"rich==13.7.1",
"safetensors==0.4.3",