Merge pull request #617 from exo-explore/runners2

Lots of fixes and QoL improvements.
This commit is contained in:
Alex Cheema
2025-01-23 02:05:17 +00:00
committed by GitHub
39 changed files with 2163 additions and 592 deletions

View File

@@ -254,6 +254,33 @@ jobs:
prompt: "Keep responses concise. Who was the king of pop?"
expected_output: "Michael Jackson"
chatgpt_api_integration_test_tinygrad_linux:
machine:
image: ubuntu-2204:current
resource_class: xlarge
steps:
- checkout
- run:
name: Set up Python
command: |
sudo apt-get update
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get update
sudo apt-get install -y python3.12 python3.12-venv clang
python3.12 -m venv env
source env/bin/activate
- run:
name: Install dependencies
command: |
source env/bin/activate
pip install --upgrade pip
pip install .
- run_chatgpt_api_test:
inference_engine: tinygrad
model_id: llama-3.2-1b
prompt: "Keep responses concise. Who was the king of pop?"
expected_output: "Michael Jackson"
measure_pip_sizes:
macos:
xcode: "16.0.0"
@@ -342,5 +369,6 @@ workflows:
- discovery_integration_test
- chatgpt_api_integration_test_mlx
- chatgpt_api_integration_test_tinygrad
- chatgpt_api_integration_test_tinygrad_linux
- chatgpt_api_integration_test_dummy
- measure_pip_sizes

401
.github/bench.py vendored Normal file
View File

@@ -0,0 +1,401 @@
import aiohttp
import asyncio
import time
import json
import os
import boto3
from typing import Dict, Any
from datetime import datetime
import subprocess
import psutil
import platform
from pathlib import Path
def check_system_state():
print("\n=== System State Check ===", flush=True)
# Add macOS-specific checks
try:
# Check powermetrics with sudo
try:
power_metrics = subprocess.run(
['sudo', 'powermetrics', '-n', '1', '-i', '1000', '--samplers', 'cpu_power'],
capture_output=True, text=True
)
print("\nPower Metrics:", power_metrics.stdout, flush=True)
except Exception as e:
print(f"Error getting power metrics: {e}", flush=True)
# Check thermal state
thermal_state = subprocess.run(['pmset', '-g', 'therm'], capture_output=True, text=True)
print("\nThermal State:", thermal_state.stdout, flush=True)
# Check if running under Rosetta
arch = subprocess.run(['arch'], capture_output=True, text=True)
print("\nArchitecture:", arch.stdout, flush=True)
# Check MLX compilation mode - only if mlx is available
try:
import mlx.core as mx
if hasattr(mx, 'build_info'):
print("\nMLX Build Info:", mx.build_info(), flush=True)
else:
print("\nMLX Build Info: Not available in this version", flush=True)
except ImportError:
print("\nMLX: Not installed", flush=True)
except Exception as e:
print(f"\nError checking MLX: {e}", flush=True)
except Exception as e:
print(f"Error in macOS checks: {e}", flush=True)
# CPU Info
print("\nCPU Information:", flush=True)
try:
if platform.system() == 'Darwin' and platform.processor() == 'arm':
# Use sysctl for Apple Silicon Macs
cpu_info = subprocess.run(['sysctl', 'machdep.cpu'], capture_output=True, text=True)
if cpu_info.returncode == 0:
print(f"CPU Info (Apple Silicon):", cpu_info.stdout, flush=True)
# Parse powermetrics output for clearer CPU frequency display
try:
power_metrics = subprocess.run(
['sudo', 'powermetrics', '-n', '1', '-i', '100', '--samplers', 'cpu_power'],
capture_output=True, text=True
)
if power_metrics.returncode == 0:
output = power_metrics.stdout
print("\nDetailed CPU Frequency Information:")
# Extract cluster frequencies and max frequencies
current_cluster = None
max_freqs = {'E': 0, 'P0': 0, 'P1': 0}
for line in output.split('\n'):
# Track which cluster we're processing
if "E-Cluster" in line:
current_cluster = 'E'
elif "P0-Cluster" in line:
current_cluster = 'P0'
elif "P1-Cluster" in line:
current_cluster = 'P1'
# Get current frequencies
if "HW active frequency:" in line:
freq = line.split(':')[1].strip()
if freq != "0 MHz":
print(f"Current {current_cluster}-Cluster Frequency: {freq}")
# Get max frequencies from residency lines
if current_cluster and "active residency:" in line and "MHz:" in line:
try:
# Extract all frequency values
freqs = []
parts = line.split('MHz:')[:-1] # Skip last part as it's not a frequency
for part in parts:
freq_str = part.split()[-1]
try:
freq = float(freq_str)
freqs.append(freq)
except ValueError:
continue
if freqs:
max_freqs[current_cluster] = max(max_freqs[current_cluster], max(freqs))
except Exception:
continue
# Print max frequencies
print("\nMaximum Available Frequencies:")
for cluster, max_freq in max_freqs.items():
if max_freq > 0:
print(f"{cluster}-Cluster Max: {max_freq:.0f} MHz")
except Exception as e:
print(f"Error parsing powermetrics: {e}", flush=True)
else:
# Use psutil for other systems
cpu_freq = psutil.cpu_freq()
print(f"CPU Frequency - Current: {cpu_freq.current:.2f}MHz, Min: {cpu_freq.min:.2f}MHz, Max: {cpu_freq.max:.2f}MHz", flush=True)
print(f"\nCPU Usage per Core: {psutil.cpu_percent(percpu=True)}%", flush=True)
# Check if running in low power mode
power_mode = subprocess.run(['pmset', '-g'], capture_output=True, text=True)
print("\nPower Settings:", power_mode.stdout, flush=True)
except Exception as e:
print(f"Error getting CPU info: {e}", flush=True)
# Memory Info
print("\nMemory Information:", flush=True)
try:
mem = psutil.virtual_memory()
print(f"Total: {mem.total/1024/1024/1024:.2f}GB", flush=True)
print(f"Available: {mem.available/1024/1024/1024:.2f}GB", flush=True)
print(f"Used: {mem.used/1024/1024/1024:.2f}GB ({mem.percent}%)", flush=True)
# Check swap
swap = psutil.swap_memory()
print(f"Swap Used: {swap.used/1024/1024/1024:.2f}GB of {swap.total/1024/1024/1024:.2f}GB", flush=True)
except Exception as e:
print(f"Error getting memory info: {e}", flush=True)
# GPU Info
print("\nGPU Information:", flush=True)
try:
# Check MLX GPU settings
print("MLX Environment Variables:", flush=True)
mlx_vars = {k: v for k, v in os.environ.items() if k.startswith('MLX')}
print(json.dumps(mlx_vars, indent=2), flush=True)
# Check Metal GPU memory allocation
gpu_mem = subprocess.run(['sysctl', 'iogpu'], capture_output=True, text=True)
print("GPU Memory Settings:", gpu_mem.stdout, flush=True)
except Exception as e:
print(f"Error getting GPU info: {e}", flush=True)
# Process Priority
print("\nProcess Priority Information:", flush=True)
try:
current_process = psutil.Process()
print(f"Process Nice Value: {current_process.nice()}", flush=True)
# Only try to get ionice if the platform supports it
if hasattr(current_process, 'ionice'):
print(f"Process IO Nice Value: {current_process.ionice()}", flush=True)
except Exception as e:
print(f"Error getting process priority info: {e}", flush=True)
# System Load
print("\nSystem Load:", flush=True)
try:
load_avg = psutil.getloadavg()
print(f"Load Average: {load_avg}", flush=True)
# Get top processes by CPU and Memory
print("\nTop Processes by CPU Usage:", flush=True)
processes = []
for proc in psutil.process_iter(['pid', 'name', 'cpu_percent', 'memory_percent']):
try:
pinfo = proc.info
if pinfo['cpu_percent'] is not None and pinfo['memory_percent'] is not None:
processes.append(pinfo)
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
# Sort and display top 5 CPU-consuming processes
sorted_by_cpu = sorted(processes, key=lambda x: x['cpu_percent'] or 0, reverse=True)[:5]
for proc in sorted_by_cpu:
print(f"PID: {proc['pid']}, Name: {proc['name']}, CPU: {proc['cpu_percent']}%, Memory: {proc['memory_percent']:.1f}%")
except Exception as e:
print(f"Error getting system load info: {e}", flush=True)
print("\n=== End System State Check ===\n", flush=True)
def check_gpu_access():
try:
# Check if MLX can see the GPU
import mlx.core as mx
print("MLX device info:", mx.default_device())
# Check Metal device availability
result = subprocess.run(['system_profiler', 'SPDisplaysDataType'], capture_output=True, text=True)
print("GPU Info:", result.stdout)
except Exception as e:
print(f"Failed to check GPU access: {e}")
async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dict[str, Any]:
"""
Measures the performance of an API endpoint by sending a prompt and recording metrics.
Args:
api_endpoint (str): The API endpoint URL.
prompt (str): The prompt to send to the API.
Returns:
Dict[str, Any]: A dictionary containing performance metrics or error information.
"""
results = {
'model': model,
'run_id': os.environ.get('GITHUB_RUN_ID', 'unknown'),
'branch': os.environ.get('GITHUB_REF_NAME', 'unknown'),
'commit': os.environ.get('GITHUB_SHA', 'unknown'),
'configuration': json.loads(os.environ.get('HARDWARE_CONFIG', '{}'))
}
# Get token count
session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600, connect=10, sock_read=600, sock_connect=10))
try:
response = await session.post(
"http://localhost:52415/v1/chat/token/encode",
json={
"model": model,
"messages": [{"role": "user", "content": prompt}]
}
)
response.raise_for_status()
token_data = await response.json()
results['prompt_len'] = token_data['num_tokens']
except Exception as e:
await session.close()
raise RuntimeError(f"Failed to get token count: {str(e)}")
# Measure completion performance
try:
start_time = time.time()
response = await session.post(
api_endpoint,
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0,
"stream": True
}
)
response.raise_for_status()
first_token_time = None
total_tokens = 0
async for line in response.content.iter_chunks():
line = line[0].decode('utf-8').strip()
if not line.startswith('data: '):
continue
data = json.loads(line[6:]) # Skip 'data: ' prefix
if content := data.get('choices', [{}])[0].get('delta', {}).get('content'):
print(f"Received content: {content}", flush=True)
if first_token_time is None:
first_token_time = time.time()
ttft = first_token_time - start_time
results.update({
'ttft': ttft,
'prompt_tps': results['prompt_len'] / ttft
})
total_tokens += 1
total_time = time.time() - start_time
results.update({
'generation_tps': total_tokens / total_time,
'response_len': total_tokens,
'total_time': total_time
})
except Exception as e:
raise RuntimeError(f"Performance measurement failed: {str(e)}")
finally:
await session.close()
return results
async def main() -> None:
api_endpoint = "http://localhost:52415/v1/chat/completions"
# Define prompts
prompt_warmup = "what is the capital of France?"
prompt_essay = "write an essay about cats"
model = os.environ.get('model', 'llama-3.2-1b')
# Warmup request
print("\nPerforming warmup request...", flush=True)
try:
warmup_results = await measure_performance(api_endpoint, prompt_warmup, model)
print("Warmup completed successfully", flush=True)
except Exception as e:
print(f"Warmup request failed: {e}", flush=True)
# Measure performance for the essay prompt
print("\nMeasuring performance for the essay prompt...", flush=True)
results = await measure_performance(api_endpoint, prompt_essay, model)
try:
s3_client = boto3.client(
's3',
aws_access_key_id=os.environ.get('aws_access_key_id'),
aws_secret_access_key=os.environ.get('aws_secret_key')
)
job_name = os.environ.get('GITHUB_JOB')
# Create S3 key with timestamp and commit info
now = datetime.utcnow()
timestamp = now.strftime('%H-%M-%S')
commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
s3_key = f"{job_name}/{model}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"
# Upload to S3
s3_client.put_object(
Bucket='exo-benchmarks',
Key=s3_key,
Body=json.dumps(results),
ContentType='application/json'
)
print(f"Performance metrics uploaded to S3: s3://exo-benchmarks/{s3_key}", flush=True)
except Exception as e:
print(f"Failed to upload metrics to S3: {e}", flush=True)
# Optionally print the metrics for visibility
print("Performance metrics:", flush=True)
print(json.dumps(results, indent=4), flush=True)
def optimize_system_performance():
"""Set optimal system performance settings before running benchmark."""
try:
# Try to set high performance power mode
subprocess.run(['sudo', 'pmset', '-a', 'powermode', '2'], check=False)
# Ensure MLX uses performance cores and GPU
os.environ['MLX_FORCE_P_CORES'] = '1'
os.environ['MLX_METAL_PREWARM'] = '1'
os.environ['MLX_USE_GPU'] = '1'
# Set process priority
current_process = psutil.Process()
try:
# Set highest priority
subprocess.run(['sudo', 'renice', '-n', '-20', '-p', str(current_process.pid)], check=False)
# Print current process state
print("\nProcess State Before Benchmark:", flush=True)
proc_info = subprocess.run(
['ps', '-o', 'pid,ppid,user,%cpu,%mem,nice,stat,pri,command', '-p', str(current_process.pid)],
capture_output=True, text=True
)
print(proc_info.stdout, flush=True)
# Verify power mode
power_info = subprocess.run(['pmset', '-g'], capture_output=True, text=True)
if 'powermode 0' in power_info.stdout:
print("\nWarning: System still in normal power mode. Trying to set high performance mode again...", flush=True)
subprocess.run(['sudo', 'pmset', '-a', 'powermode', '2'], check=False)
except Exception as e:
print(f"Warning: Could not set process priority: {e}", flush=True)
except Exception as e:
print(f"Warning: Could not optimize system performance: {e}", flush=True)
# Print optimization status
print("\nOptimization Settings:", flush=True)
print("MLX Environment Variables:", flush=True)
for var in ['MLX_FORCE_P_CORES', 'MLX_METAL_PREWARM', 'MLX_USE_GPU']:
print(f"{var}: {os.environ.get(var, 'Not set')}", flush=True)
try:
nice_value = psutil.Process().nice()
print(f"Process Nice Value: {nice_value}", flush=True)
if nice_value != -20:
print("Warning: Process not running at highest priority", flush=True)
except Exception:
pass
if __name__ == "__main__":
check_system_state()
check_gpu_access()
optimize_system_performance()
asyncio.run(main())

330
.github/bootstrap.sh vendored Executable file
View File

@@ -0,0 +1,330 @@
#!/bin/bash
set -e
command_exists() {
command -v "$1" >/dev/null 2>&1
}
log() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
}
if [ "$EUID" -eq 0 ]; then
log "Please do not run as root. Run as regular user with sudo access."
exit 1
fi
# Check for required arguments
if [ -z "$1" ]; then
log "Error: Runner token is required"
log "Usage: $0 <runner-token> [tailscale-auth-key]"
exit 1
fi
RUNNER_TOKEN=$1
TAILSCALE_AUTH_KEY=$2
REPO="exo-explore/exo"
# Add sudoers configuration
log "Configuring sudo access..."
SUDOERS_CONTENT="$(whoami) ALL=(ALL) NOPASSWD: ALL"
echo "$SUDOERS_CONTENT" | sudo tee /etc/sudoers.d/github-runner > /dev/null
sudo chmod 440 /etc/sudoers.d/github-runner
log "Configuring privacy permissions..."
sudo tccutil reset All
sudo tccutil reset SystemPolicyAllFiles
sudo tccutil reset SystemPolicyNetworkVolumes
# Configure power management for maximum performance
log "Configuring power management..."
sudo pmset -a powermode 2 # Force highest performance mode
sudo pmset -a gpuswitch 2 # Force discrete/high-performance GPU
sudo pmset -a lowpowermode 0
sudo pmset -a lessbright 0
sudo pmset -a disablesleep 1
sudo pmset -a sleep 0
sudo pmset -a hibernatemode 0
sudo pmset -a autopoweroff 0
sudo pmset -a standby 0
sudo pmset -a powernap 0
# For Python specifically
PYTHON_PATH="/opt/homebrew/bin/python3.12"
sudo chmod 755 "$PYTHON_PATH"
# Add to firewall
log "Configuring firewall access..."
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --add "$PYTHON_PATH"
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --unblock "$PYTHON_PATH"
# Set Homebrew paths based on architecture
if [ "$(uname -p)" = "arm" ]; then
BREW_PREFIX="/opt/homebrew"
else
BREW_PREFIX="/usr/local"
fi
# Install Homebrew if not present
if ! command_exists brew; then
log "Installing Homebrew..."
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
echo 'eval "$(/opt/homebrew/bin/brew shellenv)"' >> ~/.zshrc
eval "$(/opt/homebrew/bin/brew shellenv)"
fi
# Install required packages
log "Installing required packages..."
export HOMEBREW_NO_AUTO_UPDATE=1
brew install python@3.12 coreutils
# Optional Tailscale setup if auth key is provided
if [ -n "$TAILSCALE_AUTH_KEY" ]; then
log "Installing and configuring Tailscale..."
brew install --quiet tailscale
sudo brew services stop tailscale 2>/dev/null || true
sudo rm -f /var/db/tailscale/tailscaled.state 2>/dev/null || true
sudo brew services start tailscale
sleep 2
sudo tailscale up --authkey=$TAILSCALE_AUTH_KEY
# Enable SSH and Screen Sharing
log "Enabling remote access services..."
sudo launchctl load -w /System/Library/LaunchDaemons/ssh.plist
sudo /System/Library/CoreServices/RemoteManagement/ARDAgent.app/Contents/Resources/kickstart \
-activate \
-configure -access -on \
-configure -allowAccessFor -allUsers \
-configure -restart -agent -privs -all
# Create launch daemon for remote access
sudo bash -c 'cat > /Library/LaunchDaemons/com.remote.access.setup.plist' << 'EOL'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>com.remote.access.setup</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>-c</string>
<string>
launchctl load -w /System/Library/LaunchDaemons/ssh.plist;
/System/Library/CoreServices/RemoteManagement/ARDAgent.app/Contents/Resources/kickstart -activate -configure -access -on
</string>
</array>
<key>RunAtLoad</key>
<true/>
</dict>
</plist>
EOL
sudo chmod 644 /Library/LaunchDaemons/com.remote.access.setup.plist
sudo launchctl load -w /Library/LaunchDaemons/com.remote.access.setup.plist
fi
# Configure GitHub Actions Runner
log "Gathering system metadata..."
MACHINE_NAME=$(scutil --get ComputerName)
MACHINE_NAME="runner-$(echo -n "$MACHINE_NAME" | tr '[:upper:]' '[:lower:]' | tr -cd '[:alnum:]-')"
# Enhanced Apple Silicon detection
MACHINE_INFO=$(system_profiler SPHardwareDataType)
CHIP_FULL=$(echo "$MACHINE_INFO" | grep "Chip" | cut -d: -f2 | xargs)
if [[ $CHIP_FULL =~ "Apple" ]]; then
CHIP_MODEL=$(echo "$CHIP_FULL" | sed 's/^Apple //' | tr -d ' ' | tr '[:lower:]' '[:upper:]')
GPU_CORES=$(ioreg -l | grep "gpu-core-count" | awk -F'= ' '{print $2}')
if [ -z "$GPU_CORES" ]; then
GPU_CORES="N/A"
fi
else
CHIP_MODEL="Intel"
GPU_CORES="N/A"
fi
MEMORY=$(($(sysctl -n hw.memsize) / 1024 / 1024 / 1024))
# Set up GitHub Runner
RUNNER_DIR="$HOME/actions-runner"
# Check if runner is already configured
if [ -f "$RUNNER_DIR/.runner" ]; then
log "Runner already configured. Stopping existing service..."
sudo launchctl unload /Library/LaunchDaemons/com.github.runner.plist 2>/dev/null || true
fi
# Create runner directory if it doesn't exist
mkdir -p "$RUNNER_DIR"
cd "$RUNNER_DIR"
CUSTOM_LABELS="self-hosted,macos,arm64,${CHIP_MODEL}_GPU${GPU_CORES}_${MEMORY}GB"
# Only download and extract if not already present or if forced
if [ ! -f "$RUNNER_DIR/run.sh" ] || [ "${FORCE_SETUP:-false}" = "true" ]; then
log "Downloading GitHub Actions runner..."
RUNNER_VERSION=$(curl -s https://api.github.com/repos/actions/runner/releases/latest | grep '"tag_name":' | cut -d'"' -f4)
curl -o actions-runner.tar.gz -L "https://github.com/actions/runner/releases/download/${RUNNER_VERSION}/actions-runner-osx-arm64-${RUNNER_VERSION#v}.tar.gz"
tar xzf actions-runner.tar.gz
rm actions-runner.tar.gz
else
log "Runner already downloaded, skipping download step"
fi
log "Configuring runner with labels: $CUSTOM_LABELS"
./config.sh --unattended \
--url "https://github.com/${REPO}" \
--token "${RUNNER_TOKEN}" \
--name "${MACHINE_NAME}" \
--labels "${CUSTOM_LABELS}" \
--work "_work"
# Set optimal performance settings
log "Configuring system for optimal performance..."
# Configure CPU performance
log "Setting CPU performance controls..."
# Disable timer coalescing
sudo sysctl -w kern.timer.coalescing_enabled=0
sudo sysctl -w kern.timer_coalesce_bg_scale=-5
sudo sysctl -w kern.timer_resort_threshold_ns=0
# Set minimum timer intervals
sudo sysctl -w kern.wq_max_timer_interval_usecs=1000
sudo sysctl -w kern.timer_coalesce_bg_ns_max=1000
# Set minimum timer coalescing for all tiers
sudo sysctl -w kern.timer_coalesce_tier0_scale=-5
sudo sysctl -w kern.timer_coalesce_tier0_ns_max=1000
sudo sysctl -w kern.timer_coalesce_tier1_scale=-5
sudo sysctl -w kern.timer_coalesce_tier1_ns_max=1000
sudo sysctl -w kern.timer_coalesce_tier2_scale=-5
sudo sysctl -w kern.timer_coalesce_tier2_ns_max=1000
sudo sysctl -w kern.timer_coalesce_tier3_scale=-5
sudo sysctl -w kern.timer_coalesce_tier3_ns_max=1000
sudo sysctl -w kern.timer_coalesce_tier4_scale=-5
sudo sysctl -w kern.timer_coalesce_tier4_ns_max=1000
# Disable QoS restrictions
sudo sysctl -w net.qos.policy.restricted=0
sudo sysctl -w net.qos.policy.restrict_avapps=0
sudo sysctl -w net.qos.policy.wifi_enabled=0
sudo sysctl -w net.qos.policy.capable_enabled=0
# Set scheduler parameters
sudo sysctl -w kern.sched_rt_avoid_cpu0=0
sudo sysctl -w debug.sched=2
sudo sysctl -w net.pktsched.netem.sched_output_ival_ms=1
# Clean up any existing runner services
log "Cleaning up existing runner services..."
for service in com.github.runner com.github.runner.monitor com.github.runner.cpuaffinity com.github.runner.affinity; do
sudo launchctl bootout system/$service 2>/dev/null || true
sudo rm -f /Library/LaunchDaemons/$service.plist
done
# Create a simple runner service configuration
sudo tee /Library/LaunchDaemons/com.github.runner.plist > /dev/null << EOF
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>com.github.runner</string>
<key>UserName</key>
<string>$(whoami)</string>
<key>GroupName</key>
<string>staff</string>
<key>WorkingDirectory</key>
<string>$RUNNER_DIR</string>
<key>ProgramArguments</key>
<array>
<string>$RUNNER_DIR/run.sh</string>
</array>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<dict>
<key>SuccessfulExit</key>
<false/>
<key>Crashed</key>
<true/>
</dict>
<key>ProcessType</key>
<string>Interactive</string>
<key>LowPriorityIO</key>
<false/>
<key>AbandonProcessGroup</key>
<false/>
<key>EnableTransactions</key>
<true/>
<key>ThrottleInterval</key>
<integer>0</integer>
<key>HardResourceLimits</key>
<dict>
<key>NumberOfFiles</key>
<integer>524288</integer>
<key>MemoryLock</key>
<integer>-1</integer>
</dict>
<key>SoftResourceLimits</key>
<dict>
<key>NumberOfFiles</key>
<integer>524288</integer>
<key>MemoryLock</key>
<integer>-1</integer>
</dict>
<key>QOSClass</key>
<string>User-Interactive</string>
<key>StandardOutPath</key>
<string>$RUNNER_DIR/_diag/runner.log</string>
<key>StandardErrorPath</key>
<string>$RUNNER_DIR/_diag/runner.err</string>
<key>EnvironmentVariables</key>
<dict>
<key>PATH</key>
<string>/usr/local/bin:/opt/homebrew/bin:/usr/bin:/bin:/usr/sbin:/sbin</string>
</dict>
<key>Nice</key>
<integer>-20</integer>
</dict>
</plist>
EOF
# Set proper permissions for the LaunchDaemon
sudo chown root:wheel /Library/LaunchDaemons/com.github.runner.plist
sudo chmod 644 /Library/LaunchDaemons/com.github.runner.plist
# Remove any existing service
sudo launchctl bootout system/com.github.runner 2>/dev/null || true
# Load the new service using bootstrap
sudo launchctl bootstrap system /Library/LaunchDaemons/com.github.runner.plist
# Add Runner.Listener permissions (after runner installation)
RUNNER_PATH="$RUNNER_DIR/bin/Runner.Listener"
sudo chmod 755 "$RUNNER_PATH"
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --add "$RUNNER_PATH"
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --unblock "$RUNNER_PATH"
# Create connection info file if Tailscale is configured
if [ -n "$TAILSCALE_AUTH_KEY" ]; then
TAILSCALE_IP=$(tailscale ip)
cat > "$HOME/remote_access_info.txt" << EOL
Mac Remote Access Information
============================
Computer Name: $MACHINE_NAME
Username: $USER
Tailscale IP: $TAILSCALE_IP
SSH Command: ssh $USER@$TAILSCALE_IP
Screen Sharing: vnc://$TAILSCALE_IP
EOL
chmod 600 "$HOME/remote_access_info.txt"
fi
log "Verifying runner service status..."
if sudo launchctl list | grep com.github.runner > /dev/null; then
log "GitHub Actions runner service is running successfully!"
log "Runner labels: $CUSTOM_LABELS"
[ -n "$TAILSCALE_AUTH_KEY" ] && log "Remote access details saved to: $HOME/remote_access_info.txt"
else
log "Error: Failed to start GitHub Actions runner service"
exit 1
fi

95
.github/optimize_performance.sh vendored Executable file
View File

@@ -0,0 +1,95 @@
#!/bin/bash
set -e
# Function to log with timestamp
log() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
}
log "Applying comprehensive performance optimizations..."
# System-wide power management
log "Configuring power management..."
sudo pmset -a lessbright 0
sudo pmset -a disablesleep 1
sudo pmset -a sleep 0
sudo pmset -a hibernatemode 0
sudo pmset -a autopoweroff 0
sudo pmset -a standby 0
sudo pmset -a powernap 0
sudo pmset -a proximitywake 0
sudo pmset -a tcpkeepalive 1
sudo pmset -a powermode 2
sudo pmset -a gpuswitch 2
sudo pmset -a displaysleep 0
sudo pmset -a disksleep 0
# Memory and kernel optimizations
log "Configuring memory and kernel settings..."
sudo sysctl -w kern.memorystatus_purge_on_warning=0
sudo sysctl -w kern.memorystatus_purge_on_critical=0
sudo sysctl -w kern.timer.coalescing_enabled=0
# Metal and GPU optimizations
log "Configuring Metal and GPU settings..."
defaults write com.apple.CoreML MPSEnableGPUValidation -bool false
defaults write com.apple.CoreML MPSEnableMetalValidation -bool false
defaults write com.apple.CoreML MPSEnableGPUDebug -bool false
defaults write com.apple.Metal GPUDebug -bool false
defaults write com.apple.Metal GPUValidation -bool false
defaults write com.apple.Metal MetalValidation -bool false
defaults write com.apple.Metal MetalCaptureEnabled -bool false
defaults write com.apple.Metal MTLValidationBehavior -string "Disabled"
defaults write com.apple.Metal EnableMTLDebugLayer -bool false
defaults write com.apple.Metal MTLDebugLevel -int 0
defaults write com.apple.Metal PreferIntegratedGPU -bool false
defaults write com.apple.Metal ForceMaximumPerformance -bool true
defaults write com.apple.Metal MTLPreferredDeviceGPUFrame -bool true
# Create MPS cache directory with proper permissions
sudo mkdir -p /tmp/mps_cache
sudo chmod 777 /tmp/mps_cache
# Process and resource limits
log "Configuring process limits..."
sudo launchctl limit maxfiles 524288 524288
ulimit -n 524288 || log "Warning: Could not set file descriptor limit"
ulimit -c 0
ulimit -l unlimited || log "Warning: Could not set memory lock limit"
# Export performance-related environment variables
cat << 'EOF' > /tmp/performance_env.sh
# Metal optimizations
export MTL_DEBUG_LAYER=0
export METAL_DEVICE_WRAPPER_TYPE=1
export METAL_DEBUG_ERROR_MODE=0
export METAL_FORCE_PERFORMANCE_MODE=1
export METAL_DEVICE_PRIORITY=high
export METAL_MAX_COMMAND_QUEUES=1024
export METAL_LOAD_LIMIT=0
export METAL_VALIDATION_ENABLED=0
export METAL_ENABLE_VALIDATION_LAYER=0
export OBJC_DEBUG_MISSING_POOLS=NO
export MPS_CACHEDIR=/tmp/mps_cache
# MLX optimizations
export MLX_USE_GPU=1
export MLX_METAL_COMPILE_ASYNC=1
export MLX_METAL_PREALLOCATE=1
export MLX_METAL_MEMORY_GUARD=0
export MLX_METAL_CACHE_KERNELS=1
export MLX_PLACEMENT_POLICY=metal
export MLX_METAL_VALIDATION=0
export MLX_METAL_DEBUG=0
export MLX_FORCE_P_CORES=1
export MLX_METAL_MEMORY_BUDGET=0
export MLX_METAL_PREWARM=1
# Python optimizations
export PYTHONUNBUFFERED=1
export PYTHONOPTIMIZE=2
export PYTHONHASHSEED=0
export PYTHONDONTWRITEBYTECODE=1
EOF
log "Performance optimizations completed. Environment variables written to /tmp/performance_env.sh"

206
.github/workflows/bench_job.yml vendored Normal file
View File

@@ -0,0 +1,206 @@
# This is the reusable workflow file
name: Distributed Job Runner
on:
workflow_call:
inputs:
config:
required: true
type: string
model:
required: true
type: string
calling_job_name:
required: true
type: string
network_interface:
required: true
type: string
jobs:
generate-matrix:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- id: set-matrix
env:
CONFIG: ${{ inputs.config }}
run: |
MATRIX=$(echo $CONFIG | jq -c '{cpu: [to_entries | .[] | .key as $k | range(.value) | $k]}')
echo "matrix=$MATRIX" >> $GITHUB_OUTPUT
run-distributed-job:
needs: generate-matrix
strategy:
matrix: ${{fromJson(needs.generate-matrix.outputs.matrix)}}
runs-on: ['self-hosted', 'macOS', '${{ matrix.cpu }}']
env:
HARDWARE_CONFIG: ${{ inputs.config }}
model: ${{ inputs.model }}
# Add performance-related environment variables
MTL_DEBUG_LAYER: 0
METAL_VALIDATION_ENABLED: 0
MLX_METAL_VALIDATION: 0
MLX_METAL_DEBUG: 0
MLX_FORCE_P_CORES: 1
MLX_METAL_PREWARM: 1
PYTHONOPTIMIZE: 2
steps:
- name: Cleanup workspace
run: |
sudo rm -rf "$GITHUB_WORKSPACE"
sudo mkdir -p "$GITHUB_WORKSPACE"
sudo chown -R $(whoami):$(id -g) "$GITHUB_WORKSPACE"
- uses: actions/checkout@v4
- name: Install dependencies
run: |
export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
python3.12 -m venv .venv || {
echo "Failed to find python3.12. Checking installation locations:"
ls -l /usr/local/bin/python* /opt/homebrew/bin/python* 2>/dev/null || true
exit 1
}
source .venv/bin/activate
pip install --upgrade pip
pip install -e .
pip install boto3==1.35.76
- name: Apply Performance Optimizations
run: |
# Export performance-related environment variables
cat << 'EOF' > /tmp/performance_env.sh
# MLX and Metal optimizations
export MTL_DEBUG_LAYER=0
export METAL_VALIDATION_ENABLED=0
export MLX_METAL_VALIDATION=0
export MLX_METAL_DEBUG=0
export MLX_FORCE_P_CORES=1
export MLX_METAL_PREWARM=1
export PYTHONOPTIMIZE=2
EOF
# Source the performance environment variables
source /tmp/performance_env.sh
# MLX Memory Settings
./configure_mlx.sh
# Verify optimizations
echo "Verifying performance settings..."
env | grep -E "MLX_|METAL_|MTL_"
- name: Run exo
env:
aws_access_key_id: ${{ secrets.S3_EXO_BENCHMARKS_AWS_ACCESS_KEY_ID }}
aws_secret_key: ${{ secrets.S3_EXO_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
run: |
# Source performance environment variables
source /tmp/performance_env.sh
# Debug information
echo "Current commit SHA: $GITHUB_SHA"
git rev-parse HEAD
git status
CALLING_JOB="${{ inputs.calling_job_name }}"
UNIQUE_JOB_ID="${CALLING_JOB}_${model}_${GITHUB_RUN_ID}"
ALL_NODE_IDS=$(for i in $(seq ${{ strategy.job-total }} -1 0); do echo -n "${UNIQUE_JOB_ID}_${i},"; done | sed 's/,$//')
MY_NODE_ID="${UNIQUE_JOB_ID}_${{ strategy.job-index }}"
source .venv/bin/activate
export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
echo "=== Before starting exo ==="
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | head -1
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | grep -i python
echo "Starting exo daemon..."
echo "Power mode settings:"
sudo pmset -g
# Start exo with explicit process control
sudo taskpolicy -d default -g default -a -t 0 -l 0 .venv/bin/exo \
--node-id="${MY_NODE_ID}" \
--node-id-filter="${ALL_NODE_IDS}" \
--interface-type-filter="${{ inputs.network_interface }}" \
--disable-tui \
--max-generate-tokens 250 \
--chatgpt-api-port 52415 > output1.log 2>&1 &
PID1=$!
echo "Exo process started with PID: $PID1"
tail -f output1.log &
TAIL1=$!
# Give process time to start
sleep 2
# Set additional process priorities
sudo renice -n -20 -p $PID1
sudo taskpolicy -t 4 -p $PID1
echo "=== After starting exo ==="
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | head -1
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | grep $PID1
echo "Additional process details:"
sudo powermetrics -n 1 -i 1000 --show-process-energy | grep -A 5 $PID1 || true
trap 'kill $TAIL1' EXIT
trap 'kill $PID1' EXIT
echo "Waiting for all nodes to connect..."
for i in {1..20}; do
echo "Attempt $i: Checking node count..."
nodes=$(curl -s http://localhost:52415/topology | jq ".nodes | length")
echo "Current node count: $nodes"
if [ "$nodes" -eq "${{ strategy.job-total }}" ]; then
echo "All nodes connected successfully!"
break
fi
if [ $i -eq 20 ]; then
echo "ERROR: Failed to connect all nodes after 20 attempts. Expected ${{ strategy.job-total }} nodes, but got $nodes"
exit 1
fi
sleep 5
done
if ! kill -0 $PID1 2>/dev/null; then
echo "ERROR: Instance (PID $PID1) died unexpectedly. Full log output:"
cat output1.log
exit 1
fi
if [ "${{ strategy.job-index }}" -eq "0" ]; then
sleep 10
echo "This is the primary node (index 0). Running benchmark..."
GITHUB_JOB=$CALLING_JOB python .github/bench.py
else
echo "This is a secondary node (index ${{ strategy.job-index }}). Waiting for completion..."
sleep 10
while true; do
echo "Checking if primary node is still running..."
nodes=$(curl -s http://localhost:52415/topology | jq ".nodes | length")
echo "Current node count: $nodes"
if [ "$nodes" -lt "${{ strategy.job-total }}" ]; then
echo "Primary node completed, exiting..."
break
fi
sleep 5
done
fi
- name: Check Final System State
if: always()
run: |
echo "=== Final System State ==="
sudo pmset -g
sudo powermetrics -n 1 -i 1000 --show-process-energy || true
system_profiler SPDisplaysDataType
sysctl iogpu
ps -eo pid,ppid,user,%cpu,%mem,nice,state,command | grep -i python
env | grep -E "MLX_|METAL_|MTL_"
echo "=== End Final System State ==="

71
.github/workflows/benchmarks.yml vendored Normal file
View File

@@ -0,0 +1,71 @@
name: Build and Test
on:
push:
branches: [ '*' ]
tags: [ '*' ]
pull_request:
branches: [ '*' ]
jobs:
single-m4-pro:
strategy:
matrix:
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
uses: ./.github/workflows/bench_job.yml
with:
config: '{"M4PRO_GPU16_24GB": 1}'
model: ${{ matrix.model }}
calling_job_name: 'single-m4-pro'
network_interface: 'Ethernet'
secrets: inherit
two-m4-pro-cluster:
strategy:
matrix:
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
uses: ./.github/workflows/bench_job.yml
with:
config: '{"M4PRO_GPU16_24GB": 2}'
model: ${{ matrix.model }}
calling_job_name: 'two-m4-pro-cluster'
network_interface: 'Ethernet'
secrets: inherit
# two-m4-pro-cluster-thunderbolt:
# strategy:
# matrix:
# model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
# uses: ./.github/workflows/bench_job.yml
# with:
# config: '{"M4PRO_GPU16_24GB": 2}'
# model: ${{ matrix.model }}
# calling_job_name: 'two-m4-pro-cluster-thunderbolt'
# network_interface: 'Thunderbolt'
# secrets: inherit
three-m4-pro-cluster:
strategy:
matrix:
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b', 'llama-3.3-70b']
fail-fast: false
uses: ./.github/workflows/bench_job.yml
with:
config: '{"M4PRO_GPU16_24GB": 3}'
model: ${{ matrix.model }}
calling_job_name: 'three-m4-pro-cluster'
network_interface: 'Ethernet'
secrets: inherit
# test-m3-single-node:
# strategy:
# matrix:
# model: ['llama-3.2-1b']
# fail-fast: false
# uses: ./.github/workflows/bench_job.yml
# with:
# config: '{"M3MAX_GPU40_128GB": 1}'
# model: ${{ matrix.model }}
# calling_job_name: 'test-m3-cluster'
# network_interface: 'Ethernet'
# secrets: inherit

View File

@@ -3,16 +3,41 @@
# Get the total memory in MB
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
# Set WIRED_LIMIT_MB to 80%
WIRED_LIMIT_MB=$(($TOTAL_MEM_MB * 80 / 100))
# Set WIRED_LWM_MB to 70%
WIRED_LWM_MB=$(($TOTAL_MEM_MB * 70 / 100))
# Calculate 80% and TOTAL_MEM_GB-5GB in MB
EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100))
MINUS_5GB=$((($TOTAL_MEM_MB - 5120)))
# Calculate 70% and TOTAL_MEM_GB-8GB in MB
SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100))
MINUS_8GB=$((($TOTAL_MEM_MB - 8192)))
# Set WIRED_LIMIT_MB to higher value
if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then
WIRED_LIMIT_MB=$EIGHTY_PERCENT
else
WIRED_LIMIT_MB=$MINUS_5GB
fi
# Set WIRED_LWM_MB to higher value
if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then
WIRED_LWM_MB=$SEVENTY_PERCENT
else
WIRED_LWM_MB=$MINUS_8GB
fi
# Display the calculated values
echo "Total memory: $TOTAL_MEM_MB MB"
echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
# Apply the values with sysctl
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
# Apply the values with sysctl, but check if we're already root
if [ "$EUID" -eq 0 ]; then
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
else
# Try without sudo first, fall back to sudo if needed
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
fi

View File

@@ -33,6 +33,7 @@ from exo.download.hf.hf_shard_download import HFShardDownloader
import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
from exo.apputil import create_animation_mp4
from collections import defaultdict
class Message:
@@ -199,6 +200,11 @@ class ChatGPTAPI:
self.prev_token_lens: Dict[str, int] = {}
self.stream_tasks: Dict[str, asyncio.Task] = {}
self.default_model = default_model or "llama-3.2-1b"
self.token_queues = defaultdict(asyncio.Queue)
# Get the callback system and register our handler
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished)))
self.system_prompt = system_prompt
cors = aiohttp_cors.setup(self.app)
@@ -223,6 +229,7 @@ class ChatGPTAPI:
cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
# Add static routes
@@ -348,13 +355,13 @@ class ChatGPTAPI:
async def handle_post_chat_completions(self, request):
data = await request.json()
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
stream = data.get("stream", False)
chat_request = parse_chat_request(data, self.default_model)
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
chat_request.model = self.default_model
if not chat_request.model or chat_request.model not in model_cards:
if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
chat_request.model = self.default_model
shard = build_base_shard(chat_request.model, self.inference_engine_classname)
if not shard:
@@ -365,7 +372,7 @@ class ChatGPTAPI:
)
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")
# Add system prompt if set
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
@@ -378,28 +385,13 @@ class ChatGPTAPI:
self.on_chat_completion_request(request_id, chat_request, prompt)
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
# request_id = None
# match = self.prompts.find_longest_prefix(prompt)
# if match and len(prompt) > len(match[1].prompt):
# if DEBUG >= 2:
# print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
# request_id = match[1].request_id
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
# # remove the matching prefix from the prompt
# prompt = prompt[len(match[1].prompt):]
# else:
# request_id = str(uuid.uuid4())
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
callback_id = f"chatgpt-api-wait-response-{request_id}"
callback = self.node.on_token.register(callback_id)
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")
try:
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")
if stream:
response = web.StreamResponse(
@@ -412,62 +404,74 @@ class ChatGPTAPI:
)
await response.prepare(request)
async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
new_tokens = tokens[prev_last_tokens_len:]
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
new_tokens = new_tokens[:-1]
if is_finished:
finish_reason = "stop"
if is_finished and not finish_reason:
finish_reason = "length"
try:
# Stream tokens while waiting for inference to complete
while True:
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
tokens, is_finished = await asyncio.wait_for(
self.token_queues[request_id].get(),
timeout=self.response_timeout
)
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}")
eos_token_id = None
if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
finish_reason = None
if is_finished: finish_reason = "stop" if tokens[-1] == eos_token_id else "length"
if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=} {finish_reason=}")
completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
tokens,
stream,
finish_reason,
"chat.completion",
)
completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
new_tokens,
stream,
finish_reason,
"chat.completion",
)
if DEBUG >= 2: print(f"Streaming completion: {completion}")
try:
await response.write(f"data: {json.dumps(completion)}\n\n".encode())
except Exception as e:
if DEBUG >= 2: print(f"Error streaming completion: {e}")
if DEBUG >= 2: traceback.print_exc()
def on_result(_request_id: str, tokens: List[int], is_finished: bool):
if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
if is_finished:
break
return _request_id == request_id and is_finished
await response.write_eof()
return response
_, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
try:
await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
except asyncio.TimeoutError:
print("WARNING: Stream task timed out. This should not happen.")
await response.write_eof()
return response
except asyncio.TimeoutError:
if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
return web.json_response({"detail": "Response generation timed out"}, status=408)
except Exception as e:
if DEBUG >= 2:
print(f"[ChatGPTAPI] Error processing prompt: {e}")
traceback.print_exc()
return web.json_response(
{"detail": f"Error processing prompt: {str(e)}"},
status=500
)
finally:
# Clean up the queue for this request
if request_id in self.token_queues:
if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
del self.token_queues[request_id]
else:
_, tokens, _ = await callback.wait(
lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
timeout=self.response_timeout,
)
tokens = []
while True:
_tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
tokens.extend(_tokens)
if is_finished:
break
finish_reason = "length"
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
eos_token_id = None
if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
tokens = tokens[:-1]
finish_reason = "stop"
return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
@@ -476,9 +480,6 @@ class ChatGPTAPI:
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
finally:
deregistered_callback = self.node.on_token.deregister(callback_id)
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
async def handle_post_image_generations(self, request):
data = await request.json()
@@ -678,6 +679,9 @@ class ChatGPTAPI:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool):
await self.token_queues[request_id].put((tokens, is_finished))
async def run(self, host: str = "0.0.0.0", port: int = 52415):
runner = web.AppRunner(self.app)
await runner.setup()

View File

@@ -441,7 +441,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
shard_specific_patterns.add(sorted_file_names[-1])
else:
shard_specific_patterns = set(["*.safetensors"])
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)
async def get_file_download_percentage(

View File

@@ -159,13 +159,14 @@ class HFShardDownloader(ShardDownloader):
print(f"Download calculation for {self.current_repo_id}:")
print(f"Total bytes: {total_bytes}")
print(f"Downloaded bytes: {downloaded_bytes}")
if DEBUG >= 3:
for file in relevant_files:
print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
return status
except Exception as e:
if DEBUG >= 2:
if DEBUG >= 3:
print(f"Error getting shard download status: {e}")
traceback.print_exc()
return None

View File

@@ -14,6 +14,7 @@ from pathlib import Path
import tempfile
import json
from concurrent.futures import ThreadPoolExecutor
import traceback
DEBUG = int(os.getenv("DEBUG", default="0"))
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
@@ -230,20 +231,21 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
def get_all_ip_addresses_and_interfaces():
try:
ip_addresses = []
for interface in get_if_list():
ip = get_if_addr(interface)
# Include all addresses, including loopback
# Filter out link-local addresses
if not ip.startswith('169.254.') and not ip.startswith('0.0.'):
# Remove "\\Device\\NPF_" prefix from interface name
try:
ip = get_if_addr(interface)
if ip.startswith("0.0."): continue
simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
ip_addresses.append((ip, simplified_interface))
except:
if DEBUG >= 1: print(f"Failed to get IP address for interface {interface}")
if DEBUG >= 1: traceback.print_exc()
if not ip_addresses:
if DEBUG >= 1: print("Failed to get any IP addresses. Defaulting to localhost.")
return [("localhost", "lo")]
return list(set(ip_addresses))
except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return [("localhost", "lo")]
async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
@@ -329,6 +331,30 @@ def is_frozen():
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
async def get_mac_system_info() -> Tuple[str, str, int]:
"""Get Mac system information using system_profiler."""
try:
output = await asyncio.get_running_loop().run_in_executor(
subprocess_pool,
lambda: subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
)
model_line = next((line for line in output.split("\n") if "Model Name" in line), None)
model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
chip_line = next((line for line in output.split("\n") if "Chip" in line), None)
chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
memory_line = next((line for line in output.split("\n") if "Memory" in line), None)
memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
memory_units = memory_str.split()
memory_value = int(memory_units[0])
memory = memory_value * 1024 if memory_units[1] == "GB" else memory_value
return model_id, chip_id, memory
except Exception as e:
if DEBUG >= 2: print(f"Error getting Mac system info: {e}")
return "Unknown Model", "Unknown Chip", 0
def get_exo_home() -> Path:
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"

View File

@@ -0,0 +1,7 @@
# Perf improvements
Target: 460 tok/sec
- removing sample goes from 369 -> 402
- performance degrades as we generate more tokens
- make mlx inference engien synchronous, removing thread pool executor: 402 -> 413
- remove self.on_opaque_status.trigger_all: 413 -> 418

View File

@@ -1,155 +1,167 @@
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.sample_utils import top_p_sampling
from mlx_lm.sample_utils import top_p_sampling, make_sampler
import mlx.optimizers as optim
from ..inference_engine import InferenceEngine
from .sharded_utils import load_shard, get_image_from_str
from .losses import loss_fns
from .losses import loss_fns
from ..shard import Shard
from typing import Dict, Optional, Tuple
from exo.download.shard_download import ShardDownloader
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from collections import OrderedDict
from mlx_lm.models.cache import make_prompt_cache
def sample_logits(
logits: mx.array,
temp: float = 0.0,
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None
) -> Tuple[mx.array, float]:
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
else:
token = mx.random.categorical(logits*(1/temp))
return token
from concurrent.futures import ThreadPoolExecutor
class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
self.shard_downloader = shard_downloader
self.executor = ThreadPoolExecutor(max_workers=1)
self.caches = OrderedDict()
self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
self.sampler = make_sampler(*self.sampler_params)
self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
self.session = {}
async def _eval_mlx(self, *args):
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
async def poll_state(self, request_id: str, max_caches=2):
if request_id in self.caches:
self.caches.move_to_end(request_id)
else:
newcache = await asyncio.get_running_loop().run_in_executor(self.executor, make_prompt_cache, self.model)
newcache = make_prompt_cache(self.model)
if len(self.caches) > max_caches:
self.caches.popitem(last=False)
self.caches[request_id] = newcache
return {"cache": self.caches[request_id]}
async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
y = mx.array(x)
logits = y[:, -1, :]
out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
return out
async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
if (temp, top_p, 0.0, 1) != self.sampler_params:
self.sampler_params = (temp, top_p, 0.0, 1)
self.sampler = make_sampler(*self.sampler_params)
logits = mx.array(x)
logits = logits[:, -1, :]
logprobs = logits - mx.logsumexp(logits, keepdims=True)
result = self.sampler(logprobs)
await self._eval_mlx(result)
return np.asarray(result, dtype=int)
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
await self.ensure_shard(shard)
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
return np.array(tokens)
return np.asarray(
await asyncio.get_running_loop().run_in_executor(
self._tokenizer_thread,
self.tokenizer.encode,
prompt
)
)
async def decode(self, shard: Shard, tokens) -> str:
await self.ensure_shard(shard)
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
return tokens
return await asyncio.get_running_loop().run_in_executor(
self._tokenizer_thread,
self.tokenizer.decode,
tokens
)
async def save_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.save_weights(path))
async def load_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.load_weights(path))
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
loop = asyncio.get_running_loop()
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
x = mx.array(input_data)
if self.model.model_type != 'StableDiffusionPipeline':
output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
output_data = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: self.model(x, **state, **(inference_state or {}))
)
inference_state = None
else:
output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
output_data = np.array(output_data)
result = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: self.model(x, **state, **(inference_state or {}))
)
output_data, inference_state = result
output_data = np.array(output_data, copy=False)
return output_data, inference_state
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
await self.ensure_shard(shard)
await self.save_session('loss', loss_fns[loss])
loop = asyncio.get_running_loop()
#print(f"evaluate in <- {inputs}")
x = mx.array(inputs)
y = mx.array(targets)
l = mx.array(lengths)
score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
#print(f"evaluate out -> {score}")
score = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: self.session['loss'](self.model, x, y, l)
)
return score
async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
await self.ensure_shard(shard)
if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
await self.save_session('train_layers', trainable_layers)
self.model.freeze()
self.model.apply_to_modules(lambda k, v: v.unfreeze() if any(lambda: k.endswith(i) for i in trainable_layers) else None)
def freeze_unfreeze():
self.model.freeze()
self.model.apply_to_modules(
lambda k, v: v.unfreeze() if any(k.endswith(layer_name) for layer_name in trainable_layers) else None
)
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, freeze_unfreeze)
if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
await self.save_session('lossname', loss)
await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
if 'opt' not in self.session:
await self.save_session('opt', opt(lr))
return True
async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
loop = asyncio.get_running_loop()
nothin = await self.ensure_train(shard, loss, opt, lr)
await self.ensure_train(shard, loss, opt, lr)
def train_step(inp, tar, lng):
lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
gradlayers = grad['model']['layers']
self.session['opt'].update(self.model, grad)
mx.eval(self.model.parameters(), self.session['opt'].state, lval)
return lval, gradlayers
return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
x = mx.array(inputs)
y = mx.array(targets)
l = mx.array(lengths)
score, gradients, eval_args = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: train_step(x, y, l)
)
await self._eval_mlx(*eval_args)
score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
#print(f"{score=}")
layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
#print(layers[0])
return score, np.array(layers[0]['input_layernorm'])
layers = [{k: v["weight"] for k, v in layer.items() if 'weight' in v} for layer in gradients if layer]
first_layer = np.array(layers[0]['input_layernorm'], copy=False)
await self._eval_mlx(first_layer)
return score, first_layer
async def ensure_shard(self, shard: Shard):
if self.shard == shard:
return
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
if self.shard != shard:
def load_shard_wrapper():
return asyncio.run(load_shard(model_path, shard))
model_shard, self.tokenizer = await asyncio.get_running_loop().run_in_executor(self.executor, load_shard_wrapper)
model_shard, self.tokenizer = await load_shard(model_path, shard)
self.shard = shard
self.model = model_shard
self.model = model_shard
self.caches = OrderedDict()
self.session = {}
async def cleanup(self):
self._mlx_thread.shutdown(wait=True)

View File

@@ -0,0 +1,81 @@
import asyncio
import time
import numpy as np
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.inference.shard import Shard
from exo.models import build_base_shard
from collections import deque
from statistics import mean, median
async def test_non_blocking():
# Setup
shard_downloader = HFShardDownloader()
engine = MLXDynamicShardInferenceEngine(shard_downloader)
_shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
await engine.ensure_shard(shard)
queue = asyncio.Queue()
measurements = deque(maxlen=1000000)
running = True
async def mlx_worker():
try:
start_time = time.time()
count = 0
while running and (time.time() - start_time) < 5: # Hard time limit
start = time.perf_counter_ns()
await engine.infer_prompt("req1", shard, "test prompt")
duration = (time.perf_counter_ns() - start) / 1_000_000 # Convert to ms
count += 1
print(f"MLX operation {count} took: {duration:.3f}ms")
except asyncio.CancelledError:
pass
finally:
print(f"\nTotal MLX operations completed: {count}")
print(f"Average rate: {count/5:.1f} ops/second")
async def latency_producer():
try:
start_time = time.perf_counter_ns()
count = 0
while running:
await queue.put(time.perf_counter_ns())
count += 1
await asyncio.sleep(0) # Yield to event loop without delay
duration = (time.perf_counter_ns() - start_time) / 1e9 # Convert to seconds
print(f"\nProducer iterations: {count}")
print(f"Producer rate: {count/duration:.1f} iterations/second")
except asyncio.CancelledError:
pass
async def latency_consumer():
try:
while running:
timestamp = await queue.get()
latency = (time.perf_counter_ns() - timestamp) / 1_000_000 # Convert to ms
measurements.append(latency)
queue.task_done()
except asyncio.CancelledError:
pass
tasks = [
asyncio.create_task(mlx_worker()),
asyncio.create_task(latency_producer()),
asyncio.create_task(latency_consumer())
]
try:
await asyncio.wait_for(asyncio.gather(*tasks), timeout=6)
except asyncio.TimeoutError:
print("\nTest timed out")
finally:
running = False
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
print(f"\nFinal measurement count: {len(measurements)}")
if __name__ == "__main__":
asyncio.run(test_non_blocking())

View File

@@ -13,7 +13,6 @@ import uuid
import numpy as np
from functools import partial
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
from exo.train.dataset import load_dataset, iterate_batches, compose
from exo.networking.manual.manual_discovery import ManualDiscovery
from exo.networking.manual.network_topology_config import NetworkTopology
@@ -33,6 +32,46 @@ from exo.inference.tokenizers import resolve_tokenizer
from exo.models import build_base_shard, get_repo
from exo.viz.topology_viz import TopologyViz
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
import uvloop
from contextlib import asynccontextmanager
import concurrent.futures
import socket
import resource
import psutil
# TODO: figure out why this is happening
os.environ["GRPC_VERBOSITY"] = "error"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# Configure uvloop for maximum performance
def configure_uvloop():
# Install uvloop as event loop policy
uvloop.install()
# Create new event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Increase file descriptor limits on Unix systems
if not psutil.WINDOWS:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
try:
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
except ValueError:
try:
resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
except ValueError:
pass
# Configure thread pool for blocking operations
loop.set_default_executor(
concurrent.futures.ThreadPoolExecutor(
max_workers=min(32, (os.cpu_count() or 1) * 4)
)
)
return loop
# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -52,7 +91,6 @@ parser.add_argument("--models-seed-dir", type=str, default=None, help="Model see
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
@@ -69,6 +107,7 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")
@@ -101,8 +140,9 @@ if DEBUG >= 0:
for chatgpt_api_endpoint in chatgpt_api_endpoints:
print(f" - {terminal_link(chatgpt_api_endpoint)}")
# Convert node-id-filter to list if provided
# Convert node-id-filter and interface-type-filter to lists if provided
allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
allowed_interface_types = args.interface_type_filter.split(',') if args.interface_type_filter else None
if args.discovery_module == "udp":
discovery = UDPDiscovery(
@@ -112,7 +152,8 @@ if args.discovery_module == "udp":
args.broadcast_port,
lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
discovery_timeout=args.discovery_timeout,
allowed_node_ids=allowed_node_ids
allowed_node_ids=allowed_node_ids,
allowed_interface_types=allowed_interface_types
)
elif args.discovery_module == "tailscale":
discovery = TailscaleDiscovery(
@@ -150,9 +191,16 @@ api = ChatGPTAPI(
default_model=args.default_model,
system_prompt=args.system_prompt
)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None
)
buffered_token_output = {}
def update_topology_viz(req_id, tokens, __):
if not topology_viz: return
if not inference_engine.shard: return
if inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return
if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
else: buffered_token_output[req_id] = tokens
topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(buffered_token_output[req_id]))
node.on_token.register("update_topology_viz").on_next(update_topology_viz)
def preemptively_start_download(request_id: str, opaque_status: str):
try:
@@ -169,10 +217,6 @@ def preemptively_start_download(request_id: str, opaque_status: str):
node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
if args.prometheus_client_port:
from exo.stats.metrics import start_metrics_server
start_metrics_server(node, args.prometheus_client_port)
last_broadcast_time = 0
@@ -204,7 +248,11 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
print(f"Processing prompt: {prompt}")
await node.process_prompt(shard, prompt, request_id=request_id)
_, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
tokens = []
def on_token(_request_id, _tokens, _is_finished):
tokens.extend(_tokens)
return _request_id == request_id and _is_finished
await callback.wait(on_token, timeout=300)
print("\nGenerated response:")
print(tokenizer.decode(tokens))
@@ -223,7 +271,7 @@ def clean_path(path):
async def hold_outstanding(node: Node):
while node.outstanding_requests:
await asyncio.sleep(.5)
return
return
async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
losses = []
@@ -234,7 +282,7 @@ async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
tokens.append(np.sum(lengths))
total_tokens = np.sum(tokens)
total_loss = np.sum(losses) / total_tokens
return total_loss, total_tokens
async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
@@ -270,7 +318,7 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
await hold_outstanding(node)
await hold_outstanding(node)
async def main():
loop = asyncio.get_running_loop()
@@ -285,7 +333,7 @@ async def main():
{"❌ No read access" if not has_read else ""}
{"❌ No write access" if not has_write else ""}
""")
if not args.models_seed_dir is None:
try:
models_seed_dir = clean_path(args.models_seed_dir)
@@ -330,29 +378,31 @@ async def main():
print("Error: This train ain't leaving the station without a model")
return
await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
else:
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
await asyncio.Event().wait()
if args.wait_for_peers > 0:
print("Cooldown to allow peers to exit gracefully")
for i in tqdm(range(50)):
await asyncio.sleep(.1)
@asynccontextmanager
async def setup_node(args):
# Rest of setup_node implementation...
pass
def run():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
print("Received keyboard interrupt. Shutting down...")
finally:
loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
loop.close()
loop = None
try:
loop = configure_uvloop()
loop.run_until_complete(main())
except KeyboardInterrupt:
print("\nShutdown requested... exiting")
finally:
if loop:
loop.close()
if __name__ == "__main__":
run()

View File

@@ -28,6 +28,19 @@ class GRPCPeerHandle(PeerHandle):
self._device_capabilities = device_capabilities
self.channel = None
self.stub = None
self.channel_options = [
("grpc.max_metadata_size", 64 * 1024 * 1024),
("grpc.max_receive_message_length", 256 * 1024 * 1024),
("grpc.max_send_message_length", 256 * 1024 * 1024),
("grpc.max_concurrent_streams", 100),
("grpc.http2.min_time_between_pings_ms", 10000),
("grpc.keepalive_time_ms", 20000),
("grpc.keepalive_timeout_ms", 10000),
("grpc.keepalive_permit_without_calls", 1),
("grpc.http2.max_pings_without_data", 0),
("grpc.tcp_nodelay", 1),
("grpc.optimization_target", "throughput"),
]
def id(self) -> str:
return self._id
@@ -44,7 +57,9 @@ class GRPCPeerHandle(PeerHandle):
async def connect(self):
if self.channel is None:
self.channel = grpc.aio.insecure_channel(
self.address, options=[("grpc.max_metadata_size", 32*1024*1024), ('grpc.max_receive_message_length', 32*1024*1024), ('grpc.max_send_message_length', 32*1024*1024)]
self.address,
options=self.channel_options,
compression=grpc.Compression.Gzip
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()
@@ -59,7 +74,13 @@ class GRPCPeerHandle(PeerHandle):
self.stub = None
async def _ensure_connected(self):
if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
if not await self.is_connected():
try:
await asyncio.wait_for(self.connect(), timeout=10.0)
except asyncio.TimeoutError:
if DEBUG >= 2: print(f"Connection timeout for {self._id}@{self.address}")
await self.disconnect()
raise
async def health_check(self) -> bool:
try:
@@ -88,12 +109,7 @@ class GRPCPeerHandle(PeerHandle):
request_id=request_id,
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response = await self.stub.SendPrompt(request)
if not response.tensor_data or not response.shape or not response.dtype:
return None
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
await self.stub.SendPrompt(request)
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
@@ -154,16 +170,6 @@ class GRPCPeerHandle(PeerHandle):
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
response = await self.stub.GetInferenceResult(request)
if response.tensor is None:
return None, response.is_finished
return (
np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
response.is_finished,
)
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
response = await self.stub.CollectTopology(request)

View File

@@ -27,11 +27,19 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
async def start(self) -> None:
self.server = grpc.aio.server(
futures.ThreadPoolExecutor(max_workers=10),
futures.ThreadPoolExecutor(max_workers=32),
options=[
("grpc.max_metadata_size", 32*1024*1024),
("grpc.max_send_message_length", 128*1024*1024),
("grpc.max_receive_message_length", 128*1024*1024),
("grpc.max_send_message_length", 256*1024*1024),
("grpc.max_receive_message_length", 256*1024*1024),
("grpc.keepalive_time_ms", 10000),
("grpc.keepalive_timeout_ms", 5000),
("grpc.http2.max_pings_without_data", 0),
("grpc.http2.min_time_between_pings_ms", 10000),
("grpc.http2.min_ping_interval_without_data_ms", 5000),
("grpc.max_concurrent_streams", 100),
("grpc.tcp_nodelay", 1),
("grpc.optimization_target", "throughput"),
],
)
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)

View File

@@ -6,7 +6,6 @@ service NodeService {
rpc SendPrompt (PromptRequest) returns (Tensor) {}
rpc SendTensor (TensorRequest) returns (Tensor) {}
rpc SendExample (ExampleRequest) returns (Loss) {}
rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
rpc SendResult (SendResultRequest) returns (Empty) {}
rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
@@ -47,15 +46,6 @@ message Loss {
float loss = 1;
optional Tensor grads = 2;
}
message GetInferenceResultRequest {
string request_id = 1;
}
message InferenceResult {
optional Tensor tensor = 1;
bool is_finished = 2;
}
message Tensor {
bytes tensor_data = 1;

File diff suppressed because one or more lines are too long

View File

@@ -3,7 +3,7 @@
import grpc
import warnings
from exo.networking.grpc import node_service_pb2 as exo_dot_networking_dot_grpc_dot_node__service__pb2
from . import node_service_pb2 as node__service__pb2
GRPC_GENERATED_VERSION = '1.68.0'
GRPC_VERSION = grpc.__version__
@@ -18,7 +18,7 @@ except ImportError:
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in exo/networking/grpc/node_service_pb2_grpc.py depends on'
+ f' but the generated code in node_service_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@@ -36,43 +36,38 @@ class NodeServiceStub(object):
"""
self.SendPrompt = channel.unary_unary(
'/node_service.NodeService/SendPrompt',
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendTensor = channel.unary_unary(
'/node_service.NodeService/SendTensor',
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendExample = channel.unary_unary(
'/node_service.NodeService/SendExample',
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
_registered_method=True)
self.GetInferenceResult = channel.unary_unary(
'/node_service.NodeService/GetInferenceResult',
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
response_deserializer=node__service__pb2.Loss.FromString,
_registered_method=True)
self.CollectTopology = channel.unary_unary(
'/node_service.NodeService/CollectTopology',
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=node__service__pb2.Topology.FromString,
_registered_method=True)
self.SendResult = channel.unary_unary(
'/node_service.NodeService/SendResult',
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True)
self.SendOpaqueStatus = channel.unary_unary(
'/node_service.NodeService/SendOpaqueStatus',
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
_registered_method=True)
self.HealthCheck = channel.unary_unary(
'/node_service.NodeService/HealthCheck',
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
_registered_method=True)
@@ -97,12 +92,6 @@ class NodeServiceServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetInferenceResult(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def CollectTopology(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -132,43 +121,38 @@ def add_NodeServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendPrompt': grpc.unary_unary_rpc_method_handler(
servicer.SendPrompt,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
request_deserializer=node__service__pb2.PromptRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'SendTensor': grpc.unary_unary_rpc_method_handler(
servicer.SendTensor,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
request_deserializer=node__service__pb2.TensorRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
),
'SendExample': grpc.unary_unary_rpc_method_handler(
servicer.SendExample,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.SerializeToString,
),
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
servicer.GetInferenceResult,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.SerializeToString,
request_deserializer=node__service__pb2.ExampleRequest.FromString,
response_serializer=node__service__pb2.Loss.SerializeToString,
),
'CollectTopology': grpc.unary_unary_rpc_method_handler(
servicer.CollectTopology,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.SerializeToString,
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=node__service__pb2.Topology.SerializeToString,
),
'SendResult': grpc.unary_unary_rpc_method_handler(
servicer.SendResult,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
request_deserializer=node__service__pb2.SendResultRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
servicer.SendOpaqueStatus,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
),
'HealthCheck': grpc.unary_unary_rpc_method_handler(
servicer.HealthCheck,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.SerializeToString,
request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
@@ -196,8 +180,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendPrompt',
exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
node__service__pb2.PromptRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
@@ -223,8 +207,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendTensor',
exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
node__service__pb2.TensorRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
@@ -250,35 +234,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendExample',
exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetInferenceResult(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/GetInferenceResult',
exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
node__service__pb2.ExampleRequest.SerializeToString,
node__service__pb2.Loss.FromString,
options,
channel_credentials,
insecure,
@@ -304,8 +261,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/CollectTopology',
exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
node__service__pb2.CollectTopologyRequest.SerializeToString,
node__service__pb2.Topology.FromString,
options,
channel_credentials,
insecure,
@@ -331,8 +288,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendResult',
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
node__service__pb2.SendResultRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -358,8 +315,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendOpaqueStatus',
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -385,8 +342,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/HealthCheck',
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
node__service__pb2.HealthCheckRequest.SerializeToString,
node__service__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,

View File

@@ -63,8 +63,7 @@ class ManualDiscovery(Discovery):
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
self.known_peers = new_known_peers
await asyncio.sleep(1.0)
await asyncio.sleep(5.0)
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")

View File

@@ -51,10 +51,6 @@ class PeerHandle(ABC):
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
pass
@abstractmethod
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
pass
@abstractmethod
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
pass

View File

@@ -40,7 +40,7 @@ class TailscaleDiscovery(Discovery):
self.update_task = None
async def start(self):
self.device_capabilities = device_capabilities()
self.device_capabilities = await device_capabilities()
self.discovery_task = asyncio.create_task(self.task_discover_peers())
self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
self.update_task = asyncio.create_task(self.task_update_device_posture_attributes())

View File

@@ -3,7 +3,7 @@ import json
import socket
import time
import traceback
from typing import List, Dict, Callable, Tuple, Coroutine
from typing import List, Dict, Callable, Tuple, Coroutine, Optional
from exo.networking.discovery import Discovery
from exo.networking.peer_handle import PeerHandle
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
@@ -23,15 +23,29 @@ class ListenProtocol(asyncio.DatagramProtocol):
asyncio.create_task(self.on_message(data, addr))
def get_broadcast_address(ip_addr: str) -> str:
try:
# Split IP into octets and create broadcast address for the subnet
ip_parts = ip_addr.split('.')
return f"{ip_parts[0]}.{ip_parts[1]}.{ip_parts[2]}.255"
except:
return "255.255.255.255"
class BroadcastProtocol(asyncio.DatagramProtocol):
def __init__(self, message: str, broadcast_port: int):
def __init__(self, message: str, broadcast_port: int, source_ip: str):
self.message = message
self.broadcast_port = broadcast_port
self.source_ip = source_ip
def connection_made(self, transport):
sock = transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
# Try both subnet-specific and global broadcast
broadcast_addr = get_broadcast_address(self.source_ip)
transport.sendto(self.message.encode("utf-8"), (broadcast_addr, self.broadcast_port))
if broadcast_addr != "255.255.255.255":
transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))
class UDPDiscovery(Discovery):
@@ -45,7 +59,8 @@ class UDPDiscovery(Discovery):
broadcast_interval: int = 2.5,
discovery_timeout: int = 30,
device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
allowed_node_ids: List[str] = None,
allowed_node_ids: Optional[List[str]] = None,
allowed_interface_types: Optional[List[str]] = None,
):
self.node_id = node_id
self.node_port = node_port
@@ -56,13 +71,14 @@ class UDPDiscovery(Discovery):
self.discovery_timeout = discovery_timeout
self.device_capabilities = device_capabilities
self.allowed_node_ids = allowed_node_ids
self.allowed_interface_types = allowed_interface_types
self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
self.broadcast_task = None
self.listen_task = None
self.cleanup_task = None
async def start(self):
self.device_capabilities = device_capabilities()
self.device_capabilities = await device_capabilities()
self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
self.listen_task = asyncio.create_task(self.task_listen_for_peers())
self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
@@ -82,11 +98,7 @@ class UDPDiscovery(Discovery):
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]
async def task_broadcast_presence(self):
if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")
while True:
# Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
# the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
for addr, interface_name in get_all_ip_addresses_and_interfaces():
interface_priority, interface_type = await get_interface_priority_and_type(interface_name)
message = json.dumps({
@@ -94,16 +106,26 @@ class UDPDiscovery(Discovery):
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
"priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
"priority": interface_priority,
"interface_name": interface_name,
"interface_type": interface_type,
})
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")
transport = None
try:
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except AttributeError:
pass
sock.bind((addr, 0))
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: BroadcastProtocol(message, self.broadcast_port, addr),
sock=sock
)
except Exception as e:
print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
finally:
@@ -111,7 +133,7 @@ class UDPDiscovery(Discovery):
try: transport.close()
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
if DEBUG_DISCOVERY >= 2: traceback.print_exc()
await asyncio.sleep(self.broadcast_interval)
async def on_listen_message(self, data, addr):
@@ -147,6 +169,12 @@ class UDPDiscovery(Discovery):
peer_prio = message["priority"]
peer_interface_name = message["interface_name"]
peer_interface_type = message["interface_type"]
# Skip if interface type is not in allowed list
if self.allowed_interface_types and peer_interface_type not in self.allowed_interface_types:
if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as its interface type {peer_interface_type} is not in the allowed interface types list")
return
device_capabilities = DeviceCapabilities(**message["device_capabilities"])
if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":

View File

@@ -8,7 +8,7 @@ from typing import List, Dict, Optional, Tuple, Union, Set
from exo.networking import Discovery, PeerHandle, Server
from exo.inference.inference_engine import InferenceEngine, Shard
from exo.topology.topology import Topology
from exo.topology.device_capabilities import device_capabilities
from exo.topology.device_capabilities import device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
from exo import DEBUG
from exo.helpers import AsyncCallbackSystem
@@ -37,7 +37,7 @@ class Node:
self.partitioning_strategy = partitioning_strategy
self.peers: List[PeerHandle] = {}
self.topology: Topology = Topology()
self.device_capabilities = device_capabilities()
self.device_capabilities = UNKNOWN_DEVICE_CAPABILITIES
self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
self.buffered_logits: Dict[str, List[np.ndarray]] = {}
self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
@@ -56,6 +56,7 @@ class Node:
self.outstanding_requests = {}
async def start(self, wait_for_peers: int = 0) -> None:
self.device_capabilities = await device_capabilities()
await self.server.start()
await self.discovery.start()
await self.update_peers(wait_for_peers)
@@ -70,25 +71,28 @@ class Node:
def on_node_status(self, request_id, opaque_status):
try:
status_data = json.loads(opaque_status)
if status_data.get("type", "") == "supported_inference_engines":
status_type = status_data.get("type", "")
if status_type == "supported_inference_engines":
node_id = status_data.get("node_id")
engines = status_data.get("engines", [])
self.topology_inference_engines_pool.append(engines)
if status_data.get("type", "") == "node_status":
elif status_type == "node_status":
if status_data.get("status", "").startswith("start_"):
self.current_topology.active_node_id = status_data.get("node_id")
elif status_data.get("status", "").startswith("end_"):
if status_data.get("node_id") == self.current_topology.active_node_id:
self.current_topology.active_node_id = None
download_progress = None
if status_data.get("type", "") == "download_progress":
if status_type == "download_progress":
if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
self.node_download_progress[status_data.get('node_id')] = download_progress
if self.topology_viz:
self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress)
except Exception as e:
if DEBUG >= 1: print(f"Error updating visualization: {e}")
if DEBUG >= 1: print(f"Error on_node_status: {e}")
if DEBUG >= 1: traceback.print_exc()
def get_supported_inference_engines(self):
@@ -107,6 +111,8 @@ class Node:
def get_topology_inference_engines(self) -> List[List[str]]:
return self.topology_inference_engines_pool
token_count = 0
first_token_time = 0
async def process_inference_result(
self,
shard,
@@ -124,9 +130,8 @@ class Node:
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
forward = token.reshape(1, -1)
intermediate_result = self.buffered_token_output[request_id][0]
intermediate_result = [self.buffered_token_output[request_id][0][-1]]
else:
forward = result
else:
@@ -157,6 +162,7 @@ class Node:
inference_state: Optional[dict] = {},
) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
start_time = time.perf_counter_ns()
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
@@ -187,18 +193,17 @@ class Node:
"prompt": prompt,
"request_id": request_id,
"elapsed_time_ns": elapsed_time_ns,
"result_size": resp.size if resp is not None else 0,
}),
)
)
return resp
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
if request_id is None:
request_id = str(uuid.uuid4())
shard = self.get_current_shard(base_shard)
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
if not shard.is_first_layer():
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
self.outstanding_requests[request_id] = "waiting"
@@ -355,41 +360,11 @@ class Node:
inference_state: Optional[dict] = None,
) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
json.dumps({
"type": "node_status",
"node_id": self.id,
"status": "start_process_tensor",
"base_shard": base_shard.to_dict(),
"shard": shard.to_dict(),
"tensor_size": tensor.size,
"tensor_shape": tensor.shape,
"request_id": request_id,
}),
)
)
start_time = time.perf_counter_ns()
resp = await self._process_tensor(shard, tensor, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
json.dumps({
"type": "node_status",
"node_id": self.id,
"status": "end_process_tensor",
"base_shard": base_shard.to_dict(),
"shard": shard.to_dict(),
"request_id": request_id,
"elapsed_time_ns": elapsed_time_ns,
"result_size": resp.size if resp is not None else 0,
}),
)
)
return resp
if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
async def _process_tensor(
self,
@@ -402,7 +377,6 @@ class Node:
request_id = str(uuid.uuid4())
shard = self.get_current_shard(base_shard)
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
try:
self.outstanding_requests[request_id] = "processing"
result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
@@ -412,7 +386,6 @@ class Node:
self.outstanding_requests.pop(request_id)
print(f"Error processing tensor for shard {shard}: {e}")
traceback.print_exc()
return None
async def forward_example(
self,
@@ -558,18 +531,13 @@ class Node:
try:
did_peers_change = await self.update_peers()
if DEBUG >= 2: print(f"{did_peers_change=}")
await self.collect_topology(set())
if did_peers_change:
await self.collect_topology(set())
await self.select_best_inference_engine()
except Exception as e:
print(f"Error collecting topology: {e}")
traceback.print_exc()
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
if request_id not in self.buffered_token_output:
return None, False
return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology:
next_topology = Topology()
next_topology.update_node(self.id, self.device_capabilities)
@@ -614,7 +582,7 @@ class Node:
return self._on_opaque_status
def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=}")
self.on_token.trigger_all(request_id, tokens, is_finished)
async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:

View File

@@ -1,6 +1,7 @@
import unittest
from unittest.mock import Mock, AsyncMock
import numpy as np
import pytest
from .node import Node
from exo.networking.peer_handle import PeerHandle
@@ -55,3 +56,11 @@ class TestNode(unittest.IsolatedAsyncioTestCase):
await self.node.process_tensor(input_tensor, None)
self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)
@pytest.mark.asyncio
async def test_node_capabilities():
node = Node()
await node.initialize()
caps = await node.get_device_capabilities()
assert caps is not None
assert caps.model != ""

View File

@@ -0,0 +1,166 @@
from dataclasses import dataclass
from typing import Dict, Optional, Any
from opentelemetry import trace, context
from opentelemetry.trace import Status, StatusCode, SpanContext
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from contextlib import contextmanager
import time
from threading import Lock
@dataclass
class TraceContext:
request_id: str
sequence_number: int
current_span: Optional[trace.Span] = None
trace_parent: Optional[str] = None
token_group_span: Optional[trace.Span] = None
token_count: int = 0
token_group_size: int = 10 # Default group size
request_span: Optional[trace.Span] = None # Track the main request span
class Tracer:
def __init__(self):
self.tracer = trace.get_tracer("exo")
self.contexts: Dict[str, TraceContext] = {}
self._lock = Lock()
self.propagator = TraceContextTextMapPropagator()
def get_context(self, request_id: str) -> Optional[TraceContext]:
with self._lock:
return self.contexts.get(request_id)
def set_context(self, request_id: str, context: TraceContext):
with self._lock:
self.contexts[request_id] = context
def inject_context(self, span: trace.Span) -> str:
"""Inject current span context into carrier for propagation"""
carrier = {}
ctx = trace.set_span_in_context(span)
self.propagator.inject(carrier, context=ctx)
return carrier.get("traceparent", "")
def extract_context(self, trace_parent: str) -> Optional[context.Context]:
"""Extract span context from carrier"""
if not trace_parent:
return None
carrier = {"traceparent": trace_parent}
return self.propagator.extract(carrier)
def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext:
"""Create a new context with the given trace parent"""
parent_ctx = self.extract_context(trace_parent)
if parent_ctx:
# Create a new request span that links to the parent context
request_span = self.tracer.start_span(
"request",
context=parent_ctx,
attributes={
"request_id": request_id,
"sequence_number": sequence_number
}
)
return TraceContext(
request_id=request_id,
sequence_number=sequence_number,
request_span=request_span,
current_span=request_span,
trace_parent=trace_parent
)
return TraceContext(request_id=request_id, sequence_number=sequence_number)
def handle_token(self, context: TraceContext, token: int, is_finished: bool = False):
"""Handle token generation and manage token group spans"""
context.token_count += 1
# Start a new token group span if needed
if not context.token_group_span and context.request_span:
group_number = (context.token_count - 1) // context.token_group_size + 1
# Create token group span as child of request span
parent_ctx = trace.set_span_in_context(context.request_span)
context.token_group_span = self.tracer.start_span(
f"token_group_{group_number}",
context=parent_ctx,
attributes={
"request_id": context.request_id,
"group.number": group_number,
"group.start_token": context.token_count,
"group.max_tokens": context.token_group_size
}
)
# Add token to current group span
if context.token_group_span:
relative_pos = ((context.token_count - 1) % context.token_group_size) + 1
context.token_group_span.set_attribute(f"token.{relative_pos}", token)
context.token_group_span.set_attribute("token.count", relative_pos)
# End current group span if we've reached the group size or if generation is finished
if context.token_count % context.token_group_size == 0 or is_finished:
context.token_group_span.set_attribute("token.final_count", relative_pos)
context.token_group_span.end()
context.token_group_span = None
@contextmanager
def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None):
"""Start a new span with proper parent context"""
attributes = {
"request_id": context.request_id,
"sequence_number": context.sequence_number
}
if extra_attributes:
attributes.update(extra_attributes)
# Use request span as parent if available
parent_ctx = None
if context.request_span:
parent_ctx = trace.set_span_in_context(context.request_span)
elif context.trace_parent:
parent_ctx = self.extract_context(context.trace_parent)
if parent_ctx and not context.request_span:
# Create a new request span that links to the parent context
context.request_span = self.tracer.start_span(
"request",
context=parent_ctx,
attributes={
"request_id": context.request_id,
"sequence_number": context.sequence_number
}
)
parent_ctx = trace.set_span_in_context(context.request_span)
elif context.current_span:
parent_ctx = trace.set_span_in_context(context.current_span)
# Create span with parent context if it exists
if parent_ctx:
span = self.tracer.start_span(
name,
context=parent_ctx,
attributes=attributes
)
else:
span = self.tracer.start_span(
name,
attributes=attributes
)
# Update context with current span
prev_span = context.current_span
context.current_span = span
try:
start_time = time.perf_counter()
yield span
duration = time.perf_counter() - start_time
span.set_attribute("duration_s", duration)
span.set_status(Status(StatusCode.OK))
except Exception as e:
span.set_status(Status(StatusCode.ERROR, str(e)))
raise
finally:
span.end()
context.current_span = prev_span
# Global tracer instance
tracer = Tracer()

View File

View File

@@ -1,27 +0,0 @@
version: '3.8'
services:
prometheus:
image: prom/prometheus:latest
container_name: prometheus
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
command:
- '--config.file=/etc/prometheus/prometheus.yml'
ports:
- "9090:9090"
networks:
- monitoring
grafana:
image: grafana/grafana:latest
container_name: grafana
ports:
- "3000:3000"
networks:
- monitoring
depends_on:
- prometheus
networks:
monitoring:

View File

@@ -1,29 +0,0 @@
from exo.orchestration import Node
from prometheus_client import start_http_server, Counter, Histogram
import json
# Create metrics to track time spent and requests made.
PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
PROCESS_TENSOR_COUNTER = Counter("process_tensor_total", "Total number of tensors processed", ["node_id"])
PROCESS_TENSOR_TIME = Histogram("process_tensor_seconds", "Time spent processing tensor", ["node_id"])
def start_metrics_server(node: Node, port: int):
start_http_server(port)
def _on_opaque_status(request_id, opaque_status: str):
status_data = json.loads(opaque_status)
_type = status_data.get("type", "")
node_id = status_data.get("node_id", "")
if _type != "node_status":
return
status = status_data.get("status", "")
if status == "end_process_prompt":
PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc()
elif status == "end_process_tensor":
elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9) # Convert ns to seconds
node.on_opaque_status.register("stats").on_next(_on_opaque_status)

View File

@@ -1,7 +0,0 @@
global:
scrape_interval: 15s
scrape_configs:
- job_name: 'exo-node'
static_configs:
- targets: ['host.docker.internal:8005']

View File

@@ -654,4 +654,92 @@ main {
.model-download-button i {
font-size: 0.9em;
}
.topology-section {
margin-bottom: 30px;
padding: 15px;
background: rgba(255, 255, 255, 0.05);
border-radius: 8px;
}
.topology-visualization {
min-height: 150px;
position: relative;
margin-top: 10px;
}
.topology-loading {
display: flex;
align-items: center;
gap: 10px;
color: #666;
font-size: 0.9em;
}
.topology-node {
padding: 8px;
background: rgba(255, 255, 255, 0.05);
border-radius: 4px;
margin: 4px 0;
display: flex;
flex-direction: column;
gap: 4px;
}
.node-info {
display: flex;
align-items: center;
gap: 6px;
font-size: 0.9em;
}
.topology-node .status {
width: 6px;
height: 6px;
border-radius: 50%;
flex-shrink: 0;
}
.topology-node .status.active {
background: #4CAF50;
}
.topology-node .status.inactive {
background: #666;
}
.node-details {
padding-left: 12px;
display: flex;
flex-direction: column;
gap: 2px;
font-size: 0.8em;
opacity: 0.6;
}
.node-details span {
display: flex;
align-items: center;
}
.peer-connections {
margin-top: 8px;
padding-left: 12px;
display: flex;
flex-direction: column;
gap: 4px;
}
.peer-connection {
display: flex;
align-items: center;
gap: 8px;
font-size: 0.85em;
color: #a0a0a0;
}
.peer-connection i {
font-size: 0.8em;
color: #666;
}

View File

@@ -26,21 +26,36 @@
<body>
<main x-data="state" x-init="console.log(endpoint)">
<div class="sidebar">
<!-- Add topology section -->
<div class="topology-section">
<h2 class="megrim-regular">Network Topology</h2>
<div class="topology-visualization"
x-init="initTopology()"
x-ref="topologyViz">
<!-- Loading indicator for topology -->
<div class="topology-loading" x-show="!topology">
<i class="fas fa-spinner fa-spin"></i>
<span>Loading topology...</span>
</div>
<!-- Topology visualization will be rendered here -->
</div>
</div>
<h2 class="megrim-regular" style="margin-bottom: 20px;">Models</h2>
<!-- Loading indicator -->
<div class="loading-container" x-show="Object.keys(models).length === 0">
<i class="fas fa-spinner fa-spin"></i>
<span>Loading models...</span>
</div>
<template x-for="(model, key) in models" :key="key">
<div class="model-option"
<div class="model-option"
:class="{ 'selected': cstate.selectedModel === key }"
@click="cstate.selectedModel = key">
<div class="model-header">
<div class="model-name" x-text="model.name"></div>
<button
<button
@click.stop="deleteModel(key, model)"
class="model-delete-button"
x-show="model.download_percentage > 0">
@@ -56,7 +71,7 @@
<template x-if="!model.loading && model.download_percentage != null">
<span>
<!-- Check if there's an active download for this model -->
<template x-if="downloadProgress?.some(p =>
<template x-if="downloadProgress?.some(p =>
p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
)">
<i class="fas fa-circle-notch fa-spin"></i>
@@ -65,7 +80,7 @@
</span>
</template>
<template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
<button
<button
@click.stop="handleDownload(key)"
class="model-download-button">
<i class="fas fa-download"></i>
@@ -75,22 +90,22 @@
</div>
</div>
<template x-if="model.total_size">
<div class="model-size" x-text="model.total_downloaded ?
`${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
<div class="model-size" x-text="model.total_downloaded ?
`${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
formatBytes(model.total_size)">
</div>
</template>
</div>
</div>
</template>
</div>
</div>
<!-- Error Toast -->
<div x-show="errorMessage !== null" x-transition.opacity class="toast">
<div class="toast-header">
<span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
<div class="toast-header-buttons">
<button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }"
class="toast-expand-button"
<button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }"
class="toast-expand-button"
x-show="errorMessage?.stack">
<span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
</button>
@@ -119,8 +134,8 @@
" x-show="home === 0" x-transition="">
<h1 class="title megrim-regular">tinychat</h1>
<template x-if="histories.length">
<button
@click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();"
<button
@click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();"
class="clear-history-button">
<i class="fas fa-trash"></i> Clear All History
</button>
@@ -162,14 +177,14 @@
</template>
</div>
</div>
<button
<button
@click="
home = 0;
cstate = { time: null, messages: [], selectedModel: cstate.selectedModel };
time_till_first = 0;
tokens_per_second = 0;
total_tokens = 0;
"
"
class="back-button"
x-show="home === 2">
<i class="fas fa-arrow-left"></i>
@@ -250,7 +265,7 @@
<p><strong>Model:</strong> <span x-text="progress.repo_id + '@' + progress.repo_revision"></span></p>
<p><strong>Status:</strong> <span x-text="progress.status"></span></p>
<div class="progress-bar-container">
<div class="progress-bar"
<div class="progress-bar"
:class="progress.isComplete ? 'complete' : 'in-progress'"
:style="`width: ${progress.percentage}%;`">
</div>
@@ -294,10 +309,10 @@
<i class="fas fa-times"></i>
</button>
</div>
<textarea
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
<textarea
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
:placeholder="
generating ? 'Generating...' :
generating ? 'Generating...' :
(downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete)) ? 'Download in progress...' :
'Say something'
"
@@ -329,9 +344,9 @@
});
"
x-ref="inputForm"></textarea>
<button
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
@click="await handleSend()"
<button
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
@click="await handleSend()"
class="input-button">
<i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
</button>

View File

@@ -5,7 +5,7 @@ document.addEventListener("alpine:init", () => {
time: null,
messages: [],
selectedModel: 'llama-3.2-1b',
},
},
// historical state
histories: JSON.parse(localStorage.getItem("histories")) || [],
@@ -13,7 +13,7 @@ document.addEventListener("alpine:init", () => {
home: 0,
generating: false,
endpoint: `${window.location.origin}/v1`,
// Initialize error message structure
errorMessage: null,
errorExpanded: false,
@@ -39,6 +39,9 @@ document.addEventListener("alpine:init", () => {
// Add models state alongside existing state
models: {},
topology: null,
topologyInterval: null,
init() {
// Clean up any pending messages
localStorage.removeItem("pendingMessage");
@@ -48,7 +51,7 @@ document.addEventListener("alpine:init", () => {
// Start polling for download progress
this.startDownloadProgressPolling();
// Start model polling with the new pattern
this.startModelPolling();
},
@@ -82,14 +85,14 @@ document.addEventListener("alpine:init", () => {
async populateSelector() {
return new Promise((resolve, reject) => {
const evtSource = new EventSource(`${window.location.origin}/modelpool`);
evtSource.onmessage = (event) => {
if (event.data === "[DONE]") {
evtSource.close();
resolve();
return;
}
const modelData = JSON.parse(event.data);
// Update existing model data while preserving other properties
Object.entries(modelData).forEach(([modelName, data]) => {
@@ -102,7 +105,7 @@ document.addEventListener("alpine:init", () => {
}
});
};
evtSource.onerror = (error) => {
console.error('EventSource failed:', error);
evtSource.close();
@@ -509,7 +512,7 @@ document.addEventListener("alpine:init", () => {
stack: error.stack || ""
};
this.errorExpanded = false;
if (this.errorTimeout) {
clearTimeout(this.errorTimeout);
}
@@ -524,10 +527,10 @@ document.addEventListener("alpine:init", () => {
async deleteModel(modelName, model) {
const downloadedSize = model.total_downloaded || 0;
const sizeMessage = downloadedSize > 0 ?
const sizeMessage = downloadedSize > 0 ?
`This will free up ${this.formatBytes(downloadedSize)} of space.` :
'This will remove any partially downloaded files.';
if (!confirm(`Are you sure you want to delete ${model.name}? ${sizeMessage}`)) {
return;
}
@@ -541,7 +544,7 @@ document.addEventListener("alpine:init", () => {
});
const data = await response.json();
if (!response.ok) {
throw new Error(data.detail || 'Failed to delete model');
}
@@ -600,6 +603,71 @@ document.addEventListener("alpine:init", () => {
console.error('Error starting download:', error);
this.setError(error);
}
},
async fetchTopology() {
try {
const response = await fetch(`${this.endpoint}/topology`);
if (!response.ok) throw new Error('Failed to fetch topology');
return await response.json();
} catch (error) {
console.error('Topology fetch error:', error);
return null;
}
},
initTopology() {
// Initial fetch
this.updateTopology();
// Set up periodic updates
this.topologyInterval = setInterval(() => this.updateTopology(), 5000);
// Cleanup on page unload
window.addEventListener('beforeunload', () => {
if (this.topologyInterval) {
clearInterval(this.topologyInterval);
}
});
},
async updateTopology() {
const topologyData = await this.fetchTopology();
if (!topologyData) return;
const vizElement = this.$refs.topologyViz;
vizElement.innerHTML = ''; // Clear existing visualization
// Create nodes from object
Object.entries(topologyData.nodes).forEach(([nodeId, node]) => {
const nodeElement = document.createElement('div');
nodeElement.className = 'topology-node';
// Get peer connections for this node
const peerConnections = topologyData.peer_graph[nodeId] || [];
const peerConnectionsHtml = peerConnections.map(peer => `
<div class="peer-connection">
<i class="fas fa-arrow-right"></i>
<span>To ${peer.to_id}: ${peer.description}</span>
</div>
`).join('');
nodeElement.innerHTML = `
<div class="node-info">
<span class="status ${nodeId === topologyData.active_node_id ? 'active' : 'inactive'}"></span>
<span>${node.model}</span>
</div>
<div class="node-details">
<span>${node.chip}</span>
<span>${(node.memory / 1024).toFixed(1)}GB RAM</span>
<span>${node.flops.fp32.toFixed(1)} TF</span>
</div>
<div class="peer-connections">
${peerConnectionsHtml}
</div>
`;
vizElement.appendChild(nodeElement);
});
}
}));
});

View File

@@ -3,6 +3,8 @@ from pydantic import BaseModel
from exo import DEBUG
import subprocess
import psutil
import asyncio
from exo.helpers import get_mac_system_info, subprocess_pool
TFLOPS = 1.00
@@ -144,13 +146,13 @@ CHIP_FLOPS.update({f"{key} LAPTOP GPU": value for key, value in CHIP_FLOPS.items
CHIP_FLOPS.update({f"{key} Laptop GPU": value for key, value in CHIP_FLOPS.items()})
def device_capabilities() -> DeviceCapabilities:
async def device_capabilities() -> DeviceCapabilities:
if psutil.MACOS:
return mac_device_capabilities()
return await mac_device_capabilities()
elif psutil.LINUX:
return linux_device_capabilities()
return await linux_device_capabilities()
elif psutil.WINDOWS:
return windows_device_capabilities()
return await windows_device_capabilities()
else:
return DeviceCapabilities(
model="Unknown Device",
@@ -160,27 +162,18 @@ def device_capabilities() -> DeviceCapabilities:
)
def mac_device_capabilities() -> DeviceCapabilities:
# Fetch the model of the Mac using system_profiler
model = subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
model_line = next((line for line in model.split("\n") if "Model Name" in line), None)
model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
chip_line = next((line for line in model.split("\n") if "Chip" in line), None)
chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
memory_line = next((line for line in model.split("\n") if "Memory" in line), None)
memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
memory_units = memory_str.split()
memory_value = int(memory_units[0])
if memory_units[1] == "GB":
memory = memory_value*1024
else:
memory = memory_value
# Assuming static values for other attributes for demonstration
return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
async def mac_device_capabilities() -> DeviceCapabilities:
model_id, chip_id, memory = await get_mac_system_info()
return DeviceCapabilities(
model=model_id,
chip=chip_id,
memory=memory,
flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0))
)
def linux_device_capabilities() -> DeviceCapabilities:
async def linux_device_capabilities() -> DeviceCapabilities:
import psutil
from tinygrad import Device

View File

@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod
from typing import List
from typing import List, Dict
from dataclasses import dataclass
from .topology import Topology
from exo.inference.shard import Shard
from exo.topology.device_capabilities import device_capabilities
import asyncio
# Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1

View File

@@ -1,11 +1,11 @@
import unittest
import pytest
from unittest.mock import patch
from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS
from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS, device_capabilities
class TestMacDeviceCapabilities(unittest.TestCase):
@patch("subprocess.check_output")
def test_mac_device_capabilities_pro(self, mock_check_output):
@pytest.mark.asyncio
@patch("subprocess.check_output")
async def test_mac_device_capabilities_pro(mock_check_output):
# Mock the subprocess output
mock_check_output.return_value = b"""
Hardware:
@@ -27,20 +27,19 @@ Activation Lock Status: Enabled
"""
# Call the function
result = mac_device_capabilities()
result = await mac_device_capabilities()
# Check the results
self.assertIsInstance(result, DeviceCapabilities)
self.assertEqual(result.model, "MacBook Pro")
self.assertEqual(result.chip, "Apple M3 Max")
self.assertEqual(result.memory, 131072) # 16 GB in MB
self.assertEqual(
str(result),
"Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
)
assert isinstance(result, DeviceCapabilities)
assert result.model == "MacBook Pro"
assert result.chip == "Apple M3 Max"
assert result.memory == 131072 # 128 GB in MB
assert str(result) == "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS"
@patch("subprocess.check_output")
def test_mac_device_capabilities_air(self, mock_check_output):
@pytest.mark.asyncio
@patch("subprocess.check_output")
async def test_mac_device_capabilities_air(mock_check_output):
# Mock the subprocess output
mock_check_output.return_value = b"""
Hardware:
@@ -62,30 +61,34 @@ Activation Lock Status: Disabled
"""
# Call the function
result = mac_device_capabilities()
result = await mac_device_capabilities()
# Check the results
self.assertIsInstance(result, DeviceCapabilities)
self.assertEqual(result.model, "MacBook Air")
self.assertEqual(result.chip, "Apple M2")
self.assertEqual(result.memory, 8192) # 8 GB in MB
assert isinstance(result, DeviceCapabilities)
assert result.model == "MacBook Air"
assert result.chip == "Apple M2"
assert result.memory == 8192 # 8 GB in MB
@unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
def test_mac_device_capabilities_real(self):
@pytest.mark.skip(reason="Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
@pytest.mark.asyncio
async def test_mac_device_capabilities_real():
# Call the function without mocking
result = mac_device_capabilities()
result = await mac_device_capabilities()
# Check the results
self.assertIsInstance(result, DeviceCapabilities)
self.assertEqual(result.model, "MacBook Pro")
self.assertEqual(result.chip, "Apple M3 Max")
self.assertEqual(result.memory, 131072) # 128 GB in MB
self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS))
self.assertEqual(
str(result),
"Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
)
assert isinstance(result, DeviceCapabilities)
assert result.model == "MacBook Pro"
assert result.chip == "Apple M3 Max"
assert result.memory == 131072 # 128 GB in MB
assert result.flops == DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS)
assert str(result) == "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS"
if __name__ == "__main__":
unittest.main()
@pytest.mark.asyncio
async def test_device_capabilities():
caps = await device_capabilities()
assert caps.model != ""
assert caps.chip != ""
assert caps.memory > 0
assert caps.flops is not None

View File

@@ -74,9 +74,9 @@ def gen_diff(table_old, table_new):
def create_json_report(table, is_diff=False):
timestamp = datetime.now(timezone.utc).isoformat()
commit_sha = os.environ.get('CIRCLE_SHA1', 'unknown')
branch = os.environ.get('CIRCLE_BRANCH', 'unknown')
pr_number = os.environ.get('CIRCLE_PR_NUMBER', '')
commit_sha = os.environ.get('GITHUB_SHA', 'unknown')
branch = os.environ.get('GITHUB_REF_NAME', 'unknown')
pr_number = os.environ.get('GITHUB_EVENT_NUMBER', '')
if is_diff:
files = [{

View File

@@ -28,14 +28,19 @@ install_requires = [
"tqdm==4.66.4",
"transformers==4.46.3",
"uuid==1.30",
"uvloop==0.21.0",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
]
extras_require = {
"formatting": ["yapf==0.40.2",], "apple_silicon": [
"mlx==0.20.0",
"mlx-lm==0.19.3",
], "windows": ["pywin32==308",], "nvidia-gpu": ["nvidia-ml-py==12.560.30",], "amd-gpu": ["pyrsmi==0.2.0"]
"formatting": ["yapf==0.40.2",],
"apple_silicon": [
"mlx==0.21.1",
"mlx-lm==0.20.4",
],
"windows": ["pywin32==308",],
"nvidia-gpu": ["nvidia-ml-py==12.560.30",],
"amd-gpu": ["pyrsmi==0.2.0"],
}
# Check if running on macOS with Apple Silicon