diff --git a/DEV_PLAN.md b/DEV_PLAN.md deleted file mode 100644 index 5738244..0000000 --- a/DEV_PLAN.md +++ /dev/null @@ -1,1488 +0,0 @@ -# Development Plan: Async Job Queue & GPU Health Monitoring - -**Version:** 1.0 -**Date:** 2025-10-07 -**Status:** Implementation Ready - ---- - -## Table of Contents - -1. [Executive Summary](#executive-summary) -2. [Problem Statement](#problem-statement) -3. [Solution Architecture](#solution-architecture) -4. [Component Specifications](#component-specifications) -5. [Data Structures](#data-structures) -6. [API Specifications](#api-specifications) -7. [Implementation Phases](#implementation-phases) -8. [Testing Strategy](#testing-strategy) -9. [Environment Variables](#environment-variables) -10. [Error Handling](#error-handling) - ---- - -## Executive Summary - -This plan introduces an **asynchronous job queue system** with **GPU health monitoring** to address two critical production issues: - -1. **HTTP Request Timeouts**: Long audio transcriptions (10+ minutes) cause client timeouts -2. **Silent GPU Failures**: GPU driver issues cause models to fall back to CPU silently, resulting in 10-100x slower processing - -### Solution Overview - -- **Async Job Queue**: FIFO queue with immediate response, background processing, and disk persistence -- **GPU Health Monitoring**: Real transcription tests with tiny model, periodic monitoring, and **strict failure handling** -- **Clean API**: Async-only endpoints (REST + MCP) optimized for LLM agents -- **Zero External Dependencies**: Uses Python stdlib (threading, queue, json) only - ---- - -## Problem Statement - -### Problem 1: Request Timeout Issues - -**Current Behavior:** -``` -Client → POST /transcribe → [waits 10+ minutes] → Timeout -``` - -**Impact:** -- Clients experience HTTP timeouts on long audio files -- No way to check progress -- Failed requests waste GPU time - -**Root Cause:** -- Synchronous request/response pattern -- HTTP clients have timeout limits (30-120 seconds typical) -- Transcription can take 5-20+ minutes for long files - -### Problem 2: Silent GPU Fallback - -**Current Behavior:** -```python -# model_manager.py:64-66 -if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" - compute_type = "float16" if device == "cuda" else "int8" -``` - -**Issue:** -- `torch.cuda.is_available()` can return `True` but model loading can still fail -- GPU driver issues, OOM errors, or CUDA incompatibilities cause silent fallback to CPU -- Current `test_gpu_driver()` only tests tensor operations, not model loading -- Processing becomes 10-100x slower without notification - -**Impact:** -- Users expect 2-minute transcription, get 30-minute CPU transcription -- No error message, just extremely slow processing -- Wastes resources and user time - ---- - -## Solution Architecture - -### High-Level Design - -``` -┌─────────────────────────────────────────────────────────────────┐ -│ Client (HTTP / MCP) │ -└───────────────────┬─────────────────────────────────────────────┘ - │ - │ Submit Job - ↓ -┌─────────────────────────────────────────────────────────────────┐ -│ API Server / MCP Server │ -│ ┌──────────────────────────────────────────────────────────┐ │ -│ │ 1. Validate Request │ │ -│ │ 2. Check GPU Health (if device=cuda) │ │ -│ │ 3. Generate job_id │ │ -│ │ 4. Add to Queue │ │ -│ │ 5. Save Job Metadata to Disk │ │ -│ │ 6. Return {job_id, status: "queued", queue_position} │ │ -│ └──────────────────────────────────────────────────────────┘ │ -└───────────────────┬─────────────────────────────────────────────┘ - │ - │ Job in Queue - ↓ -┌─────────────────────────────────────────────────────────────────┐ -│ Job Queue Manager │ -│ ┌──────────────────────────────────────────────────────────┐ │ -│ │ In-Memory Queue (queue.Queue) │ │ -│ │ - FIFO ordering │ │ -│ │ - Thread-safe │ │ -│ │ - Max size limit from env var │ │ -│ └──────────────────────────────────────────────────────────┘ │ -│ ┌──────────────────────────────────────────────────────────┐ │ -│ │ Background Worker Thread (Single Worker) │ │ -│ │ 1. Pop job from queue │ │ -│ │ 2. Update status → "running" │ │ -│ │ 3. Call transcribe_audio() │ │ -│ │ 4. Update status → "completed"/"failed" │ │ -│ │ 5. Save result metadata │ │ -│ └──────────────────────────────────────────────────────────┘ │ -│ ┌──────────────────────────────────────────────────────────┐ │ -│ │ Disk Persistence Layer │ │ -│ │ - One JSON file per job: {job_id}.json │ │ -│ │ - Survives server restarts │ │ -│ │ - Load on startup │ │ -│ └──────────────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────────┘ - -┌─────────────────────────────────────────────────────────────────┐ -│ GPU Health Monitor │ -│ ┌──────────────────────────────────────────────────────────┐ │ -│ │ Startup Health Check │ │ -│ │ 1. Generate 1-second test audio │ │ -│ │ 2. Load tiny model │ │ -│ │ 3. Transcribe test audio │ │ -│ │ 4. Time execution (GPU: <1s, CPU: 5-10s) │ │ -│ │ 5. If expected=cuda but got=cpu → REJECT │ │ -│ └──────────────────────────────────────────────────────────┘ │ -│ ┌──────────────────────────────────────────────────────────┐ │ -│ │ Periodic Background Monitoring (Every 10 min) │ │ -│ │ - Re-run health check │ │ -│ │ - Log warnings if degradation detected │ │ -│ │ - Store health history │ │ -│ └──────────────────────────────────────────────────────────┘ │ -│ ┌──────────────────────────────────────────────────────────┐ │ -│ │ On-Demand Health Check │ │ -│ │ - Exposed via /health/gpu endpoint │ │ -│ │ - Returns detailed GPU status │ │ -│ └──────────────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────────┘ -``` - -### Client Interaction Flow - -``` -Client Server Background Worker - | | | - |-- POST /jobs (submit) ------->| | - | |-- Add to queue --------------->| - | |-- Save to disk | - |<-- {job_id, queued, pos:3} ---| | - | | | - | | [Processing job 1] - | | | - |-- GET /jobs/{id} (poll) ----->| | - |<-- {status: queued, pos:2} ---| | - | | | - | [Wait 10 seconds] | [Processing job 2] - | | | - |-- GET /jobs/{id} (poll) ----->| | - |<-- {status: queued, pos:1} ---| | - | | | - | [Wait 10 seconds] | | - | | | - |-- GET /jobs/{id} (poll) ----->| [Start our job] - |<-- {status: running} ---------| | - | | | - | [Wait 30 seconds] | [Transcribing...] - | | | - |-- GET /jobs/{id} (poll) ----->| | - |<-- {status: running} ---------| | - | | | - | [Wait 30 seconds] | [Job completed] - | |<-- Update status --------------| - | |<-- Save result | - | | | - |-- GET /jobs/{id} (poll) ----->| | - |<-- {status: completed} -------| | - | | | - |-- GET /jobs/{id}/result ----->| | - |<-- Transcription text --------| | - | | | -``` - ---- - -## Component Specifications - -### Component 1: Test Audio Generator - -**File:** `src/utils/test_audio_generator.py` - -**Purpose:** Generate synthetic test audio programmatically (no need to bundle .mp3 files) - -**Key Functions:** -```python -def generate_test_audio() -> str: - """ - Generate a 1-second test audio file for GPU health checks. - - Returns: - str: Path to temporary audio file - - Implementation: - - Generate 1 second of 440Hz sine wave (A note) - - 16kHz sample rate, mono - - Save as WAV format (simplest) - - Store in system temp directory - - Reuse same file if exists (cache) - """ -``` - -**Dependencies:** -- numpy (already installed) -- scipy.io.wavfile (stdlib) -- tempfile (stdlib) - ---- - -### Component 2: GPU Health Monitor - -**File:** `src/core/gpu_health.py` - -**Purpose:** Test GPU functionality with actual model loading and transcription - -**Key Classes/Functions:** - -```python -class GPUHealthStatus: - """Data class for health check results""" - gpu_available: bool # torch.cuda.is_available() - gpu_working: bool # Model actually loaded on GPU - device_used: str # "cuda" or "cpu" - device_name: str # GPU name if available - memory_total_gb: float # Total GPU memory - memory_available_gb: float # Available GPU memory - test_duration_seconds: float # How long test took - timestamp: str # ISO timestamp - error: str | None # Error message if any - -def check_gpu_health(expected_device: str = "auto") -> GPUHealthStatus: - """ - Comprehensive GPU health check using real model + transcription. - - Args: - expected_device: Expected device ("auto", "cuda", "cpu") - - Returns: - GPUHealthStatus object - - Raises: - RuntimeError: If expected_device="cuda" but GPU test fails - - Implementation Steps: - 1. Generate test audio (1 second) - 2. Load tiny model with requested device - 3. Transcribe test audio - 4. Time the operation - 5. Verify model actually ran on GPU (check torch.cuda.memory_allocated) - 6. CRITICAL: If expected_device="cuda" but used="cpu" → raise RuntimeError - 7. Return detailed status - - Performance Expectations: - - GPU (tiny model): 0.3-1.0 seconds - - CPU (tiny model): 3-10 seconds - - If GPU test takes >2 seconds, likely running on CPU - """ - -class HealthMonitor: - """Background thread for periodic GPU health monitoring""" - - def __init__(self, check_interval_minutes: int = 10): - """Initialize health monitor""" - - def start(self): - """Start background monitoring thread""" - - def stop(self): - """Stop background monitoring thread""" - - def get_latest_status(self) -> GPUHealthStatus: - """Get most recent health check result""" - - def get_health_history(self, limit: int = 10) -> List[GPUHealthStatus]: - """Get recent health check history""" -``` - -**Critical Error Handling:** - -```python -# In check_gpu_health() -if expected_device == "cuda" and actual_device == "cpu": - error_msg = ( - "GPU device requested but model loaded on CPU. " - "This indicates GPU driver issues or insufficient memory. " - "Transcription would be 10-100x slower than expected. " - "Please check CUDA installation and GPU availability." - ) - raise RuntimeError(error_msg) -``` - ---- - -### Component 3: Job Queue Manager - -**File:** `src/core/job_queue.py` - -**Purpose:** Manage async job queue with FIFO processing and disk persistence - -**Key Classes:** - -```python -class JobStatus(Enum): - """Job status enumeration""" - QUEUED = "queued" # In queue, waiting - RUNNING = "running" # Currently processing - COMPLETED = "completed" # Successfully finished - FAILED = "failed" # Error occurred - -class Job: - """Represents a transcription job""" - job_id: str # UUID - status: JobStatus # Current status - created_at: datetime # When job was created - started_at: datetime | None # When processing started - completed_at: datetime | None # When processing finished - queue_position: int # Position in queue (0 if running) - - # Request parameters - audio_path: str - model_name: str - device: str - compute_type: str - language: str | None - output_format: str - beam_size: int - temperature: float - initial_prompt: str | None - output_directory: str | None - - # Results - result_path: str | None # Path to transcription file - error: str | None # Error message if failed - processing_time_seconds: float | None - - def to_dict(self) -> dict: - """Serialize to dictionary for JSON storage""" - - @classmethod - def from_dict(cls, data: dict) -> 'Job': - """Deserialize from dictionary""" - - def save_to_disk(self, metadata_dir: str): - """Save job metadata to {metadata_dir}/{job_id}.json""" - -class JobQueue: - """Manages job queue with background worker""" - - def __init__(self, - max_queue_size: int = 100, - metadata_dir: str = "/outputs/jobs"): - """ - Initialize job queue. - - Args: - max_queue_size: Maximum number of jobs in queue - metadata_dir: Directory to store job metadata JSON files - """ - self._queue = queue.Queue(maxsize=max_queue_size) - self._jobs = {} # job_id -> Job - self._metadata_dir = metadata_dir - self._worker_thread = None - self._stop_event = threading.Event() - self._current_job_id = None - self._lock = threading.Lock() - - def start(self): - """ - Start background worker thread. - Load existing jobs from disk on startup. - """ - - def stop(self, wait_for_current: bool = True): - """ - Stop background worker. - - Args: - wait_for_current: If True, wait for current job to complete - """ - - def submit_job(self, - audio_path: str, - model_name: str = "large-v3", - device: str = "auto", - compute_type: str = "auto", - language: str | None = None, - output_format: str = "txt", - beam_size: int = 5, - temperature: float = 0.0, - initial_prompt: str | None = None, - output_directory: str | None = None) -> dict: - """ - Submit a new transcription job. - - Returns: - dict: { - "job_id": str, - "status": str, - "queue_position": int, - "created_at": str - } - - Raises: - queue.Full: If queue is at max capacity - RuntimeError: If GPU health check fails (when device="cuda") - """ - # 1. Validate audio file exists - # 2. Check GPU health if device="cuda" (raises if fails) - # 3. Generate job_id - # 4. Create Job object - # 5. Add to queue (raises queue.Full if full) - # 6. Save to disk - # 7. Return job info - - def get_job_status(self, job_id: str) -> dict: - """ - Get current status of a job. - - Returns: - dict: { - "job_id": str, - "status": str, - "queue_position": int | None, - "created_at": str, - "started_at": str | None, - "completed_at": str | None, - "result_path": str | None, - "error": str | None, - "processing_time_seconds": float | None - } - - Raises: - KeyError: If job_id not found - """ - - def get_job_result(self, job_id: str) -> str: - """ - Get transcription result text for completed job. - - Returns: - str: Content of transcription file - - Raises: - KeyError: If job_id not found - ValueError: If job not completed - FileNotFoundError: If result file missing - """ - - def list_jobs(self, - status_filter: JobStatus | None = None, - limit: int = 100) -> List[dict]: - """ - List jobs with optional status filter. - - Args: - status_filter: Only return jobs with this status - limit: Maximum number of jobs to return - - Returns: - List of job status dictionaries - """ - - def _worker_loop(self): - """ - Background worker thread function. - Processes jobs from queue in FIFO order. - """ - while not self._stop_event.is_set(): - try: - # Get job from queue (with timeout to check stop_event) - job = self._queue.get(timeout=1.0) - - with self._lock: - self._current_job_id = job.job_id - job.status = JobStatus.RUNNING - job.started_at = datetime.utcnow() - job.save_to_disk(self._metadata_dir) - - # Process job - start_time = time.time() - try: - result = transcribe_audio( - audio_path=job.audio_path, - model_name=job.model_name, - device=job.device, - compute_type=job.compute_type, - language=job.language, - output_format=job.output_format, - beam_size=job.beam_size, - temperature=job.temperature, - initial_prompt=job.initial_prompt, - output_directory=job.output_directory - ) - - # Parse result - if "saved to:" in result: - job.result_path = result.split("saved to:")[1].strip() - job.status = JobStatus.COMPLETED - else: - job.status = JobStatus.FAILED - job.error = result - - except Exception as e: - job.status = JobStatus.FAILED - job.error = str(e) - logger.error(f"Job {job.job_id} failed: {e}") - - finally: - job.completed_at = datetime.utcnow() - job.processing_time_seconds = time.time() - start_time - job.save_to_disk(self._metadata_dir) - - with self._lock: - self._current_job_id = None - - self._queue.task_done() - - except queue.Empty: - continue - - def _load_jobs_from_disk(self): - """Load existing job metadata from disk on startup""" - - def _calculate_queue_positions(self): - """Update queue_position for all queued jobs""" -``` - ---- - -## Data Structures - -### Job Metadata JSON Format - -**File:** `{JOB_METADATA_DIR}/{job_id}.json` - -```json -{ - "job_id": "550e8400-e29b-41d4-a716-446655440000", - "status": "completed", - "created_at": "2025-10-07T10:30:00.123456Z", - "started_at": "2025-10-07T10:30:05.234567Z", - "completed_at": "2025-10-07T10:32:15.345678Z", - "queue_position": 0, - - "request_params": { - "audio_path": "/media/raid/audio/interview.mp3", - "model_name": "large-v3", - "device": "cuda", - "compute_type": "float16", - "language": "en", - "output_format": "txt", - "beam_size": 5, - "temperature": 0.0, - "initial_prompt": null, - "output_directory": "/media/raid/outputs" - }, - - "result_path": "/media/raid/outputs/interview.txt", - "error": null, - "processing_time_seconds": 130.22 -} -``` - -### GPU Health Status JSON Format - -```json -{ - "gpu_available": true, - "gpu_working": true, - "device_used": "cuda", - "device_name": "NVIDIA RTX 3090", - "memory_total_gb": 24.0, - "memory_available_gb": 20.5, - "test_duration_seconds": 0.87, - "timestamp": "2025-10-07T10:30:00.123456Z", - "error": null -} -``` - ---- - -## API Specifications - -### REST API Endpoints - -#### 1. Submit Job (Async Transcription) - -**Endpoint:** `POST /jobs` - -**Request Body:** -```json -{ - "audio_path": "/path/to/audio.mp3", - "model_name": "large-v3", - "device": "auto", - "compute_type": "auto", - "language": "en", - "output_format": "txt", - "beam_size": 5, - "temperature": 0.0, - "initial_prompt": null, - "output_directory": null -} -``` - -**Success Response (200):** -```json -{ - "job_id": "550e8400-e29b-41d4-a716-446655440000", - "status": "queued", - "queue_position": 3, - "created_at": "2025-10-07T10:30:00.123456Z", - "message": "Job submitted successfully. Poll /jobs/{job_id} for status." -} -``` - -**Error Responses:** -- **400 Bad Request:** Invalid parameters -- **503 Service Unavailable:** Queue is full -- **500 Internal Server Error:** GPU health check failed (if device=cuda) - -#### 2. Get Job Status - -**Endpoint:** `GET /jobs/{job_id}` - -**Success Response (200):** -```json -{ - "job_id": "550e8400-e29b-41d4-a716-446655440000", - "status": "running", - "queue_position": null, - "created_at": "2025-10-07T10:30:00.123456Z", - "started_at": "2025-10-07T10:30:05.234567Z", - "completed_at": null, - "result_path": null, - "error": null, - "processing_time_seconds": null -} -``` - -**Error Responses:** -- **404 Not Found:** Job ID not found - -#### 3. Get Job Result - -**Endpoint:** `GET /jobs/{job_id}/result` - -**Success Response (200):** -``` -Content-Type: text/plain - -This is the transcribed text from the audio file... -``` - -**Error Responses:** -- **404 Not Found:** Job ID not found -- **409 Conflict:** Job not completed yet -- **500 Internal Server Error:** Result file missing - -#### 4. List Jobs - -**Endpoint:** `GET /jobs?status={status}&limit={limit}` - -**Query Parameters:** -- `status` (optional): Filter by status (queued, running, completed, failed) -- `limit` (optional): Max results (default: 100) - -**Success Response (200):** -```json -{ - "jobs": [ - { - "job_id": "...", - "status": "completed", - "created_at": "...", - ... - } - ], - "total": 42 -} -``` - -#### 5. GPU Health Check - -**Endpoint:** `GET /health/gpu` - -**Success Response (200):** -```json -{ - "gpu_available": true, - "gpu_working": true, - "device_used": "cuda", - "device_name": "NVIDIA RTX 3090", - "memory_total_gb": 24.0, - "memory_available_gb": 20.5, - "test_duration_seconds": 0.87, - "timestamp": "2025-10-07T10:30:00.123456Z", - "error": null, - "interpretation": "GPU is healthy and working correctly" -} -``` - ---- - -### MCP Tools - -#### 1. transcribe_async - -```python -@mcp.tool() -def transcribe_async( - audio_path: str, - model_name: str = "large-v3", - device: str = "auto", - compute_type: str = "auto", - language: str = None, - output_format: str = "txt", - beam_size: int = 5, - temperature: float = 0.0, - initial_prompt: str = None, - output_directory: str = None -) -> str: - """ - Submit an audio file for asynchronous transcription. - - IMPORTANT: This tool returns immediately with a job_id. Use get_job_status() - to check progress and get_job_result() to retrieve the transcription. - - WORKFLOW FOR LLM AGENTS: - 1. Call this tool to submit the job - 2. You will receive a job_id and queue_position - 3. Poll get_job_status(job_id) every 5-10 seconds to check progress - 4. When status="completed", call get_job_result(job_id) to get transcription - - For long audio files (>10 minutes), expect processing to take several minutes. - You can check queue_position to estimate wait time (each job ~2-5 minutes). - - Args: - audio_path: Path to audio file on server - model_name: Whisper model (tiny, base, small, medium, large-v3) - device: Execution device (cpu, cuda, auto) - compute_type: Computation type (float16, int8, auto) - language: Language code (en, zh, ja, etc.) or auto-detect - output_format: Output format (txt, vtt, srt, json) - beam_size: Beam search size (larger=better quality, slower) - temperature: Sampling temperature (0.0=greedy) - initial_prompt: Optional prompt to guide transcription - output_directory: Where to save result (uses default if not specified) - - Returns: - JSON string with job_id, status, queue_position, and instructions - """ -``` - -#### 2. get_job_status - -```python -@mcp.tool() -def get_job_status(job_id: str) -> str: - """ - Check the status of a transcription job. - - Status values: - - "queued": Job is waiting in queue. Check queue_position. - - "running": Job is currently being processed. - - "completed": Transcription finished. Call get_job_result() to retrieve. - - "failed": Job failed. Check error field for details. - - Args: - job_id: Job ID from transcribe_async() - - Returns: - JSON string with detailed job status including: - - status, queue_position, timestamps, error (if any) - """ -``` - -#### 3. get_job_result - -```python -@mcp.tool() -def get_job_result(job_id: str) -> str: - """ - Retrieve the transcription result for a completed job. - - IMPORTANT: Only call this when get_job_status() returns status="completed". - If the job is not completed, this will return an error. - - Args: - job_id: Job ID from transcribe_async() - - Returns: - Transcription text as a string - - Errors: - - "Job not found" if invalid job_id - - "Job not completed yet" if status is not "completed" - - "Result file not found" if transcription file is missing - """ -``` - -#### 4. list_transcription_jobs - -```python -@mcp.tool() -def list_transcription_jobs( - status_filter: str = None, - limit: int = 20 -) -> str: - """ - List transcription jobs with optional filtering. - - Useful for: - - Checking all your submitted jobs - - Finding completed jobs - - Monitoring queue status - - Args: - status_filter: Filter by status (queued, running, completed, failed) - limit: Maximum number of jobs to return (default: 20) - - Returns: - JSON string with list of jobs - """ -``` - -#### 5. check_gpu_health - -```python -@mcp.tool() -def check_gpu_health() -> str: - """ - Test GPU availability and performance by running a quick transcription. - - This tool loads the tiny model and transcribes a 1-second test audio file - to verify the GPU is working correctly. - - Use this when: - - You want to verify GPU is available before submitting large jobs - - You suspect GPU performance issues - - For monitoring/debugging purposes - - Returns: - JSON string with detailed GPU status including: - - gpu_available, gpu_working, device_name, memory_info - - test_duration_seconds (GPU: <1s, CPU: 5-10s) - - interpretation message - - Note: If this returns gpu_working=false, transcriptions will be very slow. - """ -``` - ---- - -## Implementation Phases - -### Phase 1: Core Infrastructure (Estimate: 2-3 hours) - -**Tasks:** -1. ✅ Create `src/utils/test_audio_generator.py` - - Implement `generate_test_audio()` - - Test audio file generation - -2. ✅ Create `src/core/gpu_health.py` - - Implement `GPUHealthStatus` dataclass - - Implement `check_gpu_health()` with strict failure handling - - Implement `HealthMonitor` class - - Test GPU health check (verify it raises error on GPU failure) - -3. ✅ Create `src/core/job_queue.py` - - Implement `Job` class with serialization - - Implement `JobQueue` class - - Test job submission, processing, status retrieval - - Test disk persistence (save/load) - -**Testing:** -- Unit test each component independently -- Verify GPU health check rejects when GPU fails -- Verify job queue persists and loads correctly - -### Phase 2: Server Integration (Estimate: 1-2 hours) - -**Tasks:** -4. ✅ Update `src/servers/api_server.py` - - Initialize JobQueue singleton - - Initialize HealthMonitor - - Add POST /jobs endpoint - - Add GET /jobs/{id} endpoint - - Add GET /jobs/{id}/result endpoint - - Add GET /jobs endpoint - - Add GET /health/gpu endpoint - - Remove or deprecate old sync endpoints - -5. ✅ Update `src/servers/whisper_server.py` - - Initialize JobQueue singleton - - Initialize HealthMonitor - - Replace old tools with async tools: - - transcribe_async() - - get_job_status() - - get_job_result() - - list_transcription_jobs() - - check_gpu_health() - - Remove old tools - -**Testing:** -- Test each endpoint with curl/httpie -- Test MCP tools with `mcp dev` command -- Verify error handling - -### Phase 3: Configuration & Environment (Estimate: 30 min) - -**Tasks:** -6. ✅ Update `run_api_server.sh` - - Add job queue env vars - - Add GPU health monitor env vars - - Create job metadata directory - -7. ✅ Update `run_mcp_server.sh` - - Add job queue env vars - - Add GPU health monitor env vars - - Create job metadata directory - -**Testing:** -- Test startup with new env vars -- Verify directories are created - -### Phase 4: Integration Testing (Estimate: 1-2 hours) - -**Tasks:** -8. ✅ Test end-to-end flow - - Submit job → Poll status → Get result - - Test with real audio files - - Test queue limits (submit 101 jobs) - - Test GPU health check - - Test server restart (verify job persistence) - -9. ✅ Test error conditions - - Invalid audio path - - Queue full - - GPU failure (mock by setting device=cuda on CPU-only machine) - - Job not found - - Result retrieval before completion - -10. ✅ Test MCP integration - - Add to Claude Desktop config - - Test transcribe_async flow - - Test all MCP tools - -**Testing Checklist:** -- [ ] Single job submission and completion -- [ ] Multiple jobs in queue (FIFO ordering) -- [ ] Queue full rejection (503 error) -- [ ] GPU health check passes on GPU machine -- [ ] GPU health check fails on CPU-only machine (when device=cuda) -- [ ] Server restart with queued jobs (resume processing) -- [ ] Server restart with running job (mark as failed) -- [ ] Result retrieval for completed job -- [ ] Error handling for invalid job_id -- [ ] MCP tools work in Claude Desktop -- [ ] Periodic GPU monitoring runs in background - ---- - -## Testing Strategy - -### Unit Tests - -```python -# tests/test_gpu_health.py -def test_gpu_health_check_success(): - """Test GPU health check on working GPU""" - -def test_gpu_health_check_rejects_cpu_fallback(): - """Test that expected=cuda but actual=cpu raises error""" - -def test_health_monitor_periodic_checks(): - """Test background monitoring thread""" - -# tests/test_job_queue.py -def test_job_submission(): - """Test job submission returns job_id""" - -def test_job_processing_fifo(): - """Test jobs processed in FIFO order""" - -def test_queue_full_rejection(): - """Test queue rejects when full""" - -def test_job_persistence(): - """Test jobs saved and loaded from disk""" - -def test_job_status_retrieval(): - """Test get_job_status() returns correct info""" -``` - -### Integration Tests - -```bash -# Test API endpoints -curl -X POST http://localhost:8000/jobs \ - -H "Content-Type: application/json" \ - -d '{"audio_path": "/path/to/test.mp3"}' - -# Expected: {"job_id": "...", "status": "queued", "queue_position": 1} - -# Poll status -curl http://localhost:8000/jobs/{job_id} - -# Expected: {"status": "running", ...} - -# Get result (when completed) -curl http://localhost:8000/jobs/{job_id}/result - -# Expected: Transcription text -``` - -### MCP Tests - -```bash -# Test with MCP CLI -mcp dev src/servers/whisper_server.py - -# In MCP client, call: -transcribe_async(audio_path="/path/to/test.mp3") -# Returns: {job_id, status, queue_position} - -get_job_status(job_id="...") -# Returns: {status, ...} - -get_job_result(job_id="...") -# Returns: transcription text -``` - ---- - -## Environment Variables - -### New Environment Variables - -Add to `run_api_server.sh` and `run_mcp_server.sh`: - -```bash -# Job Queue Configuration -export JOB_QUEUE_MAX_SIZE=100 -export JOB_METADATA_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/jobs" -export JOB_RETENTION_DAYS=7 # Optional: auto-cleanup old jobs (0=disabled) - -# GPU Health Monitoring -export GPU_HEALTH_CHECK_ENABLED=true -export GPU_HEALTH_CHECK_INTERVAL_MINUTES=10 -export GPU_HEALTH_TEST_MODEL="tiny" # Model to use for health checks - -# Create job metadata directory -mkdir -p "$JOB_METADATA_DIR" -``` - -### Existing Variables (Keep) - -```bash -export CUDA_VISIBLE_DEVICES=1 -export WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/data/models" -export TRANSCRIPTION_OUTPUT_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs" -export TRANSCRIPTION_BATCH_OUTPUT_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/batch" -export TRANSCRIPTION_MODEL="large-v3" -export TRANSCRIPTION_DEVICE="cuda" -export TRANSCRIPTION_COMPUTE_TYPE="float16" -export TRANSCRIPTION_OUTPUT_FORMAT="txt" -export TRANSCRIPTION_BEAM_SIZE="5" -export TRANSCRIPTION_TEMPERATURE="0.0" -``` - ---- - -## Error Handling - -### Critical: GPU Failure Rejection - -**Scenario:** User requests device=cuda but GPU is unavailable/failing - -**Current Behavior (BAD):** -```python -# model_manager.py:64-66 -if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" # Silent fallback! -``` - -**New Behavior (GOOD):** - -```python -# In job_queue.py:submit_job() -if device == "cuda": - try: - health_status = check_gpu_health(expected_device="cuda") - if not health_status.gpu_working: - raise RuntimeError( - f"GPU device requested but not available. " - f"GPU check failed: {health_status.error}. " - f"Transcription would run on CPU and be 10-100x slower. " - f"Please use device='cpu' explicitly if you want CPU processing." - ) - except RuntimeError as e: - # Re-raise with clear message - raise RuntimeError(f"Job rejected: {e}") - -# In gpu_health.py:check_gpu_health() -if expected_device == "cuda": - # Run health check - if actual_device != "cuda": - raise RuntimeError( - "GPU requested but model loaded on CPU. " - "Possible causes: GPU driver issues, insufficient memory, " - "CUDA version mismatch. Check logs for details." - ) -``` - -**Result:** -- Job submission fails immediately with clear error -- User knows GPU is not working -- User can decide to use CPU explicitly or fix GPU -- No wasted time on slow CPU processing - -### Other Error Scenarios - -**1. Queue Full** -```python -# Return 503 Service Unavailable -{ - "error": "Job queue is full", - "queue_size": 100, - "message": "Please try again later or contact administrator" -} -``` - -**2. Invalid Audio Path** -```python -# Return 400 Bad Request -{ - "error": "Audio file not found", - "audio_path": "/invalid/path.mp3", - "message": "Please verify the file exists and path is correct" -} -``` - -**3. Job Not Found** -```python -# Return 404 Not Found -{ - "error": "Job not found", - "job_id": "invalid-uuid", - "message": "Job ID does not exist or has been cleaned up" -} -``` - -**4. Result Not Ready** -```python -# Return 409 Conflict -{ - "error": "Job not completed", - "job_id": "...", - "current_status": "running", - "message": "Please wait for job to complete before requesting result" -} -``` - ---- - -## Monitoring & Observability - -### Logging Strategy - -**Log Levels:** -- **INFO**: Normal operations (job submitted, started, completed) -- **WARNING**: Performance issues (GPU slow, queue filling up) -- **ERROR**: Failures (job failed, GPU check failed) - -**Key Log Messages:** - -```python -# Job lifecycle -logger.info(f"Job {job_id} submitted: {audio_path}") -logger.info(f"Job {job_id} started processing (queue position was {pos})") -logger.info(f"Job {job_id} completed in {duration:.1f}s") -logger.error(f"Job {job_id} failed: {error}") - -# GPU health -logger.info(f"GPU health check passed: {device_name}, {test_duration:.2f}s") -logger.warning(f"GPU health check slow: {test_duration:.2f}s (expected <1s)") -logger.error(f"GPU health check failed: {error}") - -# Queue status -logger.warning(f"Job queue filling up: {queue_size}/{max_size}") -logger.error(f"Job queue full, rejecting request") -``` - -### Metrics to Track - -**Job Metrics:** -- Total jobs submitted -- Jobs completed successfully -- Jobs failed -- Average processing time -- Average queue wait time - -**Queue Metrics:** -- Current queue size -- Max queue size seen -- Queue full rejections - -**GPU Metrics:** -- GPU health check results (success/fail) -- GPU utilization (if available) -- Model loading failures - ---- - -## Migration Strategy - -### Backward Compatibility - -**Option 1: Deprecate Old Endpoints (Recommended)** -- Keep old endpoints for 1-2 releases with deprecation warnings -- Return warning header: `X-Deprecated: Use /jobs endpoint instead` -- Document migration path in CLAUDE.md - -**Option 2: Remove Old Endpoints Immediately** -- Clean break, simpler codebase -- Update CLAUDE.md with new API only -- Announce breaking change in release notes - -**Recommendation:** Option 1 for REST API, Option 2 for MCP tools (MCP users update config anyway) - -### Deployment Steps - -1. **Pre-deployment:** - - Test all components in development - - Verify GPU health check works - - Test job persistence - -2. **Deployment:** - - Update code - - Update environment variables in run scripts - - Create job metadata directory - - Restart services - -3. **Post-deployment:** - - Monitor logs for errors - - Check GPU health status - - Verify jobs are processing - - Test with real workload - -4. **Rollback Plan:** - - Keep old code in git branch - - Can quickly revert if issues found - - Job metadata on disk survives rollback - ---- - -## Future Enhancements - -### Phase 2 Features (Not in Initial Implementation) - -1. **Job Cancellation** - - Add `DELETE /jobs/{id}` endpoint - - Gracefully stop running job - -2. **Priority Queue** - - Add priority parameter to job submission - - Use PriorityQueue instead of Queue - -3. **Batch Job Submission** - - Submit multiple files as single batch - - Track as parent job with sub-jobs - -4. **Result Streaming** - - Stream partial results as transcription progresses - - Use Server-Sent Events or WebSockets - -5. **Distributed Workers** - - Multiple worker processes/machines - - Use Redis/RabbitMQ for queue - - Horizontal scaling - -6. **Job Expiration** - - Auto-delete old completed jobs - - Configurable retention policy - -7. **Retry Logic** - - Auto-retry failed jobs - - Exponential backoff - -8. **Progress Reporting** - - Report percentage complete - - Estimate time remaining - ---- - -## Appendix A: Code Examples - -### Example 1: Using REST API - -```python -import requests -import time - -# Submit job -response = requests.post('http://localhost:8000/jobs', json={ - 'audio_path': '/path/to/audio.mp3', - 'model_name': 'large-v3', - 'output_format': 'txt' -}) -job = response.json() -job_id = job['job_id'] -print(f"Job submitted: {job_id}, queue position: {job['queue_position']}") - -# Poll for completion -while True: - response = requests.get(f'http://localhost:8000/jobs/{job_id}') - status = response.json() - - if status['status'] == 'completed': - print("Job completed!") - break - elif status['status'] == 'failed': - print(f"Job failed: {status['error']}") - break - else: - print(f"Status: {status['status']}, queue_position: {status.get('queue_position', 'N/A')}") - time.sleep(10) # Poll every 10 seconds - -# Get result -response = requests.get(f'http://localhost:8000/jobs/{job_id}/result') -transcription = response.text -print(f"Transcription:\n{transcription}") -``` - -### Example 2: Using MCP Tools (LLM Agent) - -``` -LLM Agent workflow: - -1. Submit job: - transcribe_async(audio_path="/path/to/podcast.mp3", model_name="large-v3") - → Returns: {"job_id": "abc-123", "status": "queued", "queue_position": 2} - -2. Poll status: - get_job_status(job_id="abc-123") - → Returns: {"status": "queued", "queue_position": 1} - - [Wait 10 seconds] - - get_job_status(job_id="abc-123") - → Returns: {"status": "running"} - - [Wait 30 seconds] - - get_job_status(job_id="abc-123") - → Returns: {"status": "completed", "result_path": "/outputs/podcast.txt"} - -3. Get result: - get_job_result(job_id="abc-123") - → Returns: "Welcome to our podcast. Today we're discussing..." -``` - ---- - -## Appendix B: Architecture Decisions - -### Why In-Memory Queue Instead of Redis? - -**Pros of In-Memory:** -- Zero external dependencies -- Simple to implement and test -- Fast (no network overhead) -- Sufficient for single-machine deployment - -**Cons:** -- Not distributed (can't scale horizontally) -- Jobs lost if process crashes before saving to disk -- No shared queue across multiple processes - -**Decision:** Start with in-memory, migrate to Redis if scaling needed - -### Why Single Worker Thread? - -**Pros:** -- No concurrent GPU access (avoids memory issues) -- Simple to implement and debug -- Predictable resource usage -- FIFO ordering guaranteed - -**Cons:** -- Lower throughput (one job at a time) -- Can't utilize multiple GPUs - -**Decision:** Single worker is best for GPU processing. Can add multi-worker for CPU-only mode later. - -### Why JSON Files Instead of SQLite? - -**Pros of JSON Files:** -- Simple to inspect (just cat the file) -- No database corruption issues -- Easy to backup/restore -- No locking issues -- One file per job (no shared state) - -**Cons:** -- Slower for large job counts (10,000+) -- No complex queries -- No transactions - -**Decision:** JSON files sufficient for expected workload (<1000 jobs). Can migrate to SQLite if needed. - ---- - -## Appendix C: Security Considerations - -### Input Validation - -**Audio Path:** -- Verify file exists -- Check file extension -- Verify file size (<10GB recommended) -- Consider path traversal attacks (validate no `../` in path) - -**Model Selection:** -- Validate against whitelist of allowed models -- Prevent arbitrary model loading - -**Output Directory:** -- Validate directory exists and is writable -- Consider restricting to specific base directories - -### Resource Limits - -**Queue Size:** -- Limit max queue size (prevent DOS) -- Return 503 when full - -**File Size:** -- Warn on files >1GB -- Consider max file size limit - -**Job Retention:** -- Implement cleanup of old jobs -- Prevent disk space exhaustion - ---- - -## Success Criteria - -Implementation is considered successful when: - -1. ✅ Jobs can be submitted and return immediately (no timeout) -2. ✅ Jobs are processed in FIFO order -3. ✅ GPU health check correctly detects GPU failures -4. ✅ GPU device=cuda requests are REJECTED if GPU unavailable -5. ✅ Jobs persist to disk and survive server restarts -6. ✅ Queue full scenario returns 503 error -7. ✅ MCP tools work correctly in Claude Desktop -8. ✅ All tests pass -9. ✅ Documentation is complete and accurate -10. ✅ Existing functionality is not broken - ---- - -**END OF DEVELOPMENT PLAN** diff --git a/src/__init__.py b/src/__init__.py index 2565d5d..a4c466d 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,5 +1,9 @@ """ -语音识别MCP服务模块 +Faster Whisper MCP Transcription Service + +High-performance audio transcription service with dual-server architecture +(MCP and REST API) featuring async job queue and GPU health monitoring. """ -__version__ = "0.1.0" +__version__ = "0.2.0" +__author__ = "Whisper MCP Team" diff --git a/src/core/__init__.py b/src/core/__init__.py index e69de29..f87c330 100644 --- a/src/core/__init__.py +++ b/src/core/__init__.py @@ -0,0 +1,6 @@ +""" +Core modules for Whisper transcription service. + +Includes model management, transcription logic, job queue, GPU health monitoring, +and GPU reset functionality. +""" diff --git a/src/core/gpu_health.py b/src/core/gpu_health.py index 7c9fa91..4cbe00d 100644 --- a/src/core/gpu_health.py +++ b/src/core/gpu_health.py @@ -3,6 +3,7 @@ GPU health monitoring for Whisper transcription service. Performs real GPU health checks using actual model loading and transcription, with strict failure handling to prevent silent CPU fallbacks. +Includes circuit breaker pattern to prevent repeated failed checks. """ import time @@ -14,9 +15,19 @@ from typing import Optional, List import torch from utils.test_audio_generator import generate_test_audio +from utils.circuit_breaker import CircuitBreaker, CircuitBreakerOpen logger = logging.getLogger(__name__) +# Global circuit breaker for GPU health checks +_gpu_health_circuit_breaker = CircuitBreaker( + name="gpu_health_check", + failure_threshold=3, # Open after 3 consecutive failures + success_threshold=2, # Close after 2 consecutive successes + timeout_seconds=60, # Try again after 60 seconds + half_open_max_calls=1 # Only 1 test call in half-open state +) + # Import reset functionality (after logger initialization) try: from core.gpu_reset import ( @@ -48,7 +59,7 @@ class GPUHealthStatus: return asdict(self) -def check_gpu_health(expected_device: str = "auto") -> GPUHealthStatus: +def _check_gpu_health_internal(expected_device: str = "auto") -> GPUHealthStatus: """ Comprehensive GPU health check using real model + transcription. @@ -207,6 +218,58 @@ def check_gpu_health(expected_device: str = "auto") -> GPUHealthStatus: return status +def check_gpu_health(expected_device: str = "auto", use_circuit_breaker: bool = True) -> GPUHealthStatus: + """ + GPU health check with optional circuit breaker protection. + + This is the main entry point for GPU health checks. By default, it uses + circuit breaker pattern to prevent repeated failed checks. + + Args: + expected_device: Expected device ("auto", "cuda", "cpu") + use_circuit_breaker: Enable circuit breaker protection (default: True) + + Returns: + GPUHealthStatus object + + Raises: + RuntimeError: If expected_device="cuda" but GPU test fails + CircuitBreakerOpen: If circuit breaker is open (too many recent failures) + """ + if use_circuit_breaker: + try: + return _gpu_health_circuit_breaker.call(_check_gpu_health_internal, expected_device) + except CircuitBreakerOpen as e: + # Circuit is open, fail fast without attempting check + logger.warning(f"GPU health check circuit breaker is OPEN: {e}") + raise RuntimeError(f"GPU health check unavailable: {str(e)}") + else: + return _check_gpu_health_internal(expected_device) + + +def get_circuit_breaker_stats() -> dict: + """ + Get current circuit breaker statistics. + + Returns: + Dictionary with circuit state and failure/success counts + """ + return _gpu_health_circuit_breaker.get_stats() + + +def reset_circuit_breaker(): + """ + Manually reset the GPU health check circuit breaker. + + Useful for: + - Testing + - Manual intervention after fixing GPU issues + - Clearing persistent error state + """ + _gpu_health_circuit_breaker.reset() + logger.info("GPU health check circuit breaker manually reset") + + def check_gpu_health_with_reset( expected_device: str = "cuda", auto_reset: bool = True diff --git a/src/servers/__init__.py b/src/servers/__init__.py index e69de29..9683346 100644 --- a/src/servers/__init__.py +++ b/src/servers/__init__.py @@ -0,0 +1,5 @@ +""" +Server implementations for Whisper transcription service. + +Includes MCP server (whisper_server.py) and REST API server (api_server.py). +""" diff --git a/src/servers/api_server.py b/src/servers/api_server.py index 1bc7fa9..90c020e 100644 --- a/src/servers/api_server.py +++ b/src/servers/api_server.py @@ -8,6 +8,8 @@ import os import sys import logging import queue as queue_module +import tempfile +from pathlib import Path from contextlib import asynccontextmanager from typing import Optional, List from fastapi import FastAPI, HTTPException, UploadFile, File, Form @@ -17,20 +19,13 @@ import json from core.model_manager import get_model_info from core.job_queue import JobQueue, JobStatus -from core.gpu_health import HealthMonitor, check_gpu_health +from core.gpu_health import HealthMonitor, check_gpu_health, get_circuit_breaker_stats, reset_circuit_breaker +from utils.startup import startup_sequence, cleanup_on_shutdown # Logging configuration logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Import GPU health check with reset -try: - from core.gpu_health import check_gpu_health_with_reset - GPU_HEALTH_CHECK_AVAILABLE = True -except ImportError as e: - logger.warning(f"GPU health check with reset not available: {e}") - GPU_HEALTH_CHECK_AVAILABLE = False - # Global instances job_queue: Optional[JobQueue] = None health_monitor: Optional[HealthMonitor] = None @@ -41,34 +36,22 @@ async def lifespan(app: FastAPI): """FastAPI lifespan context manager for startup/shutdown""" global job_queue, health_monitor - # Startup - logger.info("Starting job queue and health monitor...") + # Startup - use common startup logic (without GPU check, handled in main) + logger.info("Initializing job queue and health monitor...") - # Initialize job queue - max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100")) - metadata_dir = os.getenv("JOB_METADATA_DIR", "/media/raid/agents/tools/mcp-transcriptor/outputs/jobs") - job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir) - job_queue.start() - logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})") + from utils.startup import initialize_job_queue, initialize_health_monitor - # Initialize health monitor - health_check_enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true" - if health_check_enabled: - check_interval = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10")) - health_monitor = HealthMonitor(check_interval_minutes=check_interval) - health_monitor.start() - logger.info(f"GPU health monitor started (interval={check_interval} minutes)") + job_queue = initialize_job_queue() + health_monitor = initialize_health_monitor() yield - # Shutdown - logger.info("Shutting down job queue and health monitor...") - if job_queue: - job_queue.stop(wait_for_current=True) - logger.info("Job queue stopped") - if health_monitor: - health_monitor.stop() - logger.info("Health monitor stopped") + # Shutdown - use common cleanup logic + cleanup_on_shutdown( + job_queue=job_queue, + health_monitor=health_monitor, + wait_for_current_job=True + ) # Create FastAPI app @@ -107,6 +90,8 @@ async def root(): "GET /": "API information", "GET /health": "Health check", "GET /health/gpu": "GPU health check", + "GET /health/circuit-breaker": "Get circuit breaker stats", + "POST /health/circuit-breaker/reset": "Reset circuit breaker", "GET /models": "Get available models information", "POST /jobs": "Submit transcription job (async)", "GET /jobs/{job_id}": "Get job status", @@ -418,6 +403,39 @@ async def gpu_health_check_endpoint(): ) +@app.get("/health/circuit-breaker") +async def get_circuit_breaker_status(): + """ + Get GPU health check circuit breaker statistics. + + Returns current state, failure/success counts, and last failure time. + """ + try: + stats = get_circuit_breaker_stats() + return JSONResponse(status_code=200, content=stats) + except Exception as e: + logger.error(f"Failed to get circuit breaker stats: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/health/circuit-breaker/reset") +async def reset_circuit_breaker_endpoint(): + """ + Manually reset the GPU health check circuit breaker. + + Useful after fixing GPU issues or for testing purposes. + """ + try: + reset_circuit_breaker() + return JSONResponse( + status_code=200, + content={"message": "Circuit breaker reset successfully"} + ) + except Exception as e: + logger.error(f"Failed to reset circuit breaker: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @@ -425,31 +443,14 @@ async def gpu_health_check_endpoint(): if __name__ == "__main__": import uvicorn - # Perform startup GPU health check with auto-reset - if GPU_HEALTH_CHECK_AVAILABLE: - try: - logger.info("=" * 70) - logger.info("PERFORMING STARTUP GPU HEALTH CHECK") - logger.info("=" * 70) + # Perform startup GPU health check + from utils.startup import perform_startup_gpu_check - status = check_gpu_health_with_reset(expected_device="cuda", auto_reset=True) - - logger.info("=" * 70) - logger.info("STARTUP GPU CHECK SUCCESSFUL") - logger.info(f"GPU Device: {status.device_name}") - logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB") - logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s") - logger.info("=" * 70) - - except Exception as e: - logger.error("=" * 70) - logger.error("STARTUP GPU CHECK FAILED") - logger.error(f"Error: {e}") - logger.error("This service requires GPU. Terminating.") - logger.error("=" * 70) - sys.exit(1) - else: - logger.warning("GPU health check not available, starting without GPU validation") + perform_startup_gpu_check( + required_device="cuda", + auto_reset=True, + exit_on_failure=True + ) # Get configuration from environment variables host = os.getenv("API_HOST", "0.0.0.0") diff --git a/src/servers/whisper_server.py b/src/servers/whisper_server.py index 65bc871..c1e7da5 100644 --- a/src/servers/whisper_server.py +++ b/src/servers/whisper_server.py @@ -8,25 +8,21 @@ import os import sys import logging import json +import base64 +import tempfile +from pathlib import Path from typing import Optional from mcp.server.fastmcp import FastMCP from core.model_manager import get_model_info from core.job_queue import JobQueue, JobStatus from core.gpu_health import HealthMonitor, check_gpu_health +from utils.startup import startup_sequence, cleanup_on_shutdown # Log configuration logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Import GPU health check with reset -try: - from core.gpu_health import check_gpu_health_with_reset - GPU_HEALTH_CHECK_AVAILABLE = True -except ImportError as e: - logger.warning(f"GPU health check with reset not available: {e}") - GPU_HEALTH_CHECK_AVAILABLE = False - # Global instances job_queue: Optional[JobQueue] = None health_monitor: Optional[HealthMonitor] = None @@ -319,56 +315,20 @@ def check_gpu_health() -> str: if __name__ == "__main__": print("starting mcp server for whisper stt transcriptor") - # Perform startup GPU health check with auto-reset - if GPU_HEALTH_CHECK_AVAILABLE: - try: - logger.info("=" * 70) - logger.info("PERFORMING STARTUP GPU HEALTH CHECK") - logger.info("=" * 70) - - status = check_gpu_health_with_reset(expected_device="cuda", auto_reset=True) - - logger.info("=" * 70) - logger.info("STARTUP GPU CHECK SUCCESSFUL") - logger.info(f"GPU Device: {status.device_name}") - logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB") - logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s") - logger.info("=" * 70) - - except Exception as e: - logger.error("=" * 70) - logger.error("STARTUP GPU CHECK FAILED") - logger.error(f"Error: {e}") - logger.error("This service requires GPU. Terminating.") - logger.error("=" * 70) - sys.exit(1) - else: - logger.warning("GPU health check not available, starting without GPU validation") - - # Initialize job queue - logger.info("Initializing job queue...") - max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100")) - metadata_dir = os.getenv("JOB_METADATA_DIR", "/media/raid/agents/tools/mcp-transcriptor/outputs/jobs") - job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir) - job_queue.start() - logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})") - - # Initialize health monitor - health_check_enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true" - if health_check_enabled: - check_interval = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10")) - health_monitor = HealthMonitor(check_interval_minutes=check_interval) - health_monitor.start() - logger.info(f"GPU health monitor started (interval={check_interval} minutes)") + # Execute common startup sequence + job_queue, health_monitor = startup_sequence( + service_name="MCP Whisper Server", + require_gpu=True, + initialize_queue=True, + initialize_monitoring=True + ) try: mcp.run() finally: # Cleanup on shutdown - logger.info("Shutting down...") - if job_queue: - job_queue.stop(wait_for_current=True) - logger.info("Job queue stopped") - if health_monitor: - health_monitor.stop() - logger.info("Health monitor stopped") \ No newline at end of file + cleanup_on_shutdown( + job_queue=job_queue, + health_monitor=health_monitor, + wait_for_current_job=True + ) \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py index e69de29..9f2d98e 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -0,0 +1,6 @@ +""" +Utility modules for Whisper transcription service. + +Includes audio processing, formatters, test audio generation, input validation, +circuit breaker, and startup logic. +""" diff --git a/src/utils/audio_processor.py b/src/utils/audio_processor.py index 9801a46..6c386e8 100644 --- a/src/utils/audio_processor.py +++ b/src/utils/audio_processor.py @@ -7,17 +7,24 @@ Responsible for audio file validation and preprocessing import os import logging from typing import Union, Any +from pathlib import Path from faster_whisper import decode_audio +from utils.input_validation import ( + validate_audio_file as validate_audio_file_secure, + sanitize_error_message +) + # Log configuration logger = logging.getLogger(__name__) -def validate_audio_file(audio_path: str) -> None: +def validate_audio_file(audio_path: str, allowed_dirs: list = None) -> None: """ - Validate if an audio file is valid + Validate if an audio file is valid (with security checks). Args: audio_path: Path to the audio file + allowed_dirs: Optional list of allowed base directories Raises: FileNotFoundError: If audio file doesn't exist @@ -27,31 +34,19 @@ def validate_audio_file(audio_path: str) -> None: Returns: None: If validation passes """ - # Validate parameters - if not os.path.exists(audio_path): - raise FileNotFoundError(f"Audio file does not exist: {audio_path}") - - # Validate file format - supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"] - file_ext = os.path.splitext(audio_path)[1].lower() - if file_ext not in supported_formats: - raise ValueError(f"Unsupported audio format: {file_ext}. Supported formats: {', '.join(supported_formats)}") - - # Validate file size try: - file_size = os.path.getsize(audio_path) - if file_size == 0: - raise ValueError(f"Audio file is empty: {audio_path}") + # Use secure validation + validate_audio_file_secure(audio_path, allowed_dirs) + except Exception as e: + # Re-raise with sanitized error messages + error_msg = sanitize_error_message(str(e)) - # Warning for large files (over 1GB) - if file_size > 1024 * 1024 * 1024: - logger.warning(f"Warning: File size exceeds 1GB, may require longer processing time: {audio_path}") - except OSError as e: - logger.error(f"Failed to check file size: {str(e)}") - raise OSError(f"Failed to check file size: {str(e)}") - - # Validation passed - return None + if "not found" in str(e).lower(): + raise FileNotFoundError(error_msg) + elif "size" in str(e).lower(): + raise OSError(error_msg) + else: + raise ValueError(error_msg) def process_audio(audio_path: str) -> Union[str, Any]: """ diff --git a/src/utils/circuit_breaker.py b/src/utils/circuit_breaker.py new file mode 100644 index 0000000..b9295d0 --- /dev/null +++ b/src/utils/circuit_breaker.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +""" +Circuit Breaker Pattern Implementation + +Prevents repeated failed attempts and provides fail-fast behavior. +Useful for GPU health checks and other operations that may fail repeatedly. +""" + +import time +import logging +import threading +from enum import Enum +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Callable, Any, Optional + +logger = logging.getLogger(__name__) + + +class CircuitState(Enum): + """Circuit breaker states.""" + CLOSED = "closed" # Normal operation, requests pass through + OPEN = "open" # Circuit is open, requests fail immediately + HALF_OPEN = "half_open" # Testing if circuit can close + + +@dataclass +class CircuitBreakerConfig: + """Configuration for circuit breaker.""" + failure_threshold: int = 3 # Failures before opening circuit + success_threshold: int = 2 # Successes before closing from half-open + timeout_seconds: int = 60 # Time before attempting half-open + half_open_max_calls: int = 1 # Max calls to test in half-open state + + +class CircuitBreaker: + """ + Circuit breaker implementation for preventing repeated failures. + + Usage: + breaker = CircuitBreaker(name="gpu_health", failure_threshold=3) + + @breaker.call + def check_gpu(): + # This function will be protected by circuit breaker + return perform_gpu_check() + + # Or use decorator: + @breaker.decorator() + def my_function(): + return "result" + """ + + def __init__( + self, + name: str, + failure_threshold: int = 3, + success_threshold: int = 2, + timeout_seconds: int = 60, + half_open_max_calls: int = 1 + ): + """ + Initialize circuit breaker. + + Args: + name: Name of the circuit (for logging) + failure_threshold: Number of failures before opening + success_threshold: Number of successes to close from half-open + timeout_seconds: Seconds before transitioning to half-open + half_open_max_calls: Max concurrent calls in half-open state + """ + self.name = name + self.config = CircuitBreakerConfig( + failure_threshold=failure_threshold, + success_threshold=success_threshold, + timeout_seconds=timeout_seconds, + half_open_max_calls=half_open_max_calls + ) + + self._state = CircuitState.CLOSED + self._failure_count = 0 + self._success_count = 0 + self._last_failure_time: Optional[datetime] = None + self._half_open_calls = 0 + self._lock = threading.RLock() + + logger.info( + f"Circuit breaker '{name}' initialized: " + f"failure_threshold={failure_threshold}, " + f"timeout={timeout_seconds}s" + ) + + @property + def state(self) -> CircuitState: + """Get current circuit state.""" + with self._lock: + self._update_state() + return self._state + + @property + def is_closed(self) -> bool: + """Check if circuit is closed (normal operation).""" + return self.state == CircuitState.CLOSED + + @property + def is_open(self) -> bool: + """Check if circuit is open (failing fast).""" + return self.state == CircuitState.OPEN + + @property + def is_half_open(self) -> bool: + """Check if circuit is half-open (testing).""" + return self.state == CircuitState.HALF_OPEN + + def _update_state(self): + """Update state based on timeout and counters.""" + if self._state == CircuitState.OPEN: + # Check if timeout has passed + if self._last_failure_time: + elapsed = datetime.utcnow() - self._last_failure_time + if elapsed.total_seconds() >= self.config.timeout_seconds: + logger.info( + f"Circuit '{self.name}': Transitioning to HALF_OPEN " + f"after {elapsed.total_seconds():.0f}s timeout" + ) + self._state = CircuitState.HALF_OPEN + self._half_open_calls = 0 + self._success_count = 0 + + def _on_success(self): + """Handle successful call.""" + with self._lock: + if self._state == CircuitState.HALF_OPEN: + self._success_count += 1 + logger.debug( + f"Circuit '{self.name}': Success in HALF_OPEN " + f"({self._success_count}/{self.config.success_threshold})" + ) + + if self._success_count >= self.config.success_threshold: + logger.info(f"Circuit '{self.name}': Closing circuit after successful test") + self._state = CircuitState.CLOSED + self._failure_count = 0 + self._success_count = 0 + self._last_failure_time = None + + elif self._state == CircuitState.CLOSED: + # Reset failure count on success + self._failure_count = 0 + + def _on_failure(self, error: Exception): + """Handle failed call.""" + with self._lock: + self._failure_count += 1 + self._last_failure_time = datetime.utcnow() + + if self._state == CircuitState.HALF_OPEN: + logger.warning( + f"Circuit '{self.name}': Failure in HALF_OPEN, reopening circuit" + ) + self._state = CircuitState.OPEN + self._success_count = 0 + + elif self._state == CircuitState.CLOSED: + logger.debug( + f"Circuit '{self.name}': Failure {self._failure_count}/" + f"{self.config.failure_threshold}" + ) + + if self._failure_count >= self.config.failure_threshold: + logger.warning( + f"Circuit '{self.name}': Opening circuit after " + f"{self._failure_count} failures. " + f"Will retry in {self.config.timeout_seconds}s" + ) + self._state = CircuitState.OPEN + self._success_count = 0 + + def call(self, func: Callable, *args, **kwargs) -> Any: + """ + Execute function with circuit breaker protection. + + Args: + func: Function to execute + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Function result + + Raises: + CircuitBreakerOpen: If circuit is open + Exception: Original exception from func if it fails + """ + with self._lock: + self._update_state() + + # Check if circuit is open + if self._state == CircuitState.OPEN: + raise CircuitBreakerOpen( + f"Circuit '{self.name}' is OPEN. " + f"Failing fast to prevent repeated failures. " + f"Last failure: {self._last_failure_time.isoformat() if self._last_failure_time else 'unknown'}. " + f"Will retry in {self.config.timeout_seconds}s" + ) + + # Check half-open call limit + if self._state == CircuitState.HALF_OPEN: + if self._half_open_calls >= self.config.half_open_max_calls: + raise CircuitBreakerOpen( + f"Circuit '{self.name}' is HALF_OPEN with max calls reached. " + f"Please wait for current test to complete." + ) + self._half_open_calls += 1 + + # Execute function + try: + result = func(*args, **kwargs) + self._on_success() + return result + + except Exception as e: + self._on_failure(e) + raise + + finally: + # Decrement half-open counter + with self._lock: + if self._state == CircuitState.HALF_OPEN: + self._half_open_calls -= 1 + + def decorator(self): + """ + Decorator for protecting functions with circuit breaker. + + Usage: + breaker = CircuitBreaker("my_service") + + @breaker.decorator() + def my_function(): + return do_something() + """ + def wrapper(func): + def decorated(*args, **kwargs): + return self.call(func, *args, **kwargs) + return decorated + return wrapper + + def reset(self): + """ + Manually reset circuit breaker to closed state. + + Useful for: + - Testing + - Manual intervention + - Clearing error state + """ + with self._lock: + logger.info(f"Circuit '{self.name}': Manual reset to CLOSED state") + self._state = CircuitState.CLOSED + self._failure_count = 0 + self._success_count = 0 + self._last_failure_time = None + self._half_open_calls = 0 + + def get_stats(self) -> dict: + """ + Get circuit breaker statistics. + + Returns: + Dictionary with current state and counters + """ + with self._lock: + self._update_state() + return { + "name": self.name, + "state": self._state.value, + "failure_count": self._failure_count, + "success_count": self._success_count, + "last_failure_time": self._last_failure_time.isoformat() if self._last_failure_time else None, + "config": { + "failure_threshold": self.config.failure_threshold, + "success_threshold": self.config.success_threshold, + "timeout_seconds": self.config.timeout_seconds, + } + } + + +class CircuitBreakerOpen(Exception): + """Exception raised when circuit breaker is open.""" + pass diff --git a/src/utils/input_validation.py b/src/utils/input_validation.py new file mode 100644 index 0000000..5daabfd --- /dev/null +++ b/src/utils/input_validation.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 +""" +Input Validation and Path Sanitization Module + +Provides robust validation for user inputs with security protections +against path traversal, injection attacks, and other malicious inputs. +""" + +import os +import re +import logging +from pathlib import Path +from typing import Optional, List + +logger = logging.getLogger(__name__) + +# Maximum file size (10GB) +MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 * 1024 + +# Allowed audio file extensions +ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"} + +# Allowed output formats +ALLOWED_OUTPUT_FORMATS = {"vtt", "srt", "txt", "json"} + +# Model name validation (whitelist) +ALLOWED_MODEL_NAMES = {"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"} + +# Device validation +ALLOWED_DEVICES = {"cuda", "auto", "cpu"} + +# Compute type validation +ALLOWED_COMPUTE_TYPES = {"float16", "int8", "auto"} + + +class ValidationError(Exception): + """Base exception for validation errors.""" + pass + + +class PathTraversalError(ValidationError): + """Exception for path traversal attempts.""" + pass + + +class InvalidFileTypeError(ValidationError): + """Exception for invalid file types.""" + pass + + +class FileSizeError(ValidationError): + """Exception for file size issues.""" + pass + + +def sanitize_error_message(error_msg: str, sanitize_paths: bool = True) -> str: + """ + Sanitize error messages to prevent information leakage. + + Args: + error_msg: Original error message + sanitize_paths: Whether to sanitize file paths (default: True) + + Returns: + Sanitized error message + """ + if not sanitize_paths: + return error_msg + + # Replace absolute paths with relative paths + # Pattern: /home/user/... or /media/... or /var/... or /tmp/... + path_pattern = r'(/(?:home|media|var|tmp|opt|usr)/[^\s:,]+)' + + def replace_path(match): + full_path = match.group(1) + try: + # Try to get just the filename + basename = os.path.basename(full_path) + return f"" + except: + return "" + + sanitized = re.sub(path_pattern, replace_path, error_msg) + + # Also sanitize user names if present + sanitized = re.sub(r'/home/([^/]+)/', '/home//', sanitized) + + return sanitized + + +def validate_path_safe(file_path: str, allowed_dirs: Optional[List[str]] = None) -> Path: + """ + Validate and sanitize a file path to prevent directory traversal attacks. + + Args: + file_path: Path to validate + allowed_dirs: Optional list of allowed base directories + + Returns: + Resolved Path object + + Raises: + PathTraversalError: If path contains traversal attempts + ValidationError: If path is invalid + """ + if not file_path: + raise ValidationError("File path cannot be empty") + + # Convert to Path object + try: + path = Path(file_path) + except Exception as e: + raise ValidationError(f"Invalid path format: {sanitize_error_message(str(e))}") + + # Check for path traversal attempts + path_str = str(path) + if ".." in path_str: + logger.warning(f"Path traversal attempt detected: {path_str}") + raise PathTraversalError("Path traversal (..) is not allowed") + + # Check for null bytes + if "\x00" in path_str: + logger.warning(f"Null byte in path detected: {path_str}") + raise PathTraversalError("Null bytes in path are not allowed") + + # Resolve to absolute path (but don't follow symlinks yet) + try: + resolved_path = path.resolve() + except Exception as e: + raise ValidationError(f"Cannot resolve path: {sanitize_error_message(str(e))}") + + # If allowed_dirs specified, ensure path is within one of them + if allowed_dirs: + allowed = False + for allowed_dir in allowed_dirs: + try: + allowed_dir_path = Path(allowed_dir).resolve() + # Check if resolved_path is under allowed_dir + resolved_path.relative_to(allowed_dir_path) + allowed = True + break + except ValueError: + # Not relative to this allowed_dir + continue + + if not allowed: + logger.warning( + f"Path outside allowed directories: {path_str}, " + f"allowed: {allowed_dirs}" + ) + raise PathTraversalError( + f"Path must be within allowed directories. " + f"Allowed: {[os.path.basename(d) for d in allowed_dirs]}" + ) + + return resolved_path + + +def validate_audio_file( + file_path: str, + allowed_dirs: Optional[List[str]] = None, + max_size_bytes: int = MAX_FILE_SIZE_BYTES +) -> Path: + """ + Validate audio file path with security checks. + + Args: + file_path: Path to audio file + allowed_dirs: Optional list of allowed base directories + max_size_bytes: Maximum allowed file size + + Returns: + Validated Path object + + Raises: + ValidationError: If validation fails + PathTraversalError: If path traversal detected + FileNotFoundError: If file doesn't exist + InvalidFileTypeError: If file type not allowed + FileSizeError: If file too large + """ + # Validate and sanitize path + validated_path = validate_path_safe(file_path, allowed_dirs) + + # Check file exists + if not validated_path.exists(): + raise FileNotFoundError(f"Audio file not found: {validated_path.name}") + + # Check it's a file (not directory) + if not validated_path.is_file(): + raise ValidationError(f"Path is not a file: {validated_path.name}") + + # Check file extension + file_ext = validated_path.suffix.lower() + if file_ext not in ALLOWED_AUDIO_EXTENSIONS: + raise InvalidFileTypeError( + f"Unsupported audio format: {file_ext}. " + f"Supported: {', '.join(sorted(ALLOWED_AUDIO_EXTENSIONS))}" + ) + + # Check file size + try: + file_size = validated_path.stat().st_size + except Exception as e: + raise ValidationError(f"Cannot check file size: {sanitize_error_message(str(e))}") + + if file_size == 0: + raise FileSizeError(f"Audio file is empty: {validated_path.name}") + + if file_size > max_size_bytes: + raise FileSizeError( + f"File too large: {file_size / (1024**3):.2f}GB. " + f"Maximum: {max_size_bytes / (1024**3):.2f}GB" + ) + + # Warn for large files (>1GB) + if file_size > 1024 * 1024 * 1024: + logger.warning( + f"Large file: {file_size / (1024**3):.2f}GB, " + f"may require extended processing time" + ) + + return validated_path + + +def validate_output_directory( + dir_path: str, + allowed_dirs: Optional[List[str]] = None, + create_if_missing: bool = True +) -> Path: + """ + Validate output directory path. + + Args: + dir_path: Directory path + allowed_dirs: Optional list of allowed base directories + create_if_missing: Create directory if it doesn't exist + + Returns: + Validated Path object + + Raises: + ValidationError: If validation fails + PathTraversalError: If path traversal detected + """ + # Validate and sanitize path + validated_path = validate_path_safe(dir_path, allowed_dirs) + + # Create if requested and doesn't exist + if create_if_missing and not validated_path.exists(): + try: + validated_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Created output directory: {validated_path}") + except Exception as e: + raise ValidationError( + f"Cannot create output directory: {sanitize_error_message(str(e))}" + ) + + # Check it's a directory + if validated_path.exists() and not validated_path.is_dir(): + raise ValidationError(f"Path exists but is not a directory: {validated_path.name}") + + return validated_path + + +def validate_model_name(model_name: str) -> str: + """ + Validate Whisper model name. + + Args: + model_name: Model name to validate + + Returns: + Validated model name + + Raises: + ValidationError: If model name invalid + """ + if not model_name: + raise ValidationError("Model name cannot be empty") + + model_name = model_name.strip().lower() + + if model_name not in ALLOWED_MODEL_NAMES: + raise ValidationError( + f"Invalid model name: {model_name}. " + f"Allowed: {', '.join(sorted(ALLOWED_MODEL_NAMES))}" + ) + + return model_name + + +def validate_device(device: str) -> str: + """ + Validate device parameter. + + Args: + device: Device name to validate + + Returns: + Validated device name + + Raises: + ValidationError: If device invalid + """ + if not device: + raise ValidationError("Device cannot be empty") + + device = device.strip().lower() + + if device not in ALLOWED_DEVICES: + raise ValidationError( + f"Invalid device: {device}. " + f"Allowed: {', '.join(sorted(ALLOWED_DEVICES))}" + ) + + return device + + +def validate_compute_type(compute_type: str) -> str: + """ + Validate compute type parameter. + + Args: + compute_type: Compute type to validate + + Returns: + Validated compute type + + Raises: + ValidationError: If compute type invalid + """ + if not compute_type: + raise ValidationError("Compute type cannot be empty") + + compute_type = compute_type.strip().lower() + + if compute_type not in ALLOWED_COMPUTE_TYPES: + raise ValidationError( + f"Invalid compute type: {compute_type}. " + f"Allowed: {', '.join(sorted(ALLOWED_COMPUTE_TYPES))}" + ) + + return compute_type + + +def validate_output_format(output_format: str) -> str: + """ + Validate output format parameter. + + Args: + output_format: Output format to validate + + Returns: + Validated output format + + Raises: + ValidationError: If output format invalid + """ + if not output_format: + raise ValidationError("Output format cannot be empty") + + output_format = output_format.strip().lower() + + if output_format not in ALLOWED_OUTPUT_FORMATS: + raise ValidationError( + f"Invalid output format: {output_format}. " + f"Allowed: {', '.join(sorted(ALLOWED_OUTPUT_FORMATS))}" + ) + + return output_format + + +def validate_numeric_range( + value: float, + min_value: float, + max_value: float, + param_name: str +) -> float: + """ + Validate numeric parameter is within range. + + Args: + value: Value to validate + min_value: Minimum allowed value + max_value: Maximum allowed value + param_name: Parameter name for error messages + + Returns: + Validated value + + Raises: + ValidationError: If value out of range + """ + if value < min_value or value > max_value: + raise ValidationError( + f"{param_name} must be between {min_value} and {max_value}, " + f"got {value}" + ) + + return value + + +def validate_beam_size(beam_size: int) -> int: + """Validate beam size parameter.""" + return int(validate_numeric_range(beam_size, 1, 20, "beam_size")) + + +def validate_temperature(temperature: float) -> float: + """Validate temperature parameter.""" + return validate_numeric_range(temperature, 0.0, 1.0, "temperature") diff --git a/src/utils/startup.py b/src/utils/startup.py new file mode 100644 index 0000000..e4ba391 --- /dev/null +++ b/src/utils/startup.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +""" +Common Startup Logic Module + +Centralizes startup procedures shared between MCP and API servers, +including GPU health checks, job queue initialization, and health monitoring. +""" + +import os +import sys +import logging +from typing import Optional, Tuple + +logger = logging.getLogger(__name__) + +# Import GPU health check with reset +try: + from core.gpu_health import check_gpu_health_with_reset + GPU_HEALTH_CHECK_AVAILABLE = True +except ImportError as e: + logger.warning(f"GPU health check with reset not available: {e}") + GPU_HEALTH_CHECK_AVAILABLE = False + + +def perform_startup_gpu_check( + required_device: str = "cuda", + auto_reset: bool = True, + exit_on_failure: bool = True +) -> bool: + """ + Perform startup GPU health check with optional auto-reset. + + This function: + 1. Checks if GPU health check is available + 2. Runs comprehensive GPU health check + 3. Attempts auto-reset if check fails and auto_reset=True + 4. Optionally exits process if check fails + + Args: + required_device: Required device ("cuda", "auto") + auto_reset: Enable automatic GPU driver reset on failure + exit_on_failure: Exit process if GPU check fails + + Returns: + True if GPU check passed, False otherwise + + Side effects: + May exit process if exit_on_failure=True and check fails + """ + if not GPU_HEALTH_CHECK_AVAILABLE: + logger.warning("GPU health check not available, starting without GPU validation") + if exit_on_failure: + logger.error("GPU health check required but not available. Exiting.") + sys.exit(1) + return False + + try: + logger.info("=" * 70) + logger.info("PERFORMING STARTUP GPU HEALTH CHECK") + logger.info("=" * 70) + + status = check_gpu_health_with_reset( + expected_device=required_device, + auto_reset=auto_reset + ) + + logger.info("=" * 70) + logger.info("STARTUP GPU CHECK SUCCESSFUL") + logger.info(f"GPU Device: {status.device_name}") + logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB") + logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s") + logger.info("=" * 70) + + return True + + except Exception as e: + logger.error("=" * 70) + logger.error("STARTUP GPU CHECK FAILED") + logger.error(f"Error: {e}") + + if exit_on_failure: + logger.error("This service requires GPU. Terminating.") + logger.error("=" * 70) + sys.exit(1) + else: + logger.error("Continuing without GPU (may have reduced functionality)") + logger.error("=" * 70) + return False + + +def initialize_job_queue( + max_queue_size: Optional[int] = None, + metadata_dir: Optional[str] = None +) -> 'JobQueue': + """ + Initialize job queue with environment variable configuration. + + Args: + max_queue_size: Override for max queue size (uses env var if None) + metadata_dir: Override for metadata directory (uses env var if None) + + Returns: + Initialized JobQueue instance (started) + """ + from core.job_queue import JobQueue + + # Get configuration from environment + if max_queue_size is None: + max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100")) + + if metadata_dir is None: + metadata_dir = os.getenv( + "JOB_METADATA_DIR", + "/media/raid/agents/tools/mcp-transcriptor/outputs/jobs" + ) + + logger.info("Initializing job queue...") + job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir) + job_queue.start() + logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})") + + return job_queue + + +def initialize_health_monitor( + check_interval_minutes: Optional[int] = None, + enabled: Optional[bool] = None +) -> Optional['HealthMonitor']: + """ + Initialize GPU health monitor with environment variable configuration. + + Args: + check_interval_minutes: Override for check interval (uses env var if None) + enabled: Override for enabled status (uses env var if None) + + Returns: + Initialized HealthMonitor instance (started), or None if disabled + """ + from core.gpu_health import HealthMonitor + + # Get configuration from environment + if enabled is None: + enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true" + + if not enabled: + logger.info("GPU health monitoring disabled") + return None + + if check_interval_minutes is None: + check_interval_minutes = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10")) + + health_monitor = HealthMonitor(check_interval_minutes=check_interval_minutes) + health_monitor.start() + logger.info(f"GPU health monitor started (interval={check_interval_minutes} minutes)") + + return health_monitor + + +def startup_sequence( + service_name: str = "whisper-transcription", + require_gpu: bool = True, + initialize_queue: bool = True, + initialize_monitoring: bool = True +) -> Tuple[Optional['JobQueue'], Optional['HealthMonitor']]: + """ + Execute complete startup sequence for a Whisper transcription server. + + This function performs all common startup tasks: + 1. GPU health check with auto-reset + 2. Job queue initialization + 3. Health monitor initialization + + Args: + service_name: Name of the service (for logging) + require_gpu: Whether GPU is required (exit if not available) + initialize_queue: Whether to initialize job queue + initialize_monitoring: Whether to initialize health monitoring + + Returns: + Tuple of (job_queue, health_monitor) - either may be None + + Side effects: + May exit process if GPU required but unavailable + """ + logger.info(f"Starting {service_name}...") + + # Step 1: GPU health check + gpu_ok = perform_startup_gpu_check( + required_device="cuda", + auto_reset=True, + exit_on_failure=require_gpu + ) + + if not gpu_ok and require_gpu: + # Should not reach here (exit_on_failure should have exited) + logger.error("GPU check failed and GPU is required") + sys.exit(1) + + # Step 2: Initialize job queue + job_queue = None + if initialize_queue: + job_queue = initialize_job_queue() + + # Step 3: Initialize health monitor + health_monitor = None + if initialize_monitoring: + health_monitor = initialize_health_monitor() + + logger.info(f"{service_name} startup sequence completed") + + return job_queue, health_monitor + + +def cleanup_on_shutdown( + job_queue: Optional['JobQueue'] = None, + health_monitor: Optional['HealthMonitor'] = None, + wait_for_current_job: bool = True +) -> None: + """ + Perform cleanup on server shutdown. + + Args: + job_queue: JobQueue instance to stop (if any) + health_monitor: HealthMonitor instance to stop (if any) + wait_for_current_job: Wait for current job to complete before stopping + """ + logger.info("Shutting down...") + + if job_queue: + job_queue.stop(wait_for_current=wait_for_current_job) + logger.info("Job queue stopped") + + if health_monitor: + health_monitor.stop() + logger.info("Health monitor stopped") + + logger.info("Shutdown complete") diff --git a/test_phase2.py b/tests/test_async_api_integration.py similarity index 98% rename from test_phase2.py rename to tests/test_async_api_integration.py index 0ccef18..b4a2018 100755 --- a/test_phase2.py +++ b/tests/test_async_api_integration.py @@ -22,8 +22,8 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) -# Add src to path -sys.path.insert(0, str(Path(__file__).parent / "src")) +# Add src to path (go up one level from tests/ to root) +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) # Color codes for terminal output class Colors: @@ -435,7 +435,9 @@ class Phase2Tester: def _create_test_audio_file(self): """Get the path to the test audio file""" - test_audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.mp3" + # Use relative path from project root + project_root = Path(__file__).parent.parent + test_audio_path = str(project_root / "data" / "test.mp3") if not os.path.exists(test_audio_path): raise FileNotFoundError(f"Test audio file not found: {test_audio_path}") return test_audio_path diff --git a/test_phase1.py b/tests/test_core_components.py similarity index 94% rename from test_phase1.py rename to tests/test_core_components.py index 995ca4c..32c0985 100755 --- a/test_phase1.py +++ b/tests/test_core_components.py @@ -24,8 +24,8 @@ logging.basicConfig( datefmt='%H:%M:%S' ) -# Add src to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) +# Add src to path (go up one level from tests/ to root) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) from core.gpu_health import check_gpu_health, HealthMonitor from core.job_queue import JobQueue, JobStatus @@ -61,8 +61,9 @@ def test_audio_file(): print("="*60) try: - # Use the actual test audio file - audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.m4a" + # Use the actual test audio file (relative to project root) + project_root = os.path.join(os.path.dirname(__file__), '..') + audio_path = os.path.join(project_root, "data/test.mp3") # Verify file exists assert os.path.exists(audio_path), "Audio file not found" @@ -147,8 +148,9 @@ def test_job_queue(): job_queue.start() print("✓ Job queue started") - # Use the actual test audio file - audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.m4a" + # Use the actual test audio file (relative to project root) + project_root = os.path.join(os.path.dirname(__file__), '..') + audio_path = os.path.join(project_root, "data/test.mp3") # Test job submission print("\nSubmitting test job...") diff --git a/tests/test_e2e_integration.py b/tests/test_e2e_integration.py new file mode 100755 index 0000000..58695ae --- /dev/null +++ b/tests/test_e2e_integration.py @@ -0,0 +1,523 @@ +#!/usr/bin/env python3 +""" +Test Phase 4: End-to-End Integration Testing + +Comprehensive integration tests for the async job queue system. +Tests all scenarios from the DEV_PLAN.md Phase 4 checklist. +""" + +import os +import sys +import time +import json +import logging +import requests +import subprocess +import signal +from pathlib import Path +from datetime import datetime + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger(__name__) + +# Add src to path (go up one level from tests/ to root) +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +# Color codes for terminal output +class Colors: + GREEN = '\033[92m' + RED = '\033[91m' + YELLOW = '\033[93m' + BLUE = '\033[94m' + CYAN = '\033[96m' + END = '\033[0m' + BOLD = '\033[1m' + +def print_success(msg): + print(f"{Colors.GREEN}✓ {msg}{Colors.END}") + +def print_error(msg): + print(f"{Colors.RED}✗ {msg}{Colors.END}") + +def print_info(msg): + print(f"{Colors.BLUE}ℹ {msg}{Colors.END}") + +def print_warning(msg): + print(f"{Colors.YELLOW}⚠ {msg}{Colors.END}") + +def print_section(msg): + print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}") + print(f"{Colors.BOLD}{Colors.YELLOW}{msg}{Colors.END}") + print(f"{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}\n") + + +class Phase4Tester: + def __init__(self, api_url="http://localhost:8000", test_audio=None): + self.api_url = api_url + # Use relative path from project root if not provided + if test_audio is None: + project_root = Path(__file__).parent.parent + test_audio = str(project_root / "data" / "test.mp3") + self.test_audio = test_audio + self.test_results = [] + self.server_process = None + + # Verify test audio exists + if not os.path.exists(test_audio): + raise FileNotFoundError(f"Test audio file not found: {test_audio}") + + def test(self, name, func): + """Run a test and record result""" + try: + logger.info(f"Testing: {name}") + print_info(f"Testing: {name}") + func() + logger.info(f"PASSED: {name}") + print_success(f"PASSED: {name}") + self.test_results.append((name, True, None)) + return True + except AssertionError as e: + logger.error(f"FAILED: {name} - {str(e)}") + print_error(f"FAILED: {name}") + print_error(f" Reason: {str(e)}") + self.test_results.append((name, False, str(e))) + return False + except Exception as e: + logger.error(f"ERROR: {name} - {str(e)}") + print_error(f"ERROR: {name}") + print_error(f" Exception: {str(e)}") + self.test_results.append((name, False, f"Exception: {str(e)}")) + return False + + def start_api_server(self, wait_time=5): + """Start the API server in background""" + print_info("Starting API server...") + # Script is in project root, one level up from tests/ + script_path = Path(__file__).parent.parent / "run_api_server.sh" + + # Start server in background + self.server_process = subprocess.Popen( + [str(script_path)], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + preexec_fn=os.setsid + ) + + # Wait for server to start + time.sleep(wait_time) + + # Verify server is running + try: + response = requests.get(f"{self.api_url}/health", timeout=5) + if response.status_code == 200: + print_success("API server started successfully") + return True + except: + pass + + print_error("API server failed to start") + return False + + def stop_api_server(self): + """Stop the API server""" + if self.server_process: + print_info("Stopping API server...") + os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM) + self.server_process.wait(timeout=10) + print_success("API server stopped") + + def wait_for_job_completion(self, job_id, timeout=60, poll_interval=2): + """Poll job status until completed or failed""" + start_time = time.time() + last_status = None + + while time.time() - start_time < timeout: + try: + response = requests.get(f"{self.api_url}/jobs/{job_id}") + assert response.status_code == 200, f"Failed to get job status: {response.status_code}" + + status_data = response.json() + current_status = status_data['status'] + + # Print status changes + if current_status != last_status: + if status_data.get('queue_position') is not None: + print_info(f" Job status: {current_status}, queue position: {status_data['queue_position']}") + else: + print_info(f" Job status: {current_status}") + last_status = current_status + + if current_status == "completed": + return status_data + elif current_status == "failed": + raise AssertionError(f"Job failed: {status_data.get('error', 'Unknown error')}") + + time.sleep(poll_interval) + except requests.exceptions.RequestException as e: + raise AssertionError(f"Request failed: {e}") + + raise AssertionError(f"Job did not complete within {timeout} seconds") + + # ======================================================================== + # TEST 1: Single Job Submission and Completion + # ======================================================================== + def test_single_job_flow(self): + """Test complete job flow: submit → poll → get result""" + # Submit job + print_info(" Submitting job...") + response = requests.post(f"{self.api_url}/jobs", json={ + "audio_path": self.test_audio, + "model_name": "large-v3", + "output_format": "txt" + }) + assert response.status_code == 200, f"Job submission failed: {response.status_code}" + + job_data = response.json() + assert 'job_id' in job_data, "No job_id in response" + # Status can be 'queued' or 'running' (if queue is empty and job starts immediately) + assert job_data['status'] in ['queued', 'running'], f"Expected status 'queued' or 'running', got '{job_data['status']}'" + + job_id = job_data['job_id'] + print_success(f" Job submitted: {job_id}") + + # Wait for completion + print_info(" Waiting for job completion...") + final_status = self.wait_for_job_completion(job_id) + + assert final_status['status'] == 'completed', "Job did not complete" + assert final_status['result_path'] is not None, "No result_path in completed job" + assert final_status['processing_time_seconds'] is not None, "No processing time" + print_success(f" Job completed in {final_status['processing_time_seconds']:.2f}s") + + # Get result + print_info(" Retrieving result...") + response = requests.get(f"{self.api_url}/jobs/{job_id}/result") + assert response.status_code == 200, f"Failed to get result: {response.status_code}" + + result_text = response.text + assert len(result_text) > 0, "Empty result" + print_success(f" Got result: {len(result_text)} characters") + + # ======================================================================== + # TEST 2: Multiple Jobs in Queue (FIFO) + # ======================================================================== + def test_multiple_jobs_fifo(self): + """Test multiple jobs are processed in FIFO order""" + job_ids = [] + + # Submit 3 jobs + print_info(" Submitting 3 jobs...") + for i in range(3): + response = requests.post(f"{self.api_url}/jobs", json={ + "audio_path": self.test_audio, + "model_name": "tiny", # Use tiny model for faster processing + "output_format": "txt" + }) + assert response.status_code == 200, f"Job {i+1} submission failed" + + job_data = response.json() + job_ids.append(job_data['job_id']) + print_info(f" Job {i+1} submitted: {job_data['job_id']}, queue_position: {job_data.get('queue_position', 0)}") + + # Wait for all jobs to complete + print_info(" Waiting for all jobs to complete...") + for i, job_id in enumerate(job_ids): + print_info(f" Waiting for job {i+1}/{len(job_ids)}...") + final_status = self.wait_for_job_completion(job_id, timeout=120) + assert final_status['status'] == 'completed', f"Job {i+1} failed" + + print_success(f" All {len(job_ids)} jobs completed in FIFO order") + + # ======================================================================== + # TEST 3: GPU Health Check + # ======================================================================== + def test_gpu_health_check(self): + """Test GPU health check endpoint""" + print_info(" Checking GPU health...") + response = requests.get(f"{self.api_url}/health/gpu") + assert response.status_code == 200, f"GPU health check failed: {response.status_code}" + + health_data = response.json() + assert 'gpu_available' in health_data, "Missing gpu_available field" + assert 'gpu_working' in health_data, "Missing gpu_working field" + assert 'device_used' in health_data, "Missing device_used field" + + print_info(f" GPU Available: {health_data['gpu_available']}") + print_info(f" GPU Working: {health_data['gpu_working']}") + print_info(f" Device: {health_data['device_used']}") + + if health_data['gpu_available']: + assert health_data['device_name'], "GPU available but no device_name" + assert health_data['test_duration_seconds'] < 3, "GPU test took too long (might be using CPU)" + print_success(f" GPU is healthy: {health_data['device_name']}") + else: + print_warning(" GPU not available on this system") + + # ======================================================================== + # TEST 4: Invalid Audio Path + # ======================================================================== + def test_invalid_audio_path(self): + """Test job submission with invalid audio path""" + print_info(" Submitting job with invalid path...") + response = requests.post(f"{self.api_url}/jobs", json={ + "audio_path": "/invalid/path/does/not/exist.mp3", + "model_name": "large-v3" + }) + + # Should return 400 Bad Request + assert response.status_code == 400, f"Expected 400, got {response.status_code}" + + error_data = response.json() + assert 'detail' in error_data or 'error' in error_data, "No error message in response" + print_success(" Invalid path rejected correctly") + + # ======================================================================== + # TEST 5: Job Not Found + # ======================================================================== + def test_job_not_found(self): + """Test retrieving non-existent job""" + print_info(" Requesting non-existent job...") + fake_job_id = "00000000-0000-0000-0000-000000000000" + + response = requests.get(f"{self.api_url}/jobs/{fake_job_id}") + assert response.status_code == 404, f"Expected 404, got {response.status_code}" + print_success(" Non-existent job handled correctly") + + # ======================================================================== + # TEST 6: Result Before Completion + # ======================================================================== + def test_result_before_completion(self): + """Test getting result for job that hasn't completed""" + print_info(" Submitting job and trying to get result immediately...") + + # Submit job + response = requests.post(f"{self.api_url}/jobs", json={ + "audio_path": self.test_audio, + "model_name": "large-v3" + }) + assert response.status_code == 200 + job_id = response.json()['job_id'] + + # Try to get result immediately (job is still queued/running) + time.sleep(0.5) + response = requests.get(f"{self.api_url}/jobs/{job_id}/result") + + # Should return 409 Conflict or similar + assert response.status_code in [409, 400, 404], f"Expected 4xx error, got {response.status_code}" + print_success(" Result request before completion handled correctly") + + # Clean up: wait for job to complete + self.wait_for_job_completion(job_id) + + # ======================================================================== + # TEST 7: List Jobs + # ======================================================================== + def test_list_jobs(self): + """Test listing jobs with filters""" + print_info(" Testing job listing...") + + # List all jobs + response = requests.get(f"{self.api_url}/jobs") + assert response.status_code == 200, f"List jobs failed: {response.status_code}" + + jobs_data = response.json() + assert 'jobs' in jobs_data, "No jobs array in response" + assert isinstance(jobs_data['jobs'], list), "Jobs is not a list" + print_info(f" Found {len(jobs_data['jobs'])} jobs") + + # List only completed jobs + response = requests.get(f"{self.api_url}/jobs?status=completed") + assert response.status_code == 200 + completed_jobs = response.json()['jobs'] + print_info(f" Found {len(completed_jobs)} completed jobs") + + # List with limit + response = requests.get(f"{self.api_url}/jobs?limit=5") + assert response.status_code == 200 + limited_jobs = response.json()['jobs'] + assert len(limited_jobs) <= 5, "Limit not respected" + print_success(" Job listing works correctly") + + # ======================================================================== + # TEST 8: Server Restart with Job Persistence + # ======================================================================== + def test_server_restart_persistence(self): + """Test that jobs persist across server restarts""" + print_info(" Testing job persistence across restart...") + + # Submit a job + response = requests.post(f"{self.api_url}/jobs", json={ + "audio_path": self.test_audio, + "model_name": "tiny" + }) + assert response.status_code == 200 + job_id = response.json()['job_id'] + print_info(f" Submitted job: {job_id}") + + # Get job count before restart + response = requests.get(f"{self.api_url}/jobs") + jobs_before = len(response.json()['jobs']) + print_info(f" Jobs before restart: {jobs_before}") + + # Restart server + print_info(" Restarting server...") + self.stop_api_server() + time.sleep(2) + assert self.start_api_server(wait_time=8), "Server failed to restart" + + # Check jobs after restart + response = requests.get(f"{self.api_url}/jobs") + assert response.status_code == 200 + jobs_after = len(response.json()['jobs']) + print_info(f" Jobs after restart: {jobs_after}") + + # Check our specific job is still there (this is the key test) + response = requests.get(f"{self.api_url}/jobs/{job_id}") + assert response.status_code == 200, "Job not found after restart" + + # Note: Total count may differ due to job retention/cleanup, but persistence works if we can find the job + if jobs_after < jobs_before: + print_warning(f" Job count decreased ({jobs_before} -> {jobs_after}), may be due to cleanup") + + print_success(" Jobs persisted correctly across restart") + + # ======================================================================== + # TEST 9: Health Endpoint + # ======================================================================== + def test_health_endpoint(self): + """Test basic health endpoint""" + print_info(" Checking health endpoint...") + response = requests.get(f"{self.api_url}/health") + assert response.status_code == 200, f"Health check failed: {response.status_code}" + + health_data = response.json() + assert health_data['status'] == 'healthy', "Server not healthy" + print_success(" Health endpoint OK") + + # ======================================================================== + # TEST 10: Models Endpoint + # ======================================================================== + def test_models_endpoint(self): + """Test models information endpoint""" + print_info(" Checking models endpoint...") + response = requests.get(f"{self.api_url}/models") + assert response.status_code == 200, f"Models endpoint failed: {response.status_code}" + + models_data = response.json() + assert 'available_models' in models_data, "No available_models field" + assert 'available_devices' in models_data, "No available_devices field" + assert len(models_data['available_models']) > 0, "No models listed" + print_info(f" Available models: {len(models_data['available_models'])}") + print_success(" Models endpoint OK") + + def print_summary(self): + """Print test summary""" + print_section("TEST SUMMARY") + + passed = sum(1 for _, result, _ in self.test_results if result) + failed = len(self.test_results) - passed + + for name, result, error in self.test_results: + if result: + print_success(f"{name}") + else: + print_error(f"{name}") + if error: + print(f" {error}") + + print(f"\n{Colors.BOLD}Total: {len(self.test_results)} | ", end="") + print(f"{Colors.GREEN}Passed: {passed}{Colors.END} | ", end="") + print(f"{Colors.RED}Failed: {failed}{Colors.END}\n") + + return failed == 0 + + def run_all_tests(self, start_server=True): + """Run all Phase 4 integration tests""" + print_section("PHASE 4: END-TO-END INTEGRATION TESTING") + + try: + # Start server if requested + if start_server: + if not self.start_api_server(): + print_error("Failed to start API server. Aborting tests.") + return False + else: + # Verify server is already running + try: + response = requests.get(f"{self.api_url}/health", timeout=5) + if response.status_code != 200: + print_error("Server is not responding. Please start it first.") + return False + print_info("Using existing API server") + except: + print_error("Cannot connect to API server. Please start it first.") + return False + + # Run tests + print_section("TEST 1: Single Job Submission and Completion") + self.test("Single job flow (submit → poll → get result)", self.test_single_job_flow) + + print_section("TEST 2: Multiple Jobs (FIFO Order)") + self.test("Multiple jobs in queue (FIFO)", self.test_multiple_jobs_fifo) + + print_section("TEST 3: GPU Health Check") + self.test("GPU health check endpoint", self.test_gpu_health_check) + + print_section("TEST 4: Error Handling - Invalid Path") + self.test("Invalid audio path rejection", self.test_invalid_audio_path) + + print_section("TEST 5: Error Handling - Job Not Found") + self.test("Non-existent job handling", self.test_job_not_found) + + print_section("TEST 6: Error Handling - Result Before Completion") + self.test("Result request before completion", self.test_result_before_completion) + + print_section("TEST 7: Job Listing") + self.test("List jobs with filters", self.test_list_jobs) + + print_section("TEST 8: Health Endpoint") + self.test("Basic health endpoint", self.test_health_endpoint) + + print_section("TEST 9: Models Endpoint") + self.test("Models information endpoint", self.test_models_endpoint) + + print_section("TEST 10: Server Restart Persistence") + self.test("Job persistence across server restart", self.test_server_restart_persistence) + + # Print summary + success = self.print_summary() + + return success + + finally: + # Cleanup + if start_server and self.server_process: + self.stop_api_server() + + +def main(): + """Main test runner""" + import argparse + + parser = argparse.ArgumentParser(description='Phase 4 Integration Tests') + parser.add_argument('--url', default='http://localhost:8000', help='API server URL') + # Default to None so Phase4Tester uses relative path + parser.add_argument('--audio', default=None, + help='Path to test audio file (default: /data/test.mp3)') + parser.add_argument('--no-start-server', action='store_true', + help='Do not start server (assume it is already running)') + args = parser.parse_args() + + tester = Phase4Tester(api_url=args.url, test_audio=args.audio) + success = tester.run_all_tests(start_server=not args.no_start_server) + + sys.exit(0 if success else 1) + + +if __name__ == '__main__': + main()