mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
super secret update
This commit is contained in:
@@ -40,12 +40,4 @@ else
|
||||
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
|
||||
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
|
||||
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
|
||||
fi
|
||||
|
||||
if ifconfig bridge0 >/dev/null 2>&1; then
|
||||
echo "Thunderbolt bridge found"
|
||||
if ifconfig bridge0 | grep -q "status: active"; then
|
||||
sudo ifconfig bridge0 down
|
||||
echo "Thunderbolt bridge disabled"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
@@ -492,133 +492,6 @@
|
||||
transition: width 0.3s ease;
|
||||
}
|
||||
|
||||
/* Detailed download info */
|
||||
.download-details {
|
||||
margin-top: 8px;
|
||||
padding: 12px;
|
||||
background-color: #1a1a1a;
|
||||
border: 1px solid var(--exo-medium-gray);
|
||||
border-radius: 6px;
|
||||
box-sizing: border-box;
|
||||
width: 100%;
|
||||
max-width: 100%;
|
||||
overflow: visible;
|
||||
}
|
||||
.download-runner-header {
|
||||
font-size: 11px;
|
||||
color: var(--exo-light-gray);
|
||||
opacity: 0.85;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.download-overview-row {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
flex-wrap: wrap;
|
||||
font-size: 12px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.download-overview-item strong {
|
||||
color: #E0E0E0;
|
||||
font-weight: 600;
|
||||
margin-right: 4px;
|
||||
}
|
||||
.progress-with-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.progress-with-label .progress-bar-container {
|
||||
flex: 1 1 auto;
|
||||
}
|
||||
.progress-percent {
|
||||
font-size: 12px;
|
||||
color: var(--exo-light-gray);
|
||||
font-variant-numeric: tabular-nums;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.download-overview-combined {
|
||||
font-size: 12px;
|
||||
color: var(--exo-light-gray);
|
||||
opacity: 0.9;
|
||||
}
|
||||
.instance-download-summary {
|
||||
font-size: 11px;
|
||||
color: var(--exo-light-gray);
|
||||
margin-top: 6px;
|
||||
opacity: 0.95;
|
||||
}
|
||||
.download-files-list {
|
||||
display: grid;
|
||||
gap: 8px;
|
||||
}
|
||||
.download-file {
|
||||
padding: 8px;
|
||||
background-color: var(--exo-dark-gray);
|
||||
border: 1px solid var(--exo-medium-gray);
|
||||
border-radius: 6px;
|
||||
box-sizing: border-box;
|
||||
width: 100%;
|
||||
max-width: 100%;
|
||||
}
|
||||
.download-file-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
font-size: 11px;
|
||||
margin-bottom: 6px;
|
||||
width: 100%;
|
||||
max-width: 100%;
|
||||
overflow: hidden;
|
||||
}
|
||||
.download-file-name {
|
||||
color: #E0E0E0;
|
||||
font-weight: 500;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
min-width: 0;
|
||||
flex: 1 1 auto;
|
||||
}
|
||||
.download-file-stats {
|
||||
color: var(--exo-light-gray);
|
||||
text-align: right;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.download-file-percent {
|
||||
color: var(--exo-light-gray);
|
||||
white-space: nowrap;
|
||||
font-size: 11px;
|
||||
font-variant-numeric: tabular-nums;
|
||||
flex: 0 0 auto;
|
||||
}
|
||||
.download-file-subtext {
|
||||
color: var(--exo-light-gray);
|
||||
font-size: 10px;
|
||||
opacity: 0.85;
|
||||
margin-bottom: 6px;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
max-width: 100%;
|
||||
}
|
||||
.download-details, .download-files-list {
|
||||
box-sizing: border-box;
|
||||
width: 100%;
|
||||
max-width: 100%;
|
||||
}
|
||||
.download-files-list {
|
||||
overflow: visible;
|
||||
padding-right: 2px; /* avoid edge clipping */
|
||||
}
|
||||
.download-file .progress-bar-container {
|
||||
width: 100%;
|
||||
max-width: 100%;
|
||||
box-sizing: border-box;
|
||||
height: 5px;
|
||||
}
|
||||
|
||||
/* Launch instance section styles */
|
||||
.launch-instance-section {
|
||||
display: flex;
|
||||
@@ -877,7 +750,6 @@
|
||||
|
||||
const USE_MOCK_DATA = false; // <<< FLAG TO TOGGLE MOCK DATA
|
||||
let currentlySelectedNodeId = null; // To store the ID of the currently selected node
|
||||
let nodeIdToFriendlyName = {}; // Map nodeId -> friendly name for download sections
|
||||
|
||||
const API_ENDPOINT = window.location.origin + window.location.pathname.replace(/\/$/, "") + '/state';
|
||||
const REFRESH_INTERVAL = 1000; // 1 second
|
||||
@@ -983,36 +855,6 @@
|
||||
return days + (days === 1 ? ' day ago' : ' days ago');
|
||||
}
|
||||
|
||||
// --- Download formatting helpers ---
|
||||
function bytesFromValue(value) {
|
||||
if (typeof value === 'number') return value;
|
||||
if (!value || typeof value !== 'object') return 0;
|
||||
if (typeof value.in_bytes === 'number') return value.in_bytes;
|
||||
if (typeof value.inBytes === 'number') return value.inBytes;
|
||||
return 0;
|
||||
}
|
||||
|
||||
function formatDurationMs(ms) {
|
||||
if (ms == null || isNaN(ms) || ms < 0) return '—';
|
||||
const totalSeconds = Math.round(ms / 1000);
|
||||
const s = totalSeconds % 60;
|
||||
const m = Math.floor(totalSeconds / 60) % 60;
|
||||
const h = Math.floor(totalSeconds / 3600);
|
||||
if (h > 0) return `${h}h ${m}m ${s}s`;
|
||||
if (m > 0) return `${m}m ${s}s`;
|
||||
return `${s}s`;
|
||||
}
|
||||
|
||||
function formatPercent(value, digits = 2) {
|
||||
if (value == null || isNaN(value)) return '0.00%';
|
||||
return `${value.toFixed(digits)}%`;
|
||||
}
|
||||
|
||||
function formatBytesPerSecond(bps) {
|
||||
if (bps == null || isNaN(bps) || bps < 0) return '0 B/s';
|
||||
return `${formatBytes(bps)}/s`;
|
||||
}
|
||||
|
||||
// Sidebar toggle functionality
|
||||
let sidebarOpen = false;
|
||||
|
||||
@@ -1092,7 +934,7 @@
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ model_id: selectedModelId })
|
||||
body: JSON.stringify({ modelId: selectedModelId, model_id: selectedModelId })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -1132,123 +974,75 @@
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate download status for an instance based on its runners, with detailed per-file info
|
||||
// Calculate download status for an instance based on its runners
|
||||
function calculateInstanceDownloadStatus(instance, runners) {
|
||||
if (!instance.shard_assignments?.runner_to_shard || !runners) {
|
||||
return { isDownloading: false, progress: 0, details: [] };
|
||||
const shardAssignments = instance.shard_assignments ?? instance.shardAssignments;
|
||||
const runnerToShard = shardAssignments?.runner_to_shard ?? shardAssignments?.runnerToShard;
|
||||
if (!runnerToShard || !runners) {
|
||||
return { isDownloading: false, progress: 0 };
|
||||
}
|
||||
|
||||
const pick = (obj, snake, camel, fallback = undefined) => {
|
||||
if (!obj) return fallback;
|
||||
if (obj[snake] !== undefined) return obj[snake];
|
||||
if (obj[camel] !== undefined) return obj[camel];
|
||||
return fallback;
|
||||
};
|
||||
|
||||
function normalizeProgress(progressRaw) {
|
||||
if (!progressRaw) return null;
|
||||
const totalBytes = bytesFromValue(pick(progressRaw, 'total_bytes', 'totalBytes', 0));
|
||||
const downloadedBytes = bytesFromValue(pick(progressRaw, 'downloaded_bytes', 'downloadedBytes', 0));
|
||||
const downloadedBytesThisSession = bytesFromValue(pick(progressRaw, 'downloaded_bytes_this_session', 'downloadedBytesThisSession', 0));
|
||||
const completedFiles = Number(pick(progressRaw, 'completed_files', 'completedFiles', 0)) || 0;
|
||||
const totalFiles = Number(pick(progressRaw, 'total_files', 'totalFiles', 0)) || 0;
|
||||
const speed = Number(pick(progressRaw, 'speed', 'speed', 0)) || 0;
|
||||
const etaMs = Number(pick(progressRaw, 'eta_ms', 'etaMs', 0)) || 0;
|
||||
const filesObj = pick(progressRaw, 'files', 'files', {}) || {};
|
||||
const files = [];
|
||||
Object.keys(filesObj).forEach(name => {
|
||||
const f = filesObj[name];
|
||||
if (!f || typeof f !== 'object') return;
|
||||
const fTotal = bytesFromValue(pick(f, 'total_bytes', 'totalBytes', 0));
|
||||
const fDownloaded = bytesFromValue(pick(f, 'downloaded_bytes', 'downloadedBytes', 0));
|
||||
const fSpeed = Number(pick(f, 'speed', 'speed', 0)) || 0;
|
||||
const fEta = Number(pick(f, 'eta_ms', 'etaMs', 0)) || 0;
|
||||
const fPct = fTotal > 0 ? (fDownloaded / fTotal) * 100 : 0;
|
||||
files.push({ name, totalBytes: fTotal, downloadedBytes: fDownloaded, speed: fSpeed, etaMs: fEta, percentage: fPct });
|
||||
});
|
||||
const percentage = totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0;
|
||||
return { totalBytes, downloadedBytes, downloadedBytesThisSession, completedFiles, totalFiles, speed, etaMs, files, percentage };
|
||||
}
|
||||
|
||||
const runnerIds = Object.keys(instance.shard_assignments.runner_to_shard);
|
||||
const details = [];
|
||||
const runnerIds = Object.keys(runnerToShard);
|
||||
const downloadingRunners = [];
|
||||
let totalBytes = 0;
|
||||
let downloadedBytes = 0;
|
||||
|
||||
for (const runnerId of runnerIds) {
|
||||
const runner = runners[runnerId];
|
||||
if (!runner || runner.runner_status !== 'Downloading' || !runner.download_progress) continue;
|
||||
const dp = runner.download_progress;
|
||||
const isDownloading = (dp.download_status === 'Downloading') || (dp.downloadStatus === 'Downloading');
|
||||
if (!isDownloading) continue;
|
||||
const nodeId = (dp && (dp.node_id || dp.nodeId)) || undefined;
|
||||
const rawProg = pick(dp, 'download_progress', 'downloadProgress', null);
|
||||
const normalized = normalizeProgress(rawProg);
|
||||
if (!normalized) continue;
|
||||
details.push({ runnerId, nodeId, progress: normalized });
|
||||
totalBytes += normalized.totalBytes || 0;
|
||||
downloadedBytes += normalized.downloadedBytes || 0;
|
||||
let isRunnerDownloading = false;
|
||||
|
||||
// Legacy snake_case structure
|
||||
if (runner && runner.runner_status === 'Downloading' && runner.download_progress) {
|
||||
isRunnerDownloading = runner.download_progress.download_status === 'Downloading';
|
||||
if (isRunnerDownloading && runner.download_progress.download_progress) {
|
||||
totalBytes += runner.download_progress.download_progress.total_bytes || 0;
|
||||
downloadedBytes += runner.download_progress.download_progress.downloaded_bytes || 0;
|
||||
}
|
||||
} else if (runner && typeof runner === 'object') {
|
||||
// Tagged-union camelCase structure, e.g. { "DownloadingRunnerStatus": { downloadProgress: { totalBytes, downloadedBytes } } }
|
||||
const tag = Object.keys(runner)[0];
|
||||
if (tag && /DownloadingRunnerStatus$/i.test(tag)) {
|
||||
isRunnerDownloading = true;
|
||||
const inner = runner[tag] || {};
|
||||
const prog = inner.downloadProgress || inner.download_progress || {};
|
||||
const t = prog.totalBytes ?? prog.total_bytes ?? 0;
|
||||
const d = prog.downloadedBytes ?? prog.downloaded_bytes ?? 0;
|
||||
totalBytes += typeof t === 'number' ? t : 0;
|
||||
downloadedBytes += typeof d === 'number' ? d : 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (isRunnerDownloading) downloadingRunners.push(runner);
|
||||
}
|
||||
|
||||
const isDownloadingAny = details.length > 0;
|
||||
const progress = totalBytes > 0 ? ((downloadedBytes / totalBytes) * 100) : 0;
|
||||
return { isDownloading: isDownloadingAny, progress, details };
|
||||
}
|
||||
const isDownloading = downloadingRunners.length > 0;
|
||||
const progress = totalBytes > 0 ? Math.round((downloadedBytes / totalBytes) * 100) : 0;
|
||||
|
||||
function buildDownloadDetailsHTML(details) {
|
||||
if (!details || details.length === 0) return '';
|
||||
function shortId(id) { return (id && id.length > 8) ? id.slice(0, 8) + '…' : (id || ''); }
|
||||
return details.map(({ runnerId, nodeId, progress }) => {
|
||||
const etaStr = formatDurationMs(progress.etaMs);
|
||||
const pctStr = formatPercent(progress.percentage || 0, 2);
|
||||
const bytesStr = `${formatBytes(progress.downloadedBytes)} / ${formatBytes(progress.totalBytes)}`;
|
||||
const speedStr = formatBytesPerSecond(progress.speed);
|
||||
const filesSummary = `${progress.completedFiles}/${progress.totalFiles}`;
|
||||
|
||||
const filesHTML = (progress.files || []).map(f => {
|
||||
const fPct = f.percentage || 0;
|
||||
const fBytes = `${formatBytes(f.downloadedBytes)} / ${formatBytes(f.totalBytes)}`;
|
||||
const fEta = formatDurationMs(f.etaMs);
|
||||
const fSpeed = formatBytesPerSecond(f.speed);
|
||||
const pctText = formatPercent(fPct, 2);
|
||||
return `
|
||||
<div class="download-file">
|
||||
<div class="download-file-header">
|
||||
<span class="download-file-name" title="${f.name}">${f.name}</span>
|
||||
<span class="download-file-percent">${pctText}</span>
|
||||
</div>
|
||||
<div class="download-file-subtext">${fBytes} • ETA ${fEta} • ${fSpeed}</div>
|
||||
<div class="progress-bar-container"><div class="progress-bar" style="width: ${Math.max(0, Math.min(100, fPct)).toFixed(2)}%;"></div></div>
|
||||
</div>
|
||||
`;
|
||||
}).join('');
|
||||
|
||||
const runnerName = (nodeId && nodeIdToFriendlyName[nodeId]) ? nodeIdToFriendlyName[nodeId] : '?';
|
||||
const headerText = `${runnerName} (${shortId(nodeId || '')})`;
|
||||
return `
|
||||
<div class="download-details">
|
||||
<div class="download-runner-header">${headerText}</div>
|
||||
<div class="download-files-list">
|
||||
${filesHTML}
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
}).join('');
|
||||
return { isDownloading, progress, downloadingRunners: downloadingRunners.length };
|
||||
}
|
||||
|
||||
// Derive a display status for an instance from its runners.
|
||||
// Priority: FAILED > DOWNLOADING > STARTING > RUNNING > LOADED > INACTIVE
|
||||
function deriveInstanceStatus(instance, runners = {}) {
|
||||
const runnerIds = Object.keys(instance.shard_assignments?.runner_to_shard || {});
|
||||
const shardAssignments = instance.shard_assignments ?? instance.shardAssignments;
|
||||
const runnerToShard = shardAssignments?.runner_to_shard ?? shardAssignments?.runnerToShard ?? {};
|
||||
const runnerIds = Object.keys(runnerToShard);
|
||||
const statuses = runnerIds
|
||||
.map(rid => runners[rid]?.runner_status)
|
||||
.map(rid => {
|
||||
const r = runners[rid];
|
||||
if (!r || typeof r !== 'object') return undefined;
|
||||
if (typeof r.runner_status === 'string') return r.runner_status;
|
||||
const tag = Object.keys(r)[0];
|
||||
return typeof tag === 'string' ? tag.replace(/RunnerStatus$/,'') : undefined; // e.g. LoadedRunnerStatus -> Loaded
|
||||
})
|
||||
.filter(s => typeof s === 'string');
|
||||
|
||||
const has = (s) => statuses.includes(s);
|
||||
const every = (pred) => statuses.length > 0 && statuses.every(pred);
|
||||
|
||||
if (statuses.length === 0) {
|
||||
const inactive = instance.instance_type === 'INACTIVE';
|
||||
const instanceType = instance.instance_type ?? instance.instanceType;
|
||||
const inactive = instanceType === 'INACTIVE' || instanceType === 'Inactive';
|
||||
return { statusText: inactive ? 'INACTIVE' : 'LOADED', statusClass: inactive ? 'inactive' : 'loaded' };
|
||||
}
|
||||
|
||||
@@ -1278,10 +1072,12 @@
|
||||
}
|
||||
|
||||
const instancesHTML = instancesArray.map(instance => {
|
||||
const modelId = instance.shard_assignments?.model_id || 'Unknown Model';
|
||||
const truncatedInstanceId = instance.instance_id.length > 8
|
||||
? instance.instance_id.substring(0, 8) + '...'
|
||||
: instance.instance_id;
|
||||
const shardAssignments = instance.shard_assignments ?? instance.shardAssignments;
|
||||
const modelId = shardAssignments?.model_id ?? shardAssignments?.modelId ?? 'Unknown Model';
|
||||
const instanceId = instance.instance_id ?? instance.instanceId ?? '';
|
||||
const truncatedInstanceId = instanceId.length > 8
|
||||
? instanceId.substring(0, 8) + '...'
|
||||
: instanceId;
|
||||
|
||||
const hostsHTML = instance.hosts?.map(host =>
|
||||
`<span class="instance-host">${host.ip}:${host.port}</span>`
|
||||
@@ -1298,31 +1094,15 @@
|
||||
}
|
||||
|
||||
// Generate download progress HTML
|
||||
let downloadProgressHTML = '';
|
||||
let instanceDownloadSummary = '';
|
||||
if (downloadStatus.isDownloading) {
|
||||
const detailsHTML = buildDownloadDetailsHTML(downloadStatus.details || []);
|
||||
const pctText = (downloadStatus.progress || 0).toFixed(2);
|
||||
// Aggregate a compact summary from the first runner (they should be consistent in aggregate)
|
||||
const first = (downloadStatus.details || [])[0]?.progress;
|
||||
const etaStr = first ? formatDurationMs(first.etaMs) : '—';
|
||||
const bytesStr = first ? `${formatBytes(first.downloadedBytes)} / ${formatBytes(first.totalBytes)}` : '';
|
||||
const speedStr = first ? formatBytesPerSecond(first.speed) : '';
|
||||
const filesSummary = first ? `${first.completedFiles}/${first.totalFiles}` : '';
|
||||
instanceDownloadSummary = `${etaStr} · ${bytesStr} · ${speedStr} · ${filesSummary} files`;
|
||||
|
||||
downloadProgressHTML = `
|
||||
<div class="download-progress">
|
||||
<span>${pctText}%</span>
|
||||
<div class="progress-bar-container">
|
||||
<div class="progress-bar" style="width: ${pctText}%;"></div>
|
||||
</div>
|
||||
const downloadProgressHTML = downloadStatus.isDownloading
|
||||
? `<div class="download-progress">
|
||||
<span>${downloadStatus.progress}% downloaded</span>
|
||||
<div class="progress-bar-container">
|
||||
<div class="progress-bar" style="width: ${downloadStatus.progress}%;"></div>
|
||||
</div>
|
||||
${detailsHTML}
|
||||
`;
|
||||
}
|
||||
</div>`
|
||||
: '';
|
||||
|
||||
const shardCount = Object.keys(instance.shard_assignments?.runner_to_shard || {}).length;
|
||||
return `
|
||||
<div class="instance-item">
|
||||
<div class="instance-header">
|
||||
@@ -1331,14 +1111,15 @@
|
||||
<span class="instance-status ${statusClass}">${statusText}</span>
|
||||
</div>
|
||||
<div class="instance-actions">
|
||||
<button class="instance-delete-button" data-instance-id="${instance.instance_id}" title="Delete Instance">
|
||||
<button class="instance-delete-button" data-instance-id="${instanceId}" title="Delete Instance">
|
||||
Delete
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="instance-model">${modelId} <span style="color: var(--exo-light-gray); opacity: 0.8;">(${shardCount})</span></div>
|
||||
${instanceDownloadSummary ? `<div class="instance-download-summary">${instanceDownloadSummary}</div>` : ''}
|
||||
|
||||
<div class="instance-model">${modelId}</div>
|
||||
<div class="instance-details">
|
||||
Shards: ${Object.keys((shardAssignments?.runner_to_shard ?? shardAssignments?.runnerToShard) || {}).length}
|
||||
</div>
|
||||
${downloadProgressHTML}
|
||||
${hostsHTML ? `<div class="instance-hosts">${hostsHTML}</div>` : ''}
|
||||
</div>
|
||||
@@ -1395,12 +1176,10 @@
|
||||
}
|
||||
}
|
||||
|
||||
function renderNodes(topologyData) {
|
||||
function renderNodes(nodesData) {
|
||||
if (!topologyGraphContainer) return;
|
||||
topologyGraphContainer.innerHTML = ''; // Clear previous SVG content
|
||||
|
||||
const nodesData = (topologyData && topologyData.nodes) ? topologyData.nodes : {};
|
||||
const edgesData = (topologyData && Array.isArray(topologyData.edges)) ? topologyData.edges : [];
|
||||
const nodeIds = Object.keys(nodesData);
|
||||
|
||||
if (nodeIds.length === 0) {
|
||||
@@ -1435,128 +1214,23 @@
|
||||
};
|
||||
});
|
||||
|
||||
// Add arrowhead definition (supports bidirectional arrows on a single line)
|
||||
const defs = document.createElementNS('http://www.w3.org/2000/svg', 'defs');
|
||||
const marker = document.createElementNS('http://www.w3.org/2000/svg', 'marker');
|
||||
marker.setAttribute('id', 'arrowhead');
|
||||
marker.setAttribute('viewBox', '0 0 10 10');
|
||||
marker.setAttribute('refX', '10');
|
||||
marker.setAttribute('refY', '5');
|
||||
marker.setAttribute('markerWidth', '11');
|
||||
marker.setAttribute('markerHeight', '11');
|
||||
marker.setAttribute('orient', 'auto-start-reverse');
|
||||
// Draw a subtle V-tip (no filled body)
|
||||
const markerTip = document.createElementNS('http://www.w3.org/2000/svg', 'path');
|
||||
markerTip.setAttribute('d', 'M 0 0 L 10 5 L 0 10');
|
||||
markerTip.setAttribute('fill', 'none');
|
||||
markerTip.setAttribute('stroke', 'var(--exo-light-gray)');
|
||||
markerTip.setAttribute('stroke-width', '1.6');
|
||||
markerTip.setAttribute('stroke-linecap', 'round');
|
||||
markerTip.setAttribute('stroke-linejoin', 'round');
|
||||
markerTip.setAttribute('stroke-dasharray', 'none');
|
||||
markerTip.setAttribute('stroke-dashoffset', '0');
|
||||
markerTip.setAttribute('style', 'animation: none; pointer-events: none;');
|
||||
marker.appendChild(markerTip);
|
||||
defs.appendChild(marker);
|
||||
topologyGraphContainer.appendChild(defs);
|
||||
|
||||
// Create groups for links and separate arrow markers (so arrows are not affected by line animations)
|
||||
// Create group for links (drawn first, so they are behind nodes)
|
||||
const linksGroup = document.createElementNS('http://www.w3.org/2000/svg', 'g');
|
||||
linksGroup.setAttribute('class', 'links-group');
|
||||
linksGroup.setAttribute('style', 'pointer-events: none;');
|
||||
const arrowsGroup = document.createElementNS('http://www.w3.org/2000/svg', 'g');
|
||||
arrowsGroup.setAttribute('class', 'arrows-group');
|
||||
arrowsGroup.setAttribute('style', 'pointer-events: none;');
|
||||
|
||||
// Build quick lookup for node positions
|
||||
const positionById = {};
|
||||
nodesWithPositions.forEach(n => { positionById[n.id] = { x: n.x, y: n.y }; });
|
||||
|
||||
// Group directed edges into undirected pairs to support single line with two arrows
|
||||
const pairMap = new Map(); // key: "a|b" with a<b, value: { a, b, aToB, bToA }
|
||||
edgesData.forEach(edge => {
|
||||
if (!edge || !edge.source || !edge.target) return;
|
||||
if (!positionById[edge.source] || !positionById[edge.target]) return;
|
||||
if (edge.source === edge.target) return;
|
||||
const a = edge.source < edge.target ? edge.source : edge.target;
|
||||
const b = edge.source < edge.target ? edge.target : edge.source;
|
||||
const key = `${a}|${b}`;
|
||||
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false };
|
||||
if (edge.source === a && edge.target === b) entry.aToB = true; else entry.bToA = true;
|
||||
pairMap.set(key, entry);
|
||||
});
|
||||
|
||||
// Draw one line per undirected pair with separate arrow carrier lines
|
||||
pairMap.forEach(entry => {
|
||||
const posA = positionById[entry.a];
|
||||
const posB = positionById[entry.b];
|
||||
if (!posA || !posB) return;
|
||||
|
||||
// Full-length center-to-center lines
|
||||
const x1 = posA.x;
|
||||
const y1 = posA.y;
|
||||
const x2 = posB.x;
|
||||
const y2 = posB.y;
|
||||
|
||||
// Base animated dashed line (no markers)
|
||||
const baseLine = document.createElementNS('http://www.w3.org/2000/svg', 'line');
|
||||
baseLine.setAttribute('x1', x1);
|
||||
baseLine.setAttribute('y1', y1);
|
||||
baseLine.setAttribute('x2', x2);
|
||||
baseLine.setAttribute('y2', y2);
|
||||
baseLine.setAttribute('class', 'graph-link');
|
||||
linksGroup.appendChild(baseLine);
|
||||
|
||||
// Arrowheads centered on the line (tip lies exactly on the line),
|
||||
// offset along the tangent so opposite directions straddle the center.
|
||||
const dx = x2 - x1;
|
||||
const dy = y2 - y1;
|
||||
const len = Math.hypot(dx, dy) || 1;
|
||||
const ux = dx / len;
|
||||
const uy = dy / len;
|
||||
const mx = (x1 + x2) / 2;
|
||||
const my = (y1 + y2) / 2;
|
||||
const tipOffset = 16; // shift arrow tips away from the exact center along the line
|
||||
const carrier = 2; // short carrier segment length to define orientation
|
||||
|
||||
if (entry.aToB) {
|
||||
// Arrow pointing A -> B: place tip slightly before center along +tangent
|
||||
const tipX = mx - ux * tipOffset;
|
||||
const tipY = my - uy * tipOffset;
|
||||
const sx = tipX - ux * carrier;
|
||||
const sy = tipY - uy * carrier;
|
||||
const ex = tipX;
|
||||
const ey = tipY;
|
||||
const arrowSeg = document.createElementNS('http://www.w3.org/2000/svg', 'line');
|
||||
arrowSeg.setAttribute('x1', sx);
|
||||
arrowSeg.setAttribute('y1', sy);
|
||||
arrowSeg.setAttribute('x2', ex);
|
||||
arrowSeg.setAttribute('y2', ey);
|
||||
arrowSeg.setAttribute('stroke', 'none');
|
||||
arrowSeg.setAttribute('fill', 'none');
|
||||
arrowSeg.setAttribute('marker-end', 'url(#arrowhead)');
|
||||
arrowsGroup.appendChild(arrowSeg);
|
||||
for (let i = 0; i < numNodes; i++) {
|
||||
for (let j = i + 1; j < numNodes; j++) {
|
||||
const link = document.createElementNS('http://www.w3.org/2000/svg', 'line');
|
||||
link.setAttribute('x1', nodesWithPositions[i].x);
|
||||
link.setAttribute('y1', nodesWithPositions[i].y);
|
||||
link.setAttribute('x2', nodesWithPositions[j].x);
|
||||
link.setAttribute('y2', nodesWithPositions[j].y);
|
||||
link.setAttribute('class', 'graph-link');
|
||||
linksGroup.appendChild(link);
|
||||
}
|
||||
}
|
||||
topologyGraphContainer.appendChild(linksGroup);
|
||||
|
||||
if (entry.bToA) {
|
||||
// Arrow pointing B -> A: place tip slightly after center along -tangent
|
||||
const tipX = mx + ux * tipOffset;
|
||||
const tipY = my + uy * tipOffset;
|
||||
const sx = tipX + ux * carrier; // start ahead so the segment points toward tip
|
||||
const sy = tipY + uy * carrier;
|
||||
const ex = tipX;
|
||||
const ey = tipY;
|
||||
const arrowSeg = document.createElementNS('http://www.w3.org/2000/svg', 'line');
|
||||
arrowSeg.setAttribute('x1', sx);
|
||||
arrowSeg.setAttribute('y1', sy);
|
||||
arrowSeg.setAttribute('x2', ex);
|
||||
arrowSeg.setAttribute('y2', ey);
|
||||
arrowSeg.setAttribute('stroke', 'none');
|
||||
arrowSeg.setAttribute('fill', 'none');
|
||||
arrowSeg.setAttribute('marker-end', 'url(#arrowhead)');
|
||||
arrowsGroup.appendChild(arrowSeg);
|
||||
}
|
||||
});
|
||||
// Create group for nodes
|
||||
const nodesGroup = document.createElementNS('http://www.w3.org/2000/svg', 'g');
|
||||
nodesGroup.setAttribute('class', 'nodes-group');
|
||||
@@ -2064,10 +1738,7 @@
|
||||
|
||||
nodesGroup.appendChild(nodeG);
|
||||
});
|
||||
// Draw order: lines at the very back, then nodes, then mid-line arrows on top
|
||||
topologyGraphContainer.appendChild(linksGroup);
|
||||
topologyGraphContainer.appendChild(nodesGroup);
|
||||
topologyGraphContainer.appendChild(arrowsGroup);
|
||||
}
|
||||
|
||||
function showNodeDetails(selectedNodeId, allNodesData) {
|
||||
@@ -2215,22 +1886,13 @@
|
||||
throw new Error(`HTTP error! status: ${response.status} ${response.statusText}`);
|
||||
}
|
||||
const clusterState = await response.json();
|
||||
const topologyData = transformClusterStateToTopology(clusterState);
|
||||
// Build nodeId -> friendly name map
|
||||
nodeIdToFriendlyName = {};
|
||||
if (topologyData && topologyData.nodes) {
|
||||
Object.keys(topologyData.nodes).forEach(nid => {
|
||||
const n = topologyData.nodes[nid];
|
||||
const name = (n && (n.friendly_name || (n.system_info && n.system_info.model_id))) || null;
|
||||
if (name) nodeIdToFriendlyName[nid] = name;
|
||||
});
|
||||
}
|
||||
renderNodes(topologyData);
|
||||
const nodesData = transformClusterStateToTopology(clusterState);
|
||||
renderNodes(nodesData);
|
||||
|
||||
// If a node was selected, and it still exists, refresh its details
|
||||
if (currentlySelectedNodeId && topologyData.nodes[currentlySelectedNodeId]) {
|
||||
showNodeDetails(currentlySelectedNodeId, topologyData.nodes);
|
||||
} else if (currentlySelectedNodeId && !topologyData.nodes[currentlySelectedNodeId]) {
|
||||
if (currentlySelectedNodeId && nodesData[currentlySelectedNodeId]) {
|
||||
showNodeDetails(currentlySelectedNodeId, nodesData);
|
||||
} else if (currentlySelectedNodeId && !nodesData[currentlySelectedNodeId]) {
|
||||
// If selected node is gone, close panel and clear selection
|
||||
nodeDetailPanel.classList.remove('visible');
|
||||
currentlySelectedNodeId = null;
|
||||
@@ -2276,9 +1938,8 @@
|
||||
}
|
||||
|
||||
function transformClusterStateToTopology(clusterState) {
|
||||
const resultNodes = {};
|
||||
const resultEdges = [];
|
||||
if (!clusterState) return { nodes: resultNodes, edges: resultEdges };
|
||||
const result = {};
|
||||
if (!clusterState) return result;
|
||||
|
||||
// Helper: get numeric bytes from various shapes (number | {in_bytes}|{inBytes})
|
||||
function getBytes(value) {
|
||||
@@ -2298,21 +1959,18 @@
|
||||
return fallback;
|
||||
};
|
||||
|
||||
// Helper: detect API placeholders like "unknown" (case-insensitive)
|
||||
const isUnknown = (value) => {
|
||||
return typeof value === 'string' && value.trim().toLowerCase() === 'unknown';
|
||||
};
|
||||
|
||||
// Process nodes from topology or fallback to node_profiles directly
|
||||
// Process nodes from topology or fallback to node_profiles/nodeProfiles directly
|
||||
let nodesToProcess = {};
|
||||
if (clusterState.topology && Array.isArray(clusterState.topology.nodes)) {
|
||||
clusterState.topology.nodes.forEach(node => {
|
||||
if (node.node_id && node.node_profile) {
|
||||
nodesToProcess[node.node_id] = node.node_profile;
|
||||
const nid = node.node_id ?? node.nodeId;
|
||||
const nprof = node.node_profile ?? node.nodeProfile;
|
||||
if (nid && nprof) {
|
||||
nodesToProcess[nid] = nprof;
|
||||
}
|
||||
});
|
||||
} else if (clusterState.node_profiles) {
|
||||
nodesToProcess = clusterState.node_profiles;
|
||||
} else if (clusterState.node_profiles || clusterState.nodeProfiles) {
|
||||
nodesToProcess = clusterState.node_profiles ?? clusterState.nodeProfiles;
|
||||
}
|
||||
|
||||
// Transform each node
|
||||
@@ -2333,15 +1991,10 @@
|
||||
memBytesAvailable = getBytes(ramAvailVal);
|
||||
const memBytesUsed = Math.max(memBytesTotal - memBytesAvailable, 0);
|
||||
|
||||
// Extract model information with graceful placeholders while node is loading
|
||||
const rawModelId = pick(nodeProfile, 'model_id', 'modelId', 'Unknown');
|
||||
const rawChipId = pick(nodeProfile, 'chip_id', 'chipId', '');
|
||||
const rawFriendlyName = pick(nodeProfile, 'friendly_name', 'friendlyName', `${nodeId.substring(0, 8)}...`);
|
||||
|
||||
// When API has not fully loaded (reports "unknown"), present a nice default
|
||||
const modelId = isUnknown(rawModelId) ? 'Mac Studio' : rawModelId;
|
||||
const chipId = isUnknown(rawChipId) ? '' : rawChipId;
|
||||
const friendlyName = (!rawFriendlyName || isUnknown(rawFriendlyName)) ? 'Mac' : rawFriendlyName;
|
||||
// Extract model information
|
||||
const modelId = pick(nodeProfile, 'model_id', 'modelId', 'Unknown');
|
||||
const chipId = pick(nodeProfile, 'chip_id', 'chipId', '');
|
||||
const friendlyName = pick(nodeProfile, 'friendly_name', 'friendlyName', `${nodeId.substring(0, 8)}...`);
|
||||
|
||||
// Extract network addresses (support snake_case and camelCase)
|
||||
const addrList = [];
|
||||
@@ -2386,7 +2039,7 @@
|
||||
timestamp: new Date().toISOString()
|
||||
};
|
||||
|
||||
resultNodes[nodeId] = {
|
||||
result[nodeId] = {
|
||||
mem: memBytesTotal,
|
||||
addrs: addrList,
|
||||
last_addr_update: Date.now() / 1000,
|
||||
@@ -2400,21 +2053,7 @@
|
||||
};
|
||||
}
|
||||
|
||||
// Extract directed edges from topology.connections if present
|
||||
const connections = clusterState.topology && Array.isArray(clusterState.topology.connections)
|
||||
? clusterState.topology.connections
|
||||
: [];
|
||||
connections.forEach(conn => {
|
||||
if (!conn) return;
|
||||
const src = conn.local_node_id ?? conn.localNodeId;
|
||||
const dst = conn.send_back_node_id ?? conn.sendBackNodeId;
|
||||
if (!src || !dst) return;
|
||||
if (!resultNodes[src] || !resultNodes[dst]) return; // only draw edges between known nodes
|
||||
if (src === dst) return; // skip self loops for now
|
||||
resultEdges.push({ source: src, target: dst });
|
||||
});
|
||||
|
||||
return { nodes: resultNodes, edges: resultEdges };
|
||||
return result;
|
||||
}
|
||||
|
||||
// --- Conditional Data Handling ---
|
||||
@@ -2554,12 +2193,11 @@
|
||||
mi.timestamp = new Date().toISOString();
|
||||
}
|
||||
}
|
||||
const mockTopology = { nodes: mockData, edges: [] };
|
||||
renderNodes(mockTopology);
|
||||
renderNodes(mockData);
|
||||
lastUpdatedElement.textContent = `Last updated: ${new Date().toLocaleTimeString()} (Mock Data)`;
|
||||
|
||||
if (currentlySelectedNodeId && mockData[currentlySelectedNodeId]) {
|
||||
showNodeDetails(currentlySelectedNodeId, mockTopology.nodes);
|
||||
showNodeDetails(currentlySelectedNodeId, mockData);
|
||||
} else if (currentlySelectedNodeId && !mockData[currentlySelectedNodeId]) {
|
||||
nodeDetailPanel.classList.remove('visible');
|
||||
currentlySelectedNodeId = null;
|
||||
|
||||
@@ -57,13 +57,13 @@
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
|
||||
# JUST
|
||||
just
|
||||
]
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
|
||||
# MACMON
|
||||
macmon
|
||||
|
||||
# JUST
|
||||
just
|
||||
]);
|
||||
|
||||
shellHook = ''
|
||||
|
||||
5
justfile
5
justfile
@@ -15,3 +15,8 @@ sync:
|
||||
|
||||
sync-clean:
|
||||
uv sync --all-packages --force-reinstall --no-cache
|
||||
|
||||
clean:
|
||||
rm -rf **/__pycache__
|
||||
rm -rf rust/target
|
||||
rm -rf .venv
|
||||
|
||||
@@ -36,7 +36,6 @@ dependencies = [
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio>=4.10.0",
|
||||
"bidict>=0.23.1",
|
||||
"chainlit>=2.8.3",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -19,8 +19,7 @@ from exo.utils.channels import Receiver, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.worker.main import Worker
|
||||
from exo.utils.browser import open_url_in_browser_when_ready
|
||||
from exo.utils.chainlit_ui import start_chainlit, chainlit_cleanup
|
||||
|
||||
|
||||
# TODO: Entrypoint refactor
|
||||
# I marked this as a dataclass as I want trivial constructors.
|
||||
@@ -156,27 +155,17 @@ class Node:
|
||||
if self.api:
|
||||
self.api.reset()
|
||||
|
||||
|
||||
def main():
|
||||
args = Args.parse()
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
logger.info("Starting EXO")
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
|
||||
chainlit_proc = (
|
||||
start_chainlit(args.chainlit_port, args.chainlit_host, args.headless)
|
||||
if args.with_chainlit
|
||||
else None
|
||||
)
|
||||
if args.spawn_api and not args.headless:
|
||||
open_url_in_browser_when_ready(f"http://localhost:{args.api_port}")
|
||||
|
||||
try:
|
||||
anyio.run(node.run)
|
||||
finally:
|
||||
chainlit_cleanup(chainlit_proc)
|
||||
logger_cleanup()
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
|
||||
logger_cleanup()
|
||||
|
||||
|
||||
class Args(CamelCaseModel):
|
||||
@@ -185,11 +174,6 @@ class Args(CamelCaseModel):
|
||||
spawn_api: bool = False
|
||||
api_port: PositiveInt = 8000
|
||||
tb_only: bool = False
|
||||
# Chainlit options
|
||||
with_chainlit: bool = True
|
||||
chainlit_port: PositiveInt = 8001
|
||||
chainlit_host: str = "127.0.0.1"
|
||||
headless: bool = False
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -232,30 +216,6 @@ class Args(CamelCaseModel):
|
||||
action="store_true",
|
||||
dest="tb_only",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with-chainlit",
|
||||
action="store_true",
|
||||
dest="with_chainlit",
|
||||
default=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chainlit-port",
|
||||
type=int,
|
||||
dest="chainlit_port",
|
||||
default=8001,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chainlit-host",
|
||||
type=str,
|
||||
dest="chainlit_host",
|
||||
default="127.0.0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--headless",
|
||||
action="store_true",
|
||||
dest="headless",
|
||||
help="Prevents the app from opening in the browser."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -33,7 +33,6 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
TaggedCommand,
|
||||
# TODO: SpinUpInstance
|
||||
TaskFinished,
|
||||
)
|
||||
@@ -306,9 +305,7 @@ class API:
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for event in events:
|
||||
if isinstance(event, ChunkGenerated):
|
||||
logger.info(f"API received ChunkGenerated: {str(event)[:100]}")
|
||||
self.event_buffer.ingest(event.origin_idx, event.tagged_event.c)
|
||||
self.event_buffer.ingest(event.origin_idx, event.event)
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if (
|
||||
@@ -319,7 +316,5 @@ class API:
|
||||
|
||||
async def _send(self, command: Command):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self.node_id, tagged_command=TaggedCommand.from_(command)
|
||||
)
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
@@ -23,13 +23,12 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceDeleted,
|
||||
TaggedEvent,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
|
||||
from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
@@ -90,11 +89,9 @@ class Master:
|
||||
with self.command_receiver as commands:
|
||||
async for forwarder_command in commands:
|
||||
try:
|
||||
logger.info(
|
||||
f"Executing command: {forwarder_command.tagged_command.c}"
|
||||
)
|
||||
logger.info(f"Executing command: {forwarder_command.command}")
|
||||
generated_events: list[Event] = []
|
||||
command = forwarder_command.tagged_command.c
|
||||
command = forwarder_command.command
|
||||
match command:
|
||||
case ChatCompletion():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
@@ -130,11 +127,10 @@ class Master:
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ChatCompletionTask(
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
@@ -190,13 +186,14 @@ class Master:
|
||||
async for local_event in local_events:
|
||||
self._multi_buffer.ingest(
|
||||
local_event.origin_idx,
|
||||
local_event.tagged_event.c,
|
||||
local_event.event,
|
||||
local_event.origin,
|
||||
)
|
||||
for event in self._multi_buffer.drain():
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
|
||||
# TODO: SQL
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
@@ -224,17 +221,18 @@ class Master:
|
||||
ForwarderEvent(
|
||||
origin=NodeId(f"master_{self.node_id}"),
|
||||
origin_idx=local_index,
|
||||
tagged_event=TaggedEvent.from_(event),
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
local_index += 1
|
||||
|
||||
# This function is re-entrant, take care!
|
||||
async def _send_event(self, event: IndexedEvent):
|
||||
# Convenience method since this line is ugly
|
||||
await self.global_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin=self.node_id,
|
||||
origin_idx=event.idx,
|
||||
tagged_event=TaggedEvent.from_(event.event),
|
||||
event=event.event,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -88,7 +88,7 @@ def get_instance_placements_after_create(
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
target_instances[instance_id] = Instance(
|
||||
instance_id=instance_id,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_type=InstanceStatus.Active,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[
|
||||
Host(
|
||||
|
||||
@@ -2,17 +2,10 @@ import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.tests.api_utils_test import (
|
||||
ChatMessage,
|
||||
stream_chatgpt_response,
|
||||
with_master_main,
|
||||
)
|
||||
|
||||
|
||||
@with_master_main
|
||||
@pytest.mark.asyncio
|
||||
async def test_master_api_multiple_response_sequential() -> None:
|
||||
# TODO: This hangs at the moment it seems.
|
||||
# TODO
|
||||
return
|
||||
messages = [ChatMessage(role="user", content="Hello, who are you?")]
|
||||
token_count = 0
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from typing import List, Sequence
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
@@ -11,7 +12,6 @@ from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
CreateInstance,
|
||||
ForwarderCommand,
|
||||
TaggedCommand,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
@@ -19,7 +19,6 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
NodePerformanceMeasured,
|
||||
TaggedEvent,
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -29,9 +28,9 @@ from exo.shared.types.profiling import (
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletionTask, TaskStatus, TaskType
|
||||
from exo.shared.types.tasks import ChatCompletionTask, TaskStatus
|
||||
from exo.shared.types.worker.instances import Instance, InstanceStatus, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PartitionStrategy, PipelineShardMetadata
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.utils.channels import channel
|
||||
|
||||
|
||||
@@ -46,12 +45,12 @@ async def test_master():
|
||||
|
||||
all_events: List[IndexedEvent] = []
|
||||
|
||||
async def _get_events() -> Sequence[IndexedEvent]:
|
||||
def _get_events() -> Sequence[IndexedEvent]:
|
||||
orig_events = global_event_receiver.collect()
|
||||
for e in orig_events:
|
||||
all_events.append(
|
||||
IndexedEvent(
|
||||
event=e.tagged_event.c,
|
||||
event=e.event,
|
||||
idx=len(all_events), # origin=e.origin,
|
||||
)
|
||||
)
|
||||
@@ -64,133 +63,141 @@ async def test_master():
|
||||
command_receiver=co_receiver,
|
||||
tb_only=False,
|
||||
)
|
||||
asyncio.create_task(master.run())
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
# inject a NodePerformanceProfile event
|
||||
await local_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin_idx=0,
|
||||
origin=sender_node_id,
|
||||
tagged_event=TaggedEvent.from_(
|
||||
NodePerformanceMeasured(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
# inject a NodePerformanceProfile event
|
||||
logger.info("inject a NodePerformanceProfile event")
|
||||
await local_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin_idx=0,
|
||||
origin=sender_node_id,
|
||||
event=(
|
||||
NodePerformanceMeasured(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(flops_fp16=0),
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(flops_fp16=0),
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# wait for initial topology event
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await asyncio.sleep(0.001)
|
||||
while len(master.state.node_profiles) == 0:
|
||||
await asyncio.sleep(0.001)
|
||||
# wait for initial topology event
|
||||
logger.info("wait for initial topology event")
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.001)
|
||||
while len(master.state.node_profiles) == 0:
|
||||
await anyio.sleep(0.001)
|
||||
|
||||
await command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=node_id,
|
||||
tagged_command=TaggedCommand.from_(
|
||||
CreateInstance(
|
||||
command_id=CommandId(),
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
while len(master.state.instances.keys()) == 0:
|
||||
await asyncio.sleep(0.001)
|
||||
await command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=node_id,
|
||||
tagged_command=TaggedCommand.from_(
|
||||
ChatCompletion(
|
||||
command_id=CommandId(),
|
||||
request_params=ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="user", content="Hello, how are you?"
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
while len(await _get_events()) < 3:
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
events = await _get_events()
|
||||
assert len(events) == 3
|
||||
assert events[0].idx == 0
|
||||
assert events[1].idx == 1
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
runner_id = list(events[1].event.instance.shard_assignments.runner_to_shard.keys())[
|
||||
0
|
||||
]
|
||||
assert events[1].event == InstanceCreated(
|
||||
event_id=events[1].event.event_id,
|
||||
instance=Instance(
|
||||
instance_id=events[1].event.instance.instance_id,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
partition_strategy=PartitionStrategy.pipeline,
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
n_layers=16,
|
||||
logger.info("inject a CreateInstance Command")
|
||||
await command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=node_id,
|
||||
command=(
|
||||
CreateInstance(
|
||||
command_id=CommandId(),
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.info("wait for an instance")
|
||||
while len(master.state.instances.keys()) == 0:
|
||||
await anyio.sleep(0.001)
|
||||
logger.info("inject a ChatCompletion Command")
|
||||
await command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=node_id,
|
||||
command=(
|
||||
ChatCompletion(
|
||||
command_id=CommandId(),
|
||||
request_params=ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="user", content="Hello, how are you?"
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
while len(_get_events()) < 3:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
events = _get_events()
|
||||
assert len(events) == 3
|
||||
assert events[0].idx == 0
|
||||
assert events[1].idx == 1
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
runner_id = list(
|
||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||
)[0]
|
||||
assert events[1].event == InstanceCreated(
|
||||
event_id=events[1].event.event_id,
|
||||
instance=Instance(
|
||||
instance_id=events[1].event.instance.instance_id,
|
||||
instance_type=InstanceStatus.Active,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
),
|
||||
hosts=[],
|
||||
),
|
||||
hosts=[],
|
||||
),
|
||||
)
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event == TaskCreated(
|
||||
event_id=events[2].event.event_id,
|
||||
task_id=events[2].event.task_id,
|
||||
task=ChatCompletionTask(
|
||||
)
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event == TaskCreated(
|
||||
event_id=events[2].event.event_id,
|
||||
task_id=events[2].event.task_id,
|
||||
command_id=events[2].event.task.command_id,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
instance_id=events[2].event.task.instance_id,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content="Hello, how are you?")
|
||||
],
|
||||
task=ChatCompletionTask(
|
||||
task_id=events[2].event.task_id,
|
||||
command_id=events[2].event.task.command_id,
|
||||
instance_id=events[2].event.task.instance_id,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="user", content="Hello, how are you?"
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
await master.shutdown()
|
||||
|
||||
@@ -27,7 +27,7 @@ def topology() -> Topology:
|
||||
def instance() -> Instance:
|
||||
return Instance(
|
||||
instance_id=InstanceId(),
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_type=InstanceStatus.Active,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
|
||||
),
|
||||
|
||||
@@ -104,7 +104,7 @@ def apply_task_state_updated(event: TaskStateUpdated, state: State) -> State:
|
||||
update: dict[str, TaskStatus | None] = {
|
||||
"task_status": event.task_status,
|
||||
}
|
||||
if event.task_status != TaskStatus.FAILED:
|
||||
if event.task_status != TaskStatus.Failed:
|
||||
update["error_type"] = None
|
||||
update["error_message"] = None
|
||||
|
||||
@@ -138,7 +138,7 @@ def apply_instance_activated(event: InstanceActivated, state: State) -> State:
|
||||
return state
|
||||
|
||||
updated_instance = state.instances[event.instance_id].model_copy(
|
||||
update={"instance_type": InstanceStatus.ACTIVE}
|
||||
update={"instance_type": InstanceStatus.Active}
|
||||
)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
@@ -152,7 +152,7 @@ def apply_instance_deactivated(event: InstanceDeactivated, state: State) -> Stat
|
||||
return state
|
||||
|
||||
updated_instance = state.instances[event.instance_id].model_copy(
|
||||
update={"instance_type": InstanceStatus.INACTIVE}
|
||||
update={"instance_type": InstanceStatus.Inactive}
|
||||
)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
@@ -254,21 +254,18 @@ def apply_worker_status_updated(event: WorkerStatusUpdated, state: State) -> Sta
|
||||
|
||||
|
||||
def apply_topology_node_created(event: TopologyNodeCreated, state: State) -> State:
|
||||
logger.warning(f"~~~ APPLY Node {event.node_id} created")
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
|
||||
logger.warning(f"~~~ APPLY Edge {event.edge.local_node_id} -> {event.edge.send_back_node_id} created")
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_connection(event.edge)
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
|
||||
logger.warning(f"~~~ APPLY Edge {event.edge.local_node_id} -> {event.edge.send_back_node_id} deleted")
|
||||
topology = copy.copy(state.topology)
|
||||
if not topology.contains_connection(event.edge):
|
||||
return state
|
||||
|
||||
@@ -15,21 +15,6 @@ class ModelCard(CamelCaseModel):
|
||||
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# kimi k2
|
||||
# "kimi-k2:4bit": ModelCard(
|
||||
# short_id="kimi-k2:4bit",
|
||||
# model_id="mlx-community/Kimi-K2-Instruct-4bit",
|
||||
# name="Kimi K2 (4-bit)",
|
||||
# description="""Kimi K2 is a state-of-the-art mixture-of-experts (MoE) language model with 32 billion activated parameters and 1 trillion total parameters. Trained with the Muon optimizer, Kimi K2 achieves exceptional performance across frontier knowledge, reasoning, and coding tasks while being meticulously optimized for agentic capabilities.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
# pretty_name="Kimi K2 (4-bit)",
|
||||
# storage_size=Memory.from_kb(536870912),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
|
||||
# deepseek v3
|
||||
"deepseek-v3-0324:4bit": ModelCard(
|
||||
short_id="deepseek-v3-0324:4bit",
|
||||
@@ -110,6 +95,19 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
short_id="kimi-k2-instruct-4bit",
|
||||
model_id="mlx-community/Kimi-K2-Instruct-4bit",
|
||||
name="Kimi K2 Instruct (4-bit)",
|
||||
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
pretty_name="Kimi K2 Instruct (4-bit)",
|
||||
storage_size=Memory.from_bytes(577597603840),
|
||||
n_layers=61,
|
||||
),
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
short_id="llama-3.1-8b",
|
||||
|
||||
@@ -1,35 +1,30 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.openai_compat import FinishReason
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
token = "token"
|
||||
image = "image"
|
||||
Token = "Token"
|
||||
Image = "Image"
|
||||
|
||||
|
||||
class BaseChunk[ChunkTypeT: ChunkType](BaseModel):
|
||||
chunk_type: ChunkTypeT
|
||||
class BaseChunk(TaggedModel):
|
||||
command_id: CommandId
|
||||
idx: int
|
||||
model: ModelId
|
||||
|
||||
|
||||
class TokenChunk(BaseChunk[ChunkType.token]):
|
||||
chunk_type: Literal[ChunkType.token] = Field(default=ChunkType.token, frozen=True)
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk[ChunkType.image]):
|
||||
chunk_type: Literal[ChunkType.image] = Field(default=ChunkType.image, frozen=True)
|
||||
class ImageChunk(BaseChunk):
|
||||
data: bytes
|
||||
|
||||
|
||||
GenerationChunk = Annotated[TokenChunk | ImageChunk, Field(discriminator="chunk_type")]
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -7,8 +6,7 @@ from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.pydantic_tagged import Tagged, tagged_union
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
# TODO: We need to have a distinction between create instance and spin up instance.
|
||||
@@ -21,7 +19,7 @@ class CommandType(str, Enum):
|
||||
RequestEventLog = "RequestEventLog"
|
||||
|
||||
|
||||
class BaseCommand(CamelCaseModel):
|
||||
class BaseCommand(TaggedModel):
|
||||
command_id: CommandId = Field(default_factory=CommandId)
|
||||
|
||||
|
||||
@@ -49,30 +47,16 @@ class RequestEventLog(BaseCommand):
|
||||
since_idx: int
|
||||
|
||||
|
||||
Command = Union[
|
||||
RequestEventLog,
|
||||
ChatCompletion,
|
||||
CreateInstance,
|
||||
SpinUpInstance,
|
||||
DeleteInstance,
|
||||
TaskFinished,
|
||||
]
|
||||
|
||||
|
||||
@tagged_union(
|
||||
{
|
||||
CommandType.ChatCompletion: ChatCompletion,
|
||||
CommandType.CreateInstance: CreateInstance,
|
||||
CommandType.SpinUpInstance: SpinUpInstance,
|
||||
CommandType.DeleteInstance: DeleteInstance,
|
||||
CommandType.TaskFinished: TaskFinished,
|
||||
CommandType.RequestEventLog: RequestEventLog,
|
||||
}
|
||||
Command = (
|
||||
RequestEventLog
|
||||
| ChatCompletion
|
||||
| CreateInstance
|
||||
| SpinUpInstance
|
||||
| DeleteInstance
|
||||
| TaskFinished
|
||||
)
|
||||
class TaggedCommand(Tagged[Command]):
|
||||
pass
|
||||
|
||||
|
||||
class ForwarderCommand(CamelCaseModel):
|
||||
origin: NodeId
|
||||
tagged_command: TaggedCommand
|
||||
command: Command
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import Self
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, field_validator
|
||||
from pydantic import GetCoreSchemaHandler, field_validator
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
class ID(str):
|
||||
|
||||
class Id(str):
|
||||
def __new__(cls, value: str | None = None) -> Self:
|
||||
return super().__new__(cls, value or str(uuid4()))
|
||||
|
||||
@@ -17,15 +19,15 @@ class ID(str):
|
||||
return core_schema.str_schema()
|
||||
|
||||
|
||||
class NodeId(ID):
|
||||
class NodeId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class CommandId(ID):
|
||||
class CommandId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class Host(BaseModel):
|
||||
class Host(CamelCaseModel):
|
||||
ip: str
|
||||
port: int
|
||||
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import CommandId, GenerationChunk
|
||||
from exo.shared.types.common import ID, NodeId
|
||||
from exo.shared.types.common import Id, NodeId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.common import InstanceId, WorkerStatus
|
||||
from exo.shared.types.worker.instances import Instance
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.pydantic_tagged import Tagged, tagged_union
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
class EventId(ID):
|
||||
class EventId(Id):
|
||||
"""
|
||||
Newtype around `ID`
|
||||
"""
|
||||
@@ -60,7 +58,7 @@ class EventType(str, Enum):
|
||||
TopologyEdgeDeleted = "TopologyEdgeDeleted"
|
||||
|
||||
|
||||
class BaseEvent(CamelCaseModel):
|
||||
class BaseEvent(TaggedModel):
|
||||
event_id: EventId = Field(default_factory=EventId)
|
||||
|
||||
|
||||
@@ -145,52 +143,26 @@ class TopologyEdgeDeleted(BaseEvent):
|
||||
edge: Connection
|
||||
|
||||
|
||||
Event = Union[
|
||||
TestEvent,
|
||||
TaskCreated,
|
||||
TaskStateUpdated,
|
||||
TaskFailed,
|
||||
TaskDeleted,
|
||||
InstanceCreated,
|
||||
InstanceActivated,
|
||||
InstanceDeactivated,
|
||||
InstanceDeleted,
|
||||
RunnerStatusUpdated,
|
||||
RunnerDeleted,
|
||||
NodePerformanceMeasured,
|
||||
NodeMemoryMeasured,
|
||||
WorkerStatusUpdated,
|
||||
ChunkGenerated,
|
||||
TopologyNodeCreated,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
]
|
||||
|
||||
|
||||
@tagged_union(
|
||||
{
|
||||
EventType.TestEvent: TestEvent,
|
||||
EventType.TaskCreated: TaskCreated,
|
||||
EventType.TaskStateUpdated: TaskStateUpdated,
|
||||
EventType.TaskFailed: TaskFailed,
|
||||
EventType.TaskDeleted: TaskDeleted,
|
||||
EventType.InstanceCreated: InstanceCreated,
|
||||
EventType.InstanceActivated: InstanceActivated,
|
||||
EventType.InstanceDeactivated: InstanceDeactivated,
|
||||
EventType.InstanceDeleted: InstanceDeleted,
|
||||
EventType.RunnerStatusUpdated: RunnerStatusUpdated,
|
||||
EventType.RunnerDeleted: RunnerDeleted,
|
||||
EventType.NodePerformanceMeasured: NodePerformanceMeasured,
|
||||
EventType.NodeMemoryMeasured: NodeMemoryMeasured,
|
||||
EventType.WorkerStatusUpdated: WorkerStatusUpdated,
|
||||
EventType.ChunkGenerated: ChunkGenerated,
|
||||
EventType.TopologyNodeCreated: TopologyNodeCreated,
|
||||
EventType.TopologyEdgeCreated: TopologyEdgeCreated,
|
||||
EventType.TopologyEdgeDeleted: TopologyEdgeDeleted,
|
||||
}
|
||||
Event = (
|
||||
TestEvent
|
||||
| TaskCreated
|
||||
| TaskStateUpdated
|
||||
| TaskFailed
|
||||
| TaskDeleted
|
||||
| InstanceCreated
|
||||
| InstanceActivated
|
||||
| InstanceDeactivated
|
||||
| InstanceDeleted
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodePerformanceMeasured
|
||||
| NodeMemoryMeasured
|
||||
| WorkerStatusUpdated
|
||||
| ChunkGenerated
|
||||
| TopologyNodeCreated
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
)
|
||||
class TaggedEvent(Tagged[Event]):
|
||||
pass
|
||||
|
||||
|
||||
class IndexedEvent(CamelCaseModel):
|
||||
@@ -205,4 +177,4 @@ class ForwarderEvent(CamelCaseModel):
|
||||
|
||||
origin_idx: int = Field(ge=0)
|
||||
origin: NodeId
|
||||
tagged_event: TaggedEvent
|
||||
event: Event
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.types.common import ID
|
||||
from exo.shared.types.common import Id
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class ModelId(ID):
|
||||
class ModelId(Id):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -1,25 +1,19 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import ConfigDict, Field, field_validator, field_serializer
|
||||
|
||||
from exo.shared.topology import Topology, TopologySnapshot
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.common import InstanceId, WorkerStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgressData
|
||||
from exo.shared.types.worker.instances import Instance
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
def _encode_topology(topo: "Topology") -> dict[str, Any]: # noqa: D401
|
||||
"""Serialise *topo* into a JSON-compatible dict."""
|
||||
|
||||
return topo.to_snapshot().model_dump()
|
||||
|
||||
|
||||
class State(BaseModel):
|
||||
class State(CamelCaseModel):
|
||||
"""Global system state.
|
||||
|
||||
The :class:`Topology` instance is encoded/decoded via an immutable
|
||||
@@ -29,9 +23,6 @@ class State(BaseModel):
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
json_encoders={
|
||||
Topology: _encode_topology,
|
||||
},
|
||||
)
|
||||
node_status: Mapping[NodeId, WorkerStatus] = {}
|
||||
instances: Mapping[InstanceId, Instance] = {}
|
||||
@@ -40,10 +31,12 @@ class State(BaseModel):
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
|
||||
topology: Topology = Topology()
|
||||
history: Sequence[Topology] = []
|
||||
# TODO: we want information about every model that is downloaded on each node
|
||||
node_downloads: Mapping[NodeId, Mapping[str, DownloadProgressData]] = {}
|
||||
last_event_applied_idx: int = Field(default=-1, ge=-1)
|
||||
|
||||
@field_serializer("topology", mode="plain")
|
||||
def _encode_topology(self, value: Topology) -> TopologySnapshot:
|
||||
return value.to_snapshot()
|
||||
|
||||
@field_validator("topology", mode="before")
|
||||
@classmethod
|
||||
def _deserialize_topology(cls, value: object) -> Topology: # noqa: D401 – Pydantic validator signature
|
||||
|
||||
@@ -1,30 +1,25 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.common import ID, CommandId
|
||||
from exo.shared.types.common import CommandId, Id
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
class TaskId(ID):
|
||||
class TaskId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
CHAT_COMPLETION = "CHAT_COMPLETION"
|
||||
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
PENDING = "PENDING"
|
||||
RUNNING = "RUNNING"
|
||||
COMPLETE = "COMPLETE"
|
||||
FAILED = "FAILED"
|
||||
Pending = "Pending"
|
||||
Running = "Running"
|
||||
Complete = "Complete"
|
||||
Failed = "Failed"
|
||||
|
||||
|
||||
class ChatCompletionTask(BaseModel):
|
||||
task_type: Literal[TaskType.CHAT_COMPLETION] = TaskType.CHAT_COMPLETION
|
||||
class ChatCompletionTask(TaggedModel):
|
||||
task_id: TaskId
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
@@ -35,4 +30,4 @@ class ChatCompletionTask(BaseModel):
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
Task = Annotated[ChatCompletionTask, Field(discriminator="task_type")]
|
||||
Task = ChatCompletionTask
|
||||
|
||||
@@ -1,116 +1,69 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from exo.shared.openai_compat import FinishReason
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
## Messages passed TO the runner
|
||||
class MessageType(str, Enum):
|
||||
Setup = "setup"
|
||||
ChatTask = "chat_task"
|
||||
Exit = "exit"
|
||||
|
||||
|
||||
class BaseRunnerMessage[MT: MessageType](BaseModel):
|
||||
class BaseRunnerMessage(TaggedModel):
|
||||
pass
|
||||
|
||||
|
||||
class SetupMessage(BaseRunnerMessage[MessageType.Setup]):
|
||||
type: Literal[MessageType.Setup] = Field(default=MessageType.Setup, frozen=True)
|
||||
class SetupMessage(BaseRunnerMessage):
|
||||
model_shard_meta: ShardMetadata
|
||||
hosts: list[Host]
|
||||
|
||||
|
||||
# TODO: We probably want a general task message that can take any task type. Can be fixed later.
|
||||
class ChatTaskMessage(BaseRunnerMessage[MessageType.ChatTask]):
|
||||
type: Literal[MessageType.ChatTask] = Field(
|
||||
default=MessageType.ChatTask, frozen=True
|
||||
)
|
||||
class ChatTaskMessage(BaseRunnerMessage):
|
||||
task_data: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class ExitMessage(BaseRunnerMessage[MessageType.Exit]):
|
||||
type: Literal[MessageType.Exit] = Field(default=MessageType.Exit, frozen=True)
|
||||
|
||||
|
||||
RunnerMessage = Annotated[
|
||||
SetupMessage | ChatTaskMessage | ExitMessage, Field(discriminator="type")
|
||||
]
|
||||
RunnerMessageTypeAdapter: TypeAdapter[RunnerMessage] = TypeAdapter(RunnerMessage)
|
||||
|
||||
|
||||
## Responses passed FROM the runner
|
||||
class RunnerResponseType(str, Enum):
|
||||
InitializedResponse = "initialized_response"
|
||||
TokenizedResponse = "tokenized_response"
|
||||
GenerationResponse = "generation_response"
|
||||
FinishedResponse = "finished_response"
|
||||
PrintResponse = "print_response"
|
||||
ErrorResponse = "error_response"
|
||||
|
||||
|
||||
class BaseRunnerResponse[RRT: RunnerResponseType](BaseModel):
|
||||
class ExitMessage(BaseRunnerMessage):
|
||||
pass
|
||||
|
||||
|
||||
class InitializedResponse(BaseRunnerResponse[RunnerResponseType.InitializedResponse]):
|
||||
type: Literal[RunnerResponseType.InitializedResponse] = Field(
|
||||
default=RunnerResponseType.InitializedResponse, frozen=True
|
||||
)
|
||||
RunnerMessage = SetupMessage | ChatTaskMessage | ExitMessage
|
||||
|
||||
|
||||
class BaseRunnerResponse(TaggedModel):
|
||||
pass
|
||||
|
||||
|
||||
class InitializedResponse(BaseRunnerResponse):
|
||||
time_taken: float
|
||||
|
||||
|
||||
class TokenizedResponse(BaseRunnerResponse[RunnerResponseType.TokenizedResponse]):
|
||||
type: Literal[RunnerResponseType.TokenizedResponse] = Field(
|
||||
default=RunnerResponseType.TokenizedResponse, frozen=True
|
||||
)
|
||||
class TokenizedResponse(BaseRunnerResponse):
|
||||
prompt_tokens: int
|
||||
|
||||
|
||||
class GenerationResponse(BaseRunnerResponse[RunnerResponseType.GenerationResponse]):
|
||||
type: Literal[RunnerResponseType.GenerationResponse] = Field(
|
||||
default=RunnerResponseType.GenerationResponse, frozen=True
|
||||
)
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
# logprobs: Optional[list[float]] = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class PrintResponse(BaseRunnerResponse[RunnerResponseType.PrintResponse]):
|
||||
type: Literal[RunnerResponseType.PrintResponse] = Field(
|
||||
default=RunnerResponseType.PrintResponse, frozen=True
|
||||
)
|
||||
class PrintResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse[RunnerResponseType.FinishedResponse]):
|
||||
type: Literal[RunnerResponseType.FinishedResponse] = Field(
|
||||
default=RunnerResponseType.FinishedResponse, frozen=True
|
||||
)
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
|
||||
class ErrorResponse(BaseRunnerResponse[RunnerResponseType.ErrorResponse]):
|
||||
type: Literal[RunnerResponseType.ErrorResponse] = Field(
|
||||
default=RunnerResponseType.ErrorResponse, frozen=True
|
||||
)
|
||||
class ErrorResponse(BaseRunnerResponse):
|
||||
error_type: str
|
||||
error_message: str
|
||||
traceback: str
|
||||
|
||||
|
||||
RunnerResponse = Annotated[
|
||||
RunnerResponse = (
|
||||
InitializedResponse
|
||||
| TokenizedResponse
|
||||
| GenerationResponse
|
||||
| PrintResponse
|
||||
| FinishedResponse
|
||||
| ErrorResponse,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
RunnerResponseTypeAdapter: TypeAdapter[RunnerResponse] = TypeAdapter(RunnerResponse)
|
||||
| ErrorResponse
|
||||
)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo.shared.types.common import ID
|
||||
from exo.shared.types.common import Id
|
||||
|
||||
|
||||
class InstanceId(ID):
|
||||
class InstanceId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerId(ID):
|
||||
class RunnerId(Id):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from exo.shared.types.worker.commands_runner import (
|
||||
PrintResponse,
|
||||
RunnerMessage,
|
||||
RunnerResponse,
|
||||
RunnerResponseType,
|
||||
)
|
||||
|
||||
### Utils - Runner Prints
|
||||
@@ -17,7 +16,6 @@ from exo.shared.types.worker.commands_runner import (
|
||||
|
||||
def runner_print(text: str) -> None:
|
||||
obj = PrintResponse(
|
||||
type=RunnerResponseType.PrintResponse,
|
||||
text=text,
|
||||
)
|
||||
|
||||
@@ -27,7 +25,6 @@ def runner_print(text: str) -> None:
|
||||
|
||||
def runner_write_error(error: Exception) -> None:
|
||||
error_response: ErrorResponse = ErrorResponse(
|
||||
type=RunnerResponseType.ErrorResponse,
|
||||
error_type=type(error).__name__,
|
||||
error_message=str(error),
|
||||
traceback=traceback.format_exc(),
|
||||
|
||||
@@ -1,73 +1,33 @@
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Literal,
|
||||
Union,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
class DownloadProgressData(CamelCaseModel):
|
||||
total_bytes: Memory
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
|
||||
completed_files: int
|
||||
total_files: int
|
||||
|
||||
speed: float
|
||||
eta_ms: int
|
||||
|
||||
files: dict[str, "DownloadProgressData"]
|
||||
|
||||
class DownloadStatus(str, Enum):
|
||||
Pending = "Pending"
|
||||
Downloading = "Downloading"
|
||||
Completed = "Completed"
|
||||
Failed = "Failed"
|
||||
|
||||
|
||||
class BaseDownloadProgress[DownloadStatusT: DownloadStatus](CamelCaseModel):
|
||||
class BaseDownloadProgress(TaggedModel):
|
||||
node_id: NodeId
|
||||
download_status: DownloadStatusT
|
||||
|
||||
|
||||
class DownloadPending(BaseDownloadProgress[DownloadStatus.Pending]):
|
||||
download_status: Literal[DownloadStatus.Pending] = Field(
|
||||
default=DownloadStatus.Pending
|
||||
)
|
||||
class DownloadPending(BaseDownloadProgress):
|
||||
pass
|
||||
|
||||
|
||||
class DownloadCompleted(BaseDownloadProgress[DownloadStatus.Completed]):
|
||||
download_status: Literal[DownloadStatus.Completed] = Field(
|
||||
default=DownloadStatus.Completed
|
||||
)
|
||||
class DownloadCompleted(BaseDownloadProgress):
|
||||
pass
|
||||
|
||||
|
||||
class DownloadFailed(BaseDownloadProgress[DownloadStatus.Failed]):
|
||||
download_status: Literal[DownloadStatus.Failed] = Field(
|
||||
default=DownloadStatus.Failed
|
||||
)
|
||||
class DownloadFailed(BaseDownloadProgress):
|
||||
error_message: str
|
||||
|
||||
|
||||
class DownloadOngoing(BaseDownloadProgress[DownloadStatus.Downloading]):
|
||||
download_status: Literal[DownloadStatus.Downloading] = Field(
|
||||
default=DownloadStatus.Downloading
|
||||
)
|
||||
class DownloadOngoing(BaseDownloadProgress):
|
||||
download_progress: DownloadProgressData
|
||||
|
||||
|
||||
DownloadProgress = Annotated[
|
||||
Union[
|
||||
DownloadPending,
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
],
|
||||
Field(discriminator="download_status"),
|
||||
]
|
||||
DownloadProgress = (
|
||||
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
|
||||
)
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class InstanceStatus(str, Enum):
|
||||
ACTIVE = "ACTIVE"
|
||||
INACTIVE = "INACTIVE"
|
||||
Active = "Active"
|
||||
Inactive = "Inactive"
|
||||
|
||||
|
||||
class Instance(BaseModel):
|
||||
class Instance(CamelCaseModel):
|
||||
instance_id: InstanceId
|
||||
instance_type: InstanceStatus
|
||||
shard_assignments: ShardAssignments
|
||||
|
||||
@@ -1,86 +1,49 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Generic, Literal, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.events import InstanceId
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.common import RunnerId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
class RunnerOpType(str, Enum):
|
||||
ASSIGN_RUNNER = "assign_runner"
|
||||
UNASSIGN_RUNNER = "unassign_runner"
|
||||
RUNNER_UP = "runner_up"
|
||||
RUNNER_DOWN = "runner_down"
|
||||
RUNNER_FAILED = "runner_failed"
|
||||
CHAT_COMPLETION = "chat_completion"
|
||||
class BaseRunnerOp(TaggedModel):
|
||||
pass
|
||||
|
||||
|
||||
RunnerOpT = TypeVar("RunnerOpT", bound=RunnerOpType)
|
||||
|
||||
|
||||
class BaseRunnerOp(BaseModel, Generic[RunnerOpT]):
|
||||
op_type: RunnerOpT
|
||||
|
||||
|
||||
class AssignRunnerOp(BaseRunnerOp[Literal[RunnerOpType.ASSIGN_RUNNER]]):
|
||||
op_type: Literal[RunnerOpType.ASSIGN_RUNNER] = Field(
|
||||
default=RunnerOpType.ASSIGN_RUNNER, frozen=True
|
||||
)
|
||||
class AssignRunnerOp(BaseRunnerOp):
|
||||
instance_id: InstanceId
|
||||
runner_id: RunnerId
|
||||
shard_metadata: ShardMetadata
|
||||
hosts: list[Host]
|
||||
|
||||
|
||||
class UnassignRunnerOp(BaseRunnerOp[Literal[RunnerOpType.UNASSIGN_RUNNER]]):
|
||||
op_type: Literal[RunnerOpType.UNASSIGN_RUNNER] = Field(
|
||||
default=RunnerOpType.UNASSIGN_RUNNER, frozen=True
|
||||
)
|
||||
class UnassignRunnerOp(BaseRunnerOp):
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class RunnerUpOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_UP]]):
|
||||
op_type: Literal[RunnerOpType.RUNNER_UP] = Field(
|
||||
default=RunnerOpType.RUNNER_UP, frozen=True
|
||||
)
|
||||
class RunnerUpOp(BaseRunnerOp):
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class RunnerDownOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_DOWN]]):
|
||||
op_type: Literal[RunnerOpType.RUNNER_DOWN] = Field(
|
||||
default=RunnerOpType.RUNNER_DOWN, frozen=True
|
||||
)
|
||||
class RunnerDownOp(BaseRunnerOp):
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class RunnerFailedOp(BaseRunnerOp[Literal[RunnerOpType.RUNNER_FAILED]]):
|
||||
op_type: Literal[RunnerOpType.RUNNER_FAILED] = Field(
|
||||
default=RunnerOpType.RUNNER_FAILED, frozen=True
|
||||
)
|
||||
class RunnerFailedOp(BaseRunnerOp):
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class ExecuteTaskOp(BaseRunnerOp[Literal[RunnerOpType.CHAT_COMPLETION]]):
|
||||
op_type: Literal[RunnerOpType.CHAT_COMPLETION] = Field(
|
||||
default=RunnerOpType.CHAT_COMPLETION, frozen=True
|
||||
)
|
||||
class ExecuteTaskOp(BaseRunnerOp):
|
||||
runner_id: RunnerId
|
||||
task: Task
|
||||
|
||||
|
||||
# Aggregate all runner operations into a single, strictly-typed union for dispatching.
|
||||
RunnerOp = Annotated[
|
||||
Union[
|
||||
AssignRunnerOp,
|
||||
UnassignRunnerOp,
|
||||
RunnerUpOp,
|
||||
RunnerDownOp,
|
||||
RunnerFailedOp,
|
||||
ExecuteTaskOp,
|
||||
],
|
||||
Field(discriminator="op_type"),
|
||||
]
|
||||
RunnerOp = (
|
||||
AssignRunnerOp
|
||||
| UnassignRunnerOp
|
||||
| RunnerUpOp
|
||||
| RunnerDownOp
|
||||
| RunnerFailedOp
|
||||
| ExecuteTaskOp
|
||||
)
|
||||
|
||||
@@ -1,80 +1,54 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.worker.common import RunnerId
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
class RunnerStatusType(str, Enum):
|
||||
Downloading = "Downloading"
|
||||
Inactive = "Inactive"
|
||||
Starting = "Starting"
|
||||
Loaded = "Loaded"
|
||||
Running = "Running"
|
||||
Failed = "Failed"
|
||||
class BaseRunnerStatus(TaggedModel):
|
||||
pass
|
||||
|
||||
|
||||
class BaseRunnerStatus[T: RunnerStatusType](BaseModel):
|
||||
runner_status: T
|
||||
|
||||
|
||||
class DownloadingRunnerStatus(BaseRunnerStatus[RunnerStatusType.Downloading]):
|
||||
runner_status: Literal[RunnerStatusType.Downloading] = Field(
|
||||
default=RunnerStatusType.Downloading
|
||||
)
|
||||
class DownloadingRunnerStatus(BaseRunnerStatus):
|
||||
download_progress: DownloadProgress
|
||||
|
||||
|
||||
class InactiveRunnerStatus(BaseRunnerStatus[RunnerStatusType.Inactive]):
|
||||
runner_status: Literal[RunnerStatusType.Inactive] = Field(
|
||||
default=RunnerStatusType.Inactive
|
||||
)
|
||||
class InactiveRunnerStatus(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class StartingRunnerStatus(BaseRunnerStatus[RunnerStatusType.Starting]):
|
||||
runner_status: Literal[RunnerStatusType.Starting] = Field(
|
||||
default=RunnerStatusType.Starting
|
||||
)
|
||||
class StartingRunnerStatus(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class LoadedRunnerStatus(BaseRunnerStatus[RunnerStatusType.Loaded]):
|
||||
runner_status: Literal[RunnerStatusType.Loaded] = Field(
|
||||
default=RunnerStatusType.Loaded
|
||||
)
|
||||
class LoadedRunnerStatus(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunningRunnerStatus(BaseRunnerStatus[RunnerStatusType.Running]):
|
||||
runner_status: Literal[RunnerStatusType.Running] = Field(
|
||||
default=RunnerStatusType.Running
|
||||
)
|
||||
class RunningRunnerStatus(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class FailedRunnerStatus(BaseRunnerStatus[RunnerStatusType.Failed]):
|
||||
runner_status: Literal[RunnerStatusType.Failed] = Field(
|
||||
default=RunnerStatusType.Failed
|
||||
)
|
||||
class FailedRunnerStatus(BaseRunnerStatus):
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
RunnerStatus = Annotated[
|
||||
RunnerStatus = (
|
||||
DownloadingRunnerStatus
|
||||
| InactiveRunnerStatus
|
||||
| StartingRunnerStatus
|
||||
| LoadedRunnerStatus
|
||||
| RunningRunnerStatus
|
||||
| FailedRunnerStatus,
|
||||
Field,
|
||||
]
|
||||
RunnerStatusParser: TypeAdapter[RunnerStatus] = TypeAdapter(RunnerStatus)
|
||||
| FailedRunnerStatus
|
||||
)
|
||||
|
||||
|
||||
class ShardAssignments(BaseModel):
|
||||
class ShardAssignments(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
runner_to_shard: Mapping[RunnerId, ShardMetadata]
|
||||
node_to_runner: Mapping[NodeId, RunnerId]
|
||||
|
||||
@@ -1,39 +1,26 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Generic, Literal, Optional, TypeVar
|
||||
from pydantic import Field
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
class PartitionStrategy(str, Enum):
|
||||
pipeline = "pipeline"
|
||||
|
||||
|
||||
PartitionStrategyT = TypeVar(
|
||||
"PartitionStrategyT", bound=PartitionStrategy, covariant=True
|
||||
)
|
||||
|
||||
|
||||
class BaseShardMetadata(BaseModel, Generic[PartitionStrategyT]):
|
||||
class BaseShardMetadata(TaggedModel):
|
||||
"""
|
||||
Defines a specific shard of the model that is ready to be run on a device.
|
||||
Replaces previous `Shard` object.
|
||||
"""
|
||||
|
||||
model_meta: ModelMetadata
|
||||
partition_strategy: PartitionStrategyT
|
||||
device_rank: int
|
||||
world_size: int
|
||||
|
||||
# Error handling; equivalent to monkey-patch, but we can't monkey-patch runner.py
|
||||
# This is kinda annoying because it allocates memory in the ShardMetadata object. Can be rethought after Shanghai.
|
||||
immediate_exception: bool = False
|
||||
should_timeout: Optional[float] = None
|
||||
should_timeout: float | None = None
|
||||
|
||||
|
||||
class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline]]):
|
||||
class PipelineShardMetadata(BaseShardMetadata):
|
||||
"""
|
||||
Pipeline parallelism shard meta.
|
||||
|
||||
@@ -41,12 +28,9 @@ class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline
|
||||
where start_layer is inclusive and end_layer is exclusive.
|
||||
"""
|
||||
|
||||
partition_strategy: Literal[PartitionStrategy.pipeline] = Field(
|
||||
default=PartitionStrategy.pipeline, frozen=True
|
||||
)
|
||||
start_layer: Annotated[int, Field(ge=0)]
|
||||
end_layer: Annotated[int, Field(ge=0)]
|
||||
n_layers: Annotated[int, Field(ge=0)]
|
||||
start_layer: int = Field(ge=0)
|
||||
end_layer: int = Field(ge=0)
|
||||
n_layers: int = Field(ge=0)
|
||||
|
||||
@property
|
||||
def is_first_layer(self) -> bool:
|
||||
@@ -62,17 +46,4 @@ class PipelineShardMetadata(BaseShardMetadata[Literal[PartitionStrategy.pipeline
|
||||
)
|
||||
|
||||
|
||||
ShardMetadata = Annotated[
|
||||
PipelineShardMetadata, Field(discriminator="partition_strategy")
|
||||
]
|
||||
ShardMetadataParser: TypeAdapter[ShardMetadata] = TypeAdapter(ShardMetadata)
|
||||
|
||||
|
||||
class ShardPlacement(BaseModel, Generic[PartitionStrategyT]):
|
||||
"""
|
||||
A shard placement is the description of a model distributed across a set of nodes.
|
||||
The Generic[PartitionStrategyT] enforces that the shard assignments all use the same partition strategy.
|
||||
"""
|
||||
|
||||
model_id: ModelId
|
||||
shard_assignments: dict[NodeId, BaseShardMetadata[PartitionStrategyT]]
|
||||
ShardMetadata = PipelineShardMetadata
|
||||
|
||||
@@ -9,9 +9,9 @@ class OrderedBuffer[T]:
|
||||
source at a time.
|
||||
"""
|
||||
|
||||
def __init__(self, start_idx: int = 0):
|
||||
def __init__(self):
|
||||
self.store: dict[int, T] = {}
|
||||
self.next_idx_to_release: int = start_idx
|
||||
self.next_idx_to_release: int = 0
|
||||
|
||||
def ingest(self, idx: int, t: T):
|
||||
"""Ingest a sequence into the buffer"""
|
||||
@@ -19,6 +19,9 @@ class OrderedBuffer[T]:
|
||||
if idx < self.next_idx_to_release:
|
||||
return
|
||||
if idx in self.store:
|
||||
assert self.store[idx] == t, (
|
||||
"Received different messages with identical indices, probable race condition"
|
||||
)
|
||||
return
|
||||
self.store[idx] = t
|
||||
|
||||
@@ -56,15 +59,8 @@ class MultiSourceBuffer[SourceId, T]:
|
||||
|
||||
def ingest(self, idx: int, t: T, source: SourceId):
|
||||
if source not in self.stores:
|
||||
# Seed the per-source buffer to start at the first observed index for that source.
|
||||
self.stores[source] = OrderedBuffer(start_idx=idx)
|
||||
self.stores[source] = OrderedBuffer()
|
||||
buffer = self.stores[source]
|
||||
# Handle per-source sequence reset (e.g., worker restart resetting its local index to 0).
|
||||
# If we observe idx == 0 from an existing source with a higher expected index,
|
||||
# reset that source's buffer to accept the new sequence.
|
||||
if idx == 0 and buffer.next_idx_to_release > 0:
|
||||
self.stores[source] = OrderedBuffer(start_idx=0)
|
||||
buffer = self.stores[source]
|
||||
buffer.ingest(idx, t)
|
||||
|
||||
def drain(self) -> list[T]:
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
# pyright: reportAny=false, reportUnknownArgumentType=false, reportUnknownVariableType=false
|
||||
|
||||
from typing import Any, Self
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_serializer, model_validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
from pydantic_core.core_schema import (
|
||||
SerializerFunctionWrapHandler,
|
||||
ValidatorFunctionWrapHandler,
|
||||
)
|
||||
|
||||
|
||||
class CamelCaseModel(BaseModel):
|
||||
@@ -12,5 +20,20 @@ class CamelCaseModel(BaseModel):
|
||||
validate_by_name=True,
|
||||
extra="forbid",
|
||||
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
|
||||
# strict=True,
|
||||
strict=True,
|
||||
)
|
||||
|
||||
|
||||
class TaggedModel(CamelCaseModel):
|
||||
@model_serializer(mode="wrap")
|
||||
def _serialize(self, handler: SerializerFunctionWrapHandler):
|
||||
inner = handler(self)
|
||||
return {self.__class__.__name__: inner}
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def _validate(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> Self:
|
||||
if isinstance(v, dict) and len(v) == 1 and cls.__name__ in v:
|
||||
return handler(v[cls.__name__])
|
||||
|
||||
return handler(v)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import Union
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
|
||||
from exo.utils.pydantic_tagged import Tagged, tagged_union # ← CHANGE ME
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
def test_plain_union_prefers_first_member_when_shapes_are_identical():
|
||||
@@ -22,161 +21,230 @@ def test_plain_union_prefers_first_member_when_shapes_are_identical():
|
||||
|
||||
|
||||
def test_tagged_union_serializes_and_deserializes_two_identical_shapes_correctly():
|
||||
class Foo1(BaseModel):
|
||||
class Foo1(TaggedModel):
|
||||
x: int
|
||||
|
||||
class Foo2(BaseModel):
|
||||
class Foo2(TaggedModel):
|
||||
x: int
|
||||
|
||||
foos = Union[Foo1, Foo2]
|
||||
t1 = Foo1(x=1)
|
||||
assert t1.model_dump() == {"Foo1": {"x": 1}}
|
||||
|
||||
@tagged_union({"Foo1": Foo1, "Foo2": Foo2})
|
||||
class TaggedFoos(Tagged[foos]):
|
||||
pass
|
||||
|
||||
# ---- serialize (via custom model_serializer) ----
|
||||
t1 = TaggedFoos.from_(Foo1(x=1))
|
||||
assert t1.model_dump() == {"t": "Foo1", "c": {"x": 1}}
|
||||
|
||||
t2 = TaggedFoos.from_(Foo2(x=2))
|
||||
assert t2.model_dump() == {"t": "Foo2", "c": {"x": 2}}
|
||||
t2 = Foo2(x=2)
|
||||
assert t2.model_dump() == {"Foo2": {"x": 2}}
|
||||
|
||||
# ---- deserialize (TypeAdapter -> model_validator(before)) ----
|
||||
ta = TypeAdapter(TaggedFoos)
|
||||
ta = TypeAdapter[Foo1 | Foo2](Foo1 | Foo2)
|
||||
|
||||
out1 = ta.validate_python({"t": "Foo1", "c": {"x": 10}})
|
||||
assert isinstance(out1.c, Foo1) and out1.c.x == 10
|
||||
out1 = ta.validate_python({"Foo1": {"x": 10}})
|
||||
assert isinstance(out1, Foo1) and out1.x == 10
|
||||
|
||||
out2 = ta.validate_python({"t": "Foo2", "c": {"x": 20}})
|
||||
assert isinstance(out2.c, Foo2) and out2.c.x == 20
|
||||
out2 = ta.validate_python({"Foo2": {"x": 20}})
|
||||
assert isinstance(out2, Foo2) and out2.x == 20
|
||||
|
||||
|
||||
def test_tagged_union_rejects_unknown_tag():
|
||||
class Foo1(BaseModel):
|
||||
class Foo1(TaggedModel):
|
||||
x: int
|
||||
|
||||
class Foo2(BaseModel):
|
||||
class Foo2(TaggedModel):
|
||||
x: int
|
||||
|
||||
foos = Union[Foo1, Foo2]
|
||||
|
||||
@tagged_union({"Foo1": Foo1, "Foo2": Foo2})
|
||||
class TaggedFoos(Tagged[foos]):
|
||||
pass
|
||||
|
||||
ta = TypeAdapter(TaggedFoos)
|
||||
ta = TypeAdapter[Foo1 | Foo2](Foo1 | Foo2)
|
||||
with pytest.raises(ValidationError):
|
||||
ta.validate_python({"t": "NotARealTag", "c": {"x": 0}})
|
||||
|
||||
|
||||
def test_multiple_tagged_classes_do_not_override_each_others_mappings():
|
||||
"""
|
||||
Creating a *new* Tagged[T] class must not mutate the previously defined one.
|
||||
This checks both the tag mapping and the per-class adapter dicts.
|
||||
"""
|
||||
|
||||
class Foo1(BaseModel):
|
||||
x: int
|
||||
|
||||
class Foo2(BaseModel):
|
||||
x: int
|
||||
|
||||
foos = Union[Foo1, Foo2]
|
||||
|
||||
@tagged_union({"One": Foo1, "Two": Foo2})
|
||||
class TaggedEN(Tagged[foos]):
|
||||
pass
|
||||
|
||||
# Sanity: initial mapping/behavior
|
||||
obj_en_1 = TaggedEN.from_(Foo1(x=5))
|
||||
assert obj_en_1.t == "One"
|
||||
obj_en_2 = TaggedEN.from_(Foo2(x=6))
|
||||
assert obj_en_2.t == "Two"
|
||||
|
||||
# Define a second, different mapping
|
||||
@tagged_union({"Uno": Foo1, "Dos": Foo2})
|
||||
class TaggedES(Tagged[foos]):
|
||||
pass
|
||||
|
||||
# The two classes should have *independent* mappings
|
||||
# (not the same object, and not equal content)
|
||||
assert TaggedEN._type_bidict is not TaggedES._type_bidict # pyright: ignore
|
||||
assert TaggedEN._type_bidict != TaggedES._type_bidict # pyright: ignore
|
||||
|
||||
# Their adapters dicts should also be distinct objects
|
||||
assert TaggedEN._adapter_dict is not TaggedES._adapter_dict # pyright: ignore
|
||||
# And both should cover the same set of member types
|
||||
assert set(TaggedEN._adapter_dict.keys()) == {Foo1, Foo2} # pyright: ignore
|
||||
assert set(TaggedES._adapter_dict.keys()) == {Foo1, Foo2} # pyright: ignore
|
||||
|
||||
# Re-check that EN behavior has NOT changed after ES was created
|
||||
obj_en_1_again = TaggedEN.from_(Foo1(x=7))
|
||||
obj_en_2_again = TaggedEN.from_(Foo2(x=8))
|
||||
assert obj_en_1_again.t == "One"
|
||||
assert obj_en_2_again.t == "Two"
|
||||
|
||||
# ES behavior is per its *own* mapping
|
||||
obj_es_1 = TaggedES.from_(Foo1(x=9))
|
||||
obj_es_2 = TaggedES.from_(Foo2(x=10))
|
||||
assert obj_es_1.t == "Uno"
|
||||
assert obj_es_2.t == "Dos"
|
||||
|
||||
# And deserialization respects each class's mapping independently
|
||||
ta_en = TypeAdapter(TaggedEN)
|
||||
ta_es = TypeAdapter(TaggedES)
|
||||
|
||||
out_en = ta_en.validate_python({"t": "Two", "c": {"x": 123}})
|
||||
assert isinstance(out_en.c, Foo2) and out_en.c.x == 123
|
||||
|
||||
out_es = ta_es.validate_python({"t": "Dos", "c": {"x": 456}})
|
||||
assert isinstance(out_es.c, Foo2) and out_es.c.x == 456
|
||||
ta.validate_python({"NotARealTag": {"x": 0}})
|
||||
|
||||
|
||||
def test_two_tagged_classes_with_different_shapes_are_independent_and_not_cross_deserializable():
|
||||
class A1(BaseModel):
|
||||
class A1(TaggedModel):
|
||||
x: int
|
||||
|
||||
class A2(BaseModel):
|
||||
class A2(TaggedModel):
|
||||
name: str
|
||||
|
||||
union_a = Union[A1, A2]
|
||||
|
||||
@tagged_union({"One": A1, "Two": A2})
|
||||
class TaggedA(Tagged[union_a]):
|
||||
pass
|
||||
|
||||
class B1(BaseModel):
|
||||
class B1(TaggedModel):
|
||||
name: str
|
||||
|
||||
class B2(BaseModel):
|
||||
class B2(TaggedModel):
|
||||
active: bool
|
||||
|
||||
union_b = Union[B1, B2]
|
||||
a_payload = A1(x=123).model_dump()
|
||||
b_payload = B1(name="neo").model_dump()
|
||||
|
||||
# Note: using the SAME tag strings intentionally to ensure mappings are per-class
|
||||
@tagged_union({"One": B1, "Two": B2})
|
||||
class TaggedB(Tagged[union_b]):
|
||||
pass
|
||||
assert a_payload == {"A1": {"x": 123}}
|
||||
assert b_payload == {"B1": {"name": "neo"}}
|
||||
|
||||
# --- Per-class state must be independent ---
|
||||
assert TaggedA._type_bidict is not TaggedB._type_bidict # pyright: ignore
|
||||
assert TaggedA._adapter_dict is not TaggedB._adapter_dict # pyright: ignore
|
||||
assert set(TaggedA._adapter_dict.keys()) == {A1, A2} # pyright: ignore
|
||||
assert set(TaggedB._adapter_dict.keys()) == {B1, B2} # pyright: ignore
|
||||
|
||||
# --- Round-trip for each class with overlapping tag strings ---
|
||||
a_payload = TaggedA.from_(A1(x=123)).model_dump()
|
||||
b_payload = TaggedB.from_(B1(name="neo")).model_dump()
|
||||
|
||||
assert a_payload == {"t": "One", "c": {"x": 123}}
|
||||
assert b_payload == {"t": "One", "c": {"name": "neo"}}
|
||||
|
||||
# --- Cross-deserialization must fail despite overlapping "t" values ---
|
||||
ta_a = TypeAdapter(TaggedA)
|
||||
ta_b = TypeAdapter(TaggedB)
|
||||
ta_a = TypeAdapter[A1 | A2](A1 | A2)
|
||||
ta_b = TypeAdapter[B1 | B2](B1 | B2)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ta_a.validate_python(b_payload) # TaggedA expects {"x": ...} for tag "One"
|
||||
ta_a.validate_python(b_payload)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ta_b.validate_python(a_payload) # TaggedB expects {"name": ...} for tag "One"
|
||||
ta_b.validate_python(a_payload)
|
||||
|
||||
|
||||
class Inner(TaggedModel):
|
||||
x: int
|
||||
|
||||
|
||||
class Outer(TaggedModel):
|
||||
inner: Inner
|
||||
|
||||
|
||||
class Wrapper(TaggedModel):
|
||||
outer: Outer
|
||||
label: str
|
||||
|
||||
|
||||
class Container(TaggedModel):
|
||||
items: list[Inner]
|
||||
nested: Wrapper
|
||||
|
||||
|
||||
def test_single_level_tagging():
|
||||
inner = Inner(x=10)
|
||||
dumped = inner.model_dump()
|
||||
assert dumped == {"Inner": {"x": 10}}
|
||||
|
||||
restored = Inner.model_validate(dumped)
|
||||
assert isinstance(restored, Inner)
|
||||
assert restored.x == 10
|
||||
|
||||
|
||||
def test_nested_externally_tagged_union_serializes_recursively():
|
||||
outer = Outer(inner=Inner(x=42))
|
||||
dumped = outer.model_dump()
|
||||
|
||||
assert dumped == {"Outer": {"inner": {"Inner": {"x": 42}}}}
|
||||
|
||||
restored = Outer.model_validate(dumped)
|
||||
assert isinstance(restored.inner, Inner)
|
||||
assert restored.inner.x == 42
|
||||
|
||||
|
||||
def test_two_level_nested_tagging():
|
||||
outer = Outer(inner=Inner(x=123))
|
||||
dumped = outer.model_dump()
|
||||
assert dumped == {"Outer": {"inner": {"Inner": {"x": 123}}}}
|
||||
|
||||
restored = Outer.model_validate(dumped)
|
||||
assert isinstance(restored.inner, Inner)
|
||||
assert restored.inner.x == 123
|
||||
|
||||
|
||||
def test_three_level_nested_tagging():
|
||||
wrapper = Wrapper(label="deep", outer=Outer(inner=Inner(x=7)))
|
||||
dumped = wrapper.model_dump()
|
||||
# 3-level structure, each with exactly one tag
|
||||
assert dumped == {
|
||||
"Wrapper": {
|
||||
"label": "deep",
|
||||
"outer": {"Outer": {"inner": {"Inner": {"x": 7}}}},
|
||||
}
|
||||
}
|
||||
|
||||
restored = Wrapper.model_validate(dumped)
|
||||
assert isinstance(restored.outer.inner, Inner)
|
||||
assert restored.outer.inner.x == 7
|
||||
assert restored.label == "deep"
|
||||
|
||||
|
||||
def test_lists_and_mixed_nested_structures():
|
||||
container = Container(
|
||||
items=[Inner(x=1), Inner(x=2)],
|
||||
nested=Wrapper(label="mix", outer=Outer(inner=Inner(x=9))),
|
||||
)
|
||||
dumped = container.model_dump()
|
||||
|
||||
assert dumped == {
|
||||
"Container": {
|
||||
"items": [
|
||||
{"Inner": {"x": 1}},
|
||||
{"Inner": {"x": 2}},
|
||||
],
|
||||
"nested": {
|
||||
"Wrapper": {
|
||||
"label": "mix",
|
||||
"outer": {"Outer": {"inner": {"Inner": {"x": 9}}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
restored = Container.model_validate(dumped)
|
||||
assert isinstance(restored.nested.outer.inner, Inner)
|
||||
assert [i.x for i in restored.items] == [1, 2]
|
||||
|
||||
|
||||
def test_no_double_tagging_on_repeated_calls():
|
||||
"""Ensure multiple model_dump calls don't stack tags."""
|
||||
inner = Inner(x=11)
|
||||
dumped1 = inner.model_dump()
|
||||
dumped2 = inner.model_dump()
|
||||
assert dumped1 == dumped2 == {"Inner": {"x": 11}}
|
||||
|
||||
outer = Outer(inner=inner)
|
||||
d1 = outer.model_dump()
|
||||
d2 = outer.model_dump()
|
||||
assert d1 == d2 == {"Outer": {"inner": {"Inner": {"x": 11}}}}
|
||||
|
||||
|
||||
class L3A(TaggedModel):
|
||||
x: int
|
||||
|
||||
|
||||
class L3B(TaggedModel):
|
||||
x: int
|
||||
|
||||
|
||||
class L3C(TaggedModel):
|
||||
x: int
|
||||
|
||||
|
||||
L3 = L3A | L3B | L3C
|
||||
|
||||
|
||||
class L2A(TaggedModel):
|
||||
child: L3
|
||||
|
||||
|
||||
class L2B(TaggedModel):
|
||||
child: L3
|
||||
|
||||
|
||||
class L2C(TaggedModel):
|
||||
child: L3
|
||||
|
||||
|
||||
L2 = L2A | L2B | L2C
|
||||
|
||||
|
||||
class L1A(TaggedModel):
|
||||
child: L2
|
||||
|
||||
|
||||
class L1B(TaggedModel):
|
||||
child: L2
|
||||
|
||||
|
||||
class L1C(TaggedModel):
|
||||
child: L2
|
||||
|
||||
|
||||
L1 = L1A | L1B | L1C
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tagged_union_is_fast():
|
||||
# payload along the "C" path (worst case for DFS if branches are tried A->B->C)
|
||||
payload = {"L1C": {"child": {"L2C": {"child": {"L3C": {"x": 123}}}}}}
|
||||
|
||||
with anyio.fail_after(0.1):
|
||||
out = TypeAdapter(L1).validate_python(payload) # type: ignore
|
||||
|
||||
# Sanity check the result
|
||||
assert out.__class__.__name__ == "L1C" # type: ignore
|
||||
assert out.child.__class__.__name__ == "L2C" # type: ignore
|
||||
assert out.child.child.__class__.__name__ == "L3C" # type: ignore
|
||||
assert out.child.child.x == 123 # type: ignore
|
||||
|
||||
@@ -12,12 +12,9 @@ from urllib.parse import urljoin
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, DirectoryPath, Field, PositiveInt, TypeAdapter
|
||||
from pydantic import BaseModel, DirectoryPath, Field, PositiveInt, TypeAdapter, ConfigDict
|
||||
|
||||
from exo.shared.constants import EXO_HOME
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import DownloadProgressData
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
@@ -56,8 +53,7 @@ class RepoFileDownloadProgress(BaseModel):
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
start_time: float
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
model_config = ConfigDict(frozen = True)
|
||||
|
||||
|
||||
class RepoDownloadProgress(BaseModel):
|
||||
@@ -91,31 +87,10 @@ class RepoDownloadProgress(BaseModel):
|
||||
# fine-grained file progress keyed by file_path
|
||||
file_progress: Dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
model_config = ConfigDict(
|
||||
frozen = True # allow use as dict keys if desired
|
||||
)
|
||||
|
||||
def map_repo_file_download_progress_to_download_progress_data(repo_file_download_progress: RepoFileDownloadProgress) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
downloaded_bytes=Memory.from_bytes(repo_file_download_progress.downloaded),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(repo_file_download_progress.downloaded_this_session),
|
||||
total_bytes=Memory.from_bytes(repo_file_download_progress.total),
|
||||
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
|
||||
total_files=1,
|
||||
speed=repo_file_download_progress.speed,
|
||||
eta_ms=int(repo_file_download_progress.eta.total_seconds() * 1000),
|
||||
files={},
|
||||
)
|
||||
def map_repo_download_progress_to_download_progress_data(repo_download_progress: RepoDownloadProgress) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
total_bytes=Memory.from_bytes(repo_download_progress.total_bytes),
|
||||
downloaded_bytes=Memory.from_bytes(repo_download_progress.downloaded_bytes),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(repo_download_progress.downloaded_bytes_this_session),
|
||||
completed_files=repo_download_progress.completed_files,
|
||||
total_files=repo_download_progress.total_files,
|
||||
speed=repo_download_progress.overall_speed,
|
||||
eta_ms=int(repo_download_progress.overall_eta.total_seconds() * 1000),
|
||||
files={file_path: map_repo_file_download_progress_to_download_progress_data(file_progress) for file_path, file_progress in repo_download_progress.file_progress.items()},
|
||||
)
|
||||
|
||||
def build_model_path(model_id: str) -> DirectoryPath:
|
||||
return EXO_HOME / "models" / model_id.replace("/", "--")
|
||||
@@ -166,13 +141,13 @@ async def seed_models(seed_dir: Union[str, Path]):
|
||||
if path.is_dir() and path.name.startswith("models--"):
|
||||
dest_path = dest_dir / path.name
|
||||
if await aios.path.exists(dest_path):
|
||||
logger.info("Skipping moving model to .cache directory")
|
||||
print("Skipping moving model to .cache directory")
|
||||
else:
|
||||
try:
|
||||
await aios.rename(str(path), str(dest_path))
|
||||
except Exception:
|
||||
logger.error(f"Error seeding model {path} to {dest_path}")
|
||||
logger.error(traceback.format_exc())
|
||||
print(f"Error seeding model {path} to {dest_path}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def fetch_file_list_with_cache(
|
||||
@@ -262,13 +237,9 @@ async def file_meta(
|
||||
if redirected_location is None
|
||||
else f"{get_hf_endpoint()}{redirected_location}"
|
||||
)
|
||||
# Ensure identity transfer to keep Content-Length and byte accounting
|
||||
# consistent with on-disk sizes and progress totals.
|
||||
headers = {**(await get_auth_headers()), "Accept-Encoding": "identity"}
|
||||
headers = await get_auth_headers()
|
||||
async with (
|
||||
aiohttp.ClientSession(
|
||||
# Disable transparent decompression; we want raw bytes as served.
|
||||
auto_decompress=False,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=1800, connect=60, sock_read=1800, sock_connect=60
|
||||
)
|
||||
@@ -276,18 +247,22 @@ async def file_meta(
|
||||
session.head(url, headers=headers) as r,
|
||||
):
|
||||
if r.status == 307:
|
||||
# On redirect, only trust Hugging Face's x-linked-* headers.
|
||||
x_linked_size = r.headers.get("x-linked-size")
|
||||
x_linked_etag = r.headers.get("X-Linked-ETag")
|
||||
if x_linked_size and x_linked_etag:
|
||||
content_length = int(x_linked_size)
|
||||
etag = x_linked_etag
|
||||
# Try to extract from X-Linked headers first (common for HF redirects)
|
||||
content_length = int(
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||
)
|
||||
etag = (
|
||||
r.headers.get("X-Linked-ETag")
|
||||
or r.headers.get("ETag")
|
||||
or r.headers.get("Etag")
|
||||
)
|
||||
if content_length > 0 and etag is not None:
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (
|
||||
etag[0] == "'" and etag[-1] == "'"
|
||||
):
|
||||
etag = etag[1:-1]
|
||||
return content_length, etag
|
||||
# Otherwise, follow the redirect to get authoritative size/hash
|
||||
# If not available, recurse with the redirect
|
||||
redirected_location = r.headers.get("Location")
|
||||
return await file_meta(repo_id, revision, path, redirected_location)
|
||||
content_length = int(
|
||||
@@ -321,10 +296,10 @@ async def download_file_with_retry(
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
|
||||
raise e
|
||||
logger.error(
|
||||
print(
|
||||
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
traceback.print_exc()
|
||||
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
|
||||
raise Exception(
|
||||
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
|
||||
@@ -351,15 +326,12 @@ async def _download_file(
|
||||
)
|
||||
if resume_byte_pos != length:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
# Request identity encoding so received byte counts match on-disk size
|
||||
headers = {**(await get_auth_headers()), "Accept-Encoding": "identity"}
|
||||
headers = await get_auth_headers()
|
||||
if resume_byte_pos:
|
||||
headers["Range"] = f"bytes={resume_byte_pos}-"
|
||||
n_read = resume_byte_pos or 0
|
||||
async with (
|
||||
aiohttp.ClientSession(
|
||||
# Keep raw transfer semantics (no transparent decompression)
|
||||
auto_decompress=False,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=1800, connect=60, sock_read=1800, sock_connect=60
|
||||
)
|
||||
@@ -392,7 +364,7 @@ async def _download_file(
|
||||
try:
|
||||
await aios.remove(partial_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing partial file {partial_path}: {e}")
|
||||
print(f"Error removing partial file {partial_path}: {e}")
|
||||
raise Exception(
|
||||
f"Downloaded file {target_dir / path} has hash {final_hash} but remote hash is {remote_hash}"
|
||||
)
|
||||
@@ -462,8 +434,8 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> List[str]:
|
||||
weight_map = await get_weight_map(str(shard.model_meta.model_id))
|
||||
return get_allow_patterns(weight_map, shard)
|
||||
except Exception:
|
||||
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
|
||||
logger.error(traceback.format_exc())
|
||||
print(f"Error getting weight map for {shard.model_meta.model_id=}")
|
||||
traceback.print_exc()
|
||||
return ["*"]
|
||||
|
||||
|
||||
@@ -533,11 +505,11 @@ async def download_shard(
|
||||
allow_patterns: List[str] | None = None,
|
||||
) -> tuple[Path, RepoDownloadProgress]:
|
||||
if not skip_download:
|
||||
logger.info(f"Downloading {shard.model_meta.model_id=}")
|
||||
print(f"Downloading {shard.model_meta.model_id=}")
|
||||
|
||||
# Handle local paths
|
||||
if await aios.path.exists(str(shard.model_meta.model_id)):
|
||||
logger.info(f"Using local model path {shard.model_meta.model_id}")
|
||||
print(f"Using local model path {shard.model_meta.model_id}")
|
||||
local_path = Path(str(shard.model_meta.model_id))
|
||||
return local_path, await download_progress_for_local_path(
|
||||
str(shard.model_meta.model_id), shard, local_path
|
||||
@@ -553,7 +525,7 @@ async def download_shard(
|
||||
if not allow_patterns:
|
||||
allow_patterns = await resolve_allow_patterns(shard)
|
||||
|
||||
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
|
||||
print(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import AsyncIterator, Callable, Dict, List, Optional
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.worker.shards import (
|
||||
PartitionStrategy,
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
@@ -24,7 +23,6 @@ async def build_base_shard(model_id: str) -> Optional[ShardMetadata]:
|
||||
# print(f"build_base_shard {model_id=} {model_meta=}")
|
||||
return PipelineShardMetadata(
|
||||
model_meta=model_meta,
|
||||
partition_strategy=PartitionStrategy.pipeline,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
@@ -39,7 +37,6 @@ async def build_full_shard(model_id: str) -> Optional[PipelineShardMetadata]:
|
||||
return None
|
||||
return PipelineShardMetadata(
|
||||
model_meta=base_shard.model_meta,
|
||||
partition_strategy=base_shard.partition_strategy,
|
||||
device_rank=base_shard.device_rank,
|
||||
world_size=base_shard.world_size,
|
||||
start_layer=base_shard.start_layer,
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import AsyncIterator, Callable
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.shards import (
|
||||
PartitionStrategy,
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
@@ -57,7 +56,6 @@ class ShardDownloader(ABC):
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
),
|
||||
partition_strategy=PartitionStrategy.pipeline,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
@@ -107,7 +105,6 @@ class NoopShardDownloader(ShardDownloader):
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
),
|
||||
partition_strategy=PartitionStrategy.pipeline,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
from asyncio import Queue
|
||||
from functools import partial
|
||||
@@ -13,8 +12,7 @@ from loguru import logger
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.apply import apply
|
||||
from exo.worker.download.download_utils import map_repo_download_progress_to_download_progress_data
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog, TaggedCommand
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
@@ -27,12 +25,12 @@ from exo.shared.types.events import (
|
||||
NodePerformanceMeasured,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaggedEvent,
|
||||
TaskFailed,
|
||||
TaskStateUpdated,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
@@ -43,6 +41,7 @@ from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgressData,
|
||||
)
|
||||
from exo.shared.types.worker.ops import (
|
||||
AssignRunnerOp,
|
||||
@@ -50,7 +49,6 @@ from exo.shared.types.worker.ops import (
|
||||
RunnerDownOp,
|
||||
RunnerFailedOp,
|
||||
RunnerOp,
|
||||
RunnerOpType,
|
||||
RunnerUpOp,
|
||||
UnassignRunnerOp,
|
||||
)
|
||||
@@ -120,25 +118,23 @@ class Worker:
|
||||
),
|
||||
)
|
||||
|
||||
async def memory_monitor_callback(
|
||||
memory_profile: MemoryPerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_publisher(
|
||||
NodeMemoryMeasured(node_id=self.node_id, memory=memory_profile)
|
||||
)
|
||||
|
||||
# END CLEANUP
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
||||
|
||||
async def memory_monitor_callback(
|
||||
memory_profile: MemoryPerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_publisher(
|
||||
NodeMemoryMeasured(node_id=self.node_id, memory=memory_profile)
|
||||
)
|
||||
|
||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
# Proactively request a global event sync at startup to backfill any missed events.
|
||||
tg.start_soon(self._request_full_event_log_once)
|
||||
# TODO: This is a little gross, but not too bad
|
||||
for msg in self._initial_connection_messages:
|
||||
await self.event_publisher(
|
||||
@@ -156,8 +152,8 @@ class Worker:
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for event in events:
|
||||
self.event_buffer.ingest(event.origin_idx, event.tagged_event.c)
|
||||
event_id = event.tagged_event.c.event_id
|
||||
self.event_buffer.ingest(event.origin_idx, event.event)
|
||||
event_id = event.event.event_id
|
||||
if event_id in self.out_for_delivery:
|
||||
del self.out_for_delivery[event_id]
|
||||
|
||||
@@ -201,8 +197,6 @@ class Worker:
|
||||
async for event in self.execute_op(op):
|
||||
await self.event_publisher(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing op: {str(op)[:100]}")
|
||||
logger.error(traceback.format_exc())
|
||||
if isinstance(op, ExecuteTaskOp):
|
||||
generator = self.fail_task(
|
||||
e, runner_id=op.runner_id, task_id=op.task.task_id
|
||||
@@ -227,7 +221,6 @@ class Worker:
|
||||
def _convert_connection_message_to_event(self, msg: ConnectionMessage):
|
||||
match msg.connection_type:
|
||||
case ConnectionMessageType.Connected:
|
||||
logger.warning(f"!!! Node {self.node_id} connected to {msg.node_id}")
|
||||
return TopologyEdgeCreated(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
@@ -239,7 +232,6 @@ class Worker:
|
||||
)
|
||||
|
||||
case ConnectionMessageType.Disconnected:
|
||||
logger.warning(f"!!! Node {self.node_id} disconnected from {msg.node_id}")
|
||||
return TopologyEdgeDeleted(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
@@ -262,27 +254,13 @@ class Worker:
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self.node_id,
|
||||
tagged_command=TaggedCommand.from_(
|
||||
RequestEventLog(since_idx=self.event_buffer.next_idx_to_release)
|
||||
),
|
||||
command=RequestEventLog(since_idx=0),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if self._nack_cancel_scope is scope:
|
||||
self._nack_cancel_scope = None
|
||||
|
||||
async def _request_full_event_log_once(self) -> None:
|
||||
# Fire-and-forget one-time sync shortly after startup.
|
||||
await anyio.sleep(0.1)
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self.node_id,
|
||||
tagged_command=TaggedCommand.from_(
|
||||
RequestEventLog(since_idx=self.event_buffer.next_idx_to_release)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def _resend_out_for_delivery(self) -> None:
|
||||
# This can also be massively tightened, we should check events are at least a certain age before resending.
|
||||
# Exponential backoff would also certainly help here.
|
||||
@@ -340,8 +318,13 @@ class Worker:
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(initial_progress),
|
||||
),
|
||||
download_progress=DownloadProgressData(
|
||||
total_bytes=Memory.from_bytes(initial_progress.total_bytes),
|
||||
downloaded_bytes=Memory.from_bytes(
|
||||
initial_progress.downloaded_bytes
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
|
||||
@@ -357,24 +340,15 @@ class Worker:
|
||||
download_task = asyncio.create_task(
|
||||
self.shard_downloader.ensure_shard(op.shard_metadata)
|
||||
)
|
||||
logger.info(f"Started download for {op.shard_metadata.model_meta.model_id}")
|
||||
|
||||
try:
|
||||
async for event in self._monitor_download_progress(
|
||||
assigned_runner, download_progress_queue
|
||||
):
|
||||
yield event
|
||||
# in case the download needs to finish up, wait up to 60 secs for it to finish
|
||||
# this fixes a bug where the download gets cancelled before it can rename .partial file on finish
|
||||
await asyncio.wait_for(download_task, timeout=15)
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring download progress: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise e
|
||||
finally:
|
||||
if not download_task.done():
|
||||
download_task.cancel()
|
||||
|
||||
|
||||
async def _monitor_download_progress(
|
||||
self,
|
||||
@@ -403,7 +377,12 @@ class Worker:
|
||||
assigned_runner.status = DownloadingRunnerStatus(
|
||||
download_progress=DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(progress),
|
||||
download_progress=DownloadProgressData(
|
||||
total_bytes=Memory.from_bytes(progress.total_bytes),
|
||||
downloaded_bytes=Memory.from_bytes(
|
||||
progress.downloaded_bytes
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
yield assigned_runner.status_update_event()
|
||||
@@ -424,11 +403,9 @@ class Worker:
|
||||
)
|
||||
|
||||
if initial_progress.status == "complete":
|
||||
logger.info(f"Shard {op.shard_metadata.model_meta.model_id} already downloaded")
|
||||
async for event in self._handle_already_downloaded_shard(assigned_runner):
|
||||
yield event
|
||||
else:
|
||||
logger.info(f"Shard {op.shard_metadata.model_meta.model_id} not downloaded, starting download.")
|
||||
async for event in self._handle_shard_download_process(
|
||||
assigned_runner, op, initial_progress
|
||||
):
|
||||
@@ -526,7 +503,7 @@ class Worker:
|
||||
await queue.put(
|
||||
TaskStateUpdated(
|
||||
task_id=op.task.task_id,
|
||||
task_status=TaskStatus.RUNNING,
|
||||
task_status=TaskStatus.Running,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -547,14 +524,14 @@ class Worker:
|
||||
)
|
||||
|
||||
if op.task.task_id in self.state.tasks:
|
||||
self.state.tasks[op.task.task_id].task_status = TaskStatus.COMPLETE
|
||||
self.state.tasks[op.task.task_id].task_status = TaskStatus.Complete
|
||||
|
||||
if assigned_runner.shard_metadata.device_rank == 0:
|
||||
# kind of hack - we don't want to wait for the round trip for this to complete
|
||||
await queue.put(
|
||||
TaskStateUpdated(
|
||||
task_id=op.task.task_id,
|
||||
task_status=TaskStatus.COMPLETE,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -601,18 +578,18 @@ class Worker:
|
||||
|
||||
async def execute_op(self, op: RunnerOp) -> AsyncGenerator[Event, None]:
|
||||
## It would be great if we can get rid of this async for ... yield pattern.
|
||||
match op.op_type:
|
||||
case RunnerOpType.ASSIGN_RUNNER:
|
||||
match op:
|
||||
case AssignRunnerOp():
|
||||
event_generator = self._execute_assign_op(op)
|
||||
case RunnerOpType.UNASSIGN_RUNNER:
|
||||
case UnassignRunnerOp():
|
||||
event_generator = self._execute_unassign_op(op)
|
||||
case RunnerOpType.RUNNER_UP:
|
||||
case RunnerUpOp():
|
||||
event_generator = self._execute_runner_up_op(op)
|
||||
case RunnerOpType.RUNNER_DOWN:
|
||||
case RunnerDownOp():
|
||||
event_generator = self._execute_runner_down_op(op)
|
||||
case RunnerOpType.RUNNER_FAILED:
|
||||
case RunnerFailedOp():
|
||||
event_generator = self._execute_runner_failed_op(op)
|
||||
case RunnerOpType.CHAT_COMPLETION:
|
||||
case ExecuteTaskOp():
|
||||
event_generator = self._execute_task_op(op)
|
||||
|
||||
async for event in event_generator:
|
||||
@@ -643,7 +620,7 @@ class Worker:
|
||||
if runner_id in self.assigned_runners:
|
||||
yield TaskStateUpdated(
|
||||
task_id=task_id,
|
||||
task_status=TaskStatus.FAILED,
|
||||
task_status=TaskStatus.Failed,
|
||||
)
|
||||
|
||||
yield TaskFailed(
|
||||
@@ -653,15 +630,21 @@ class Worker:
|
||||
async for event in self.fail_runner(e, runner_id):
|
||||
yield event
|
||||
|
||||
|
||||
|
||||
# This function is re-entrant, take care!
|
||||
async def event_publisher(self, event: Event) -> None:
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=self.local_event_index,
|
||||
origin=self.node_id,
|
||||
tagged_event=TaggedEvent.from_(event),
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
|
||||
)
|
||||
self.local_event_index += 1
|
||||
await self.local_event_sender.send(fe)
|
||||
self.out_for_delivery[event.event_id] = fe
|
||||
self.local_event_index += 1
|
||||
|
||||
|
||||
def event_relevant_to_worker(event: Event, worker: Worker):
|
||||
|
||||
@@ -6,7 +6,7 @@ from exo.shared.types.events import (
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.common import RunnerId
|
||||
from exo.shared.types.worker.downloads import DownloadStatus
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.instances import Instance, InstanceStatus
|
||||
from exo.shared.types.worker.ops import (
|
||||
AssignRunnerOp,
|
||||
@@ -23,8 +23,8 @@ from exo.shared.types.worker.runners import (
|
||||
InactiveRunnerStatus,
|
||||
LoadedRunnerStatus,
|
||||
RunnerStatus,
|
||||
RunnerStatusType,
|
||||
RunningRunnerStatus,
|
||||
StartingRunnerStatus,
|
||||
)
|
||||
from exo.worker.common import AssignedRunner
|
||||
|
||||
@@ -45,14 +45,12 @@ def unassign_runners(
|
||||
|
||||
# If our instance is in 'downloading' or 'assigned' state, then we know the runner is stale. These are part of AssignRunnerOp and should be blocking.
|
||||
for assigned_runner_id in assigned_runners:
|
||||
if (
|
||||
assigned_runner_id in state_runners
|
||||
and isinstance(state_runners[assigned_runner_id], DownloadingRunnerStatus)
|
||||
# Not sure about this type ignore, i don't think it should be necessary
|
||||
and state_runners[assigned_runner_id].download_progress.download_status # type: ignore
|
||||
!= DownloadStatus.Completed
|
||||
):
|
||||
return UnassignRunnerOp(runner_id=assigned_runner_id)
|
||||
if assigned_runner_id in state_runners:
|
||||
status = state_runners[assigned_runner_id]
|
||||
if isinstance(status, DownloadingRunnerStatus) and not isinstance(
|
||||
status.download_progress, DownloadCompleted
|
||||
):
|
||||
return UnassignRunnerOp(runner_id=assigned_runner_id)
|
||||
|
||||
return None
|
||||
|
||||
@@ -85,7 +83,7 @@ def spin_down_runners(
|
||||
if (
|
||||
runner_id in assigned_runners
|
||||
and isinstance(assigned_runners[runner_id].status, LoadedRunnerStatus)
|
||||
and instance.instance_type == InstanceStatus.INACTIVE
|
||||
and instance.instance_type == InstanceStatus.Inactive
|
||||
):
|
||||
return RunnerDownOp(runner_id=runner_id)
|
||||
|
||||
@@ -195,18 +193,19 @@ def spin_up_runners(
|
||||
instance.shard_assignments.node_to_runner[worker_node_id]
|
||||
].runner
|
||||
is None
|
||||
and instance.instance_type == InstanceStatus.ACTIVE
|
||||
and instance.instance_type == InstanceStatus.Active
|
||||
):
|
||||
# We are part of this instance, we want it up but it hasn't been spun up yet.
|
||||
# Need to assert all other runners are ready before we can spin up.
|
||||
ready_to_spin = True
|
||||
for runner_id in instance.shard_assignments.node_to_runner.values():
|
||||
if runner_id in state_runners and state_runners[
|
||||
runner_id
|
||||
].runner_status not in [
|
||||
RunnerStatusType.Inactive,
|
||||
RunnerStatusType.Starting,
|
||||
]:
|
||||
if runner_id in state_runners and isinstance(
|
||||
state_runners[runner_id],
|
||||
(
|
||||
InactiveRunnerStatus,
|
||||
StartingRunnerStatus,
|
||||
),
|
||||
):
|
||||
ready_to_spin = False
|
||||
|
||||
if ready_to_spin:
|
||||
@@ -229,13 +228,12 @@ def execute_task_op(
|
||||
continue
|
||||
assert runner_id in assigned_runners
|
||||
runner = assigned_runners[runner_id]
|
||||
if runner.status.runner_status != RunnerStatusType.Loaded:
|
||||
if not isinstance(runner.status, LoadedRunnerStatus):
|
||||
continue # The only previous state to get to Running is from Loaded
|
||||
|
||||
for _, task in tasks.items():
|
||||
if task.instance_id == instance_id and (
|
||||
task.task_status == TaskStatus.PENDING
|
||||
or task.task_status == TaskStatus.FAILED
|
||||
task.task_status in (TaskStatus.Pending, TaskStatus.Failed)
|
||||
):
|
||||
if (
|
||||
runner.shard_metadata.device_rank >= 1
|
||||
|
||||
@@ -10,7 +10,6 @@ from exo.shared.types.tasks import (
|
||||
ChatCompletionTask,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
from exo.shared.types.worker.instances import Instance, InstanceStatus
|
||||
@@ -131,7 +130,7 @@ def instance(
|
||||
|
||||
return Instance(
|
||||
instance_id=resolved_instance_id,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_type=InstanceStatus.Active,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(1),
|
||||
)
|
||||
@@ -161,8 +160,7 @@ def chat_completion_task(completion_create_params: ChatCompletionTaskParams):
|
||||
task_id=resolved_task_id,
|
||||
command_id=COMMAND_1_ID,
|
||||
instance_id=resolved_instance_id,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=completion_create_params,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,29 +1,22 @@
|
||||
import time
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.worker.download.shard_downloader import ShardDownloader
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_shard_downloader(
|
||||
pipeline_shard_meta: Callable[[int, int], PipelineShardMetadata],
|
||||
):
|
||||
shutil.rmtree(Path(os.path.expanduser("~/.exo/models/mlx-community--Llama-3.2-1B-Instruct-4bit")))
|
||||
|
||||
progress_log: list[RepoDownloadProgress] = []
|
||||
shard_downloader: ShardDownloader = exo_shard_downloader()
|
||||
def _on_progress(shard: ShardMetadata, progress: RepoDownloadProgress):
|
||||
print(f"Download progress: {progress}")
|
||||
progress_log.append(progress)
|
||||
shard_downloader.on_progress(_on_progress)
|
||||
shard_downloader.on_progress(
|
||||
lambda shard, progress: print(f"Download progress: {progress}")
|
||||
)
|
||||
|
||||
shard_metadata = pipeline_shard_meta(1, 0)
|
||||
path = await shard_downloader.ensure_shard(shard_metadata)
|
||||
@@ -54,12 +47,3 @@ async def test_shard_downloader(
|
||||
duration = time.monotonic() - start_time
|
||||
assert path_again == path
|
||||
assert duration < 5, f"Second call to ensure_shard took too long: {duration:.2f}s"
|
||||
|
||||
print(progress_log[-1].file_progress)
|
||||
|
||||
assert len(progress_log) > 0
|
||||
assert progress_log[-1].status == "complete"
|
||||
assert progress_log[-1].completed_files == 6
|
||||
assert progress_log[-1].total_files == 6
|
||||
assert progress_log[-1].downloaded_bytes == sum(file_size for _, file_size in expected_files_and_sizes)
|
||||
assert progress_log[-1].total_bytes == sum(file_size for _, file_size in expected_files_and_sizes)
|
||||
|
||||
@@ -145,10 +145,10 @@ async def test_execute_task_op(
|
||||
assert isinstance(events[0].runner_status, RunningRunnerStatus)
|
||||
|
||||
assert isinstance(events[1], TaskStateUpdated)
|
||||
assert events[1].task_status == TaskStatus.RUNNING # It tried to start.
|
||||
assert events[1].task_status == TaskStatus.Running # It tried to start.
|
||||
|
||||
assert isinstance(events[-2], TaskStateUpdated)
|
||||
assert events[-2].task_status == TaskStatus.COMPLETE # It tried to start.
|
||||
assert events[-2].task_status == TaskStatus.Complete # It tried to start.
|
||||
|
||||
assert isinstance(events[-1], RunnerStatusUpdated)
|
||||
assert isinstance(
|
||||
|
||||
@@ -17,7 +17,6 @@ from exo.shared.types.tasks import (
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from exo.shared.types.worker.common import InstanceId, RunnerId
|
||||
from exo.shared.types.worker.instances import (
|
||||
@@ -57,7 +56,7 @@ async def test_runner_inference(
|
||||
async with create_task_group() as tg:
|
||||
tg.start_soon(worker.run)
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
instance_value.instance_type = InstanceStatus.Active
|
||||
|
||||
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
await global_events.append_events(
|
||||
@@ -120,7 +119,7 @@ async def test_2_runner_inference(
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_type=InstanceStatus.Active,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2),
|
||||
)
|
||||
@@ -190,7 +189,7 @@ async def test_2_runner_multi_message(
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_type=InstanceStatus.Active,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2),
|
||||
)
|
||||
@@ -218,8 +217,7 @@ async def test_2_runner_multi_message(
|
||||
task_id=TASK_1_ID,
|
||||
command_id=CommandId(),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=completion_create_params,
|
||||
)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ async def test_stream_response_failed_always(
|
||||
async with create_task_group() as tg:
|
||||
tg.start_soon(worker.run)
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
instance_value.instance_type = InstanceStatus.Active
|
||||
|
||||
async def mock_stream_response(
|
||||
self: RunnerSupervisor,
|
||||
@@ -88,8 +88,8 @@ async def test_stream_response_failed_always(
|
||||
[
|
||||
x
|
||||
for x in events
|
||||
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
|
||||
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
|
||||
if isinstance(x.event, RunnerStatusUpdated)
|
||||
and isinstance(x.event.runner_status, FailedRunnerStatus)
|
||||
]
|
||||
)
|
||||
== 3
|
||||
@@ -99,13 +99,13 @@ async def test_stream_response_failed_always(
|
||||
[
|
||||
x
|
||||
for x in events
|
||||
if isinstance(x.tagged_event.c, TaskStateUpdated)
|
||||
and x.tagged_event.c.task_status == TaskStatus.FAILED
|
||||
if isinstance(x.event, TaskStateUpdated)
|
||||
and x.event.task_status == TaskStatus.Failed
|
||||
]
|
||||
)
|
||||
== 3
|
||||
)
|
||||
assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events])
|
||||
assert any([isinstance(x.event, InstanceDeleted) for x in events])
|
||||
|
||||
await global_events.append_events(
|
||||
[
|
||||
@@ -152,7 +152,7 @@ async def test_stream_response_failed_once(
|
||||
async with create_task_group() as tg:
|
||||
tg.start_soon(worker.run)
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
instance_value.instance_type = InstanceStatus.Active
|
||||
|
||||
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
await global_events.append_events(
|
||||
@@ -186,8 +186,8 @@ async def test_stream_response_failed_once(
|
||||
[
|
||||
x
|
||||
for x in events
|
||||
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
|
||||
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
|
||||
if isinstance(x.event, RunnerStatusUpdated)
|
||||
and isinstance(x.event.runner_status, FailedRunnerStatus)
|
||||
]
|
||||
)
|
||||
== 1
|
||||
@@ -197,8 +197,8 @@ async def test_stream_response_failed_once(
|
||||
[
|
||||
x
|
||||
for x in events
|
||||
if isinstance(x.tagged_event.c, TaskStateUpdated)
|
||||
and x.tagged_event.c.task_status == TaskStatus.FAILED
|
||||
if isinstance(x.event, TaskStateUpdated)
|
||||
and x.event.task_status == TaskStatus.Failed
|
||||
]
|
||||
)
|
||||
== 1
|
||||
@@ -209,11 +209,11 @@ async def test_stream_response_failed_once(
|
||||
|
||||
seen_task_started, seen_task_finished = False, False
|
||||
for wrapped_event in events:
|
||||
event = wrapped_event.tagged_event.c
|
||||
event = wrapped_event.event
|
||||
if isinstance(event, TaskStateUpdated):
|
||||
if event.task_status == TaskStatus.RUNNING:
|
||||
if event.task_status == TaskStatus.Running:
|
||||
seen_task_started = True
|
||||
if event.task_status == TaskStatus.COMPLETE:
|
||||
if event.task_status == TaskStatus.Complete:
|
||||
seen_task_finished = True
|
||||
|
||||
if isinstance(event, ChunkGenerated):
|
||||
@@ -246,7 +246,7 @@ async def test_stream_response_timeout(
|
||||
async with create_task_group() as tg:
|
||||
tg.start_soon(worker.run)
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
instance_value.instance_type = InstanceStatus.Active
|
||||
|
||||
task: Task = chat_completion_task(INSTANCE_1_ID, TASK_1_ID)
|
||||
task.task_params.messages[0].content = "EXO RUNNER MUST TIMEOUT"
|
||||
@@ -269,8 +269,8 @@ async def test_stream_response_timeout(
|
||||
[
|
||||
x
|
||||
for x in events
|
||||
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
|
||||
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
|
||||
if isinstance(x.event, RunnerStatusUpdated)
|
||||
and isinstance(x.event.runner_status, FailedRunnerStatus)
|
||||
]
|
||||
)
|
||||
== 3
|
||||
@@ -280,8 +280,8 @@ async def test_stream_response_timeout(
|
||||
[
|
||||
x
|
||||
for x in events
|
||||
if isinstance(x.tagged_event.c, TaskStateUpdated)
|
||||
and x.tagged_event.c.task_status == TaskStatus.FAILED
|
||||
if isinstance(x.event, TaskStateUpdated)
|
||||
and x.event.task_status == TaskStatus.Failed
|
||||
]
|
||||
)
|
||||
== 3
|
||||
@@ -291,8 +291,8 @@ async def test_stream_response_timeout(
|
||||
[
|
||||
x
|
||||
for x in events
|
||||
if isinstance(x.tagged_event.c, TaskFailed)
|
||||
and "timeouterror" in x.tagged_event.c.error_type.lower()
|
||||
if isinstance(x.event, TaskFailed)
|
||||
and "timeouterror" in x.event.error_type.lower()
|
||||
]
|
||||
)
|
||||
== 3
|
||||
|
||||
@@ -37,7 +37,7 @@ async def test_runner_spinup_timeout(
|
||||
async with create_task_group() as tg:
|
||||
tg.start_soon(worker.run)
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
instance_value.instance_type = InstanceStatus.Active
|
||||
instance_value.shard_assignments.runner_to_shard[
|
||||
RUNNER_1_ID
|
||||
].should_timeout = 10
|
||||
@@ -61,11 +61,11 @@ async def test_runner_spinup_timeout(
|
||||
[
|
||||
x
|
||||
for x in events
|
||||
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
|
||||
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
|
||||
if isinstance(x.event, RunnerStatusUpdated)
|
||||
and isinstance(x.event.runner_status, FailedRunnerStatus)
|
||||
]
|
||||
)
|
||||
== 3
|
||||
)
|
||||
assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events])
|
||||
assert any([isinstance(x.event, InstanceDeleted) for x in events])
|
||||
worker.shutdown()
|
||||
|
||||
@@ -38,7 +38,7 @@ async def test_runner_spinup_exception(
|
||||
async with create_task_group() as tg:
|
||||
tg.start_soon(worker.run)
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
instance_value.instance_type = InstanceStatus.Active
|
||||
instance_value.shard_assignments.runner_to_shard[
|
||||
RUNNER_1_ID
|
||||
].immediate_exception = True
|
||||
@@ -57,13 +57,13 @@ async def test_runner_spinup_exception(
|
||||
[
|
||||
x
|
||||
for x in events
|
||||
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
|
||||
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
|
||||
if isinstance(x.event, RunnerStatusUpdated)
|
||||
and isinstance(x.event.runner_status, FailedRunnerStatus)
|
||||
]
|
||||
)
|
||||
== 3
|
||||
)
|
||||
assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events])
|
||||
assert any([isinstance(x.event, InstanceDeleted) for x in events])
|
||||
worker.shutdown()
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ async def test_runner_spinup_timeout(
|
||||
async with create_task_group() as tg:
|
||||
tg.start_soon(worker.run)
|
||||
instance_value: Instance = instance(INSTANCE_1_ID, NODE_A, RUNNER_1_ID)
|
||||
instance_value.instance_type = InstanceStatus.ACTIVE
|
||||
instance_value.instance_type = InstanceStatus.Active
|
||||
instance_value.shard_assignments.runner_to_shard[
|
||||
RUNNER_1_ID
|
||||
].should_timeout = 10
|
||||
@@ -99,11 +99,11 @@ async def test_runner_spinup_timeout(
|
||||
[
|
||||
x
|
||||
for x in events
|
||||
if isinstance(x.tagged_event.c, RunnerStatusUpdated)
|
||||
and isinstance(x.tagged_event.c.runner_status, FailedRunnerStatus)
|
||||
if isinstance(x.event, RunnerStatusUpdated)
|
||||
and isinstance(x.event.runner_status, FailedRunnerStatus)
|
||||
]
|
||||
)
|
||||
== 3
|
||||
)
|
||||
assert any([isinstance(x.tagged_event.c, InstanceDeleted) for x in events])
|
||||
assert any([isinstance(x.event, InstanceDeleted) for x in events])
|
||||
worker.shutdown()
|
||||
|
||||
@@ -22,7 +22,6 @@ from exo.shared.types.tasks import (
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
from exo.shared.types.worker.instances import (
|
||||
@@ -107,7 +106,7 @@ async def test_ttft(
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_type=InstanceStatus.Active,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(1),
|
||||
)
|
||||
@@ -139,8 +138,7 @@ async def test_ttft(
|
||||
task_id=TASK_1_ID,
|
||||
command_id=COMMAND_1_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=task1_params,
|
||||
)
|
||||
|
||||
@@ -157,7 +155,7 @@ async def test_ttft(
|
||||
first_chunk_seen_1 = False
|
||||
time_to_first_token_1: None | float = None
|
||||
while not first_chunk_seen_1:
|
||||
event = (await global_events.receive()).tagged_event.c
|
||||
event = (await global_events.receive()).event
|
||||
if isinstance(event, ChunkGenerated) and hasattr(event, "chunk"):
|
||||
first_chunk_time_1 = time.time()
|
||||
time_to_first_token_1 = first_chunk_time_1 - task_created_time_1
|
||||
@@ -192,8 +190,7 @@ async def test_ttft(
|
||||
task_id=TASK_2_ID,
|
||||
command_id=COMMAND_2_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=task2_params,
|
||||
)
|
||||
|
||||
@@ -211,7 +208,7 @@ async def test_ttft(
|
||||
first_chunk_seen_2 = False
|
||||
time_to_first_token_2: float | None = None
|
||||
while not first_chunk_seen_2:
|
||||
event = (await global_events.receive()).tagged_event.c
|
||||
event = (await global_events.receive()).event
|
||||
if isinstance(event, ChunkGenerated) and hasattr(event, "chunk"):
|
||||
first_chunk_time_2 = time.time()
|
||||
time_to_first_token_2 = first_chunk_time_2 - task_created_time_2
|
||||
@@ -344,7 +341,7 @@ async def test_2_runner_inference(
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_type=InstanceStatus.Active,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2),
|
||||
)
|
||||
@@ -424,7 +421,7 @@ async def test_parallel_inference(
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_type=InstanceStatus.Active,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2),
|
||||
)
|
||||
@@ -443,8 +440,7 @@ async def test_parallel_inference(
|
||||
task_id=TASK_1_ID,
|
||||
command_id=COMMAND_1_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=completion_create_params_1,
|
||||
)
|
||||
|
||||
@@ -462,8 +458,7 @@ async def test_parallel_inference(
|
||||
task_id=TASK_2_ID,
|
||||
command_id=COMMAND_2_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=completion_create_params_2,
|
||||
)
|
||||
|
||||
@@ -485,7 +480,7 @@ async def test_parallel_inference(
|
||||
|
||||
incomplete_task = (
|
||||
TASK_2_ID
|
||||
if worker1.state.tasks[TASK_1_ID].task_status == TaskStatus.COMPLETE
|
||||
if worker1.state.tasks[TASK_1_ID].task_status == TaskStatus.Complete
|
||||
else TASK_2_ID
|
||||
)
|
||||
(
|
||||
|
||||
@@ -6,7 +6,6 @@ from exo.shared.types.tasks import (
|
||||
ChatCompletionTask,
|
||||
ChatCompletionTaskParams,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from exo.shared.types.worker.common import WorkerStatus
|
||||
from exo.shared.types.worker.downloads import (
|
||||
@@ -85,7 +84,7 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
"downloaded": False,
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.INACTIVE,
|
||||
instance_status=InstanceStatus.Inactive,
|
||||
expected_op=UnassignRunnerOp(runner_id=RUNNER_1_ID),
|
||||
),
|
||||
make_test_case(
|
||||
@@ -99,7 +98,7 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
"downloaded": True,
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.INACTIVE,
|
||||
instance_status=InstanceStatus.Inactive,
|
||||
expected_op=None,
|
||||
),
|
||||
PlanTestCase(
|
||||
@@ -110,7 +109,7 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
INSTANCE_1_ID: [(RUNNER_1_ID, NODE_A, 0, InactiveRunnerStatus())]
|
||||
},
|
||||
model_id=MODEL_A_ID,
|
||||
instance_status=InstanceStatus.ACTIVE, # Either active or inactive should yield the same.
|
||||
instance_status=InstanceStatus.Active, # Either active or inactive should yield the same.
|
||||
),
|
||||
expected_op=AssignRunnerOp(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
@@ -153,7 +152,7 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
"downloaded": True,
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.ACTIVE,
|
||||
instance_status=InstanceStatus.Active,
|
||||
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID),
|
||||
),
|
||||
make_test_case(
|
||||
@@ -180,11 +179,11 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
{
|
||||
"task_id": TASK_1_ID,
|
||||
"instance_id": INSTANCE_1_ID,
|
||||
"status": TaskStatus.PENDING,
|
||||
"status": TaskStatus.Pending,
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.ACTIVE,
|
||||
instance_status=InstanceStatus.Active,
|
||||
expected_op=None,
|
||||
),
|
||||
make_test_case(
|
||||
@@ -209,11 +208,11 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
{
|
||||
"task_id": TASK_1_ID,
|
||||
"instance_id": INSTANCE_1_ID,
|
||||
"status": TaskStatus.PENDING,
|
||||
"status": TaskStatus.Pending,
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.ACTIVE,
|
||||
instance_status=InstanceStatus.Active,
|
||||
expected_op=RunnerUpOp(runner_id=RUNNER_1_ID),
|
||||
),
|
||||
make_test_case(
|
||||
@@ -227,7 +226,7 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
"downloaded": True,
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.INACTIVE,
|
||||
instance_status=InstanceStatus.Inactive,
|
||||
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
|
||||
),
|
||||
make_test_case(
|
||||
@@ -241,7 +240,7 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
"downloaded": True,
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.INACTIVE,
|
||||
instance_status=InstanceStatus.Inactive,
|
||||
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
|
||||
),
|
||||
make_test_case(
|
||||
@@ -259,19 +258,18 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
{
|
||||
"task_id": TASK_1_ID,
|
||||
"instance_id": INSTANCE_1_ID,
|
||||
"status": TaskStatus.PENDING,
|
||||
"status": TaskStatus.Pending,
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.ACTIVE,
|
||||
instance_status=InstanceStatus.Active,
|
||||
expected_op=ExecuteTaskOp(
|
||||
runner_id=RUNNER_1_ID,
|
||||
task=ChatCompletionTask(
|
||||
task_id=TASK_1_ID,
|
||||
command_id=COMMAND_1_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[
|
||||
@@ -304,11 +302,11 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
{
|
||||
"task_id": TASK_1_ID,
|
||||
"instance_id": INSTANCE_1_ID,
|
||||
"status": TaskStatus.PENDING,
|
||||
"status": TaskStatus.Pending,
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.ACTIVE,
|
||||
instance_status=InstanceStatus.Active,
|
||||
expected_op=None,
|
||||
),
|
||||
make_test_case(
|
||||
@@ -333,25 +331,24 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
{
|
||||
"task_id": TASK_1_ID,
|
||||
"instance_id": INSTANCE_1_ID,
|
||||
"status": TaskStatus.PENDING,
|
||||
"status": TaskStatus.Pending,
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.ACTIVE,
|
||||
instance_status=InstanceStatus.Active,
|
||||
expected_op=ExecuteTaskOp(
|
||||
runner_id=RUNNER_1_ID,
|
||||
task=ChatCompletionTask(
|
||||
task_id=TASK_1_ID,
|
||||
command_id=COMMAND_1_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content="Hello, world!")
|
||||
],
|
||||
),
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_status=TaskStatus.Pending,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -377,25 +374,24 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
{
|
||||
"task_id": TASK_1_ID,
|
||||
"instance_id": INSTANCE_1_ID,
|
||||
"status": TaskStatus.PENDING,
|
||||
"status": TaskStatus.Pending,
|
||||
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.ACTIVE,
|
||||
instance_status=InstanceStatus.Active,
|
||||
expected_op=ExecuteTaskOp(
|
||||
runner_id=RUNNER_1_ID,
|
||||
task=ChatCompletionTask(
|
||||
task_id=TASK_1_ID,
|
||||
command_id=COMMAND_1_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content="Hello, world!")
|
||||
],
|
||||
),
|
||||
task_status=TaskStatus.PENDING,
|
||||
task_status=TaskStatus.Pending,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -410,7 +406,7 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
"downloaded": True,
|
||||
}
|
||||
],
|
||||
instance_status=InstanceStatus.ACTIVE,
|
||||
instance_status=InstanceStatus.Active,
|
||||
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
|
||||
),
|
||||
make_test_case(
|
||||
@@ -431,7 +427,7 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
"downloaded": True,
|
||||
},
|
||||
],
|
||||
instance_status=InstanceStatus.ACTIVE,
|
||||
instance_status=InstanceStatus.Active,
|
||||
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
|
||||
),
|
||||
make_test_case(
|
||||
@@ -452,7 +448,7 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
"downloaded": True,
|
||||
},
|
||||
],
|
||||
instance_status=InstanceStatus.ACTIVE,
|
||||
instance_status=InstanceStatus.Active,
|
||||
expected_op=None,
|
||||
),
|
||||
make_test_case(
|
||||
@@ -473,7 +469,7 @@ def _get_test_cases() -> list[PlanTestCase]:
|
||||
"downloaded": True,
|
||||
},
|
||||
],
|
||||
instance_status=InstanceStatus.ACTIVE,
|
||||
instance_status=InstanceStatus.Active,
|
||||
expected_op=RunnerDownOp(runner_id=RUNNER_1_ID),
|
||||
),
|
||||
]
|
||||
|
||||
@@ -9,7 +9,7 @@ from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus, TaskType
|
||||
from exo.shared.types.tasks import ChatCompletionTask, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.common import InstanceId, RunnerId, WorkerStatus
|
||||
from exo.shared.types.worker.downloads import DownloadOngoing, DownloadProgressData
|
||||
from exo.shared.types.worker.instances import Instance, InstanceStatus
|
||||
@@ -117,7 +117,7 @@ def make_downloading_status(node_id: NodeId) -> DownloadingRunnerStatus:
|
||||
download_progress=DownloadOngoing(
|
||||
node_id=node_id,
|
||||
download_progress=DownloadProgressData(
|
||||
total_bytes=Memory.from_bytes(1), downloaded_bytes=Memory.from_bytes(0), downloaded_bytes_this_session=Memory.from_bytes(0), completed_files=0, total_files=1, speed=0, eta_ms=0, files={}
|
||||
total_bytes=Memory.from_bytes(1), downloaded_bytes=Memory.from_bytes(0)
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -146,7 +146,7 @@ def make_instance(
|
||||
instance_id: InstanceId,
|
||||
runner_specs: list[tuple[RunnerId, NodeId, int, RunnerStatus]],
|
||||
model_id: ModelId = MODEL_A_ID,
|
||||
instance_status: InstanceStatus = InstanceStatus.ACTIVE,
|
||||
instance_status: InstanceStatus = InstanceStatus.Active,
|
||||
) -> tuple[Instance, dict[RunnerId, RunnerStatus], dict[NodeId, WorkerStatus]]:
|
||||
"""Creates an instance with one or more runners."""
|
||||
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
|
||||
@@ -189,7 +189,7 @@ def make_state(
|
||||
],
|
||||
tasks: dict[TaskId, ChatCompletionTask] | None = None,
|
||||
model_id: ModelId = MODEL_A_ID,
|
||||
instance_status: InstanceStatus = InstanceStatus.ACTIVE,
|
||||
instance_status: InstanceStatus = InstanceStatus.Active,
|
||||
) -> State:
|
||||
"""Builds a full State from runner specs per instance, tasks, and defaults."""
|
||||
if tasks is None:
|
||||
@@ -224,7 +224,7 @@ def make_test_case(
|
||||
tasks: list[TaskSpecDict] | None = None,
|
||||
expected_op: Optional[RunnerOp] = None,
|
||||
instance_id: InstanceId = INSTANCE_1_ID,
|
||||
instance_status: InstanceStatus = InstanceStatus.ACTIVE,
|
||||
instance_status: InstanceStatus = InstanceStatus.Active,
|
||||
model_id: ModelId = MODEL_A_ID,
|
||||
command_id: CommandId = COMMAND_1_ID, # Default for tasks
|
||||
) -> PlanTestCase:
|
||||
@@ -244,8 +244,7 @@ def make_test_case(
|
||||
instance_id=instance_id,
|
||||
task_id=t["task_id"],
|
||||
command_id=t.get("command_id", command_id),
|
||||
task_type=TaskType.CHAT_COMPLETION,
|
||||
task_status=t.get("status", TaskStatus.PENDING),
|
||||
task_status=t.get("status", TaskStatus.Pending),
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=t.get("model", str(model_id)),
|
||||
messages=[
|
||||
|
||||
@@ -72,7 +72,7 @@ async def check_runner_connection(
|
||||
|
||||
instance = Instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
instance_type=InstanceStatus.ACTIVE,
|
||||
instance_type=InstanceStatus.Active,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=hosts(2),
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from exo.shared.types.common import Host
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.commands_runner import (
|
||||
ChatTaskMessage,
|
||||
RunnerMessageTypeAdapter,
|
||||
RunnerMessage,
|
||||
SetupMessage,
|
||||
)
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
@@ -30,7 +30,7 @@ def test_supervisor_setup_message_serdes(
|
||||
model_shard_meta=pipeline_shard_meta(1, 0),
|
||||
hosts=hosts(1),
|
||||
)
|
||||
assert_equal_serdes(setup_message, RunnerMessageTypeAdapter)
|
||||
assert_equal_serdes(setup_message, TypeAdapter(RunnerMessage))
|
||||
|
||||
|
||||
def test_supervisor_task_message_serdes(
|
||||
@@ -40,4 +40,4 @@ def test_supervisor_task_message_serdes(
|
||||
task_message = ChatTaskMessage(
|
||||
task_data=task.task_params,
|
||||
)
|
||||
assert_equal_serdes(task_message, RunnerMessageTypeAdapter)
|
||||
assert_equal_serdes(task_message, TypeAdapter(RunnerMessage))
|
||||
|
||||
@@ -7,10 +7,10 @@ from exo.shared.openai_compat import FinishReason
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletionTask,
|
||||
ChatCompletionTaskParams,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskType,
|
||||
)
|
||||
from exo.shared.types.worker.common import InstanceId
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
@@ -143,7 +143,7 @@ async def test_supervisor_early_stopping(
|
||||
task = chat_completion_task(instance_id, TaskId())
|
||||
|
||||
max_tokens = 50
|
||||
assert task.task_type == TaskType.CHAT_COMPLETION
|
||||
assert isinstance(task, ChatCompletionTask)
|
||||
print(f"chat_completion_task.task_params: {task.task_params}")
|
||||
assert isinstance(task.task_params, ChatCompletionTaskParams)
|
||||
task_params: ChatCompletionTaskParams = task.task_params
|
||||
|
||||
@@ -6,7 +6,7 @@ from anyio import fail_after
|
||||
from exo.routing.topics import ConnectionMessage, ForwarderCommand, ForwarderEvent
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import ChunkGenerated, Event, TaggedEvent, TaskStateUpdated
|
||||
from exo.shared.types.events import ChunkGenerated, Event, TaskStateUpdated
|
||||
from exo.shared.types.tasks import TaskId, TaskStatus
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.worker.download.shard_downloader import NoopShardDownloader, ShardDownloader
|
||||
@@ -24,7 +24,7 @@ class WorkerMailbox:
|
||||
await self.sender.send(
|
||||
ForwarderEvent(
|
||||
origin=origin,
|
||||
tagged_event=TaggedEvent.from_(event),
|
||||
event=event,
|
||||
origin_idx=self.counter,
|
||||
)
|
||||
)
|
||||
@@ -105,7 +105,7 @@ async def read_streaming_response(
|
||||
token_count = 0
|
||||
extra_events: list[Event] = []
|
||||
|
||||
event = (await global_event_receiver.receive()).tagged_event.c
|
||||
event = (await global_event_receiver.receive()).event
|
||||
extra_events.append(event)
|
||||
|
||||
from loguru import logger
|
||||
@@ -116,17 +116,17 @@ async def read_streaming_response(
|
||||
if filter_task:
|
||||
while not (
|
||||
isinstance(event, TaskStateUpdated)
|
||||
and event.task_status == TaskStatus.RUNNING
|
||||
and event.task_status == TaskStatus.Running
|
||||
and event.task_id == filter_task
|
||||
):
|
||||
event = (await global_event_receiver.receive()).tagged_event.c
|
||||
event = (await global_event_receiver.receive()).event
|
||||
extra_events.append(event)
|
||||
|
||||
for event in extra_events:
|
||||
if isinstance(event, TaskStateUpdated):
|
||||
if event.task_status == TaskStatus.RUNNING:
|
||||
if event.task_status == TaskStatus.Running:
|
||||
seen_task_started += 1
|
||||
if event.task_status == TaskStatus.COMPLETE:
|
||||
if event.task_status == TaskStatus.Complete:
|
||||
seen_task_finished += 1
|
||||
if isinstance(event, ChunkGenerated) and isinstance(
|
||||
event.chunk, TokenChunk
|
||||
@@ -137,11 +137,11 @@ async def read_streaming_response(
|
||||
finish_reason = event.chunk.finish_reason
|
||||
|
||||
while not seen_task_finished:
|
||||
event = (await global_event_receiver.receive()).tagged_event.c
|
||||
event = (await global_event_receiver.receive()).event
|
||||
if isinstance(event, TaskStateUpdated):
|
||||
if event.task_status == TaskStatus.RUNNING:
|
||||
if event.task_status == TaskStatus.Running:
|
||||
seen_task_started += 1
|
||||
if event.task_status == TaskStatus.COMPLETE:
|
||||
if event.task_status == TaskStatus.Complete:
|
||||
seen_task_finished += 1
|
||||
if isinstance(event, ChunkGenerated) and isinstance(
|
||||
event.chunk, TokenChunk
|
||||
@@ -167,7 +167,7 @@ async def until_event_with_timeout[T](
|
||||
|
||||
with fail_after(timeout):
|
||||
while times_seen < multiplicity:
|
||||
event = (await global_event_receiver.receive()).tagged_event.c
|
||||
event = (await global_event_receiver.receive()).event
|
||||
if isinstance(event, event_type):
|
||||
print(f"Wow! We got a {event}")
|
||||
print(
|
||||
|
||||
@@ -99,13 +99,13 @@ async def start_polling_node_metrics(
|
||||
system_info,
|
||||
network_interfaces,
|
||||
mac_friendly_name,
|
||||
memory_profile,
|
||||
) = await asyncio.gather(
|
||||
get_mac_system_info_async(),
|
||||
get_network_interface_info_async(),
|
||||
get_mac_friendly_name_async(),
|
||||
get_memory_profile_async(),
|
||||
)
|
||||
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
|
||||
memory_profile = await get_memory_profile_async()
|
||||
|
||||
await callback(
|
||||
NodePerformanceProfile(
|
||||
|
||||
0
typings/.gitkeep
Normal file
0
typings/.gitkeep
Normal file
Reference in New Issue
Block a user