mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
by default find an ephemeral node port fixes #35, more robust topology updates. both fix #15 and #14
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,6 +1,7 @@
|
||||
__pycache__/
|
||||
.venv
|
||||
test_weights.npz
|
||||
.exo_used_ports
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
@@ -187,9 +187,6 @@ class ChatGPTAPI:
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Cache-Control": "no-cache",
|
||||
# "Access-Control-Allow-Origin": "*",
|
||||
# "Access-Control-Allow-Methods": "*",
|
||||
# "Access-Control-Allow-Headers": "*",
|
||||
}
|
||||
)
|
||||
await response.prepare(request)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
|
||||
import socket
|
||||
import random
|
||||
|
||||
DEBUG = int(os.getenv("DEBUG", default="0"))
|
||||
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
|
||||
@@ -13,6 +15,36 @@ exo_text = """
|
||||
\___/_/\_\___/
|
||||
"""
|
||||
|
||||
def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
|
||||
used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')
|
||||
|
||||
def read_used_ports():
|
||||
if os.path.exists(used_ports_file):
|
||||
with open(used_ports_file, 'r') as f:
|
||||
return [int(line.strip()) for line in f if line.strip().isdigit()]
|
||||
return []
|
||||
|
||||
def write_used_port(port, used_ports):
|
||||
with open(used_ports_file, 'w') as f:
|
||||
print(used_ports[-19:])
|
||||
for p in used_ports[-19:] + [port]:
|
||||
f.write(f"{p}\n")
|
||||
|
||||
used_ports = read_used_ports()
|
||||
available_ports = set(range(min_port, max_port + 1)) - set(used_ports)
|
||||
|
||||
while available_ports:
|
||||
port = random.choice(list(available_ports))
|
||||
if DEBUG >= 2: print(f"Trying to find available port {port=}")
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind((host, port))
|
||||
write_used_port(port, used_ports)
|
||||
return port
|
||||
except socket.error:
|
||||
available_ports.remove(port)
|
||||
|
||||
raise RuntimeError("No available ports in the specified range")
|
||||
|
||||
def print_exo():
|
||||
print(exo_text)
|
||||
@@ -81,4 +113,4 @@ class AsyncCallbackSystem(Generic[K, T]):
|
||||
|
||||
def trigger_all(self, *args: T) -> None:
|
||||
for callback in self.callbacks.values():
|
||||
callback.set(*args)
|
||||
callback.set(*args)
|
||||
@@ -30,8 +30,7 @@ class GRPCDiscovery(Discovery):
|
||||
self.listen_port = listen_port
|
||||
self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
|
||||
self.broadcast_interval = broadcast_interval
|
||||
self.known_peers: Dict[str, GRPCPeerHandle] = {}
|
||||
self.peer_last_seen: Dict[str, float] = {}
|
||||
self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float]] = {}
|
||||
self.broadcast_task = None
|
||||
self.listen_task = None
|
||||
self.cleanup_task = None
|
||||
@@ -74,7 +73,7 @@ class GRPCDiscovery(Discovery):
|
||||
if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
|
||||
break # No new peers found in the grace period, we are done
|
||||
|
||||
return list(self.known_peers.values())
|
||||
return [peer_handle for peer_handle, _ in self.known_peers.values()]
|
||||
|
||||
async def task_broadcast_presence(self):
|
||||
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
|
||||
@@ -110,9 +109,9 @@ class GRPCDiscovery(Discovery):
|
||||
peer_port = message['grpc_port']
|
||||
device_capabilities = DeviceCapabilities(**message['device_capabilities'])
|
||||
if peer_id not in self.known_peers:
|
||||
self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
|
||||
self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time())
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
|
||||
self.peer_last_seen[peer_id] = time.time()
|
||||
self.known_peers[peer_id] = (self.known_peers[peer_id][0], time.time())
|
||||
|
||||
async def task_listen_for_peers(self):
|
||||
await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
|
||||
@@ -120,11 +119,17 @@ class GRPCDiscovery(Discovery):
|
||||
|
||||
async def task_cleanup_peers(self):
|
||||
while True:
|
||||
current_time = time.time()
|
||||
timeout = 15 * self.broadcast_interval
|
||||
peers_to_remove = [peer_id for peer_id, last_seen in self.peer_last_seen.items() if current_time - last_seen > timeout]
|
||||
for peer_id in peers_to_remove:
|
||||
del self.known_peers[peer_id]
|
||||
del self.peer_last_seen[peer_id]
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
|
||||
await asyncio.sleep(self.broadcast_interval)
|
||||
try:
|
||||
current_time = time.time()
|
||||
timeout = 15 * self.broadcast_interval
|
||||
peers_to_remove = [peer_handle.id() for peer_handle, last_seen in self.known_peers.values() if not await peer_handle.is_connected() or current_time - last_seen > timeout]
|
||||
if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, last_seen={last_seen}" for peer_handle, last_seen in self.known_peers.values()})
|
||||
if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}")
|
||||
for peer_id in peers_to_remove:
|
||||
if peer_id in self.known_peers: del self.known_peers[peer_id]
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
|
||||
await asyncio.sleep(self.broadcast_interval)
|
||||
except Exception as e:
|
||||
print(f"Error in cleanup peers: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
@@ -17,7 +17,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
|
||||
async def start(self) -> None:
|
||||
self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
|
||||
('grpc.max_metadata_size', 32*1024*1024)
|
||||
('grpc.max_metadata_size', 32*1024*1024),
|
||||
('grpc.max_send_message_length', 128 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 128 * 1024 * 1024),
|
||||
])
|
||||
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
|
||||
listen_addr = f'{self.host}:{self.port}'
|
||||
|
||||
@@ -84,7 +84,7 @@ class StandardNode(Node):
|
||||
if result.size == 1: # we got a new token out
|
||||
self.buffered_token_output[request_id][0].append(result.item())
|
||||
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
||||
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
|
||||
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
||||
|
||||
if not is_finished:
|
||||
asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
|
||||
@@ -179,7 +179,8 @@ class StandardNode(Node):
|
||||
return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
|
||||
|
||||
async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
|
||||
self.topology.update_node(self.id, self.device_capabilities)
|
||||
next_topology = Topology()
|
||||
next_topology.update_node(self.id, self.device_capabilities)
|
||||
|
||||
if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
|
||||
|
||||
@@ -187,8 +188,8 @@ class StandardNode(Node):
|
||||
visited.update(p.id() for p in self.peers)
|
||||
|
||||
for peer in self.peers:
|
||||
self.topology.update_node(peer.id(), peer.device_capabilities())
|
||||
self.topology.add_edge(self.id, peer.id())
|
||||
next_topology.update_node(peer.id(), peer.device_capabilities())
|
||||
next_topology.add_edge(self.id, peer.id())
|
||||
|
||||
if peer.id() in prev_visited:
|
||||
if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
|
||||
@@ -205,7 +206,8 @@ class StandardNode(Node):
|
||||
except Exception as e:
|
||||
print(f"Error collecting topology from {peer.id()}: {e}")
|
||||
|
||||
return self.topology
|
||||
self.topology = next_topology
|
||||
return next_topology
|
||||
|
||||
# TODO: unify this and collect_topology as global actions
|
||||
async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
|
||||
|
||||
8
main.py
8
main.py
@@ -11,13 +11,13 @@ from exo.networking.grpc.grpc_server import GRPCServer
|
||||
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
|
||||
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
from exo.api import ChatGPTAPI
|
||||
from exo.helpers import print_yellow_exo
|
||||
from exo.helpers import print_yellow_exo, find_available_port, DEBUG
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
|
||||
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
||||
parser.add_argument("--node-port", type=int, default=8080, help="Node port")
|
||||
parser.add_argument("--node-port", type=int, default=None, help="Node port")
|
||||
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
||||
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
|
||||
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
|
||||
@@ -49,6 +49,10 @@ else:
|
||||
raise ValueError(f"Inference engine {args.inference_engine} not supported")
|
||||
print(f"Using inference engine {inference_engine.__class__.__name__}")
|
||||
|
||||
|
||||
if args.node_port is None:
|
||||
args.node_port = find_available_port(args.node_host)
|
||||
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
|
||||
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
|
||||
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy())
|
||||
server = GRPCServer(node, args.node_host, args.node_port)
|
||||
|
||||
Reference in New Issue
Block a user