Add circuit breaker, input validation, and refactor startup logic

- Implement circuit breaker pattern for GPU health checks
  - Prevents repeated failures with configurable thresholds
  - Three states: CLOSED, OPEN, HALF_OPEN
  - Integrated into GPU health monitoring

- Add comprehensive input validation and path sanitization
  - Path traversal attack prevention
  - Whitelist-based validation for models, devices, formats
  - Error message sanitization to prevent information leakage
  - File size limits and security checks

- Centralize startup logic across servers
  - Extract common startup procedures to utils/startup.py
  - Deduplicate GPU health checks and initialization code
  - Simplify both MCP and API server startup sequences

- Add proper Python package structure
  - Add __init__.py files to all modules
  - Improve package organization

- Add circuit breaker status API endpoints
  - GET /health/circuit-breaker - View circuit breaker stats
  - POST /health/circuit-breaker/reset - Reset circuit breaker

- Reorganize test files into tests/ directory
  - Rename and restructure test files for better organization
This commit is contained in:
Alihan
2025-10-10 01:03:55 +03:00
parent 40555592e6
commit 5fb742a312
15 changed files with 1655 additions and 1637 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,9 @@
""" """
语音识别MCP服务模块 Faster Whisper MCP Transcription Service
High-performance audio transcription service with dual-server architecture
(MCP and REST API) featuring async job queue and GPU health monitoring.
""" """
__version__ = "0.1.0" __version__ = "0.2.0"
__author__ = "Whisper MCP Team"

View File

@@ -0,0 +1,6 @@
"""
Core modules for Whisper transcription service.
Includes model management, transcription logic, job queue, GPU health monitoring,
and GPU reset functionality.
"""

View File

@@ -3,6 +3,7 @@ GPU health monitoring for Whisper transcription service.
Performs real GPU health checks using actual model loading and transcription, Performs real GPU health checks using actual model loading and transcription,
with strict failure handling to prevent silent CPU fallbacks. with strict failure handling to prevent silent CPU fallbacks.
Includes circuit breaker pattern to prevent repeated failed checks.
""" """
import time import time
@@ -14,9 +15,19 @@ from typing import Optional, List
import torch import torch
from utils.test_audio_generator import generate_test_audio from utils.test_audio_generator import generate_test_audio
from utils.circuit_breaker import CircuitBreaker, CircuitBreakerOpen
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Global circuit breaker for GPU health checks
_gpu_health_circuit_breaker = CircuitBreaker(
name="gpu_health_check",
failure_threshold=3, # Open after 3 consecutive failures
success_threshold=2, # Close after 2 consecutive successes
timeout_seconds=60, # Try again after 60 seconds
half_open_max_calls=1 # Only 1 test call in half-open state
)
# Import reset functionality (after logger initialization) # Import reset functionality (after logger initialization)
try: try:
from core.gpu_reset import ( from core.gpu_reset import (
@@ -48,7 +59,7 @@ class GPUHealthStatus:
return asdict(self) return asdict(self)
def check_gpu_health(expected_device: str = "auto") -> GPUHealthStatus: def _check_gpu_health_internal(expected_device: str = "auto") -> GPUHealthStatus:
""" """
Comprehensive GPU health check using real model + transcription. Comprehensive GPU health check using real model + transcription.
@@ -207,6 +218,58 @@ def check_gpu_health(expected_device: str = "auto") -> GPUHealthStatus:
return status return status
def check_gpu_health(expected_device: str = "auto", use_circuit_breaker: bool = True) -> GPUHealthStatus:
"""
GPU health check with optional circuit breaker protection.
This is the main entry point for GPU health checks. By default, it uses
circuit breaker pattern to prevent repeated failed checks.
Args:
expected_device: Expected device ("auto", "cuda", "cpu")
use_circuit_breaker: Enable circuit breaker protection (default: True)
Returns:
GPUHealthStatus object
Raises:
RuntimeError: If expected_device="cuda" but GPU test fails
CircuitBreakerOpen: If circuit breaker is open (too many recent failures)
"""
if use_circuit_breaker:
try:
return _gpu_health_circuit_breaker.call(_check_gpu_health_internal, expected_device)
except CircuitBreakerOpen as e:
# Circuit is open, fail fast without attempting check
logger.warning(f"GPU health check circuit breaker is OPEN: {e}")
raise RuntimeError(f"GPU health check unavailable: {str(e)}")
else:
return _check_gpu_health_internal(expected_device)
def get_circuit_breaker_stats() -> dict:
"""
Get current circuit breaker statistics.
Returns:
Dictionary with circuit state and failure/success counts
"""
return _gpu_health_circuit_breaker.get_stats()
def reset_circuit_breaker():
"""
Manually reset the GPU health check circuit breaker.
Useful for:
- Testing
- Manual intervention after fixing GPU issues
- Clearing persistent error state
"""
_gpu_health_circuit_breaker.reset()
logger.info("GPU health check circuit breaker manually reset")
def check_gpu_health_with_reset( def check_gpu_health_with_reset(
expected_device: str = "cuda", expected_device: str = "cuda",
auto_reset: bool = True auto_reset: bool = True

View File

@@ -0,0 +1,5 @@
"""
Server implementations for Whisper transcription service.
Includes MCP server (whisper_server.py) and REST API server (api_server.py).
"""

View File

@@ -8,6 +8,8 @@ import os
import sys import sys
import logging import logging
import queue as queue_module import queue as queue_module
import tempfile
from pathlib import Path
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional, List from typing import Optional, List
from fastapi import FastAPI, HTTPException, UploadFile, File, Form from fastapi import FastAPI, HTTPException, UploadFile, File, Form
@@ -17,20 +19,13 @@ import json
from core.model_manager import get_model_info from core.model_manager import get_model_info
from core.job_queue import JobQueue, JobStatus from core.job_queue import JobQueue, JobStatus
from core.gpu_health import HealthMonitor, check_gpu_health from core.gpu_health import HealthMonitor, check_gpu_health, get_circuit_breaker_stats, reset_circuit_breaker
from utils.startup import startup_sequence, cleanup_on_shutdown
# Logging configuration # Logging configuration
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Import GPU health check with reset
try:
from core.gpu_health import check_gpu_health_with_reset
GPU_HEALTH_CHECK_AVAILABLE = True
except ImportError as e:
logger.warning(f"GPU health check with reset not available: {e}")
GPU_HEALTH_CHECK_AVAILABLE = False
# Global instances # Global instances
job_queue: Optional[JobQueue] = None job_queue: Optional[JobQueue] = None
health_monitor: Optional[HealthMonitor] = None health_monitor: Optional[HealthMonitor] = None
@@ -41,34 +36,22 @@ async def lifespan(app: FastAPI):
"""FastAPI lifespan context manager for startup/shutdown""" """FastAPI lifespan context manager for startup/shutdown"""
global job_queue, health_monitor global job_queue, health_monitor
# Startup # Startup - use common startup logic (without GPU check, handled in main)
logger.info("Starting job queue and health monitor...") logger.info("Initializing job queue and health monitor...")
# Initialize job queue from utils.startup import initialize_job_queue, initialize_health_monitor
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
metadata_dir = os.getenv("JOB_METADATA_DIR", "/media/raid/agents/tools/mcp-transcriptor/outputs/jobs")
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
job_queue.start()
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
# Initialize health monitor job_queue = initialize_job_queue()
health_check_enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true" health_monitor = initialize_health_monitor()
if health_check_enabled:
check_interval = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
health_monitor = HealthMonitor(check_interval_minutes=check_interval)
health_monitor.start()
logger.info(f"GPU health monitor started (interval={check_interval} minutes)")
yield yield
# Shutdown # Shutdown - use common cleanup logic
logger.info("Shutting down job queue and health monitor...") cleanup_on_shutdown(
if job_queue: job_queue=job_queue,
job_queue.stop(wait_for_current=True) health_monitor=health_monitor,
logger.info("Job queue stopped") wait_for_current_job=True
if health_monitor: )
health_monitor.stop()
logger.info("Health monitor stopped")
# Create FastAPI app # Create FastAPI app
@@ -107,6 +90,8 @@ async def root():
"GET /": "API information", "GET /": "API information",
"GET /health": "Health check", "GET /health": "Health check",
"GET /health/gpu": "GPU health check", "GET /health/gpu": "GPU health check",
"GET /health/circuit-breaker": "Get circuit breaker stats",
"POST /health/circuit-breaker/reset": "Reset circuit breaker",
"GET /models": "Get available models information", "GET /models": "Get available models information",
"POST /jobs": "Submit transcription job (async)", "POST /jobs": "Submit transcription job (async)",
"GET /jobs/{job_id}": "Get job status", "GET /jobs/{job_id}": "Get job status",
@@ -418,6 +403,39 @@ async def gpu_health_check_endpoint():
) )
@app.get("/health/circuit-breaker")
async def get_circuit_breaker_status():
"""
Get GPU health check circuit breaker statistics.
Returns current state, failure/success counts, and last failure time.
"""
try:
stats = get_circuit_breaker_stats()
return JSONResponse(status_code=200, content=stats)
except Exception as e:
logger.error(f"Failed to get circuit breaker stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/health/circuit-breaker/reset")
async def reset_circuit_breaker_endpoint():
"""
Manually reset the GPU health check circuit breaker.
Useful after fixing GPU issues or for testing purposes.
"""
try:
reset_circuit_breaker()
return JSONResponse(
status_code=200,
content={"message": "Circuit breaker reset successfully"}
)
except Exception as e:
logger.error(f"Failed to reset circuit breaker: {e}")
raise HTTPException(status_code=500, detail=str(e))
@@ -425,31 +443,14 @@ async def gpu_health_check_endpoint():
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
# Perform startup GPU health check with auto-reset # Perform startup GPU health check
if GPU_HEALTH_CHECK_AVAILABLE: from utils.startup import perform_startup_gpu_check
try:
logger.info("=" * 70)
logger.info("PERFORMING STARTUP GPU HEALTH CHECK")
logger.info("=" * 70)
status = check_gpu_health_with_reset(expected_device="cuda", auto_reset=True) perform_startup_gpu_check(
required_device="cuda",
logger.info("=" * 70) auto_reset=True,
logger.info("STARTUP GPU CHECK SUCCESSFUL") exit_on_failure=True
logger.info(f"GPU Device: {status.device_name}") )
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
logger.info("=" * 70)
except Exception as e:
logger.error("=" * 70)
logger.error("STARTUP GPU CHECK FAILED")
logger.error(f"Error: {e}")
logger.error("This service requires GPU. Terminating.")
logger.error("=" * 70)
sys.exit(1)
else:
logger.warning("GPU health check not available, starting without GPU validation")
# Get configuration from environment variables # Get configuration from environment variables
host = os.getenv("API_HOST", "0.0.0.0") host = os.getenv("API_HOST", "0.0.0.0")

View File

@@ -8,25 +8,21 @@ import os
import sys import sys
import logging import logging
import json import json
import base64
import tempfile
from pathlib import Path
from typing import Optional from typing import Optional
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
from core.model_manager import get_model_info from core.model_manager import get_model_info
from core.job_queue import JobQueue, JobStatus from core.job_queue import JobQueue, JobStatus
from core.gpu_health import HealthMonitor, check_gpu_health from core.gpu_health import HealthMonitor, check_gpu_health
from utils.startup import startup_sequence, cleanup_on_shutdown
# Log configuration # Log configuration
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Import GPU health check with reset
try:
from core.gpu_health import check_gpu_health_with_reset
GPU_HEALTH_CHECK_AVAILABLE = True
except ImportError as e:
logger.warning(f"GPU health check with reset not available: {e}")
GPU_HEALTH_CHECK_AVAILABLE = False
# Global instances # Global instances
job_queue: Optional[JobQueue] = None job_queue: Optional[JobQueue] = None
health_monitor: Optional[HealthMonitor] = None health_monitor: Optional[HealthMonitor] = None
@@ -319,56 +315,20 @@ def check_gpu_health() -> str:
if __name__ == "__main__": if __name__ == "__main__":
print("starting mcp server for whisper stt transcriptor") print("starting mcp server for whisper stt transcriptor")
# Perform startup GPU health check with auto-reset # Execute common startup sequence
if GPU_HEALTH_CHECK_AVAILABLE: job_queue, health_monitor = startup_sequence(
try: service_name="MCP Whisper Server",
logger.info("=" * 70) require_gpu=True,
logger.info("PERFORMING STARTUP GPU HEALTH CHECK") initialize_queue=True,
logger.info("=" * 70) initialize_monitoring=True
)
status = check_gpu_health_with_reset(expected_device="cuda", auto_reset=True)
logger.info("=" * 70)
logger.info("STARTUP GPU CHECK SUCCESSFUL")
logger.info(f"GPU Device: {status.device_name}")
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
logger.info("=" * 70)
except Exception as e:
logger.error("=" * 70)
logger.error("STARTUP GPU CHECK FAILED")
logger.error(f"Error: {e}")
logger.error("This service requires GPU. Terminating.")
logger.error("=" * 70)
sys.exit(1)
else:
logger.warning("GPU health check not available, starting without GPU validation")
# Initialize job queue
logger.info("Initializing job queue...")
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
metadata_dir = os.getenv("JOB_METADATA_DIR", "/media/raid/agents/tools/mcp-transcriptor/outputs/jobs")
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
job_queue.start()
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
# Initialize health monitor
health_check_enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true"
if health_check_enabled:
check_interval = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
health_monitor = HealthMonitor(check_interval_minutes=check_interval)
health_monitor.start()
logger.info(f"GPU health monitor started (interval={check_interval} minutes)")
try: try:
mcp.run() mcp.run()
finally: finally:
# Cleanup on shutdown # Cleanup on shutdown
logger.info("Shutting down...") cleanup_on_shutdown(
if job_queue: job_queue=job_queue,
job_queue.stop(wait_for_current=True) health_monitor=health_monitor,
logger.info("Job queue stopped") wait_for_current_job=True
if health_monitor: )
health_monitor.stop()
logger.info("Health monitor stopped")

View File

@@ -0,0 +1,6 @@
"""
Utility modules for Whisper transcription service.
Includes audio processing, formatters, test audio generation, input validation,
circuit breaker, and startup logic.
"""

View File

@@ -7,17 +7,24 @@ Responsible for audio file validation and preprocessing
import os import os
import logging import logging
from typing import Union, Any from typing import Union, Any
from pathlib import Path
from faster_whisper import decode_audio from faster_whisper import decode_audio
from utils.input_validation import (
validate_audio_file as validate_audio_file_secure,
sanitize_error_message
)
# Log configuration # Log configuration
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def validate_audio_file(audio_path: str) -> None: def validate_audio_file(audio_path: str, allowed_dirs: list = None) -> None:
""" """
Validate if an audio file is valid Validate if an audio file is valid (with security checks).
Args: Args:
audio_path: Path to the audio file audio_path: Path to the audio file
allowed_dirs: Optional list of allowed base directories
Raises: Raises:
FileNotFoundError: If audio file doesn't exist FileNotFoundError: If audio file doesn't exist
@@ -27,31 +34,19 @@ def validate_audio_file(audio_path: str) -> None:
Returns: Returns:
None: If validation passes None: If validation passes
""" """
# Validate parameters
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file does not exist: {audio_path}")
# Validate file format
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
file_ext = os.path.splitext(audio_path)[1].lower()
if file_ext not in supported_formats:
raise ValueError(f"Unsupported audio format: {file_ext}. Supported formats: {', '.join(supported_formats)}")
# Validate file size
try: try:
file_size = os.path.getsize(audio_path) # Use secure validation
if file_size == 0: validate_audio_file_secure(audio_path, allowed_dirs)
raise ValueError(f"Audio file is empty: {audio_path}") except Exception as e:
# Re-raise with sanitized error messages
error_msg = sanitize_error_message(str(e))
# Warning for large files (over 1GB) if "not found" in str(e).lower():
if file_size > 1024 * 1024 * 1024: raise FileNotFoundError(error_msg)
logger.warning(f"Warning: File size exceeds 1GB, may require longer processing time: {audio_path}") elif "size" in str(e).lower():
except OSError as e: raise OSError(error_msg)
logger.error(f"Failed to check file size: {str(e)}") else:
raise OSError(f"Failed to check file size: {str(e)}") raise ValueError(error_msg)
# Validation passed
return None
def process_audio(audio_path: str) -> Union[str, Any]: def process_audio(audio_path: str) -> Union[str, Any]:
""" """

View File

@@ -0,0 +1,291 @@
#!/usr/bin/env python3
"""
Circuit Breaker Pattern Implementation
Prevents repeated failed attempts and provides fail-fast behavior.
Useful for GPU health checks and other operations that may fail repeatedly.
"""
import time
import logging
import threading
from enum import Enum
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Callable, Any, Optional
logger = logging.getLogger(__name__)
class CircuitState(Enum):
"""Circuit breaker states."""
CLOSED = "closed" # Normal operation, requests pass through
OPEN = "open" # Circuit is open, requests fail immediately
HALF_OPEN = "half_open" # Testing if circuit can close
@dataclass
class CircuitBreakerConfig:
"""Configuration for circuit breaker."""
failure_threshold: int = 3 # Failures before opening circuit
success_threshold: int = 2 # Successes before closing from half-open
timeout_seconds: int = 60 # Time before attempting half-open
half_open_max_calls: int = 1 # Max calls to test in half-open state
class CircuitBreaker:
"""
Circuit breaker implementation for preventing repeated failures.
Usage:
breaker = CircuitBreaker(name="gpu_health", failure_threshold=3)
@breaker.call
def check_gpu():
# This function will be protected by circuit breaker
return perform_gpu_check()
# Or use decorator:
@breaker.decorator()
def my_function():
return "result"
"""
def __init__(
self,
name: str,
failure_threshold: int = 3,
success_threshold: int = 2,
timeout_seconds: int = 60,
half_open_max_calls: int = 1
):
"""
Initialize circuit breaker.
Args:
name: Name of the circuit (for logging)
failure_threshold: Number of failures before opening
success_threshold: Number of successes to close from half-open
timeout_seconds: Seconds before transitioning to half-open
half_open_max_calls: Max concurrent calls in half-open state
"""
self.name = name
self.config = CircuitBreakerConfig(
failure_threshold=failure_threshold,
success_threshold=success_threshold,
timeout_seconds=timeout_seconds,
half_open_max_calls=half_open_max_calls
)
self._state = CircuitState.CLOSED
self._failure_count = 0
self._success_count = 0
self._last_failure_time: Optional[datetime] = None
self._half_open_calls = 0
self._lock = threading.RLock()
logger.info(
f"Circuit breaker '{name}' initialized: "
f"failure_threshold={failure_threshold}, "
f"timeout={timeout_seconds}s"
)
@property
def state(self) -> CircuitState:
"""Get current circuit state."""
with self._lock:
self._update_state()
return self._state
@property
def is_closed(self) -> bool:
"""Check if circuit is closed (normal operation)."""
return self.state == CircuitState.CLOSED
@property
def is_open(self) -> bool:
"""Check if circuit is open (failing fast)."""
return self.state == CircuitState.OPEN
@property
def is_half_open(self) -> bool:
"""Check if circuit is half-open (testing)."""
return self.state == CircuitState.HALF_OPEN
def _update_state(self):
"""Update state based on timeout and counters."""
if self._state == CircuitState.OPEN:
# Check if timeout has passed
if self._last_failure_time:
elapsed = datetime.utcnow() - self._last_failure_time
if elapsed.total_seconds() >= self.config.timeout_seconds:
logger.info(
f"Circuit '{self.name}': Transitioning to HALF_OPEN "
f"after {elapsed.total_seconds():.0f}s timeout"
)
self._state = CircuitState.HALF_OPEN
self._half_open_calls = 0
self._success_count = 0
def _on_success(self):
"""Handle successful call."""
with self._lock:
if self._state == CircuitState.HALF_OPEN:
self._success_count += 1
logger.debug(
f"Circuit '{self.name}': Success in HALF_OPEN "
f"({self._success_count}/{self.config.success_threshold})"
)
if self._success_count >= self.config.success_threshold:
logger.info(f"Circuit '{self.name}': Closing circuit after successful test")
self._state = CircuitState.CLOSED
self._failure_count = 0
self._success_count = 0
self._last_failure_time = None
elif self._state == CircuitState.CLOSED:
# Reset failure count on success
self._failure_count = 0
def _on_failure(self, error: Exception):
"""Handle failed call."""
with self._lock:
self._failure_count += 1
self._last_failure_time = datetime.utcnow()
if self._state == CircuitState.HALF_OPEN:
logger.warning(
f"Circuit '{self.name}': Failure in HALF_OPEN, reopening circuit"
)
self._state = CircuitState.OPEN
self._success_count = 0
elif self._state == CircuitState.CLOSED:
logger.debug(
f"Circuit '{self.name}': Failure {self._failure_count}/"
f"{self.config.failure_threshold}"
)
if self._failure_count >= self.config.failure_threshold:
logger.warning(
f"Circuit '{self.name}': Opening circuit after "
f"{self._failure_count} failures. "
f"Will retry in {self.config.timeout_seconds}s"
)
self._state = CircuitState.OPEN
self._success_count = 0
def call(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute function with circuit breaker protection.
Args:
func: Function to execute
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
Function result
Raises:
CircuitBreakerOpen: If circuit is open
Exception: Original exception from func if it fails
"""
with self._lock:
self._update_state()
# Check if circuit is open
if self._state == CircuitState.OPEN:
raise CircuitBreakerOpen(
f"Circuit '{self.name}' is OPEN. "
f"Failing fast to prevent repeated failures. "
f"Last failure: {self._last_failure_time.isoformat() if self._last_failure_time else 'unknown'}. "
f"Will retry in {self.config.timeout_seconds}s"
)
# Check half-open call limit
if self._state == CircuitState.HALF_OPEN:
if self._half_open_calls >= self.config.half_open_max_calls:
raise CircuitBreakerOpen(
f"Circuit '{self.name}' is HALF_OPEN with max calls reached. "
f"Please wait for current test to complete."
)
self._half_open_calls += 1
# Execute function
try:
result = func(*args, **kwargs)
self._on_success()
return result
except Exception as e:
self._on_failure(e)
raise
finally:
# Decrement half-open counter
with self._lock:
if self._state == CircuitState.HALF_OPEN:
self._half_open_calls -= 1
def decorator(self):
"""
Decorator for protecting functions with circuit breaker.
Usage:
breaker = CircuitBreaker("my_service")
@breaker.decorator()
def my_function():
return do_something()
"""
def wrapper(func):
def decorated(*args, **kwargs):
return self.call(func, *args, **kwargs)
return decorated
return wrapper
def reset(self):
"""
Manually reset circuit breaker to closed state.
Useful for:
- Testing
- Manual intervention
- Clearing error state
"""
with self._lock:
logger.info(f"Circuit '{self.name}': Manual reset to CLOSED state")
self._state = CircuitState.CLOSED
self._failure_count = 0
self._success_count = 0
self._last_failure_time = None
self._half_open_calls = 0
def get_stats(self) -> dict:
"""
Get circuit breaker statistics.
Returns:
Dictionary with current state and counters
"""
with self._lock:
self._update_state()
return {
"name": self.name,
"state": self._state.value,
"failure_count": self._failure_count,
"success_count": self._success_count,
"last_failure_time": self._last_failure_time.isoformat() if self._last_failure_time else None,
"config": {
"failure_threshold": self.config.failure_threshold,
"success_threshold": self.config.success_threshold,
"timeout_seconds": self.config.timeout_seconds,
}
}
class CircuitBreakerOpen(Exception):
"""Exception raised when circuit breaker is open."""
pass

View File

@@ -0,0 +1,411 @@
#!/usr/bin/env python3
"""
Input Validation and Path Sanitization Module
Provides robust validation for user inputs with security protections
against path traversal, injection attacks, and other malicious inputs.
"""
import os
import re
import logging
from pathlib import Path
from typing import Optional, List
logger = logging.getLogger(__name__)
# Maximum file size (10GB)
MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 * 1024
# Allowed audio file extensions
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"}
# Allowed output formats
ALLOWED_OUTPUT_FORMATS = {"vtt", "srt", "txt", "json"}
# Model name validation (whitelist)
ALLOWED_MODEL_NAMES = {"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"}
# Device validation
ALLOWED_DEVICES = {"cuda", "auto", "cpu"}
# Compute type validation
ALLOWED_COMPUTE_TYPES = {"float16", "int8", "auto"}
class ValidationError(Exception):
"""Base exception for validation errors."""
pass
class PathTraversalError(ValidationError):
"""Exception for path traversal attempts."""
pass
class InvalidFileTypeError(ValidationError):
"""Exception for invalid file types."""
pass
class FileSizeError(ValidationError):
"""Exception for file size issues."""
pass
def sanitize_error_message(error_msg: str, sanitize_paths: bool = True) -> str:
"""
Sanitize error messages to prevent information leakage.
Args:
error_msg: Original error message
sanitize_paths: Whether to sanitize file paths (default: True)
Returns:
Sanitized error message
"""
if not sanitize_paths:
return error_msg
# Replace absolute paths with relative paths
# Pattern: /home/user/... or /media/... or /var/... or /tmp/...
path_pattern = r'(/(?:home|media|var|tmp|opt|usr)/[^\s:,]+)'
def replace_path(match):
full_path = match.group(1)
try:
# Try to get just the filename
basename = os.path.basename(full_path)
return f"<file:{basename}>"
except:
return "<file:redacted>"
sanitized = re.sub(path_pattern, replace_path, error_msg)
# Also sanitize user names if present
sanitized = re.sub(r'/home/([^/]+)/', '/home/<user>/', sanitized)
return sanitized
def validate_path_safe(file_path: str, allowed_dirs: Optional[List[str]] = None) -> Path:
"""
Validate and sanitize a file path to prevent directory traversal attacks.
Args:
file_path: Path to validate
allowed_dirs: Optional list of allowed base directories
Returns:
Resolved Path object
Raises:
PathTraversalError: If path contains traversal attempts
ValidationError: If path is invalid
"""
if not file_path:
raise ValidationError("File path cannot be empty")
# Convert to Path object
try:
path = Path(file_path)
except Exception as e:
raise ValidationError(f"Invalid path format: {sanitize_error_message(str(e))}")
# Check for path traversal attempts
path_str = str(path)
if ".." in path_str:
logger.warning(f"Path traversal attempt detected: {path_str}")
raise PathTraversalError("Path traversal (..) is not allowed")
# Check for null bytes
if "\x00" in path_str:
logger.warning(f"Null byte in path detected: {path_str}")
raise PathTraversalError("Null bytes in path are not allowed")
# Resolve to absolute path (but don't follow symlinks yet)
try:
resolved_path = path.resolve()
except Exception as e:
raise ValidationError(f"Cannot resolve path: {sanitize_error_message(str(e))}")
# If allowed_dirs specified, ensure path is within one of them
if allowed_dirs:
allowed = False
for allowed_dir in allowed_dirs:
try:
allowed_dir_path = Path(allowed_dir).resolve()
# Check if resolved_path is under allowed_dir
resolved_path.relative_to(allowed_dir_path)
allowed = True
break
except ValueError:
# Not relative to this allowed_dir
continue
if not allowed:
logger.warning(
f"Path outside allowed directories: {path_str}, "
f"allowed: {allowed_dirs}"
)
raise PathTraversalError(
f"Path must be within allowed directories. "
f"Allowed: {[os.path.basename(d) for d in allowed_dirs]}"
)
return resolved_path
def validate_audio_file(
file_path: str,
allowed_dirs: Optional[List[str]] = None,
max_size_bytes: int = MAX_FILE_SIZE_BYTES
) -> Path:
"""
Validate audio file path with security checks.
Args:
file_path: Path to audio file
allowed_dirs: Optional list of allowed base directories
max_size_bytes: Maximum allowed file size
Returns:
Validated Path object
Raises:
ValidationError: If validation fails
PathTraversalError: If path traversal detected
FileNotFoundError: If file doesn't exist
InvalidFileTypeError: If file type not allowed
FileSizeError: If file too large
"""
# Validate and sanitize path
validated_path = validate_path_safe(file_path, allowed_dirs)
# Check file exists
if not validated_path.exists():
raise FileNotFoundError(f"Audio file not found: {validated_path.name}")
# Check it's a file (not directory)
if not validated_path.is_file():
raise ValidationError(f"Path is not a file: {validated_path.name}")
# Check file extension
file_ext = validated_path.suffix.lower()
if file_ext not in ALLOWED_AUDIO_EXTENSIONS:
raise InvalidFileTypeError(
f"Unsupported audio format: {file_ext}. "
f"Supported: {', '.join(sorted(ALLOWED_AUDIO_EXTENSIONS))}"
)
# Check file size
try:
file_size = validated_path.stat().st_size
except Exception as e:
raise ValidationError(f"Cannot check file size: {sanitize_error_message(str(e))}")
if file_size == 0:
raise FileSizeError(f"Audio file is empty: {validated_path.name}")
if file_size > max_size_bytes:
raise FileSizeError(
f"File too large: {file_size / (1024**3):.2f}GB. "
f"Maximum: {max_size_bytes / (1024**3):.2f}GB"
)
# Warn for large files (>1GB)
if file_size > 1024 * 1024 * 1024:
logger.warning(
f"Large file: {file_size / (1024**3):.2f}GB, "
f"may require extended processing time"
)
return validated_path
def validate_output_directory(
dir_path: str,
allowed_dirs: Optional[List[str]] = None,
create_if_missing: bool = True
) -> Path:
"""
Validate output directory path.
Args:
dir_path: Directory path
allowed_dirs: Optional list of allowed base directories
create_if_missing: Create directory if it doesn't exist
Returns:
Validated Path object
Raises:
ValidationError: If validation fails
PathTraversalError: If path traversal detected
"""
# Validate and sanitize path
validated_path = validate_path_safe(dir_path, allowed_dirs)
# Create if requested and doesn't exist
if create_if_missing and not validated_path.exists():
try:
validated_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Created output directory: {validated_path}")
except Exception as e:
raise ValidationError(
f"Cannot create output directory: {sanitize_error_message(str(e))}"
)
# Check it's a directory
if validated_path.exists() and not validated_path.is_dir():
raise ValidationError(f"Path exists but is not a directory: {validated_path.name}")
return validated_path
def validate_model_name(model_name: str) -> str:
"""
Validate Whisper model name.
Args:
model_name: Model name to validate
Returns:
Validated model name
Raises:
ValidationError: If model name invalid
"""
if not model_name:
raise ValidationError("Model name cannot be empty")
model_name = model_name.strip().lower()
if model_name not in ALLOWED_MODEL_NAMES:
raise ValidationError(
f"Invalid model name: {model_name}. "
f"Allowed: {', '.join(sorted(ALLOWED_MODEL_NAMES))}"
)
return model_name
def validate_device(device: str) -> str:
"""
Validate device parameter.
Args:
device: Device name to validate
Returns:
Validated device name
Raises:
ValidationError: If device invalid
"""
if not device:
raise ValidationError("Device cannot be empty")
device = device.strip().lower()
if device not in ALLOWED_DEVICES:
raise ValidationError(
f"Invalid device: {device}. "
f"Allowed: {', '.join(sorted(ALLOWED_DEVICES))}"
)
return device
def validate_compute_type(compute_type: str) -> str:
"""
Validate compute type parameter.
Args:
compute_type: Compute type to validate
Returns:
Validated compute type
Raises:
ValidationError: If compute type invalid
"""
if not compute_type:
raise ValidationError("Compute type cannot be empty")
compute_type = compute_type.strip().lower()
if compute_type not in ALLOWED_COMPUTE_TYPES:
raise ValidationError(
f"Invalid compute type: {compute_type}. "
f"Allowed: {', '.join(sorted(ALLOWED_COMPUTE_TYPES))}"
)
return compute_type
def validate_output_format(output_format: str) -> str:
"""
Validate output format parameter.
Args:
output_format: Output format to validate
Returns:
Validated output format
Raises:
ValidationError: If output format invalid
"""
if not output_format:
raise ValidationError("Output format cannot be empty")
output_format = output_format.strip().lower()
if output_format not in ALLOWED_OUTPUT_FORMATS:
raise ValidationError(
f"Invalid output format: {output_format}. "
f"Allowed: {', '.join(sorted(ALLOWED_OUTPUT_FORMATS))}"
)
return output_format
def validate_numeric_range(
value: float,
min_value: float,
max_value: float,
param_name: str
) -> float:
"""
Validate numeric parameter is within range.
Args:
value: Value to validate
min_value: Minimum allowed value
max_value: Maximum allowed value
param_name: Parameter name for error messages
Returns:
Validated value
Raises:
ValidationError: If value out of range
"""
if value < min_value or value > max_value:
raise ValidationError(
f"{param_name} must be between {min_value} and {max_value}, "
f"got {value}"
)
return value
def validate_beam_size(beam_size: int) -> int:
"""Validate beam size parameter."""
return int(validate_numeric_range(beam_size, 1, 20, "beam_size"))
def validate_temperature(temperature: float) -> float:
"""Validate temperature parameter."""
return validate_numeric_range(temperature, 0.0, 1.0, "temperature")

237
src/utils/startup.py Normal file
View File

@@ -0,0 +1,237 @@
#!/usr/bin/env python3
"""
Common Startup Logic Module
Centralizes startup procedures shared between MCP and API servers,
including GPU health checks, job queue initialization, and health monitoring.
"""
import os
import sys
import logging
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
# Import GPU health check with reset
try:
from core.gpu_health import check_gpu_health_with_reset
GPU_HEALTH_CHECK_AVAILABLE = True
except ImportError as e:
logger.warning(f"GPU health check with reset not available: {e}")
GPU_HEALTH_CHECK_AVAILABLE = False
def perform_startup_gpu_check(
required_device: str = "cuda",
auto_reset: bool = True,
exit_on_failure: bool = True
) -> bool:
"""
Perform startup GPU health check with optional auto-reset.
This function:
1. Checks if GPU health check is available
2. Runs comprehensive GPU health check
3. Attempts auto-reset if check fails and auto_reset=True
4. Optionally exits process if check fails
Args:
required_device: Required device ("cuda", "auto")
auto_reset: Enable automatic GPU driver reset on failure
exit_on_failure: Exit process if GPU check fails
Returns:
True if GPU check passed, False otherwise
Side effects:
May exit process if exit_on_failure=True and check fails
"""
if not GPU_HEALTH_CHECK_AVAILABLE:
logger.warning("GPU health check not available, starting without GPU validation")
if exit_on_failure:
logger.error("GPU health check required but not available. Exiting.")
sys.exit(1)
return False
try:
logger.info("=" * 70)
logger.info("PERFORMING STARTUP GPU HEALTH CHECK")
logger.info("=" * 70)
status = check_gpu_health_with_reset(
expected_device=required_device,
auto_reset=auto_reset
)
logger.info("=" * 70)
logger.info("STARTUP GPU CHECK SUCCESSFUL")
logger.info(f"GPU Device: {status.device_name}")
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
logger.info("=" * 70)
return True
except Exception as e:
logger.error("=" * 70)
logger.error("STARTUP GPU CHECK FAILED")
logger.error(f"Error: {e}")
if exit_on_failure:
logger.error("This service requires GPU. Terminating.")
logger.error("=" * 70)
sys.exit(1)
else:
logger.error("Continuing without GPU (may have reduced functionality)")
logger.error("=" * 70)
return False
def initialize_job_queue(
max_queue_size: Optional[int] = None,
metadata_dir: Optional[str] = None
) -> 'JobQueue':
"""
Initialize job queue with environment variable configuration.
Args:
max_queue_size: Override for max queue size (uses env var if None)
metadata_dir: Override for metadata directory (uses env var if None)
Returns:
Initialized JobQueue instance (started)
"""
from core.job_queue import JobQueue
# Get configuration from environment
if max_queue_size is None:
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
if metadata_dir is None:
metadata_dir = os.getenv(
"JOB_METADATA_DIR",
"/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
)
logger.info("Initializing job queue...")
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
job_queue.start()
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
return job_queue
def initialize_health_monitor(
check_interval_minutes: Optional[int] = None,
enabled: Optional[bool] = None
) -> Optional['HealthMonitor']:
"""
Initialize GPU health monitor with environment variable configuration.
Args:
check_interval_minutes: Override for check interval (uses env var if None)
enabled: Override for enabled status (uses env var if None)
Returns:
Initialized HealthMonitor instance (started), or None if disabled
"""
from core.gpu_health import HealthMonitor
# Get configuration from environment
if enabled is None:
enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true"
if not enabled:
logger.info("GPU health monitoring disabled")
return None
if check_interval_minutes is None:
check_interval_minutes = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
health_monitor = HealthMonitor(check_interval_minutes=check_interval_minutes)
health_monitor.start()
logger.info(f"GPU health monitor started (interval={check_interval_minutes} minutes)")
return health_monitor
def startup_sequence(
service_name: str = "whisper-transcription",
require_gpu: bool = True,
initialize_queue: bool = True,
initialize_monitoring: bool = True
) -> Tuple[Optional['JobQueue'], Optional['HealthMonitor']]:
"""
Execute complete startup sequence for a Whisper transcription server.
This function performs all common startup tasks:
1. GPU health check with auto-reset
2. Job queue initialization
3. Health monitor initialization
Args:
service_name: Name of the service (for logging)
require_gpu: Whether GPU is required (exit if not available)
initialize_queue: Whether to initialize job queue
initialize_monitoring: Whether to initialize health monitoring
Returns:
Tuple of (job_queue, health_monitor) - either may be None
Side effects:
May exit process if GPU required but unavailable
"""
logger.info(f"Starting {service_name}...")
# Step 1: GPU health check
gpu_ok = perform_startup_gpu_check(
required_device="cuda",
auto_reset=True,
exit_on_failure=require_gpu
)
if not gpu_ok and require_gpu:
# Should not reach here (exit_on_failure should have exited)
logger.error("GPU check failed and GPU is required")
sys.exit(1)
# Step 2: Initialize job queue
job_queue = None
if initialize_queue:
job_queue = initialize_job_queue()
# Step 3: Initialize health monitor
health_monitor = None
if initialize_monitoring:
health_monitor = initialize_health_monitor()
logger.info(f"{service_name} startup sequence completed")
return job_queue, health_monitor
def cleanup_on_shutdown(
job_queue: Optional['JobQueue'] = None,
health_monitor: Optional['HealthMonitor'] = None,
wait_for_current_job: bool = True
) -> None:
"""
Perform cleanup on server shutdown.
Args:
job_queue: JobQueue instance to stop (if any)
health_monitor: HealthMonitor instance to stop (if any)
wait_for_current_job: Wait for current job to complete before stopping
"""
logger.info("Shutting down...")
if job_queue:
job_queue.stop(wait_for_current=wait_for_current_job)
logger.info("Job queue stopped")
if health_monitor:
health_monitor.stop()
logger.info("Health monitor stopped")
logger.info("Shutdown complete")

View File

@@ -22,8 +22,8 @@ logging.basicConfig(
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Add src to path # Add src to path (go up one level from tests/ to root)
sys.path.insert(0, str(Path(__file__).parent / "src")) sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
# Color codes for terminal output # Color codes for terminal output
class Colors: class Colors:
@@ -435,7 +435,9 @@ class Phase2Tester:
def _create_test_audio_file(self): def _create_test_audio_file(self):
"""Get the path to the test audio file""" """Get the path to the test audio file"""
test_audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.mp3" # Use relative path from project root
project_root = Path(__file__).parent.parent
test_audio_path = str(project_root / "data" / "test.mp3")
if not os.path.exists(test_audio_path): if not os.path.exists(test_audio_path):
raise FileNotFoundError(f"Test audio file not found: {test_audio_path}") raise FileNotFoundError(f"Test audio file not found: {test_audio_path}")
return test_audio_path return test_audio_path

View File

@@ -24,8 +24,8 @@ logging.basicConfig(
datefmt='%H:%M:%S' datefmt='%H:%M:%S'
) )
# Add src to path # Add src to path (go up one level from tests/ to root)
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from core.gpu_health import check_gpu_health, HealthMonitor from core.gpu_health import check_gpu_health, HealthMonitor
from core.job_queue import JobQueue, JobStatus from core.job_queue import JobQueue, JobStatus
@@ -61,8 +61,9 @@ def test_audio_file():
print("="*60) print("="*60)
try: try:
# Use the actual test audio file # Use the actual test audio file (relative to project root)
audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.m4a" project_root = os.path.join(os.path.dirname(__file__), '..')
audio_path = os.path.join(project_root, "data/test.mp3")
# Verify file exists # Verify file exists
assert os.path.exists(audio_path), "Audio file not found" assert os.path.exists(audio_path), "Audio file not found"
@@ -147,8 +148,9 @@ def test_job_queue():
job_queue.start() job_queue.start()
print("✓ Job queue started") print("✓ Job queue started")
# Use the actual test audio file # Use the actual test audio file (relative to project root)
audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.m4a" project_root = os.path.join(os.path.dirname(__file__), '..')
audio_path = os.path.join(project_root, "data/test.mp3")
# Test job submission # Test job submission
print("\nSubmitting test job...") print("\nSubmitting test job...")

523
tests/test_e2e_integration.py Executable file
View File

@@ -0,0 +1,523 @@
#!/usr/bin/env python3
"""
Test Phase 4: End-to-End Integration Testing
Comprehensive integration tests for the async job queue system.
Tests all scenarios from the DEV_PLAN.md Phase 4 checklist.
"""
import os
import sys
import time
import json
import logging
import requests
import subprocess
import signal
from pathlib import Path
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# Add src to path (go up one level from tests/ to root)
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
# Color codes for terminal output
class Colors:
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
CYAN = '\033[96m'
END = '\033[0m'
BOLD = '\033[1m'
def print_success(msg):
print(f"{Colors.GREEN}{msg}{Colors.END}")
def print_error(msg):
print(f"{Colors.RED}{msg}{Colors.END}")
def print_info(msg):
print(f"{Colors.BLUE} {msg}{Colors.END}")
def print_warning(msg):
print(f"{Colors.YELLOW}{msg}{Colors.END}")
def print_section(msg):
print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}")
print(f"{Colors.BOLD}{Colors.YELLOW}{msg}{Colors.END}")
print(f"{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}\n")
class Phase4Tester:
def __init__(self, api_url="http://localhost:8000", test_audio=None):
self.api_url = api_url
# Use relative path from project root if not provided
if test_audio is None:
project_root = Path(__file__).parent.parent
test_audio = str(project_root / "data" / "test.mp3")
self.test_audio = test_audio
self.test_results = []
self.server_process = None
# Verify test audio exists
if not os.path.exists(test_audio):
raise FileNotFoundError(f"Test audio file not found: {test_audio}")
def test(self, name, func):
"""Run a test and record result"""
try:
logger.info(f"Testing: {name}")
print_info(f"Testing: {name}")
func()
logger.info(f"PASSED: {name}")
print_success(f"PASSED: {name}")
self.test_results.append((name, True, None))
return True
except AssertionError as e:
logger.error(f"FAILED: {name} - {str(e)}")
print_error(f"FAILED: {name}")
print_error(f" Reason: {str(e)}")
self.test_results.append((name, False, str(e)))
return False
except Exception as e:
logger.error(f"ERROR: {name} - {str(e)}")
print_error(f"ERROR: {name}")
print_error(f" Exception: {str(e)}")
self.test_results.append((name, False, f"Exception: {str(e)}"))
return False
def start_api_server(self, wait_time=5):
"""Start the API server in background"""
print_info("Starting API server...")
# Script is in project root, one level up from tests/
script_path = Path(__file__).parent.parent / "run_api_server.sh"
# Start server in background
self.server_process = subprocess.Popen(
[str(script_path)],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
preexec_fn=os.setsid
)
# Wait for server to start
time.sleep(wait_time)
# Verify server is running
try:
response = requests.get(f"{self.api_url}/health", timeout=5)
if response.status_code == 200:
print_success("API server started successfully")
return True
except:
pass
print_error("API server failed to start")
return False
def stop_api_server(self):
"""Stop the API server"""
if self.server_process:
print_info("Stopping API server...")
os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM)
self.server_process.wait(timeout=10)
print_success("API server stopped")
def wait_for_job_completion(self, job_id, timeout=60, poll_interval=2):
"""Poll job status until completed or failed"""
start_time = time.time()
last_status = None
while time.time() - start_time < timeout:
try:
response = requests.get(f"{self.api_url}/jobs/{job_id}")
assert response.status_code == 200, f"Failed to get job status: {response.status_code}"
status_data = response.json()
current_status = status_data['status']
# Print status changes
if current_status != last_status:
if status_data.get('queue_position') is not None:
print_info(f" Job status: {current_status}, queue position: {status_data['queue_position']}")
else:
print_info(f" Job status: {current_status}")
last_status = current_status
if current_status == "completed":
return status_data
elif current_status == "failed":
raise AssertionError(f"Job failed: {status_data.get('error', 'Unknown error')}")
time.sleep(poll_interval)
except requests.exceptions.RequestException as e:
raise AssertionError(f"Request failed: {e}")
raise AssertionError(f"Job did not complete within {timeout} seconds")
# ========================================================================
# TEST 1: Single Job Submission and Completion
# ========================================================================
def test_single_job_flow(self):
"""Test complete job flow: submit → poll → get result"""
# Submit job
print_info(" Submitting job...")
response = requests.post(f"{self.api_url}/jobs", json={
"audio_path": self.test_audio,
"model_name": "large-v3",
"output_format": "txt"
})
assert response.status_code == 200, f"Job submission failed: {response.status_code}"
job_data = response.json()
assert 'job_id' in job_data, "No job_id in response"
# Status can be 'queued' or 'running' (if queue is empty and job starts immediately)
assert job_data['status'] in ['queued', 'running'], f"Expected status 'queued' or 'running', got '{job_data['status']}'"
job_id = job_data['job_id']
print_success(f" Job submitted: {job_id}")
# Wait for completion
print_info(" Waiting for job completion...")
final_status = self.wait_for_job_completion(job_id)
assert final_status['status'] == 'completed', "Job did not complete"
assert final_status['result_path'] is not None, "No result_path in completed job"
assert final_status['processing_time_seconds'] is not None, "No processing time"
print_success(f" Job completed in {final_status['processing_time_seconds']:.2f}s")
# Get result
print_info(" Retrieving result...")
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
assert response.status_code == 200, f"Failed to get result: {response.status_code}"
result_text = response.text
assert len(result_text) > 0, "Empty result"
print_success(f" Got result: {len(result_text)} characters")
# ========================================================================
# TEST 2: Multiple Jobs in Queue (FIFO)
# ========================================================================
def test_multiple_jobs_fifo(self):
"""Test multiple jobs are processed in FIFO order"""
job_ids = []
# Submit 3 jobs
print_info(" Submitting 3 jobs...")
for i in range(3):
response = requests.post(f"{self.api_url}/jobs", json={
"audio_path": self.test_audio,
"model_name": "tiny", # Use tiny model for faster processing
"output_format": "txt"
})
assert response.status_code == 200, f"Job {i+1} submission failed"
job_data = response.json()
job_ids.append(job_data['job_id'])
print_info(f" Job {i+1} submitted: {job_data['job_id']}, queue_position: {job_data.get('queue_position', 0)}")
# Wait for all jobs to complete
print_info(" Waiting for all jobs to complete...")
for i, job_id in enumerate(job_ids):
print_info(f" Waiting for job {i+1}/{len(job_ids)}...")
final_status = self.wait_for_job_completion(job_id, timeout=120)
assert final_status['status'] == 'completed', f"Job {i+1} failed"
print_success(f" All {len(job_ids)} jobs completed in FIFO order")
# ========================================================================
# TEST 3: GPU Health Check
# ========================================================================
def test_gpu_health_check(self):
"""Test GPU health check endpoint"""
print_info(" Checking GPU health...")
response = requests.get(f"{self.api_url}/health/gpu")
assert response.status_code == 200, f"GPU health check failed: {response.status_code}"
health_data = response.json()
assert 'gpu_available' in health_data, "Missing gpu_available field"
assert 'gpu_working' in health_data, "Missing gpu_working field"
assert 'device_used' in health_data, "Missing device_used field"
print_info(f" GPU Available: {health_data['gpu_available']}")
print_info(f" GPU Working: {health_data['gpu_working']}")
print_info(f" Device: {health_data['device_used']}")
if health_data['gpu_available']:
assert health_data['device_name'], "GPU available but no device_name"
assert health_data['test_duration_seconds'] < 3, "GPU test took too long (might be using CPU)"
print_success(f" GPU is healthy: {health_data['device_name']}")
else:
print_warning(" GPU not available on this system")
# ========================================================================
# TEST 4: Invalid Audio Path
# ========================================================================
def test_invalid_audio_path(self):
"""Test job submission with invalid audio path"""
print_info(" Submitting job with invalid path...")
response = requests.post(f"{self.api_url}/jobs", json={
"audio_path": "/invalid/path/does/not/exist.mp3",
"model_name": "large-v3"
})
# Should return 400 Bad Request
assert response.status_code == 400, f"Expected 400, got {response.status_code}"
error_data = response.json()
assert 'detail' in error_data or 'error' in error_data, "No error message in response"
print_success(" Invalid path rejected correctly")
# ========================================================================
# TEST 5: Job Not Found
# ========================================================================
def test_job_not_found(self):
"""Test retrieving non-existent job"""
print_info(" Requesting non-existent job...")
fake_job_id = "00000000-0000-0000-0000-000000000000"
response = requests.get(f"{self.api_url}/jobs/{fake_job_id}")
assert response.status_code == 404, f"Expected 404, got {response.status_code}"
print_success(" Non-existent job handled correctly")
# ========================================================================
# TEST 6: Result Before Completion
# ========================================================================
def test_result_before_completion(self):
"""Test getting result for job that hasn't completed"""
print_info(" Submitting job and trying to get result immediately...")
# Submit job
response = requests.post(f"{self.api_url}/jobs", json={
"audio_path": self.test_audio,
"model_name": "large-v3"
})
assert response.status_code == 200
job_id = response.json()['job_id']
# Try to get result immediately (job is still queued/running)
time.sleep(0.5)
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
# Should return 409 Conflict or similar
assert response.status_code in [409, 400, 404], f"Expected 4xx error, got {response.status_code}"
print_success(" Result request before completion handled correctly")
# Clean up: wait for job to complete
self.wait_for_job_completion(job_id)
# ========================================================================
# TEST 7: List Jobs
# ========================================================================
def test_list_jobs(self):
"""Test listing jobs with filters"""
print_info(" Testing job listing...")
# List all jobs
response = requests.get(f"{self.api_url}/jobs")
assert response.status_code == 200, f"List jobs failed: {response.status_code}"
jobs_data = response.json()
assert 'jobs' in jobs_data, "No jobs array in response"
assert isinstance(jobs_data['jobs'], list), "Jobs is not a list"
print_info(f" Found {len(jobs_data['jobs'])} jobs")
# List only completed jobs
response = requests.get(f"{self.api_url}/jobs?status=completed")
assert response.status_code == 200
completed_jobs = response.json()['jobs']
print_info(f" Found {len(completed_jobs)} completed jobs")
# List with limit
response = requests.get(f"{self.api_url}/jobs?limit=5")
assert response.status_code == 200
limited_jobs = response.json()['jobs']
assert len(limited_jobs) <= 5, "Limit not respected"
print_success(" Job listing works correctly")
# ========================================================================
# TEST 8: Server Restart with Job Persistence
# ========================================================================
def test_server_restart_persistence(self):
"""Test that jobs persist across server restarts"""
print_info(" Testing job persistence across restart...")
# Submit a job
response = requests.post(f"{self.api_url}/jobs", json={
"audio_path": self.test_audio,
"model_name": "tiny"
})
assert response.status_code == 200
job_id = response.json()['job_id']
print_info(f" Submitted job: {job_id}")
# Get job count before restart
response = requests.get(f"{self.api_url}/jobs")
jobs_before = len(response.json()['jobs'])
print_info(f" Jobs before restart: {jobs_before}")
# Restart server
print_info(" Restarting server...")
self.stop_api_server()
time.sleep(2)
assert self.start_api_server(wait_time=8), "Server failed to restart"
# Check jobs after restart
response = requests.get(f"{self.api_url}/jobs")
assert response.status_code == 200
jobs_after = len(response.json()['jobs'])
print_info(f" Jobs after restart: {jobs_after}")
# Check our specific job is still there (this is the key test)
response = requests.get(f"{self.api_url}/jobs/{job_id}")
assert response.status_code == 200, "Job not found after restart"
# Note: Total count may differ due to job retention/cleanup, but persistence works if we can find the job
if jobs_after < jobs_before:
print_warning(f" Job count decreased ({jobs_before} -> {jobs_after}), may be due to cleanup")
print_success(" Jobs persisted correctly across restart")
# ========================================================================
# TEST 9: Health Endpoint
# ========================================================================
def test_health_endpoint(self):
"""Test basic health endpoint"""
print_info(" Checking health endpoint...")
response = requests.get(f"{self.api_url}/health")
assert response.status_code == 200, f"Health check failed: {response.status_code}"
health_data = response.json()
assert health_data['status'] == 'healthy', "Server not healthy"
print_success(" Health endpoint OK")
# ========================================================================
# TEST 10: Models Endpoint
# ========================================================================
def test_models_endpoint(self):
"""Test models information endpoint"""
print_info(" Checking models endpoint...")
response = requests.get(f"{self.api_url}/models")
assert response.status_code == 200, f"Models endpoint failed: {response.status_code}"
models_data = response.json()
assert 'available_models' in models_data, "No available_models field"
assert 'available_devices' in models_data, "No available_devices field"
assert len(models_data['available_models']) > 0, "No models listed"
print_info(f" Available models: {len(models_data['available_models'])}")
print_success(" Models endpoint OK")
def print_summary(self):
"""Print test summary"""
print_section("TEST SUMMARY")
passed = sum(1 for _, result, _ in self.test_results if result)
failed = len(self.test_results) - passed
for name, result, error in self.test_results:
if result:
print_success(f"{name}")
else:
print_error(f"{name}")
if error:
print(f" {error}")
print(f"\n{Colors.BOLD}Total: {len(self.test_results)} | ", end="")
print(f"{Colors.GREEN}Passed: {passed}{Colors.END} | ", end="")
print(f"{Colors.RED}Failed: {failed}{Colors.END}\n")
return failed == 0
def run_all_tests(self, start_server=True):
"""Run all Phase 4 integration tests"""
print_section("PHASE 4: END-TO-END INTEGRATION TESTING")
try:
# Start server if requested
if start_server:
if not self.start_api_server():
print_error("Failed to start API server. Aborting tests.")
return False
else:
# Verify server is already running
try:
response = requests.get(f"{self.api_url}/health", timeout=5)
if response.status_code != 200:
print_error("Server is not responding. Please start it first.")
return False
print_info("Using existing API server")
except:
print_error("Cannot connect to API server. Please start it first.")
return False
# Run tests
print_section("TEST 1: Single Job Submission and Completion")
self.test("Single job flow (submit → poll → get result)", self.test_single_job_flow)
print_section("TEST 2: Multiple Jobs (FIFO Order)")
self.test("Multiple jobs in queue (FIFO)", self.test_multiple_jobs_fifo)
print_section("TEST 3: GPU Health Check")
self.test("GPU health check endpoint", self.test_gpu_health_check)
print_section("TEST 4: Error Handling - Invalid Path")
self.test("Invalid audio path rejection", self.test_invalid_audio_path)
print_section("TEST 5: Error Handling - Job Not Found")
self.test("Non-existent job handling", self.test_job_not_found)
print_section("TEST 6: Error Handling - Result Before Completion")
self.test("Result request before completion", self.test_result_before_completion)
print_section("TEST 7: Job Listing")
self.test("List jobs with filters", self.test_list_jobs)
print_section("TEST 8: Health Endpoint")
self.test("Basic health endpoint", self.test_health_endpoint)
print_section("TEST 9: Models Endpoint")
self.test("Models information endpoint", self.test_models_endpoint)
print_section("TEST 10: Server Restart Persistence")
self.test("Job persistence across server restart", self.test_server_restart_persistence)
# Print summary
success = self.print_summary()
return success
finally:
# Cleanup
if start_server and self.server_process:
self.stop_api_server()
def main():
"""Main test runner"""
import argparse
parser = argparse.ArgumentParser(description='Phase 4 Integration Tests')
parser.add_argument('--url', default='http://localhost:8000', help='API server URL')
# Default to None so Phase4Tester uses relative path
parser.add_argument('--audio', default=None,
help='Path to test audio file (default: <project_root>/data/test.mp3)')
parser.add_argument('--no-start-server', action='store_true',
help='Do not start server (assume it is already running)')
args = parser.parse_args()
tester = Phase4Tester(api_url=args.url, test_audio=args.audio)
success = tester.run_all_tests(start_server=not args.no_start_server)
sys.exit(0 if success else 1)
if __name__ == '__main__':
main()