super secret fixes

This commit is contained in:
Alex Cheema
2025-10-09 14:33:05 +01:00
parent 7d79ea95c6
commit bb91381e95
4 changed files with 48 additions and 17 deletions

View File

@@ -254,18 +254,21 @@ def apply_worker_status_updated(event: WorkerStatusUpdated, state: State) -> Sta
def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> State:
logger.warning(f"~~~ APPLY Node {event.node_id} created")
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
logger.warning(f"~~~ APPLY Edge {event.edge.local_node_id} -> {event.edge.send_back_node_id} created")
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
logger.warning(f"~~~ APPLY Edge {event.edge.local_node_id} -> {event.edge.send_back_node_id} deleted")
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state

View File

@@ -9,9 +9,9 @@ class OrderedBuffer[T]:
source at a time.
"""
def __init__(self):
def __init__(self, start_idx: int = 0):
self.store: dict[int, T] = {}
self.next_idx_to_release: int = 0
self.next_idx_to_release: int = start_idx
def ingest(self, idx: int, t: T):
"""Ingest a sequence into the buffer"""
@@ -56,8 +56,15 @@ class MultiSourceBuffer[SourceId, T]:
def ingest(self, idx: int, t: T, source: SourceId):
if source not in self.stores:
self.stores[source] = OrderedBuffer()
# Seed the per-source buffer to start at the first observed index for that source.
self.stores[source] = OrderedBuffer(start_idx=idx)
buffer = self.stores[source]
# Handle per-source sequence reset (e.g., worker restart resetting its local index to 0).
# If we observe idx == 0 from an existing source with a higher expected index,
# reset that source's buffer to accept the new sequence.
if idx == 0 and buffer.next_idx_to_release > 0:
self.stores[source] = OrderedBuffer(start_idx=0)
buffer = self.stores[source]
buffer.ingest(idx, t)
def drain(self) -> list[T]:

View File

@@ -12,6 +12,7 @@ from urllib.parse import urljoin
import aiofiles
import aiofiles.os as aios
import aiohttp
from loguru import logger
from pydantic import BaseModel, DirectoryPath, Field, PositiveInt, TypeAdapter
from exo.shared.constants import EXO_HOME
@@ -165,13 +166,13 @@ async def seed_models(seed_dir: Union[str, Path]):
if path.is_dir() and path.name.startswith("models--"):
dest_path = dest_dir / path.name
if await aios.path.exists(dest_path):
print("Skipping moving model to .cache directory")
logger.info("Skipping moving model to .cache directory")
else:
try:
await aios.rename(str(path), str(dest_path))
except Exception:
print(f"Error seeding model {path} to {dest_path}")
traceback.print_exc()
logger.error(f"Error seeding model {path} to {dest_path}")
logger.error(traceback.format_exc())
async def fetch_file_list_with_cache(
@@ -320,10 +321,10 @@ async def download_file_with_retry(
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
print(
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
)
traceback.print_exc()
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
@@ -391,7 +392,7 @@ async def _download_file(
try:
await aios.remove(partial_path)
except Exception as e:
print(f"Error removing partial file {partial_path}: {e}")
logger.error(f"Error removing partial file {partial_path}: {e}")
raise Exception(
f"Downloaded file {target_dir / path} has hash {final_hash} but remote hash is {remote_hash}"
)
@@ -461,8 +462,8 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> List[str]:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
print(f"Error getting weight map for {shard.model_meta.model_id=}")
traceback.print_exc()
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
logger.error(traceback.format_exc())
return ["*"]
@@ -532,11 +533,11 @@ async def download_shard(
allow_patterns: List[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
print(f"Downloading {shard.model_meta.model_id=}")
logger.info(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_meta.model_id)):
print(f"Using local model path {shard.model_meta.model_id}")
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_meta.model_id), shard, local_path
@@ -552,7 +553,7 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
print(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.

View File

@@ -136,6 +136,8 @@ class Worker:
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
# Proactively request a global event sync at startup to backfill any missed events.
tg.start_soon(self._request_full_event_log_once)
# TODO: This is a little gross, but not too bad
for msg in self._initial_connection_messages:
await self.event_publisher(
@@ -222,6 +224,7 @@ class Worker:
def _convert_connection_message_to_event(self, msg: ConnectionMessage):
match msg.connection_type:
case ConnectionMessageType.Connected:
logger.warning(f"!!! Node {self.node_id} connected to {msg.node_id}")
return TopologyEdgeCreated(
edge=Connection(
local_node_id=self.node_id,
@@ -233,6 +236,7 @@ class Worker:
)
case ConnectionMessageType.Disconnected:
logger.warning(f"!!! Node {self.node_id} disconnected from {msg.node_id}")
return TopologyEdgeDeleted(
edge=Connection(
local_node_id=self.node_id,
@@ -256,7 +260,7 @@ class Worker:
ForwarderCommand(
origin=self.node_id,
tagged_command=TaggedCommand.from_(
RequestEventLog(since_idx=0)
RequestEventLog(since_idx=self.event_buffer.next_idx_to_release)
),
)
)
@@ -264,6 +268,18 @@ class Worker:
if self._nack_cancel_scope is scope:
self._nack_cancel_scope = None
async def _request_full_event_log_once(self) -> None:
# Fire-and-forget one-time sync shortly after startup.
await anyio.sleep(0.1)
await self.command_sender.send(
ForwarderCommand(
origin=self.node_id,
tagged_command=TaggedCommand.from_(
RequestEventLog(since_idx=self.event_buffer.next_idx_to_release)
),
)
)
async def _resend_out_for_delivery(self) -> None:
# This can also be massively tightened, we should check events are at least a certain age before resending.
# Exponential backoff would also certainly help here.
@@ -345,9 +361,13 @@ class Worker:
assigned_runner, download_progress_queue
):
yield event
# in case the download needs to finish up, wait up to 60 secs for it to finish
# this fixes a bug where the download gets cancelled before it can rename .partial file on finish
await asyncio.wait_for(download_task, timeout=15)
finally:
if not download_task.done():
download_task.cancel()
async def _monitor_download_progress(
self,
@@ -632,8 +652,8 @@ class Worker:
)
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe
logger.debug(
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
logger.info(
f"Worker published event {self.local_event_index}: {str(event)[:100]}...{str(event)[-100:]}"
)
self.local_event_index += 1