super secret update

This commit is contained in:
Alex Cheema
2025-10-15 12:49:49 +01:00
parent b516d89130
commit 191553a298
54 changed files with 851 additions and 2923 deletions

View File

@@ -40,12 +40,4 @@ else
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
fi
if ifconfig bridge0 >/dev/null 2>&1; then
echo "Thunderbolt bridge found"
if ifconfig bridge0 | grep -q "status: active"; then
sudo ifconfig bridge0 down
echo "Thunderbolt bridge disabled"
fi
fi
fi

View File

@@ -492,133 +492,6 @@
transition: width 0.3s ease;
}
/* Detailed download info */
.download-details {
margin-top: 8px;
padding: 12px;
background-color: #1a1a1a;
border: 1px solid var(--exo-medium-gray);
border-radius: 6px;
box-sizing: border-box;
width: 100%;
max-width: 100%;
overflow: visible;
}
.download-runner-header {
font-size: 11px;
color: var(--exo-light-gray);
opacity: 0.85;
margin-bottom: 4px;
}
.download-overview-row {
display: flex;
gap: 12px;
flex-wrap: wrap;
font-size: 12px;
margin-bottom: 8px;
}
.download-overview-item strong {
color: #E0E0E0;
font-weight: 600;
margin-right: 4px;
}
.progress-with-label {
display: flex;
align-items: center;
gap: 8px;
margin-bottom: 10px;
}
.progress-with-label .progress-bar-container {
flex: 1 1 auto;
}
.progress-percent {
font-size: 12px;
color: var(--exo-light-gray);
font-variant-numeric: tabular-nums;
white-space: nowrap;
}
.download-overview-combined {
font-size: 12px;
color: var(--exo-light-gray);
opacity: 0.9;
}
.instance-download-summary {
font-size: 11px;
color: var(--exo-light-gray);
margin-top: 6px;
opacity: 0.95;
}
.download-files-list {
display: grid;
gap: 8px;
}
.download-file {
padding: 8px;
background-color: var(--exo-dark-gray);
border: 1px solid var(--exo-medium-gray);
border-radius: 6px;
box-sizing: border-box;
width: 100%;
max-width: 100%;
}
.download-file-header {
display: flex;
justify-content: space-between;
align-items: center;
gap: 10px;
font-size: 11px;
margin-bottom: 6px;
width: 100%;
max-width: 100%;
overflow: hidden;
}
.download-file-name {
color: #E0E0E0;
font-weight: 500;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
min-width: 0;
flex: 1 1 auto;
}
.download-file-stats {
color: var(--exo-light-gray);
text-align: right;
white-space: nowrap;
}
.download-file-percent {
color: var(--exo-light-gray);
white-space: nowrap;
font-size: 11px;
font-variant-numeric: tabular-nums;
flex: 0 0 auto;
}
.download-file-subtext {
color: var(--exo-light-gray);
font-size: 10px;
opacity: 0.85;
margin-bottom: 6px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
max-width: 100%;
}
.download-details, .download-files-list {
box-sizing: border-box;
width: 100%;
max-width: 100%;
}
.download-files-list {
overflow: visible;
padding-right: 2px; /* avoid edge clipping */
}
.download-file .progress-bar-container {
width: 100%;
max-width: 100%;
box-sizing: border-box;
height: 5px;
}
/* Launch instance section styles */
.launch-instance-section {
display: flex;
@@ -877,7 +750,6 @@
const USE_MOCK_DATA = false; // <<< FLAG TO TOGGLE MOCK DATA
let currentlySelectedNodeId = null; // To store the ID of the currently selected node
let nodeIdToFriendlyName = {}; // Map nodeId -> friendly name for download sections
const API_ENDPOINT = window.location.origin + window.location.pathname.replace(/\/$/, "") + '/state';
const REFRESH_INTERVAL = 1000; // 1 second
@@ -983,36 +855,6 @@
return days + (days === 1 ? ' day ago' : ' days ago');
}
// --- Download formatting helpers ---
function bytesFromValue(value) {
if (typeof value === 'number') return value;
if (!value || typeof value !== 'object') return 0;
if (typeof value.in_bytes === 'number') return value.in_bytes;
if (typeof value.inBytes === 'number') return value.inBytes;
return 0;
}
function formatDurationMs(ms) {
if (ms == null || isNaN(ms) || ms < 0) return '—';
const totalSeconds = Math.round(ms / 1000);
const s = totalSeconds % 60;
const m = Math.floor(totalSeconds / 60) % 60;
const h = Math.floor(totalSeconds / 3600);
if (h > 0) return `${h}h ${m}m ${s}s`;
if (m > 0) return `${m}m ${s}s`;
return `${s}s`;
}
function formatPercent(value, digits = 2) {
if (value == null || isNaN(value)) return '0.00%';
return `${value.toFixed(digits)}%`;
}
function formatBytesPerSecond(bps) {
if (bps == null || isNaN(bps) || bps < 0) return '0 B/s';
return `${formatBytes(bps)}/s`;
}
// Sidebar toggle functionality
let sidebarOpen = false;
@@ -1092,7 +934,7 @@
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ model_id: selectedModelId })
body: JSON.stringify({ modelId: selectedModelId, model_id: selectedModelId })
});
if (!response.ok) {
@@ -1132,123 +974,75 @@
}
}
// Calculate download status for an instance based on its runners, with detailed per-file info
// Calculate download status for an instance based on its runners
function calculateInstanceDownloadStatus(instance, runners) {
if (!instance.shard_assignments?.runner_to_shard || !runners) {
return { isDownloading: false, progress: 0, details: [] };
const shardAssignments = instance.shard_assignments ?? instance.shardAssignments;
const runnerToShard = shardAssignments?.runner_to_shard ?? shardAssignments?.runnerToShard;
if (!runnerToShard || !runners) {
return { isDownloading: false, progress: 0 };
}
const pick = (obj, snake, camel, fallback = undefined) => {
if (!obj) return fallback;
if (obj[snake] !== undefined) return obj[snake];
if (obj[camel] !== undefined) return obj[camel];
return fallback;
};
function normalizeProgress(progressRaw) {
if (!progressRaw) return null;
const totalBytes = bytesFromValue(pick(progressRaw, 'total_bytes', 'totalBytes', 0));
const downloadedBytes = bytesFromValue(pick(progressRaw, 'downloaded_bytes', 'downloadedBytes', 0));
const downloadedBytesThisSession = bytesFromValue(pick(progressRaw, 'downloaded_bytes_this_session', 'downloadedBytesThisSession', 0));
const completedFiles = Number(pick(progressRaw, 'completed_files', 'completedFiles', 0)) || 0;
const totalFiles = Number(pick(progressRaw, 'total_files', 'totalFiles', 0)) || 0;
const speed = Number(pick(progressRaw, 'speed', 'speed', 0)) || 0;
const etaMs = Number(pick(progressRaw, 'eta_ms', 'etaMs', 0)) || 0;
const filesObj = pick(progressRaw, 'files', 'files', {}) || {};
const files = [];
Object.keys(filesObj).forEach(name => {
const f = filesObj[name];
if (!f || typeof f !== 'object') return;
const fTotal = bytesFromValue(pick(f, 'total_bytes', 'totalBytes', 0));
const fDownloaded = bytesFromValue(pick(f, 'downloaded_bytes', 'downloadedBytes', 0));
const fSpeed = Number(pick(f, 'speed', 'speed', 0)) || 0;
const fEta = Number(pick(f, 'eta_ms', 'etaMs', 0)) || 0;
const fPct = fTotal > 0 ? (fDownloaded / fTotal) * 100 : 0;
files.push({ name, totalBytes: fTotal, downloadedBytes: fDownloaded, speed: fSpeed, etaMs: fEta, percentage: fPct });
});
const percentage = totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0;
return { totalBytes, downloadedBytes, downloadedBytesThisSession, completedFiles, totalFiles, speed, etaMs, files, percentage };
}
const runnerIds = Object.keys(instance.shard_assignments.runner_to_shard);
const details = [];
const runnerIds = Object.keys(runnerToShard);
const downloadingRunners = [];
let totalBytes = 0;
let downloadedBytes = 0;
for (const runnerId of runnerIds) {
const runner = runners[runnerId];
if (!runner || runner.runner_status !== 'Downloading' || !runner.download_progress) continue;
const dp = runner.download_progress;
const isDownloading = (dp.download_status === 'Downloading') || (dp.downloadStatus === 'Downloading');
if (!isDownloading) continue;
const nodeId = (dp && (dp.node_id || dp.nodeId)) || undefined;
const rawProg = pick(dp, 'download_progress', 'downloadProgress', null);
const normalized = normalizeProgress(rawProg);
if (!normalized) continue;
details.push({ runnerId, nodeId, progress: normalized });
totalBytes += normalized.totalBytes || 0;
downloadedBytes += normalized.downloadedBytes || 0;
let isRunnerDownloading = false;
// Legacy snake_case structure
if (runner && runner.runner_status === 'Downloading' && runner.download_progress) {
isRunnerDownloading = runner.download_progress.download_status === 'Downloading';
if (isRunnerDownloading && runner.download_progress.download_progress) {
totalBytes += runner.download_progress.download_progress.total_bytes || 0;
downloadedBytes += runner.download_progress.download_progress.downloaded_bytes || 0;
}
} else if (runner && typeof runner === 'object') {
// Tagged-union camelCase structure, e.g. { "DownloadingRunnerStatus": { downloadProgress: { totalBytes, downloadedBytes } } }
const tag = Object.keys(runner)[0];
if (tag && /DownloadingRunnerStatus$/i.test(tag)) {
isRunnerDownloading = true;
const inner = runner[tag] || {};
const prog = inner.downloadProgress || inner.download_progress || {};
const t = prog.totalBytes ?? prog.total_bytes ?? 0;
const d = prog.downloadedBytes ?? prog.downloaded_bytes ?? 0;
totalBytes += typeof t === 'number' ? t : 0;
downloadedBytes += typeof d === 'number' ? d : 0;
}
}
if (isRunnerDownloading) downloadingRunners.push(runner);
}
const isDownloadingAny = details.length > 0;
const progress = totalBytes > 0 ? ((downloadedBytes / totalBytes) * 100) : 0;
return { isDownloading: isDownloadingAny, progress, details };
}
const isDownloading = downloadingRunners.length > 0;
const progress = totalBytes > 0 ? Math.round((downloadedBytes / totalBytes) * 100) : 0;
function buildDownloadDetailsHTML(details) {
if (!details || details.length === 0) return '';
function shortId(id) { return (id && id.length > 8) ? id.slice(0, 8) + '…' : (id || ''); }
return details.map(({ runnerId, nodeId, progress }) => {
const etaStr = formatDurationMs(progress.etaMs);
const pctStr = formatPercent(progress.percentage || 0, 2);
const bytesStr = `${formatBytes(progress.downloadedBytes)} / ${formatBytes(progress.totalBytes)}`;
const speedStr = formatBytesPerSecond(progress.speed);
const filesSummary = `${progress.completedFiles}/${progress.totalFiles}`;
const filesHTML = (progress.files || []).map(f => {
const fPct = f.percentage || 0;
const fBytes = `${formatBytes(f.downloadedBytes)} / ${formatBytes(f.totalBytes)}`;
const fEta = formatDurationMs(f.etaMs);
const fSpeed = formatBytesPerSecond(f.speed);
const pctText = formatPercent(fPct, 2);
return `
<div class="download-file">
<div class="download-file-header">
<span class="download-file-name" title="${f.name}">${f.name}</span>
<span class="download-file-percent">${pctText}</span>
</div>
<div class="download-file-subtext">${fBytes} • ETA ${fEta}${fSpeed}</div>
<div class="progress-bar-container"><div class="progress-bar" style="width: ${Math.max(0, Math.min(100, fPct)).toFixed(2)}%;"></div></div>
</div>
`;
}).join('');
const runnerName = (nodeId && nodeIdToFriendlyName[nodeId]) ? nodeIdToFriendlyName[nodeId] : '?';
const headerText = `${runnerName} (${shortId(nodeId || '')})`;
return `
<div class="download-details">
<div class="download-runner-header">${headerText}</div>
<div class="download-files-list">
${filesHTML}
</div>
</div>
`;
}).join('');
return { isDownloading, progress, downloadingRunners: downloadingRunners.length };
}
// Derive a display status for an instance from its runners.
// Priority: FAILED > DOWNLOADING > STARTING > RUNNING > LOADED > INACTIVE
function deriveInstanceStatus(instance, runners = {}) {
const runnerIds = Object.keys(instance.shard_assignments?.runner_to_shard || {});
const shardAssignments = instance.shard_assignments ?? instance.shardAssignments;
const runnerToShard = shardAssignments?.runner_to_shard ?? shardAssignments?.runnerToShard ?? {};
const runnerIds = Object.keys(runnerToShard);
const statuses = runnerIds
.map(rid => runners[rid]?.runner_status)
.map(rid => {
const r = runners[rid];
if (!r || typeof r !== 'object') return undefined;
if (typeof r.runner_status === 'string') return r.runner_status;
const tag = Object.keys(r)[0];
return typeof tag === 'string' ? tag.replace(/RunnerStatus$/,'') : undefined; // e.g. LoadedRunnerStatus -> Loaded
})
.filter(s => typeof s === 'string');
const has = (s) => statuses.includes(s);
const every = (pred) => statuses.length > 0 && statuses.every(pred);
if (statuses.length === 0) {
const inactive = instance.instance_type === 'INACTIVE';
const instanceType = instance.instance_type ?? instance.instanceType;
const inactive = instanceType === 'INACTIVE' || instanceType === 'Inactive';
return { statusText: inactive ? 'INACTIVE' : 'LOADED', statusClass: inactive ? 'inactive' : 'loaded' };
}
@@ -1278,10 +1072,12 @@
}
const instancesHTML = instancesArray.map(instance => {
const modelId = instance.shard_assignments?.model_id || 'Unknown Model';
const truncatedInstanceId = instance.instance_id.length > 8
? instance.instance_id.substring(0, 8) + '...'
: instance.instance_id;
const shardAssignments = instance.shard_assignments ?? instance.shardAssignments;
const modelId = shardAssignments?.model_id ?? shardAssignments?.modelId ?? 'Unknown Model';
const instanceId = instance.instance_id ?? instance.instanceId ?? '';
const truncatedInstanceId = instanceId.length > 8
? instanceId.substring(0, 8) + '...'
: instanceId;
const hostsHTML = instance.hosts?.map(host =>
`<span class="instance-host">${host.ip}:${host.port}</span>`
@@ -1298,31 +1094,15 @@
}
// Generate download progress HTML
let downloadProgressHTML = '';
let instanceDownloadSummary = '';
if (downloadStatus.isDownloading) {
const detailsHTML = buildDownloadDetailsHTML(downloadStatus.details || []);
const pctText = (downloadStatus.progress || 0).toFixed(2);
// Aggregate a compact summary from the first runner (they should be consistent in aggregate)
const first = (downloadStatus.details || [])[0]?.progress;
const etaStr = first ? formatDurationMs(first.etaMs) : '—';
const bytesStr = first ? `${formatBytes(first.downloadedBytes)} / ${formatBytes(first.totalBytes)}` : '';
const speedStr = first ? formatBytesPerSecond(first.speed) : '';
const filesSummary = first ? `${first.completedFiles}/${first.totalFiles}` : '';
instanceDownloadSummary = `${etaStr} · ${bytesStr} · ${speedStr} · ${filesSummary} files`;
downloadProgressHTML = `
<div class="download-progress">
<span>${pctText}%</span>
<div class="progress-bar-container">
<div class="progress-bar" style="width: ${pctText}%;"></div>
</div>
const downloadProgressHTML = downloadStatus.isDownloading
? `<div class="download-progress">
<span>${downloadStatus.progress}% downloaded</span>
<div class="progress-bar-container">
<div class="progress-bar" style="width: ${downloadStatus.progress}%;"></div>
</div>
${detailsHTML}
`;
}
</div>`
: '';
const shardCount = Object.keys(instance.shard_assignments?.runner_to_shard || {}).length;
return `
<div class="instance-item">
<div class="instance-header">
@@ -1331,14 +1111,15 @@
<span class="instance-status ${statusClass}">${statusText}</span>
</div>
<div class="instance-actions">
<button class="instance-delete-button" data-instance-id="${instance.instance_id}" title="Delete Instance">
<button class="instance-delete-button" data-instance-id="${instanceId}" title="Delete Instance">
Delete
</button>
</div>
</div>
<div class="instance-model">${modelId} <span style="color: var(--exo-light-gray); opacity: 0.8;">(${shardCount})</span></div>
${instanceDownloadSummary ? `<div class="instance-download-summary">${instanceDownloadSummary}</div>` : ''}
<div class="instance-model">${modelId}</div>
<div class="instance-details">
Shards: ${Object.keys((shardAssignments?.runner_to_shard ?? shardAssignments?.runnerToShard) || {}).length}
</div>
${downloadProgressHTML}
${hostsHTML ? `<div class="instance-hosts">${hostsHTML}</div>` : ''}
</div>
@@ -1395,12 +1176,10 @@
}
}
function renderNodes(topologyData) {
function renderNodes(nodesData) {
if (!topologyGraphContainer) return;
topologyGraphContainer.innerHTML = ''; // Clear previous SVG content
const nodesData = (topologyData && topologyData.nodes) ? topologyData.nodes : {};
const edgesData = (topologyData && Array.isArray(topologyData.edges)) ? topologyData.edges : [];
const nodeIds = Object.keys(nodesData);
if (nodeIds.length === 0) {
@@ -1435,128 +1214,23 @@
};
});
// Add arrowhead definition (supports bidirectional arrows on a single line)
const defs = document.createElementNS('http://www.w3.org/2000/svg', 'defs');
const marker = document.createElementNS('http://www.w3.org/2000/svg', 'marker');
marker.setAttribute('id', 'arrowhead');
marker.setAttribute('viewBox', '0 0 10 10');
marker.setAttribute('refX', '10');
marker.setAttribute('refY', '5');
marker.setAttribute('markerWidth', '11');
marker.setAttribute('markerHeight', '11');
marker.setAttribute('orient', 'auto-start-reverse');
// Draw a subtle V-tip (no filled body)
const markerTip = document.createElementNS('http://www.w3.org/2000/svg', 'path');
markerTip.setAttribute('d', 'M 0 0 L 10 5 L 0 10');
markerTip.setAttribute('fill', 'none');
markerTip.setAttribute('stroke', 'var(--exo-light-gray)');
markerTip.setAttribute('stroke-width', '1.6');
markerTip.setAttribute('stroke-linecap', 'round');
markerTip.setAttribute('stroke-linejoin', 'round');
markerTip.setAttribute('stroke-dasharray', 'none');
markerTip.setAttribute('stroke-dashoffset', '0');
markerTip.setAttribute('style', 'animation: none; pointer-events: none;');
marker.appendChild(markerTip);
defs.appendChild(marker);
topologyGraphContainer.appendChild(defs);
// Create groups for links and separate arrow markers (so arrows are not affected by line animations)
// Create group for links (drawn first, so they are behind nodes)
const linksGroup = document.createElementNS('http://www.w3.org/2000/svg', 'g');
linksGroup.setAttribute('class', 'links-group');
linksGroup.setAttribute('style', 'pointer-events: none;');
const arrowsGroup = document.createElementNS('http://www.w3.org/2000/svg', 'g');
arrowsGroup.setAttribute('class', 'arrows-group');
arrowsGroup.setAttribute('style', 'pointer-events: none;');
// Build quick lookup for node positions
const positionById = {};
nodesWithPositions.forEach(n => { positionById[n.id] = { x: n.x, y: n.y }; });
// Group directed edges into undirected pairs to support single line with two arrows
const pairMap = new Map(); // key: "a|b" with a<b, value: { a, b, aToB, bToA }
edgesData.forEach(edge => {
if (!edge || !edge.source || !edge.target) return;
if (!positionById[edge.source] || !positionById[edge.target]) return;
if (edge.source === edge.target) return;
const a = edge.source < edge.target ? edge.source : edge.target;
const b = edge.source < edge.target ? edge.target : edge.source;
const key = `${a}|${b}`;
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false };
if (edge.source === a && edge.target === b) entry.aToB = true; else entry.bToA = true;
pairMap.set(key, entry);
});
// Draw one line per undirected pair with separate arrow carrier lines
pairMap.forEach(entry => {
const posA = positionById[entry.a];
const posB = positionById[entry.b];
if (!posA || !posB) return;
// Full-length center-to-center lines
const x1 = posA.x;
const y1 = posA.y;
const x2 = posB.x;
const y2 = posB.y;
// Base animated dashed line (no markers)
const baseLine = document.createElementNS('http://www.w3.org/2000/svg', 'line');
baseLine.setAttribute('x1', x1);
baseLine.setAttribute('y1', y1);
baseLine.setAttribute('x2', x2);
baseLine.setAttribute('y2', y2);
baseLine.setAttribute('class', 'graph-link');
linksGroup.appendChild(baseLine);
// Arrowheads centered on the line (tip lies exactly on the line),
// offset along the tangent so opposite directions straddle the center.
const dx = x2 - x1;
const dy = y2 - y1;
const len = Math.hypot(dx, dy) || 1;
const ux = dx / len;
const uy = dy / len;
const mx = (x1 + x2) / 2;
const my = (y1 + y2) / 2;
const tipOffset = 16; // shift arrow tips away from the exact center along the line
const carrier = 2; // short carrier segment length to define orientation
if (entry.aToB) {
// Arrow pointing A -> B: place tip slightly before center along +tangent
const tipX = mx - ux * tipOffset;
const tipY = my - uy * tipOffset;
const sx = tipX - ux * carrier;
const sy = tipY - uy * carrier;
const ex = tipX;
const ey = tipY;
const arrowSeg = document.createElementNS('http://www.w3.org/2000/svg', 'line');
arrowSeg.setAttribute('x1', sx);
arrowSeg.setAttribute('y1', sy);
arrowSeg.setAttribute('x2', ex);
arrowSeg.setAttribute('y2', ey);
arrowSeg.setAttribute('stroke', 'none');
arrowSeg.setAttribute('fill', 'none');
arrowSeg.setAttribute('marker-end', 'url(#arrowhead)');
arrowsGroup.appendChild(arrowSeg);
for (let i = 0; i < numNodes; i++) {
for (let j = i + 1; j < numNodes; j++) {
const link = document.createElementNS('http://www.w3.org/2000/svg', 'line');
link.setAttribute('x1', nodesWithPositions[i].x);
link.setAttribute('y1', nodesWithPositions[i].y);
link.setAttribute('x2', nodesWithPositions[j].x);
link.setAttribute('y2', nodesWithPositions[j].y);
link.setAttribute('class', 'graph-link');
linksGroup.appendChild(link);
}
}
topologyGraphContainer.appendChild(linksGroup);
if (entry.bToA) {
// Arrow pointing B -> A: place tip slightly after center along -tangent
const tipX = mx + ux * tipOffset;
const tipY = my + uy * tipOffset;
const sx = tipX + ux * carrier; // start ahead so the segment points toward tip
const sy = tipY + uy * carrier;
const ex = tipX;
const ey = tipY;
const arrowSeg = document.createElementNS('http://www.w3.org/2000/svg', 'line');
arrowSeg.setAttribute('x1', sx);
arrowSeg.setAttribute('y1', sy);
arrowSeg.setAttribute('x2', ex);
arrowSeg.setAttribute('y2', ey);
arrowSeg.setAttribute('stroke', 'none');
arrowSeg.setAttribute('fill', 'none');
arrowSeg.setAttribute('marker-end', 'url(#arrowhead)');
arrowsGroup.appendChild(arrowSeg);
}
});
// Create group for nodes
const nodesGroup = document.createElementNS('http://www.w3.org/2000/svg', 'g');
nodesGroup.setAttribute('class', 'nodes-group');
@@ -2064,10 +1738,7 @@
nodesGroup.appendChild(nodeG);
});
// Draw order: lines at the very back, then nodes, then mid-line arrows on top
topologyGraphContainer.appendChild(linksGroup);
topologyGraphContainer.appendChild(nodesGroup);
topologyGraphContainer.appendChild(arrowsGroup);
}
function showNodeDetails(selectedNodeId, allNodesData) {
@@ -2215,22 +1886,13 @@
throw new Error(`HTTP error! status: ${response.status} ${response.statusText}`);
}
const clusterState = await response.json();
const topologyData = transformClusterStateToTopology(clusterState);
// Build nodeId -> friendly name map
nodeIdToFriendlyName = {};
if (topologyData && topologyData.nodes) {
Object.keys(topologyData.nodes).forEach(nid => {
const n = topologyData.nodes[nid];
const name = (n && (n.friendly_name || (n.system_info && n.system_info.model_id))) || null;
if (name) nodeIdToFriendlyName[nid] = name;
});
}
renderNodes(topologyData);
const nodesData = transformClusterStateToTopology(clusterState);
renderNodes(nodesData);
// If a node was selected, and it still exists, refresh its details
if (currentlySelectedNodeId && topologyData.nodes[currentlySelectedNodeId]) {
showNodeDetails(currentlySelectedNodeId, topologyData.nodes);
} else if (currentlySelectedNodeId && !topologyData.nodes[currentlySelectedNodeId]) {
if (currentlySelectedNodeId && nodesData[currentlySelectedNodeId]) {
showNodeDetails(currentlySelectedNodeId, nodesData);
} else if (currentlySelectedNodeId && !nodesData[currentlySelectedNodeId]) {
// If selected node is gone, close panel and clear selection
nodeDetailPanel.classList.remove('visible');
currentlySelectedNodeId = null;
@@ -2276,9 +1938,8 @@
}
function transformClusterStateToTopology(clusterState) {
const resultNodes = {};
const resultEdges = [];
if (!clusterState) return { nodes: resultNodes, edges: resultEdges };
const result = {};
if (!clusterState) return result;
// Helper: get numeric bytes from various shapes (number | {in_bytes}|{inBytes})
function getBytes(value) {
@@ -2298,21 +1959,18 @@
return fallback;
};
// Helper: detect API placeholders like "unknown" (case-insensitive)
const isUnknown = (value) => {
return typeof value === 'string' && value.trim().toLowerCase() === 'unknown';
};
// Process nodes from topology or fallback to node_profiles directly
// Process nodes from topology or fallback to node_profiles/nodeProfiles directly
let nodesToProcess = {};
if (clusterState.topology && Array.isArray(clusterState.topology.nodes)) {
clusterState.topology.nodes.forEach(node => {
if (node.node_id && node.node_profile) {
nodesToProcess[node.node_id] = node.node_profile;
const nid = node.node_id ?? node.nodeId;
const nprof = node.node_profile ?? node.nodeProfile;
if (nid && nprof) {
nodesToProcess[nid] = nprof;
}
});
} else if (clusterState.node_profiles) {
nodesToProcess = clusterState.node_profiles;
} else if (clusterState.node_profiles || clusterState.nodeProfiles) {
nodesToProcess = clusterState.node_profiles ?? clusterState.nodeProfiles;
}
// Transform each node
@@ -2333,15 +1991,10 @@
memBytesAvailable = getBytes(ramAvailVal);
const memBytesUsed = Math.max(memBytesTotal - memBytesAvailable, 0);
// Extract model information with graceful placeholders while node is loading
const rawModelId = pick(nodeProfile, 'model_id', 'modelId', 'Unknown');
const rawChipId = pick(nodeProfile, 'chip_id', 'chipId', '');
const rawFriendlyName = pick(nodeProfile, 'friendly_name', 'friendlyName', `${nodeId.substring(0, 8)}...`);
// When API has not fully loaded (reports "unknown"), present a nice default
const modelId = isUnknown(rawModelId) ? 'Mac Studio' : rawModelId;
const chipId = isUnknown(rawChipId) ? '' : rawChipId;
const friendlyName = (!rawFriendlyName || isUnknown(rawFriendlyName)) ? 'Mac' : rawFriendlyName;
// Extract model information
const modelId = pick(nodeProfile, 'model_id', 'modelId', 'Unknown');
const chipId = pick(nodeProfile, 'chip_id', 'chipId', '');
const friendlyName = pick(nodeProfile, 'friendly_name', 'friendlyName', `${nodeId.substring(0, 8)}...`);
// Extract network addresses (support snake_case and camelCase)
const addrList = [];
@@ -2386,7 +2039,7 @@
timestamp: new Date().toISOString()
};
resultNodes[nodeId] = {
result[nodeId] = {
mem: memBytesTotal,
addrs: addrList,
last_addr_update: Date.now() / 1000,
@@ -2400,21 +2053,7 @@
};
}
// Extract directed edges from topology.connections if present
const connections = clusterState.topology && Array.isArray(clusterState.topology.connections)
? clusterState.topology.connections
: [];
connections.forEach(conn => {
if (!conn) return;
const src = conn.local_node_id ?? conn.localNodeId;
const dst = conn.send_back_node_id ?? conn.sendBackNodeId;
if (!src || !dst) return;
if (!resultNodes[src] || !resultNodes[dst]) return; // only draw edges between known nodes
if (src === dst) return; // skip self loops for now
resultEdges.push({ source: src, target: dst });
});
return { nodes: resultNodes, edges: resultEdges };
return result;
}
// --- Conditional Data Handling ---
@@ -2554,12 +2193,11 @@
mi.timestamp = new Date().toISOString();
}
}
const mockTopology = { nodes: mockData, edges: [] };
renderNodes(mockTopology);
renderNodes(mockData);
lastUpdatedElement.textContent = `Last updated: ${new Date().toLocaleTimeString()} (Mock Data)`;
if (currentlySelectedNodeId && mockData[currentlySelectedNodeId]) {
showNodeDetails(currentlySelectedNodeId, mockTopology.nodes);
showNodeDetails(currentlySelectedNodeId, mockData);
} else if (currentlySelectedNodeId && !mockData[currentlySelectedNodeId]) {
nodeDetailPanel.classList.remove('visible');
currentlySelectedNodeId = null;

View File

@@ -57,13 +57,13 @@
# NIX
nixpkgs-fmt
# JUST
just
]
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
macmon
# JUST
just
]);
shellHook = ''

View File

@@ -15,3 +15,8 @@ sync:
sync-clean:
uv sync --all-packages --force-reinstall --no-cache
clean:
rm -rf **/__pycache__
rm -rf rust/target
rm -rf .venv

View File

@@ -36,7 +36,6 @@ dependencies = [
"exo_pyo3_bindings", # rust bindings
"anyio>=4.10.0",
"bidict>=0.23.1",
"chainlit>=2.8.3",
]
[project.scripts]

View File

@@ -19,8 +19,7 @@ from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.main import Worker
from exo.utils.browser import open_url_in_browser_when_ready
from exo.utils.chainlit_ui import start_chainlit, chainlit_cleanup
# TODO: Entrypoint refactor
# I marked this as a dataclass as I want trivial constructors.
@@ -156,27 +155,17 @@ class Node:
if self.api:
self.api.reset()
def main():
args = Args.parse()
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
node = anyio.run(Node.create, args)
chainlit_proc = (
start_chainlit(args.chainlit_port, args.chainlit_host, args.headless)
if args.with_chainlit
else None
)
if args.spawn_api and not args.headless:
open_url_in_browser_when_ready(f"http://localhost:{args.api_port}")
try:
anyio.run(node.run)
finally:
chainlit_cleanup(chainlit_proc)
logger_cleanup()
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger_cleanup()
class Args(CamelCaseModel):
@@ -185,11 +174,6 @@ class Args(CamelCaseModel):
spawn_api: bool = False
api_port: PositiveInt = 8000
tb_only: bool = False
# Chainlit options
with_chainlit: bool = True
chainlit_port: PositiveInt = 8001
chainlit_host: str = "127.0.0.1"
headless: bool = False
@classmethod
def parse(cls) -> Self:
@@ -232,30 +216,6 @@ class Args(CamelCaseModel):
action="store_true",
dest="tb_only",
)
parser.add_argument(
"--with-chainlit",
action="store_true",
dest="with_chainlit",
default=True,
)
parser.add_argument(
"--chainlit-port",
type=int,
dest="chainlit_port",
default=8001,
)
parser.add_argument(
"--chainlit-host",
type=str,
dest="chainlit_host",
default="127.0.0.1",
)
parser.add_argument(
"--headless",
action="store_true",
dest="headless",
help="Prevents the app from opening in the browser."
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -33,7 +33,6 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
TaggedCommand,
# TODO: SpinUpInstance
TaskFinished,
)
@@ -306,9 +305,7 @@ class API:
async def _apply_state(self):
with self.global_event_receiver as events:
async for event in events:
if isinstance(event, ChunkGenerated):
logger.info(f"API received ChunkGenerated: {str(event)[:100]}")
self.event_buffer.ingest(event.origin_idx, event.tagged_event.c)
self.event_buffer.ingest(event.origin_idx, event.event)
for idx, event in self.event_buffer.drain_indexed():
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if (
@@ -319,7 +316,5 @@ class API:
async def _send(self, command: Command):
await self.command_sender.send(
ForwarderCommand(
origin=self.node_id, tagged_command=TaggedCommand.from_(command)
)
ForwarderCommand(origin=self.node_id, command=command)
)

View File

@@ -23,13 +23,12 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceDeleted,
TaggedEvent,
TaskCreated,
TaskDeleted,
TopologyEdgeDeleted,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus
from exo.shared.types.worker.common import InstanceId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
@@ -90,11 +89,9 @@ class Master:
with self.command_receiver as commands:
async for forwarder_command in commands:
try:
logger.info(
f"Executing command: {forwarder_command.tagged_command.c}"
)
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.tagged_command.c
command = forwarder_command.command
match command:
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
@@ -130,11 +127,10 @@ class Master:
TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task_type=TaskType.CHAT_COMPLETION,
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.PENDING,
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
@@ -190,13 +186,14 @@ class Master:
async for local_event in local_events:
self._multi_buffer.ingest(
local_event.origin_idx,
local_event.tagged_event.c,
local_event.event,
local_event.origin,
)
for event in self._multi_buffer.drain():
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
# TODO: SQL
self._event_log.append(event)
await self._send_event(indexed)
@@ -224,17 +221,18 @@ class Master:
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index,
tagged_event=TaggedEvent.from_(event),
event=event,
)
)
local_index += 1
# This function is re-entrant, take care!
async def _send_event(self, event: IndexedEvent):
# Convenience method since this line is ugly
await self.global_event_sender.send(
ForwarderEvent(
origin=self.node_id,
origin_idx=event.idx,
tagged_event=TaggedEvent.from_(event.event),
event=event.event,
)
)

View File

@@ -88,7 +88,7 @@ def get_instance_placements_after_create(
target_instances = dict(deepcopy(current_instances))
target_instances[instance_id] = Instance(
instance_id=instance_id,
instance_type=InstanceStatus.ACTIVE,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=[
Host(

View File

@@ -2,17 +2,10 @@ import asyncio
import pytest
from exo.master.tests.api_utils_test import (
ChatMessage,
stream_chatgpt_response,
with_master_main,
)
@with_master_main
@pytest.mark.asyncio
async def test_master_api_multiple_response_sequential() -> None:
# TODO: This hangs at the moment it seems.
# TODO
return
messages = [ChatMessage(role="user", content="Hello, who are you?")]
token_count = 0

View File

@@ -1,7 +1,8 @@
import asyncio
from typing import List, Sequence
import anyio
import pytest
from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
@@ -11,7 +12,6 @@ from exo.shared.types.commands import (
CommandId,
CreateInstance,
ForwarderCommand,
TaggedCommand,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
@@ -19,7 +19,6 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
NodePerformanceMeasured,
TaggedEvent,
TaskCreated,
)
from exo.shared.types.memory import Memory
@@ -29,9 +28,9 @@ from exo.shared.types.profiling import (
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType
from exo.shared.types.tasks import ChatCompletionTask, TaskStatus
from exo.shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments
from exo.shared.types.worker.shards import PartitionStrategy, PipelineShardMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.utils.channels import channel
@@ -46,12 +45,12 @@ async def test_master():
all_events: List[IndexedEvent] = []
async def _get_events() -> Sequence[IndexedEvent]:
def _get_events() -> Sequence[IndexedEvent]:
orig_events = global_event_receiver.collect()
for e in orig_events:
all_events.append(
IndexedEvent(
event=e.tagged_event.c,
event=e.event,
idx=len(all_events), # origin=e.origin,
)
)
@@ -64,133 +63,141 @@ async def test_master():
command_receiver=co_receiver,
tb_only=False,
)
asyncio.create_task(master.run())
logger.info("run the master")
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodePerformanceProfile event
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
origin=sender_node_id,
tagged_event=TaggedEvent.from_(
NodePerformanceMeasured(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodePerformanceProfile event
logger.info("inject a NodePerformanceProfile event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
origin=sender_node_id,
event=(
NodePerformanceMeasured(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(flops_fp16=0),
),
network_interfaces=[],
system=SystemPerformanceProfile(flops_fp16=0),
),
)
),
)
),
)
)
)
# wait for initial topology event
while len(list(master.state.topology.list_nodes())) == 0:
await asyncio.sleep(0.001)
while len(master.state.node_profiles) == 0:
await asyncio.sleep(0.001)
# wait for initial topology event
logger.info("wait for initial topology event")
while len(list(master.state.topology.list_nodes())) == 0:
await anyio.sleep(0.001)
while len(master.state.node_profiles) == 0:
await anyio.sleep(0.001)
await command_sender.send(
ForwarderCommand(
origin=node_id,
tagged_command=TaggedCommand.from_(
CreateInstance(
command_id=CommandId(),
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
),
)
),
)
)
while len(master.state.instances.keys()) == 0:
await asyncio.sleep(0.001)
await command_sender.send(
ForwarderCommand(
origin=node_id,
tagged_command=TaggedCommand.from_(
ChatCompletion(
command_id=CommandId(),
request_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(
role="user", content="Hello, how are you?"
)
],
),
)
),
)
)
while len(await _get_events()) < 3:
await asyncio.sleep(0.001)
events = await _get_events()
assert len(events) == 3
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
runner_id = list(events[1].event.instance.shard_assignments.runner_to_shard.keys())[
0
]
assert events[1].event == InstanceCreated(
event_id=events[1].event.event_id,
instance=Instance(
instance_id=events[1].event.instance.instance_id,
instance_type=InstanceStatus.ACTIVE,
shard_assignments=ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
partition_strategy=PartitionStrategy.pipeline,
start_layer=0,
end_layer=16,
n_layers=16,
logger.info("inject a CreateInstance Command")
await command_sender.send(
ForwarderCommand(
origin=node_id,
command=(
CreateInstance(
command_id=CommandId(),
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
),
)
)
logger.info("wait for an instance")
while len(master.state.instances.keys()) == 0:
await anyio.sleep(0.001)
logger.info("inject a ChatCompletion Command")
await command_sender.send(
ForwarderCommand(
origin=node_id,
command=(
ChatCompletion(
command_id=CommandId(),
request_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(
role="user", content="Hello, how are you?"
)
],
),
)
),
)
)
while len(_get_events()) < 3:
await anyio.sleep(0.01)
events = _get_events()
assert len(events) == 3
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
runner_id = list(
events[1].event.instance.shard_assignments.runner_to_shard.keys()
)[0]
assert events[1].event == InstanceCreated(
event_id=events[1].event.event_id,
instance=Instance(
instance_id=events[1].event.instance.instance_id,
instance_type=InstanceStatus.Active,
shard_assignments=ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
),
hosts=[],
),
hosts=[],
),
)
assert isinstance(events[2].event, TaskCreated)
assert events[2].event == TaskCreated(
event_id=events[2].event.event_id,
task_id=events[2].event.task_id,
task=ChatCompletionTask(
)
assert isinstance(events[2].event, TaskCreated)
assert events[2].event == TaskCreated(
event_id=events[2].event.event_id,
task_id=events[2].event.task_id,
command_id=events[2].event.task.command_id,
task_type=TaskType.CHAT_COMPLETION,
instance_id=events[2].event.task.instance_id,
task_status=TaskStatus.PENDING,
task_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(role="user", content="Hello, how are you?")
],
task=ChatCompletionTask(
task_id=events[2].event.task_id,
command_id=events[2].event.task.command_id,
instance_id=events[2].event.task.instance_id,
task_status=TaskStatus.Pending,
task_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(
role="user", content="Hello, how are you?"
)
],
),
),
),
)
)
await master.shutdown()

View File

@@ -27,7 +27,7 @@ def topology() -> Topology:
def instance() -> Instance:
return Instance(
instance_id=InstanceId(),
instance_type=InstanceStatus.ACTIVE,
instance_type=InstanceStatus.Active,
shard_assignments=ShardAssignments(
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
),

View File

@@ -104,7 +104,7 @@ def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State:
update: dict[str, TaskStatus | None] = {
"task_status": event.task_status,
}
if event.task_status != TaskStatus.FAILED:
if event.task_status != TaskStatus.Failed:
update["error_type"] = None
update["error_message"] = None
@@ -138,7 +138,7 @@ def apply_instance_activated(event: InstanceActivated, state: State) -> State:
return state
updated_instance = state.instances[event.instance_id].model_copy(
update={"instance_type": InstanceStatus.ACTIVE}
update={"instance_type": InstanceStatus.Active}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
@@ -152,7 +152,7 @@ def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> Stat
return state
updated_instance = state.instances[event.instance_id].model_copy(
update={"instance_type": InstanceStatus.INACTIVE}
update={"instance_type": InstanceStatus.Inactive}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
@@ -254,21 +254,18 @@ def apply_worker_status_updated(event: WorkerStatusUpdated, state: State) -> Sta
def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> State:
logger.warning(f"~~~ APPLY Node {event.node_id} created")
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
logger.warning(f"~~~ APPLY Edge {event.edge.local_node_id} -> {event.edge.send_back_node_id} created")
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
logger.warning(f"~~~ APPLY Edge {event.edge.local_node_id} -> {event.edge.send_back_node_id} deleted")
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state

View File

@@ -15,21 +15,6 @@ class ModelCard(CamelCaseModel):
MODEL_CARDS: dict[str, ModelCard] = {
# kimi k2
# "kimi-k2:4bit": ModelCard(
# short_id="kimi-k2:4bit",
# model_id="mlx-community/Kimi-K2-Instruct-4bit",
# name="Kimi K2 (4-bit)",
# description="""Kimi K2 is a state-of-the-art mixture-of-experts (MoE) language model with 32 billion activated parameters and 1 trillion total parameters. Trained with the Muon optimizer, Kimi K2 achieves exceptional performance across frontier knowledge, reasoning, and coding tasks while being meticulously optimized for agentic capabilities.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
# pretty_name="Kimi K2 (4-bit)",
# storage_size=Memory.from_kb(536870912),
# n_layers=61,
# ),
# ),
# deepseek v3
"deepseek-v3-0324:4bit": ModelCard(
short_id="deepseek-v3-0324:4bit",
@@ -110,6 +95,19 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=61,
),
),
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id="mlx-community/Kimi-K2-Instruct-4bit",
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_bytes(577597603840),
n_layers=61,
),
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",

View File

@@ -1,35 +1,30 @@
from enum import Enum
from typing import Annotated, Literal
from pydantic import BaseModel, Field
from exo.shared.openai_compat import FinishReason
from exo.shared.types.common import CommandId
from exo.shared.types.models import ModelId
from exo.utils.pydantic_ext import TaggedModel
class ChunkType(str, Enum):
token = "token"
image = "image"
Token = "Token"
Image = "Image"
class BaseChunk[ChunkTypeT: ChunkType](BaseModel):
chunk_type: ChunkTypeT
class BaseChunk(TaggedModel):
command_id: CommandId
idx: int
model: ModelId
class TokenChunk(BaseChunk[ChunkType.token]):
chunk_type: Literal[ChunkType.token] = Field(default=ChunkType.token, frozen=True)
class TokenChunk(BaseChunk):
text: str
token_id: int
finish_reason: FinishReason | None = None
class ImageChunk(BaseChunk[ChunkType.image]):
chunk_type: Literal[ChunkType.image] = Field(default=ChunkType.image, frozen=True)
class ImageChunk(BaseChunk):
data: bytes
GenerationChunk = Annotated[TokenChunk | ImageChunk, Field(discriminator="chunk_type")]
GenerationChunk = TokenChunk | ImageChunk

View File

@@ -1,5 +1,4 @@
from enum import Enum
from typing import Union
from pydantic import Field
@@ -7,8 +6,7 @@ from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.common import InstanceId
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_tagged import Tagged, tagged_union
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
# TODO: We need to have a distinction between create instance and spin up instance.
@@ -21,7 +19,7 @@ class CommandType(str, Enum):
RequestEventLog = "RequestEventLog"
class BaseCommand(CamelCaseModel):
class BaseCommand(TaggedModel):
command_id: CommandId = Field(default_factory=CommandId)
@@ -49,30 +47,16 @@ class RequestEventLog(BaseCommand):
since_idx: int
Command = Union[
RequestEventLog,
ChatCompletion,
CreateInstance,
SpinUpInstance,
DeleteInstance,
TaskFinished,
]
@tagged_union(
{
CommandType.ChatCompletion: ChatCompletion,
CommandType.CreateInstance: CreateInstance,
CommandType.SpinUpInstance: SpinUpInstance,
CommandType.DeleteInstance: DeleteInstance,
CommandType.TaskFinished: TaskFinished,
CommandType.RequestEventLog: RequestEventLog,
}
Command = (
RequestEventLog
| ChatCompletion
| CreateInstance
| SpinUpInstance
| DeleteInstance
| TaskFinished
)
class TaggedCommand(Tagged[Command]):
pass
class ForwarderCommand(CamelCaseModel):
origin: NodeId
tagged_command: TaggedCommand
command: Command

View File

@@ -1,11 +1,13 @@
from typing import Self
from uuid import uuid4
from pydantic import BaseModel, GetCoreSchemaHandler, field_validator
from pydantic import GetCoreSchemaHandler, field_validator
from pydantic_core import core_schema
from exo.utils.pydantic_ext import CamelCaseModel
class ID(str):
class Id(str):
def __new__(cls, value: str | None = None) -> Self:
return super().__new__(cls, value or str(uuid4()))
@@ -17,15 +19,15 @@ class ID(str):
return core_schema.str_schema()
class NodeId(ID):
class NodeId(Id):
pass
class CommandId(ID):
class CommandId(Id):
pass
class Host(BaseModel):
class Host(CamelCaseModel):
ip: str
port: int

View File

@@ -1,21 +1,19 @@
from enum import Enum
from typing import Union
from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import CommandId, GenerationChunk
from exo.shared.types.common import ID, NodeId
from exo.shared.types.common import Id, NodeId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.common import InstanceId, WorkerStatus
from exo.shared.types.worker.instances import Instance
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_tagged import Tagged, tagged_union
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
class EventId(ID):
class EventId(Id):
"""
Newtype around `ID`
"""
@@ -60,7 +58,7 @@ class EventType(str, Enum):
TopologyEdgeDeleted = "TopologyEdgeDeleted"
class BaseEvent(CamelCaseModel):
class BaseEvent(TaggedModel):
event_id: EventId = Field(default_factory=EventId)
@@ -145,52 +143,26 @@ class TopologyEdgeDeleted(BaseEvent):
edge: Connection
Event = Union[
TestEvent,
TaskCreated,
TaskStateUpdated,
TaskFailed,
TaskDeleted,
InstanceCreated,
InstanceActivated,
InstanceDeactivated,
InstanceDeleted,
RunnerStatusUpdated,
RunnerDeleted,
NodePerformanceMeasured,
NodeMemoryMeasured,
WorkerStatusUpdated,
ChunkGenerated,
TopologyNodeCreated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
]
@tagged_union(
{
EventType.TestEvent: TestEvent,
EventType.TaskCreated: TaskCreated,
EventType.TaskStateUpdated: TaskStateUpdated,
EventType.TaskFailed: TaskFailed,
EventType.TaskDeleted: TaskDeleted,
EventType.InstanceCreated: InstanceCreated,
EventType.InstanceActivated: InstanceActivated,
EventType.InstanceDeactivated: InstanceDeactivated,
EventType.InstanceDeleted: InstanceDeleted,
EventType.RunnerStatusUpdated: RunnerStatusUpdated,
EventType.RunnerDeleted: RunnerDeleted,
EventType.NodePerformanceMeasured: NodePerformanceMeasured,
EventType.NodeMemoryMeasured: NodeMemoryMeasured,
EventType.WorkerStatusUpdated: WorkerStatusUpdated,
EventType.ChunkGenerated: ChunkGenerated,
EventType.TopologyNodeCreated: TopologyNodeCreated,
EventType.TopologyEdgeCreated: TopologyEdgeCreated,
EventType.TopologyEdgeDeleted: TopologyEdgeDeleted,
}
Event = (
TestEvent
| TaskCreated
| TaskStateUpdated
| TaskFailed
| TaskDeleted
| InstanceCreated
| InstanceActivated
| InstanceDeactivated
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodePerformanceMeasured
| NodeMemoryMeasured
| WorkerStatusUpdated
| ChunkGenerated
| TopologyNodeCreated
| TopologyEdgeCreated
| TopologyEdgeDeleted
)
class TaggedEvent(Tagged[Event]):
pass
class IndexedEvent(CamelCaseModel):
@@ -205,4 +177,4 @@ class ForwarderEvent(CamelCaseModel):
origin_idx: int = Field(ge=0)
origin: NodeId
tagged_event: TaggedEvent
event: Event

View File

@@ -1,11 +1,11 @@
from pydantic import PositiveInt
from exo.shared.types.common import ID
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(ID):
class ModelId(Id):
pass

View File

@@ -1,25 +1,19 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import ConfigDict, Field, field_validator, field_serializer
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.common import InstanceId, WorkerStatus
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.instances import Instance
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.pydantic_ext import CamelCaseModel
def _encode_topology(topo: "Topology") -> dict[str, Any]: # noqa: D401
"""Serialise *topo* into a JSON-compatible dict."""
return topo.to_snapshot().model_dump()
class State(BaseModel):
class State(CamelCaseModel):
"""Global system state.
The :class:`Topology` instance is encoded/decoded via an immutable
@@ -29,9 +23,6 @@ class State(BaseModel):
model_config = ConfigDict(
arbitrary_types_allowed=True,
json_encoders={
Topology: _encode_topology,
},
)
node_status: Mapping[NodeId, WorkerStatus] = {}
instances: Mapping[InstanceId, Instance] = {}
@@ -40,10 +31,12 @@ class State(BaseModel):
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
topology: Topology = Topology()
history: Sequence[Topology] = []
# TODO: we want information about every model that is downloaded on each node
node_downloads: Mapping[NodeId, Mapping[str, DownloadProgressData]] = {}
last_event_applied_idx: int = Field(default=-1, ge=-1)
@field_serializer("topology", mode="plain")
def _encode_topology(self, value: Topology) -> TopologySnapshot:
return value.to_snapshot()
@field_validator("topology", mode="before")
@classmethod
def _deserialize_topology(cls, value: object) -> Topology: # noqa: D401 Pydantic validator signature

View File

@@ -1,30 +1,25 @@
from enum import Enum
from typing import Annotated, Literal
from pydantic import BaseModel, Field
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import ID, CommandId
from exo.shared.types.common import CommandId, Id
from exo.shared.types.worker.common import InstanceId
from exo.utils.pydantic_ext import TaggedModel
class TaskId(ID):
class TaskId(Id):
pass
class TaskType(str, Enum):
CHAT_COMPLETION = "CHAT_COMPLETION"
class TaskStatus(str, Enum):
PENDING = "PENDING"
RUNNING = "RUNNING"
COMPLETE = "COMPLETE"
FAILED = "FAILED"
Pending = "Pending"
Running = "Running"
Complete = "Complete"
Failed = "Failed"
class ChatCompletionTask(BaseModel):
task_type: Literal[TaskType.CHAT_COMPLETION] = TaskType.CHAT_COMPLETION
class ChatCompletionTask(TaggedModel):
task_id: TaskId
command_id: CommandId
instance_id: InstanceId
@@ -35,4 +30,4 @@ class ChatCompletionTask(BaseModel):
error_message: str | None = Field(default=None)
Task = Annotated[ChatCompletionTask, Field(discriminator="task_type")]
Task = ChatCompletionTask

View File

@@ -1,116 +1,69 @@
from enum import Enum
from typing import Annotated, Literal
from pydantic import BaseModel, Field, TypeAdapter
from exo.shared.openai_compat import FinishReason
from exo.shared.types.common import Host
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import TaggedModel
## Messages passed TO the runner
class MessageType(str, Enum):
Setup = "setup"
ChatTask = "chat_task"
Exit = "exit"
class BaseRunnerMessage[MT: MessageType](BaseModel):
class BaseRunnerMessage(TaggedModel):
pass
class SetupMessage(BaseRunnerMessage[MessageType.Setup]):
type: Literal[MessageType.Setup] = Field(default=MessageType.Setup, frozen=True)
class SetupMessage(BaseRunnerMessage):
model_shard_meta: ShardMetadata
hosts: list[Host]
# TODO: We probably want a general task message that can take any task type. Can be fixed later.
class ChatTaskMessage(BaseRunnerMessage[MessageType.ChatTask]):
type: Literal[MessageType.ChatTask] = Field(
default=MessageType.ChatTask, frozen=True
)
class ChatTaskMessage(BaseRunnerMessage):
task_data: ChatCompletionTaskParams
class ExitMessage(BaseRunnerMessage[MessageType.Exit]):
type: Literal[MessageType.Exit] = Field(default=MessageType.Exit, frozen=True)
RunnerMessage = Annotated[
SetupMessage | ChatTaskMessage | ExitMessage, Field(discriminator="type")
]
RunnerMessageTypeAdapter: TypeAdapter[RunnerMessage] = TypeAdapter(RunnerMessage)
## Responses passed FROM the runner
class RunnerResponseType(str, Enum):
InitializedResponse = "initialized_response"
TokenizedResponse = "tokenized_response"
GenerationResponse = "generation_response"
FinishedResponse = "finished_response"
PrintResponse = "print_response"
ErrorResponse = "error_response"
class BaseRunnerResponse[RRT: RunnerResponseType](BaseModel):
class ExitMessage(BaseRunnerMessage):
pass
class InitializedResponse(BaseRunnerResponse[RunnerResponseType.InitializedResponse]):
type: Literal[RunnerResponseType.InitializedResponse] = Field(
default=RunnerResponseType.InitializedResponse, frozen=True
)
RunnerMessage = SetupMessage | ChatTaskMessage | ExitMessage
class BaseRunnerResponse(TaggedModel):
pass
class InitializedResponse(BaseRunnerResponse):
time_taken: float
class TokenizedResponse(BaseRunnerResponse[RunnerResponseType.TokenizedResponse]):
type: Literal[RunnerResponseType.TokenizedResponse] = Field(
default=RunnerResponseType.TokenizedResponse, frozen=True
)
class TokenizedResponse(BaseRunnerResponse):
prompt_tokens: int
class GenerationResponse(BaseRunnerResponse[RunnerResponseType.GenerationResponse]):
type: Literal[RunnerResponseType.GenerationResponse] = Field(
default=RunnerResponseType.GenerationResponse, frozen=True
)
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
# logprobs: Optional[list[float]] = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
class PrintResponse(BaseRunnerResponse[RunnerResponseType.PrintResponse]):
type: Literal[RunnerResponseType.PrintResponse] = Field(
default=RunnerResponseType.PrintResponse, frozen=True
)
class PrintResponse(BaseRunnerResponse):
text: str
class FinishedResponse(BaseRunnerResponse[RunnerResponseType.FinishedResponse]):
type: Literal[RunnerResponseType.FinishedResponse] = Field(
default=RunnerResponseType.FinishedResponse, frozen=True
)
class FinishedResponse(BaseRunnerResponse):
pass
class ErrorResponse(BaseRunnerResponse[RunnerResponseType.ErrorResponse]):
type: Literal[RunnerResponseType.ErrorResponse] = Field(
default=RunnerResponseType.ErrorResponse, frozen=True
)
class ErrorResponse(BaseRunnerResponse):
error_type: str
error_message: str
traceback: str
RunnerResponse = Annotated[
RunnerResponse = (
InitializedResponse
| TokenizedResponse
| GenerationResponse
| PrintResponse
| FinishedResponse
| ErrorResponse,
Field(discriminator="type"),
]
RunnerResponseTypeAdapter: TypeAdapter[RunnerResponse] = TypeAdapter(RunnerResponse)
| ErrorResponse
)

View File

@@ -1,13 +1,13 @@
from enum import Enum
from exo.shared.types.common import ID
from exo.shared.types.common import Id
class InstanceId(ID):
class InstanceId(Id):
pass
class RunnerId(ID):
class RunnerId(Id):
pass

View File

@@ -9,7 +9,6 @@ from exo.shared.types.worker.commands_runner import (
PrintResponse,
RunnerMessage,
RunnerResponse,
RunnerResponseType,
)
### Utils - Runner Prints
@@ -17,7 +16,6 @@ from exo.shared.types.worker.commands_runner import (
def runner_print(text: str) -> None:
obj = PrintResponse(
type=RunnerResponseType.PrintResponse,
text=text,
)
@@ -27,7 +25,6 @@ def runner_print(text: str) -> None:
def runner_write_error(error: Exception) -> None:
error_response: ErrorResponse = ErrorResponse(
type=RunnerResponseType.ErrorResponse,
error_type=type(error).__name__,
error_message=str(error),
traceback=traceback.format_exc(),

View File

@@ -1,73 +1,33 @@
from enum import Enum
from typing import (
Annotated,
Literal,
Union,
)
from pydantic import Field
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
class DownloadProgressData(CamelCaseModel):
total_bytes: Memory
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
completed_files: int
total_files: int
speed: float
eta_ms: int
files: dict[str, "DownloadProgressData"]
class DownloadStatus(str, Enum):
Pending = "Pending"
Downloading = "Downloading"
Completed = "Completed"
Failed = "Failed"
class BaseDownloadProgress[DownloadStatusT: DownloadStatus](CamelCaseModel):
class BaseDownloadProgress(TaggedModel):
node_id: NodeId
download_status: DownloadStatusT
class DownloadPending(BaseDownloadProgress[DownloadStatus.Pending]):
download_status: Literal[DownloadStatus.Pending] = Field(
default=DownloadStatus.Pending
)
class DownloadPending(BaseDownloadProgress):
pass
class DownloadCompleted(BaseDownloadProgress[DownloadStatus.Completed]):
download_status: Literal[DownloadStatus.Completed] = Field(
default=DownloadStatus.Completed
)
class DownloadCompleted(BaseDownloadProgress):
pass
class DownloadFailed(BaseDownloadProgress[DownloadStatus.Failed]):
download_status: Literal[DownloadStatus.Failed] = Field(
default=DownloadStatus.Failed
)
class DownloadFailed(BaseDownloadProgress):
error_message: str
class DownloadOngoing(BaseDownloadProgress[DownloadStatus.Downloading]):
download_status: Literal[DownloadStatus.Downloading] = Field(
default=DownloadStatus.Downloading
)
class DownloadOngoing(BaseDownloadProgress):
download_progress: DownloadProgressData
DownloadProgress = Annotated[
Union[
DownloadPending,
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
],
Field(discriminator="download_status"),
]
DownloadProgress = (
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
)

View File

@@ -1,20 +1,19 @@
from enum import Enum
from pydantic import BaseModel
from exo.shared.types.common import Host
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.runners import (
ShardAssignments,
)
from exo.utils.pydantic_ext import CamelCaseModel
class InstanceStatus(str, Enum):
ACTIVE = "ACTIVE"
INACTIVE = "INACTIVE"
Active = "Active"
Inactive = "Inactive"
class Instance(BaseModel):
class Instance(CamelCaseModel):
instance_id: InstanceId
instance_type: InstanceStatus
shard_assignments: ShardAssignments

View File

@@ -1,86 +1,49 @@
from enum import Enum
from typing import Annotated, Generic, Literal, TypeVar, Union
from pydantic import BaseModel, Field
from exo.shared.types.common import Host
from exo.shared.types.events import InstanceId
from exo.shared.types.tasks import Task
from exo.shared.types.worker.common import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import TaggedModel
class RunnerOpType(str, Enum):
ASSIGN_RUNNER = "assign_runner"
UNASSIGN_RUNNER = "unassign_runner"
RUNNER_UP = "runner_up"
RUNNER_DOWN = "runner_down"
RUNNER_FAILED = "runner_failed"
CHAT_COMPLETION = "chat_completion"
class BaseRunnerOp(TaggedModel):
pass
RunnerOpT = TypeVar("RunnerOpT", bound=RunnerOpType)
class BaseRunnerOp(BaseModel, Generic[RunnerOpT]):
op_type: RunnerOpT
class AssignRunnerOp(BaseRunnerOp[Literal[RunnerOpType.ASSIGN_RUNNER]]):
op_type: Literal[RunnerOpType.ASSIGN_RUNNER] = Field(
default=RunnerOpType.ASSIGN_RUNNER, frozen=True
)
class AssignRunnerOp(BaseRunnerOp):
instance_id: InstanceId
runner_id: RunnerId
shard_metadata: ShardMetadata
hosts: list[Host]
class UnassignRunnerOp(BaseRunnerOp[Literal[RunnerOpType.UNASSIGN_RUNNER]]):
op_type: Literal[RunnerOpType.UNASSIGN_RUNNER] = Field(
default=RunnerOpType.UNASSIGN_RUNNER, frozen=True
)
class UnassignRunnerOp(BaseRunnerOp):
runner_id: RunnerId
class RunnerUpOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_UP]]):
op_type: Literal[RunnerOpType.RUNNER_UP] = Field(
default=RunnerOpType.RUNNER_UP, frozen=True
)
class RunnerUpOp(BaseRunnerOp):
runner_id: RunnerId
class RunnerDownOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_DOWN]]):
op_type: Literal[RunnerOpType.RUNNER_DOWN] = Field(
default=RunnerOpType.RUNNER_DOWN, frozen=True
)
class RunnerDownOp(BaseRunnerOp):
runner_id: RunnerId
class RunnerFailedOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_FAILED]]):
op_type: Literal[RunnerOpType.RUNNER_FAILED] = Field(
default=RunnerOpType.RUNNER_FAILED, frozen=True
)
class RunnerFailedOp(BaseRunnerOp):
runner_id: RunnerId
class ExecuteTaskOp(BaseRunnerOp[Literal[RunnerOpType.CHAT_COMPLETION]]):
op_type: Literal[RunnerOpType.CHAT_COMPLETION] = Field(
default=RunnerOpType.CHAT_COMPLETION, frozen=True
)
class ExecuteTaskOp(BaseRunnerOp):
runner_id: RunnerId
task: Task
# Aggregate all runner operations into a single, strictly-typed union for dispatching.
RunnerOp = Annotated[
Union[
AssignRunnerOp,
UnassignRunnerOp,
RunnerUpOp,
RunnerDownOp,
RunnerFailedOp,
ExecuteTaskOp,
],
Field(discriminator="op_type"),
]
RunnerOp = (
AssignRunnerOp
| UnassignRunnerOp
| RunnerUpOp
| RunnerDownOp
| RunnerFailedOp
| ExecuteTaskOp
)

View File

@@ -1,80 +1,54 @@
from collections.abc import Mapping
from enum import Enum
from typing import Annotated, Literal
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from pydantic import model_validator
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.common import RunnerId
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
class RunnerStatusType(str, Enum):
Downloading = "Downloading"
Inactive = "Inactive"
Starting = "Starting"
Loaded = "Loaded"
Running = "Running"
Failed = "Failed"
class BaseRunnerStatus(TaggedModel):
pass
class BaseRunnerStatus[T: RunnerStatusType](BaseModel):
runner_status: T
class DownloadingRunnerStatus(BaseRunnerStatus[RunnerStatusType.Downloading]):
runner_status: Literal[RunnerStatusType.Downloading] = Field(
default=RunnerStatusType.Downloading
)
class DownloadingRunnerStatus(BaseRunnerStatus):
download_progress: DownloadProgress
class InactiveRunnerStatus(BaseRunnerStatus[RunnerStatusType.Inactive]):
runner_status: Literal[RunnerStatusType.Inactive] = Field(
default=RunnerStatusType.Inactive
)
class InactiveRunnerStatus(BaseRunnerStatus):
pass
class StartingRunnerStatus(BaseRunnerStatus[RunnerStatusType.Starting]):
runner_status: Literal[RunnerStatusType.Starting] = Field(
default=RunnerStatusType.Starting
)
class StartingRunnerStatus(BaseRunnerStatus):
pass
class LoadedRunnerStatus(BaseRunnerStatus[RunnerStatusType.Loaded]):
runner_status: Literal[RunnerStatusType.Loaded] = Field(
default=RunnerStatusType.Loaded
)
class LoadedRunnerStatus(BaseRunnerStatus):
pass
class RunningRunnerStatus(BaseRunnerStatus[RunnerStatusType.Running]):
runner_status: Literal[RunnerStatusType.Running] = Field(
default=RunnerStatusType.Running
)
class RunningRunnerStatus(BaseRunnerStatus):
pass
class FailedRunnerStatus(BaseRunnerStatus[RunnerStatusType.Failed]):
runner_status: Literal[RunnerStatusType.Failed] = Field(
default=RunnerStatusType.Failed
)
class FailedRunnerStatus(BaseRunnerStatus):
error_message: str | None = None
RunnerStatus = Annotated[
RunnerStatus = (
DownloadingRunnerStatus
| InactiveRunnerStatus
| StartingRunnerStatus
| LoadedRunnerStatus
| RunningRunnerStatus
| FailedRunnerStatus,
Field,
]
RunnerStatusParser: TypeAdapter[RunnerStatus] = TypeAdapter(RunnerStatus)
| FailedRunnerStatus
)
class ShardAssignments(BaseModel):
class ShardAssignments(CamelCaseModel):
model_id: ModelId
runner_to_shard: Mapping[RunnerId, ShardMetadata]
node_to_runner: Mapping[NodeId, RunnerId]

View File

@@ -1,39 +1,26 @@
from enum import Enum
from typing import Annotated, Generic, Literal, Optional, TypeVar
from pydantic import Field
from pydantic import BaseModel, Field, TypeAdapter
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.models import ModelMetadata
from exo.utils.pydantic_ext import TaggedModel
class PartitionStrategy(str, Enum):
pipeline = "pipeline"
PartitionStrategyT = TypeVar(
"PartitionStrategyT", bound=PartitionStrategy, covariant=True
)
class BaseShardMetadata(BaseModel, Generic[PartitionStrategyT]):
class BaseShardMetadata(TaggedModel):
"""
Defines a specific shard of the model that is ready to be run on a device.
Replaces previous `Shard` object.
"""
model_meta: ModelMetadata
partition_strategy: PartitionStrategyT
device_rank: int
world_size: int
# Error handling; equivalent to monkey-patch, but we can't monkey-patch runner.py
# This is kinda annoying because it allocates memory in the ShardMetadata object. Can be rethought after Shanghai.
immediate_exception: bool = False
should_timeout: Optional[float] = None
should_timeout: float | None = None
class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline]]):
class PipelineShardMetadata(BaseShardMetadata):
"""
Pipeline parallelism shard meta.
@@ -41,12 +28,9 @@ class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline
where start_layer is inclusive and end_layer is exclusive.
"""
partition_strategy: Literal[PartitionStrategy.pipeline] = Field(
default=PartitionStrategy.pipeline, frozen=True
)
start_layer: Annotated[int, Field(ge=0)]
end_layer: Annotated[int, Field(ge=0)]
n_layers: Annotated[int, Field(ge=0)]
start_layer: int = Field(ge=0)
end_layer: int = Field(ge=0)
n_layers: int = Field(ge=0)
@property
def is_first_layer(self) -> bool:
@@ -62,17 +46,4 @@ class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline
)
ShardMetadata = Annotated[
PipelineShardMetadata, Field(discriminator="partition_strategy")
]
ShardMetadataParser: TypeAdapter[ShardMetadata] = TypeAdapter(ShardMetadata)
class ShardPlacement(BaseModel, Generic[PartitionStrategyT]):
"""
A shard placement is the description of a model distributed across a set of nodes.
The Generic[PartitionStrategyT] enforces that the shard assignments all use the same partition strategy.
"""
model_id: ModelId
shard_assignments: dict[NodeId, BaseShardMetadata[PartitionStrategyT]]
ShardMetadata = PipelineShardMetadata

View File

@@ -9,9 +9,9 @@ class OrderedBuffer[T]:
source at a time.
"""
def __init__(self, start_idx: int = 0):
def __init__(self):
self.store: dict[int, T] = {}
self.next_idx_to_release: int = start_idx
self.next_idx_to_release: int = 0
def ingest(self, idx: int, t: T):
"""Ingest a sequence into the buffer"""
@@ -19,6 +19,9 @@ class OrderedBuffer[T]:
if idx < self.next_idx_to_release:
return
if idx in self.store:
assert self.store[idx] == t, (
"Received different messages with identical indices, probable race condition"
)
return
self.store[idx] = t
@@ -56,15 +59,8 @@ class MultiSourceBuffer[SourceId, T]:
def ingest(self, idx: int, t: T, source: SourceId):
if source not in self.stores:
# Seed the per-source buffer to start at the first observed index for that source.
self.stores[source] = OrderedBuffer(start_idx=idx)
self.stores[source] = OrderedBuffer()
buffer = self.stores[source]
# Handle per-source sequence reset (e.g., worker restart resetting its local index to 0).
# If we observe idx == 0 from an existing source with a higher expected index,
# reset that source's buffer to accept the new sequence.
if idx == 0 and buffer.next_idx_to_release > 0:
self.stores[source] = OrderedBuffer(start_idx=0)
buffer = self.stores[source]
buffer.ingest(idx, t)
def drain(self) -> list[T]:

View File

@@ -1,5 +1,13 @@
from pydantic import BaseModel, ConfigDict
# pyright: reportAny=false, reportUnknownArgumentType=false, reportUnknownVariableType=false
from typing import Any, Self
from pydantic import BaseModel, ConfigDict, model_serializer, model_validator
from pydantic.alias_generators import to_camel
from pydantic_core.core_schema import (
SerializerFunctionWrapHandler,
ValidatorFunctionWrapHandler,
)
class CamelCaseModel(BaseModel):
@@ -12,5 +20,20 @@ class CamelCaseModel(BaseModel):
validate_by_name=True,
extra="forbid",
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
# strict=True,
strict=True,
)
class TaggedModel(CamelCaseModel):
@model_serializer(mode="wrap")
def _serialize(self, handler: SerializerFunctionWrapHandler):
inner = handler(self)
return {self.__class__.__name__: inner}
@model_validator(mode="wrap")
@classmethod
def _validate(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> Self:
if isinstance(v, dict) and len(v) == 1 and cls.__name__ in v:
return handler(v[cls.__name__])
return handler(v)

View File

@@ -1,9 +1,8 @@
from typing import Union
import anyio
import pytest
from pydantic import BaseModel, TypeAdapter, ValidationError
from exo.utils.pydantic_tagged import Tagged, tagged_union # ← CHANGE ME
from exo.utils.pydantic_ext import TaggedModel
def test_plain_union_prefers_first_member_when_shapes_are_identical():
@@ -22,161 +21,230 @@ def test_plain_union_prefers_first_member_when_shapes_are_identical():
def test_tagged_union_serializes_and_deserializes_two_identical_shapes_correctly():
class Foo1(BaseModel):
class Foo1(TaggedModel):
x: int
class Foo2(BaseModel):
class Foo2(TaggedModel):
x: int
foos = Union[Foo1, Foo2]
t1 = Foo1(x=1)
assert t1.model_dump() == {"Foo1": {"x": 1}}
@tagged_union({"Foo1": Foo1, "Foo2": Foo2})
class TaggedFoos(Tagged[foos]):
pass
# ---- serialize (via custom model_serializer) ----
t1 = TaggedFoos.from_(Foo1(x=1))
assert t1.model_dump() == {"t": "Foo1", "c": {"x": 1}}
t2 = TaggedFoos.from_(Foo2(x=2))
assert t2.model_dump() == {"t": "Foo2", "c": {"x": 2}}
t2 = Foo2(x=2)
assert t2.model_dump() == {"Foo2": {"x": 2}}
# ---- deserialize (TypeAdapter -> model_validator(before)) ----
ta = TypeAdapter(TaggedFoos)
ta = TypeAdapter[Foo1 | Foo2](Foo1 | Foo2)
out1 = ta.validate_python({"t": "Foo1", "c": {"x": 10}})
assert isinstance(out1.c, Foo1) and out1.c.x == 10
out1 = ta.validate_python({"Foo1": {"x": 10}})
assert isinstance(out1, Foo1) and out1.x == 10
out2 = ta.validate_python({"t": "Foo2", "c": {"x": 20}})
assert isinstance(out2.c, Foo2) and out2.c.x == 20
out2 = ta.validate_python({"Foo2": {"x": 20}})
assert isinstance(out2, Foo2) and out2.x == 20
def test_tagged_union_rejects_unknown_tag():
class Foo1(BaseModel):
class Foo1(TaggedModel):
x: int
class Foo2(BaseModel):
class Foo2(TaggedModel):
x: int
foos = Union[Foo1, Foo2]
@tagged_union({"Foo1": Foo1, "Foo2": Foo2})
class TaggedFoos(Tagged[foos]):
pass
ta = TypeAdapter(TaggedFoos)
ta = TypeAdapter[Foo1 | Foo2](Foo1 | Foo2)
with pytest.raises(ValidationError):
ta.validate_python({"t": "NotARealTag", "c": {"x": 0}})
def test_multiple_tagged_classes_do_not_override_each_others_mappings():
"""
Creating a *new* Tagged[T] class must not mutate the previously defined one.
This checks both the tag mapping and the per-class adapter dicts.
"""
class Foo1(BaseModel):
x: int
class Foo2(BaseModel):
x: int
foos = Union[Foo1, Foo2]
@tagged_union({"One": Foo1, "Two": Foo2})
class TaggedEN(Tagged[foos]):
pass
# Sanity: initial mapping/behavior
obj_en_1 = TaggedEN.from_(Foo1(x=5))
assert obj_en_1.t == "One"
obj_en_2 = TaggedEN.from_(Foo2(x=6))
assert obj_en_2.t == "Two"
# Define a second, different mapping
@tagged_union({"Uno": Foo1, "Dos": Foo2})
class TaggedES(Tagged[foos]):
pass
# The two classes should have *independent* mappings
# (not the same object, and not equal content)
assert TaggedEN._type_bidict is not TaggedES._type_bidict # pyright: ignore
assert TaggedEN._type_bidict != TaggedES._type_bidict # pyright: ignore
# Their adapters dicts should also be distinct objects
assert TaggedEN._adapter_dict is not TaggedES._adapter_dict # pyright: ignore
# And both should cover the same set of member types
assert set(TaggedEN._adapter_dict.keys()) == {Foo1, Foo2} # pyright: ignore
assert set(TaggedES._adapter_dict.keys()) == {Foo1, Foo2} # pyright: ignore
# Re-check that EN behavior has NOT changed after ES was created
obj_en_1_again = TaggedEN.from_(Foo1(x=7))
obj_en_2_again = TaggedEN.from_(Foo2(x=8))
assert obj_en_1_again.t == "One"
assert obj_en_2_again.t == "Two"
# ES behavior is per its *own* mapping
obj_es_1 = TaggedES.from_(Foo1(x=9))
obj_es_2 = TaggedES.from_(Foo2(x=10))
assert obj_es_1.t == "Uno"
assert obj_es_2.t == "Dos"
# And deserialization respects each class's mapping independently
ta_en = TypeAdapter(TaggedEN)
ta_es = TypeAdapter(TaggedES)
out_en = ta_en.validate_python({"t": "Two", "c": {"x": 123}})
assert isinstance(out_en.c, Foo2) and out_en.c.x == 123
out_es = ta_es.validate_python({"t": "Dos", "c": {"x": 456}})
assert isinstance(out_es.c, Foo2) and out_es.c.x == 456
ta.validate_python({"NotARealTag": {"x": 0}})
def test_two_tagged_classes_with_different_shapes_are_independent_and_not_cross_deserializable():
class A1(BaseModel):
class A1(TaggedModel):
x: int
class A2(BaseModel):
class A2(TaggedModel):
name: str
union_a = Union[A1, A2]
@tagged_union({"One": A1, "Two": A2})
class TaggedA(Tagged[union_a]):
pass
class B1(BaseModel):
class B1(TaggedModel):
name: str
class B2(BaseModel):
class B2(TaggedModel):
active: bool
union_b = Union[B1, B2]
a_payload = A1(x=123).model_dump()
b_payload = B1(name="neo").model_dump()
# Note: using the SAME tag strings intentionally to ensure mappings are per-class
@tagged_union({"One": B1, "Two": B2})
class TaggedB(Tagged[union_b]):
pass
assert a_payload == {"A1": {"x": 123}}
assert b_payload == {"B1": {"name": "neo"}}
# --- Per-class state must be independent ---
assert TaggedA._type_bidict is not TaggedB._type_bidict # pyright: ignore
assert TaggedA._adapter_dict is not TaggedB._adapter_dict # pyright: ignore
assert set(TaggedA._adapter_dict.keys()) == {A1, A2} # pyright: ignore
assert set(TaggedB._adapter_dict.keys()) == {B1, B2} # pyright: ignore
# --- Round-trip for each class with overlapping tag strings ---
a_payload = TaggedA.from_(A1(x=123)).model_dump()
b_payload = TaggedB.from_(B1(name="neo")).model_dump()
assert a_payload == {"t": "One", "c": {"x": 123}}
assert b_payload == {"t": "One", "c": {"name": "neo"}}
# --- Cross-deserialization must fail despite overlapping "t" values ---
ta_a = TypeAdapter(TaggedA)
ta_b = TypeAdapter(TaggedB)
ta_a = TypeAdapter[A1 | A2](A1 | A2)
ta_b = TypeAdapter[B1 | B2](B1 | B2)
with pytest.raises(ValidationError):
ta_a.validate_python(b_payload) # TaggedA expects {"x": ...} for tag "One"
ta_a.validate_python(b_payload)
with pytest.raises(ValidationError):
ta_b.validate_python(a_payload) # TaggedB expects {"name": ...} for tag "One"
ta_b.validate_python(a_payload)
class Inner(TaggedModel):
x: int
class Outer(TaggedModel):
inner: Inner
class Wrapper(TaggedModel):
outer: Outer
label: str
class Container(TaggedModel):
items: list[Inner]
nested: Wrapper
def test_single_level_tagging():
inner = Inner(x=10)
dumped = inner.model_dump()
assert dumped == {"Inner": {"x": 10}}
restored = Inner.model_validate(dumped)
assert isinstance(restored, Inner)
assert restored.x == 10
def test_nested_externally_tagged_union_serializes_recursively():
outer = Outer(inner=Inner(x=42))
dumped = outer.model_dump()
assert dumped == {"Outer": {"inner": {"Inner": {"x": 42}}}}
restored = Outer.model_validate(dumped)
assert isinstance(restored.inner, Inner)
assert restored.inner.x == 42
def test_two_level_nested_tagging():
outer = Outer(inner=Inner(x=123))
dumped = outer.model_dump()
assert dumped == {"Outer": {"inner": {"Inner": {"x": 123}}}}
restored = Outer.model_validate(dumped)
assert isinstance(restored.inner, Inner)
assert restored.inner.x == 123
def test_three_level_nested_tagging():
wrapper = Wrapper(label="deep", outer=Outer(inner=Inner(x=7)))
dumped = wrapper.model_dump()
# 3-level structure, each with exactly one tag
assert dumped == {
"Wrapper": {
"label": "deep",
"outer": {"Outer": {"inner": {"Inner": {"x": 7}}}},
}
}
restored = Wrapper.model_validate(dumped)
assert isinstance(restored.outer.inner, Inner)
assert restored.outer.inner.x == 7
assert restored.label == "deep"
def test_lists_and_mixed_nested_structures():
container = Container(
items=[Inner(x=1), Inner(x=2)],
nested=Wrapper(label="mix", outer=Outer(inner=Inner(x=9))),
)
dumped = container.model_dump()
assert dumped == {
"Container": {
"items": [
{"Inner": {"x": 1}},
{"Inner": {"x": 2}},
],
"nested": {
"Wrapper": {
"label": "mix",
"outer": {"Outer": {"inner": {"Inner": {"x": 9}}}},
}
},
}
}
restored = Container.model_validate(dumped)
assert isinstance(restored.nested.outer.inner, Inner)
assert [i.x for i in restored.items] == [1, 2]
def test_no_double_tagging_on_repeated_calls():
"""Ensure multiple model_dump calls don't stack tags."""
inner = Inner(x=11)
dumped1 = inner.model_dump()
dumped2 = inner.model_dump()
assert dumped1 == dumped2 == {"Inner": {"x": 11}}
outer = Outer(inner=inner)
d1 = outer.model_dump()
d2 = outer.model_dump()
assert d1 == d2 == {"Outer": {"inner": {"Inner": {"x": 11}}}}
class L3A(TaggedModel):
x: int
class L3B(TaggedModel):
x: int
class L3C(TaggedModel):
x: int
L3 = L3A | L3B | L3C
class L2A(TaggedModel):
child: L3
class L2B(TaggedModel):
child: L3
class L2C(TaggedModel):
child: L3
L2 = L2A | L2B | L2C
class L1A(TaggedModel):
child: L2
class L1B(TaggedModel):
child: L2
class L1C(TaggedModel):
child: L2
L1 = L1A | L1B | L1C
@pytest.mark.anyio
async def test_tagged_union_is_fast():
# payload along the "C" path (worst case for DFS if branches are tried A->B->C)
payload = {"L1C": {"child": {"L2C": {"child": {"L3C": {"x": 123}}}}}}
with anyio.fail_after(0.1):
out = TypeAdapter(L1).validate_python(payload) # type: ignore
# Sanity check the result
assert out.__class__.__name__ == "L1C" # type: ignore
assert out.child.__class__.__name__ == "L2C" # type: ignore
assert out.child.child.__class__.__name__ == "L3C" # type: ignore
assert out.child.child.x == 123 # type: ignore

View File

@@ -12,12 +12,9 @@ from urllib.parse import urljoin
import aiofiles
import aiofiles.os as aios
import aiohttp
from loguru import logger
from pydantic import BaseModel, DirectoryPath, Field, PositiveInt, TypeAdapter
from pydantic import BaseModel, DirectoryPath, Field, PositiveInt, TypeAdapter, ConfigDict
from exo.shared.constants import EXO_HOME
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
@@ -56,8 +53,7 @@ class RepoFileDownloadProgress(BaseModel):
status: Literal["not_started", "in_progress", "complete"]
start_time: float
class Config:
frozen = True
model_config = ConfigDict(frozen = True)
class RepoDownloadProgress(BaseModel):
@@ -91,31 +87,10 @@ class RepoDownloadProgress(BaseModel):
# fine-grained file progress keyed by file_path
file_progress: Dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
class Config:
model_config = ConfigDict(
frozen = True # allow use as dict keys if desired
)
def map_repo_file_download_progress_to_download_progress_data(repo_file_download_progress: RepoFileDownloadProgress) -> DownloadProgressData:
return DownloadProgressData(
downloaded_bytes=Memory.from_bytes(repo_file_download_progress.downloaded),
downloaded_bytes_this_session=Memory.from_bytes(repo_file_download_progress.downloaded_this_session),
total_bytes=Memory.from_bytes(repo_file_download_progress.total),
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
total_files=1,
speed=repo_file_download_progress.speed,
eta_ms=int(repo_file_download_progress.eta.total_seconds() * 1000),
files={},
)
def map_repo_download_progress_to_download_progress_data(repo_download_progress: RepoDownloadProgress) -> DownloadProgressData:
return DownloadProgressData(
total_bytes=Memory.from_bytes(repo_download_progress.total_bytes),
downloaded_bytes=Memory.from_bytes(repo_download_progress.downloaded_bytes),
downloaded_bytes_this_session=Memory.from_bytes(repo_download_progress.downloaded_bytes_this_session),
completed_files=repo_download_progress.completed_files,
total_files=repo_download_progress.total_files,
speed=repo_download_progress.overall_speed,
eta_ms=int(repo_download_progress.overall_eta.total_seconds() * 1000),
files={file_path: map_repo_file_download_progress_to_download_progress_data(file_progress) for file_path, file_progress in repo_download_progress.file_progress.items()},
)
def build_model_path(model_id: str) -> DirectoryPath:
return EXO_HOME / "models" / model_id.replace("/", "--")
@@ -166,13 +141,13 @@ async def seed_models(seed_dir: Union[str, Path]):
if path.is_dir() and path.name.startswith("models--"):
dest_path = dest_dir / path.name
if await aios.path.exists(dest_path):
logger.info("Skipping moving model to .cache directory")
print("Skipping moving model to .cache directory")
else:
try:
await aios.rename(str(path), str(dest_path))
except Exception:
logger.error(f"Error seeding model {path} to {dest_path}")
logger.error(traceback.format_exc())
print(f"Error seeding model {path} to {dest_path}")
traceback.print_exc()
async def fetch_file_list_with_cache(
@@ -262,13 +237,9 @@ async def file_meta(
if redirected_location is None
else f"{get_hf_endpoint()}{redirected_location}"
)
# Ensure identity transfer to keep Content-Length and byte accounting
# consistent with on-disk sizes and progress totals.
headers = {**(await get_auth_headers()), "Accept-Encoding": "identity"}
headers = await get_auth_headers()
async with (
aiohttp.ClientSession(
# Disable transparent decompression; we want raw bytes as served.
auto_decompress=False,
timeout=aiohttp.ClientTimeout(
total=1800, connect=60, sock_read=1800, sock_connect=60
)
@@ -276,18 +247,22 @@ async def file_meta(
session.head(url, headers=headers) as r,
):
if r.status == 307:
# On redirect, only trust Hugging Face's x-linked-* headers.
x_linked_size = r.headers.get("x-linked-size")
x_linked_etag = r.headers.get("X-Linked-ETag")
if x_linked_size and x_linked_etag:
content_length = int(x_linked_size)
etag = x_linked_etag
# Try to extract from X-Linked headers first (common for HF redirects)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
etag = (
r.headers.get("X-Linked-ETag")
or r.headers.get("ETag")
or r.headers.get("Etag")
)
if content_length > 0 and etag is not None:
if (etag[0] == '"' and etag[-1] == '"') or (
etag[0] == "'" and etag[-1] == "'"
):
etag = etag[1:-1]
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
# If not available, recurse with the redirect
redirected_location = r.headers.get("Location")
return await file_meta(repo_id, revision, path, redirected_location)
content_length = int(
@@ -321,10 +296,10 @@ async def download_file_with_retry(
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
logger.error(
print(
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
traceback.print_exc()
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
@@ -351,15 +326,12 @@ async def _download_file(
)
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
# Request identity encoding so received byte counts match on-disk size
headers = {**(await get_auth_headers()), "Accept-Encoding": "identity"}
headers = await get_auth_headers()
if resume_byte_pos:
headers["Range"] = f"bytes={resume_byte_pos}-"
n_read = resume_byte_pos or 0
async with (
aiohttp.ClientSession(
# Keep raw transfer semantics (no transparent decompression)
auto_decompress=False,
timeout=aiohttp.ClientTimeout(
total=1800, connect=60, sock_read=1800, sock_connect=60
)
@@ -392,7 +364,7 @@ async def _download_file(
try:
await aios.remove(partial_path)
except Exception as e:
logger.error(f"Error removing partial file {partial_path}: {e}")
print(f"Error removing partial file {partial_path}: {e}")
raise Exception(
f"Downloaded file {target_dir / path} has hash {final_hash} but remote hash is {remote_hash}"
)
@@ -462,8 +434,8 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> List[str]:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
logger.error(traceback.format_exc())
print(f"Error getting weight map for {shard.model_meta.model_id=}")
traceback.print_exc()
return ["*"]
@@ -533,11 +505,11 @@ async def download_shard(
allow_patterns: List[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.info(f"Downloading {shard.model_meta.model_id=}")
print(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_meta.model_id)):
logger.info(f"Using local model path {shard.model_meta.model_id}")
print(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_meta.model_id), shard, local_path
@@ -553,7 +525,7 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
print(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.

View File

@@ -5,7 +5,6 @@ from typing import AsyncIterator, Callable, Dict, List, Optional
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.worker.shards import (
PartitionStrategy,
PipelineShardMetadata,
ShardMetadata,
)
@@ -24,7 +23,6 @@ async def build_base_shard(model_id: str) -> Optional[ShardMetadata]:
# print(f"build_base_shard {model_id=} {model_meta=}")
return PipelineShardMetadata(
model_meta=model_meta,
partition_strategy=PartitionStrategy.pipeline,
device_rank=0,
world_size=1,
start_layer=0,
@@ -39,7 +37,6 @@ async def build_full_shard(model_id: str) -> Optional[PipelineShardMetadata]:
return None
return PipelineShardMetadata(
model_meta=base_shard.model_meta,
partition_strategy=base_shard.partition_strategy,
device_rank=base_shard.device_rank,
world_size=base_shard.world_size,
start_layer=base_shard.start_layer,

View File

@@ -6,7 +6,6 @@ from typing import AsyncIterator, Callable
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import (
PartitionStrategy,
PipelineShardMetadata,
ShardMetadata,
)
@@ -57,7 +56,6 @@ class ShardDownloader(ABC):
storage_size=Memory.from_bytes(0),
n_layers=1,
),
partition_strategy=PartitionStrategy.pipeline,
device_rank=0,
world_size=1,
start_layer=0,
@@ -107,7 +105,6 @@ class NoopShardDownloader(ShardDownloader):
storage_size=Memory.from_bytes(0),
n_layers=1,
),
partition_strategy=PartitionStrategy.pipeline,
device_rank=0,
world_size=1,
start_layer=0,

View File

@@ -1,5 +1,4 @@
import asyncio
import traceback
import time
from asyncio import Queue
from functools import partial
@@ -13,8 +12,7 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.worker.download.download_utils import map_repo_download_progress_to_download_progress_data
from exo.shared.types.commands import ForwarderCommand, RequestEventLog, TaggedCommand
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
ChunkGenerated,
@@ -27,12 +25,12 @@ from exo.shared.types.events import (
NodePerformanceMeasured,
RunnerDeleted,
RunnerStatusUpdated,
TaggedEvent,
TaskFailed,
TaskStateUpdated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.memory import Memory
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
@@ -43,6 +41,7 @@ from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadPending,
DownloadProgressData,
)
from exo.shared.types.worker.ops import (
AssignRunnerOp,
@@ -50,7 +49,6 @@ from exo.shared.types.worker.ops import (
RunnerDownOp,
RunnerFailedOp,
RunnerOp,
RunnerOpType,
RunnerUpOp,
UnassignRunnerOp,
)
@@ -120,25 +118,23 @@ class Worker:
),
)
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
await self.event_publisher(
NodeMemoryMeasured(node_id=self.node_id, memory=memory_profile)
)
# END CLEANUP
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
await self.event_publisher(
NodeMemoryMeasured(node_id=self.node_id, memory=memory_profile)
)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
# Proactively request a global event sync at startup to backfill any missed events.
tg.start_soon(self._request_full_event_log_once)
# TODO: This is a little gross, but not too bad
for msg in self._initial_connection_messages:
await self.event_publisher(
@@ -156,8 +152,8 @@ class Worker:
async def _event_applier(self):
with self.global_event_receiver as events:
async for event in events:
self.event_buffer.ingest(event.origin_idx, event.tagged_event.c)
event_id = event.tagged_event.c.event_id
self.event_buffer.ingest(event.origin_idx, event.event)
event_id = event.event.event_id
if event_id in self.out_for_delivery:
del self.out_for_delivery[event_id]
@@ -201,8 +197,6 @@ class Worker:
async for event in self.execute_op(op):
await self.event_publisher(event)
except Exception as e:
logger.error(f"Error executing op: {str(op)[:100]}")
logger.error(traceback.format_exc())
if isinstance(op, ExecuteTaskOp):
generator = self.fail_task(
e, runner_id=op.runner_id, task_id=op.task.task_id
@@ -227,7 +221,6 @@ class Worker:
def _convert_connection_message_to_event(self, msg: ConnectionMessage):
match msg.connection_type:
case ConnectionMessageType.Connected:
logger.warning(f"!!! Node {self.node_id} connected to {msg.node_id}")
return TopologyEdgeCreated(
edge=Connection(
local_node_id=self.node_id,
@@ -239,7 +232,6 @@ class Worker:
)
case ConnectionMessageType.Disconnected:
logger.warning(f"!!! Node {self.node_id} disconnected from {msg.node_id}")
return TopologyEdgeDeleted(
edge=Connection(
local_node_id=self.node_id,
@@ -262,27 +254,13 @@ class Worker:
await self.command_sender.send(
ForwarderCommand(
origin=self.node_id,
tagged_command=TaggedCommand.from_(
RequestEventLog(since_idx=self.event_buffer.next_idx_to_release)
),
command=RequestEventLog(since_idx=0),
)
)
finally:
if self._nack_cancel_scope is scope:
self._nack_cancel_scope = None
async def _request_full_event_log_once(self) -> None:
# Fire-and-forget one-time sync shortly after startup.
await anyio.sleep(0.1)
await self.command_sender.send(
ForwarderCommand(
origin=self.node_id,
tagged_command=TaggedCommand.from_(
RequestEventLog(since_idx=self.event_buffer.next_idx_to_release)
),
)
)
async def _resend_out_for_delivery(self) -> None:
# This can also be massively tightened, we should check events are at least a certain age before resending.
# Exponential backoff would also certainly help here.
@@ -340,8 +318,13 @@ class Worker:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=map_repo_download_progress_to_download_progress_data(initial_progress),
),
download_progress=DownloadProgressData(
total_bytes=Memory.from_bytes(initial_progress.total_bytes),
downloaded_bytes=Memory.from_bytes(
initial_progress.downloaded_bytes
),
),
)
)
yield assigned_runner.status_update_event()
@@ -357,24 +340,15 @@ class Worker:
download_task = asyncio.create_task(
self.shard_downloader.ensure_shard(op.shard_metadata)
)
logger.info(f"Started download for {op.shard_metadata.model_meta.model_id}")
try:
async for event in self._monitor_download_progress(
assigned_runner, download_progress_queue
):
yield event
# in case the download needs to finish up, wait up to 60 secs for it to finish
# this fixes a bug where the download gets cancelled before it can rename .partial file on finish
await asyncio.wait_for(download_task, timeout=15)
except Exception as e:
logger.error(f"Error monitoring download progress: {e}")
logger.error(traceback.format_exc())
raise e
finally:
if not download_task.done():
download_task.cancel()
async def _monitor_download_progress(
self,
@@ -403,7 +377,12 @@ class Worker:
assigned_runner.status = DownloadingRunnerStatus(
download_progress=DownloadOngoing(
node_id=self.node_id,
download_progress=map_repo_download_progress_to_download_progress_data(progress),
download_progress=DownloadProgressData(
total_bytes=Memory.from_bytes(progress.total_bytes),
downloaded_bytes=Memory.from_bytes(
progress.downloaded_bytes
),
),
)
)
yield assigned_runner.status_update_event()
@@ -424,11 +403,9 @@ class Worker:
)
if initial_progress.status == "complete":
logger.info(f"Shard {op.shard_metadata.model_meta.model_id} already downloaded")
async for event in self._handle_already_downloaded_shard(assigned_runner):
yield event
else:
logger.info(f"Shard {op.shard_metadata.model_meta.model_id} not downloaded, starting download.")
async for event in self._handle_shard_download_process(
assigned_runner, op, initial_progress
):
@@ -526,7 +503,7 @@ class Worker:
await queue.put(
TaskStateUpdated(
task_id=op.task.task_id,
task_status=TaskStatus.RUNNING,
task_status=TaskStatus.Running,
)
)
@@ -547,14 +524,14 @@ class Worker:
)
if op.task.task_id in self.state.tasks:
self.state.tasks[op.task.task_id].task_status = TaskStatus.COMPLETE
self.state.tasks[op.task.task_id].task_status = TaskStatus.Complete
if assigned_runner.shard_metadata.device_rank == 0:
# kind of hack - we don't want to wait for the round trip for this to complete
await queue.put(
TaskStateUpdated(
task_id=op.task.task_id,
task_status=TaskStatus.COMPLETE,
task_status=TaskStatus.Complete,
)
)
@@ -601,18 +578,18 @@ class Worker:
async def execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]:
## It would be great if we can get rid of this async for ... yield pattern.
match op.op_type:
case RunnerOpType.ASSIGN_RUNNER:
match op:
case AssignRunnerOp():
event_generator = self._execute_assign_op(op)
case RunnerOpType.UNASSIGN_RUNNER:
case UnassignRunnerOp():
event_generator = self._execute_unassign_op(op)
case RunnerOpType.RUNNER_UP:
case RunnerUpOp():
event_generator = self._execute_runner_up_op(op)
case RunnerOpType.RUNNER_DOWN:
case RunnerDownOp():
event_generator = self._execute_runner_down_op(op)
case RunnerOpType.RUNNER_FAILED:
case RunnerFailedOp():
event_generator = self._execute_runner_failed_op(op)
case RunnerOpType.CHAT_COMPLETION:
case ExecuteTaskOp():
event_generator = self._execute_task_op(op)
async for event in event_generator:
@@ -643,7 +620,7 @@ class Worker:
if runner_id in self.assigned_runners:
yield TaskStateUpdated(
task_id=task_id,
task_status=TaskStatus.FAILED,
task_status=TaskStatus.Failed,
)
yield TaskFailed(
@@ -653,15 +630,21 @@ class Worker:
async for event in self.fail_runner(e, runner_id):
yield event
# This function is re-entrant, take care!
async def event_publisher(self, event: Event) -> None:
fe = ForwarderEvent(
origin_idx=self.local_event_index,
origin=self.node_id,
tagged_event=TaggedEvent.from_(event),
event=event,
)
logger.debug(
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
)
self.local_event_index += 1
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe
self.local_event_index += 1
def event_relevant_to_worker(event: Event, worker: Worker):

View File

@@ -6,7 +6,7 @@ from exo.shared.types.events import (
)
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.common import RunnerId
from exo.shared.types.worker.downloads import DownloadStatus
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceStatus
from exo.shared.types.worker.ops import (
AssignRunnerOp,
@@ -23,8 +23,8 @@ from exo.shared.types.worker.runners import (
InactiveRunnerStatus,
LoadedRunnerStatus,
RunnerStatus,
RunnerStatusType,
RunningRunnerStatus,
StartingRunnerStatus,
)
from exo.worker.common import AssignedRunner
@@ -45,14 +45,12 @@ def unassign_runners(
# If our instance is in 'downloading' or 'assigned' state, then we know the runner is stale. These are part of AssignRunnerOp and should be blocking.
for assigned_runner_id in assigned_runners:
if (
assigned_runner_id in state_runners
and isinstance(state_runners[assigned_runner_id], DownloadingRunnerStatus)
# Not sure about this type ignore, i don't think it should be necessary
and state_runners[assigned_runner_id].download_progress.download_status # type: ignore
!= DownloadStatus.Completed
):
return UnassignRunnerOp(runner_id=assigned_runner_id)
if assigned_runner_id in state_runners:
status = state_runners[assigned_runner_id]
if isinstance(status, DownloadingRunnerStatus) and not isinstance(
status.download_progress, DownloadCompleted
):
return UnassignRunnerOp(runner_id=assigned_runner_id)
return None
@@ -85,7 +83,7 @@ def spin_down_runners(
if (
runner_id in assigned_runners
and isinstance(assigned_runners[runner_id].status, LoadedRunnerStatus)
and instance.instance_type == InstanceStatus.INACTIVE
and instance.instance_type == InstanceStatus.Inactive
):
return RunnerDownOp(runner_id=runner_id)
@@ -195,18 +193,19 @@ def spin_up_runners(
instance.shard_assignments.node_to_runner[worker_node_id]
].runner
is None
and instance.instance_type == InstanceStatus.ACTIVE
and instance.instance_type == InstanceStatus.Active
):
# We are part of this instance, we want it up but it hasn't been spun up yet.
# Need to assert all other runners are ready before we can spin up.
ready_to_spin = True
for runner_id in instance.shard_assignments.node_to_runner.values():
if runner_id in state_runners and state_runners[
runner_id
].runner_status not in [
RunnerStatusType.Inactive,
RunnerStatusType.Starting,
]:
if runner_id in state_runners and isinstance(
state_runners[runner_id],
(
InactiveRunnerStatus,
StartingRunnerStatus,
),
):
ready_to_spin = False
if ready_to_spin:
@@ -229,13 +228,12 @@ def execute_task_op(
continue
assert runner_id in assigned_runners
runner = assigned_runners[runner_id]
if runner.status.runner_status != RunnerStatusType.Loaded:
if not isinstance(runner.status, LoadedRunnerStatus):
continue # The only previous state to get to Running is from Loaded
for _, task in tasks.items():
if task.instance_id == instance_id and (
task.task_status == TaskStatus.PENDING
or task.task_status == TaskStatus.FAILED
task.task_status in (TaskStatus.Pending, TaskStatus.Failed)
):
if (
runner.shard_metadata.device_rank >= 1

View File

@@ -10,7 +10,6 @@ from exo.shared.types.tasks import (
ChatCompletionTask,
TaskId,
TaskStatus,
TaskType,
)
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.instances import Instance, InstanceStatus
@@ -131,7 +130,7 @@ def instance(
return Instance(
instance_id=resolved_instance_id,
instance_type=InstanceStatus.ACTIVE,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(1),
)
@@ -161,8 +160,7 @@ def chat_completion_task(completion_create_params: ChatCompletionTaskParams):
task_id=resolved_task_id,
command_id=COMMAND_1_ID,
instance_id=resolved_instance_id,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_status=TaskStatus.Pending,
task_params=completion_create_params,
)

View File

@@ -1,29 +1,22 @@
import time
import os
from pathlib import Path
import shutil
from typing import Callable
import pytest
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.worker.download.download_utils import RepoDownloadProgress
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.download.shard_downloader import ShardDownloader
@pytest.mark.slow
@pytest.mark.asyncio
async def test_shard_downloader(
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
):
shutil.rmtree(Path(os.path.expanduser("~/.exo/models/mlx-community--Llama-3.2-1B-Instruct-4bit")))
progress_log: list[RepoDownloadProgress] = []
shard_downloader: ShardDownloader = exo_shard_downloader()
def _on_progress(shard: ShardMetadata, progress: RepoDownloadProgress):
print(f"Download progress: {progress}")
progress_log.append(progress)
shard_downloader.on_progress(_on_progress)
shard_downloader.on_progress(
lambda shard, progress: print(f"Download progress: {progress}")
)
shard_metadata = pipeline_shard_meta(1, 0)
path = await shard_downloader.ensure_shard(shard_metadata)
@@ -54,12 +47,3 @@ async def test_shard_downloader(
duration = time.monotonic() - start_time
assert path_again == path
assert duration < 5, f"Second call to ensure_shard took too long: {duration:.2f}s"
print(progress_log[-1].file_progress)
assert len(progress_log) > 0
assert progress_log[-1].status == "complete"
assert progress_log[-1].completed_files == 6
assert progress_log[-1].total_files == 6
assert progress_log[-1].downloaded_bytes == sum(file_size for _, file_size in expected_files_and_sizes)
assert progress_log[-1].total_bytes == sum(file_size for _, file_size in expected_files_and_sizes)

View File

@@ -145,10 +145,10 @@ async def test_execute_task_op(
assert isinstance(events[0].runner_status, RunningRunnerStatus)
assert isinstance(events[1], TaskStateUpdated)
assert events[1].task_status == TaskStatus.RUNNING # It tried to start.
assert events[1].task_status == TaskStatus.Running # It tried to start.
assert isinstance(events[-2], TaskStateUpdated)
assert events[-2].task_status == TaskStatus.COMPLETE # It tried to start.
assert events[-2].task_status == TaskStatus.Complete # It tried to start.
assert isinstance(events[-1], RunnerStatusUpdated)
assert isinstance(

View File

@@ -17,7 +17,6 @@ from exo.shared.types.tasks import (
Task,
TaskId,
TaskStatus,
TaskType,
)
from exo.shared.types.worker.common import InstanceId, RunnerId
from exo.shared.types.worker.instances import (
@@ -57,7 +56,7 @@ async def test_runner_inference(
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
instance_value.instance_type = InstanceStatus.Active
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
@@ -120,7 +119,7 @@ async def test_2_runner_inference(
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(2),
)
@@ -190,7 +189,7 @@ async def test_2_runner_multi_message(
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(2),
)
@@ -218,8 +217,7 @@ async def test_2_runner_multi_message(
task_id=TASK_1_ID,
command_id=CommandId(),
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_status=TaskStatus.Pending,
task_params=completion_create_params,
)

View File

@@ -58,7 +58,7 @@ async def test_stream_response_failed_always(
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
instance_value.instance_type = InstanceStatus.Active
async def mock_stream_response(
self: RunnerSupervisor,
@@ -88,8 +88,8 @@ async def test_stream_response_failed_always(
[
x
for x in events
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 3
@@ -99,13 +99,13 @@ async def test_stream_response_failed_always(
[
x
for x in events
if isinstance(x.tagged_event.c, TaskStateUpdated)
and x.tagged_event.c.task_status == TaskStatus.FAILED
if isinstance(x.event, TaskStateUpdated)
and x.event.task_status == TaskStatus.Failed
]
)
== 3
)
assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events])
assert any([isinstance(x.event, InstanceDeleted) for x in events])
await global_events.append_events(
[
@@ -152,7 +152,7 @@ async def test_stream_response_failed_once(
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
instance_value.instance_type = InstanceStatus.Active
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
await global_events.append_events(
@@ -186,8 +186,8 @@ async def test_stream_response_failed_once(
[
x
for x in events
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 1
@@ -197,8 +197,8 @@ async def test_stream_response_failed_once(
[
x
for x in events
if isinstance(x.tagged_event.c, TaskStateUpdated)
and x.tagged_event.c.task_status == TaskStatus.FAILED
if isinstance(x.event, TaskStateUpdated)
and x.event.task_status == TaskStatus.Failed
]
)
== 1
@@ -209,11 +209,11 @@ async def test_stream_response_failed_once(
seen_task_started, seen_task_finished = False, False
for wrapped_event in events:
event = wrapped_event.tagged_event.c
event = wrapped_event.event
if isinstance(event, TaskStateUpdated):
if event.task_status == TaskStatus.RUNNING:
if event.task_status == TaskStatus.Running:
seen_task_started = True
if event.task_status == TaskStatus.COMPLETE:
if event.task_status == TaskStatus.Complete:
seen_task_finished = True
if isinstance(event, ChunkGenerated):
@@ -246,7 +246,7 @@ async def test_stream_response_timeout(
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
instance_value.instance_type = InstanceStatus.Active
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
task.task_params.messages[0].content = "EXO RUNNER MUST TIMEOUT"
@@ -269,8 +269,8 @@ async def test_stream_response_timeout(
[
x
for x in events
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 3
@@ -280,8 +280,8 @@ async def test_stream_response_timeout(
[
x
for x in events
if isinstance(x.tagged_event.c, TaskStateUpdated)
and x.tagged_event.c.task_status == TaskStatus.FAILED
if isinstance(x.event, TaskStateUpdated)
and x.event.task_status == TaskStatus.Failed
]
)
== 3
@@ -291,8 +291,8 @@ async def test_stream_response_timeout(
[
x
for x in events
if isinstance(x.tagged_event.c, TaskFailed)
and "timeouterror" in x.tagged_event.c.error_type.lower()
if isinstance(x.event, TaskFailed)
and "timeouterror" in x.event.error_type.lower()
]
)
== 3

View File

@@ -37,7 +37,7 @@ async def test_runner_spinup_timeout(
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
instance_value.instance_type = InstanceStatus.Active
instance_value.shard_assignments.runner_to_shard[
RUNNER_1_ID
].should_timeout = 10
@@ -61,11 +61,11 @@ async def test_runner_spinup_timeout(
[
x
for x in events
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 3
)
assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events])
assert any([isinstance(x.event, InstanceDeleted) for x in events])
worker.shutdown()

View File

@@ -38,7 +38,7 @@ async def test_runner_spinup_exception(
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
instance_value.instance_type = InstanceStatus.Active
instance_value.shard_assignments.runner_to_shard[
RUNNER_1_ID
].immediate_exception = True
@@ -57,13 +57,13 @@ async def test_runner_spinup_exception(
[
x
for x in events
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 3
)
assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events])
assert any([isinstance(x.event, InstanceDeleted) for x in events])
worker.shutdown()
@@ -75,7 +75,7 @@ async def test_runner_spinup_timeout(
async with create_task_group() as tg:
tg.start_soon(worker.run)
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
instance_value.instance_type = InstanceStatus.ACTIVE
instance_value.instance_type = InstanceStatus.Active
instance_value.shard_assignments.runner_to_shard[
RUNNER_1_ID
].should_timeout = 10
@@ -99,11 +99,11 @@ async def test_runner_spinup_timeout(
[
x
for x in events
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
if isinstance(x.event, RunnerStatusUpdated)
and isinstance(x.event.runner_status, FailedRunnerStatus)
]
)
== 3
)
assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events])
assert any([isinstance(x.event, InstanceDeleted) for x in events])
worker.shutdown()

View File

@@ -22,7 +22,6 @@ from exo.shared.types.tasks import (
Task,
TaskId,
TaskStatus,
TaskType,
)
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.instances import (
@@ -107,7 +106,7 @@ async def test_ttft(
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(1),
)
@@ -139,8 +138,7 @@ async def test_ttft(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_status=TaskStatus.Pending,
task_params=task1_params,
)
@@ -157,7 +155,7 @@ async def test_ttft(
first_chunk_seen_1 = False
time_to_first_token_1: None | float = None
while not first_chunk_seen_1:
event = (await global_events.receive()).tagged_event.c
event = (await global_events.receive()).event
if isinstance(event, ChunkGenerated) and hasattr(event, "chunk"):
first_chunk_time_1 = time.time()
time_to_first_token_1 = first_chunk_time_1 - task_created_time_1
@@ -192,8 +190,7 @@ async def test_ttft(
task_id=TASK_2_ID,
command_id=COMMAND_2_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_status=TaskStatus.Pending,
task_params=task2_params,
)
@@ -211,7 +208,7 @@ async def test_ttft(
first_chunk_seen_2 = False
time_to_first_token_2: float | None = None
while not first_chunk_seen_2:
event = (await global_events.receive()).tagged_event.c
event = (await global_events.receive()).event
if isinstance(event, ChunkGenerated) and hasattr(event, "chunk"):
first_chunk_time_2 = time.time()
time_to_first_token_2 = first_chunk_time_2 - task_created_time_2
@@ -344,7 +341,7 @@ async def test_2_runner_inference(
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(2),
)
@@ -424,7 +421,7 @@ async def test_parallel_inference(
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(2),
)
@@ -443,8 +440,7 @@ async def test_parallel_inference(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_status=TaskStatus.Pending,
task_params=completion_create_params_1,
)
@@ -462,8 +458,7 @@ async def test_parallel_inference(
task_id=TASK_2_ID,
command_id=COMMAND_2_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_status=TaskStatus.Pending,
task_params=completion_create_params_2,
)
@@ -485,7 +480,7 @@ async def test_parallel_inference(
incomplete_task = (
TASK_2_ID
if worker1.state.tasks[TASK_1_ID].task_status == TaskStatus.COMPLETE
if worker1.state.tasks[TASK_1_ID].task_status == TaskStatus.Complete
else TASK_2_ID
)
(

View File

@@ -6,7 +6,6 @@ from exo.shared.types.tasks import (
ChatCompletionTask,
ChatCompletionTaskParams,
TaskStatus,
TaskType,
)
from exo.shared.types.worker.common import WorkerStatus
from exo.shared.types.worker.downloads import (
@@ -85,7 +84,7 @@ def _get_test_cases() -> list[PlanTestCase]:
"downloaded": False,
}
],
instance_status=InstanceStatus.INACTIVE,
instance_status=InstanceStatus.Inactive,
expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID),
),
make_test_case(
@@ -99,7 +98,7 @@ def _get_test_cases() -> list[PlanTestCase]:
"downloaded": True,
}
],
instance_status=InstanceStatus.INACTIVE,
instance_status=InstanceStatus.Inactive,
expected_op=None,
),
PlanTestCase(
@@ -110,7 +109,7 @@ def _get_test_cases() -> list[PlanTestCase]:
INSTANCE_1_ID: [(RUNNER_1_ID, NODE_A, 0, InactiveRunnerStatus())]
},
model_id=MODEL_A_ID,
instance_status=InstanceStatus.ACTIVE, # Either active or inactive should yield the same.
instance_status=InstanceStatus.Active, # Either active or inactive should yield the same.
),
expected_op=AssignRunnerOp(
instance_id=INSTANCE_1_ID,
@@ -153,7 +152,7 @@ def _get_test_cases() -> list[PlanTestCase]:
"downloaded": True,
}
],
instance_status=InstanceStatus.ACTIVE,
instance_status=InstanceStatus.Active,
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID),
),
make_test_case(
@@ -180,11 +179,11 @@ def _get_test_cases() -> list[PlanTestCase]:
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.PENDING,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.ACTIVE,
instance_status=InstanceStatus.Active,
expected_op=None,
),
make_test_case(
@@ -209,11 +208,11 @@ def _get_test_cases() -> list[PlanTestCase]:
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.PENDING,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.ACTIVE,
instance_status=InstanceStatus.Active,
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID),
),
make_test_case(
@@ -227,7 +226,7 @@ def _get_test_cases() -> list[PlanTestCase]:
"downloaded": True,
}
],
instance_status=InstanceStatus.INACTIVE,
instance_status=InstanceStatus.Inactive,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
make_test_case(
@@ -241,7 +240,7 @@ def _get_test_cases() -> list[PlanTestCase]:
"downloaded": True,
}
],
instance_status=InstanceStatus.INACTIVE,
instance_status=InstanceStatus.Inactive,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
make_test_case(
@@ -259,19 +258,18 @@ def _get_test_cases() -> list[PlanTestCase]:
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.PENDING,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.ACTIVE,
instance_status=InstanceStatus.Active,
expected_op=ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_status=TaskStatus.PENDING,
task_status=TaskStatus.Pending,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[
@@ -304,11 +302,11 @@ def _get_test_cases() -> list[PlanTestCase]:
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.PENDING,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.ACTIVE,
instance_status=InstanceStatus.Active,
expected_op=None,
),
make_test_case(
@@ -333,25 +331,24 @@ def _get_test_cases() -> list[PlanTestCase]:
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.PENDING,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.ACTIVE,
instance_status=InstanceStatus.Active,
expected_op=ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[
ChatCompletionMessage(role="user", content="Hello, world!")
],
),
task_status=TaskStatus.PENDING,
task_status=TaskStatus.Pending,
),
),
),
@@ -377,25 +374,24 @@ def _get_test_cases() -> list[PlanTestCase]:
{
"task_id": TASK_1_ID,
"instance_id": INSTANCE_1_ID,
"status": TaskStatus.PENDING,
"status": TaskStatus.Pending,
"messages": [{"role": "user", "content": "Hello, world!"}],
}
],
instance_status=InstanceStatus.ACTIVE,
instance_status=InstanceStatus.Active,
expected_op=ExecuteTaskOp(
runner_id=RUNNER_1_ID,
task=ChatCompletionTask(
task_id=TASK_1_ID,
command_id=COMMAND_1_ID,
instance_id=INSTANCE_1_ID,
task_type=TaskType.CHAT_COMPLETION,
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[
ChatCompletionMessage(role="user", content="Hello, world!")
],
),
task_status=TaskStatus.PENDING,
task_status=TaskStatus.Pending,
),
),
),
@@ -410,7 +406,7 @@ def _get_test_cases() -> list[PlanTestCase]:
"downloaded": True,
}
],
instance_status=InstanceStatus.ACTIVE,
instance_status=InstanceStatus.Active,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
make_test_case(
@@ -431,7 +427,7 @@ def _get_test_cases() -> list[PlanTestCase]:
"downloaded": True,
},
],
instance_status=InstanceStatus.ACTIVE,
instance_status=InstanceStatus.Active,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
make_test_case(
@@ -452,7 +448,7 @@ def _get_test_cases() -> list[PlanTestCase]:
"downloaded": True,
},
],
instance_status=InstanceStatus.ACTIVE,
instance_status=InstanceStatus.Active,
expected_op=None,
),
make_test_case(
@@ -473,7 +469,7 @@ def _get_test_cases() -> list[PlanTestCase]:
"downloaded": True,
},
],
instance_status=InstanceStatus.ACTIVE,
instance_status=InstanceStatus.Active,
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
),
]

View File

@@ -9,7 +9,7 @@ from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus
from exo.shared.types.worker.common import InstanceId, RunnerId, WorkerStatus
from exo.shared.types.worker.downloads import DownloadOngoing, DownloadProgressData
from exo.shared.types.worker.instances import Instance, InstanceStatus
@@ -117,7 +117,7 @@ def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus:
download_progress=DownloadOngoing(
node_id=node_id,
download_progress=DownloadProgressData(
total_bytes=Memory.from_bytes(1), downloaded_bytes=Memory.from_bytes(0), downloaded_bytes_this_session=Memory.from_bytes(0), completed_files=0, total_files=1, speed=0, eta_ms=0, files={}
total_bytes=Memory.from_bytes(1), downloaded_bytes=Memory.from_bytes(0)
),
)
)
@@ -146,7 +146,7 @@ def make_instance(
instance_id: InstanceId,
runner_specs: list[tuple[RunnerId, NodeId, int, RunnerStatus]],
model_id: ModelId = MODEL_A_ID,
instance_status: InstanceStatus = InstanceStatus.ACTIVE,
instance_status: InstanceStatus = InstanceStatus.Active,
) -> tuple[Instance, dict[RunnerId, RunnerStatus], dict[NodeId, WorkerStatus]]:
"""Creates an instance with one or more runners."""
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
@@ -189,7 +189,7 @@ def make_state(
],
tasks: dict[TaskId, ChatCompletionTask] | None = None,
model_id: ModelId = MODEL_A_ID,
instance_status: InstanceStatus = InstanceStatus.ACTIVE,
instance_status: InstanceStatus = InstanceStatus.Active,
) -> State:
"""Builds a full State from runner specs per instance, tasks, and defaults."""
if tasks is None:
@@ -224,7 +224,7 @@ def make_test_case(
tasks: list[TaskSpecDict] | None = None,
expected_op: Optional[RunnerOp] = None,
instance_id: InstanceId = INSTANCE_1_ID,
instance_status: InstanceStatus = InstanceStatus.ACTIVE,
instance_status: InstanceStatus = InstanceStatus.Active,
model_id: ModelId = MODEL_A_ID,
command_id: CommandId = COMMAND_1_ID, # Default for tasks
) -> PlanTestCase:
@@ -244,8 +244,7 @@ def make_test_case(
instance_id=instance_id,
task_id=t["task_id"],
command_id=t.get("command_id", command_id),
task_type=TaskType.CHAT_COMPLETION,
task_status=t.get("status", TaskStatus.PENDING),
task_status=t.get("status", TaskStatus.Pending),
task_params=ChatCompletionTaskParams(
model=t.get("model", str(model_id)),
messages=[

View File

@@ -72,7 +72,7 @@ async def check_runner_connection(
instance = Instance(
instance_id=INSTANCE_1_ID,
instance_type=InstanceStatus.ACTIVE,
instance_type=InstanceStatus.Active,
shard_assignments=shard_assignments,
hosts=hosts(2),
)

View File

@@ -6,7 +6,7 @@ from exo.shared.types.common import Host
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.commands_runner import (
ChatTaskMessage,
RunnerMessageTypeAdapter,
RunnerMessage,
SetupMessage,
)
from exo.shared.types.worker.common import InstanceId
@@ -30,7 +30,7 @@ def test_supervisor_setup_message_serdes(
model_shard_meta=pipeline_shard_meta(1, 0),
hosts=hosts(1),
)
assert_equal_serdes(setup_message, RunnerMessageTypeAdapter)
assert_equal_serdes(setup_message, TypeAdapter(RunnerMessage))
def test_supervisor_task_message_serdes(
@@ -40,4 +40,4 @@ def test_supervisor_task_message_serdes(
task_message = ChatTaskMessage(
task_data=task.task_params,
)
assert_equal_serdes(task_message, RunnerMessageTypeAdapter)
assert_equal_serdes(task_message, TypeAdapter(RunnerMessage))

View File

@@ -7,10 +7,10 @@ from exo.shared.openai_compat import FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import Host
from exo.shared.types.tasks import (
ChatCompletionTask,
ChatCompletionTaskParams,
Task,
TaskId,
TaskType,
)
from exo.shared.types.worker.common import InstanceId
from exo.shared.types.worker.shards import PipelineShardMetadata
@@ -143,7 +143,7 @@ async def test_supervisor_early_stopping(
task = chat_completion_task(instance_id, TaskId())
max_tokens = 50
assert task.task_type == TaskType.CHAT_COMPLETION
assert isinstance(task, ChatCompletionTask)
print(f"chat_completion_task.task_params: {task.task_params}")
assert isinstance(task.task_params, ChatCompletionTaskParams)
task_params: ChatCompletionTaskParams = task.task_params

View File

@@ -6,7 +6,7 @@ from anyio import fail_after
from exo.routing.topics import ConnectionMessage, ForwarderCommand, ForwarderEvent
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import NodeId
from exo.shared.types.events import ChunkGenerated, Event, TaggedEvent, TaskStateUpdated
from exo.shared.types.events import ChunkGenerated, Event, TaskStateUpdated
from exo.shared.types.tasks import TaskId, TaskStatus
from exo.utils.channels import Receiver, Sender, channel
from exo.worker.download.shard_downloader import NoopShardDownloader, ShardDownloader
@@ -24,7 +24,7 @@ class WorkerMailbox:
await self.sender.send(
ForwarderEvent(
origin=origin,
tagged_event=TaggedEvent.from_(event),
event=event,
origin_idx=self.counter,
)
)
@@ -105,7 +105,7 @@ async def read_streaming_response(
token_count = 0
extra_events: list[Event] = []
event = (await global_event_receiver.receive()).tagged_event.c
event = (await global_event_receiver.receive()).event
extra_events.append(event)
from loguru import logger
@@ -116,17 +116,17 @@ async def read_streaming_response(
if filter_task:
while not (
isinstance(event, TaskStateUpdated)
and event.task_status == TaskStatus.RUNNING
and event.task_status == TaskStatus.Running
and event.task_id == filter_task
):
event = (await global_event_receiver.receive()).tagged_event.c
event = (await global_event_receiver.receive()).event
extra_events.append(event)
for event in extra_events:
if isinstance(event, TaskStateUpdated):
if event.task_status == TaskStatus.RUNNING:
if event.task_status == TaskStatus.Running:
seen_task_started += 1
if event.task_status == TaskStatus.COMPLETE:
if event.task_status == TaskStatus.Complete:
seen_task_finished += 1
if isinstance(event, ChunkGenerated) and isinstance(
event.chunk, TokenChunk
@@ -137,11 +137,11 @@ async def read_streaming_response(
finish_reason = event.chunk.finish_reason
while not seen_task_finished:
event = (await global_event_receiver.receive()).tagged_event.c
event = (await global_event_receiver.receive()).event
if isinstance(event, TaskStateUpdated):
if event.task_status == TaskStatus.RUNNING:
if event.task_status == TaskStatus.Running:
seen_task_started += 1
if event.task_status == TaskStatus.COMPLETE:
if event.task_status == TaskStatus.Complete:
seen_task_finished += 1
if isinstance(event, ChunkGenerated) and isinstance(
event.chunk, TokenChunk
@@ -167,7 +167,7 @@ async def until_event_with_timeout[T](
with fail_after(timeout):
while times_seen < multiplicity:
event = (await global_event_receiver.receive()).tagged_event.c
event = (await global_event_receiver.receive()).event
if isinstance(event, event_type):
print(f"Wow! We got a {event}")
print(

View File

@@ -99,13 +99,13 @@ async def start_polling_node_metrics(
system_info,
network_interfaces,
mac_friendly_name,
memory_profile,
) = await asyncio.gather(
get_mac_system_info_async(),
get_network_interface_info_async(),
get_mac_friendly_name_async(),
get_memory_profile_async(),
)
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = await get_memory_profile_async()
await callback(
NodePerformanceProfile(

0
typings/.gitkeep Normal file
View File

1416
uv.lock generated

File diff suppressed because it is too large Load Diff