mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
switch to uvloop (faster asyncio event loop) and optimise grpc settings
This commit is contained in:
@@ -5,7 +5,7 @@ from mlx_lm.sample_utils import top_p_sampling, make_sampler
|
||||
import mlx.optimizers as optim
|
||||
from ..inference_engine import InferenceEngine
|
||||
from .sharded_utils import load_shard, get_image_from_str
|
||||
from .losses import loss_fns
|
||||
from .losses import loss_fns
|
||||
from ..shard import Shard
|
||||
from typing import Dict, Optional, Tuple
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
@@ -56,7 +56,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
async def load_checkpoint(self, shard: Shard, path: str):
|
||||
await self.ensure_shard(shard)
|
||||
self.model.load_weights(path)
|
||||
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
await self.ensure_shard(shard)
|
||||
state = await self.poll_state(request_id)
|
||||
@@ -102,7 +102,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
|
||||
score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
|
||||
#print(f"{score=}")
|
||||
|
||||
|
||||
layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
|
||||
#print(layers[0])
|
||||
|
||||
@@ -117,7 +117,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
if self.shard != shard:
|
||||
model_shard, self.tokenizer = await load_shard(model_path, shard)
|
||||
self.shard = shard
|
||||
self.model = model_shard
|
||||
self.model = model_shard
|
||||
self.caches = OrderedDict()
|
||||
self.session = {}
|
||||
|
||||
|
||||
72
exo/main.py
72
exo/main.py
@@ -13,7 +13,6 @@ import uuid
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from tqdm import tqdm
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from exo.train.dataset import load_dataset, iterate_batches, compose
|
||||
from exo.networking.manual.manual_discovery import ManualDiscovery
|
||||
from exo.networking.manual.network_topology_config import NetworkTopology
|
||||
@@ -33,6 +32,41 @@ from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.models import build_base_shard, get_repo
|
||||
from exo.viz.topology_viz import TopologyViz
|
||||
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
|
||||
import uvloop
|
||||
from contextlib import asynccontextmanager
|
||||
import concurrent.futures
|
||||
import socket
|
||||
import resource
|
||||
import psutil
|
||||
|
||||
# Configure uvloop for maximum performance
|
||||
def configure_uvloop():
|
||||
# Install uvloop as event loop policy
|
||||
uvloop.install()
|
||||
|
||||
# Create new event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Increase file descriptor limits on Unix systems
|
||||
if not psutil.WINDOWS:
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
try:
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
||||
except ValueError:
|
||||
try:
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Configure thread pool for blocking operations
|
||||
loop.set_default_executor(
|
||||
concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=min(32, (os.cpu_count() or 1) * 4)
|
||||
)
|
||||
)
|
||||
|
||||
return loop
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
@@ -223,7 +257,7 @@ def clean_path(path):
|
||||
async def hold_outstanding(node: Node):
|
||||
while node.outstanding_requests:
|
||||
await asyncio.sleep(.5)
|
||||
return
|
||||
return
|
||||
|
||||
async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
|
||||
losses = []
|
||||
@@ -234,7 +268,7 @@ async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
|
||||
tokens.append(np.sum(lengths))
|
||||
total_tokens = np.sum(tokens)
|
||||
total_loss = np.sum(losses) / total_tokens
|
||||
|
||||
|
||||
return total_loss, total_tokens
|
||||
|
||||
async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
|
||||
@@ -270,7 +304,7 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
|
||||
await hold_outstanding(node)
|
||||
await hold_outstanding(node)
|
||||
|
||||
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
@@ -285,7 +319,7 @@ async def main():
|
||||
{"❌ No read access" if not has_read else ""}
|
||||
{"❌ No write access" if not has_write else ""}
|
||||
""")
|
||||
|
||||
|
||||
if not args.models_seed_dir is None:
|
||||
try:
|
||||
models_seed_dir = clean_path(args.models_seed_dir)
|
||||
@@ -330,29 +364,31 @@ async def main():
|
||||
print("Error: This train ain't leaving the station without a model")
|
||||
return
|
||||
await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
|
||||
|
||||
|
||||
else:
|
||||
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
if args.wait_for_peers > 0:
|
||||
print("Cooldown to allow peers to exit gracefully")
|
||||
for i in tqdm(range(50)):
|
||||
await asyncio.sleep(.1)
|
||||
|
||||
@asynccontextmanager
|
||||
async def setup_node(args):
|
||||
# Rest of setup_node implementation...
|
||||
pass
|
||||
|
||||
def run():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Received keyboard interrupt. Shutting down...")
|
||||
finally:
|
||||
loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
|
||||
loop.close()
|
||||
|
||||
loop = None
|
||||
try:
|
||||
loop = configure_uvloop()
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutdown requested... exiting")
|
||||
finally:
|
||||
if loop:
|
||||
loop.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
|
||||
@@ -21,6 +21,19 @@ class GRPCPeerHandle(PeerHandle):
|
||||
self._device_capabilities = device_capabilities
|
||||
self.channel = None
|
||||
self.stub = None
|
||||
self.channel_options = [
|
||||
("grpc.max_metadata_size", 64 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 256 * 1024 * 1024),
|
||||
("grpc.max_send_message_length", 256 * 1024 * 1024),
|
||||
("grpc.max_concurrent_streams", 100),
|
||||
("grpc.http2.min_time_between_pings_ms", 10000),
|
||||
("grpc.keepalive_time_ms", 20000),
|
||||
("grpc.keepalive_timeout_ms", 10000),
|
||||
("grpc.keepalive_permit_without_calls", 1),
|
||||
("grpc.http2.max_pings_without_data", 0),
|
||||
("grpc.tcp_nodelay", 1),
|
||||
("grpc.optimization_target", "throughput"),
|
||||
]
|
||||
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
@@ -36,11 +49,11 @@ class GRPCPeerHandle(PeerHandle):
|
||||
|
||||
async def connect(self):
|
||||
if self.channel is None:
|
||||
self.channel = grpc.aio.insecure_channel(self.address, options=[
|
||||
("grpc.max_metadata_size", 32*1024*1024),
|
||||
('grpc.max_receive_message_length', 32*1024*1024),
|
||||
('grpc.max_send_message_length', 32*1024*1024)
|
||||
])
|
||||
self.channel = grpc.aio.insecure_channel(
|
||||
self.address,
|
||||
options=self.channel_options,
|
||||
compression=grpc.Compression.Gzip
|
||||
)
|
||||
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
||||
await self.channel.channel_ready()
|
||||
|
||||
@@ -54,7 +67,13 @@ class GRPCPeerHandle(PeerHandle):
|
||||
self.stub = None
|
||||
|
||||
async def _ensure_connected(self):
|
||||
if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
|
||||
if not await self.is_connected():
|
||||
try:
|
||||
await asyncio.wait_for(self.connect(), timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
if DEBUG >= 2: print(f"Connection timeout for {self._id}@{self.address}")
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
try:
|
||||
|
||||
@@ -31,7 +31,7 @@ class BroadcastProtocol(asyncio.DatagramProtocol):
|
||||
def connection_made(self, transport):
|
||||
sock = transport.get_extra_info("socket")
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
|
||||
transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))
|
||||
|
||||
|
||||
class UDPDiscovery(Discovery):
|
||||
@@ -84,11 +84,7 @@ class UDPDiscovery(Discovery):
|
||||
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]
|
||||
|
||||
async def task_broadcast_presence(self):
|
||||
if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")
|
||||
|
||||
while True:
|
||||
# Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
|
||||
# the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
|
||||
for addr, interface_name in get_all_ip_addresses_and_interfaces():
|
||||
interface_priority, interface_type = await get_interface_priority_and_type(interface_name)
|
||||
message = json.dumps({
|
||||
@@ -96,16 +92,23 @@ class UDPDiscovery(Discovery):
|
||||
"node_id": self.node_id,
|
||||
"grpc_port": self.node_port,
|
||||
"device_capabilities": self.device_capabilities.to_dict(),
|
||||
"priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
|
||||
"priority": interface_priority,
|
||||
"interface_name": interface_name,
|
||||
"interface_type": interface_type,
|
||||
})
|
||||
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")
|
||||
|
||||
transport = None
|
||||
try:
|
||||
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
|
||||
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
|
||||
# Create socket with explicit broadcast permission
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
sock.bind((addr, 0))
|
||||
|
||||
# Create transport with the pre-configured socket
|
||||
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
|
||||
lambda: BroadcastProtocol(message, self.broadcast_port),
|
||||
sock=sock
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
|
||||
finally:
|
||||
@@ -113,7 +116,7 @@ class UDPDiscovery(Discovery):
|
||||
try: transport.close()
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
|
||||
if DEBUG_DISCOVERY >= 2: traceback.print_exc()
|
||||
|
||||
await asyncio.sleep(self.broadcast_interval)
|
||||
|
||||
async def on_listen_message(self, data, addr):
|
||||
|
||||
Reference in New Issue
Block a user