by default find an ephemeral node port fixes #35, more robust topology updates. both fix #15 and #14

This commit is contained in:
Alex Cheema
2024-07-18 19:59:57 -07:00
parent 54c98607ef
commit 35177690bd
7 changed files with 68 additions and 25 deletions

1
.gitignore vendored
View File

@@ -1,6 +1,7 @@
__pycache__/
.venv
test_weights.npz
.exo_used_ports
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@@ -187,9 +187,6 @@ class ChatGPTAPI:
headers={
"Content-Type": "application/json",
"Cache-Control": "no-cache",
# "Access-Control-Allow-Origin": "*",
# "Access-Control-Allow-Methods": "*",
# "Access-Control-Allow-Headers": "*",
}
)
await response.prepare(request)

View File

@@ -1,6 +1,8 @@
import os
import asyncio
from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
import socket
import random
DEBUG = int(os.getenv("DEBUG", default="0"))
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
@@ -13,6 +15,36 @@ exo_text = """
\___/_/\_\___/
"""
def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')
def read_used_ports():
if os.path.exists(used_ports_file):
with open(used_ports_file, 'r') as f:
return [int(line.strip()) for line in f if line.strip().isdigit()]
return []
def write_used_port(port, used_ports):
with open(used_ports_file, 'w') as f:
print(used_ports[-19:])
for p in used_ports[-19:] + [port]:
f.write(f"{p}\n")
used_ports = read_used_ports()
available_ports = set(range(min_port, max_port + 1)) - set(used_ports)
while available_ports:
port = random.choice(list(available_ports))
if DEBUG >= 2: print(f"Trying to find available port {port=}")
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((host, port))
write_used_port(port, used_ports)
return port
except socket.error:
available_ports.remove(port)
raise RuntimeError("No available ports in the specified range")
def print_exo():
print(exo_text)
@@ -81,4 +113,4 @@ class AsyncCallbackSystem(Generic[K, T]):
def trigger_all(self, *args: T) -> None:
for callback in self.callbacks.values():
callback.set(*args)
callback.set(*args)

View File

@@ -30,8 +30,7 @@ class GRPCDiscovery(Discovery):
self.listen_port = listen_port
self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
self.broadcast_interval = broadcast_interval
self.known_peers: Dict[str, GRPCPeerHandle] = {}
self.peer_last_seen: Dict[str, float] = {}
self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float]] = {}
self.broadcast_task = None
self.listen_task = None
self.cleanup_task = None
@@ -74,7 +73,7 @@ class GRPCDiscovery(Discovery):
if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
break # No new peers found in the grace period, we are done
return list(self.known_peers.values())
return [peer_handle for peer_handle, _ in self.known_peers.values()]
async def task_broadcast_presence(self):
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
@@ -110,9 +109,9 @@ class GRPCDiscovery(Discovery):
peer_port = message['grpc_port']
device_capabilities = DeviceCapabilities(**message['device_capabilities'])
if peer_id not in self.known_peers:
self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time())
if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
self.peer_last_seen[peer_id] = time.time()
self.known_peers[peer_id] = (self.known_peers[peer_id][0], time.time())
async def task_listen_for_peers(self):
await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
@@ -120,11 +119,17 @@ class GRPCDiscovery(Discovery):
async def task_cleanup_peers(self):
while True:
current_time = time.time()
timeout = 15 * self.broadcast_interval
peers_to_remove = [peer_id for peer_id, last_seen in self.peer_last_seen.items() if current_time - last_seen > timeout]
for peer_id in peers_to_remove:
del self.known_peers[peer_id]
del self.peer_last_seen[peer_id]
if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
await asyncio.sleep(self.broadcast_interval)
try:
current_time = time.time()
timeout = 15 * self.broadcast_interval
peers_to_remove = [peer_handle.id() for peer_handle, last_seen in self.known_peers.values() if not await peer_handle.is_connected() or current_time - last_seen > timeout]
if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, last_seen={last_seen}" for peer_handle, last_seen in self.known_peers.values()})
if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}")
for peer_id in peers_to_remove:
if peer_id in self.known_peers: del self.known_peers[peer_id]
if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
await asyncio.sleep(self.broadcast_interval)
except Exception as e:
print(f"Error in cleanup peers: {e}")
import traceback
print(traceback.format_exc())

View File

@@ -17,7 +17,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
async def start(self) -> None:
self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
('grpc.max_metadata_size', 32*1024*1024)
('grpc.max_metadata_size', 32*1024*1024),
('grpc.max_send_message_length', 128 * 1024 * 1024),
('grpc.max_receive_message_length', 128 * 1024 * 1024),
])
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
listen_addr = f'{self.host}:{self.port}'

View File

@@ -84,7 +84,7 @@ class StandardNode(Node):
if result.size == 1: # we got a new token out
self.buffered_token_output[request_id][0].append(result.item())
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
if not is_finished:
asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
@@ -179,7 +179,8 @@ class StandardNode(Node):
return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
self.topology.update_node(self.id, self.device_capabilities)
next_topology = Topology()
next_topology.update_node(self.id, self.device_capabilities)
if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
@@ -187,8 +188,8 @@ class StandardNode(Node):
visited.update(p.id() for p in self.peers)
for peer in self.peers:
self.topology.update_node(peer.id(), peer.device_capabilities())
self.topology.add_edge(self.id, peer.id())
next_topology.update_node(peer.id(), peer.device_capabilities())
next_topology.add_edge(self.id, peer.id())
if peer.id() in prev_visited:
if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
@@ -205,7 +206,8 @@ class StandardNode(Node):
except Exception as e:
print(f"Error collecting topology from {peer.id()}: {e}")
return self.topology
self.topology = next_topology
return next_topology
# TODO: unify this and collect_topology as global actions
async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:

View File

@@ -11,13 +11,13 @@ from exo.networking.grpc.grpc_server import GRPCServer
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
from exo.helpers import print_yellow_exo
from exo.helpers import print_yellow_exo, find_available_port, DEBUG
# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
parser.add_argument("--node-port", type=int, default=8080, help="Node port")
parser.add_argument("--node-port", type=int, default=None, help="Node port")
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
@@ -49,6 +49,10 @@ else:
raise ValueError(f"Inference engine {args.inference_engine} not supported")
print(f"Using inference engine {inference_engine.__class__.__name__}")
if args.node_port is None:
args.node_port = find_available_port(args.node_host)
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy())
server = GRPCServer(node, args.node_host, args.node_port)