Add circuit breaker, input validation, and refactor startup logic
- Implement circuit breaker pattern for GPU health checks - Prevents repeated failures with configurable thresholds - Three states: CLOSED, OPEN, HALF_OPEN - Integrated into GPU health monitoring - Add comprehensive input validation and path sanitization - Path traversal attack prevention - Whitelist-based validation for models, devices, formats - Error message sanitization to prevent information leakage - File size limits and security checks - Centralize startup logic across servers - Extract common startup procedures to utils/startup.py - Deduplicate GPU health checks and initialization code - Simplify both MCP and API server startup sequences - Add proper Python package structure - Add __init__.py files to all modules - Improve package organization - Add circuit breaker status API endpoints - GET /health/circuit-breaker - View circuit breaker stats - POST /health/circuit-breaker/reset - Reset circuit breaker - Reorganize test files into tests/ directory - Rename and restructure test files for better organization
This commit is contained in:
1488
DEV_PLAN.md
1488
DEV_PLAN.md
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,9 @@
|
||||
"""
|
||||
语音识别MCP服务模块
|
||||
Faster Whisper MCP Transcription Service
|
||||
|
||||
High-performance audio transcription service with dual-server architecture
|
||||
(MCP and REST API) featuring async job queue and GPU health monitoring.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.2.0"
|
||||
__author__ = "Whisper MCP Team"
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Core modules for Whisper transcription service.
|
||||
|
||||
Includes model management, transcription logic, job queue, GPU health monitoring,
|
||||
and GPU reset functionality.
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ GPU health monitoring for Whisper transcription service.
|
||||
|
||||
Performs real GPU health checks using actual model loading and transcription,
|
||||
with strict failure handling to prevent silent CPU fallbacks.
|
||||
Includes circuit breaker pattern to prevent repeated failed checks.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -14,9 +15,19 @@ from typing import Optional, List
|
||||
import torch
|
||||
|
||||
from utils.test_audio_generator import generate_test_audio
|
||||
from utils.circuit_breaker import CircuitBreaker, CircuitBreakerOpen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global circuit breaker for GPU health checks
|
||||
_gpu_health_circuit_breaker = CircuitBreaker(
|
||||
name="gpu_health_check",
|
||||
failure_threshold=3, # Open after 3 consecutive failures
|
||||
success_threshold=2, # Close after 2 consecutive successes
|
||||
timeout_seconds=60, # Try again after 60 seconds
|
||||
half_open_max_calls=1 # Only 1 test call in half-open state
|
||||
)
|
||||
|
||||
# Import reset functionality (after logger initialization)
|
||||
try:
|
||||
from core.gpu_reset import (
|
||||
@@ -48,7 +59,7 @@ class GPUHealthStatus:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def check_gpu_health(expected_device: str = "auto") -> GPUHealthStatus:
|
||||
def _check_gpu_health_internal(expected_device: str = "auto") -> GPUHealthStatus:
|
||||
"""
|
||||
Comprehensive GPU health check using real model + transcription.
|
||||
|
||||
@@ -207,6 +218,58 @@ def check_gpu_health(expected_device: str = "auto") -> GPUHealthStatus:
|
||||
return status
|
||||
|
||||
|
||||
def check_gpu_health(expected_device: str = "auto", use_circuit_breaker: bool = True) -> GPUHealthStatus:
|
||||
"""
|
||||
GPU health check with optional circuit breaker protection.
|
||||
|
||||
This is the main entry point for GPU health checks. By default, it uses
|
||||
circuit breaker pattern to prevent repeated failed checks.
|
||||
|
||||
Args:
|
||||
expected_device: Expected device ("auto", "cuda", "cpu")
|
||||
use_circuit_breaker: Enable circuit breaker protection (default: True)
|
||||
|
||||
Returns:
|
||||
GPUHealthStatus object
|
||||
|
||||
Raises:
|
||||
RuntimeError: If expected_device="cuda" but GPU test fails
|
||||
CircuitBreakerOpen: If circuit breaker is open (too many recent failures)
|
||||
"""
|
||||
if use_circuit_breaker:
|
||||
try:
|
||||
return _gpu_health_circuit_breaker.call(_check_gpu_health_internal, expected_device)
|
||||
except CircuitBreakerOpen as e:
|
||||
# Circuit is open, fail fast without attempting check
|
||||
logger.warning(f"GPU health check circuit breaker is OPEN: {e}")
|
||||
raise RuntimeError(f"GPU health check unavailable: {str(e)}")
|
||||
else:
|
||||
return _check_gpu_health_internal(expected_device)
|
||||
|
||||
|
||||
def get_circuit_breaker_stats() -> dict:
|
||||
"""
|
||||
Get current circuit breaker statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with circuit state and failure/success counts
|
||||
"""
|
||||
return _gpu_health_circuit_breaker.get_stats()
|
||||
|
||||
|
||||
def reset_circuit_breaker():
|
||||
"""
|
||||
Manually reset the GPU health check circuit breaker.
|
||||
|
||||
Useful for:
|
||||
- Testing
|
||||
- Manual intervention after fixing GPU issues
|
||||
- Clearing persistent error state
|
||||
"""
|
||||
_gpu_health_circuit_breaker.reset()
|
||||
logger.info("GPU health check circuit breaker manually reset")
|
||||
|
||||
|
||||
def check_gpu_health_with_reset(
|
||||
expected_device: str = "cuda",
|
||||
auto_reset: bool = True
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Server implementations for Whisper transcription service.
|
||||
|
||||
Includes MCP server (whisper_server.py) and REST API server (api_server.py).
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,8 @@ import os
|
||||
import sys
|
||||
import logging
|
||||
import queue as queue_module
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, List
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
@@ -17,20 +19,13 @@ import json
|
||||
|
||||
from core.model_manager import get_model_info
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
from core.gpu_health import HealthMonitor, check_gpu_health
|
||||
from core.gpu_health import HealthMonitor, check_gpu_health, get_circuit_breaker_stats, reset_circuit_breaker
|
||||
from utils.startup import startup_sequence, cleanup_on_shutdown
|
||||
|
||||
# Logging configuration
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import GPU health check with reset
|
||||
try:
|
||||
from core.gpu_health import check_gpu_health_with_reset
|
||||
GPU_HEALTH_CHECK_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"GPU health check with reset not available: {e}")
|
||||
GPU_HEALTH_CHECK_AVAILABLE = False
|
||||
|
||||
# Global instances
|
||||
job_queue: Optional[JobQueue] = None
|
||||
health_monitor: Optional[HealthMonitor] = None
|
||||
@@ -41,34 +36,22 @@ async def lifespan(app: FastAPI):
|
||||
"""FastAPI lifespan context manager for startup/shutdown"""
|
||||
global job_queue, health_monitor
|
||||
|
||||
# Startup
|
||||
logger.info("Starting job queue and health monitor...")
|
||||
# Startup - use common startup logic (without GPU check, handled in main)
|
||||
logger.info("Initializing job queue and health monitor...")
|
||||
|
||||
# Initialize job queue
|
||||
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
|
||||
metadata_dir = os.getenv("JOB_METADATA_DIR", "/media/raid/agents/tools/mcp-transcriptor/outputs/jobs")
|
||||
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
|
||||
job_queue.start()
|
||||
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
|
||||
from utils.startup import initialize_job_queue, initialize_health_monitor
|
||||
|
||||
# Initialize health monitor
|
||||
health_check_enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true"
|
||||
if health_check_enabled:
|
||||
check_interval = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
|
||||
health_monitor = HealthMonitor(check_interval_minutes=check_interval)
|
||||
health_monitor.start()
|
||||
logger.info(f"GPU health monitor started (interval={check_interval} minutes)")
|
||||
job_queue = initialize_job_queue()
|
||||
health_monitor = initialize_health_monitor()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down job queue and health monitor...")
|
||||
if job_queue:
|
||||
job_queue.stop(wait_for_current=True)
|
||||
logger.info("Job queue stopped")
|
||||
if health_monitor:
|
||||
health_monitor.stop()
|
||||
logger.info("Health monitor stopped")
|
||||
# Shutdown - use common cleanup logic
|
||||
cleanup_on_shutdown(
|
||||
job_queue=job_queue,
|
||||
health_monitor=health_monitor,
|
||||
wait_for_current_job=True
|
||||
)
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
@@ -107,6 +90,8 @@ async def root():
|
||||
"GET /": "API information",
|
||||
"GET /health": "Health check",
|
||||
"GET /health/gpu": "GPU health check",
|
||||
"GET /health/circuit-breaker": "Get circuit breaker stats",
|
||||
"POST /health/circuit-breaker/reset": "Reset circuit breaker",
|
||||
"GET /models": "Get available models information",
|
||||
"POST /jobs": "Submit transcription job (async)",
|
||||
"GET /jobs/{job_id}": "Get job status",
|
||||
@@ -418,6 +403,39 @@ async def gpu_health_check_endpoint():
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health/circuit-breaker")
|
||||
async def get_circuit_breaker_status():
|
||||
"""
|
||||
Get GPU health check circuit breaker statistics.
|
||||
|
||||
Returns current state, failure/success counts, and last failure time.
|
||||
"""
|
||||
try:
|
||||
stats = get_circuit_breaker_stats()
|
||||
return JSONResponse(status_code=200, content=stats)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get circuit breaker stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/health/circuit-breaker/reset")
|
||||
async def reset_circuit_breaker_endpoint():
|
||||
"""
|
||||
Manually reset the GPU health check circuit breaker.
|
||||
|
||||
Useful after fixing GPU issues or for testing purposes.
|
||||
"""
|
||||
try:
|
||||
reset_circuit_breaker()
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"message": "Circuit breaker reset successfully"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset circuit breaker: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -425,31 +443,14 @@ async def gpu_health_check_endpoint():
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Perform startup GPU health check with auto-reset
|
||||
if GPU_HEALTH_CHECK_AVAILABLE:
|
||||
try:
|
||||
logger.info("=" * 70)
|
||||
logger.info("PERFORMING STARTUP GPU HEALTH CHECK")
|
||||
logger.info("=" * 70)
|
||||
# Perform startup GPU health check
|
||||
from utils.startup import perform_startup_gpu_check
|
||||
|
||||
status = check_gpu_health_with_reset(expected_device="cuda", auto_reset=True)
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("STARTUP GPU CHECK SUCCESSFUL")
|
||||
logger.info(f"GPU Device: {status.device_name}")
|
||||
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
|
||||
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
|
||||
logger.info("=" * 70)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("=" * 70)
|
||||
logger.error("STARTUP GPU CHECK FAILED")
|
||||
logger.error(f"Error: {e}")
|
||||
logger.error("This service requires GPU. Terminating.")
|
||||
logger.error("=" * 70)
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.warning("GPU health check not available, starting without GPU validation")
|
||||
perform_startup_gpu_check(
|
||||
required_device="cuda",
|
||||
auto_reset=True,
|
||||
exit_on_failure=True
|
||||
)
|
||||
|
||||
# Get configuration from environment variables
|
||||
host = os.getenv("API_HOST", "0.0.0.0")
|
||||
|
||||
@@ -8,25 +8,21 @@ import os
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from core.model_manager import get_model_info
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
from core.gpu_health import HealthMonitor, check_gpu_health
|
||||
from utils.startup import startup_sequence, cleanup_on_shutdown
|
||||
|
||||
# Log configuration
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import GPU health check with reset
|
||||
try:
|
||||
from core.gpu_health import check_gpu_health_with_reset
|
||||
GPU_HEALTH_CHECK_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"GPU health check with reset not available: {e}")
|
||||
GPU_HEALTH_CHECK_AVAILABLE = False
|
||||
|
||||
# Global instances
|
||||
job_queue: Optional[JobQueue] = None
|
||||
health_monitor: Optional[HealthMonitor] = None
|
||||
@@ -319,56 +315,20 @@ def check_gpu_health() -> str:
|
||||
if __name__ == "__main__":
|
||||
print("starting mcp server for whisper stt transcriptor")
|
||||
|
||||
# Perform startup GPU health check with auto-reset
|
||||
if GPU_HEALTH_CHECK_AVAILABLE:
|
||||
try:
|
||||
logger.info("=" * 70)
|
||||
logger.info("PERFORMING STARTUP GPU HEALTH CHECK")
|
||||
logger.info("=" * 70)
|
||||
|
||||
status = check_gpu_health_with_reset(expected_device="cuda", auto_reset=True)
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("STARTUP GPU CHECK SUCCESSFUL")
|
||||
logger.info(f"GPU Device: {status.device_name}")
|
||||
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
|
||||
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
|
||||
logger.info("=" * 70)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("=" * 70)
|
||||
logger.error("STARTUP GPU CHECK FAILED")
|
||||
logger.error(f"Error: {e}")
|
||||
logger.error("This service requires GPU. Terminating.")
|
||||
logger.error("=" * 70)
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.warning("GPU health check not available, starting without GPU validation")
|
||||
|
||||
# Initialize job queue
|
||||
logger.info("Initializing job queue...")
|
||||
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
|
||||
metadata_dir = os.getenv("JOB_METADATA_DIR", "/media/raid/agents/tools/mcp-transcriptor/outputs/jobs")
|
||||
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
|
||||
job_queue.start()
|
||||
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
|
||||
|
||||
# Initialize health monitor
|
||||
health_check_enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true"
|
||||
if health_check_enabled:
|
||||
check_interval = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
|
||||
health_monitor = HealthMonitor(check_interval_minutes=check_interval)
|
||||
health_monitor.start()
|
||||
logger.info(f"GPU health monitor started (interval={check_interval} minutes)")
|
||||
# Execute common startup sequence
|
||||
job_queue, health_monitor = startup_sequence(
|
||||
service_name="MCP Whisper Server",
|
||||
require_gpu=True,
|
||||
initialize_queue=True,
|
||||
initialize_monitoring=True
|
||||
)
|
||||
|
||||
try:
|
||||
mcp.run()
|
||||
finally:
|
||||
# Cleanup on shutdown
|
||||
logger.info("Shutting down...")
|
||||
if job_queue:
|
||||
job_queue.stop(wait_for_current=True)
|
||||
logger.info("Job queue stopped")
|
||||
if health_monitor:
|
||||
health_monitor.stop()
|
||||
logger.info("Health monitor stopped")
|
||||
cleanup_on_shutdown(
|
||||
job_queue=job_queue,
|
||||
health_monitor=health_monitor,
|
||||
wait_for_current_job=True
|
||||
)
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Utility modules for Whisper transcription service.
|
||||
|
||||
Includes audio processing, formatters, test audio generation, input validation,
|
||||
circuit breaker, and startup logic.
|
||||
"""
|
||||
|
||||
@@ -7,17 +7,24 @@ Responsible for audio file validation and preprocessing
|
||||
import os
|
||||
import logging
|
||||
from typing import Union, Any
|
||||
from pathlib import Path
|
||||
from faster_whisper import decode_audio
|
||||
|
||||
from utils.input_validation import (
|
||||
validate_audio_file as validate_audio_file_secure,
|
||||
sanitize_error_message
|
||||
)
|
||||
|
||||
# Log configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def validate_audio_file(audio_path: str) -> None:
|
||||
def validate_audio_file(audio_path: str, allowed_dirs: list = None) -> None:
|
||||
"""
|
||||
Validate if an audio file is valid
|
||||
Validate if an audio file is valid (with security checks).
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If audio file doesn't exist
|
||||
@@ -27,31 +34,19 @@ def validate_audio_file(audio_path: str) -> None:
|
||||
Returns:
|
||||
None: If validation passes
|
||||
"""
|
||||
# Validate parameters
|
||||
if not os.path.exists(audio_path):
|
||||
raise FileNotFoundError(f"Audio file does not exist: {audio_path}")
|
||||
|
||||
# Validate file format
|
||||
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
||||
file_ext = os.path.splitext(audio_path)[1].lower()
|
||||
if file_ext not in supported_formats:
|
||||
raise ValueError(f"Unsupported audio format: {file_ext}. Supported formats: {', '.join(supported_formats)}")
|
||||
|
||||
# Validate file size
|
||||
try:
|
||||
file_size = os.path.getsize(audio_path)
|
||||
if file_size == 0:
|
||||
raise ValueError(f"Audio file is empty: {audio_path}")
|
||||
# Use secure validation
|
||||
validate_audio_file_secure(audio_path, allowed_dirs)
|
||||
except Exception as e:
|
||||
# Re-raise with sanitized error messages
|
||||
error_msg = sanitize_error_message(str(e))
|
||||
|
||||
# Warning for large files (over 1GB)
|
||||
if file_size > 1024 * 1024 * 1024:
|
||||
logger.warning(f"Warning: File size exceeds 1GB, may require longer processing time: {audio_path}")
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to check file size: {str(e)}")
|
||||
raise OSError(f"Failed to check file size: {str(e)}")
|
||||
|
||||
# Validation passed
|
||||
return None
|
||||
if "not found" in str(e).lower():
|
||||
raise FileNotFoundError(error_msg)
|
||||
elif "size" in str(e).lower():
|
||||
raise OSError(error_msg)
|
||||
else:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def process_audio(audio_path: str) -> Union[str, Any]:
|
||||
"""
|
||||
|
||||
291
src/utils/circuit_breaker.py
Normal file
291
src/utils/circuit_breaker.py
Normal file
@@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Circuit Breaker Pattern Implementation
|
||||
|
||||
Prevents repeated failed attempts and provides fail-fast behavior.
|
||||
Useful for GPU health checks and other operations that may fail repeatedly.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states."""
|
||||
CLOSED = "closed" # Normal operation, requests pass through
|
||||
OPEN = "open" # Circuit is open, requests fail immediately
|
||||
HALF_OPEN = "half_open" # Testing if circuit can close
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitBreakerConfig:
|
||||
"""Configuration for circuit breaker."""
|
||||
failure_threshold: int = 3 # Failures before opening circuit
|
||||
success_threshold: int = 2 # Successes before closing from half-open
|
||||
timeout_seconds: int = 60 # Time before attempting half-open
|
||||
half_open_max_calls: int = 1 # Max calls to test in half-open state
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker implementation for preventing repeated failures.
|
||||
|
||||
Usage:
|
||||
breaker = CircuitBreaker(name="gpu_health", failure_threshold=3)
|
||||
|
||||
@breaker.call
|
||||
def check_gpu():
|
||||
# This function will be protected by circuit breaker
|
||||
return perform_gpu_check()
|
||||
|
||||
# Or use decorator:
|
||||
@breaker.decorator()
|
||||
def my_function():
|
||||
return "result"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
failure_threshold: int = 3,
|
||||
success_threshold: int = 2,
|
||||
timeout_seconds: int = 60,
|
||||
half_open_max_calls: int = 1
|
||||
):
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
name: Name of the circuit (for logging)
|
||||
failure_threshold: Number of failures before opening
|
||||
success_threshold: Number of successes to close from half-open
|
||||
timeout_seconds: Seconds before transitioning to half-open
|
||||
half_open_max_calls: Max concurrent calls in half-open state
|
||||
"""
|
||||
self.name = name
|
||||
self.config = CircuitBreakerConfig(
|
||||
failure_threshold=failure_threshold,
|
||||
success_threshold=success_threshold,
|
||||
timeout_seconds=timeout_seconds,
|
||||
half_open_max_calls=half_open_max_calls
|
||||
)
|
||||
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
self._last_failure_time: Optional[datetime] = None
|
||||
self._half_open_calls = 0
|
||||
self._lock = threading.RLock()
|
||||
|
||||
logger.info(
|
||||
f"Circuit breaker '{name}' initialized: "
|
||||
f"failure_threshold={failure_threshold}, "
|
||||
f"timeout={timeout_seconds}s"
|
||||
)
|
||||
|
||||
@property
|
||||
def state(self) -> CircuitState:
|
||||
"""Get current circuit state."""
|
||||
with self._lock:
|
||||
self._update_state()
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if circuit is closed (normal operation)."""
|
||||
return self.state == CircuitState.CLOSED
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
"""Check if circuit is open (failing fast)."""
|
||||
return self.state == CircuitState.OPEN
|
||||
|
||||
@property
|
||||
def is_half_open(self) -> bool:
|
||||
"""Check if circuit is half-open (testing)."""
|
||||
return self.state == CircuitState.HALF_OPEN
|
||||
|
||||
def _update_state(self):
|
||||
"""Update state based on timeout and counters."""
|
||||
if self._state == CircuitState.OPEN:
|
||||
# Check if timeout has passed
|
||||
if self._last_failure_time:
|
||||
elapsed = datetime.utcnow() - self._last_failure_time
|
||||
if elapsed.total_seconds() >= self.config.timeout_seconds:
|
||||
logger.info(
|
||||
f"Circuit '{self.name}': Transitioning to HALF_OPEN "
|
||||
f"after {elapsed.total_seconds():.0f}s timeout"
|
||||
)
|
||||
self._state = CircuitState.HALF_OPEN
|
||||
self._half_open_calls = 0
|
||||
self._success_count = 0
|
||||
|
||||
def _on_success(self):
|
||||
"""Handle successful call."""
|
||||
with self._lock:
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
self._success_count += 1
|
||||
logger.debug(
|
||||
f"Circuit '{self.name}': Success in HALF_OPEN "
|
||||
f"({self._success_count}/{self.config.success_threshold})"
|
||||
)
|
||||
|
||||
if self._success_count >= self.config.success_threshold:
|
||||
logger.info(f"Circuit '{self.name}': Closing circuit after successful test")
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
self._last_failure_time = None
|
||||
|
||||
elif self._state == CircuitState.CLOSED:
|
||||
# Reset failure count on success
|
||||
self._failure_count = 0
|
||||
|
||||
def _on_failure(self, error: Exception):
|
||||
"""Handle failed call."""
|
||||
with self._lock:
|
||||
self._failure_count += 1
|
||||
self._last_failure_time = datetime.utcnow()
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
logger.warning(
|
||||
f"Circuit '{self.name}': Failure in HALF_OPEN, reopening circuit"
|
||||
)
|
||||
self._state = CircuitState.OPEN
|
||||
self._success_count = 0
|
||||
|
||||
elif self._state == CircuitState.CLOSED:
|
||||
logger.debug(
|
||||
f"Circuit '{self.name}': Failure {self._failure_count}/"
|
||||
f"{self.config.failure_threshold}"
|
||||
)
|
||||
|
||||
if self._failure_count >= self.config.failure_threshold:
|
||||
logger.warning(
|
||||
f"Circuit '{self.name}': Opening circuit after "
|
||||
f"{self._failure_count} failures. "
|
||||
f"Will retry in {self.config.timeout_seconds}s"
|
||||
)
|
||||
self._state = CircuitState.OPEN
|
||||
self._success_count = 0
|
||||
|
||||
def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute function with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
func: Function to execute
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Function result
|
||||
|
||||
Raises:
|
||||
CircuitBreakerOpen: If circuit is open
|
||||
Exception: Original exception from func if it fails
|
||||
"""
|
||||
with self._lock:
|
||||
self._update_state()
|
||||
|
||||
# Check if circuit is open
|
||||
if self._state == CircuitState.OPEN:
|
||||
raise CircuitBreakerOpen(
|
||||
f"Circuit '{self.name}' is OPEN. "
|
||||
f"Failing fast to prevent repeated failures. "
|
||||
f"Last failure: {self._last_failure_time.isoformat() if self._last_failure_time else 'unknown'}. "
|
||||
f"Will retry in {self.config.timeout_seconds}s"
|
||||
)
|
||||
|
||||
# Check half-open call limit
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
if self._half_open_calls >= self.config.half_open_max_calls:
|
||||
raise CircuitBreakerOpen(
|
||||
f"Circuit '{self.name}' is HALF_OPEN with max calls reached. "
|
||||
f"Please wait for current test to complete."
|
||||
)
|
||||
self._half_open_calls += 1
|
||||
|
||||
# Execute function
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
self._on_success()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self._on_failure(e)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Decrement half-open counter
|
||||
with self._lock:
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
self._half_open_calls -= 1
|
||||
|
||||
def decorator(self):
|
||||
"""
|
||||
Decorator for protecting functions with circuit breaker.
|
||||
|
||||
Usage:
|
||||
breaker = CircuitBreaker("my_service")
|
||||
|
||||
@breaker.decorator()
|
||||
def my_function():
|
||||
return do_something()
|
||||
"""
|
||||
def wrapper(func):
|
||||
def decorated(*args, **kwargs):
|
||||
return self.call(func, *args, **kwargs)
|
||||
return decorated
|
||||
return wrapper
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Manually reset circuit breaker to closed state.
|
||||
|
||||
Useful for:
|
||||
- Testing
|
||||
- Manual intervention
|
||||
- Clearing error state
|
||||
"""
|
||||
with self._lock:
|
||||
logger.info(f"Circuit '{self.name}': Manual reset to CLOSED state")
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
self._last_failure_time = None
|
||||
self._half_open_calls = 0
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""
|
||||
Get circuit breaker statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with current state and counters
|
||||
"""
|
||||
with self._lock:
|
||||
self._update_state()
|
||||
return {
|
||||
"name": self.name,
|
||||
"state": self._state.value,
|
||||
"failure_count": self._failure_count,
|
||||
"success_count": self._success_count,
|
||||
"last_failure_time": self._last_failure_time.isoformat() if self._last_failure_time else None,
|
||||
"config": {
|
||||
"failure_threshold": self.config.failure_threshold,
|
||||
"success_threshold": self.config.success_threshold,
|
||||
"timeout_seconds": self.config.timeout_seconds,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class CircuitBreakerOpen(Exception):
|
||||
"""Exception raised when circuit breaker is open."""
|
||||
pass
|
||||
411
src/utils/input_validation.py
Normal file
411
src/utils/input_validation.py
Normal file
@@ -0,0 +1,411 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Input Validation and Path Sanitization Module
|
||||
|
||||
Provides robust validation for user inputs with security protections
|
||||
against path traversal, injection attacks, and other malicious inputs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum file size (10GB)
|
||||
MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 * 1024
|
||||
|
||||
# Allowed audio file extensions
|
||||
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"}
|
||||
|
||||
# Allowed output formats
|
||||
ALLOWED_OUTPUT_FORMATS = {"vtt", "srt", "txt", "json"}
|
||||
|
||||
# Model name validation (whitelist)
|
||||
ALLOWED_MODEL_NAMES = {"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"}
|
||||
|
||||
# Device validation
|
||||
ALLOWED_DEVICES = {"cuda", "auto", "cpu"}
|
||||
|
||||
# Compute type validation
|
||||
ALLOWED_COMPUTE_TYPES = {"float16", "int8", "auto"}
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Base exception for validation errors."""
|
||||
pass
|
||||
|
||||
|
||||
class PathTraversalError(ValidationError):
|
||||
"""Exception for path traversal attempts."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidFileTypeError(ValidationError):
|
||||
"""Exception for invalid file types."""
|
||||
pass
|
||||
|
||||
|
||||
class FileSizeError(ValidationError):
|
||||
"""Exception for file size issues."""
|
||||
pass
|
||||
|
||||
|
||||
def sanitize_error_message(error_msg: str, sanitize_paths: bool = True) -> str:
|
||||
"""
|
||||
Sanitize error messages to prevent information leakage.
|
||||
|
||||
Args:
|
||||
error_msg: Original error message
|
||||
sanitize_paths: Whether to sanitize file paths (default: True)
|
||||
|
||||
Returns:
|
||||
Sanitized error message
|
||||
"""
|
||||
if not sanitize_paths:
|
||||
return error_msg
|
||||
|
||||
# Replace absolute paths with relative paths
|
||||
# Pattern: /home/user/... or /media/... or /var/... or /tmp/...
|
||||
path_pattern = r'(/(?:home|media|var|tmp|opt|usr)/[^\s:,]+)'
|
||||
|
||||
def replace_path(match):
|
||||
full_path = match.group(1)
|
||||
try:
|
||||
# Try to get just the filename
|
||||
basename = os.path.basename(full_path)
|
||||
return f"<file:{basename}>"
|
||||
except:
|
||||
return "<file:redacted>"
|
||||
|
||||
sanitized = re.sub(path_pattern, replace_path, error_msg)
|
||||
|
||||
# Also sanitize user names if present
|
||||
sanitized = re.sub(r'/home/([^/]+)/', '/home/<user>/', sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def validate_path_safe(file_path: str, allowed_dirs: Optional[List[str]] = None) -> Path:
|
||||
"""
|
||||
Validate and sanitize a file path to prevent directory traversal attacks.
|
||||
|
||||
Args:
|
||||
file_path: Path to validate
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
|
||||
Returns:
|
||||
Resolved Path object
|
||||
|
||||
Raises:
|
||||
PathTraversalError: If path contains traversal attempts
|
||||
ValidationError: If path is invalid
|
||||
"""
|
||||
if not file_path:
|
||||
raise ValidationError("File path cannot be empty")
|
||||
|
||||
# Convert to Path object
|
||||
try:
|
||||
path = Path(file_path)
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Invalid path format: {sanitize_error_message(str(e))}")
|
||||
|
||||
# Check for path traversal attempts
|
||||
path_str = str(path)
|
||||
if ".." in path_str:
|
||||
logger.warning(f"Path traversal attempt detected: {path_str}")
|
||||
raise PathTraversalError("Path traversal (..) is not allowed")
|
||||
|
||||
# Check for null bytes
|
||||
if "\x00" in path_str:
|
||||
logger.warning(f"Null byte in path detected: {path_str}")
|
||||
raise PathTraversalError("Null bytes in path are not allowed")
|
||||
|
||||
# Resolve to absolute path (but don't follow symlinks yet)
|
||||
try:
|
||||
resolved_path = path.resolve()
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Cannot resolve path: {sanitize_error_message(str(e))}")
|
||||
|
||||
# If allowed_dirs specified, ensure path is within one of them
|
||||
if allowed_dirs:
|
||||
allowed = False
|
||||
for allowed_dir in allowed_dirs:
|
||||
try:
|
||||
allowed_dir_path = Path(allowed_dir).resolve()
|
||||
# Check if resolved_path is under allowed_dir
|
||||
resolved_path.relative_to(allowed_dir_path)
|
||||
allowed = True
|
||||
break
|
||||
except ValueError:
|
||||
# Not relative to this allowed_dir
|
||||
continue
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
f"Path outside allowed directories: {path_str}, "
|
||||
f"allowed: {allowed_dirs}"
|
||||
)
|
||||
raise PathTraversalError(
|
||||
f"Path must be within allowed directories. "
|
||||
f"Allowed: {[os.path.basename(d) for d in allowed_dirs]}"
|
||||
)
|
||||
|
||||
return resolved_path
|
||||
|
||||
|
||||
def validate_audio_file(
|
||||
file_path: str,
|
||||
allowed_dirs: Optional[List[str]] = None,
|
||||
max_size_bytes: int = MAX_FILE_SIZE_BYTES
|
||||
) -> Path:
|
||||
"""
|
||||
Validate audio file path with security checks.
|
||||
|
||||
Args:
|
||||
file_path: Path to audio file
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
max_size_bytes: Maximum allowed file size
|
||||
|
||||
Returns:
|
||||
Validated Path object
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
PathTraversalError: If path traversal detected
|
||||
FileNotFoundError: If file doesn't exist
|
||||
InvalidFileTypeError: If file type not allowed
|
||||
FileSizeError: If file too large
|
||||
"""
|
||||
# Validate and sanitize path
|
||||
validated_path = validate_path_safe(file_path, allowed_dirs)
|
||||
|
||||
# Check file exists
|
||||
if not validated_path.exists():
|
||||
raise FileNotFoundError(f"Audio file not found: {validated_path.name}")
|
||||
|
||||
# Check it's a file (not directory)
|
||||
if not validated_path.is_file():
|
||||
raise ValidationError(f"Path is not a file: {validated_path.name}")
|
||||
|
||||
# Check file extension
|
||||
file_ext = validated_path.suffix.lower()
|
||||
if file_ext not in ALLOWED_AUDIO_EXTENSIONS:
|
||||
raise InvalidFileTypeError(
|
||||
f"Unsupported audio format: {file_ext}. "
|
||||
f"Supported: {', '.join(sorted(ALLOWED_AUDIO_EXTENSIONS))}"
|
||||
)
|
||||
|
||||
# Check file size
|
||||
try:
|
||||
file_size = validated_path.stat().st_size
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Cannot check file size: {sanitize_error_message(str(e))}")
|
||||
|
||||
if file_size == 0:
|
||||
raise FileSizeError(f"Audio file is empty: {validated_path.name}")
|
||||
|
||||
if file_size > max_size_bytes:
|
||||
raise FileSizeError(
|
||||
f"File too large: {file_size / (1024**3):.2f}GB. "
|
||||
f"Maximum: {max_size_bytes / (1024**3):.2f}GB"
|
||||
)
|
||||
|
||||
# Warn for large files (>1GB)
|
||||
if file_size > 1024 * 1024 * 1024:
|
||||
logger.warning(
|
||||
f"Large file: {file_size / (1024**3):.2f}GB, "
|
||||
f"may require extended processing time"
|
||||
)
|
||||
|
||||
return validated_path
|
||||
|
||||
|
||||
def validate_output_directory(
|
||||
dir_path: str,
|
||||
allowed_dirs: Optional[List[str]] = None,
|
||||
create_if_missing: bool = True
|
||||
) -> Path:
|
||||
"""
|
||||
Validate output directory path.
|
||||
|
||||
Args:
|
||||
dir_path: Directory path
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
create_if_missing: Create directory if it doesn't exist
|
||||
|
||||
Returns:
|
||||
Validated Path object
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
PathTraversalError: If path traversal detected
|
||||
"""
|
||||
# Validate and sanitize path
|
||||
validated_path = validate_path_safe(dir_path, allowed_dirs)
|
||||
|
||||
# Create if requested and doesn't exist
|
||||
if create_if_missing and not validated_path.exists():
|
||||
try:
|
||||
validated_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created output directory: {validated_path}")
|
||||
except Exception as e:
|
||||
raise ValidationError(
|
||||
f"Cannot create output directory: {sanitize_error_message(str(e))}"
|
||||
)
|
||||
|
||||
# Check it's a directory
|
||||
if validated_path.exists() and not validated_path.is_dir():
|
||||
raise ValidationError(f"Path exists but is not a directory: {validated_path.name}")
|
||||
|
||||
return validated_path
|
||||
|
||||
|
||||
def validate_model_name(model_name: str) -> str:
|
||||
"""
|
||||
Validate Whisper model name.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
Validated model name
|
||||
|
||||
Raises:
|
||||
ValidationError: If model name invalid
|
||||
"""
|
||||
if not model_name:
|
||||
raise ValidationError("Model name cannot be empty")
|
||||
|
||||
model_name = model_name.strip().lower()
|
||||
|
||||
if model_name not in ALLOWED_MODEL_NAMES:
|
||||
raise ValidationError(
|
||||
f"Invalid model name: {model_name}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_MODEL_NAMES))}"
|
||||
)
|
||||
|
||||
return model_name
|
||||
|
||||
|
||||
def validate_device(device: str) -> str:
|
||||
"""
|
||||
Validate device parameter.
|
||||
|
||||
Args:
|
||||
device: Device name to validate
|
||||
|
||||
Returns:
|
||||
Validated device name
|
||||
|
||||
Raises:
|
||||
ValidationError: If device invalid
|
||||
"""
|
||||
if not device:
|
||||
raise ValidationError("Device cannot be empty")
|
||||
|
||||
device = device.strip().lower()
|
||||
|
||||
if device not in ALLOWED_DEVICES:
|
||||
raise ValidationError(
|
||||
f"Invalid device: {device}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_DEVICES))}"
|
||||
)
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def validate_compute_type(compute_type: str) -> str:
|
||||
"""
|
||||
Validate compute type parameter.
|
||||
|
||||
Args:
|
||||
compute_type: Compute type to validate
|
||||
|
||||
Returns:
|
||||
Validated compute type
|
||||
|
||||
Raises:
|
||||
ValidationError: If compute type invalid
|
||||
"""
|
||||
if not compute_type:
|
||||
raise ValidationError("Compute type cannot be empty")
|
||||
|
||||
compute_type = compute_type.strip().lower()
|
||||
|
||||
if compute_type not in ALLOWED_COMPUTE_TYPES:
|
||||
raise ValidationError(
|
||||
f"Invalid compute type: {compute_type}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_COMPUTE_TYPES))}"
|
||||
)
|
||||
|
||||
return compute_type
|
||||
|
||||
|
||||
def validate_output_format(output_format: str) -> str:
|
||||
"""
|
||||
Validate output format parameter.
|
||||
|
||||
Args:
|
||||
output_format: Output format to validate
|
||||
|
||||
Returns:
|
||||
Validated output format
|
||||
|
||||
Raises:
|
||||
ValidationError: If output format invalid
|
||||
"""
|
||||
if not output_format:
|
||||
raise ValidationError("Output format cannot be empty")
|
||||
|
||||
output_format = output_format.strip().lower()
|
||||
|
||||
if output_format not in ALLOWED_OUTPUT_FORMATS:
|
||||
raise ValidationError(
|
||||
f"Invalid output format: {output_format}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_OUTPUT_FORMATS))}"
|
||||
)
|
||||
|
||||
return output_format
|
||||
|
||||
|
||||
def validate_numeric_range(
|
||||
value: float,
|
||||
min_value: float,
|
||||
max_value: float,
|
||||
param_name: str
|
||||
) -> float:
|
||||
"""
|
||||
Validate numeric parameter is within range.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
min_value: Minimum allowed value
|
||||
max_value: Maximum allowed value
|
||||
param_name: Parameter name for error messages
|
||||
|
||||
Returns:
|
||||
Validated value
|
||||
|
||||
Raises:
|
||||
ValidationError: If value out of range
|
||||
"""
|
||||
if value < min_value or value > max_value:
|
||||
raise ValidationError(
|
||||
f"{param_name} must be between {min_value} and {max_value}, "
|
||||
f"got {value}"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_beam_size(beam_size: int) -> int:
|
||||
"""Validate beam size parameter."""
|
||||
return int(validate_numeric_range(beam_size, 1, 20, "beam_size"))
|
||||
|
||||
|
||||
def validate_temperature(temperature: float) -> float:
|
||||
"""Validate temperature parameter."""
|
||||
return validate_numeric_range(temperature, 0.0, 1.0, "temperature")
|
||||
237
src/utils/startup.py
Normal file
237
src/utils/startup.py
Normal file
@@ -0,0 +1,237 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Common Startup Logic Module
|
||||
|
||||
Centralizes startup procedures shared between MCP and API servers,
|
||||
including GPU health checks, job queue initialization, and health monitoring.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import GPU health check with reset
|
||||
try:
|
||||
from core.gpu_health import check_gpu_health_with_reset
|
||||
GPU_HEALTH_CHECK_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"GPU health check with reset not available: {e}")
|
||||
GPU_HEALTH_CHECK_AVAILABLE = False
|
||||
|
||||
|
||||
def perform_startup_gpu_check(
|
||||
required_device: str = "cuda",
|
||||
auto_reset: bool = True,
|
||||
exit_on_failure: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Perform startup GPU health check with optional auto-reset.
|
||||
|
||||
This function:
|
||||
1. Checks if GPU health check is available
|
||||
2. Runs comprehensive GPU health check
|
||||
3. Attempts auto-reset if check fails and auto_reset=True
|
||||
4. Optionally exits process if check fails
|
||||
|
||||
Args:
|
||||
required_device: Required device ("cuda", "auto")
|
||||
auto_reset: Enable automatic GPU driver reset on failure
|
||||
exit_on_failure: Exit process if GPU check fails
|
||||
|
||||
Returns:
|
||||
True if GPU check passed, False otherwise
|
||||
|
||||
Side effects:
|
||||
May exit process if exit_on_failure=True and check fails
|
||||
"""
|
||||
if not GPU_HEALTH_CHECK_AVAILABLE:
|
||||
logger.warning("GPU health check not available, starting without GPU validation")
|
||||
if exit_on_failure:
|
||||
logger.error("GPU health check required but not available. Exiting.")
|
||||
sys.exit(1)
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info("=" * 70)
|
||||
logger.info("PERFORMING STARTUP GPU HEALTH CHECK")
|
||||
logger.info("=" * 70)
|
||||
|
||||
status = check_gpu_health_with_reset(
|
||||
expected_device=required_device,
|
||||
auto_reset=auto_reset
|
||||
)
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("STARTUP GPU CHECK SUCCESSFUL")
|
||||
logger.info(f"GPU Device: {status.device_name}")
|
||||
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
|
||||
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
|
||||
logger.info("=" * 70)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("=" * 70)
|
||||
logger.error("STARTUP GPU CHECK FAILED")
|
||||
logger.error(f"Error: {e}")
|
||||
|
||||
if exit_on_failure:
|
||||
logger.error("This service requires GPU. Terminating.")
|
||||
logger.error("=" * 70)
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.error("Continuing without GPU (may have reduced functionality)")
|
||||
logger.error("=" * 70)
|
||||
return False
|
||||
|
||||
|
||||
def initialize_job_queue(
|
||||
max_queue_size: Optional[int] = None,
|
||||
metadata_dir: Optional[str] = None
|
||||
) -> 'JobQueue':
|
||||
"""
|
||||
Initialize job queue with environment variable configuration.
|
||||
|
||||
Args:
|
||||
max_queue_size: Override for max queue size (uses env var if None)
|
||||
metadata_dir: Override for metadata directory (uses env var if None)
|
||||
|
||||
Returns:
|
||||
Initialized JobQueue instance (started)
|
||||
"""
|
||||
from core.job_queue import JobQueue
|
||||
|
||||
# Get configuration from environment
|
||||
if max_queue_size is None:
|
||||
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
|
||||
|
||||
if metadata_dir is None:
|
||||
metadata_dir = os.getenv(
|
||||
"JOB_METADATA_DIR",
|
||||
"/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
|
||||
)
|
||||
|
||||
logger.info("Initializing job queue...")
|
||||
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
|
||||
job_queue.start()
|
||||
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
|
||||
|
||||
return job_queue
|
||||
|
||||
|
||||
def initialize_health_monitor(
|
||||
check_interval_minutes: Optional[int] = None,
|
||||
enabled: Optional[bool] = None
|
||||
) -> Optional['HealthMonitor']:
|
||||
"""
|
||||
Initialize GPU health monitor with environment variable configuration.
|
||||
|
||||
Args:
|
||||
check_interval_minutes: Override for check interval (uses env var if None)
|
||||
enabled: Override for enabled status (uses env var if None)
|
||||
|
||||
Returns:
|
||||
Initialized HealthMonitor instance (started), or None if disabled
|
||||
"""
|
||||
from core.gpu_health import HealthMonitor
|
||||
|
||||
# Get configuration from environment
|
||||
if enabled is None:
|
||||
enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true"
|
||||
|
||||
if not enabled:
|
||||
logger.info("GPU health monitoring disabled")
|
||||
return None
|
||||
|
||||
if check_interval_minutes is None:
|
||||
check_interval_minutes = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
|
||||
|
||||
health_monitor = HealthMonitor(check_interval_minutes=check_interval_minutes)
|
||||
health_monitor.start()
|
||||
logger.info(f"GPU health monitor started (interval={check_interval_minutes} minutes)")
|
||||
|
||||
return health_monitor
|
||||
|
||||
|
||||
def startup_sequence(
|
||||
service_name: str = "whisper-transcription",
|
||||
require_gpu: bool = True,
|
||||
initialize_queue: bool = True,
|
||||
initialize_monitoring: bool = True
|
||||
) -> Tuple[Optional['JobQueue'], Optional['HealthMonitor']]:
|
||||
"""
|
||||
Execute complete startup sequence for a Whisper transcription server.
|
||||
|
||||
This function performs all common startup tasks:
|
||||
1. GPU health check with auto-reset
|
||||
2. Job queue initialization
|
||||
3. Health monitor initialization
|
||||
|
||||
Args:
|
||||
service_name: Name of the service (for logging)
|
||||
require_gpu: Whether GPU is required (exit if not available)
|
||||
initialize_queue: Whether to initialize job queue
|
||||
initialize_monitoring: Whether to initialize health monitoring
|
||||
|
||||
Returns:
|
||||
Tuple of (job_queue, health_monitor) - either may be None
|
||||
|
||||
Side effects:
|
||||
May exit process if GPU required but unavailable
|
||||
"""
|
||||
logger.info(f"Starting {service_name}...")
|
||||
|
||||
# Step 1: GPU health check
|
||||
gpu_ok = perform_startup_gpu_check(
|
||||
required_device="cuda",
|
||||
auto_reset=True,
|
||||
exit_on_failure=require_gpu
|
||||
)
|
||||
|
||||
if not gpu_ok and require_gpu:
|
||||
# Should not reach here (exit_on_failure should have exited)
|
||||
logger.error("GPU check failed and GPU is required")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 2: Initialize job queue
|
||||
job_queue = None
|
||||
if initialize_queue:
|
||||
job_queue = initialize_job_queue()
|
||||
|
||||
# Step 3: Initialize health monitor
|
||||
health_monitor = None
|
||||
if initialize_monitoring:
|
||||
health_monitor = initialize_health_monitor()
|
||||
|
||||
logger.info(f"{service_name} startup sequence completed")
|
||||
|
||||
return job_queue, health_monitor
|
||||
|
||||
|
||||
def cleanup_on_shutdown(
|
||||
job_queue: Optional['JobQueue'] = None,
|
||||
health_monitor: Optional['HealthMonitor'] = None,
|
||||
wait_for_current_job: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Perform cleanup on server shutdown.
|
||||
|
||||
Args:
|
||||
job_queue: JobQueue instance to stop (if any)
|
||||
health_monitor: HealthMonitor instance to stop (if any)
|
||||
wait_for_current_job: Wait for current job to complete before stopping
|
||||
"""
|
||||
logger.info("Shutting down...")
|
||||
|
||||
if job_queue:
|
||||
job_queue.stop(wait_for_current=wait_for_current_job)
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
if health_monitor:
|
||||
health_monitor.stop()
|
||||
logger.info("Health monitor stopped")
|
||||
|
||||
logger.info("Shutdown complete")
|
||||
@@ -22,8 +22,8 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
||||
# Add src to path (go up one level from tests/ to root)
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
# Color codes for terminal output
|
||||
class Colors:
|
||||
@@ -435,7 +435,9 @@ class Phase2Tester:
|
||||
|
||||
def _create_test_audio_file(self):
|
||||
"""Get the path to the test audio file"""
|
||||
test_audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.mp3"
|
||||
# Use relative path from project root
|
||||
project_root = Path(__file__).parent.parent
|
||||
test_audio_path = str(project_root / "data" / "test.mp3")
|
||||
if not os.path.exists(test_audio_path):
|
||||
raise FileNotFoundError(f"Test audio file not found: {test_audio_path}")
|
||||
return test_audio_path
|
||||
@@ -24,8 +24,8 @@ logging.basicConfig(
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||
# Add src to path (go up one level from tests/ to root)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from core.gpu_health import check_gpu_health, HealthMonitor
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
@@ -61,8 +61,9 @@ def test_audio_file():
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Use the actual test audio file
|
||||
audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.m4a"
|
||||
# Use the actual test audio file (relative to project root)
|
||||
project_root = os.path.join(os.path.dirname(__file__), '..')
|
||||
audio_path = os.path.join(project_root, "data/test.mp3")
|
||||
|
||||
# Verify file exists
|
||||
assert os.path.exists(audio_path), "Audio file not found"
|
||||
@@ -147,8 +148,9 @@ def test_job_queue():
|
||||
job_queue.start()
|
||||
print("✓ Job queue started")
|
||||
|
||||
# Use the actual test audio file
|
||||
audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.m4a"
|
||||
# Use the actual test audio file (relative to project root)
|
||||
project_root = os.path.join(os.path.dirname(__file__), '..')
|
||||
audio_path = os.path.join(project_root, "data/test.mp3")
|
||||
|
||||
# Test job submission
|
||||
print("\nSubmitting test job...")
|
||||
523
tests/test_e2e_integration.py
Executable file
523
tests/test_e2e_integration.py
Executable file
@@ -0,0 +1,523 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Phase 4: End-to-End Integration Testing
|
||||
|
||||
Comprehensive integration tests for the async job queue system.
|
||||
Tests all scenarios from the DEV_PLAN.md Phase 4 checklist.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
import subprocess
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add src to path (go up one level from tests/ to root)
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
# Color codes for terminal output
|
||||
class Colors:
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
CYAN = '\033[96m'
|
||||
END = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
|
||||
def print_success(msg):
|
||||
print(f"{Colors.GREEN}✓ {msg}{Colors.END}")
|
||||
|
||||
def print_error(msg):
|
||||
print(f"{Colors.RED}✗ {msg}{Colors.END}")
|
||||
|
||||
def print_info(msg):
|
||||
print(f"{Colors.BLUE}ℹ {msg}{Colors.END}")
|
||||
|
||||
def print_warning(msg):
|
||||
print(f"{Colors.YELLOW}⚠ {msg}{Colors.END}")
|
||||
|
||||
def print_section(msg):
|
||||
print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{msg}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}\n")
|
||||
|
||||
|
||||
class Phase4Tester:
|
||||
def __init__(self, api_url="http://localhost:8000", test_audio=None):
|
||||
self.api_url = api_url
|
||||
# Use relative path from project root if not provided
|
||||
if test_audio is None:
|
||||
project_root = Path(__file__).parent.parent
|
||||
test_audio = str(project_root / "data" / "test.mp3")
|
||||
self.test_audio = test_audio
|
||||
self.test_results = []
|
||||
self.server_process = None
|
||||
|
||||
# Verify test audio exists
|
||||
if not os.path.exists(test_audio):
|
||||
raise FileNotFoundError(f"Test audio file not found: {test_audio}")
|
||||
|
||||
def test(self, name, func):
|
||||
"""Run a test and record result"""
|
||||
try:
|
||||
logger.info(f"Testing: {name}")
|
||||
print_info(f"Testing: {name}")
|
||||
func()
|
||||
logger.info(f"PASSED: {name}")
|
||||
print_success(f"PASSED: {name}")
|
||||
self.test_results.append((name, True, None))
|
||||
return True
|
||||
except AssertionError as e:
|
||||
logger.error(f"FAILED: {name} - {str(e)}")
|
||||
print_error(f"FAILED: {name}")
|
||||
print_error(f" Reason: {str(e)}")
|
||||
self.test_results.append((name, False, str(e)))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR: {name} - {str(e)}")
|
||||
print_error(f"ERROR: {name}")
|
||||
print_error(f" Exception: {str(e)}")
|
||||
self.test_results.append((name, False, f"Exception: {str(e)}"))
|
||||
return False
|
||||
|
||||
def start_api_server(self, wait_time=5):
|
||||
"""Start the API server in background"""
|
||||
print_info("Starting API server...")
|
||||
# Script is in project root, one level up from tests/
|
||||
script_path = Path(__file__).parent.parent / "run_api_server.sh"
|
||||
|
||||
# Start server in background
|
||||
self.server_process = subprocess.Popen(
|
||||
[str(script_path)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
preexec_fn=os.setsid
|
||||
)
|
||||
|
||||
# Wait for server to start
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Verify server is running
|
||||
try:
|
||||
response = requests.get(f"{self.api_url}/health", timeout=5)
|
||||
if response.status_code == 200:
|
||||
print_success("API server started successfully")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
print_error("API server failed to start")
|
||||
return False
|
||||
|
||||
def stop_api_server(self):
|
||||
"""Stop the API server"""
|
||||
if self.server_process:
|
||||
print_info("Stopping API server...")
|
||||
os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM)
|
||||
self.server_process.wait(timeout=10)
|
||||
print_success("API server stopped")
|
||||
|
||||
def wait_for_job_completion(self, job_id, timeout=60, poll_interval=2):
|
||||
"""Poll job status until completed or failed"""
|
||||
start_time = time.time()
|
||||
last_status = None
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}")
|
||||
assert response.status_code == 200, f"Failed to get job status: {response.status_code}"
|
||||
|
||||
status_data = response.json()
|
||||
current_status = status_data['status']
|
||||
|
||||
# Print status changes
|
||||
if current_status != last_status:
|
||||
if status_data.get('queue_position') is not None:
|
||||
print_info(f" Job status: {current_status}, queue position: {status_data['queue_position']}")
|
||||
else:
|
||||
print_info(f" Job status: {current_status}")
|
||||
last_status = current_status
|
||||
|
||||
if current_status == "completed":
|
||||
return status_data
|
||||
elif current_status == "failed":
|
||||
raise AssertionError(f"Job failed: {status_data.get('error', 'Unknown error')}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise AssertionError(f"Request failed: {e}")
|
||||
|
||||
raise AssertionError(f"Job did not complete within {timeout} seconds")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 1: Single Job Submission and Completion
|
||||
# ========================================================================
|
||||
def test_single_job_flow(self):
|
||||
"""Test complete job flow: submit → poll → get result"""
|
||||
# Submit job
|
||||
print_info(" Submitting job...")
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "large-v3",
|
||||
"output_format": "txt"
|
||||
})
|
||||
assert response.status_code == 200, f"Job submission failed: {response.status_code}"
|
||||
|
||||
job_data = response.json()
|
||||
assert 'job_id' in job_data, "No job_id in response"
|
||||
# Status can be 'queued' or 'running' (if queue is empty and job starts immediately)
|
||||
assert job_data['status'] in ['queued', 'running'], f"Expected status 'queued' or 'running', got '{job_data['status']}'"
|
||||
|
||||
job_id = job_data['job_id']
|
||||
print_success(f" Job submitted: {job_id}")
|
||||
|
||||
# Wait for completion
|
||||
print_info(" Waiting for job completion...")
|
||||
final_status = self.wait_for_job_completion(job_id)
|
||||
|
||||
assert final_status['status'] == 'completed', "Job did not complete"
|
||||
assert final_status['result_path'] is not None, "No result_path in completed job"
|
||||
assert final_status['processing_time_seconds'] is not None, "No processing time"
|
||||
print_success(f" Job completed in {final_status['processing_time_seconds']:.2f}s")
|
||||
|
||||
# Get result
|
||||
print_info(" Retrieving result...")
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
|
||||
assert response.status_code == 200, f"Failed to get result: {response.status_code}"
|
||||
|
||||
result_text = response.text
|
||||
assert len(result_text) > 0, "Empty result"
|
||||
print_success(f" Got result: {len(result_text)} characters")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 2: Multiple Jobs in Queue (FIFO)
|
||||
# ========================================================================
|
||||
def test_multiple_jobs_fifo(self):
|
||||
"""Test multiple jobs are processed in FIFO order"""
|
||||
job_ids = []
|
||||
|
||||
# Submit 3 jobs
|
||||
print_info(" Submitting 3 jobs...")
|
||||
for i in range(3):
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "tiny", # Use tiny model for faster processing
|
||||
"output_format": "txt"
|
||||
})
|
||||
assert response.status_code == 200, f"Job {i+1} submission failed"
|
||||
|
||||
job_data = response.json()
|
||||
job_ids.append(job_data['job_id'])
|
||||
print_info(f" Job {i+1} submitted: {job_data['job_id']}, queue_position: {job_data.get('queue_position', 0)}")
|
||||
|
||||
# Wait for all jobs to complete
|
||||
print_info(" Waiting for all jobs to complete...")
|
||||
for i, job_id in enumerate(job_ids):
|
||||
print_info(f" Waiting for job {i+1}/{len(job_ids)}...")
|
||||
final_status = self.wait_for_job_completion(job_id, timeout=120)
|
||||
assert final_status['status'] == 'completed', f"Job {i+1} failed"
|
||||
|
||||
print_success(f" All {len(job_ids)} jobs completed in FIFO order")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 3: GPU Health Check
|
||||
# ========================================================================
|
||||
def test_gpu_health_check(self):
|
||||
"""Test GPU health check endpoint"""
|
||||
print_info(" Checking GPU health...")
|
||||
response = requests.get(f"{self.api_url}/health/gpu")
|
||||
assert response.status_code == 200, f"GPU health check failed: {response.status_code}"
|
||||
|
||||
health_data = response.json()
|
||||
assert 'gpu_available' in health_data, "Missing gpu_available field"
|
||||
assert 'gpu_working' in health_data, "Missing gpu_working field"
|
||||
assert 'device_used' in health_data, "Missing device_used field"
|
||||
|
||||
print_info(f" GPU Available: {health_data['gpu_available']}")
|
||||
print_info(f" GPU Working: {health_data['gpu_working']}")
|
||||
print_info(f" Device: {health_data['device_used']}")
|
||||
|
||||
if health_data['gpu_available']:
|
||||
assert health_data['device_name'], "GPU available but no device_name"
|
||||
assert health_data['test_duration_seconds'] < 3, "GPU test took too long (might be using CPU)"
|
||||
print_success(f" GPU is healthy: {health_data['device_name']}")
|
||||
else:
|
||||
print_warning(" GPU not available on this system")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 4: Invalid Audio Path
|
||||
# ========================================================================
|
||||
def test_invalid_audio_path(self):
|
||||
"""Test job submission with invalid audio path"""
|
||||
print_info(" Submitting job with invalid path...")
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": "/invalid/path/does/not/exist.mp3",
|
||||
"model_name": "large-v3"
|
||||
})
|
||||
|
||||
# Should return 400 Bad Request
|
||||
assert response.status_code == 400, f"Expected 400, got {response.status_code}"
|
||||
|
||||
error_data = response.json()
|
||||
assert 'detail' in error_data or 'error' in error_data, "No error message in response"
|
||||
print_success(" Invalid path rejected correctly")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 5: Job Not Found
|
||||
# ========================================================================
|
||||
def test_job_not_found(self):
|
||||
"""Test retrieving non-existent job"""
|
||||
print_info(" Requesting non-existent job...")
|
||||
fake_job_id = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
response = requests.get(f"{self.api_url}/jobs/{fake_job_id}")
|
||||
assert response.status_code == 404, f"Expected 404, got {response.status_code}"
|
||||
print_success(" Non-existent job handled correctly")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 6: Result Before Completion
|
||||
# ========================================================================
|
||||
def test_result_before_completion(self):
|
||||
"""Test getting result for job that hasn't completed"""
|
||||
print_info(" Submitting job and trying to get result immediately...")
|
||||
|
||||
# Submit job
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "large-v3"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
job_id = response.json()['job_id']
|
||||
|
||||
# Try to get result immediately (job is still queued/running)
|
||||
time.sleep(0.5)
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
|
||||
|
||||
# Should return 409 Conflict or similar
|
||||
assert response.status_code in [409, 400, 404], f"Expected 4xx error, got {response.status_code}"
|
||||
print_success(" Result request before completion handled correctly")
|
||||
|
||||
# Clean up: wait for job to complete
|
||||
self.wait_for_job_completion(job_id)
|
||||
|
||||
# ========================================================================
|
||||
# TEST 7: List Jobs
|
||||
# ========================================================================
|
||||
def test_list_jobs(self):
|
||||
"""Test listing jobs with filters"""
|
||||
print_info(" Testing job listing...")
|
||||
|
||||
# List all jobs
|
||||
response = requests.get(f"{self.api_url}/jobs")
|
||||
assert response.status_code == 200, f"List jobs failed: {response.status_code}"
|
||||
|
||||
jobs_data = response.json()
|
||||
assert 'jobs' in jobs_data, "No jobs array in response"
|
||||
assert isinstance(jobs_data['jobs'], list), "Jobs is not a list"
|
||||
print_info(f" Found {len(jobs_data['jobs'])} jobs")
|
||||
|
||||
# List only completed jobs
|
||||
response = requests.get(f"{self.api_url}/jobs?status=completed")
|
||||
assert response.status_code == 200
|
||||
completed_jobs = response.json()['jobs']
|
||||
print_info(f" Found {len(completed_jobs)} completed jobs")
|
||||
|
||||
# List with limit
|
||||
response = requests.get(f"{self.api_url}/jobs?limit=5")
|
||||
assert response.status_code == 200
|
||||
limited_jobs = response.json()['jobs']
|
||||
assert len(limited_jobs) <= 5, "Limit not respected"
|
||||
print_success(" Job listing works correctly")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 8: Server Restart with Job Persistence
|
||||
# ========================================================================
|
||||
def test_server_restart_persistence(self):
|
||||
"""Test that jobs persist across server restarts"""
|
||||
print_info(" Testing job persistence across restart...")
|
||||
|
||||
# Submit a job
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "tiny"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
job_id = response.json()['job_id']
|
||||
print_info(f" Submitted job: {job_id}")
|
||||
|
||||
# Get job count before restart
|
||||
response = requests.get(f"{self.api_url}/jobs")
|
||||
jobs_before = len(response.json()['jobs'])
|
||||
print_info(f" Jobs before restart: {jobs_before}")
|
||||
|
||||
# Restart server
|
||||
print_info(" Restarting server...")
|
||||
self.stop_api_server()
|
||||
time.sleep(2)
|
||||
assert self.start_api_server(wait_time=8), "Server failed to restart"
|
||||
|
||||
# Check jobs after restart
|
||||
response = requests.get(f"{self.api_url}/jobs")
|
||||
assert response.status_code == 200
|
||||
jobs_after = len(response.json()['jobs'])
|
||||
print_info(f" Jobs after restart: {jobs_after}")
|
||||
|
||||
# Check our specific job is still there (this is the key test)
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}")
|
||||
assert response.status_code == 200, "Job not found after restart"
|
||||
|
||||
# Note: Total count may differ due to job retention/cleanup, but persistence works if we can find the job
|
||||
if jobs_after < jobs_before:
|
||||
print_warning(f" Job count decreased ({jobs_before} -> {jobs_after}), may be due to cleanup")
|
||||
|
||||
print_success(" Jobs persisted correctly across restart")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 9: Health Endpoint
|
||||
# ========================================================================
|
||||
def test_health_endpoint(self):
|
||||
"""Test basic health endpoint"""
|
||||
print_info(" Checking health endpoint...")
|
||||
response = requests.get(f"{self.api_url}/health")
|
||||
assert response.status_code == 200, f"Health check failed: {response.status_code}"
|
||||
|
||||
health_data = response.json()
|
||||
assert health_data['status'] == 'healthy', "Server not healthy"
|
||||
print_success(" Health endpoint OK")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 10: Models Endpoint
|
||||
# ========================================================================
|
||||
def test_models_endpoint(self):
|
||||
"""Test models information endpoint"""
|
||||
print_info(" Checking models endpoint...")
|
||||
response = requests.get(f"{self.api_url}/models")
|
||||
assert response.status_code == 200, f"Models endpoint failed: {response.status_code}"
|
||||
|
||||
models_data = response.json()
|
||||
assert 'available_models' in models_data, "No available_models field"
|
||||
assert 'available_devices' in models_data, "No available_devices field"
|
||||
assert len(models_data['available_models']) > 0, "No models listed"
|
||||
print_info(f" Available models: {len(models_data['available_models'])}")
|
||||
print_success(" Models endpoint OK")
|
||||
|
||||
def print_summary(self):
|
||||
"""Print test summary"""
|
||||
print_section("TEST SUMMARY")
|
||||
|
||||
passed = sum(1 for _, result, _ in self.test_results if result)
|
||||
failed = len(self.test_results) - passed
|
||||
|
||||
for name, result, error in self.test_results:
|
||||
if result:
|
||||
print_success(f"{name}")
|
||||
else:
|
||||
print_error(f"{name}")
|
||||
if error:
|
||||
print(f" {error}")
|
||||
|
||||
print(f"\n{Colors.BOLD}Total: {len(self.test_results)} | ", end="")
|
||||
print(f"{Colors.GREEN}Passed: {passed}{Colors.END} | ", end="")
|
||||
print(f"{Colors.RED}Failed: {failed}{Colors.END}\n")
|
||||
|
||||
return failed == 0
|
||||
|
||||
def run_all_tests(self, start_server=True):
|
||||
"""Run all Phase 4 integration tests"""
|
||||
print_section("PHASE 4: END-TO-END INTEGRATION TESTING")
|
||||
|
||||
try:
|
||||
# Start server if requested
|
||||
if start_server:
|
||||
if not self.start_api_server():
|
||||
print_error("Failed to start API server. Aborting tests.")
|
||||
return False
|
||||
else:
|
||||
# Verify server is already running
|
||||
try:
|
||||
response = requests.get(f"{self.api_url}/health", timeout=5)
|
||||
if response.status_code != 200:
|
||||
print_error("Server is not responding. Please start it first.")
|
||||
return False
|
||||
print_info("Using existing API server")
|
||||
except:
|
||||
print_error("Cannot connect to API server. Please start it first.")
|
||||
return False
|
||||
|
||||
# Run tests
|
||||
print_section("TEST 1: Single Job Submission and Completion")
|
||||
self.test("Single job flow (submit → poll → get result)", self.test_single_job_flow)
|
||||
|
||||
print_section("TEST 2: Multiple Jobs (FIFO Order)")
|
||||
self.test("Multiple jobs in queue (FIFO)", self.test_multiple_jobs_fifo)
|
||||
|
||||
print_section("TEST 3: GPU Health Check")
|
||||
self.test("GPU health check endpoint", self.test_gpu_health_check)
|
||||
|
||||
print_section("TEST 4: Error Handling - Invalid Path")
|
||||
self.test("Invalid audio path rejection", self.test_invalid_audio_path)
|
||||
|
||||
print_section("TEST 5: Error Handling - Job Not Found")
|
||||
self.test("Non-existent job handling", self.test_job_not_found)
|
||||
|
||||
print_section("TEST 6: Error Handling - Result Before Completion")
|
||||
self.test("Result request before completion", self.test_result_before_completion)
|
||||
|
||||
print_section("TEST 7: Job Listing")
|
||||
self.test("List jobs with filters", self.test_list_jobs)
|
||||
|
||||
print_section("TEST 8: Health Endpoint")
|
||||
self.test("Basic health endpoint", self.test_health_endpoint)
|
||||
|
||||
print_section("TEST 9: Models Endpoint")
|
||||
self.test("Models information endpoint", self.test_models_endpoint)
|
||||
|
||||
print_section("TEST 10: Server Restart Persistence")
|
||||
self.test("Job persistence across server restart", self.test_server_restart_persistence)
|
||||
|
||||
# Print summary
|
||||
success = self.print_summary()
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if start_server and self.server_process:
|
||||
self.stop_api_server()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main test runner"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Phase 4 Integration Tests')
|
||||
parser.add_argument('--url', default='http://localhost:8000', help='API server URL')
|
||||
# Default to None so Phase4Tester uses relative path
|
||||
parser.add_argument('--audio', default=None,
|
||||
help='Path to test audio file (default: <project_root>/data/test.mp3)')
|
||||
parser.add_argument('--no-start-server', action='store_true',
|
||||
help='Do not start server (assume it is already running)')
|
||||
args = parser.parse_args()
|
||||
|
||||
tester = Phase4Tester(api_url=args.url, test_audio=args.audio)
|
||||
success = tester.run_all_tests(start_server=not args.no_start_server)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user