Add circuit breaker, input validation, and refactor startup logic
- Implement circuit breaker pattern for GPU health checks - Prevents repeated failures with configurable thresholds - Three states: CLOSED, OPEN, HALF_OPEN - Integrated into GPU health monitoring - Add comprehensive input validation and path sanitization - Path traversal attack prevention - Whitelist-based validation for models, devices, formats - Error message sanitization to prevent information leakage - File size limits and security checks - Centralize startup logic across servers - Extract common startup procedures to utils/startup.py - Deduplicate GPU health checks and initialization code - Simplify both MCP and API server startup sequences - Add proper Python package structure - Add __init__.py files to all modules - Improve package organization - Add circuit breaker status API endpoints - GET /health/circuit-breaker - View circuit breaker stats - POST /health/circuit-breaker/reset - Reset circuit breaker - Reorganize test files into tests/ directory - Rename and restructure test files for better organization
This commit is contained in:
1488
DEV_PLAN.md
1488
DEV_PLAN.md
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
语音识别MCP服务模块
|
Faster Whisper MCP Transcription Service
|
||||||
|
|
||||||
|
High-performance audio transcription service with dual-server architecture
|
||||||
|
(MCP and REST API) featuring async job queue and GPU health monitoring.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.2.0"
|
||||||
|
__author__ = "Whisper MCP Team"
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""
|
||||||
|
Core modules for Whisper transcription service.
|
||||||
|
|
||||||
|
Includes model management, transcription logic, job queue, GPU health monitoring,
|
||||||
|
and GPU reset functionality.
|
||||||
|
"""
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ GPU health monitoring for Whisper transcription service.
|
|||||||
|
|
||||||
Performs real GPU health checks using actual model loading and transcription,
|
Performs real GPU health checks using actual model loading and transcription,
|
||||||
with strict failure handling to prevent silent CPU fallbacks.
|
with strict failure handling to prevent silent CPU fallbacks.
|
||||||
|
Includes circuit breaker pattern to prevent repeated failed checks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
@@ -14,9 +15,19 @@ from typing import Optional, List
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from utils.test_audio_generator import generate_test_audio
|
from utils.test_audio_generator import generate_test_audio
|
||||||
|
from utils.circuit_breaker import CircuitBreaker, CircuitBreakerOpen
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global circuit breaker for GPU health checks
|
||||||
|
_gpu_health_circuit_breaker = CircuitBreaker(
|
||||||
|
name="gpu_health_check",
|
||||||
|
failure_threshold=3, # Open after 3 consecutive failures
|
||||||
|
success_threshold=2, # Close after 2 consecutive successes
|
||||||
|
timeout_seconds=60, # Try again after 60 seconds
|
||||||
|
half_open_max_calls=1 # Only 1 test call in half-open state
|
||||||
|
)
|
||||||
|
|
||||||
# Import reset functionality (after logger initialization)
|
# Import reset functionality (after logger initialization)
|
||||||
try:
|
try:
|
||||||
from core.gpu_reset import (
|
from core.gpu_reset import (
|
||||||
@@ -48,7 +59,7 @@ class GPUHealthStatus:
|
|||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
def check_gpu_health(expected_device: str = "auto") -> GPUHealthStatus:
|
def _check_gpu_health_internal(expected_device: str = "auto") -> GPUHealthStatus:
|
||||||
"""
|
"""
|
||||||
Comprehensive GPU health check using real model + transcription.
|
Comprehensive GPU health check using real model + transcription.
|
||||||
|
|
||||||
@@ -207,6 +218,58 @@ def check_gpu_health(expected_device: str = "auto") -> GPUHealthStatus:
|
|||||||
return status
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
def check_gpu_health(expected_device: str = "auto", use_circuit_breaker: bool = True) -> GPUHealthStatus:
|
||||||
|
"""
|
||||||
|
GPU health check with optional circuit breaker protection.
|
||||||
|
|
||||||
|
This is the main entry point for GPU health checks. By default, it uses
|
||||||
|
circuit breaker pattern to prevent repeated failed checks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expected_device: Expected device ("auto", "cuda", "cpu")
|
||||||
|
use_circuit_breaker: Enable circuit breaker protection (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GPUHealthStatus object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If expected_device="cuda" but GPU test fails
|
||||||
|
CircuitBreakerOpen: If circuit breaker is open (too many recent failures)
|
||||||
|
"""
|
||||||
|
if use_circuit_breaker:
|
||||||
|
try:
|
||||||
|
return _gpu_health_circuit_breaker.call(_check_gpu_health_internal, expected_device)
|
||||||
|
except CircuitBreakerOpen as e:
|
||||||
|
# Circuit is open, fail fast without attempting check
|
||||||
|
logger.warning(f"GPU health check circuit breaker is OPEN: {e}")
|
||||||
|
raise RuntimeError(f"GPU health check unavailable: {str(e)}")
|
||||||
|
else:
|
||||||
|
return _check_gpu_health_internal(expected_device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_circuit_breaker_stats() -> dict:
|
||||||
|
"""
|
||||||
|
Get current circuit breaker statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with circuit state and failure/success counts
|
||||||
|
"""
|
||||||
|
return _gpu_health_circuit_breaker.get_stats()
|
||||||
|
|
||||||
|
|
||||||
|
def reset_circuit_breaker():
|
||||||
|
"""
|
||||||
|
Manually reset the GPU health check circuit breaker.
|
||||||
|
|
||||||
|
Useful for:
|
||||||
|
- Testing
|
||||||
|
- Manual intervention after fixing GPU issues
|
||||||
|
- Clearing persistent error state
|
||||||
|
"""
|
||||||
|
_gpu_health_circuit_breaker.reset()
|
||||||
|
logger.info("GPU health check circuit breaker manually reset")
|
||||||
|
|
||||||
|
|
||||||
def check_gpu_health_with_reset(
|
def check_gpu_health_with_reset(
|
||||||
expected_device: str = "cuda",
|
expected_device: str = "cuda",
|
||||||
auto_reset: bool = True
|
auto_reset: bool = True
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
"""
|
||||||
|
Server implementations for Whisper transcription service.
|
||||||
|
|
||||||
|
Includes MCP server (whisper_server.py) and REST API server (api_server.py).
|
||||||
|
"""
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import queue as queue_module
|
import queue as queue_module
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||||
@@ -17,20 +19,13 @@ import json
|
|||||||
|
|
||||||
from core.model_manager import get_model_info
|
from core.model_manager import get_model_info
|
||||||
from core.job_queue import JobQueue, JobStatus
|
from core.job_queue import JobQueue, JobStatus
|
||||||
from core.gpu_health import HealthMonitor, check_gpu_health
|
from core.gpu_health import HealthMonitor, check_gpu_health, get_circuit_breaker_stats, reset_circuit_breaker
|
||||||
|
from utils.startup import startup_sequence, cleanup_on_shutdown
|
||||||
|
|
||||||
# Logging configuration
|
# Logging configuration
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Import GPU health check with reset
|
|
||||||
try:
|
|
||||||
from core.gpu_health import check_gpu_health_with_reset
|
|
||||||
GPU_HEALTH_CHECK_AVAILABLE = True
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning(f"GPU health check with reset not available: {e}")
|
|
||||||
GPU_HEALTH_CHECK_AVAILABLE = False
|
|
||||||
|
|
||||||
# Global instances
|
# Global instances
|
||||||
job_queue: Optional[JobQueue] = None
|
job_queue: Optional[JobQueue] = None
|
||||||
health_monitor: Optional[HealthMonitor] = None
|
health_monitor: Optional[HealthMonitor] = None
|
||||||
@@ -41,34 +36,22 @@ async def lifespan(app: FastAPI):
|
|||||||
"""FastAPI lifespan context manager for startup/shutdown"""
|
"""FastAPI lifespan context manager for startup/shutdown"""
|
||||||
global job_queue, health_monitor
|
global job_queue, health_monitor
|
||||||
|
|
||||||
# Startup
|
# Startup - use common startup logic (without GPU check, handled in main)
|
||||||
logger.info("Starting job queue and health monitor...")
|
logger.info("Initializing job queue and health monitor...")
|
||||||
|
|
||||||
# Initialize job queue
|
from utils.startup import initialize_job_queue, initialize_health_monitor
|
||||||
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
|
|
||||||
metadata_dir = os.getenv("JOB_METADATA_DIR", "/media/raid/agents/tools/mcp-transcriptor/outputs/jobs")
|
|
||||||
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
|
|
||||||
job_queue.start()
|
|
||||||
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
|
|
||||||
|
|
||||||
# Initialize health monitor
|
job_queue = initialize_job_queue()
|
||||||
health_check_enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true"
|
health_monitor = initialize_health_monitor()
|
||||||
if health_check_enabled:
|
|
||||||
check_interval = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
|
|
||||||
health_monitor = HealthMonitor(check_interval_minutes=check_interval)
|
|
||||||
health_monitor.start()
|
|
||||||
logger.info(f"GPU health monitor started (interval={check_interval} minutes)")
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown
|
# Shutdown - use common cleanup logic
|
||||||
logger.info("Shutting down job queue and health monitor...")
|
cleanup_on_shutdown(
|
||||||
if job_queue:
|
job_queue=job_queue,
|
||||||
job_queue.stop(wait_for_current=True)
|
health_monitor=health_monitor,
|
||||||
logger.info("Job queue stopped")
|
wait_for_current_job=True
|
||||||
if health_monitor:
|
)
|
||||||
health_monitor.stop()
|
|
||||||
logger.info("Health monitor stopped")
|
|
||||||
|
|
||||||
|
|
||||||
# Create FastAPI app
|
# Create FastAPI app
|
||||||
@@ -107,6 +90,8 @@ async def root():
|
|||||||
"GET /": "API information",
|
"GET /": "API information",
|
||||||
"GET /health": "Health check",
|
"GET /health": "Health check",
|
||||||
"GET /health/gpu": "GPU health check",
|
"GET /health/gpu": "GPU health check",
|
||||||
|
"GET /health/circuit-breaker": "Get circuit breaker stats",
|
||||||
|
"POST /health/circuit-breaker/reset": "Reset circuit breaker",
|
||||||
"GET /models": "Get available models information",
|
"GET /models": "Get available models information",
|
||||||
"POST /jobs": "Submit transcription job (async)",
|
"POST /jobs": "Submit transcription job (async)",
|
||||||
"GET /jobs/{job_id}": "Get job status",
|
"GET /jobs/{job_id}": "Get job status",
|
||||||
@@ -418,6 +403,39 @@ async def gpu_health_check_endpoint():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health/circuit-breaker")
|
||||||
|
async def get_circuit_breaker_status():
|
||||||
|
"""
|
||||||
|
Get GPU health check circuit breaker statistics.
|
||||||
|
|
||||||
|
Returns current state, failure/success counts, and last failure time.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stats = get_circuit_breaker_stats()
|
||||||
|
return JSONResponse(status_code=200, content=stats)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get circuit breaker stats: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/health/circuit-breaker/reset")
|
||||||
|
async def reset_circuit_breaker_endpoint():
|
||||||
|
"""
|
||||||
|
Manually reset the GPU health check circuit breaker.
|
||||||
|
|
||||||
|
Useful after fixing GPU issues or for testing purposes.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
reset_circuit_breaker()
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={"message": "Circuit breaker reset successfully"}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to reset circuit breaker: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -425,31 +443,14 @@ async def gpu_health_check_endpoint():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
# Perform startup GPU health check with auto-reset
|
# Perform startup GPU health check
|
||||||
if GPU_HEALTH_CHECK_AVAILABLE:
|
from utils.startup import perform_startup_gpu_check
|
||||||
try:
|
|
||||||
logger.info("=" * 70)
|
|
||||||
logger.info("PERFORMING STARTUP GPU HEALTH CHECK")
|
|
||||||
logger.info("=" * 70)
|
|
||||||
|
|
||||||
status = check_gpu_health_with_reset(expected_device="cuda", auto_reset=True)
|
perform_startup_gpu_check(
|
||||||
|
required_device="cuda",
|
||||||
logger.info("=" * 70)
|
auto_reset=True,
|
||||||
logger.info("STARTUP GPU CHECK SUCCESSFUL")
|
exit_on_failure=True
|
||||||
logger.info(f"GPU Device: {status.device_name}")
|
)
|
||||||
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
|
|
||||||
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
|
|
||||||
logger.info("=" * 70)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("=" * 70)
|
|
||||||
logger.error("STARTUP GPU CHECK FAILED")
|
|
||||||
logger.error(f"Error: {e}")
|
|
||||||
logger.error("This service requires GPU. Terminating.")
|
|
||||||
logger.error("=" * 70)
|
|
||||||
sys.exit(1)
|
|
||||||
else:
|
|
||||||
logger.warning("GPU health check not available, starting without GPU validation")
|
|
||||||
|
|
||||||
# Get configuration from environment variables
|
# Get configuration from environment variables
|
||||||
host = os.getenv("API_HOST", "0.0.0.0")
|
host = os.getenv("API_HOST", "0.0.0.0")
|
||||||
|
|||||||
@@ -8,25 +8,21 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
import base64
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
from core.model_manager import get_model_info
|
from core.model_manager import get_model_info
|
||||||
from core.job_queue import JobQueue, JobStatus
|
from core.job_queue import JobQueue, JobStatus
|
||||||
from core.gpu_health import HealthMonitor, check_gpu_health
|
from core.gpu_health import HealthMonitor, check_gpu_health
|
||||||
|
from utils.startup import startup_sequence, cleanup_on_shutdown
|
||||||
|
|
||||||
# Log configuration
|
# Log configuration
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Import GPU health check with reset
|
|
||||||
try:
|
|
||||||
from core.gpu_health import check_gpu_health_with_reset
|
|
||||||
GPU_HEALTH_CHECK_AVAILABLE = True
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning(f"GPU health check with reset not available: {e}")
|
|
||||||
GPU_HEALTH_CHECK_AVAILABLE = False
|
|
||||||
|
|
||||||
# Global instances
|
# Global instances
|
||||||
job_queue: Optional[JobQueue] = None
|
job_queue: Optional[JobQueue] = None
|
||||||
health_monitor: Optional[HealthMonitor] = None
|
health_monitor: Optional[HealthMonitor] = None
|
||||||
@@ -319,56 +315,20 @@ def check_gpu_health() -> str:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("starting mcp server for whisper stt transcriptor")
|
print("starting mcp server for whisper stt transcriptor")
|
||||||
|
|
||||||
# Perform startup GPU health check with auto-reset
|
# Execute common startup sequence
|
||||||
if GPU_HEALTH_CHECK_AVAILABLE:
|
job_queue, health_monitor = startup_sequence(
|
||||||
try:
|
service_name="MCP Whisper Server",
|
||||||
logger.info("=" * 70)
|
require_gpu=True,
|
||||||
logger.info("PERFORMING STARTUP GPU HEALTH CHECK")
|
initialize_queue=True,
|
||||||
logger.info("=" * 70)
|
initialize_monitoring=True
|
||||||
|
)
|
||||||
status = check_gpu_health_with_reset(expected_device="cuda", auto_reset=True)
|
|
||||||
|
|
||||||
logger.info("=" * 70)
|
|
||||||
logger.info("STARTUP GPU CHECK SUCCESSFUL")
|
|
||||||
logger.info(f"GPU Device: {status.device_name}")
|
|
||||||
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
|
|
||||||
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
|
|
||||||
logger.info("=" * 70)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("=" * 70)
|
|
||||||
logger.error("STARTUP GPU CHECK FAILED")
|
|
||||||
logger.error(f"Error: {e}")
|
|
||||||
logger.error("This service requires GPU. Terminating.")
|
|
||||||
logger.error("=" * 70)
|
|
||||||
sys.exit(1)
|
|
||||||
else:
|
|
||||||
logger.warning("GPU health check not available, starting without GPU validation")
|
|
||||||
|
|
||||||
# Initialize job queue
|
|
||||||
logger.info("Initializing job queue...")
|
|
||||||
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
|
|
||||||
metadata_dir = os.getenv("JOB_METADATA_DIR", "/media/raid/agents/tools/mcp-transcriptor/outputs/jobs")
|
|
||||||
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
|
|
||||||
job_queue.start()
|
|
||||||
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
|
|
||||||
|
|
||||||
# Initialize health monitor
|
|
||||||
health_check_enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true"
|
|
||||||
if health_check_enabled:
|
|
||||||
check_interval = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
|
|
||||||
health_monitor = HealthMonitor(check_interval_minutes=check_interval)
|
|
||||||
health_monitor.start()
|
|
||||||
logger.info(f"GPU health monitor started (interval={check_interval} minutes)")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mcp.run()
|
mcp.run()
|
||||||
finally:
|
finally:
|
||||||
# Cleanup on shutdown
|
# Cleanup on shutdown
|
||||||
logger.info("Shutting down...")
|
cleanup_on_shutdown(
|
||||||
if job_queue:
|
job_queue=job_queue,
|
||||||
job_queue.stop(wait_for_current=True)
|
health_monitor=health_monitor,
|
||||||
logger.info("Job queue stopped")
|
wait_for_current_job=True
|
||||||
if health_monitor:
|
)
|
||||||
health_monitor.stop()
|
|
||||||
logger.info("Health monitor stopped")
|
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""
|
||||||
|
Utility modules for Whisper transcription service.
|
||||||
|
|
||||||
|
Includes audio processing, formatters, test audio generation, input validation,
|
||||||
|
circuit breaker, and startup logic.
|
||||||
|
"""
|
||||||
|
|||||||
@@ -7,17 +7,24 @@ Responsible for audio file validation and preprocessing
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import Union, Any
|
from typing import Union, Any
|
||||||
|
from pathlib import Path
|
||||||
from faster_whisper import decode_audio
|
from faster_whisper import decode_audio
|
||||||
|
|
||||||
|
from utils.input_validation import (
|
||||||
|
validate_audio_file as validate_audio_file_secure,
|
||||||
|
sanitize_error_message
|
||||||
|
)
|
||||||
|
|
||||||
# Log configuration
|
# Log configuration
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def validate_audio_file(audio_path: str) -> None:
|
def validate_audio_file(audio_path: str, allowed_dirs: list = None) -> None:
|
||||||
"""
|
"""
|
||||||
Validate if an audio file is valid
|
Validate if an audio file is valid (with security checks).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_path: Path to the audio file
|
audio_path: Path to the audio file
|
||||||
|
allowed_dirs: Optional list of allowed base directories
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If audio file doesn't exist
|
FileNotFoundError: If audio file doesn't exist
|
||||||
@@ -27,31 +34,19 @@ def validate_audio_file(audio_path: str) -> None:
|
|||||||
Returns:
|
Returns:
|
||||||
None: If validation passes
|
None: If validation passes
|
||||||
"""
|
"""
|
||||||
# Validate parameters
|
|
||||||
if not os.path.exists(audio_path):
|
|
||||||
raise FileNotFoundError(f"Audio file does not exist: {audio_path}")
|
|
||||||
|
|
||||||
# Validate file format
|
|
||||||
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
|
||||||
file_ext = os.path.splitext(audio_path)[1].lower()
|
|
||||||
if file_ext not in supported_formats:
|
|
||||||
raise ValueError(f"Unsupported audio format: {file_ext}. Supported formats: {', '.join(supported_formats)}")
|
|
||||||
|
|
||||||
# Validate file size
|
|
||||||
try:
|
try:
|
||||||
file_size = os.path.getsize(audio_path)
|
# Use secure validation
|
||||||
if file_size == 0:
|
validate_audio_file_secure(audio_path, allowed_dirs)
|
||||||
raise ValueError(f"Audio file is empty: {audio_path}")
|
except Exception as e:
|
||||||
|
# Re-raise with sanitized error messages
|
||||||
|
error_msg = sanitize_error_message(str(e))
|
||||||
|
|
||||||
# Warning for large files (over 1GB)
|
if "not found" in str(e).lower():
|
||||||
if file_size > 1024 * 1024 * 1024:
|
raise FileNotFoundError(error_msg)
|
||||||
logger.warning(f"Warning: File size exceeds 1GB, may require longer processing time: {audio_path}")
|
elif "size" in str(e).lower():
|
||||||
except OSError as e:
|
raise OSError(error_msg)
|
||||||
logger.error(f"Failed to check file size: {str(e)}")
|
else:
|
||||||
raise OSError(f"Failed to check file size: {str(e)}")
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
# Validation passed
|
|
||||||
return None
|
|
||||||
|
|
||||||
def process_audio(audio_path: str) -> Union[str, Any]:
|
def process_audio(audio_path: str) -> Union[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
291
src/utils/circuit_breaker.py
Normal file
291
src/utils/circuit_breaker.py
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Circuit Breaker Pattern Implementation
|
||||||
|
|
||||||
|
Prevents repeated failed attempts and provides fail-fast behavior.
|
||||||
|
Useful for GPU health checks and other operations that may fail repeatedly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from enum import Enum
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Callable, Any, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitState(Enum):
|
||||||
|
"""Circuit breaker states."""
|
||||||
|
CLOSED = "closed" # Normal operation, requests pass through
|
||||||
|
OPEN = "open" # Circuit is open, requests fail immediately
|
||||||
|
HALF_OPEN = "half_open" # Testing if circuit can close
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CircuitBreakerConfig:
|
||||||
|
"""Configuration for circuit breaker."""
|
||||||
|
failure_threshold: int = 3 # Failures before opening circuit
|
||||||
|
success_threshold: int = 2 # Successes before closing from half-open
|
||||||
|
timeout_seconds: int = 60 # Time before attempting half-open
|
||||||
|
half_open_max_calls: int = 1 # Max calls to test in half-open state
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreaker:
|
||||||
|
"""
|
||||||
|
Circuit breaker implementation for preventing repeated failures.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
breaker = CircuitBreaker(name="gpu_health", failure_threshold=3)
|
||||||
|
|
||||||
|
@breaker.call
|
||||||
|
def check_gpu():
|
||||||
|
# This function will be protected by circuit breaker
|
||||||
|
return perform_gpu_check()
|
||||||
|
|
||||||
|
# Or use decorator:
|
||||||
|
@breaker.decorator()
|
||||||
|
def my_function():
|
||||||
|
return "result"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
failure_threshold: int = 3,
|
||||||
|
success_threshold: int = 2,
|
||||||
|
timeout_seconds: int = 60,
|
||||||
|
half_open_max_calls: int = 1
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize circuit breaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the circuit (for logging)
|
||||||
|
failure_threshold: Number of failures before opening
|
||||||
|
success_threshold: Number of successes to close from half-open
|
||||||
|
timeout_seconds: Seconds before transitioning to half-open
|
||||||
|
half_open_max_calls: Max concurrent calls in half-open state
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self.config = CircuitBreakerConfig(
|
||||||
|
failure_threshold=failure_threshold,
|
||||||
|
success_threshold=success_threshold,
|
||||||
|
timeout_seconds=timeout_seconds,
|
||||||
|
half_open_max_calls=half_open_max_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
self._failure_count = 0
|
||||||
|
self._success_count = 0
|
||||||
|
self._last_failure_time: Optional[datetime] = None
|
||||||
|
self._half_open_calls = 0
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Circuit breaker '{name}' initialized: "
|
||||||
|
f"failure_threshold={failure_threshold}, "
|
||||||
|
f"timeout={timeout_seconds}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> CircuitState:
|
||||||
|
"""Get current circuit state."""
|
||||||
|
with self._lock:
|
||||||
|
self._update_state()
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
"""Check if circuit is closed (normal operation)."""
|
||||||
|
return self.state == CircuitState.CLOSED
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_open(self) -> bool:
|
||||||
|
"""Check if circuit is open (failing fast)."""
|
||||||
|
return self.state == CircuitState.OPEN
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_half_open(self) -> bool:
|
||||||
|
"""Check if circuit is half-open (testing)."""
|
||||||
|
return self.state == CircuitState.HALF_OPEN
|
||||||
|
|
||||||
|
def _update_state(self):
|
||||||
|
"""Update state based on timeout and counters."""
|
||||||
|
if self._state == CircuitState.OPEN:
|
||||||
|
# Check if timeout has passed
|
||||||
|
if self._last_failure_time:
|
||||||
|
elapsed = datetime.utcnow() - self._last_failure_time
|
||||||
|
if elapsed.total_seconds() >= self.config.timeout_seconds:
|
||||||
|
logger.info(
|
||||||
|
f"Circuit '{self.name}': Transitioning to HALF_OPEN "
|
||||||
|
f"after {elapsed.total_seconds():.0f}s timeout"
|
||||||
|
)
|
||||||
|
self._state = CircuitState.HALF_OPEN
|
||||||
|
self._half_open_calls = 0
|
||||||
|
self._success_count = 0
|
||||||
|
|
||||||
|
def _on_success(self):
|
||||||
|
"""Handle successful call."""
|
||||||
|
with self._lock:
|
||||||
|
if self._state == CircuitState.HALF_OPEN:
|
||||||
|
self._success_count += 1
|
||||||
|
logger.debug(
|
||||||
|
f"Circuit '{self.name}': Success in HALF_OPEN "
|
||||||
|
f"({self._success_count}/{self.config.success_threshold})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._success_count >= self.config.success_threshold:
|
||||||
|
logger.info(f"Circuit '{self.name}': Closing circuit after successful test")
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
self._failure_count = 0
|
||||||
|
self._success_count = 0
|
||||||
|
self._last_failure_time = None
|
||||||
|
|
||||||
|
elif self._state == CircuitState.CLOSED:
|
||||||
|
# Reset failure count on success
|
||||||
|
self._failure_count = 0
|
||||||
|
|
||||||
|
def _on_failure(self, error: Exception):
|
||||||
|
"""Handle failed call."""
|
||||||
|
with self._lock:
|
||||||
|
self._failure_count += 1
|
||||||
|
self._last_failure_time = datetime.utcnow()
|
||||||
|
|
||||||
|
if self._state == CircuitState.HALF_OPEN:
|
||||||
|
logger.warning(
|
||||||
|
f"Circuit '{self.name}': Failure in HALF_OPEN, reopening circuit"
|
||||||
|
)
|
||||||
|
self._state = CircuitState.OPEN
|
||||||
|
self._success_count = 0
|
||||||
|
|
||||||
|
elif self._state == CircuitState.CLOSED:
|
||||||
|
logger.debug(
|
||||||
|
f"Circuit '{self.name}': Failure {self._failure_count}/"
|
||||||
|
f"{self.config.failure_threshold}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._failure_count >= self.config.failure_threshold:
|
||||||
|
logger.warning(
|
||||||
|
f"Circuit '{self.name}': Opening circuit after "
|
||||||
|
f"{self._failure_count} failures. "
|
||||||
|
f"Will retry in {self.config.timeout_seconds}s"
|
||||||
|
)
|
||||||
|
self._state = CircuitState.OPEN
|
||||||
|
self._success_count = 0
|
||||||
|
|
||||||
|
def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Execute function with circuit breaker protection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to execute
|
||||||
|
*args: Positional arguments
|
||||||
|
**kwargs: Keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CircuitBreakerOpen: If circuit is open
|
||||||
|
Exception: Original exception from func if it fails
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
self._update_state()
|
||||||
|
|
||||||
|
# Check if circuit is open
|
||||||
|
if self._state == CircuitState.OPEN:
|
||||||
|
raise CircuitBreakerOpen(
|
||||||
|
f"Circuit '{self.name}' is OPEN. "
|
||||||
|
f"Failing fast to prevent repeated failures. "
|
||||||
|
f"Last failure: {self._last_failure_time.isoformat() if self._last_failure_time else 'unknown'}. "
|
||||||
|
f"Will retry in {self.config.timeout_seconds}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check half-open call limit
|
||||||
|
if self._state == CircuitState.HALF_OPEN:
|
||||||
|
if self._half_open_calls >= self.config.half_open_max_calls:
|
||||||
|
raise CircuitBreakerOpen(
|
||||||
|
f"Circuit '{self.name}' is HALF_OPEN with max calls reached. "
|
||||||
|
f"Please wait for current test to complete."
|
||||||
|
)
|
||||||
|
self._half_open_calls += 1
|
||||||
|
|
||||||
|
# Execute function
|
||||||
|
try:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
self._on_success()
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._on_failure(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Decrement half-open counter
|
||||||
|
with self._lock:
|
||||||
|
if self._state == CircuitState.HALF_OPEN:
|
||||||
|
self._half_open_calls -= 1
|
||||||
|
|
||||||
|
def decorator(self):
|
||||||
|
"""
|
||||||
|
Decorator for protecting functions with circuit breaker.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
breaker = CircuitBreaker("my_service")
|
||||||
|
|
||||||
|
@breaker.decorator()
|
||||||
|
def my_function():
|
||||||
|
return do_something()
|
||||||
|
"""
|
||||||
|
def wrapper(func):
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
return self.call(func, *args, **kwargs)
|
||||||
|
return decorated
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Manually reset circuit breaker to closed state.
|
||||||
|
|
||||||
|
Useful for:
|
||||||
|
- Testing
|
||||||
|
- Manual intervention
|
||||||
|
- Clearing error state
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
logger.info(f"Circuit '{self.name}': Manual reset to CLOSED state")
|
||||||
|
self._state = CircuitState.CLOSED
|
||||||
|
self._failure_count = 0
|
||||||
|
self._success_count = 0
|
||||||
|
self._last_failure_time = None
|
||||||
|
self._half_open_calls = 0
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get circuit breaker statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with current state and counters
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
self._update_state()
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"state": self._state.value,
|
||||||
|
"failure_count": self._failure_count,
|
||||||
|
"success_count": self._success_count,
|
||||||
|
"last_failure_time": self._last_failure_time.isoformat() if self._last_failure_time else None,
|
||||||
|
"config": {
|
||||||
|
"failure_threshold": self.config.failure_threshold,
|
||||||
|
"success_threshold": self.config.success_threshold,
|
||||||
|
"timeout_seconds": self.config.timeout_seconds,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreakerOpen(Exception):
|
||||||
|
"""Exception raised when circuit breaker is open."""
|
||||||
|
pass
|
||||||
411
src/utils/input_validation.py
Normal file
411
src/utils/input_validation.py
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Input Validation and Path Sanitization Module
|
||||||
|
|
||||||
|
Provides robust validation for user inputs with security protections
|
||||||
|
against path traversal, injection attacks, and other malicious inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Maximum file size (10GB)
|
||||||
|
MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
# Allowed audio file extensions
|
||||||
|
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"}
|
||||||
|
|
||||||
|
# Allowed output formats
|
||||||
|
ALLOWED_OUTPUT_FORMATS = {"vtt", "srt", "txt", "json"}
|
||||||
|
|
||||||
|
# Model name validation (whitelist)
|
||||||
|
ALLOWED_MODEL_NAMES = {"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"}
|
||||||
|
|
||||||
|
# Device validation
|
||||||
|
ALLOWED_DEVICES = {"cuda", "auto", "cpu"}
|
||||||
|
|
||||||
|
# Compute type validation
|
||||||
|
ALLOWED_COMPUTE_TYPES = {"float16", "int8", "auto"}
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(Exception):
|
||||||
|
"""Base exception for validation errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PathTraversalError(ValidationError):
|
||||||
|
"""Exception for path traversal attempts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidFileTypeError(ValidationError):
|
||||||
|
"""Exception for invalid file types."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FileSizeError(ValidationError):
|
||||||
|
"""Exception for file size issues."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_error_message(error_msg: str, sanitize_paths: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize error messages to prevent information leakage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_msg: Original error message
|
||||||
|
sanitize_paths: Whether to sanitize file paths (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized error message
|
||||||
|
"""
|
||||||
|
if not sanitize_paths:
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
# Replace absolute paths with relative paths
|
||||||
|
# Pattern: /home/user/... or /media/... or /var/... or /tmp/...
|
||||||
|
path_pattern = r'(/(?:home|media|var|tmp|opt|usr)/[^\s:,]+)'
|
||||||
|
|
||||||
|
def replace_path(match):
|
||||||
|
full_path = match.group(1)
|
||||||
|
try:
|
||||||
|
# Try to get just the filename
|
||||||
|
basename = os.path.basename(full_path)
|
||||||
|
return f"<file:{basename}>"
|
||||||
|
except:
|
||||||
|
return "<file:redacted>"
|
||||||
|
|
||||||
|
sanitized = re.sub(path_pattern, replace_path, error_msg)
|
||||||
|
|
||||||
|
# Also sanitize user names if present
|
||||||
|
sanitized = re.sub(r'/home/([^/]+)/', '/home/<user>/', sanitized)
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def validate_path_safe(file_path: str, allowed_dirs: Optional[List[str]] = None) -> Path:
|
||||||
|
"""
|
||||||
|
Validate and sanitize a file path to prevent directory traversal attacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to validate
|
||||||
|
allowed_dirs: Optional list of allowed base directories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved Path object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PathTraversalError: If path contains traversal attempts
|
||||||
|
ValidationError: If path is invalid
|
||||||
|
"""
|
||||||
|
if not file_path:
|
||||||
|
raise ValidationError("File path cannot be empty")
|
||||||
|
|
||||||
|
# Convert to Path object
|
||||||
|
try:
|
||||||
|
path = Path(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValidationError(f"Invalid path format: {sanitize_error_message(str(e))}")
|
||||||
|
|
||||||
|
# Check for path traversal attempts
|
||||||
|
path_str = str(path)
|
||||||
|
if ".." in path_str:
|
||||||
|
logger.warning(f"Path traversal attempt detected: {path_str}")
|
||||||
|
raise PathTraversalError("Path traversal (..) is not allowed")
|
||||||
|
|
||||||
|
# Check for null bytes
|
||||||
|
if "\x00" in path_str:
|
||||||
|
logger.warning(f"Null byte in path detected: {path_str}")
|
||||||
|
raise PathTraversalError("Null bytes in path are not allowed")
|
||||||
|
|
||||||
|
# Resolve to absolute path (but don't follow symlinks yet)
|
||||||
|
try:
|
||||||
|
resolved_path = path.resolve()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValidationError(f"Cannot resolve path: {sanitize_error_message(str(e))}")
|
||||||
|
|
||||||
|
# If allowed_dirs specified, ensure path is within one of them
|
||||||
|
if allowed_dirs:
|
||||||
|
allowed = False
|
||||||
|
for allowed_dir in allowed_dirs:
|
||||||
|
try:
|
||||||
|
allowed_dir_path = Path(allowed_dir).resolve()
|
||||||
|
# Check if resolved_path is under allowed_dir
|
||||||
|
resolved_path.relative_to(allowed_dir_path)
|
||||||
|
allowed = True
|
||||||
|
break
|
||||||
|
except ValueError:
|
||||||
|
# Not relative to this allowed_dir
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
logger.warning(
|
||||||
|
f"Path outside allowed directories: {path_str}, "
|
||||||
|
f"allowed: {allowed_dirs}"
|
||||||
|
)
|
||||||
|
raise PathTraversalError(
|
||||||
|
f"Path must be within allowed directories. "
|
||||||
|
f"Allowed: {[os.path.basename(d) for d in allowed_dirs]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return resolved_path
|
||||||
|
|
||||||
|
|
||||||
|
def validate_audio_file(
|
||||||
|
file_path: str,
|
||||||
|
allowed_dirs: Optional[List[str]] = None,
|
||||||
|
max_size_bytes: int = MAX_FILE_SIZE_BYTES
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Validate audio file path with security checks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to audio file
|
||||||
|
allowed_dirs: Optional list of allowed base directories
|
||||||
|
max_size_bytes: Maximum allowed file size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated Path object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If validation fails
|
||||||
|
PathTraversalError: If path traversal detected
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
InvalidFileTypeError: If file type not allowed
|
||||||
|
FileSizeError: If file too large
|
||||||
|
"""
|
||||||
|
# Validate and sanitize path
|
||||||
|
validated_path = validate_path_safe(file_path, allowed_dirs)
|
||||||
|
|
||||||
|
# Check file exists
|
||||||
|
if not validated_path.exists():
|
||||||
|
raise FileNotFoundError(f"Audio file not found: {validated_path.name}")
|
||||||
|
|
||||||
|
# Check it's a file (not directory)
|
||||||
|
if not validated_path.is_file():
|
||||||
|
raise ValidationError(f"Path is not a file: {validated_path.name}")
|
||||||
|
|
||||||
|
# Check file extension
|
||||||
|
file_ext = validated_path.suffix.lower()
|
||||||
|
if file_ext not in ALLOWED_AUDIO_EXTENSIONS:
|
||||||
|
raise InvalidFileTypeError(
|
||||||
|
f"Unsupported audio format: {file_ext}. "
|
||||||
|
f"Supported: {', '.join(sorted(ALLOWED_AUDIO_EXTENSIONS))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check file size
|
||||||
|
try:
|
||||||
|
file_size = validated_path.stat().st_size
|
||||||
|
except Exception as e:
|
||||||
|
raise ValidationError(f"Cannot check file size: {sanitize_error_message(str(e))}")
|
||||||
|
|
||||||
|
if file_size == 0:
|
||||||
|
raise FileSizeError(f"Audio file is empty: {validated_path.name}")
|
||||||
|
|
||||||
|
if file_size > max_size_bytes:
|
||||||
|
raise FileSizeError(
|
||||||
|
f"File too large: {file_size / (1024**3):.2f}GB. "
|
||||||
|
f"Maximum: {max_size_bytes / (1024**3):.2f}GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warn for large files (>1GB)
|
||||||
|
if file_size > 1024 * 1024 * 1024:
|
||||||
|
logger.warning(
|
||||||
|
f"Large file: {file_size / (1024**3):.2f}GB, "
|
||||||
|
f"may require extended processing time"
|
||||||
|
)
|
||||||
|
|
||||||
|
return validated_path
|
||||||
|
|
||||||
|
|
||||||
|
def validate_output_directory(
|
||||||
|
dir_path: str,
|
||||||
|
allowed_dirs: Optional[List[str]] = None,
|
||||||
|
create_if_missing: bool = True
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Validate output directory path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dir_path: Directory path
|
||||||
|
allowed_dirs: Optional list of allowed base directories
|
||||||
|
create_if_missing: Create directory if it doesn't exist
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated Path object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If validation fails
|
||||||
|
PathTraversalError: If path traversal detected
|
||||||
|
"""
|
||||||
|
# Validate and sanitize path
|
||||||
|
validated_path = validate_path_safe(dir_path, allowed_dirs)
|
||||||
|
|
||||||
|
# Create if requested and doesn't exist
|
||||||
|
if create_if_missing and not validated_path.exists():
|
||||||
|
try:
|
||||||
|
validated_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info(f"Created output directory: {validated_path}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Cannot create output directory: {sanitize_error_message(str(e))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check it's a directory
|
||||||
|
if validated_path.exists() and not validated_path.is_dir():
|
||||||
|
raise ValidationError(f"Path exists but is not a directory: {validated_path.name}")
|
||||||
|
|
||||||
|
return validated_path
|
||||||
|
|
||||||
|
|
||||||
|
def validate_model_name(model_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate Whisper model name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model name to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated model name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If model name invalid
|
||||||
|
"""
|
||||||
|
if not model_name:
|
||||||
|
raise ValidationError("Model name cannot be empty")
|
||||||
|
|
||||||
|
model_name = model_name.strip().lower()
|
||||||
|
|
||||||
|
if model_name not in ALLOWED_MODEL_NAMES:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Invalid model name: {model_name}. "
|
||||||
|
f"Allowed: {', '.join(sorted(ALLOWED_MODEL_NAMES))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
|
def validate_device(device: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate device parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: Device name to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated device name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If device invalid
|
||||||
|
"""
|
||||||
|
if not device:
|
||||||
|
raise ValidationError("Device cannot be empty")
|
||||||
|
|
||||||
|
device = device.strip().lower()
|
||||||
|
|
||||||
|
if device not in ALLOWED_DEVICES:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Invalid device: {device}. "
|
||||||
|
f"Allowed: {', '.join(sorted(ALLOWED_DEVICES))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def validate_compute_type(compute_type: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate compute type parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
compute_type: Compute type to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated compute type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If compute type invalid
|
||||||
|
"""
|
||||||
|
if not compute_type:
|
||||||
|
raise ValidationError("Compute type cannot be empty")
|
||||||
|
|
||||||
|
compute_type = compute_type.strip().lower()
|
||||||
|
|
||||||
|
if compute_type not in ALLOWED_COMPUTE_TYPES:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Invalid compute type: {compute_type}. "
|
||||||
|
f"Allowed: {', '.join(sorted(ALLOWED_COMPUTE_TYPES))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return compute_type
|
||||||
|
|
||||||
|
|
||||||
|
def validate_output_format(output_format: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate output format parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_format: Output format to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated output format
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If output format invalid
|
||||||
|
"""
|
||||||
|
if not output_format:
|
||||||
|
raise ValidationError("Output format cannot be empty")
|
||||||
|
|
||||||
|
output_format = output_format.strip().lower()
|
||||||
|
|
||||||
|
if output_format not in ALLOWED_OUTPUT_FORMATS:
|
||||||
|
raise ValidationError(
|
||||||
|
f"Invalid output format: {output_format}. "
|
||||||
|
f"Allowed: {', '.join(sorted(ALLOWED_OUTPUT_FORMATS))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return output_format
|
||||||
|
|
||||||
|
|
||||||
|
def validate_numeric_range(
|
||||||
|
value: float,
|
||||||
|
min_value: float,
|
||||||
|
max_value: float,
|
||||||
|
param_name: str
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Validate numeric parameter is within range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to validate
|
||||||
|
min_value: Minimum allowed value
|
||||||
|
max_value: Maximum allowed value
|
||||||
|
param_name: Parameter name for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If value out of range
|
||||||
|
"""
|
||||||
|
if value < min_value or value > max_value:
|
||||||
|
raise ValidationError(
|
||||||
|
f"{param_name} must be between {min_value} and {max_value}, "
|
||||||
|
f"got {value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def validate_beam_size(beam_size: int) -> int:
|
||||||
|
"""Validate beam size parameter."""
|
||||||
|
return int(validate_numeric_range(beam_size, 1, 20, "beam_size"))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_temperature(temperature: float) -> float:
|
||||||
|
"""Validate temperature parameter."""
|
||||||
|
return validate_numeric_range(temperature, 0.0, 1.0, "temperature")
|
||||||
237
src/utils/startup.py
Normal file
237
src/utils/startup.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Common Startup Logic Module
|
||||||
|
|
||||||
|
Centralizes startup procedures shared between MCP and API servers,
|
||||||
|
including GPU health checks, job queue initialization, and health monitoring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Import GPU health check with reset
|
||||||
|
try:
|
||||||
|
from core.gpu_health import check_gpu_health_with_reset
|
||||||
|
GPU_HEALTH_CHECK_AVAILABLE = True
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"GPU health check with reset not available: {e}")
|
||||||
|
GPU_HEALTH_CHECK_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def perform_startup_gpu_check(
|
||||||
|
required_device: str = "cuda",
|
||||||
|
auto_reset: bool = True,
|
||||||
|
exit_on_failure: bool = True
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Perform startup GPU health check with optional auto-reset.
|
||||||
|
|
||||||
|
This function:
|
||||||
|
1. Checks if GPU health check is available
|
||||||
|
2. Runs comprehensive GPU health check
|
||||||
|
3. Attempts auto-reset if check fails and auto_reset=True
|
||||||
|
4. Optionally exits process if check fails
|
||||||
|
|
||||||
|
Args:
|
||||||
|
required_device: Required device ("cuda", "auto")
|
||||||
|
auto_reset: Enable automatic GPU driver reset on failure
|
||||||
|
exit_on_failure: Exit process if GPU check fails
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if GPU check passed, False otherwise
|
||||||
|
|
||||||
|
Side effects:
|
||||||
|
May exit process if exit_on_failure=True and check fails
|
||||||
|
"""
|
||||||
|
if not GPU_HEALTH_CHECK_AVAILABLE:
|
||||||
|
logger.warning("GPU health check not available, starting without GPU validation")
|
||||||
|
if exit_on_failure:
|
||||||
|
logger.error("GPU health check required but not available. Exiting.")
|
||||||
|
sys.exit(1)
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("PERFORMING STARTUP GPU HEALTH CHECK")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
status = check_gpu_health_with_reset(
|
||||||
|
expected_device=required_device,
|
||||||
|
auto_reset=auto_reset
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("STARTUP GPU CHECK SUCCESSFUL")
|
||||||
|
logger.info(f"GPU Device: {status.device_name}")
|
||||||
|
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
|
||||||
|
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("=" * 70)
|
||||||
|
logger.error("STARTUP GPU CHECK FAILED")
|
||||||
|
logger.error(f"Error: {e}")
|
||||||
|
|
||||||
|
if exit_on_failure:
|
||||||
|
logger.error("This service requires GPU. Terminating.")
|
||||||
|
logger.error("=" * 70)
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
logger.error("Continuing without GPU (may have reduced functionality)")
|
||||||
|
logger.error("=" * 70)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_job_queue(
|
||||||
|
max_queue_size: Optional[int] = None,
|
||||||
|
metadata_dir: Optional[str] = None
|
||||||
|
) -> 'JobQueue':
|
||||||
|
"""
|
||||||
|
Initialize job queue with environment variable configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_queue_size: Override for max queue size (uses env var if None)
|
||||||
|
metadata_dir: Override for metadata directory (uses env var if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Initialized JobQueue instance (started)
|
||||||
|
"""
|
||||||
|
from core.job_queue import JobQueue
|
||||||
|
|
||||||
|
# Get configuration from environment
|
||||||
|
if max_queue_size is None:
|
||||||
|
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
|
||||||
|
|
||||||
|
if metadata_dir is None:
|
||||||
|
metadata_dir = os.getenv(
|
||||||
|
"JOB_METADATA_DIR",
|
||||||
|
"/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Initializing job queue...")
|
||||||
|
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
|
||||||
|
job_queue.start()
|
||||||
|
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
|
||||||
|
|
||||||
|
return job_queue
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_health_monitor(
|
||||||
|
check_interval_minutes: Optional[int] = None,
|
||||||
|
enabled: Optional[bool] = None
|
||||||
|
) -> Optional['HealthMonitor']:
|
||||||
|
"""
|
||||||
|
Initialize GPU health monitor with environment variable configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
check_interval_minutes: Override for check interval (uses env var if None)
|
||||||
|
enabled: Override for enabled status (uses env var if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Initialized HealthMonitor instance (started), or None if disabled
|
||||||
|
"""
|
||||||
|
from core.gpu_health import HealthMonitor
|
||||||
|
|
||||||
|
# Get configuration from environment
|
||||||
|
if enabled is None:
|
||||||
|
enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true"
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
logger.info("GPU health monitoring disabled")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if check_interval_minutes is None:
|
||||||
|
check_interval_minutes = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
|
||||||
|
|
||||||
|
health_monitor = HealthMonitor(check_interval_minutes=check_interval_minutes)
|
||||||
|
health_monitor.start()
|
||||||
|
logger.info(f"GPU health monitor started (interval={check_interval_minutes} minutes)")
|
||||||
|
|
||||||
|
return health_monitor
|
||||||
|
|
||||||
|
|
||||||
|
def startup_sequence(
|
||||||
|
service_name: str = "whisper-transcription",
|
||||||
|
require_gpu: bool = True,
|
||||||
|
initialize_queue: bool = True,
|
||||||
|
initialize_monitoring: bool = True
|
||||||
|
) -> Tuple[Optional['JobQueue'], Optional['HealthMonitor']]:
|
||||||
|
"""
|
||||||
|
Execute complete startup sequence for a Whisper transcription server.
|
||||||
|
|
||||||
|
This function performs all common startup tasks:
|
||||||
|
1. GPU health check with auto-reset
|
||||||
|
2. Job queue initialization
|
||||||
|
3. Health monitor initialization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service_name: Name of the service (for logging)
|
||||||
|
require_gpu: Whether GPU is required (exit if not available)
|
||||||
|
initialize_queue: Whether to initialize job queue
|
||||||
|
initialize_monitoring: Whether to initialize health monitoring
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (job_queue, health_monitor) - either may be None
|
||||||
|
|
||||||
|
Side effects:
|
||||||
|
May exit process if GPU required but unavailable
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting {service_name}...")
|
||||||
|
|
||||||
|
# Step 1: GPU health check
|
||||||
|
gpu_ok = perform_startup_gpu_check(
|
||||||
|
required_device="cuda",
|
||||||
|
auto_reset=True,
|
||||||
|
exit_on_failure=require_gpu
|
||||||
|
)
|
||||||
|
|
||||||
|
if not gpu_ok and require_gpu:
|
||||||
|
# Should not reach here (exit_on_failure should have exited)
|
||||||
|
logger.error("GPU check failed and GPU is required")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Step 2: Initialize job queue
|
||||||
|
job_queue = None
|
||||||
|
if initialize_queue:
|
||||||
|
job_queue = initialize_job_queue()
|
||||||
|
|
||||||
|
# Step 3: Initialize health monitor
|
||||||
|
health_monitor = None
|
||||||
|
if initialize_monitoring:
|
||||||
|
health_monitor = initialize_health_monitor()
|
||||||
|
|
||||||
|
logger.info(f"{service_name} startup sequence completed")
|
||||||
|
|
||||||
|
return job_queue, health_monitor
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_on_shutdown(
|
||||||
|
job_queue: Optional['JobQueue'] = None,
|
||||||
|
health_monitor: Optional['HealthMonitor'] = None,
|
||||||
|
wait_for_current_job: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Perform cleanup on server shutdown.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_queue: JobQueue instance to stop (if any)
|
||||||
|
health_monitor: HealthMonitor instance to stop (if any)
|
||||||
|
wait_for_current_job: Wait for current job to complete before stopping
|
||||||
|
"""
|
||||||
|
logger.info("Shutting down...")
|
||||||
|
|
||||||
|
if job_queue:
|
||||||
|
job_queue.stop(wait_for_current=wait_for_current_job)
|
||||||
|
logger.info("Job queue stopped")
|
||||||
|
|
||||||
|
if health_monitor:
|
||||||
|
health_monitor.stop()
|
||||||
|
logger.info("Health monitor stopped")
|
||||||
|
|
||||||
|
logger.info("Shutdown complete")
|
||||||
@@ -22,8 +22,8 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Add src to path
|
# Add src to path (go up one level from tests/ to root)
|
||||||
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||||
|
|
||||||
# Color codes for terminal output
|
# Color codes for terminal output
|
||||||
class Colors:
|
class Colors:
|
||||||
@@ -435,7 +435,9 @@ class Phase2Tester:
|
|||||||
|
|
||||||
def _create_test_audio_file(self):
|
def _create_test_audio_file(self):
|
||||||
"""Get the path to the test audio file"""
|
"""Get the path to the test audio file"""
|
||||||
test_audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.mp3"
|
# Use relative path from project root
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
test_audio_path = str(project_root / "data" / "test.mp3")
|
||||||
if not os.path.exists(test_audio_path):
|
if not os.path.exists(test_audio_path):
|
||||||
raise FileNotFoundError(f"Test audio file not found: {test_audio_path}")
|
raise FileNotFoundError(f"Test audio file not found: {test_audio_path}")
|
||||||
return test_audio_path
|
return test_audio_path
|
||||||
@@ -24,8 +24,8 @@ logging.basicConfig(
|
|||||||
datefmt='%H:%M:%S'
|
datefmt='%H:%M:%S'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add src to path
|
# Add src to path (go up one level from tests/ to root)
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||||
|
|
||||||
from core.gpu_health import check_gpu_health, HealthMonitor
|
from core.gpu_health import check_gpu_health, HealthMonitor
|
||||||
from core.job_queue import JobQueue, JobStatus
|
from core.job_queue import JobQueue, JobStatus
|
||||||
@@ -61,8 +61,9 @@ def test_audio_file():
|
|||||||
print("="*60)
|
print("="*60)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use the actual test audio file
|
# Use the actual test audio file (relative to project root)
|
||||||
audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.m4a"
|
project_root = os.path.join(os.path.dirname(__file__), '..')
|
||||||
|
audio_path = os.path.join(project_root, "data/test.mp3")
|
||||||
|
|
||||||
# Verify file exists
|
# Verify file exists
|
||||||
assert os.path.exists(audio_path), "Audio file not found"
|
assert os.path.exists(audio_path), "Audio file not found"
|
||||||
@@ -147,8 +148,9 @@ def test_job_queue():
|
|||||||
job_queue.start()
|
job_queue.start()
|
||||||
print("✓ Job queue started")
|
print("✓ Job queue started")
|
||||||
|
|
||||||
# Use the actual test audio file
|
# Use the actual test audio file (relative to project root)
|
||||||
audio_path = "/home/uad/agents/tools/mcp-transcriptor/data/test.m4a"
|
project_root = os.path.join(os.path.dirname(__file__), '..')
|
||||||
|
audio_path = os.path.join(project_root, "data/test.mp3")
|
||||||
|
|
||||||
# Test job submission
|
# Test job submission
|
||||||
print("\nSubmitting test job...")
|
print("\nSubmitting test job...")
|
||||||
523
tests/test_e2e_integration.py
Executable file
523
tests/test_e2e_integration.py
Executable file
@@ -0,0 +1,523 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test Phase 4: End-to-End Integration Testing
|
||||||
|
|
||||||
|
Comprehensive integration tests for the async job queue system.
|
||||||
|
Tests all scenarios from the DEV_PLAN.md Phase 4 checklist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
import subprocess
|
||||||
|
import signal
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s [%(levelname)s] %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Add src to path (go up one level from tests/ to root)
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||||
|
|
||||||
|
# Color codes for terminal output
|
||||||
|
class Colors:
|
||||||
|
GREEN = '\033[92m'
|
||||||
|
RED = '\033[91m'
|
||||||
|
YELLOW = '\033[93m'
|
||||||
|
BLUE = '\033[94m'
|
||||||
|
CYAN = '\033[96m'
|
||||||
|
END = '\033[0m'
|
||||||
|
BOLD = '\033[1m'
|
||||||
|
|
||||||
|
def print_success(msg):
|
||||||
|
print(f"{Colors.GREEN}✓ {msg}{Colors.END}")
|
||||||
|
|
||||||
|
def print_error(msg):
|
||||||
|
print(f"{Colors.RED}✗ {msg}{Colors.END}")
|
||||||
|
|
||||||
|
def print_info(msg):
|
||||||
|
print(f"{Colors.BLUE}ℹ {msg}{Colors.END}")
|
||||||
|
|
||||||
|
def print_warning(msg):
|
||||||
|
print(f"{Colors.YELLOW}⚠ {msg}{Colors.END}")
|
||||||
|
|
||||||
|
def print_section(msg):
|
||||||
|
print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}")
|
||||||
|
print(f"{Colors.BOLD}{Colors.YELLOW}{msg}{Colors.END}")
|
||||||
|
print(f"{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}\n")
|
||||||
|
|
||||||
|
|
||||||
|
class Phase4Tester:
|
||||||
|
def __init__(self, api_url="http://localhost:8000", test_audio=None):
|
||||||
|
self.api_url = api_url
|
||||||
|
# Use relative path from project root if not provided
|
||||||
|
if test_audio is None:
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
test_audio = str(project_root / "data" / "test.mp3")
|
||||||
|
self.test_audio = test_audio
|
||||||
|
self.test_results = []
|
||||||
|
self.server_process = None
|
||||||
|
|
||||||
|
# Verify test audio exists
|
||||||
|
if not os.path.exists(test_audio):
|
||||||
|
raise FileNotFoundError(f"Test audio file not found: {test_audio}")
|
||||||
|
|
||||||
|
def test(self, name, func):
|
||||||
|
"""Run a test and record result"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Testing: {name}")
|
||||||
|
print_info(f"Testing: {name}")
|
||||||
|
func()
|
||||||
|
logger.info(f"PASSED: {name}")
|
||||||
|
print_success(f"PASSED: {name}")
|
||||||
|
self.test_results.append((name, True, None))
|
||||||
|
return True
|
||||||
|
except AssertionError as e:
|
||||||
|
logger.error(f"FAILED: {name} - {str(e)}")
|
||||||
|
print_error(f"FAILED: {name}")
|
||||||
|
print_error(f" Reason: {str(e)}")
|
||||||
|
self.test_results.append((name, False, str(e)))
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ERROR: {name} - {str(e)}")
|
||||||
|
print_error(f"ERROR: {name}")
|
||||||
|
print_error(f" Exception: {str(e)}")
|
||||||
|
self.test_results.append((name, False, f"Exception: {str(e)}"))
|
||||||
|
return False
|
||||||
|
|
||||||
|
def start_api_server(self, wait_time=5):
|
||||||
|
"""Start the API server in background"""
|
||||||
|
print_info("Starting API server...")
|
||||||
|
# Script is in project root, one level up from tests/
|
||||||
|
script_path = Path(__file__).parent.parent / "run_api_server.sh"
|
||||||
|
|
||||||
|
# Start server in background
|
||||||
|
self.server_process = subprocess.Popen(
|
||||||
|
[str(script_path)],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
preexec_fn=os.setsid
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for server to start
|
||||||
|
time.sleep(wait_time)
|
||||||
|
|
||||||
|
# Verify server is running
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.api_url}/health", timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print_success("API server started successfully")
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
print_error("API server failed to start")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop_api_server(self):
|
||||||
|
"""Stop the API server"""
|
||||||
|
if self.server_process:
|
||||||
|
print_info("Stopping API server...")
|
||||||
|
os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM)
|
||||||
|
self.server_process.wait(timeout=10)
|
||||||
|
print_success("API server stopped")
|
||||||
|
|
||||||
|
def wait_for_job_completion(self, job_id, timeout=60, poll_interval=2):
|
||||||
|
"""Poll job status until completed or failed"""
|
||||||
|
start_time = time.time()
|
||||||
|
last_status = None
|
||||||
|
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.api_url}/jobs/{job_id}")
|
||||||
|
assert response.status_code == 200, f"Failed to get job status: {response.status_code}"
|
||||||
|
|
||||||
|
status_data = response.json()
|
||||||
|
current_status = status_data['status']
|
||||||
|
|
||||||
|
# Print status changes
|
||||||
|
if current_status != last_status:
|
||||||
|
if status_data.get('queue_position') is not None:
|
||||||
|
print_info(f" Job status: {current_status}, queue position: {status_data['queue_position']}")
|
||||||
|
else:
|
||||||
|
print_info(f" Job status: {current_status}")
|
||||||
|
last_status = current_status
|
||||||
|
|
||||||
|
if current_status == "completed":
|
||||||
|
return status_data
|
||||||
|
elif current_status == "failed":
|
||||||
|
raise AssertionError(f"Job failed: {status_data.get('error', 'Unknown error')}")
|
||||||
|
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise AssertionError(f"Request failed: {e}")
|
||||||
|
|
||||||
|
raise AssertionError(f"Job did not complete within {timeout} seconds")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# TEST 1: Single Job Submission and Completion
|
||||||
|
# ========================================================================
|
||||||
|
def test_single_job_flow(self):
|
||||||
|
"""Test complete job flow: submit → poll → get result"""
|
||||||
|
# Submit job
|
||||||
|
print_info(" Submitting job...")
|
||||||
|
response = requests.post(f"{self.api_url}/jobs", json={
|
||||||
|
"audio_path": self.test_audio,
|
||||||
|
"model_name": "large-v3",
|
||||||
|
"output_format": "txt"
|
||||||
|
})
|
||||||
|
assert response.status_code == 200, f"Job submission failed: {response.status_code}"
|
||||||
|
|
||||||
|
job_data = response.json()
|
||||||
|
assert 'job_id' in job_data, "No job_id in response"
|
||||||
|
# Status can be 'queued' or 'running' (if queue is empty and job starts immediately)
|
||||||
|
assert job_data['status'] in ['queued', 'running'], f"Expected status 'queued' or 'running', got '{job_data['status']}'"
|
||||||
|
|
||||||
|
job_id = job_data['job_id']
|
||||||
|
print_success(f" Job submitted: {job_id}")
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
print_info(" Waiting for job completion...")
|
||||||
|
final_status = self.wait_for_job_completion(job_id)
|
||||||
|
|
||||||
|
assert final_status['status'] == 'completed', "Job did not complete"
|
||||||
|
assert final_status['result_path'] is not None, "No result_path in completed job"
|
||||||
|
assert final_status['processing_time_seconds'] is not None, "No processing time"
|
||||||
|
print_success(f" Job completed in {final_status['processing_time_seconds']:.2f}s")
|
||||||
|
|
||||||
|
# Get result
|
||||||
|
print_info(" Retrieving result...")
|
||||||
|
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
|
||||||
|
assert response.status_code == 200, f"Failed to get result: {response.status_code}"
|
||||||
|
|
||||||
|
result_text = response.text
|
||||||
|
assert len(result_text) > 0, "Empty result"
|
||||||
|
print_success(f" Got result: {len(result_text)} characters")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# TEST 2: Multiple Jobs in Queue (FIFO)
|
||||||
|
# ========================================================================
|
||||||
|
def test_multiple_jobs_fifo(self):
|
||||||
|
"""Test multiple jobs are processed in FIFO order"""
|
||||||
|
job_ids = []
|
||||||
|
|
||||||
|
# Submit 3 jobs
|
||||||
|
print_info(" Submitting 3 jobs...")
|
||||||
|
for i in range(3):
|
||||||
|
response = requests.post(f"{self.api_url}/jobs", json={
|
||||||
|
"audio_path": self.test_audio,
|
||||||
|
"model_name": "tiny", # Use tiny model for faster processing
|
||||||
|
"output_format": "txt"
|
||||||
|
})
|
||||||
|
assert response.status_code == 200, f"Job {i+1} submission failed"
|
||||||
|
|
||||||
|
job_data = response.json()
|
||||||
|
job_ids.append(job_data['job_id'])
|
||||||
|
print_info(f" Job {i+1} submitted: {job_data['job_id']}, queue_position: {job_data.get('queue_position', 0)}")
|
||||||
|
|
||||||
|
# Wait for all jobs to complete
|
||||||
|
print_info(" Waiting for all jobs to complete...")
|
||||||
|
for i, job_id in enumerate(job_ids):
|
||||||
|
print_info(f" Waiting for job {i+1}/{len(job_ids)}...")
|
||||||
|
final_status = self.wait_for_job_completion(job_id, timeout=120)
|
||||||
|
assert final_status['status'] == 'completed', f"Job {i+1} failed"
|
||||||
|
|
||||||
|
print_success(f" All {len(job_ids)} jobs completed in FIFO order")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# TEST 3: GPU Health Check
|
||||||
|
# ========================================================================
|
||||||
|
def test_gpu_health_check(self):
|
||||||
|
"""Test GPU health check endpoint"""
|
||||||
|
print_info(" Checking GPU health...")
|
||||||
|
response = requests.get(f"{self.api_url}/health/gpu")
|
||||||
|
assert response.status_code == 200, f"GPU health check failed: {response.status_code}"
|
||||||
|
|
||||||
|
health_data = response.json()
|
||||||
|
assert 'gpu_available' in health_data, "Missing gpu_available field"
|
||||||
|
assert 'gpu_working' in health_data, "Missing gpu_working field"
|
||||||
|
assert 'device_used' in health_data, "Missing device_used field"
|
||||||
|
|
||||||
|
print_info(f" GPU Available: {health_data['gpu_available']}")
|
||||||
|
print_info(f" GPU Working: {health_data['gpu_working']}")
|
||||||
|
print_info(f" Device: {health_data['device_used']}")
|
||||||
|
|
||||||
|
if health_data['gpu_available']:
|
||||||
|
assert health_data['device_name'], "GPU available but no device_name"
|
||||||
|
assert health_data['test_duration_seconds'] < 3, "GPU test took too long (might be using CPU)"
|
||||||
|
print_success(f" GPU is healthy: {health_data['device_name']}")
|
||||||
|
else:
|
||||||
|
print_warning(" GPU not available on this system")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# TEST 4: Invalid Audio Path
|
||||||
|
# ========================================================================
|
||||||
|
def test_invalid_audio_path(self):
|
||||||
|
"""Test job submission with invalid audio path"""
|
||||||
|
print_info(" Submitting job with invalid path...")
|
||||||
|
response = requests.post(f"{self.api_url}/jobs", json={
|
||||||
|
"audio_path": "/invalid/path/does/not/exist.mp3",
|
||||||
|
"model_name": "large-v3"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Should return 400 Bad Request
|
||||||
|
assert response.status_code == 400, f"Expected 400, got {response.status_code}"
|
||||||
|
|
||||||
|
error_data = response.json()
|
||||||
|
assert 'detail' in error_data or 'error' in error_data, "No error message in response"
|
||||||
|
print_success(" Invalid path rejected correctly")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# TEST 5: Job Not Found
|
||||||
|
# ========================================================================
|
||||||
|
def test_job_not_found(self):
|
||||||
|
"""Test retrieving non-existent job"""
|
||||||
|
print_info(" Requesting non-existent job...")
|
||||||
|
fake_job_id = "00000000-0000-0000-0000-000000000000"
|
||||||
|
|
||||||
|
response = requests.get(f"{self.api_url}/jobs/{fake_job_id}")
|
||||||
|
assert response.status_code == 404, f"Expected 404, got {response.status_code}"
|
||||||
|
print_success(" Non-existent job handled correctly")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# TEST 6: Result Before Completion
|
||||||
|
# ========================================================================
|
||||||
|
def test_result_before_completion(self):
|
||||||
|
"""Test getting result for job that hasn't completed"""
|
||||||
|
print_info(" Submitting job and trying to get result immediately...")
|
||||||
|
|
||||||
|
# Submit job
|
||||||
|
response = requests.post(f"{self.api_url}/jobs", json={
|
||||||
|
"audio_path": self.test_audio,
|
||||||
|
"model_name": "large-v3"
|
||||||
|
})
|
||||||
|
assert response.status_code == 200
|
||||||
|
job_id = response.json()['job_id']
|
||||||
|
|
||||||
|
# Try to get result immediately (job is still queued/running)
|
||||||
|
time.sleep(0.5)
|
||||||
|
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
|
||||||
|
|
||||||
|
# Should return 409 Conflict or similar
|
||||||
|
assert response.status_code in [409, 400, 404], f"Expected 4xx error, got {response.status_code}"
|
||||||
|
print_success(" Result request before completion handled correctly")
|
||||||
|
|
||||||
|
# Clean up: wait for job to complete
|
||||||
|
self.wait_for_job_completion(job_id)
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# TEST 7: List Jobs
|
||||||
|
# ========================================================================
|
||||||
|
def test_list_jobs(self):
|
||||||
|
"""Test listing jobs with filters"""
|
||||||
|
print_info(" Testing job listing...")
|
||||||
|
|
||||||
|
# List all jobs
|
||||||
|
response = requests.get(f"{self.api_url}/jobs")
|
||||||
|
assert response.status_code == 200, f"List jobs failed: {response.status_code}"
|
||||||
|
|
||||||
|
jobs_data = response.json()
|
||||||
|
assert 'jobs' in jobs_data, "No jobs array in response"
|
||||||
|
assert isinstance(jobs_data['jobs'], list), "Jobs is not a list"
|
||||||
|
print_info(f" Found {len(jobs_data['jobs'])} jobs")
|
||||||
|
|
||||||
|
# List only completed jobs
|
||||||
|
response = requests.get(f"{self.api_url}/jobs?status=completed")
|
||||||
|
assert response.status_code == 200
|
||||||
|
completed_jobs = response.json()['jobs']
|
||||||
|
print_info(f" Found {len(completed_jobs)} completed jobs")
|
||||||
|
|
||||||
|
# List with limit
|
||||||
|
response = requests.get(f"{self.api_url}/jobs?limit=5")
|
||||||
|
assert response.status_code == 200
|
||||||
|
limited_jobs = response.json()['jobs']
|
||||||
|
assert len(limited_jobs) <= 5, "Limit not respected"
|
||||||
|
print_success(" Job listing works correctly")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# TEST 8: Server Restart with Job Persistence
|
||||||
|
# ========================================================================
|
||||||
|
def test_server_restart_persistence(self):
|
||||||
|
"""Test that jobs persist across server restarts"""
|
||||||
|
print_info(" Testing job persistence across restart...")
|
||||||
|
|
||||||
|
# Submit a job
|
||||||
|
response = requests.post(f"{self.api_url}/jobs", json={
|
||||||
|
"audio_path": self.test_audio,
|
||||||
|
"model_name": "tiny"
|
||||||
|
})
|
||||||
|
assert response.status_code == 200
|
||||||
|
job_id = response.json()['job_id']
|
||||||
|
print_info(f" Submitted job: {job_id}")
|
||||||
|
|
||||||
|
# Get job count before restart
|
||||||
|
response = requests.get(f"{self.api_url}/jobs")
|
||||||
|
jobs_before = len(response.json()['jobs'])
|
||||||
|
print_info(f" Jobs before restart: {jobs_before}")
|
||||||
|
|
||||||
|
# Restart server
|
||||||
|
print_info(" Restarting server...")
|
||||||
|
self.stop_api_server()
|
||||||
|
time.sleep(2)
|
||||||
|
assert self.start_api_server(wait_time=8), "Server failed to restart"
|
||||||
|
|
||||||
|
# Check jobs after restart
|
||||||
|
response = requests.get(f"{self.api_url}/jobs")
|
||||||
|
assert response.status_code == 200
|
||||||
|
jobs_after = len(response.json()['jobs'])
|
||||||
|
print_info(f" Jobs after restart: {jobs_after}")
|
||||||
|
|
||||||
|
# Check our specific job is still there (this is the key test)
|
||||||
|
response = requests.get(f"{self.api_url}/jobs/{job_id}")
|
||||||
|
assert response.status_code == 200, "Job not found after restart"
|
||||||
|
|
||||||
|
# Note: Total count may differ due to job retention/cleanup, but persistence works if we can find the job
|
||||||
|
if jobs_after < jobs_before:
|
||||||
|
print_warning(f" Job count decreased ({jobs_before} -> {jobs_after}), may be due to cleanup")
|
||||||
|
|
||||||
|
print_success(" Jobs persisted correctly across restart")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# TEST 9: Health Endpoint
|
||||||
|
# ========================================================================
|
||||||
|
def test_health_endpoint(self):
|
||||||
|
"""Test basic health endpoint"""
|
||||||
|
print_info(" Checking health endpoint...")
|
||||||
|
response = requests.get(f"{self.api_url}/health")
|
||||||
|
assert response.status_code == 200, f"Health check failed: {response.status_code}"
|
||||||
|
|
||||||
|
health_data = response.json()
|
||||||
|
assert health_data['status'] == 'healthy', "Server not healthy"
|
||||||
|
print_success(" Health endpoint OK")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# TEST 10: Models Endpoint
|
||||||
|
# ========================================================================
|
||||||
|
def test_models_endpoint(self):
|
||||||
|
"""Test models information endpoint"""
|
||||||
|
print_info(" Checking models endpoint...")
|
||||||
|
response = requests.get(f"{self.api_url}/models")
|
||||||
|
assert response.status_code == 200, f"Models endpoint failed: {response.status_code}"
|
||||||
|
|
||||||
|
models_data = response.json()
|
||||||
|
assert 'available_models' in models_data, "No available_models field"
|
||||||
|
assert 'available_devices' in models_data, "No available_devices field"
|
||||||
|
assert len(models_data['available_models']) > 0, "No models listed"
|
||||||
|
print_info(f" Available models: {len(models_data['available_models'])}")
|
||||||
|
print_success(" Models endpoint OK")
|
||||||
|
|
||||||
|
def print_summary(self):
|
||||||
|
"""Print test summary"""
|
||||||
|
print_section("TEST SUMMARY")
|
||||||
|
|
||||||
|
passed = sum(1 for _, result, _ in self.test_results if result)
|
||||||
|
failed = len(self.test_results) - passed
|
||||||
|
|
||||||
|
for name, result, error in self.test_results:
|
||||||
|
if result:
|
||||||
|
print_success(f"{name}")
|
||||||
|
else:
|
||||||
|
print_error(f"{name}")
|
||||||
|
if error:
|
||||||
|
print(f" {error}")
|
||||||
|
|
||||||
|
print(f"\n{Colors.BOLD}Total: {len(self.test_results)} | ", end="")
|
||||||
|
print(f"{Colors.GREEN}Passed: {passed}{Colors.END} | ", end="")
|
||||||
|
print(f"{Colors.RED}Failed: {failed}{Colors.END}\n")
|
||||||
|
|
||||||
|
return failed == 0
|
||||||
|
|
||||||
|
def run_all_tests(self, start_server=True):
|
||||||
|
"""Run all Phase 4 integration tests"""
|
||||||
|
print_section("PHASE 4: END-TO-END INTEGRATION TESTING")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start server if requested
|
||||||
|
if start_server:
|
||||||
|
if not self.start_api_server():
|
||||||
|
print_error("Failed to start API server. Aborting tests.")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# Verify server is already running
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{self.api_url}/health", timeout=5)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print_error("Server is not responding. Please start it first.")
|
||||||
|
return False
|
||||||
|
print_info("Using existing API server")
|
||||||
|
except:
|
||||||
|
print_error("Cannot connect to API server. Please start it first.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
print_section("TEST 1: Single Job Submission and Completion")
|
||||||
|
self.test("Single job flow (submit → poll → get result)", self.test_single_job_flow)
|
||||||
|
|
||||||
|
print_section("TEST 2: Multiple Jobs (FIFO Order)")
|
||||||
|
self.test("Multiple jobs in queue (FIFO)", self.test_multiple_jobs_fifo)
|
||||||
|
|
||||||
|
print_section("TEST 3: GPU Health Check")
|
||||||
|
self.test("GPU health check endpoint", self.test_gpu_health_check)
|
||||||
|
|
||||||
|
print_section("TEST 4: Error Handling - Invalid Path")
|
||||||
|
self.test("Invalid audio path rejection", self.test_invalid_audio_path)
|
||||||
|
|
||||||
|
print_section("TEST 5: Error Handling - Job Not Found")
|
||||||
|
self.test("Non-existent job handling", self.test_job_not_found)
|
||||||
|
|
||||||
|
print_section("TEST 6: Error Handling - Result Before Completion")
|
||||||
|
self.test("Result request before completion", self.test_result_before_completion)
|
||||||
|
|
||||||
|
print_section("TEST 7: Job Listing")
|
||||||
|
self.test("List jobs with filters", self.test_list_jobs)
|
||||||
|
|
||||||
|
print_section("TEST 8: Health Endpoint")
|
||||||
|
self.test("Basic health endpoint", self.test_health_endpoint)
|
||||||
|
|
||||||
|
print_section("TEST 9: Models Endpoint")
|
||||||
|
self.test("Models information endpoint", self.test_models_endpoint)
|
||||||
|
|
||||||
|
print_section("TEST 10: Server Restart Persistence")
|
||||||
|
self.test("Job persistence across server restart", self.test_server_restart_persistence)
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
success = self.print_summary()
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup
|
||||||
|
if start_server and self.server_process:
|
||||||
|
self.stop_api_server()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main test runner"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Phase 4 Integration Tests')
|
||||||
|
parser.add_argument('--url', default='http://localhost:8000', help='API server URL')
|
||||||
|
# Default to None so Phase4Tester uses relative path
|
||||||
|
parser.add_argument('--audio', default=None,
|
||||||
|
help='Path to test audio file (default: <project_root>/data/test.mp3)')
|
||||||
|
parser.add_argument('--no-start-server', action='store_true',
|
||||||
|
help='Do not start server (assume it is already running)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
tester = Phase4Tester(api_url=args.url, test_audio=args.audio)
|
||||||
|
success = tester.run_all_tests(start_server=not args.no_start_server)
|
||||||
|
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user