Files
Fast-Whisper-MCP-Server/tests/test_async_api_integration.py
Alihan 5fb742a312 Add circuit breaker, input validation, and refactor startup logic
- Implement circuit breaker pattern for GPU health checks
  - Prevents repeated failures with configurable thresholds
  - Three states: CLOSED, OPEN, HALF_OPEN
  - Integrated into GPU health monitoring

- Add comprehensive input validation and path sanitization
  - Path traversal attack prevention
  - Whitelist-based validation for models, devices, formats
  - Error message sanitization to prevent information leakage
  - File size limits and security checks

- Centralize startup logic across servers
  - Extract common startup procedures to utils/startup.py
  - Deduplicate GPU health checks and initialization code
  - Simplify both MCP and API server startup sequences

- Add proper Python package structure
  - Add __init__.py files to all modules
  - Improve package organization

- Add circuit breaker status API endpoints
  - GET /health/circuit-breaker - View circuit breaker stats
  - POST /health/circuit-breaker/reset - Reset circuit breaker

- Reorganize test files into tests/ directory
  - Rename and restructure test files for better organization
2025-10-10 01:03:55 +03:00

538 lines
22 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Test Phase 2: Async Job Queue Integration
Tests the async job queue system for both API and MCP servers.
Validates all new endpoints and error handling.
"""
import os
import sys
import time
import json
import logging
import requests
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# Add src to path (go up one level from tests/ to root)
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
# Color codes for terminal output
class Colors:
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
END = '\033[0m'
BOLD = '\033[1m'
def print_success(msg):
print(f"{Colors.GREEN}{msg}{Colors.END}")
def print_error(msg):
print(f"{Colors.RED}{msg}{Colors.END}")
def print_info(msg):
print(f"{Colors.BLUE} {msg}{Colors.END}")
def print_section(msg):
print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}")
print(f"{Colors.BOLD}{Colors.YELLOW}{msg}{Colors.END}")
print(f"{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}\n")
class Phase2Tester:
def __init__(self, api_url="http://localhost:8000"):
self.api_url = api_url
self.test_results = []
def test(self, name, func):
"""Run a test and record result"""
try:
logger.info(f"Testing: {name}")
print_info(f"Testing: {name}")
func()
logger.info(f"PASSED: {name}")
print_success(f"PASSED: {name}")
self.test_results.append((name, True, None))
return True
except AssertionError as e:
logger.error(f"FAILED: {name} - {str(e)}")
print_error(f"FAILED: {name}")
print_error(f" Reason: {str(e)}")
self.test_results.append((name, False, str(e)))
return False
except Exception as e:
logger.error(f"ERROR: {name} - {str(e)}")
print_error(f"ERROR: {name}")
print_error(f" Exception: {str(e)}")
self.test_results.append((name, False, f"Exception: {str(e)}"))
return False
def print_summary(self):
"""Print test summary"""
print_section("TEST SUMMARY")
passed = sum(1 for _, result, _ in self.test_results if result)
failed = len(self.test_results) - passed
for name, result, error in self.test_results:
if result:
print_success(f"{name}")
else:
print_error(f"{name}")
if error:
print(f" {error}")
print(f"\n{Colors.BOLD}Total: {len(self.test_results)} | ", end="")
print(f"{Colors.GREEN}Passed: {passed}{Colors.END} | ", end="")
print(f"{Colors.RED}Failed: {failed}{Colors.END}\n")
return failed == 0
# ========================================================================
# API Server Tests
# ========================================================================
def test_api_root_endpoint(self):
"""Test GET / returns new API information"""
logger.info(f"GET {self.api_url}/")
resp = requests.get(f"{self.api_url}/")
logger.info(f"Response status: {resp.status_code}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
logger.info(f"Response data: {json.dumps(data, indent=2)}")
assert data["version"] == "0.2.0", "Version should be 0.2.0"
assert "POST /jobs" in str(data["endpoints"]), "Should have POST /jobs endpoint"
assert "workflow" in data, "Should have workflow documentation"
def test_api_health_endpoint(self):
"""Test GET /health still works"""
logger.info(f"GET {self.api_url}/health")
resp = requests.get(f"{self.api_url}/health")
logger.info(f"Response status: {resp.status_code}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
logger.info(f"Response data: {data}")
assert data["status"] == "healthy", "Health check should return healthy"
def test_api_models_endpoint(self):
"""Test GET /models still works"""
logger.info(f"GET {self.api_url}/models")
resp = requests.get(f"{self.api_url}/models")
logger.info(f"Response status: {resp.status_code}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
logger.info(f"Available models: {data.get('available_models', [])}")
assert "available_models" in data, "Should return available models"
def test_api_gpu_health_endpoint(self):
"""Test GET /health/gpu returns GPU status"""
logger.info(f"GET {self.api_url}/health/gpu")
resp = requests.get(f"{self.api_url}/health/gpu")
logger.info(f"Response status: {resp.status_code}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
logger.info(f"GPU health: {json.dumps(data, indent=2)}")
assert "gpu_available" in data, "Should have gpu_available field"
assert "gpu_working" in data, "Should have gpu_working field"
assert "interpretation" in data, "Should have interpretation field"
print_info(f" GPU Status: {data.get('interpretation', 'unknown')}")
def test_api_submit_job_invalid_audio(self):
"""Test POST /jobs with invalid audio path returns 400"""
payload = {
"audio_path": "/nonexistent/file.mp3",
"model_name": "tiny",
"output_format": "txt"
}
logger.info(f"POST {self.api_url}/jobs with invalid audio path")
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
resp = requests.post(f"{self.api_url}/jobs", json=payload)
logger.info(f"Response status: {resp.status_code}")
logger.info(f"Response: {resp.json()}")
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
data = resp.json()
assert "error" in data["detail"], "Should have error field"
assert data["detail"]["error"] == "Invalid audio file", f"Wrong error type: {data['detail']['error']}"
print_info(f" Error message: {data['detail']['message'][:50]}...")
def test_api_submit_job_cpu_device_rejected(self):
"""Test POST /jobs with device=cpu is rejected (400)"""
# Create a test audio file first
logger.info("Creating test audio file...")
test_audio = self._create_test_audio_file()
logger.info(f"Test audio created at: {test_audio}")
payload = {
"audio_path": test_audio,
"model_name": "tiny",
"device": "cpu",
"output_format": "txt"
}
logger.info(f"POST {self.api_url}/jobs with device=cpu")
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
resp = requests.post(f"{self.api_url}/jobs", json=payload)
logger.info(f"Response status: {resp.status_code}")
logger.info(f"Response: {resp.json()}")
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
data = resp.json()
assert "error" in data["detail"], "Should have error field"
assert "Invalid device" in data["detail"]["error"] or "CPU" in data["detail"]["message"], \
"Should reject CPU device"
def test_api_submit_job_success(self):
"""Test POST /jobs with valid audio returns job_id"""
logger.info("Creating test audio file...")
test_audio = self._create_test_audio_file()
logger.info(f"Test audio created at: {test_audio}")
payload = {
"audio_path": test_audio,
"model_name": "tiny",
"device": "auto",
"output_format": "txt"
}
logger.info(f"POST {self.api_url}/jobs with valid audio")
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
resp = requests.post(f"{self.api_url}/jobs", json=payload)
logger.info(f"Response status: {resp.status_code}")
logger.info(f"Response: {json.dumps(resp.json(), indent=2)}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
assert "job_id" in data, "Should return job_id"
assert "status" in data, "Should return status"
assert data["status"] == "queued", f"Status should be queued, got {data['status']}"
assert "queue_position" in data, "Should return queue_position"
assert "message" in data, "Should return message"
logger.info(f"Job submitted successfully: {data['job_id']}")
print_info(f" Job ID: {data['job_id']}")
print_info(f" Queue position: {data['queue_position']}")
# Store job_id for later tests
self.test_job_id = data["job_id"]
def test_api_get_job_status(self):
"""Test GET /jobs/{job_id} returns job status"""
if not hasattr(self, 'test_job_id'):
print_info(" Skipping (no test_job_id from previous test)")
return
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
assert "job_id" in data, "Should return job_id"
assert "status" in data, "Should return status"
assert data["status"] in ["queued", "running", "completed", "failed"], \
f"Invalid status: {data['status']}"
print_info(f" Status: {data['status']}")
def test_api_get_job_status_not_found(self):
"""Test GET /jobs/{job_id} with invalid ID returns 404"""
fake_job_id = "00000000-0000-0000-0000-000000000000"
resp = requests.get(f"{self.api_url}/jobs/{fake_job_id}")
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}"
data = resp.json()
assert "error" in data["detail"], "Should have error field"
assert data["detail"]["error"] == "Job not found", f"Wrong error: {data['detail']['error']}"
def test_api_get_job_result_not_completed(self):
"""Test GET /jobs/{job_id}/result when job not completed returns 409"""
if not hasattr(self, 'test_job_id'):
print_info(" Skipping (no test_job_id from previous test)")
return
# Check current status
status_resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
current_status = status_resp.json()["status"]
if current_status == "completed":
print_info(" Skipping (job already completed)")
return
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}/result")
assert resp.status_code == 409, f"Expected 409, got {resp.status_code}"
data = resp.json()
assert "error" in data["detail"], "Should have error field"
assert data["detail"]["error"] == "Job not completed", f"Wrong error: {data['detail']['error']}"
assert "current_status" in data["detail"], "Should include current_status"
def test_api_list_jobs(self):
"""Test GET /jobs returns job list"""
resp = requests.get(f"{self.api_url}/jobs")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
assert "jobs" in data, "Should have jobs field"
assert "total" in data, "Should have total field"
assert isinstance(data["jobs"], list), "Jobs should be a list"
print_info(f" Total jobs: {data['total']}")
def test_api_list_jobs_with_filter(self):
"""Test GET /jobs?status=queued filters by status"""
resp = requests.get(f"{self.api_url}/jobs?status=queued&limit=10")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
assert "jobs" in data, "Should have jobs field"
assert "filters" in data, "Should have filters field"
assert data["filters"]["status"] == "queued", "Filter should be applied"
# All returned jobs should be queued
for job in data["jobs"]:
assert job["status"] == "queued", f"Job {job['job_id']} has wrong status: {job['status']}"
def test_api_wait_for_job_completion(self):
"""Test waiting for job to complete and retrieving result"""
if not hasattr(self, 'test_job_id'):
logger.warning("Skipping - no test_job_id from previous test")
print_info(" Skipping (no test_job_id from previous test)")
return
logger.info(f"Waiting for job {self.test_job_id} to complete (max 60s)...")
print_info(" Waiting for job to complete (max 60s)...")
max_wait = 60
start_time = time.time()
while time.time() - start_time < max_wait:
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
data = resp.json()
status = data["status"]
elapsed = int(time.time() - start_time)
logger.info(f"Job status: {status} (elapsed: {elapsed}s)")
print_info(f" Status: {status} (elapsed: {elapsed}s)")
if status == "completed":
logger.info("Job completed successfully!")
print_success(" Job completed!")
# Now get the result
logger.info("Fetching job result...")
result_resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}/result")
logger.info(f"Result response status: {result_resp.status_code}")
assert result_resp.status_code == 200, f"Expected 200, got {result_resp.status_code}"
result_data = result_resp.json()
logger.info(f"Result data keys: {result_data.keys()}")
assert "result" in result_data, "Should have result field"
assert len(result_data["result"]) > 0, "Result should not be empty"
actual_text = result_data["result"].strip()
logger.info(f"Transcription result: '{actual_text}'")
print_info(f" Transcription: '{actual_text}'")
return
elif status == "failed":
error_msg = f"Job failed: {data.get('error', 'unknown error')}"
logger.error(error_msg)
raise AssertionError(error_msg)
time.sleep(2)
error_msg = f"Job did not complete within {max_wait}s"
logger.error(error_msg)
raise AssertionError(error_msg)
# ========================================================================
# MCP Server Tests (Import-based)
# ========================================================================
def test_mcp_imports(self):
"""Test MCP server modules can be imported"""
try:
logger.info("Importing MCP server module...")
from servers import whisper_server
logger.info("Checking for new async tools...")
assert hasattr(whisper_server, 'transcribe_async'), "Should have transcribe_async tool"
assert hasattr(whisper_server, 'get_job_status'), "Should have get_job_status tool"
assert hasattr(whisper_server, 'get_job_result'), "Should have get_job_result tool"
assert hasattr(whisper_server, 'list_transcription_jobs'), "Should have list_transcription_jobs tool"
assert hasattr(whisper_server, 'check_gpu_health'), "Should have check_gpu_health tool"
assert hasattr(whisper_server, 'get_model_info_api'), "Should have get_model_info_api tool"
logger.info("All new tools found!")
# Verify old tools are removed
logger.info("Verifying old tools are removed...")
assert not hasattr(whisper_server, 'transcribe'), "Old transcribe tool should be removed"
assert not hasattr(whisper_server, 'batch_transcribe_audio'), "Old batch_transcribe_audio tool should be removed"
logger.info("Old tools successfully removed!")
except ImportError as e:
logger.error(f"Failed to import MCP server: {e}")
raise AssertionError(f"Failed to import MCP server: {e}")
def test_job_queue_integration(self):
"""Test JobQueue integration is working"""
from core.job_queue import JobQueue, JobStatus
# Create a test queue
test_queue = JobQueue(max_queue_size=5, metadata_dir="/tmp/test_job_queue")
try:
# Verify it can be started
test_queue.start()
assert test_queue._worker_thread is not None, "Worker thread should be created"
assert test_queue._worker_thread.is_alive(), "Worker thread should be running"
finally:
# Clean up
test_queue.stop(wait_for_current=False)
def test_health_monitor_integration(self):
"""Test HealthMonitor integration is working"""
from core.gpu_health import HealthMonitor
# Create a test monitor
test_monitor = HealthMonitor(check_interval_minutes=60) # Long interval
try:
# Verify it can be started
test_monitor.start()
assert test_monitor._thread is not None, "Monitor thread should be created"
assert test_monitor._thread.is_alive(), "Monitor thread should be running"
# Check we can get status
status = test_monitor.get_latest_status()
assert status is not None, "Should have initial status"
finally:
# Clean up
test_monitor.stop()
# ========================================================================
# Helper Methods
# ========================================================================
def _create_test_audio_file(self):
"""Get the path to the test audio file"""
# Use relative path from project root
project_root = Path(__file__).parent.parent
test_audio_path = str(project_root / "data" / "test.mp3")
if not os.path.exists(test_audio_path):
raise FileNotFoundError(f"Test audio file not found: {test_audio_path}")
return test_audio_path
def main():
print_section("PHASE 2: ASYNC JOB QUEUE INTEGRATION TESTS")
logger.info("=" * 70)
logger.info("PHASE 2: ASYNC JOB QUEUE INTEGRATION TESTS")
logger.info("=" * 70)
# Check if API server is running
api_url = os.getenv("API_URL", "http://localhost:8000")
logger.info(f"Testing API server at: {api_url}")
print_info(f"Testing API server at: {api_url}")
try:
logger.info("Checking API server health...")
resp = requests.get(f"{api_url}/health", timeout=2)
logger.info(f"Health check status: {resp.status_code}")
if resp.status_code != 200:
logger.error(f"API server not responding correctly at {api_url}")
print_error(f"API server not responding correctly at {api_url}")
print_error("Please start the API server with: ./run_api_server.sh")
return 1
except requests.exceptions.RequestException as e:
logger.error(f"Cannot connect to API server: {e}")
print_error(f"Cannot connect to API server at {api_url}")
print_error("Please start the API server with: ./run_api_server.sh")
return 1
logger.info(f"API server is running at {api_url}")
print_success(f"API server is running at {api_url}")
# Create tester
tester = Phase2Tester(api_url=api_url)
# ========================================================================
# Run API Tests
# ========================================================================
print_section("API SERVER TESTS")
logger.info("Starting API server tests...")
tester.test("API Root Endpoint", tester.test_api_root_endpoint)
tester.test("API Health Endpoint", tester.test_api_health_endpoint)
tester.test("API Models Endpoint", tester.test_api_models_endpoint)
tester.test("API GPU Health Endpoint", tester.test_api_gpu_health_endpoint)
print_section("API JOB SUBMISSION TESTS")
tester.test("Submit Job - Invalid Audio (400)", tester.test_api_submit_job_invalid_audio)
tester.test("Submit Job - CPU Device Rejected (400)", tester.test_api_submit_job_cpu_device_rejected)
tester.test("Submit Job - Success (200)", tester.test_api_submit_job_success)
print_section("API JOB STATUS TESTS")
tester.test("Get Job Status - Success", tester.test_api_get_job_status)
tester.test("Get Job Status - Not Found (404)", tester.test_api_get_job_status_not_found)
tester.test("Get Job Result - Not Completed (409)", tester.test_api_get_job_result_not_completed)
print_section("API JOB LISTING TESTS")
tester.test("List Jobs", tester.test_api_list_jobs)
tester.test("List Jobs - With Filter", tester.test_api_list_jobs_with_filter)
print_section("API JOB COMPLETION TEST")
tester.test("Wait for Job Completion & Get Result", tester.test_api_wait_for_job_completion)
# ========================================================================
# Run MCP Tests
# ========================================================================
print_section("MCP SERVER TESTS")
logger.info("Starting MCP server tests...")
tester.test("MCP Module Imports", tester.test_mcp_imports)
tester.test("JobQueue Integration", tester.test_job_queue_integration)
tester.test("HealthMonitor Integration", tester.test_health_monitor_integration)
# ========================================================================
# Print Summary
# ========================================================================
logger.info("All tests completed, generating summary...")
success = tester.print_summary()
if success:
logger.info("ALL TESTS PASSED!")
print_section("ALL TESTS PASSED! ✓")
return 0
else:
logger.error("SOME TESTS FAILED!")
print_section("SOME TESTS FAILED! ✗")
return 1
if __name__ == "__main__":
sys.exit(main())