switch to uvloop (faster asyncio event loop) and optimise grpc settings

This commit is contained in:
Alex Cheema
2024-12-17 16:10:56 +00:00
parent 58f0a0f547
commit 0a07223074
5 changed files with 97 additions and 38 deletions

View File

@@ -5,7 +5,7 @@ from mlx_lm.sample_utils import top_p_sampling, make_sampler
import mlx.optimizers as optim
from ..inference_engine import InferenceEngine
from .sharded_utils import load_shard, get_image_from_str
from .losses import loss_fns
from .losses import loss_fns
from ..shard import Shard
from typing import Dict, Optional, Tuple
from exo.download.shard_download import ShardDownloader
@@ -56,7 +56,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
async def load_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
self.model.load_weights(path)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
await self.ensure_shard(shard)
state = await self.poll_state(request_id)
@@ -102,7 +102,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
#print(f"{score=}")
layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
#print(layers[0])
@@ -117,7 +117,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
if self.shard != shard:
model_shard, self.tokenizer = await load_shard(model_path, shard)
self.shard = shard
self.model = model_shard
self.model = model_shard
self.caches = OrderedDict()
self.session = {}

View File

@@ -13,7 +13,6 @@ import uuid
import numpy as np
from functools import partial
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
from exo.train.dataset import load_dataset, iterate_batches, compose
from exo.networking.manual.manual_discovery import ManualDiscovery
from exo.networking.manual.network_topology_config import NetworkTopology
@@ -33,6 +32,41 @@ from exo.inference.tokenizers import resolve_tokenizer
from exo.models import build_base_shard, get_repo
from exo.viz.topology_viz import TopologyViz
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
import uvloop
from contextlib import asynccontextmanager
import concurrent.futures
import socket
import resource
import psutil
# Configure uvloop for maximum performance
def configure_uvloop():
# Install uvloop as event loop policy
uvloop.install()
# Create new event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Increase file descriptor limits on Unix systems
if not psutil.WINDOWS:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
try:
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
except ValueError:
try:
resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
except ValueError:
pass
# Configure thread pool for blocking operations
loop.set_default_executor(
concurrent.futures.ThreadPoolExecutor(
max_workers=min(32, (os.cpu_count() or 1) * 4)
)
)
return loop
# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -223,7 +257,7 @@ def clean_path(path):
async def hold_outstanding(node: Node):
while node.outstanding_requests:
await asyncio.sleep(.5)
return
return
async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
losses = []
@@ -234,7 +268,7 @@ async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
tokens.append(np.sum(lengths))
total_tokens = np.sum(tokens)
total_loss = np.sum(losses) / total_tokens
return total_loss, total_tokens
async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
@@ -270,7 +304,7 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
await hold_outstanding(node)
await hold_outstanding(node)
async def main():
loop = asyncio.get_running_loop()
@@ -285,7 +319,7 @@ async def main():
{"❌ No read access" if not has_read else ""}
{"❌ No write access" if not has_write else ""}
""")
if not args.models_seed_dir is None:
try:
models_seed_dir = clean_path(args.models_seed_dir)
@@ -330,29 +364,31 @@ async def main():
print("Error: This train ain't leaving the station without a model")
return
await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
else:
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
await asyncio.Event().wait()
if args.wait_for_peers > 0:
print("Cooldown to allow peers to exit gracefully")
for i in tqdm(range(50)):
await asyncio.sleep(.1)
@asynccontextmanager
async def setup_node(args):
# Rest of setup_node implementation...
pass
def run():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
print("Received keyboard interrupt. Shutting down...")
finally:
loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
loop.close()
loop = None
try:
loop = configure_uvloop()
loop.run_until_complete(main())
except KeyboardInterrupt:
print("\nShutdown requested... exiting")
finally:
if loop:
loop.close()
if __name__ == "__main__":
run()

View File

@@ -21,6 +21,19 @@ class GRPCPeerHandle(PeerHandle):
self._device_capabilities = device_capabilities
self.channel = None
self.stub = None
self.channel_options = [
("grpc.max_metadata_size", 64 * 1024 * 1024),
("grpc.max_receive_message_length", 256 * 1024 * 1024),
("grpc.max_send_message_length", 256 * 1024 * 1024),
("grpc.max_concurrent_streams", 100),
("grpc.http2.min_time_between_pings_ms", 10000),
("grpc.keepalive_time_ms", 20000),
("grpc.keepalive_timeout_ms", 10000),
("grpc.keepalive_permit_without_calls", 1),
("grpc.http2.max_pings_without_data", 0),
("grpc.tcp_nodelay", 1),
("grpc.optimization_target", "throughput"),
]
def id(self) -> str:
return self._id
@@ -36,11 +49,11 @@ class GRPCPeerHandle(PeerHandle):
async def connect(self):
if self.channel is None:
self.channel = grpc.aio.insecure_channel(self.address, options=[
("grpc.max_metadata_size", 32*1024*1024),
('grpc.max_receive_message_length', 32*1024*1024),
('grpc.max_send_message_length', 32*1024*1024)
])
self.channel = grpc.aio.insecure_channel(
self.address,
options=self.channel_options,
compression=grpc.Compression.Gzip
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()
@@ -54,7 +67,13 @@ class GRPCPeerHandle(PeerHandle):
self.stub = None
async def _ensure_connected(self):
if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
if not await self.is_connected():
try:
await asyncio.wait_for(self.connect(), timeout=10.0)
except asyncio.TimeoutError:
if DEBUG >= 2: print(f"Connection timeout for {self._id}@{self.address}")
await self.disconnect()
raise
async def health_check(self) -> bool:
try:

View File

@@ -31,7 +31,7 @@ class BroadcastProtocol(asyncio.DatagramProtocol):
def connection_made(self, transport):
sock = transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))
class UDPDiscovery(Discovery):
@@ -84,11 +84,7 @@ class UDPDiscovery(Discovery):
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]
async def task_broadcast_presence(self):
if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")
while True:
# Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
# the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
for addr, interface_name in get_all_ip_addresses_and_interfaces():
interface_priority, interface_type = await get_interface_priority_and_type(interface_name)
message = json.dumps({
@@ -96,16 +92,23 @@ class UDPDiscovery(Discovery):
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
"priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
"priority": interface_priority,
"interface_name": interface_name,
"interface_type": interface_type,
})
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")
transport = None
try:
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
# Create socket with explicit broadcast permission
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.bind((addr, 0))
# Create transport with the pre-configured socket
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: BroadcastProtocol(message, self.broadcast_port),
sock=sock
)
except Exception as e:
print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
finally:
@@ -113,7 +116,7 @@ class UDPDiscovery(Discovery):
try: transport.close()
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
if DEBUG_DISCOVERY >= 2: traceback.print_exc()
await asyncio.sleep(self.broadcast_interval)
async def on_listen_message(self, data, addr):

View File

@@ -27,6 +27,7 @@ install_requires = [
"tqdm==4.66.4",
"transformers==4.46.3",
"uuid==1.30",
"uvloop==0.21.0",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
]