Compare commits
4 Commits
7c9a8d8378
...
5fb742a312
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fb742a312 | ||
|
|
40555592e6 | ||
|
|
1292f0f09b | ||
|
|
e7a457e602 |
334
CLAUDE.md
334
CLAUDE.md
@@ -4,13 +4,21 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
## Overview
|
||||
|
||||
This is a Whisper-based speech recognition service that provides high-performance audio transcription using Faster Whisper. The service can run as either:
|
||||
This is a Whisper-based speech recognition service that provides high-performance audio transcription using Faster Whisper. The service runs as either:
|
||||
|
||||
1. **MCP Server** - For integration with Claude Desktop and other MCP clients
|
||||
2. **REST API Server** - For HTTP-based integrations
|
||||
2. **REST API Server** - For HTTP-based integrations with async job queue support
|
||||
|
||||
Both servers share the same core transcription logic and can run independently or simultaneously on different ports.
|
||||
|
||||
**Key Features:**
|
||||
- Async job queue system for long-running transcriptions (prevents HTTP timeouts)
|
||||
- GPU health monitoring with strict failure detection (prevents silent CPU fallback)
|
||||
- **Automatic GPU driver reset** on CUDA errors with cooldown protection (handles sleep/wake issues)
|
||||
- Dual-server architecture (MCP + REST API)
|
||||
- Model caching for fast repeated transcriptions
|
||||
- Automatic batch size optimization based on GPU memory
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Environment Setup
|
||||
@@ -88,46 +96,93 @@ docker run --gpus all -v /path/to/models:/models -v /path/to/outputs:/outputs wh
|
||||
|
||||
## Architecture
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
.
|
||||
├── src/ # Source code directory
|
||||
│ ├── servers/ # Server implementations
|
||||
│ │ ├── whisper_server.py # MCP server entry point
|
||||
│ │ └── api_server.py # REST API server (async job queue)
|
||||
│ ├── core/ # Core business logic
|
||||
│ │ ├── transcriber.py # Transcription logic (single & batch)
|
||||
│ │ ├── model_manager.py # Model lifecycle & caching
|
||||
│ │ ├── job_queue.py # Async job queue manager
|
||||
│ │ └── gpu_health.py # GPU health monitoring
|
||||
│ └── utils/ # Utility modules
|
||||
│ ├── audio_processor.py # Audio validation & preprocessing
|
||||
│ ├── formatters.py # Output format conversion
|
||||
│ └── test_audio_generator.py # Test audio generation for GPU checks
|
||||
├── run_mcp_server.sh # MCP server startup script
|
||||
├── run_api_server.sh # API server startup script
|
||||
├── reset_gpu.sh # GPU driver reset script
|
||||
├── DEV_PLAN.md # Development plan for async features
|
||||
├── requirements.txt # Python dependencies
|
||||
└── pyproject.toml # Project configuration
|
||||
```
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **whisper_server.py** - MCP server entry point
|
||||
- Uses FastMCP framework to expose three MCP tools
|
||||
- Delegates to transcriber.py for actual processing
|
||||
1. **src/servers/whisper_server.py** - MCP server entry point
|
||||
- Uses FastMCP framework to expose MCP tools
|
||||
- Three main tools: `get_model_info_api()`, `transcribe()`, `batch_transcribe_audio()`
|
||||
- Server initialization at line 19
|
||||
|
||||
2. **api_server.py** - REST API server entry point
|
||||
- Uses FastAPI framework to expose HTTP endpoints
|
||||
- Provides 5 REST endpoints: `/`, `/health`, `/models`, `/transcribe`, `/batch-transcribe`, `/upload-transcribe`
|
||||
- Shares the same core transcription logic with MCP server
|
||||
- Includes file upload support via multipart/form-data
|
||||
2. **src/servers/api_server.py** - REST API server entry point
|
||||
- Uses FastAPI framework for HTTP endpoints
|
||||
- Provides REST endpoints: `/`, `/health`, `/models`, `/transcribe`, `/batch-transcribe`, `/upload-transcribe`
|
||||
- Shares core transcription logic with MCP server
|
||||
- File upload support via multipart/form-data
|
||||
|
||||
3. **transcriber.py** - Core transcription logic (shared by both servers)
|
||||
- `transcribe_audio()` (line 38) - Single file transcription with environment variable support
|
||||
- `batch_transcribe()` (line 208) - Batch processing with progress reporting
|
||||
- All parameters support environment variable defaults
|
||||
- Handles output formatting delegation to formatters.py
|
||||
3. **src/core/transcriber.py** - Core transcription logic (shared by both servers)
|
||||
- `transcribe_audio()`:39 - Single file transcription with environment variable support
|
||||
- `batch_transcribe()`:209 - Batch processing with progress reporting
|
||||
- All parameters support environment variable defaults (lines 21-37)
|
||||
- Delegates output formatting to utils.formatters
|
||||
|
||||
4. **model_manager.py** - Whisper model lifecycle management
|
||||
- `get_whisper_model()` (line 44) - Returns cached model instances or loads new ones
|
||||
- `test_gpu_driver()` (line 20) - GPU validation before model loading
|
||||
4. **src/core/model_manager.py** - Whisper model lifecycle management
|
||||
- `get_whisper_model()`:44 - Returns cached model instances or loads new ones
|
||||
- `test_gpu_driver()`:20 - GPU validation before model loading
|
||||
- **CRITICAL**: GPU-only mode enforced at lines 64-90 (no CPU fallback)
|
||||
- Global `model_instances` dict caches loaded models to prevent reloading
|
||||
- Automatically determines batch size based on available GPU memory (lines 113-134)
|
||||
- Automatic batch size optimization based on GPU memory (lines 134-147)
|
||||
|
||||
5. **audio_processor.py** - Audio file validation and preprocessing
|
||||
- `validate_audio_file()` (line 15) - Checks file existence, format, and size
|
||||
- `process_audio()` (line 50) - Decodes audio using faster_whisper's decode_audio
|
||||
5. **src/core/job_queue.py** - Async job queue manager
|
||||
- `JobQueue` class manages FIFO queue with background worker thread
|
||||
- `submit_job()` - Validates audio, checks GPU health, adds to queue
|
||||
- `get_job_status()` - Returns current job status and queue position
|
||||
- `get_job_result()` - Returns transcription result for completed jobs
|
||||
- Jobs persist to disk as JSON files for crash recovery
|
||||
- Single worker thread processes jobs sequentially (prevents GPU contention)
|
||||
|
||||
6. **formatters.py** - Output format conversion
|
||||
6. **src/core/gpu_health.py** - GPU health monitoring
|
||||
- `check_gpu_health()`:39 - Real GPU test using tiny model + test audio
|
||||
- `GPUHealthStatus` dataclass contains detailed GPU metrics
|
||||
- **CRITICAL**: Raises RuntimeError if device=cuda but GPU fails (lines 99-135)
|
||||
- Prevents silent CPU fallback that would cause 10-100x slowdown
|
||||
- `HealthMonitor` class for periodic background monitoring
|
||||
|
||||
7. **src/utils/audio_processor.py** - Audio file validation and preprocessing
|
||||
- `validate_audio_file()`:15 - Checks file existence, format, and size
|
||||
- `process_audio()`:50 - Decodes audio using faster_whisper's decode_audio
|
||||
|
||||
8. **src/utils/formatters.py** - Output format conversion
|
||||
- `format_vtt()`, `format_srt()`, `format_txt()`, `format_json()` - Convert segments to various formats
|
||||
- All formatters accept segment lists from Whisper output
|
||||
|
||||
9. **src/utils/test_audio_generator.py** - Test audio generation
|
||||
- `generate_test_audio()` - Creates synthetic 1-second audio for GPU health checks
|
||||
- Uses numpy to generate sine wave, no external audio files needed
|
||||
|
||||
### Key Architecture Patterns
|
||||
|
||||
- **Dual Server Architecture**: Both MCP and REST API servers import and use the same core modules (transcriber.py, model_manager.py, audio_processor.py, formatters.py), ensuring consistent behavior
|
||||
- **Model Caching**: Models are cached in `model_instances` dictionary with key format `{model_name}_{device}_{compute_type}` (model_manager.py:84). This cache is shared if both servers run in the same process
|
||||
- **Batch Processing**: CUDA devices automatically use BatchedInferencePipeline for performance (model_manager.py:109-134)
|
||||
- **Environment Variable Configuration**: All transcription parameters support env var defaults (transcriber.py:19-36)
|
||||
- **Device Auto-Detection**: `device="auto"` automatically selects CUDA if available, otherwise CPU (model_manager.py:64-66)
|
||||
- **Dual Server Architecture**: Both MCP and REST API servers import and use the same core modules (core.transcriber, core.model_manager, utils.audio_processor, utils.formatters), ensuring consistent behavior
|
||||
- **Model Caching**: Models are cached in `model_instances` dictionary with key format `{model_name}_{device}_{compute_type}` (src/core/model_manager.py:104). This cache is shared if both servers run in the same process
|
||||
- **Batch Processing**: CUDA devices automatically use BatchedInferencePipeline for performance (src/core/model_manager.py:132-160)
|
||||
- **Environment Variable Configuration**: All transcription parameters support env var defaults (src/core/transcriber.py:21-37)
|
||||
- **GPU-Only Mode**: Service is configured for GPU-only operation. `device="auto"` requires CUDA, `device="cpu"` is rejected (src/core/model_manager.py:64-90)
|
||||
- **Async Job Queue**: Long-running transcriptions use async queue pattern to prevent HTTP timeouts. Jobs return immediately with job_id for polling
|
||||
- **GPU Health Monitoring**: Real GPU tests with tiny model prevent silent CPU fallback. Jobs are rejected immediately if GPU fails rather than running 10-100x slower on CPU
|
||||
|
||||
## Environment Variables
|
||||
|
||||
@@ -137,6 +192,22 @@ All configuration can be set via environment variables in run_mcp_server.sh and
|
||||
- `API_HOST` - API server host (default: 0.0.0.0)
|
||||
- `API_PORT` - API server port (default: 8000)
|
||||
|
||||
**Job Queue Configuration (if using async features):**
|
||||
- `JOB_QUEUE_MAX_SIZE` - Maximum queue size (default: 100)
|
||||
- `JOB_METADATA_DIR` - Directory for job metadata JSON files
|
||||
- `JOB_RETENTION_DAYS` - Auto-cleanup old jobs (0=disabled)
|
||||
|
||||
**GPU Health Monitoring:**
|
||||
- `GPU_HEALTH_CHECK_ENABLED` - Enable periodic GPU monitoring (true/false)
|
||||
- `GPU_HEALTH_CHECK_INTERVAL_MINUTES` - Monitoring interval (default: 10)
|
||||
- `GPU_HEALTH_TEST_MODEL` - Model for health checks (default: tiny)
|
||||
|
||||
**GPU Auto-Reset Configuration:**
|
||||
- `GPU_RESET_COOLDOWN_MINUTES` - Minimum time between GPU reset attempts (default: 5 minutes)
|
||||
- Prevents reset loops while allowing recovery from sleep/wake cycles
|
||||
- Auto-reset is **enabled by default**
|
||||
- Service terminates if GPU unavailable after reset attempt
|
||||
|
||||
**Transcription Configuration (shared by both servers):**
|
||||
|
||||
- `CUDA_VISIBLE_DEVICES` - GPU device selection
|
||||
@@ -144,7 +215,7 @@ All configuration can be set via environment variables in run_mcp_server.sh and
|
||||
- `TRANSCRIPTION_OUTPUT_DIR` - Default output directory for single transcriptions
|
||||
- `TRANSCRIPTION_BATCH_OUTPUT_DIR` - Default output directory for batch processing
|
||||
- `TRANSCRIPTION_MODEL` - Model size (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
- `TRANSCRIPTION_DEVICE` - Execution device (cpu, cuda, auto)
|
||||
- `TRANSCRIPTION_DEVICE` - Execution device (cuda, auto) - **NOTE: cpu is rejected in GPU-only mode**
|
||||
- `TRANSCRIPTION_COMPUTE_TYPE` - Computation type (float16, int8, auto)
|
||||
- `TRANSCRIPTION_OUTPUT_FORMAT` - Output format (vtt, srt, txt, json)
|
||||
- `TRANSCRIPTION_BEAM_SIZE` - Beam search size (default: 5)
|
||||
@@ -242,7 +313,7 @@ Upload an audio file and transcribe it immediately. Returns the transcription fi
|
||||
# Get model information
|
||||
curl http://localhost:8000/models
|
||||
|
||||
# Transcribe existing file
|
||||
# Transcribe existing file (synchronous)
|
||||
curl -X POST http://localhost:8000/transcribe \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"audio_path": "/path/to/audio.mp3", "output_format": "txt"}'
|
||||
@@ -252,14 +323,207 @@ curl -X POST http://localhost:8000/upload-transcribe \
|
||||
-F "file=@audio.mp3" \
|
||||
-F "output_format=txt" \
|
||||
-F "model_name=large-v3"
|
||||
|
||||
# Async job queue (if enabled)
|
||||
# Submit job
|
||||
curl -X POST http://localhost:8000/jobs \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"audio_path": "/path/to/audio.mp3"}'
|
||||
# Returns: {"job_id": "abc-123", "status": "queued", "queue_position": 1}
|
||||
|
||||
# Check status
|
||||
curl http://localhost:8000/jobs/abc-123
|
||||
# Returns: {"status": "running", ...}
|
||||
|
||||
# Get result (when completed)
|
||||
curl http://localhost:8000/jobs/abc-123/result
|
||||
# Returns: transcription text
|
||||
|
||||
# Check GPU health
|
||||
curl http://localhost:8000/health/gpu
|
||||
# Returns: {"gpu_available": true, "gpu_working": true, ...}
|
||||
```
|
||||
|
||||
## GPU Auto-Reset Configuration
|
||||
|
||||
### Overview
|
||||
This service features automatic GPU driver reset on CUDA errors, which is especially useful for recovering from sleep/wake cycles. The reset functionality is **enabled by default** and includes cooldown protection to prevent reset loops.
|
||||
|
||||
### Passwordless Sudo Setup (Required)
|
||||
|
||||
For automatic GPU reset to work, you must configure passwordless sudo for NVIDIA commands. Create a sudoers configuration file:
|
||||
|
||||
```bash
|
||||
sudo visudo -f /etc/sudoers.d/whisper-gpu-reset
|
||||
```
|
||||
|
||||
Add the following (replace `your_username` with your actual username):
|
||||
|
||||
```
|
||||
# Whisper GPU Auto-Reset Permissions
|
||||
your_username ALL=(ALL) NOPASSWD: /bin/systemctl stop nvidia-persistenced
|
||||
your_username ALL=(ALL) NOPASSWD: /bin/systemctl start nvidia-persistenced
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia_uvm
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia_drm
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia_modeset
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia_modeset
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia_uvm
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia_drm
|
||||
```
|
||||
|
||||
**Security Note:** These permissions are limited to specific NVIDIA driver commands only. The reset script (`reset_gpu.sh`) is executed with sudo but is part of the codebase and can be audited.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Startup Check**: When the service starts, it performs a GPU health check
|
||||
- If CUDA errors detected → automatic reset attempt → retry
|
||||
- If retry fails → service terminates
|
||||
|
||||
2. **Runtime Check**: Before job submission and model loading
|
||||
- If CUDA errors detected → automatic reset attempt → retry
|
||||
- If retry fails → job rejected, service continues
|
||||
|
||||
3. **Cooldown Protection**: Prevents reset loops
|
||||
- Minimum 5 minutes between reset attempts (configurable via `GPU_RESET_COOLDOWN_MINUTES`)
|
||||
- Cooldown persists across restarts (stored in `/tmp/whisper-gpu-last-reset`)
|
||||
- If reset needed but cooldown active → service/job fails immediately
|
||||
|
||||
### Manual GPU Reset
|
||||
|
||||
You can manually reset the GPU anytime:
|
||||
|
||||
```bash
|
||||
./reset_gpu.sh
|
||||
```
|
||||
|
||||
Or clear the cooldown to allow immediate reset:
|
||||
|
||||
```python
|
||||
from core.gpu_reset import clear_reset_cooldown
|
||||
clear_reset_cooldown()
|
||||
```
|
||||
|
||||
### Behavior Examples
|
||||
|
||||
**After sleep/wake with GPU issue:**
|
||||
```
|
||||
Service starts → GPU check fails (CUDA error)
|
||||
→ Cooldown OK → Reset drivers → Wait 3s → Retry
|
||||
→ Success → Service continues
|
||||
```
|
||||
|
||||
**Multiple failures (hardware issue):**
|
||||
```
|
||||
First failure → Reset → Retry fails → Job fails
|
||||
Second failure within 5 min → Cooldown active → Fail immediately
|
||||
(Prevents reset loop)
|
||||
```
|
||||
|
||||
**Normal operation:**
|
||||
```
|
||||
No CUDA errors → No resets → Normal performance
|
||||
Reset only happens on actual CUDA failures
|
||||
```
|
||||
|
||||
## Important Implementation Details
|
||||
|
||||
- GPU memory is checked before loading models (model_manager.py:115-127)
|
||||
### GPU-Only Architecture
|
||||
- **CRITICAL**: Service enforces GPU-only mode. CPU device is explicitly rejected (src/core/model_manager.py:84-90)
|
||||
- `device="auto"` requires CUDA to be available, raises RuntimeError if not (src/core/model_manager.py:64-73)
|
||||
- GPU health checks use real model loading + transcription, not just torch.cuda.is_available()
|
||||
- If GPU health check fails, jobs are rejected immediately rather than silently falling back to CPU
|
||||
- **GPU Auto-Reset**: Automatic driver reset on CUDA errors with 5-minute cooldown (handles sleep/wake issues)
|
||||
|
||||
### Model Management
|
||||
- GPU memory is checked before loading models (src/core/model_manager.py:115-127)
|
||||
- Batch size dynamically adjusts: 32 (>16GB), 16 (>12GB), 8 (>8GB), 4 (>4GB), 2 (otherwise)
|
||||
- VAD (Voice Activity Detection) is enabled by default for better long-audio accuracy (transcriber.py:101)
|
||||
- Word timestamps are enabled by default (transcriber.py:106)
|
||||
- Model loading includes GPU driver test to fail fast if GPU is unavailable (model_manager.py:92)
|
||||
- Files over 1GB generate warnings about processing time (audio_processor.py:42)
|
||||
- Models are cached globally in `model_instances` dict, shared across requests
|
||||
- Model loading includes GPU driver test to fail fast if GPU is unavailable (src/core/model_manager.py:112-114)
|
||||
|
||||
### Transcription Settings
|
||||
- VAD (Voice Activity Detection) is enabled by default for better long-audio accuracy (src/core/transcriber.py:102)
|
||||
- Word timestamps are enabled by default (src/core/transcriber.py:107)
|
||||
- Files over 1GB generate warnings about processing time (src/utils/audio_processor.py:42)
|
||||
- Default output format is "txt" for REST API, configured via environment variables for MCP server
|
||||
|
||||
### Async Job Queue (if enabled)
|
||||
- Single worker thread processes jobs sequentially (prevents GPU memory contention)
|
||||
- Jobs persist to disk as JSON files in JOB_METADATA_DIR
|
||||
- Queue has max size limit (default 100), returns 503 when full
|
||||
- Job status polling recommended every 5-10 seconds for LLM agents
|
||||
|
||||
## Development Workflow
|
||||
|
||||
### Testing GPU Health
|
||||
```python
|
||||
# Test GPU health check manually
|
||||
from src.core.gpu_health import check_gpu_health
|
||||
|
||||
status = check_gpu_health(expected_device="cuda")
|
||||
print(f"GPU Working: {status.gpu_working}")
|
||||
print(f"Device: {status.device_used}")
|
||||
print(f"Test Duration: {status.test_duration_seconds}s")
|
||||
# Expected: <1s for GPU, 3-10s for CPU
|
||||
```
|
||||
|
||||
### Testing Job Queue
|
||||
```python
|
||||
# Test job queue manually
|
||||
from src.core.job_queue import JobQueue
|
||||
|
||||
queue = JobQueue(max_queue_size=100, metadata_dir="/tmp/jobs")
|
||||
queue.start()
|
||||
|
||||
# Submit job
|
||||
job_info = queue.submit_job(
|
||||
audio_path="/path/to/test.mp3",
|
||||
model_name="large-v3",
|
||||
device="cuda"
|
||||
)
|
||||
print(f"Job ID: {job_info['job_id']}")
|
||||
|
||||
# Poll status
|
||||
status = queue.get_job_status(job_info['job_id'])
|
||||
print(f"Status: {status['status']}")
|
||||
|
||||
# Get result when completed
|
||||
result = queue.get_job_result(job_info['job_id'])
|
||||
```
|
||||
|
||||
### Common Debugging
|
||||
|
||||
**Model loading issues:**
|
||||
- Check `WHISPER_MODEL_DIR` is set correctly
|
||||
- Verify GPU memory with `nvidia-smi`
|
||||
- Check logs for GPU driver test failures at model_manager.py:112-114
|
||||
|
||||
**GPU not detected:**
|
||||
- Verify `CUDA_VISIBLE_DEVICES` is set correctly
|
||||
- Check `torch.cuda.is_available()` returns True
|
||||
- Run GPU health check to see detailed error
|
||||
|
||||
**Silent failures:**
|
||||
- Check that service is NOT silently falling back to CPU
|
||||
- GPU health check should RAISE errors, not log warnings
|
||||
- If device=cuda fails, the job should be rejected, not processed on CPU
|
||||
|
||||
**Job queue issues:**
|
||||
- Check `JOB_METADATA_DIR` exists and is writable
|
||||
- Verify background worker thread is running (check logs)
|
||||
- Job metadata files are in {JOB_METADATA_DIR}/{job_id}.json
|
||||
|
||||
### File Locations
|
||||
|
||||
- **Logs**: `mcp.logs` (MCP server), `api.logs` (API server)
|
||||
- **Models**: `$WHISPER_MODEL_DIR` or HuggingFace cache
|
||||
- **Outputs**: `$TRANSCRIPTION_OUTPUT_DIR` or `$TRANSCRIPTION_BATCH_OUTPUT_DIR`
|
||||
- **Job Metadata**: `$JOB_METADATA_DIR/{job_id}.json`
|
||||
|
||||
### Important Development Notes
|
||||
|
||||
- See `DEV_PLAN.md` for detailed architecture and implementation plan for async job queue features
|
||||
- The service is designed for GPU-only operation - CPU fallback is intentionally disabled to prevent silent performance degradation
|
||||
- When modifying model_manager.py, maintain the strict GPU-only enforcement
|
||||
- When adding new endpoints, follow the async pattern if transcription time >30 seconds
|
||||
|
||||
11
Dockerfile
11
Dockerfile
@@ -25,7 +25,7 @@ RUN python -m pip install --upgrade pip
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY fast-whisper-mcp-server/requirements.txt .
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies with CUDA support
|
||||
RUN pip install --no-cache-dir \
|
||||
@@ -35,11 +35,16 @@ RUN pip install --no-cache-dir \
|
||||
mcp[cli]
|
||||
|
||||
# Copy application code
|
||||
COPY fast-whisper-mcp-server/ .
|
||||
COPY src/ ./src/
|
||||
COPY pyproject.toml .
|
||||
COPY README.md .
|
||||
|
||||
# Create directories for models and outputs
|
||||
RUN mkdir -p /models /outputs
|
||||
|
||||
# Set Python path
|
||||
ENV PYTHONPATH=/app/src
|
||||
|
||||
# Set environment variables for GPU
|
||||
ENV WHISPER_MODEL_DIR=/models
|
||||
ENV TRANSCRIPTION_OUTPUT_DIR=/outputs
|
||||
@@ -48,4 +53,4 @@ ENV TRANSCRIPTION_DEVICE=cuda
|
||||
ENV TRANSCRIPTION_COMPUTE_TYPE=float16
|
||||
|
||||
# Run the server
|
||||
CMD ["python", "whisper_server.py"]
|
||||
CMD ["python", "src/servers/whisper_server.py"]
|
||||
163
README.md
163
README.md
@@ -1,163 +0,0 @@
|
||||
# Whisper Speech Recognition MCP Server
|
||||
---
|
||||
[中文文档](README-CN.md)
|
||||
---
|
||||
A high-performance speech recognition MCP server based on Faster Whisper, providing efficient audio transcription capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- Integrated with Faster Whisper for efficient speech recognition
|
||||
- Batch processing acceleration for improved transcription speed
|
||||
- Automatic CUDA acceleration (if available)
|
||||
- Support for multiple model sizes (tiny to large-v3)
|
||||
- Output formats include VTT subtitles, SRT, and JSON
|
||||
- Support for batch transcription of audio files in a folder
|
||||
- Model instance caching to avoid repeated loading
|
||||
- Dynamic batch size adjustment based on GPU memory
|
||||
|
||||
## Installation
|
||||
|
||||
### Dependencies
|
||||
|
||||
- Python 3.10+
|
||||
- faster-whisper>=0.9.0
|
||||
- torch==2.6.0+cu126
|
||||
- torchaudio==2.6.0+cu126
|
||||
- mcp[cli]>=1.2.0
|
||||
|
||||
### Installation Steps
|
||||
|
||||
1. Clone or download this repository
|
||||
2. Create and activate a virtual environment (recommended)
|
||||
3. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### PyTorch Installation Guide
|
||||
|
||||
Install the appropriate version of PyTorch based on your CUDA version:
|
||||
|
||||
- CUDA 12.6:
|
||||
```bash
|
||||
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
```
|
||||
|
||||
- CUDA 12.1:
|
||||
```bash
|
||||
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
- CPU version:
|
||||
```bash
|
||||
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
You can check your CUDA version with `nvcc --version` or `nvidia-smi`.
|
||||
|
||||
## Usage
|
||||
|
||||
### Starting the Server
|
||||
|
||||
On Windows, simply run `start_server.bat`.
|
||||
|
||||
On other platforms, run:
|
||||
|
||||
```bash
|
||||
python whisper_server.py
|
||||
```
|
||||
|
||||
### Configuring Claude Desktop
|
||||
|
||||
1. Open the Claude Desktop configuration file:
|
||||
- Windows: `%APPDATA%\Claude\claude_desktop_config.json`
|
||||
- macOS: `~/Library/Application Support/Claude/claude_desktop_config.json`
|
||||
|
||||
2. Add the Whisper server configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"whisper": {
|
||||
"command": "python",
|
||||
"args": ["D:/path/to/whisper_server.py"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. Restart Claude Desktop
|
||||
|
||||
### Available Tools
|
||||
|
||||
The server provides the following tools:
|
||||
|
||||
1. **get_model_info** - Get information about available Whisper models
|
||||
2. **transcribe** - Transcribe a single audio file
|
||||
3. **batch_transcribe** - Batch transcribe audio files in a folder
|
||||
|
||||
## Performance Optimization Tips
|
||||
|
||||
- Using CUDA acceleration significantly improves transcription speed
|
||||
- Batch processing mode is more efficient for large numbers of short audio files
|
||||
- Batch size is automatically adjusted based on GPU memory size
|
||||
- Using VAD (Voice Activity Detection) filtering improves accuracy for long audio
|
||||
- Specifying the correct language can improve transcription quality
|
||||
|
||||
## Local Testing Methods
|
||||
|
||||
1. Use MCP Inspector for quick testing:
|
||||
|
||||
```bash
|
||||
mcp dev whisper_server.py
|
||||
```
|
||||
|
||||
2. Use Claude Desktop for integration testing
|
||||
|
||||
3. Use command line direct invocation (requires mcp[cli]):
|
||||
|
||||
```bash
|
||||
mcp run whisper_server.py
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The server implements the following error handling mechanisms:
|
||||
|
||||
- Audio file existence check
|
||||
- Model loading failure handling
|
||||
- Transcription process exception catching
|
||||
- GPU memory management
|
||||
- Batch processing parameter adaptive adjustment
|
||||
|
||||
## Project Structure
|
||||
|
||||
- `whisper_server.py`: Main server code
|
||||
- `model_manager.py`: Whisper model loading and caching
|
||||
- `audio_processor.py`: Audio file validation and preprocessing
|
||||
- `formatters.py`: Output formatting (VTT, SRT, JSON)
|
||||
- `transcriber.py`: Core transcription logic
|
||||
- `start_server.bat`: Windows startup script
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
This project was developed with the assistance of these amazing AI tools and models:
|
||||
|
||||
- [GitHub Copilot](https://github.com/features/copilot) - AI pair programmer
|
||||
- [Trae](https://trae.ai/) - Agentic AI coding assistant
|
||||
- [Cline](https://cline.ai/) - AI-powered terminal
|
||||
- [DeepSeek](https://www.deepseek.com/) - Advanced AI model
|
||||
- [Claude-3.7-Sonnet](https://www.anthropic.com/claude) - Anthropic's powerful AI assistant
|
||||
- [Gemini-2.0-Flash](https://ai.google/gemini/) - Google's multimodal AI model
|
||||
- [VS Code](https://code.visualstudio.com/) - Powerful code editor
|
||||
- [Whisper](https://github.com/openai/whisper) - OpenAI's speech recognition model
|
||||
- [Faster Whisper](https://github.com/guillaumekln/faster-whisper) - Optimized Whisper implementation
|
||||
|
||||
Special thanks to these incredible tools and the teams behind them.
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
语音识别MCP服务模块
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
85
api.logs
Normal file
85
api.logs
Normal file
@@ -0,0 +1,85 @@
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:PERFORMING STARTUP GPU HEALTH CHECK
|
||||
INFO:__main__:======================================================================
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 1.04s
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:STARTUP GPU CHECK SUCCESSFUL
|
||||
INFO:__main__:GPU Device: NVIDIA GeForce RTX 3060
|
||||
INFO:__main__:Memory Available: 11.66 GB
|
||||
INFO:__main__:Test Duration: 1.04s
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:Starting Whisper REST API server on 0.0.0.0:8000
|
||||
INFO: Started server process [69821]
|
||||
INFO: Waiting for application startup.
|
||||
INFO:__main__:Starting job queue and health monitor...
|
||||
INFO:core.job_queue:Starting job queue (max size: 100)
|
||||
INFO:core.job_queue:Loading jobs from /media/raid/agents/tools/mcp-transcriptor/outputs/jobs
|
||||
INFO:core.job_queue:Loaded 8 jobs from disk
|
||||
INFO:core.job_queue:Job queue worker loop started
|
||||
INFO:core.job_queue:Job queue worker started
|
||||
INFO:__main__:Job queue started (max_size=100, metadata_dir=/media/raid/agents/tools/mcp-transcriptor/outputs/jobs)
|
||||
INFO:core.gpu_health:Starting GPU health monitor (interval: 10.0 minutes)
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.37s
|
||||
INFO:__main__:GPU health monitor started (interval=10 minutes)
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
|
||||
INFO: 127.0.0.1:48092 - "GET /jobs HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:60874 - "GET /jobs?status=completed&limit=3 HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:60876 - "GET /jobs?status=failed&limit=10 HTTP/1.1" 200 OK
|
||||
INFO:core.job_queue:Running GPU health check before job submission
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.39s
|
||||
INFO:core.job_queue:GPU health check passed
|
||||
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 submitted: /tmp/whisper_test_voice_1s.mp3 (queue position: 1)
|
||||
INFO: 127.0.0.1:58376 - "POST /jobs HTTP/1.1" 200 OK
|
||||
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 started processing
|
||||
INFO:core.model_manager:Running GPU health check with auto-reset before model loading
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.54s
|
||||
INFO:core.model_manager:Loading Whisper model: tiny device: cuda compute type: float16
|
||||
INFO:core.model_manager:Available GPU memory: 12.52 GB
|
||||
INFO:core.model_manager:Enabling batch processing acceleration, batch size: 16
|
||||
INFO:core.transcriber:Starting transcription of file: whisper_test_voice_1s.mp3
|
||||
INFO:utils.audio_processor:Successfully preprocessed audio: whisper_test_voice_1s.mp3
|
||||
INFO:core.transcriber:Using batch acceleration for transcription...
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:VAD filter removed 00:00.000 of audio
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.transcriber:Transcription completed, time used: 0.16 seconds, detected language: en, audio length: 1.51 seconds
|
||||
INFO:core.transcriber:Transcription results saved to: /media/raid/agents/tools/mcp-transcriptor/outputs/whisper_test_voice_1s.txt
|
||||
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 completed successfully: /media/raid/agents/tools/mcp-transcriptor/outputs/whisper_test_voice_1s.txt
|
||||
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 finished: status=completed, duration=1.1s
|
||||
INFO: 127.0.0.1:41646 - "GET /jobs/6be8e49a-bdc1-4508-af99-280bef033cb0 HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:34046 - "GET /jobs/6be8e49a-bdc1-4508-af99-280bef033cb0/result HTTP/1.1" 200 OK
|
||||
INFO:core.job_queue:Running GPU health check before job submission
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.39s
|
||||
INFO:core.job_queue:GPU health check passed
|
||||
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a submitted: /home/uad/agents/tools/mcp-transcriptor/data/test.mp3 (queue position: 1)
|
||||
INFO: 127.0.0.1:44576 - "POST /jobs HTTP/1.1" 200 OK
|
||||
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a started processing
|
||||
INFO:core.model_manager:Running GPU health check with auto-reset before model loading
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.39s
|
||||
INFO:core.model_manager:Loading Whisper model: large-v3 device: cuda compute type: float16
|
||||
INFO:core.model_manager:Available GPU memory: 12.52 GB
|
||||
INFO:core.model_manager:Enabling batch processing acceleration, batch size: 16
|
||||
INFO:core.transcriber:Starting transcription of file: test.mp3
|
||||
INFO:utils.audio_processor:Successfully preprocessed audio: test.mp3
|
||||
INFO:core.transcriber:Using batch acceleration for transcription...
|
||||
INFO:faster_whisper:Processing audio with duration 00:06.955
|
||||
INFO:faster_whisper:VAD filter removed 00:00.299 of audio
|
||||
INFO:core.transcriber:Transcription completed, time used: 0.52 seconds, detected language: en, audio length: 6.95 seconds
|
||||
INFO:core.transcriber:Transcription results saved to: /media/raid/agents/tools/mcp-transcriptor/outputs/test.txt
|
||||
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a completed successfully: /media/raid/agents/tools/mcp-transcriptor/outputs/test.txt
|
||||
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a finished: status=completed, duration=23.3s
|
||||
INFO: 127.0.0.1:59120 - "GET /jobs/41ce74c0-8929-457b-96b3-1b8e4a720a7a HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:53806 - "GET /jobs/41ce74c0-8929-457b-96b3-1b8e4a720a7a/result HTTP/1.1" 200 OK
|
||||
286
api_server.py
286
api_server.py
@@ -1,286 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FastAPI REST API Server for Whisper Transcription
|
||||
Provides HTTP REST endpoints for audio transcription
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
|
||||
from model_manager import get_model_info
|
||||
from transcriber import transcribe_audio, batch_transcribe
|
||||
|
||||
# Logging configuration
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Whisper Speech Recognition API",
|
||||
description="High-performance audio transcription API based on Faster Whisper",
|
||||
version="0.1.1"
|
||||
)
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class TranscribeRequest(BaseModel):
|
||||
audio_path: str = Field(..., description="Path to the audio file on the server")
|
||||
model_name: str = Field("large-v3", description="Whisper model name")
|
||||
device: str = Field("auto", description="Execution device (cpu, cuda, auto)")
|
||||
compute_type: str = Field("auto", description="Computation type (float16, int8, auto)")
|
||||
language: Optional[str] = Field(None, description="Language code (zh, en, ja, etc.)")
|
||||
output_format: str = Field("txt", description="Output format (vtt, srt, json, txt)")
|
||||
beam_size: int = Field(5, description="Beam search size")
|
||||
temperature: float = Field(0.0, description="Sampling temperature")
|
||||
initial_prompt: Optional[str] = Field(None, description="Initial prompt text")
|
||||
output_directory: Optional[str] = Field(None, description="Output directory path")
|
||||
|
||||
|
||||
class BatchTranscribeRequest(BaseModel):
|
||||
audio_folder: str = Field(..., description="Path to folder containing audio files")
|
||||
output_folder: Optional[str] = Field(None, description="Output folder path")
|
||||
model_name: str = Field("large-v3", description="Whisper model name")
|
||||
device: str = Field("auto", description="Execution device (cpu, cuda, auto)")
|
||||
compute_type: str = Field("auto", description="Computation type (float16, int8, auto)")
|
||||
language: Optional[str] = Field(None, description="Language code (zh, en, ja, etc.)")
|
||||
output_format: str = Field("txt", description="Output format (vtt, srt, json, txt)")
|
||||
beam_size: int = Field(5, description="Beam search size")
|
||||
temperature: float = Field(0.0, description="Sampling temperature")
|
||||
initial_prompt: Optional[str] = Field(None, description="Initial prompt text")
|
||||
parallel_files: int = Field(1, description="Number of files to process in parallel")
|
||||
|
||||
|
||||
class TranscribeResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
output_path: Optional[str] = None
|
||||
|
||||
|
||||
class BatchTranscribeResponse(BaseModel):
|
||||
success: bool
|
||||
summary: str
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information"""
|
||||
return {
|
||||
"name": "Whisper Speech Recognition API",
|
||||
"version": "0.1.1",
|
||||
"endpoints": {
|
||||
"GET /health": "Health check",
|
||||
"GET /models": "Get available models information",
|
||||
"POST /transcribe": "Transcribe a single audio file",
|
||||
"POST /batch-transcribe": "Batch transcribe audio files",
|
||||
"POST /upload-transcribe": "Upload and transcribe audio file"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy", "service": "whisper-transcription"}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def get_models():
|
||||
"""Get available Whisper models and configuration information"""
|
||||
try:
|
||||
model_info = get_model_info()
|
||||
return JSONResponse(content=json.loads(model_info))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model info: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/transcribe", response_model=TranscribeResponse)
|
||||
async def transcribe(request: TranscribeRequest):
|
||||
"""
|
||||
Transcribe a single audio file
|
||||
|
||||
The audio file must already exist on the server at the specified path.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Received transcription request for: {request.audio_path}")
|
||||
|
||||
result = transcribe_audio(
|
||||
audio_path=request.audio_path,
|
||||
model_name=request.model_name,
|
||||
device=request.device,
|
||||
compute_type=request.compute_type,
|
||||
language=request.language,
|
||||
output_format=request.output_format,
|
||||
beam_size=request.beam_size,
|
||||
temperature=request.temperature,
|
||||
initial_prompt=request.initial_prompt,
|
||||
output_directory=request.output_directory
|
||||
)
|
||||
|
||||
# Parse result to determine success
|
||||
if result.startswith("Error") or "failed" in result.lower():
|
||||
return TranscribeResponse(
|
||||
success=False,
|
||||
message=result,
|
||||
output_path=None
|
||||
)
|
||||
|
||||
# Extract output path from success message
|
||||
output_path = None
|
||||
if "saved to:" in result:
|
||||
output_path = result.split("saved to:")[1].strip()
|
||||
|
||||
return TranscribeResponse(
|
||||
success=True,
|
||||
message=result,
|
||||
output_path=output_path
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/batch-transcribe", response_model=BatchTranscribeResponse)
|
||||
async def batch_transcribe_endpoint(request: BatchTranscribeRequest):
|
||||
"""
|
||||
Batch transcribe all audio files in a folder
|
||||
|
||||
Processes all supported audio files in the specified folder.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Received batch transcription request for: {request.audio_folder}")
|
||||
|
||||
result = batch_transcribe(
|
||||
audio_folder=request.audio_folder,
|
||||
output_folder=request.output_folder,
|
||||
model_name=request.model_name,
|
||||
device=request.device,
|
||||
compute_type=request.compute_type,
|
||||
language=request.language,
|
||||
output_format=request.output_format,
|
||||
beam_size=request.beam_size,
|
||||
temperature=request.temperature,
|
||||
initial_prompt=request.initial_prompt,
|
||||
parallel_files=request.parallel_files
|
||||
)
|
||||
|
||||
# Check if there were errors
|
||||
success = not result.startswith("Error")
|
||||
|
||||
return BatchTranscribeResponse(
|
||||
success=success,
|
||||
summary=result
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch transcription failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Batch transcription failed: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/upload-transcribe")
|
||||
async def upload_and_transcribe(
|
||||
file: UploadFile = File(...),
|
||||
model_name: str = Form("large-v3"),
|
||||
device: str = Form("auto"),
|
||||
compute_type: str = Form("auto"),
|
||||
language: Optional[str] = Form(None),
|
||||
output_format: str = Form("txt"),
|
||||
beam_size: int = Form(5),
|
||||
temperature: float = Form(0.0),
|
||||
initial_prompt: Optional[str] = Form(None)
|
||||
):
|
||||
"""
|
||||
Upload an audio file and transcribe it
|
||||
|
||||
This endpoint accepts file uploads via multipart/form-data.
|
||||
"""
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
try:
|
||||
# Create temporary directory for upload
|
||||
temp_dir = tempfile.mkdtemp(prefix="whisper_upload_")
|
||||
|
||||
# Save uploaded file
|
||||
file_ext = os.path.splitext(file.filename)[1]
|
||||
temp_audio_path = os.path.join(temp_dir, f"upload{file_ext}")
|
||||
|
||||
with open(temp_audio_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
logger.info(f"Uploaded file saved to: {temp_audio_path}")
|
||||
|
||||
# Transcribe the uploaded file
|
||||
result = transcribe_audio(
|
||||
audio_path=temp_audio_path,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
output_directory=temp_dir
|
||||
)
|
||||
|
||||
# Parse result
|
||||
if result.startswith("Error") or "failed" in result.lower():
|
||||
# Clean up temp files
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
raise HTTPException(status_code=500, detail=result)
|
||||
|
||||
# Extract output path
|
||||
output_path = None
|
||||
if "saved to:" in result:
|
||||
output_path = result.split("saved to:")[1].strip()
|
||||
|
||||
# Return the transcription file
|
||||
if output_path and os.path.exists(output_path):
|
||||
return FileResponse(
|
||||
output_path,
|
||||
media_type="text/plain",
|
||||
filename=os.path.basename(output_path),
|
||||
background=None # Don't delete yet, we'll clean up after
|
||||
)
|
||||
else:
|
||||
# Clean up temp files
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
return JSONResponse(content={
|
||||
"success": True,
|
||||
"message": result
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Upload and transcribe failed: {str(e)}")
|
||||
# Clean up temp files on error
|
||||
if 'temp_dir' in locals():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
raise HTTPException(status_code=500, detail=f"Upload and transcribe failed: {str(e)}")
|
||||
finally:
|
||||
await file.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Get configuration from environment variables
|
||||
host = os.getenv("API_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("API_PORT", "8000"))
|
||||
|
||||
logger.info(f"Starting Whisper REST API server on {host}:{port}")
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info"
|
||||
)
|
||||
@@ -1,67 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Audio Processing Module
|
||||
Responsible for audio file validation and preprocessing
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Union, Any
|
||||
from faster_whisper import decode_audio
|
||||
|
||||
# Log configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def validate_audio_file(audio_path: str) -> str:
|
||||
"""
|
||||
Validate if an audio file is valid
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
|
||||
Returns:
|
||||
str: Validation result, "ok" indicates validation passed, otherwise returns error message
|
||||
"""
|
||||
# Validate parameters
|
||||
if not os.path.exists(audio_path):
|
||||
return f"Error: Audio file does not exist: {audio_path}"
|
||||
|
||||
# Validate file format
|
||||
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
||||
file_ext = os.path.splitext(audio_path)[1].lower()
|
||||
if file_ext not in supported_formats:
|
||||
return f"Error: Unsupported audio format: {file_ext}. Supported formats: {', '.join(supported_formats)}"
|
||||
|
||||
# Validate file size
|
||||
try:
|
||||
file_size = os.path.getsize(audio_path)
|
||||
if file_size == 0:
|
||||
return f"Error: Audio file is empty: {audio_path}"
|
||||
|
||||
# Warning for large files (over 1GB)
|
||||
if file_size > 1024 * 1024 * 1024:
|
||||
logger.warning(f"Warning: File size exceeds 1GB, may require longer processing time: {audio_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check file size: {str(e)}")
|
||||
return f"Error: Failed to check file size: {str(e)}"
|
||||
|
||||
return "ok"
|
||||
|
||||
def process_audio(audio_path: str) -> Union[str, Any]:
|
||||
"""
|
||||
Process audio file, perform decoding and preprocessing
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
|
||||
Returns:
|
||||
Union[str, Any]: Processed audio data or original file path
|
||||
"""
|
||||
# Try to preprocess audio using decode_audio to handle more formats
|
||||
try:
|
||||
audio_data = decode_audio(audio_path)
|
||||
logger.info(f"Successfully preprocessed audio: {os.path.basename(audio_path)}")
|
||||
return audio_data
|
||||
except Exception as audio_error:
|
||||
logger.warning(f"Audio preprocessing failed, will use file path directly: {str(audio_error)}")
|
||||
return audio_path
|
||||
25
mcp.logs
Normal file
25
mcp.logs
Normal file
@@ -0,0 +1,25 @@
|
||||
starting mcp server for whisper stt transcriptor
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:PERFORMING STARTUP GPU HEALTH CHECK
|
||||
INFO:__main__:======================================================================
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.93s
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:STARTUP GPU CHECK SUCCESSFUL
|
||||
INFO:__main__:GPU Device: NVIDIA GeForce RTX 3060
|
||||
INFO:__main__:Memory Available: 11.66 GB
|
||||
INFO:__main__:Test Duration: 0.93s
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:Initializing job queue...
|
||||
INFO:core.job_queue:Starting job queue (max size: 100)
|
||||
INFO:core.job_queue:Loading jobs from /media/raid/agents/tools/mcp-transcriptor/outputs/jobs
|
||||
INFO:core.job_queue:Loaded 5 jobs from disk
|
||||
INFO:core.job_queue:Job queue worker loop started
|
||||
INFO:core.job_queue:Job queue worker started
|
||||
INFO:__main__:Job queue started (max_size=100, metadata_dir=/media/raid/agents/tools/mcp-transcriptor/outputs/jobs)
|
||||
INFO:core.gpu_health:Starting GPU health monitor (interval: 10.0 minutes)
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.38s
|
||||
INFO:__main__:GPU health monitor started (interval=10 minutes)
|
||||
@@ -1,9 +1,12 @@
|
||||
[project]
|
||||
name = "fast-whisper-mcp-server"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
version = "0.1.1"
|
||||
description = "High-performance speech recognition service with MCP and REST API servers"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"faster-whisper>=1.1.1",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["src"]
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# uv pip install -r ./requirements.txt --index-url https://download.pytorch.org/whl/cu126
|
||||
# uv pip install -r ./requirements.txt --index-url https://download.pytorch.org/whl/cu124
|
||||
faster-whisper
|
||||
torch #==2.6.0+cu126
|
||||
torchaudio #==2.6.0+cu126
|
||||
torch #==2.6.0+cu124
|
||||
torchaudio #==2.6.0+cu124
|
||||
|
||||
# uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
# uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
|
||||
# pip install faster-whisper>=0.9.0
|
||||
# pip install mcp[cli]>=1.2.0
|
||||
mcp[cli]
|
||||
@@ -13,12 +13,21 @@ fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.32.0
|
||||
python-multipart>=0.0.9
|
||||
|
||||
# Test audio generation dependencies
|
||||
gTTS>=2.3.0
|
||||
pyttsx3>=2.90
|
||||
scipy>=1.10.0
|
||||
numpy>=1.24.0
|
||||
|
||||
# PyTorch Installation Guide:
|
||||
# Please install the appropriate version of PyTorch based on your CUDA version:
|
||||
#
|
||||
# • CUDA 12.6:
|
||||
# pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
#
|
||||
# • CUDA 12.4:
|
||||
# pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
|
||||
#
|
||||
# • CUDA 12.1:
|
||||
# pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
||||
#
|
||||
|
||||
70
reset_gpu.sh
Executable file
70
reset_gpu.sh
Executable file
@@ -0,0 +1,70 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to reset NVIDIA GPU drivers without rebooting
|
||||
# This reloads kernel modules and restarts nvidia-persistenced service
|
||||
|
||||
echo "============================================================"
|
||||
echo "NVIDIA GPU Driver Reset Script"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
|
||||
# Stop nvidia-persistenced service
|
||||
echo "Stopping nvidia-persistenced service..."
|
||||
sudo systemctl stop nvidia-persistenced
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✓ nvidia-persistenced stopped"
|
||||
else
|
||||
echo "✗ Failed to stop nvidia-persistenced"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Unload NVIDIA kernel modules (in correct order)
|
||||
echo "Unloading NVIDIA kernel modules..."
|
||||
sudo rmmod nvidia_uvm 2>/dev/null && echo "✓ nvidia_uvm unloaded" || echo " nvidia_uvm not loaded or failed to unload"
|
||||
sudo rmmod nvidia_drm 2>/dev/null && echo "✓ nvidia_drm unloaded" || echo " nvidia_drm not loaded or failed to unload"
|
||||
sudo rmmod nvidia_modeset 2>/dev/null && echo "✓ nvidia_modeset unloaded" || echo " nvidia_modeset not loaded or failed to unload"
|
||||
sudo rmmod nvidia 2>/dev/null && echo "✓ nvidia unloaded" || echo " nvidia not loaded or failed to unload"
|
||||
echo ""
|
||||
|
||||
# Small delay to ensure clean unload
|
||||
sleep 1
|
||||
|
||||
# Reload NVIDIA kernel modules (in correct order)
|
||||
echo "Loading NVIDIA kernel modules..."
|
||||
sudo modprobe nvidia && echo "✓ nvidia loaded" || { echo "✗ Failed to load nvidia"; exit 1; }
|
||||
sudo modprobe nvidia_modeset && echo "✓ nvidia_modeset loaded" || { echo "✗ Failed to load nvidia_modeset"; exit 1; }
|
||||
sudo modprobe nvidia_uvm && echo "✓ nvidia_uvm loaded" || { echo "✗ Failed to load nvidia_uvm"; exit 1; }
|
||||
sudo modprobe nvidia_drm && echo "✓ nvidia_drm loaded" || { echo "✗ Failed to load nvidia_drm"; exit 1; }
|
||||
echo ""
|
||||
|
||||
# Restart nvidia-persistenced service
|
||||
echo "Starting nvidia-persistenced service..."
|
||||
sudo systemctl start nvidia-persistenced
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✓ nvidia-persistenced started"
|
||||
else
|
||||
echo "✗ Failed to start nvidia-persistenced"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Verify GPU is accessible
|
||||
echo "Verifying GPU accessibility..."
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✓ GPU reset successful"
|
||||
else
|
||||
echo "✗ GPU not accessible"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✗ nvidia-smi not found"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "============================================================"
|
||||
echo "GPU driver reset completed successfully"
|
||||
echo "============================================================"
|
||||
@@ -5,6 +5,12 @@ datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
# Set Python path to include src directory
|
||||
export PYTHONPATH="/home/uad/agents/tools/mcp-transcriptor/src:$PYTHONPATH"
|
||||
|
||||
# Set CUDA library path
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/targets/x86_64-linux/lib:$LD_LIBRARY_PATH
|
||||
|
||||
# Set environment variables
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/data/models"
|
||||
@@ -23,6 +29,19 @@ export TRANSCRIPTION_FILENAME_PREFIX=""
|
||||
export API_HOST="0.0.0.0"
|
||||
export API_PORT="8000"
|
||||
|
||||
# GPU Auto-Reset Configuration
|
||||
export GPU_RESET_COOLDOWN_MINUTES=5 # Minimum time between GPU reset attempts
|
||||
|
||||
# Job Queue Configuration
|
||||
export JOB_QUEUE_MAX_SIZE=100
|
||||
export JOB_METADATA_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
|
||||
export JOB_RETENTION_DAYS=7
|
||||
|
||||
# GPU Health Monitoring
|
||||
export GPU_HEALTH_CHECK_ENABLED=true
|
||||
export GPU_HEALTH_CHECK_INTERVAL_MINUTES=10
|
||||
export GPU_HEALTH_TEST_MODEL="tiny"
|
||||
|
||||
# Log start of the script
|
||||
echo "$(datetime_prefix) Starting Whisper REST API server..."
|
||||
echo "$(datetime_prefix) Model directory: $WHISPER_MODEL_DIR"
|
||||
@@ -37,6 +56,7 @@ fi
|
||||
# Ensure output directories exist
|
||||
mkdir -p "$TRANSCRIPTION_OUTPUT_DIR"
|
||||
mkdir -p "$TRANSCRIPTION_BATCH_OUTPUT_DIR"
|
||||
mkdir -p "$JOB_METADATA_DIR"
|
||||
|
||||
# Run the API server
|
||||
/home/uad/agents/tools/mcp-transcriptor/venv/bin/python -u /home/uad/agents/tools/mcp-transcriptor/api_server.py 2>&1 | tee /home/uad/agents/tools/mcp-transcriptor/api.logs
|
||||
/home/uad/agents/tools/mcp-transcriptor/venv/bin/python -u /home/uad/agents/tools/mcp-transcriptor/src/servers/api_server.py 2>&1 | tee /home/uad/agents/tools/mcp-transcriptor/api.logs
|
||||
|
||||
@@ -9,6 +9,12 @@ datetime_prefix() {
|
||||
USER_ID=$(id -u)
|
||||
GROUP_ID=$(id -g)
|
||||
|
||||
# Set Python path to include src directory
|
||||
export PYTHONPATH="/home/uad/agents/tools/mcp-transcriptor/src:$PYTHONPATH"
|
||||
|
||||
# Set CUDA library path
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/targets/x86_64-linux/lib:$LD_LIBRARY_PATH
|
||||
|
||||
# Set environment variables
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/data/models"
|
||||
@@ -23,6 +29,19 @@ export TRANSCRIPTION_TEMPERATURE="0.0"
|
||||
export TRANSCRIPTION_USE_TIMESTAMP="false"
|
||||
export TRANSCRIPTION_FILENAME_PREFIX="test_"
|
||||
|
||||
# GPU Auto-Reset Configuration
|
||||
export GPU_RESET_COOLDOWN_MINUTES=5 # Minimum time between GPU reset attempts
|
||||
|
||||
# Job Queue Configuration
|
||||
export JOB_QUEUE_MAX_SIZE=100
|
||||
export JOB_METADATA_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
|
||||
export JOB_RETENTION_DAYS=7
|
||||
|
||||
# GPU Health Monitoring
|
||||
export GPU_HEALTH_CHECK_ENABLED=true
|
||||
export GPU_HEALTH_CHECK_INTERVAL_MINUTES=10
|
||||
export GPU_HEALTH_TEST_MODEL="tiny"
|
||||
|
||||
# Log start of the script
|
||||
echo "$(datetime_prefix) Starting whisper server script..."
|
||||
echo "test: $WHISPER_MODEL_DIR"
|
||||
@@ -33,7 +52,10 @@ if [ ! -d "$WHISPER_MODEL_DIR" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure job metadata directory exists
|
||||
mkdir -p "$JOB_METADATA_DIR"
|
||||
|
||||
# Run the Python script with the defined environment variables
|
||||
#/home/uad/agents/tools/mcp-transcriptor/venv/bin/python /home/uad/agents/tools/mcp-transcriptor/whisper_server.py 2>&1 | tee /home/uad/agents/tools/mcp-transcriptor/mcp.logs
|
||||
/home/uad/agents/tools/mcp-transcriptor/venv/bin/python -u /home/uad/agents/tools/mcp-transcriptor/whisper_server.py 2>&1 | tee /home/uad/agents/tools/mcp-transcriptor/mcp.logs
|
||||
/home/uad/agents/tools/mcp-transcriptor/venv/bin/python -u /home/uad/agents/tools/mcp-transcriptor/src/servers/whisper_server.py 2>&1 | tee /home/uad/agents/tools/mcp-transcriptor/mcp.logs
|
||||
|
||||
|
||||
9
src/__init__.py
Normal file
9
src/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Faster Whisper MCP Transcription Service
|
||||
|
||||
High-performance audio transcription service with dual-server architecture
|
||||
(MCP and REST API) featuring async job queue and GPU health monitoring.
|
||||
"""
|
||||
|
||||
__version__ = "0.2.0"
|
||||
__author__ = "Whisper MCP Team"
|
||||
6
src/core/__init__.py
Normal file
6
src/core/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Core modules for Whisper transcription service.
|
||||
|
||||
Includes model management, transcription logic, job queue, GPU health monitoring,
|
||||
and GPU reset functionality.
|
||||
"""
|
||||
467
src/core/gpu_health.py
Normal file
467
src/core/gpu_health.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
GPU health monitoring for Whisper transcription service.
|
||||
|
||||
Performs real GPU health checks using actual model loading and transcription,
|
||||
with strict failure handling to prevent silent CPU fallbacks.
|
||||
Includes circuit breaker pattern to prevent repeated failed checks.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
import torch
|
||||
|
||||
from utils.test_audio_generator import generate_test_audio
|
||||
from utils.circuit_breaker import CircuitBreaker, CircuitBreakerOpen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global circuit breaker for GPU health checks
|
||||
_gpu_health_circuit_breaker = CircuitBreaker(
|
||||
name="gpu_health_check",
|
||||
failure_threshold=3, # Open after 3 consecutive failures
|
||||
success_threshold=2, # Close after 2 consecutive successes
|
||||
timeout_seconds=60, # Try again after 60 seconds
|
||||
half_open_max_calls=1 # Only 1 test call in half-open state
|
||||
)
|
||||
|
||||
# Import reset functionality (after logger initialization)
|
||||
try:
|
||||
from core.gpu_reset import (
|
||||
reset_gpu_drivers,
|
||||
can_attempt_reset,
|
||||
record_reset_attempt
|
||||
)
|
||||
GPU_RESET_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"GPU reset functionality not available: {e}")
|
||||
GPU_RESET_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUHealthStatus:
|
||||
"""Data class for GPU health check results."""
|
||||
gpu_available: bool # torch.cuda.is_available()
|
||||
gpu_working: bool # Model actually loaded and ran on GPU
|
||||
device_used: str # "cuda" or "cpu"
|
||||
device_name: str # GPU name if available
|
||||
memory_total_gb: float # Total GPU memory
|
||||
memory_available_gb: float # Available GPU memory
|
||||
test_duration_seconds: float # How long test took
|
||||
timestamp: str # ISO timestamp
|
||||
error: Optional[str] = None # Error message if any
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def _check_gpu_health_internal(expected_device: str = "auto") -> GPUHealthStatus:
|
||||
"""
|
||||
Comprehensive GPU health check using real model + transcription.
|
||||
|
||||
Args:
|
||||
expected_device: Expected device ("auto", "cuda", "cpu")
|
||||
|
||||
Returns:
|
||||
GPUHealthStatus object
|
||||
|
||||
Raises:
|
||||
RuntimeError: If expected_device="cuda" but GPU test fails
|
||||
|
||||
Implementation:
|
||||
1. Generate test audio (1 second)
|
||||
2. Load tiny model with requested device
|
||||
3. Transcribe test audio
|
||||
4. Time the operation
|
||||
5. Verify model actually ran on GPU (check torch.cuda.memory_allocated)
|
||||
6. CRITICAL: If expected_device="cuda" but used="cpu" → raise RuntimeError
|
||||
7. Return detailed status
|
||||
|
||||
Performance Expectations:
|
||||
- GPU (tiny model): 0.3-1.0 seconds
|
||||
- CPU (tiny model): 3-10 seconds
|
||||
- If GPU test takes >2 seconds, likely running on CPU
|
||||
"""
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
start_time = time.time()
|
||||
|
||||
# Initialize status with defaults
|
||||
gpu_available = torch.cuda.is_available()
|
||||
device_name = ""
|
||||
memory_total_gb = 0.0
|
||||
memory_available_gb = 0.0
|
||||
actual_device = "cpu"
|
||||
gpu_working = False
|
||||
error_msg = None
|
||||
|
||||
# Get GPU info if available
|
||||
if gpu_available:
|
||||
try:
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
memory_total_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||||
memory_available_gb = (torch.cuda.get_device_properties(0).total_memory -
|
||||
torch.cuda.memory_allocated(0)) / (1024**3)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get GPU info: {e}")
|
||||
|
||||
try:
|
||||
# Generate test audio
|
||||
test_audio_path = generate_test_audio(duration_seconds=1.0)
|
||||
|
||||
# Import here to avoid circular dependencies
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
# Determine device to use
|
||||
test_device = "cpu"
|
||||
test_compute_type = "int8"
|
||||
|
||||
if expected_device == "cuda" or (expected_device == "auto" and gpu_available):
|
||||
test_device = "cuda"
|
||||
test_compute_type = "float16"
|
||||
|
||||
# Record GPU memory before loading model
|
||||
gpu_memory_before = 0
|
||||
if torch.cuda.is_available():
|
||||
gpu_memory_before = torch.cuda.memory_allocated(0)
|
||||
|
||||
# Load tiny model and transcribe
|
||||
try:
|
||||
model = WhisperModel(
|
||||
"tiny",
|
||||
device=test_device,
|
||||
compute_type=test_compute_type
|
||||
)
|
||||
|
||||
# Transcribe test audio
|
||||
segments, info = model.transcribe(test_audio_path, beam_size=1)
|
||||
|
||||
# Consume segments (needed to actually run inference)
|
||||
list(segments)
|
||||
|
||||
# Check if GPU was actually used
|
||||
# faster-whisper uses CTranslate2 which manages GPU memory separately
|
||||
# from PyTorch, so we check model.model.device instead of torch memory
|
||||
actual_device = model.model.device
|
||||
if actual_device == "cuda":
|
||||
gpu_working = True
|
||||
else:
|
||||
gpu_working = False
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Model loading/inference failed: {str(e)}"
|
||||
actual_device = "cpu"
|
||||
gpu_working = False
|
||||
logger.error(f"GPU health check failed: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Health check setup failed: {str(e)}"
|
||||
logger.error(f"GPU health check error: {error_msg}")
|
||||
|
||||
# Calculate test duration
|
||||
test_duration = time.time() - start_time
|
||||
|
||||
# Create status object
|
||||
status = GPUHealthStatus(
|
||||
gpu_available=gpu_available,
|
||||
gpu_working=gpu_working,
|
||||
device_used=actual_device,
|
||||
device_name=device_name,
|
||||
memory_total_gb=round(memory_total_gb, 2),
|
||||
memory_available_gb=round(memory_available_gb, 2),
|
||||
test_duration_seconds=round(test_duration, 2),
|
||||
timestamp=timestamp,
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
# CRITICAL: Reject if expected CUDA but got CPU
|
||||
# This service is GPU-only, so also reject device="auto" that falls back to CPU
|
||||
if expected_device == "cuda" and actual_device == "cpu":
|
||||
error_message = (
|
||||
f"GPU device requested but model loaded on CPU. "
|
||||
f"This indicates GPU driver issues or insufficient memory. "
|
||||
f"Transcription would be 10-100x slower than expected. "
|
||||
f"Please check CUDA installation and GPU availability. "
|
||||
f"Details: {error_msg or 'Unknown cause'}"
|
||||
)
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
# CRITICAL: Service requires GPU - reject CPU fallback even with device="auto"
|
||||
if expected_device == "auto" and actual_device == "cpu":
|
||||
error_message = (
|
||||
f"GPU required but not available. This service is configured for GPU-only operation. "
|
||||
f"CUDA is not available or GPU health check failed. "
|
||||
f"Please ensure CUDA is properly installed and GPU is accessible. "
|
||||
f"Details: {error_msg or 'CUDA not available'}"
|
||||
)
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
# Log health check result
|
||||
if gpu_working:
|
||||
logger.info(
|
||||
f"GPU health check passed: {device_name}, "
|
||||
f"test duration: {test_duration:.2f}s"
|
||||
)
|
||||
elif gpu_available and not gpu_working:
|
||||
logger.warning(
|
||||
f"GPU available but health check failed. "
|
||||
f"Test duration: {test_duration:.2f}s, Error: {error_msg}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"GPU not available, using CPU. Test duration: {test_duration:.2f}s")
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def check_gpu_health(expected_device: str = "auto", use_circuit_breaker: bool = True) -> GPUHealthStatus:
|
||||
"""
|
||||
GPU health check with optional circuit breaker protection.
|
||||
|
||||
This is the main entry point for GPU health checks. By default, it uses
|
||||
circuit breaker pattern to prevent repeated failed checks.
|
||||
|
||||
Args:
|
||||
expected_device: Expected device ("auto", "cuda", "cpu")
|
||||
use_circuit_breaker: Enable circuit breaker protection (default: True)
|
||||
|
||||
Returns:
|
||||
GPUHealthStatus object
|
||||
|
||||
Raises:
|
||||
RuntimeError: If expected_device="cuda" but GPU test fails
|
||||
CircuitBreakerOpen: If circuit breaker is open (too many recent failures)
|
||||
"""
|
||||
if use_circuit_breaker:
|
||||
try:
|
||||
return _gpu_health_circuit_breaker.call(_check_gpu_health_internal, expected_device)
|
||||
except CircuitBreakerOpen as e:
|
||||
# Circuit is open, fail fast without attempting check
|
||||
logger.warning(f"GPU health check circuit breaker is OPEN: {e}")
|
||||
raise RuntimeError(f"GPU health check unavailable: {str(e)}")
|
||||
else:
|
||||
return _check_gpu_health_internal(expected_device)
|
||||
|
||||
|
||||
def get_circuit_breaker_stats() -> dict:
|
||||
"""
|
||||
Get current circuit breaker statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with circuit state and failure/success counts
|
||||
"""
|
||||
return _gpu_health_circuit_breaker.get_stats()
|
||||
|
||||
|
||||
def reset_circuit_breaker():
|
||||
"""
|
||||
Manually reset the GPU health check circuit breaker.
|
||||
|
||||
Useful for:
|
||||
- Testing
|
||||
- Manual intervention after fixing GPU issues
|
||||
- Clearing persistent error state
|
||||
"""
|
||||
_gpu_health_circuit_breaker.reset()
|
||||
logger.info("GPU health check circuit breaker manually reset")
|
||||
|
||||
|
||||
def check_gpu_health_with_reset(
|
||||
expected_device: str = "cuda",
|
||||
auto_reset: bool = True
|
||||
) -> GPUHealthStatus:
|
||||
"""
|
||||
Check GPU health with automatic reset on failure.
|
||||
|
||||
This function wraps check_gpu_health() and adds automatic GPU driver
|
||||
reset capability with cooldown protection to prevent reset loops.
|
||||
|
||||
Behavior:
|
||||
1. Attempt GPU health check
|
||||
2. If fails and auto_reset=True:
|
||||
a. Check cooldown (prevents reset loops)
|
||||
b. If cooldown active → raise error immediately
|
||||
c. If cooldown OK → reset drivers → wait 3s → retry
|
||||
3. If retry fails → raise error (terminate service)
|
||||
|
||||
Args:
|
||||
expected_device: Expected device ("auto", "cuda", "cpu")
|
||||
auto_reset: Enable automatic reset on failure (default: True)
|
||||
|
||||
Returns:
|
||||
GPUHealthStatus object
|
||||
|
||||
Raises:
|
||||
RuntimeError: If GPU check fails and reset not available/allowed,
|
||||
or if retry after reset fails
|
||||
"""
|
||||
try:
|
||||
# First attempt
|
||||
return check_gpu_health(expected_device)
|
||||
|
||||
except (RuntimeError, Exception) as first_error:
|
||||
# GPU health check failed
|
||||
|
||||
if not auto_reset:
|
||||
logger.error(f"GPU health check failed (auto-reset disabled): {first_error}")
|
||||
raise
|
||||
|
||||
if not GPU_RESET_AVAILABLE:
|
||||
logger.error(
|
||||
f"GPU health check failed but reset not available: {first_error}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"GPU unavailable and reset functionality not available: {first_error}"
|
||||
)
|
||||
|
||||
# Check if reset is allowed (cooldown protection)
|
||||
if not can_attempt_reset():
|
||||
error_msg = (
|
||||
f"GPU health check failed but reset cooldown is active. "
|
||||
f"This prevents reset loops. Last error: {first_error}. "
|
||||
f"Service will terminate. If this happens after sleep/wake, "
|
||||
f"the cooldown should expire soon and next restart will work."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Attempt GPU reset
|
||||
logger.warning("=" * 70)
|
||||
logger.warning(f"GPU HEALTH CHECK FAILED: {first_error}")
|
||||
logger.warning("Attempting automatic GPU driver reset...")
|
||||
logger.warning("=" * 70)
|
||||
|
||||
try:
|
||||
reset_gpu_drivers()
|
||||
record_reset_attempt()
|
||||
|
||||
logger.info("GPU reset completed, waiting 3 seconds for stabilization...")
|
||||
time.sleep(3)
|
||||
|
||||
# Retry GPU health check
|
||||
logger.info("Retrying GPU health check after reset...")
|
||||
status = check_gpu_health(expected_device)
|
||||
|
||||
logger.warning("=" * 70)
|
||||
logger.warning("GPU HEALTH CHECK SUCCESS AFTER RESET")
|
||||
logger.warning(f"Device: {status.device_name}")
|
||||
logger.warning(f"Memory: {status.memory_available_gb:.2f} GB available")
|
||||
logger.warning("=" * 70)
|
||||
|
||||
return status
|
||||
|
||||
except Exception as reset_error:
|
||||
error_msg = (
|
||||
f"GPU health check failed after reset attempt. "
|
||||
f"Original error: {first_error}. "
|
||||
f"Reset error: {reset_error}. "
|
||||
f"Service will terminate."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
class HealthMonitor:
|
||||
"""Background thread for periodic GPU health monitoring."""
|
||||
|
||||
def __init__(self, check_interval_minutes: int = 10):
|
||||
"""
|
||||
Initialize health monitor.
|
||||
|
||||
Args:
|
||||
check_interval_minutes: Interval between health checks (default: 10)
|
||||
"""
|
||||
self._check_interval_seconds = check_interval_minutes * 60
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
self._latest_status: Optional[GPUHealthStatus] = None
|
||||
self._history: List[GPUHealthStatus] = []
|
||||
self._max_history = 100
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def start(self):
|
||||
"""Start background monitoring thread."""
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
logger.warning("Health monitor already running")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Starting GPU health monitor "
|
||||
f"(interval: {self._check_interval_seconds / 60:.1f} minutes)"
|
||||
)
|
||||
|
||||
# Run initial health check
|
||||
try:
|
||||
status = check_gpu_health(expected_device="auto")
|
||||
with self._lock:
|
||||
self._latest_status = status
|
||||
self._history.append(status)
|
||||
except Exception as e:
|
||||
logger.error(f"Initial GPU health check failed: {e}")
|
||||
|
||||
# Start background thread
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop background monitoring thread."""
|
||||
if self._thread is None:
|
||||
return
|
||||
|
||||
logger.info("Stopping GPU health monitor")
|
||||
self._stop_event.set()
|
||||
self._thread.join(timeout=5.0)
|
||||
self._thread = None
|
||||
|
||||
def get_latest_status(self) -> Optional[GPUHealthStatus]:
|
||||
"""Get most recent health check result."""
|
||||
with self._lock:
|
||||
return self._latest_status
|
||||
|
||||
def get_health_history(self, limit: int = 10) -> List[GPUHealthStatus]:
|
||||
"""
|
||||
Get recent health check history.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of GPUHealthStatus objects (most recent first)
|
||||
"""
|
||||
with self._lock:
|
||||
return list(reversed(self._history[-limit:]))
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""Background monitoring loop."""
|
||||
while not self._stop_event.wait(timeout=self._check_interval_seconds):
|
||||
try:
|
||||
logger.debug("Running periodic GPU health check")
|
||||
status = check_gpu_health(expected_device="auto")
|
||||
|
||||
with self._lock:
|
||||
self._latest_status = status
|
||||
self._history.append(status)
|
||||
|
||||
# Trim history if too long
|
||||
if len(self._history) > self._max_history:
|
||||
self._history = self._history[-self._max_history:]
|
||||
|
||||
# Log warning if performance degraded
|
||||
if status.gpu_available and not status.gpu_working:
|
||||
logger.warning(
|
||||
f"GPU health degraded: {status.error}. "
|
||||
f"Test duration: {status.test_duration_seconds}s"
|
||||
)
|
||||
elif status.gpu_working and status.test_duration_seconds > 2.0:
|
||||
logger.warning(
|
||||
f"GPU health check slow: {status.test_duration_seconds}s "
|
||||
f"(expected <1s). May indicate performance issues."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Periodic GPU health check failed: {e}")
|
||||
220
src/core/gpu_reset.py
Normal file
220
src/core/gpu_reset.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
GPU driver reset functionality with cooldown protection.
|
||||
|
||||
Provides automatic GPU driver reset on CUDA errors with safeguards
|
||||
to prevent reset loops while allowing recovery from sleep/wake cycles.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import subprocess
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cooldown file location
|
||||
RESET_TIMESTAMP_FILE = "/tmp/whisper-gpu-last-reset"
|
||||
|
||||
# Default cooldown period (minutes)
|
||||
DEFAULT_COOLDOWN_MINUTES = 5
|
||||
|
||||
|
||||
def get_cooldown_minutes() -> int:
|
||||
"""
|
||||
Get cooldown period from environment variable.
|
||||
|
||||
Returns:
|
||||
Cooldown period in minutes (default: 5)
|
||||
"""
|
||||
try:
|
||||
return int(os.getenv("GPU_RESET_COOLDOWN_MINUTES", str(DEFAULT_COOLDOWN_MINUTES)))
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid GPU_RESET_COOLDOWN_MINUTES value, using default: {DEFAULT_COOLDOWN_MINUTES}"
|
||||
)
|
||||
return DEFAULT_COOLDOWN_MINUTES
|
||||
|
||||
|
||||
def get_last_reset_time() -> Optional[datetime]:
|
||||
"""
|
||||
Read timestamp of last GPU reset attempt.
|
||||
|
||||
Returns:
|
||||
datetime object of last reset, or None if no previous reset
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(RESET_TIMESTAMP_FILE):
|
||||
return None
|
||||
|
||||
with open(RESET_TIMESTAMP_FILE, 'r') as f:
|
||||
timestamp_str = f.read().strip()
|
||||
|
||||
return datetime.fromisoformat(timestamp_str)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read last reset timestamp: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def record_reset_attempt() -> None:
|
||||
"""
|
||||
Record current time as last GPU reset attempt.
|
||||
|
||||
Creates/updates timestamp file with current UTC time.
|
||||
"""
|
||||
try:
|
||||
timestamp = datetime.utcnow().isoformat()
|
||||
|
||||
with open(RESET_TIMESTAMP_FILE, 'w') as f:
|
||||
f.write(timestamp)
|
||||
|
||||
logger.info(f"Recorded GPU reset timestamp: {timestamp}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record reset timestamp: {e}")
|
||||
|
||||
|
||||
def can_attempt_reset() -> bool:
|
||||
"""
|
||||
Check if GPU reset can be attempted based on cooldown period.
|
||||
|
||||
Returns:
|
||||
True if reset is allowed (no recent reset or cooldown expired),
|
||||
False if cooldown is still active
|
||||
"""
|
||||
last_reset = get_last_reset_time()
|
||||
|
||||
if last_reset is None:
|
||||
# No previous reset recorded
|
||||
logger.debug("No previous GPU reset found, reset allowed")
|
||||
return True
|
||||
|
||||
cooldown_minutes = get_cooldown_minutes()
|
||||
cooldown_period = timedelta(minutes=cooldown_minutes)
|
||||
time_since_reset = datetime.utcnow() - last_reset
|
||||
|
||||
if time_since_reset < cooldown_period:
|
||||
remaining = cooldown_period - time_since_reset
|
||||
logger.warning(
|
||||
f"GPU reset cooldown active. "
|
||||
f"Last reset: {last_reset.isoformat()}, "
|
||||
f"Cooldown: {cooldown_minutes} min, "
|
||||
f"Remaining: {remaining.total_seconds():.0f}s"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
f"GPU reset cooldown expired. "
|
||||
f"Last reset: {last_reset.isoformat()}, "
|
||||
f"Time since: {time_since_reset.total_seconds():.0f}s"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def reset_gpu_drivers() -> None:
|
||||
"""
|
||||
Execute GPU driver reset using reset_gpu.sh script.
|
||||
|
||||
This function:
|
||||
1. Locates reset_gpu.sh in project root
|
||||
2. Executes it with sudo
|
||||
3. Waits for completion
|
||||
4. Raises exception if reset fails
|
||||
|
||||
Raises:
|
||||
RuntimeError: If reset script fails or is not found
|
||||
PermissionError: If sudo permissions not configured
|
||||
|
||||
Note:
|
||||
Requires passwordless sudo for nvidia commands.
|
||||
See CLAUDE.md for setup instructions.
|
||||
"""
|
||||
logger.warning("=" * 60)
|
||||
logger.warning("INITIATING GPU DRIVER RESET")
|
||||
logger.warning("=" * 60)
|
||||
|
||||
# Find reset_gpu.sh script
|
||||
script_path = Path(__file__).parent.parent.parent / "reset_gpu.sh"
|
||||
|
||||
if not script_path.exists():
|
||||
error_msg = f"GPU reset script not found: {script_path}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
if not os.access(script_path, os.X_OK):
|
||||
error_msg = f"GPU reset script not executable: {script_path}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
logger.info(f"Executing GPU reset script: {script_path}")
|
||||
logger.warning("This will temporarily interrupt all GPU operations")
|
||||
|
||||
try:
|
||||
# Execute reset script with sudo
|
||||
result = subprocess.run(
|
||||
['sudo', str(script_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30 # 30 second timeout
|
||||
)
|
||||
|
||||
# Log script output
|
||||
if result.stdout:
|
||||
logger.info(f"Reset script output:\n{result.stdout}")
|
||||
|
||||
if result.stderr:
|
||||
logger.warning(f"Reset script stderr:\n{result.stderr}")
|
||||
|
||||
# Check exit code
|
||||
if result.returncode != 0:
|
||||
error_msg = (
|
||||
f"GPU reset script failed with exit code {result.returncode}. "
|
||||
f"This may indicate:\n"
|
||||
f" 1. Sudo permissions not configured (see CLAUDE.md)\n"
|
||||
f" 2. GPU hardware issue\n"
|
||||
f" 3. Driver installation problem\n"
|
||||
f"Output: {result.stderr or result.stdout}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
logger.warning("=" * 60)
|
||||
logger.warning("GPU DRIVER RESET COMPLETED SUCCESSFULLY")
|
||||
logger.warning("=" * 60)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
error_msg = "GPU reset script timed out after 30 seconds"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
except FileNotFoundError:
|
||||
error_msg = (
|
||||
"sudo command not found. This service requires sudo access "
|
||||
"for GPU driver reset. Please ensure sudo is installed."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during GPU reset: {e}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def clear_reset_cooldown() -> None:
|
||||
"""
|
||||
Clear reset cooldown by removing timestamp file.
|
||||
|
||||
Useful for testing or manual intervention.
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(RESET_TIMESTAMP_FILE):
|
||||
os.remove(RESET_TIMESTAMP_FILE)
|
||||
logger.info("GPU reset cooldown cleared")
|
||||
else:
|
||||
logger.debug("No cooldown to clear")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear reset cooldown: {e}")
|
||||
545
src/core/job_queue.py
Normal file
545
src/core/job_queue.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""
|
||||
Job queue manager for asynchronous transcription processing.
|
||||
|
||||
Provides FIFO job queue with background worker thread and disk persistence.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import queue
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
from core.gpu_health import check_gpu_health_with_reset
|
||||
from core.transcriber import transcribe_audio
|
||||
from utils.audio_processor import validate_audio_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
"""Job status enumeration."""
|
||||
QUEUED = "queued" # In queue, waiting
|
||||
RUNNING = "running" # Currently processing
|
||||
COMPLETED = "completed" # Successfully finished
|
||||
FAILED = "failed" # Error occurred
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
"""Represents a transcription job."""
|
||||
job_id: str
|
||||
status: JobStatus
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime]
|
||||
completed_at: Optional[datetime]
|
||||
queue_position: int
|
||||
|
||||
# Request parameters
|
||||
audio_path: str
|
||||
model_name: str
|
||||
device: str
|
||||
compute_type: str
|
||||
language: Optional[str]
|
||||
output_format: str
|
||||
beam_size: int
|
||||
temperature: float
|
||||
initial_prompt: Optional[str]
|
||||
output_directory: Optional[str]
|
||||
|
||||
# Results
|
||||
result_path: Optional[str]
|
||||
error: Optional[str]
|
||||
processing_time_seconds: Optional[float]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize to dictionary for JSON storage."""
|
||||
return {
|
||||
"job_id": self.job_id,
|
||||
"status": self.status.value,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"queue_position": self.queue_position,
|
||||
"request_params": {
|
||||
"audio_path": self.audio_path,
|
||||
"model_name": self.model_name,
|
||||
"device": self.device,
|
||||
"compute_type": self.compute_type,
|
||||
"language": self.language,
|
||||
"output_format": self.output_format,
|
||||
"beam_size": self.beam_size,
|
||||
"temperature": self.temperature,
|
||||
"initial_prompt": self.initial_prompt,
|
||||
"output_directory": self.output_directory,
|
||||
},
|
||||
"result_path": self.result_path,
|
||||
"error": self.error,
|
||||
"processing_time_seconds": self.processing_time_seconds,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> 'Job':
|
||||
"""Deserialize from dictionary."""
|
||||
params = data.get("request_params", {})
|
||||
return cls(
|
||||
job_id=data["job_id"],
|
||||
status=JobStatus(data["status"]),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.utcnow(),
|
||||
started_at=datetime.fromisoformat(data["started_at"]) if data.get("started_at") else None,
|
||||
completed_at=datetime.fromisoformat(data["completed_at"]) if data.get("completed_at") else None,
|
||||
queue_position=data.get("queue_position", 0),
|
||||
audio_path=params["audio_path"],
|
||||
model_name=params.get("model_name", "large-v3"),
|
||||
device=params.get("device", "auto"),
|
||||
compute_type=params.get("compute_type", "auto"),
|
||||
language=params.get("language"),
|
||||
output_format=params.get("output_format", "txt"),
|
||||
beam_size=params.get("beam_size", 5),
|
||||
temperature=params.get("temperature", 0.0),
|
||||
initial_prompt=params.get("initial_prompt"),
|
||||
output_directory=params.get("output_directory"),
|
||||
result_path=data.get("result_path"),
|
||||
error=data.get("error"),
|
||||
processing_time_seconds=data.get("processing_time_seconds"),
|
||||
)
|
||||
|
||||
def save_to_disk(self, metadata_dir: str):
|
||||
"""Save job metadata to {metadata_dir}/{job_id}.json"""
|
||||
os.makedirs(metadata_dir, exist_ok=True)
|
||||
filepath = os.path.join(metadata_dir, f"{self.job_id}.json")
|
||||
|
||||
try:
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(self.to_dict(), f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save job {self.job_id} to disk: {e}")
|
||||
|
||||
|
||||
class JobQueue:
|
||||
"""Manages job queue with background worker."""
|
||||
|
||||
def __init__(self,
|
||||
max_queue_size: int = 100,
|
||||
metadata_dir: str = "/outputs/jobs"):
|
||||
"""
|
||||
Initialize job queue.
|
||||
|
||||
Args:
|
||||
max_queue_size: Maximum number of jobs in queue
|
||||
metadata_dir: Directory to store job metadata JSON files
|
||||
"""
|
||||
self._queue = queue.Queue(maxsize=max_queue_size)
|
||||
self._jobs: Dict[str, Job] = {}
|
||||
self._metadata_dir = metadata_dir
|
||||
self._worker_thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
self._current_job_id: Optional[str] = None
|
||||
self._lock = threading.RLock() # Use RLock to allow re-entrant locking
|
||||
self._max_queue_size = max_queue_size
|
||||
|
||||
# Create metadata directory
|
||||
os.makedirs(metadata_dir, exist_ok=True)
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start background worker thread.
|
||||
Load existing jobs from disk on startup.
|
||||
"""
|
||||
if self._worker_thread is not None and self._worker_thread.is_alive():
|
||||
logger.warning("Job queue worker already running")
|
||||
return
|
||||
|
||||
logger.info(f"Starting job queue (max size: {self._max_queue_size})")
|
||||
|
||||
# Load existing jobs from disk
|
||||
self._load_jobs_from_disk()
|
||||
|
||||
# Start background worker thread
|
||||
self._stop_event.clear()
|
||||
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
|
||||
self._worker_thread.start()
|
||||
|
||||
logger.info("Job queue worker started")
|
||||
|
||||
def stop(self, wait_for_current: bool = True):
|
||||
"""
|
||||
Stop background worker.
|
||||
|
||||
Args:
|
||||
wait_for_current: If True, wait for current job to complete
|
||||
"""
|
||||
if self._worker_thread is None:
|
||||
return
|
||||
|
||||
logger.info(f"Stopping job queue (wait_for_current={wait_for_current})")
|
||||
self._stop_event.set()
|
||||
|
||||
if wait_for_current:
|
||||
self._worker_thread.join(timeout=30.0)
|
||||
else:
|
||||
self._worker_thread.join(timeout=1.0)
|
||||
|
||||
self._worker_thread = None
|
||||
logger.info("Job queue worker stopped")
|
||||
|
||||
def submit_job(self,
|
||||
audio_path: str,
|
||||
model_name: str = "large-v3",
|
||||
device: str = "auto",
|
||||
compute_type: str = "auto",
|
||||
language: Optional[str] = None,
|
||||
output_format: str = "txt",
|
||||
beam_size: int = 5,
|
||||
temperature: float = 0.0,
|
||||
initial_prompt: Optional[str] = None,
|
||||
output_directory: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Submit a new transcription job.
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"job_id": str,
|
||||
"status": str,
|
||||
"queue_position": int,
|
||||
"created_at": str
|
||||
}
|
||||
|
||||
Raises:
|
||||
queue.Full: If queue is at max capacity
|
||||
RuntimeError: If GPU health check fails (when device="cuda")
|
||||
FileNotFoundError: If audio file doesn't exist
|
||||
"""
|
||||
# 1. Validate audio file exists
|
||||
try:
|
||||
validate_audio_file(audio_path)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"Audio file validation failed: {e}")
|
||||
|
||||
# 2. Check GPU health (GPU required for all devices since this is GPU-only service)
|
||||
# Both device="cuda" and device="auto" require GPU
|
||||
if device == "cuda" or device == "auto":
|
||||
try:
|
||||
logger.info("Running GPU health check before job submission")
|
||||
# Use expected_device to match what user requested
|
||||
expected = "cuda" if device == "cuda" else "auto"
|
||||
health_status = check_gpu_health_with_reset(expected_device=expected, auto_reset=True)
|
||||
if not health_status.gpu_working:
|
||||
raise RuntimeError(
|
||||
f"GPU device required but not available. "
|
||||
f"GPU check failed: {health_status.error}. "
|
||||
f"This service is configured for GPU-only operation."
|
||||
)
|
||||
logger.info("GPU health check passed")
|
||||
except RuntimeError as e:
|
||||
# Re-raise GPU health errors
|
||||
logger.error(f"Job rejected due to GPU health check failure: {e}")
|
||||
raise RuntimeError(f"Job rejected: {e}")
|
||||
elif device == "cpu":
|
||||
# Reject CPU device explicitly
|
||||
raise ValueError(
|
||||
"CPU device requested but this service is configured for GPU-only operation. "
|
||||
"Please use device='cuda' or device='auto' with a GPU available."
|
||||
)
|
||||
|
||||
# 3. Generate job_id
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
# 4. Create Job object
|
||||
job = Job(
|
||||
job_id=job_id,
|
||||
status=JobStatus.QUEUED,
|
||||
created_at=datetime.utcnow(),
|
||||
started_at=None,
|
||||
completed_at=None,
|
||||
queue_position=0, # Will be calculated
|
||||
audio_path=audio_path,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
output_directory=output_directory,
|
||||
result_path=None,
|
||||
error=None,
|
||||
processing_time_seconds=None,
|
||||
)
|
||||
|
||||
# 5. Add to queue (raises queue.Full if full)
|
||||
try:
|
||||
self._queue.put_nowait(job)
|
||||
except queue.Full:
|
||||
raise queue.Full(
|
||||
f"Job queue is full (max size: {self._max_queue_size}). "
|
||||
f"Please try again later."
|
||||
)
|
||||
|
||||
# 6. Add to jobs dict and save to disk
|
||||
with self._lock:
|
||||
self._jobs[job_id] = job
|
||||
self._calculate_queue_positions()
|
||||
job.save_to_disk(self._metadata_dir)
|
||||
|
||||
logger.info(
|
||||
f"Job {job_id} submitted: {audio_path} "
|
||||
f"(queue position: {job.queue_position})"
|
||||
)
|
||||
|
||||
# 7. Return job info
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": job.status.value,
|
||||
"queue_position": job.queue_position,
|
||||
"created_at": job.created_at.isoformat() + "Z"
|
||||
}
|
||||
|
||||
def get_job_status(self, job_id: str) -> dict:
|
||||
"""
|
||||
Get current status of a job.
|
||||
|
||||
Returns:
|
||||
dict: Job status information
|
||||
|
||||
Raises:
|
||||
KeyError: If job_id not found
|
||||
"""
|
||||
with self._lock:
|
||||
if job_id not in self._jobs:
|
||||
raise KeyError(f"Job {job_id} not found")
|
||||
|
||||
job = self._jobs[job_id]
|
||||
return {
|
||||
"job_id": job.job_id,
|
||||
"status": job.status.value,
|
||||
"queue_position": job.queue_position if job.status == JobStatus.QUEUED else None,
|
||||
"created_at": job.created_at.isoformat() + "Z",
|
||||
"started_at": job.started_at.isoformat() + "Z" if job.started_at else None,
|
||||
"completed_at": job.completed_at.isoformat() + "Z" if job.completed_at else None,
|
||||
"result_path": job.result_path,
|
||||
"error": job.error,
|
||||
"processing_time_seconds": job.processing_time_seconds,
|
||||
}
|
||||
|
||||
def get_job_result(self, job_id: str) -> str:
|
||||
"""
|
||||
Get transcription result text for completed job.
|
||||
|
||||
Returns:
|
||||
str: Content of transcription file
|
||||
|
||||
Raises:
|
||||
KeyError: If job_id not found
|
||||
ValueError: If job not completed
|
||||
FileNotFoundError: If result file missing
|
||||
"""
|
||||
with self._lock:
|
||||
if job_id not in self._jobs:
|
||||
raise KeyError(f"Job {job_id} not found")
|
||||
|
||||
job = self._jobs[job_id]
|
||||
|
||||
if job.status != JobStatus.COMPLETED:
|
||||
raise ValueError(
|
||||
f"Job {job_id} is not completed yet. "
|
||||
f"Current status: {job.status.value}"
|
||||
)
|
||||
|
||||
if not job.result_path:
|
||||
raise FileNotFoundError(f"Job {job_id} has no result path")
|
||||
|
||||
# Read result file (outside lock to avoid blocking)
|
||||
if not os.path.exists(job.result_path):
|
||||
raise FileNotFoundError(
|
||||
f"Result file not found: {job.result_path}"
|
||||
)
|
||||
|
||||
with open(job.result_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
def list_jobs(self,
|
||||
status_filter: Optional[JobStatus] = None,
|
||||
limit: int = 100) -> List[dict]:
|
||||
"""
|
||||
List jobs with optional status filter.
|
||||
|
||||
Args:
|
||||
status_filter: Only return jobs with this status
|
||||
limit: Maximum number of jobs to return
|
||||
|
||||
Returns:
|
||||
List of job status dictionaries
|
||||
"""
|
||||
with self._lock:
|
||||
jobs = list(self._jobs.values())
|
||||
|
||||
# Filter by status
|
||||
if status_filter:
|
||||
jobs = [j for j in jobs if j.status == status_filter]
|
||||
|
||||
# Sort by created_at (newest first)
|
||||
jobs.sort(key=lambda j: j.created_at, reverse=True)
|
||||
|
||||
# Limit results
|
||||
jobs = jobs[:limit]
|
||||
|
||||
# Convert to dict
|
||||
return [self.get_job_status(j.job_id) for j in jobs]
|
||||
|
||||
def _worker_loop(self):
|
||||
"""
|
||||
Background worker thread function.
|
||||
Processes jobs from queue in FIFO order.
|
||||
"""
|
||||
logger.info("Job queue worker loop started")
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
# Get job from queue (with timeout to check stop_event)
|
||||
try:
|
||||
job = self._queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# Update job status to running
|
||||
with self._lock:
|
||||
self._current_job_id = job.job_id
|
||||
job.status = JobStatus.RUNNING
|
||||
job.started_at = datetime.utcnow()
|
||||
job.queue_position = 0
|
||||
self._calculate_queue_positions()
|
||||
job.save_to_disk(self._metadata_dir)
|
||||
|
||||
logger.info(f"Job {job.job_id} started processing")
|
||||
|
||||
# Process job
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = transcribe_audio(
|
||||
audio_path=job.audio_path,
|
||||
model_name=job.model_name,
|
||||
device=job.device,
|
||||
compute_type=job.compute_type,
|
||||
language=job.language,
|
||||
output_format=job.output_format,
|
||||
beam_size=job.beam_size,
|
||||
temperature=job.temperature,
|
||||
initial_prompt=job.initial_prompt,
|
||||
output_directory=job.output_directory
|
||||
)
|
||||
|
||||
# Parse result
|
||||
if "saved to:" in result:
|
||||
job.result_path = result.split("saved to:")[1].strip()
|
||||
job.status = JobStatus.COMPLETED
|
||||
logger.info(
|
||||
f"Job {job.job_id} completed successfully: {job.result_path}"
|
||||
)
|
||||
else:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = result
|
||||
logger.error(f"Job {job.job_id} failed: {result}")
|
||||
|
||||
except Exception as e:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
logger.error(f"Job {job.job_id} failed with exception: {e}")
|
||||
|
||||
finally:
|
||||
# Update job completion info
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.processing_time_seconds = time.time() - start_time
|
||||
job.save_to_disk(self._metadata_dir)
|
||||
|
||||
with self._lock:
|
||||
self._current_job_id = None
|
||||
self._calculate_queue_positions()
|
||||
|
||||
self._queue.task_done()
|
||||
|
||||
logger.info(
|
||||
f"Job {job.job_id} finished: "
|
||||
f"status={job.status.value}, "
|
||||
f"duration={job.processing_time_seconds:.1f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in worker loop: {e}")
|
||||
|
||||
logger.info("Job queue worker loop stopped")
|
||||
|
||||
def _load_jobs_from_disk(self):
|
||||
"""Load existing job metadata from disk on startup."""
|
||||
if not os.path.exists(self._metadata_dir):
|
||||
logger.info("No existing job metadata directory found")
|
||||
return
|
||||
|
||||
logger.info(f"Loading jobs from {self._metadata_dir}")
|
||||
|
||||
loaded_count = 0
|
||||
for filename in os.listdir(self._metadata_dir):
|
||||
if not filename.endswith(".json"):
|
||||
continue
|
||||
|
||||
filepath = os.path.join(self._metadata_dir, filename)
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
job = Job.from_dict(data)
|
||||
|
||||
# Handle jobs that were running when server stopped
|
||||
if job.status == JobStatus.RUNNING:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = "Server restarted while job was running"
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.save_to_disk(self._metadata_dir)
|
||||
logger.warning(
|
||||
f"Job {job.job_id} was running during shutdown, "
|
||||
f"marking as failed"
|
||||
)
|
||||
|
||||
# Re-queue queued jobs
|
||||
elif job.status == JobStatus.QUEUED:
|
||||
try:
|
||||
self._queue.put_nowait(job)
|
||||
logger.info(f"Re-queued job {job.job_id} from disk")
|
||||
except queue.Full:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = "Queue full on server restart"
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.save_to_disk(self._metadata_dir)
|
||||
logger.warning(
|
||||
f"Job {job.job_id} could not be re-queued (queue full)"
|
||||
)
|
||||
|
||||
self._jobs[job.job_id] = job
|
||||
loaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load job from {filepath}: {e}")
|
||||
|
||||
logger.info(f"Loaded {loaded_count} jobs from disk")
|
||||
self._calculate_queue_positions()
|
||||
|
||||
def _calculate_queue_positions(self):
|
||||
"""Update queue_position for all queued jobs."""
|
||||
queued_jobs = [
|
||||
j for j in self._jobs.values()
|
||||
if j.status == JobStatus.QUEUED
|
||||
]
|
||||
|
||||
# Sort by created_at (FIFO)
|
||||
queued_jobs.sort(key=lambda j: j.created_at)
|
||||
|
||||
# Update positions
|
||||
for i, job in enumerate(queued_jobs, start=1):
|
||||
job.queue_position = i
|
||||
@@ -4,7 +4,7 @@ Model Management Module
|
||||
Responsible for loading, caching, and managing Whisper models
|
||||
"""
|
||||
|
||||
import os; print(os.environ.get("WHISPER_MODEL_DIR"))
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
@@ -14,6 +14,14 @@ from faster_whisper import WhisperModel, BatchedInferencePipeline
|
||||
# Log configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import GPU health check with reset capability
|
||||
try:
|
||||
from core.gpu_health import check_gpu_health_with_reset
|
||||
GPU_HEALTH_CHECK_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("GPU health check with reset not available")
|
||||
GPU_HEALTH_CHECK_AVAILABLE = False
|
||||
|
||||
# Global model instance cache
|
||||
model_instances = {}
|
||||
|
||||
@@ -60,15 +68,35 @@ def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[s
|
||||
if model_name not in valid_models:
|
||||
raise ValueError(f"Invalid model name: {model_name}. Valid models: {', '.join(valid_models)}")
|
||||
|
||||
# Auto-detect device
|
||||
# Auto-detect device - GPU REQUIRED (no CPU fallback)
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if not torch.cuda.is_available():
|
||||
error_msg = (
|
||||
"GPU required but CUDA is not available. "
|
||||
"This service is configured for GPU-only operation. "
|
||||
"Please ensure CUDA is properly installed and GPU is accessible."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
device = "cuda"
|
||||
|
||||
# Auto-detect compute type
|
||||
if compute_type == "auto":
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
|
||||
# Validate device and compute type
|
||||
if device not in ["cpu", "cuda"]:
|
||||
raise ValueError(f"Invalid device: {device}. Valid devices: cpu, cuda")
|
||||
|
||||
# CRITICAL: Reject CPU device - this service is GPU-only
|
||||
if device == "cpu":
|
||||
error_msg = (
|
||||
"CPU device requested but this service is configured for GPU-only operation. "
|
||||
"Please use device='cuda' or device='auto' with a GPU available."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
logger.error("CUDA requested but not available")
|
||||
raise RuntimeError("CUDA not available but explicitly requested")
|
||||
@@ -90,7 +118,18 @@ def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[s
|
||||
|
||||
# Test GPU driver before loading model and clean
|
||||
if device == "cuda":
|
||||
test_gpu_driver()
|
||||
# Use GPU health check with reset capability if available
|
||||
if GPU_HEALTH_CHECK_AVAILABLE:
|
||||
try:
|
||||
logger.info("Running GPU health check with auto-reset before model loading")
|
||||
check_gpu_health_with_reset(expected_device="cuda", auto_reset=True)
|
||||
except Exception as e:
|
||||
logger.error(f"GPU health check failed: {e}")
|
||||
raise RuntimeError(f"GPU not available for model loading: {e}")
|
||||
else:
|
||||
# Fallback to simple GPU test
|
||||
test_gpu_driver()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Instantiate model
|
||||
@@ -163,8 +202,14 @@ def get_model_info() -> str:
|
||||
models = [
|
||||
"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"
|
||||
]
|
||||
devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
|
||||
compute_types = ["float16", "int8"] if torch.cuda.is_available() else ["int8"]
|
||||
# GPU-only service - CUDA required
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
"GPU required but CUDA is not available. "
|
||||
"This service is configured for GPU-only operation."
|
||||
)
|
||||
devices = ["cuda"] # CPU not supported in GPU-only mode
|
||||
compute_types = ["float16", "int8"]
|
||||
|
||||
# Supported language list
|
||||
languages = {
|
||||
@@ -181,20 +226,21 @@ def get_model_info() -> str:
|
||||
"available_models": models,
|
||||
"default_model": "large-v3",
|
||||
"available_devices": devices,
|
||||
"default_device": "cuda" if torch.cuda.is_available() else "cpu",
|
||||
"default_device": "cuda", # GPU-only service
|
||||
"available_compute_types": compute_types,
|
||||
"default_compute_type": "float16" if torch.cuda.is_available() else "int8",
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"default_compute_type": "float16",
|
||||
"cuda_available": True, # Required for this service
|
||||
"gpu_only_mode": True, # Indicate this is GPU-only
|
||||
"supported_languages": languages,
|
||||
"supported_audio_formats": audio_formats,
|
||||
"version": "0.1.1"
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
info["gpu_info"] = {
|
||||
"name": torch.cuda.get_device_name(0),
|
||||
"memory_total": f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB",
|
||||
"memory_available": f"{torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated() / 1e9:.2f} GB"
|
||||
}
|
||||
# GPU info (always present in GPU-only mode)
|
||||
info["gpu_info"] = {
|
||||
"name": torch.cuda.get_device_name(0),
|
||||
"memory_total": f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB",
|
||||
"memory_available": f"{torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated() / 1e9:.2f} GB"
|
||||
}
|
||||
|
||||
return json.dumps(info, indent=2)
|
||||
@@ -9,9 +9,9 @@ import time
|
||||
import logging
|
||||
from typing import Dict, Any, Tuple, List, Optional, Union
|
||||
|
||||
from model_manager import get_whisper_model
|
||||
from audio_processor import validate_audio_file, process_audio
|
||||
from formatters import format_vtt, format_srt, format_json, format_txt, format_time
|
||||
from core.model_manager import get_whisper_model
|
||||
from utils.audio_processor import validate_audio_file, process_audio
|
||||
from utils.formatters import format_vtt, format_srt, format_json, format_txt, format_time
|
||||
|
||||
# Logging configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -76,9 +76,10 @@ def transcribe_audio(
|
||||
temperature = temperature if temperature is not None else DEFAULT_TEMPERATURE
|
||||
|
||||
# Validate audio file
|
||||
validation_result = validate_audio_file(audio_path)
|
||||
if validation_result != "ok":
|
||||
return validation_result
|
||||
try:
|
||||
validate_audio_file(audio_path)
|
||||
except (FileNotFoundError, ValueError, OSError) as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
try:
|
||||
# Get model instance
|
||||
5
src/servers/__init__.py
Normal file
5
src/servers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Server implementations for Whisper transcription service.
|
||||
|
||||
Includes MCP server (whisper_server.py) and REST API server (api_server.py).
|
||||
"""
|
||||
466
src/servers/api_server.py
Normal file
466
src/servers/api_server.py
Normal file
@@ -0,0 +1,466 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FastAPI REST API Server for Whisper Transcription
|
||||
Provides HTTP REST endpoints for audio transcription
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import queue as queue_module
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, List
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
|
||||
from core.model_manager import get_model_info
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
from core.gpu_health import HealthMonitor, check_gpu_health, get_circuit_breaker_stats, reset_circuit_breaker
|
||||
from utils.startup import startup_sequence, cleanup_on_shutdown
|
||||
|
||||
# Logging configuration
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances
|
||||
job_queue: Optional[JobQueue] = None
|
||||
health_monitor: Optional[HealthMonitor] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""FastAPI lifespan context manager for startup/shutdown"""
|
||||
global job_queue, health_monitor
|
||||
|
||||
# Startup - use common startup logic (without GPU check, handled in main)
|
||||
logger.info("Initializing job queue and health monitor...")
|
||||
|
||||
from utils.startup import initialize_job_queue, initialize_health_monitor
|
||||
|
||||
job_queue = initialize_job_queue()
|
||||
health_monitor = initialize_health_monitor()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown - use common cleanup logic
|
||||
cleanup_on_shutdown(
|
||||
job_queue=job_queue,
|
||||
health_monitor=health_monitor,
|
||||
wait_for_current_job=True
|
||||
)
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Whisper Speech Recognition API",
|
||||
description="High-performance audio transcription API based on Faster Whisper with async job queue",
|
||||
version="0.2.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class SubmitJobRequest(BaseModel):
|
||||
audio_path: str = Field(..., description="Path to the audio file on the server")
|
||||
model_name: str = Field("large-v3", description="Whisper model name")
|
||||
device: str = Field("auto", description="Execution device (cuda, auto)")
|
||||
compute_type: str = Field("auto", description="Computation type (float16, int8, auto)")
|
||||
language: Optional[str] = Field(None, description="Language code (zh, en, ja, etc.)")
|
||||
output_format: str = Field("txt", description="Output format (vtt, srt, json, txt)")
|
||||
beam_size: int = Field(5, description="Beam search size")
|
||||
temperature: float = Field(0.0, description="Sampling temperature")
|
||||
initial_prompt: Optional[str] = Field(None, description="Initial prompt text")
|
||||
output_directory: Optional[str] = Field(None, description="Output directory path")
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information"""
|
||||
return {
|
||||
"name": "Whisper Speech Recognition API",
|
||||
"version": "0.2.0",
|
||||
"description": "Async job queue-based transcription service",
|
||||
"endpoints": {
|
||||
"GET /": "API information",
|
||||
"GET /health": "Health check",
|
||||
"GET /health/gpu": "GPU health check",
|
||||
"GET /health/circuit-breaker": "Get circuit breaker stats",
|
||||
"POST /health/circuit-breaker/reset": "Reset circuit breaker",
|
||||
"GET /models": "Get available models information",
|
||||
"POST /jobs": "Submit transcription job (async)",
|
||||
"GET /jobs/{job_id}": "Get job status",
|
||||
"GET /jobs/{job_id}/result": "Get job result",
|
||||
"GET /jobs": "List jobs with optional filtering"
|
||||
},
|
||||
"workflow": {
|
||||
"1": "Submit job via POST /jobs → receive job_id",
|
||||
"2": "Poll status via GET /jobs/{job_id} → wait for status='completed'",
|
||||
"3": "Get result via GET /jobs/{job_id}/result → retrieve transcription"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy", "service": "whisper-transcription"}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def get_models():
|
||||
"""Get available Whisper models and configuration information"""
|
||||
try:
|
||||
model_info = get_model_info()
|
||||
return JSONResponse(content=json.loads(model_info))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model info: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/jobs")
|
||||
async def submit_job(request: SubmitJobRequest):
|
||||
"""
|
||||
Submit a transcription job for async processing.
|
||||
|
||||
Returns immediately with job_id. Poll GET /jobs/{job_id} for status.
|
||||
"""
|
||||
try:
|
||||
job_info = job_queue.submit_job(
|
||||
audio_path=request.audio_path,
|
||||
model_name=request.model_name,
|
||||
device=request.device,
|
||||
compute_type=request.compute_type,
|
||||
language=request.language,
|
||||
output_format=request.output_format,
|
||||
beam_size=request.beam_size,
|
||||
temperature=request.temperature,
|
||||
initial_prompt=request.initial_prompt,
|
||||
output_directory=request.output_directory
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
**job_info,
|
||||
"message": f"Job submitted successfully. Poll /jobs/{job_info['job_id']} for status."
|
||||
}
|
||||
)
|
||||
|
||||
except queue_module.Full:
|
||||
# Queue is full
|
||||
logger.warning("Job queue is full, rejecting request")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={
|
||||
"error": "Queue full",
|
||||
"message": f"Job queue is full ({job_queue._max_queue_size}/{job_queue._max_queue_size}). Please try again later or contact administrator.",
|
||||
"queue_size": job_queue._max_queue_size,
|
||||
"max_queue_size": job_queue._max_queue_size
|
||||
}
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
# Invalid audio file
|
||||
logger.error(f"Invalid audio file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid audio file",
|
||||
"message": str(e),
|
||||
"audio_path": request.audio_path
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
# CPU device rejected
|
||||
logger.error(f"Invalid device parameter: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid device",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
# GPU health check failed
|
||||
logger.error(f"GPU health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "GPU unavailable",
|
||||
"message": f"Job rejected: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit job: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Internal error",
|
||||
"message": f"Failed to submit job: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/jobs/{job_id}")
|
||||
async def get_job_status_endpoint(job_id: str):
|
||||
"""
|
||||
Get the current status of a job.
|
||||
|
||||
Returns job status including queue position, timestamps, and result path when completed.
|
||||
"""
|
||||
try:
|
||||
status = job_queue.get_job_status(job_id)
|
||||
return JSONResponse(status_code=200, content=status)
|
||||
|
||||
except KeyError:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": "Job not found",
|
||||
"message": f"Job ID '{job_id}' does not exist or has been cleaned up",
|
||||
"job_id": job_id
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get job status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Internal error",
|
||||
"message": f"Failed to get job status: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/jobs/{job_id}/result")
|
||||
async def get_job_result_endpoint(job_id: str):
|
||||
"""
|
||||
Get the transcription result for a completed job.
|
||||
|
||||
Returns the transcription text. Only works for jobs with status='completed'.
|
||||
"""
|
||||
try:
|
||||
result_text = job_queue.get_job_result(job_id)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"job_id": job_id, "result": result_text}
|
||||
)
|
||||
|
||||
except KeyError:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": "Job not found",
|
||||
"message": f"Job ID '{job_id}' does not exist or has been cleaned up",
|
||||
"job_id": job_id
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
# Job not completed yet
|
||||
# Extract current status from error message
|
||||
status_match = str(e).split("Current status: ")
|
||||
current_status = status_match[1] if len(status_match) > 1 else "unknown"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"error": "Job not completed",
|
||||
"message": f"Job is not completed yet. Current status: {current_status}. Please wait and poll again.",
|
||||
"job_id": job_id,
|
||||
"current_status": current_status
|
||||
}
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
# Result file missing
|
||||
logger.error(f"Result file not found for job {job_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Result file not found",
|
||||
"message": str(e),
|
||||
"job_id": job_id
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get job result: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Internal error",
|
||||
"message": f"Failed to get job result: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/jobs")
|
||||
async def list_jobs_endpoint(
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100
|
||||
):
|
||||
"""
|
||||
List jobs with optional filtering.
|
||||
|
||||
Query parameters:
|
||||
- status: Filter by status (queued, running, completed, failed)
|
||||
- limit: Maximum number of results (default: 100)
|
||||
"""
|
||||
try:
|
||||
# Parse status filter
|
||||
status_filter = None
|
||||
if status:
|
||||
try:
|
||||
status_filter = JobStatus(status)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid status",
|
||||
"message": f"Invalid status value: {status}. Must be one of: queued, running, completed, failed"
|
||||
}
|
||||
)
|
||||
|
||||
jobs = job_queue.list_jobs(status_filter=status_filter, limit=limit)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"jobs": jobs,
|
||||
"total": len(jobs),
|
||||
"filters": {
|
||||
"status": status,
|
||||
"limit": limit
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list jobs: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Internal error",
|
||||
"message": f"Failed to list jobs: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health/gpu")
|
||||
async def gpu_health_check_endpoint():
|
||||
"""
|
||||
Check GPU health by running a quick transcription test.
|
||||
|
||||
Returns detailed GPU status including device name, memory, and test duration.
|
||||
"""
|
||||
try:
|
||||
status = check_gpu_health(expected_device="auto")
|
||||
|
||||
# Add interpretation
|
||||
interpretation = "GPU is healthy and working correctly"
|
||||
if not status.gpu_available:
|
||||
interpretation = "GPU not available on this system"
|
||||
elif not status.gpu_working:
|
||||
interpretation = f"GPU available but not working correctly: {status.error}"
|
||||
elif status.test_duration_seconds > 2.0:
|
||||
interpretation = f"GPU working but performance degraded (test took {status.test_duration_seconds:.2f}s, expected <1s)"
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
**status.to_dict(),
|
||||
"interpretation": interpretation
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GPU health check failed: {e}")
|
||||
return JSONResponse(
|
||||
status_code=200, # Still return 200 but with error details
|
||||
content={
|
||||
"gpu_available": False,
|
||||
"gpu_working": False,
|
||||
"device_used": "unknown",
|
||||
"device_name": "",
|
||||
"memory_total_gb": 0.0,
|
||||
"memory_available_gb": 0.0,
|
||||
"test_duration_seconds": 0.0,
|
||||
"timestamp": "",
|
||||
"error": str(e),
|
||||
"interpretation": f"GPU health check failed: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health/circuit-breaker")
|
||||
async def get_circuit_breaker_status():
|
||||
"""
|
||||
Get GPU health check circuit breaker statistics.
|
||||
|
||||
Returns current state, failure/success counts, and last failure time.
|
||||
"""
|
||||
try:
|
||||
stats = get_circuit_breaker_stats()
|
||||
return JSONResponse(status_code=200, content=stats)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get circuit breaker stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/health/circuit-breaker/reset")
|
||||
async def reset_circuit_breaker_endpoint():
|
||||
"""
|
||||
Manually reset the GPU health check circuit breaker.
|
||||
|
||||
Useful after fixing GPU issues or for testing purposes.
|
||||
"""
|
||||
try:
|
||||
reset_circuit_breaker()
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"message": "Circuit breaker reset successfully"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset circuit breaker: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Perform startup GPU health check
|
||||
from utils.startup import perform_startup_gpu_check
|
||||
|
||||
perform_startup_gpu_check(
|
||||
required_device="cuda",
|
||||
auto_reset=True,
|
||||
exit_on_failure=True
|
||||
)
|
||||
|
||||
# Get configuration from environment variables
|
||||
host = os.getenv("API_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("API_PORT", "8000"))
|
||||
|
||||
logger.info(f"Starting Whisper REST API server on {host}:{port}")
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info"
|
||||
)
|
||||
334
src/servers/whisper_server.py
Normal file
334
src/servers/whisper_server.py
Normal file
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Faster Whisper-based Speech Recognition MCP Service
|
||||
Provides high-performance audio transcription with batch processing acceleration and multiple output formats
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from core.model_manager import get_model_info
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
from core.gpu_health import HealthMonitor, check_gpu_health
|
||||
from utils.startup import startup_sequence, cleanup_on_shutdown
|
||||
|
||||
# Log configuration
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances
|
||||
job_queue: Optional[JobQueue] = None
|
||||
health_monitor: Optional[HealthMonitor] = None
|
||||
|
||||
# Create FastMCP server instance
|
||||
mcp = FastMCP(
|
||||
name="fast-whisper-mcp-server",
|
||||
version="0.2.0",
|
||||
dependencies=["faster-whisper>=0.9.0", "torch==2.6.0+cu126", "torchaudio==2.6.0+cu126", "numpy>=1.20.0"]
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def get_model_info_api() -> str:
|
||||
"""
|
||||
Get available Whisper model information and system configuration.
|
||||
|
||||
Returns available models, devices, languages, and GPU information.
|
||||
"""
|
||||
return get_model_info()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def transcribe_async(
|
||||
audio_path: str,
|
||||
model_name: str = "large-v3",
|
||||
device: str = "auto",
|
||||
compute_type: str = "auto",
|
||||
language: Optional[str] = None,
|
||||
output_format: str = "txt",
|
||||
beam_size: int = 5,
|
||||
temperature: float = 0.0,
|
||||
initial_prompt: Optional[str] = None,
|
||||
output_directory: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Submit an audio file for asynchronous transcription.
|
||||
|
||||
IMPORTANT: This tool returns immediately with a job_id. Use get_job_status()
|
||||
to check progress and get_job_result() to retrieve the transcription.
|
||||
|
||||
WORKFLOW FOR LLM AGENTS:
|
||||
1. Call this tool to submit the job
|
||||
2. You will receive a job_id and queue_position
|
||||
3. Poll get_job_status(job_id) every 5-10 seconds to check progress
|
||||
4. When status="completed", call get_job_result(job_id) to get transcription
|
||||
|
||||
For long audio files (>10 minutes), expect processing to take several minutes.
|
||||
You can check queue_position to estimate wait time (each job ~2-5 minutes).
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file on server
|
||||
model_name: Whisper model (tiny, base, small, medium, large-v3)
|
||||
device: Execution device (cuda, auto) - cpu is rejected
|
||||
compute_type: Computation type (float16, int8, auto)
|
||||
language: Language code (en, zh, ja, etc.) or auto-detect
|
||||
output_format: Output format (txt, vtt, srt, json)
|
||||
beam_size: Beam search size (larger=better quality, slower)
|
||||
temperature: Sampling temperature (0.0=greedy)
|
||||
initial_prompt: Optional prompt to guide transcription
|
||||
output_directory: Where to save result (uses default if not specified)
|
||||
|
||||
Returns:
|
||||
JSON string with job_id, status, queue_position, and instructions
|
||||
"""
|
||||
try:
|
||||
job_info = job_queue.submit_job(
|
||||
audio_path=audio_path,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
output_directory=output_directory
|
||||
)
|
||||
return json.dumps({
|
||||
**job_info,
|
||||
"message": f"Job submitted successfully. Poll get_job_status('{job_info['job_id']}') for updates."
|
||||
}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
if "Full" in error_type or "queue is full" in str(e).lower():
|
||||
error_code = "QUEUE_FULL"
|
||||
message = f"Job queue is full. Please try again in a few minutes. Error: {str(e)}"
|
||||
elif "FileNotFoundError" in error_type or "not found" in str(e).lower():
|
||||
error_code = "INVALID_AUDIO_FILE"
|
||||
message = f"Audio file not found or invalid. Error: {str(e)}"
|
||||
elif "RuntimeError" in error_type or "GPU" in str(e):
|
||||
error_code = "GPU_UNAVAILABLE"
|
||||
message = f"GPU unavailable. Error: {str(e)}"
|
||||
elif "ValueError" in error_type or "CPU" in str(e):
|
||||
error_code = "INVALID_DEVICE"
|
||||
message = f"Invalid device parameter. Error: {str(e)}"
|
||||
else:
|
||||
error_code = "INTERNAL_ERROR"
|
||||
message = f"Failed to submit job. Error: {str(e)}"
|
||||
|
||||
return json.dumps({
|
||||
"error": error_code,
|
||||
"message": message
|
||||
}, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_job_status(job_id: str) -> str:
|
||||
"""
|
||||
Check the status of a transcription job.
|
||||
|
||||
Status values:
|
||||
- "queued": Job is waiting in queue. Check queue_position.
|
||||
- "running": Job is currently being processed.
|
||||
- "completed": Transcription finished. Call get_job_result() to retrieve.
|
||||
- "failed": Job failed. Check error field for details.
|
||||
|
||||
Args:
|
||||
job_id: Job ID from transcribe_async()
|
||||
|
||||
Returns:
|
||||
JSON string with detailed job status including:
|
||||
- status, queue_position, timestamps, error (if any)
|
||||
"""
|
||||
try:
|
||||
status = job_queue.get_job_status(job_id)
|
||||
return json.dumps(status, indent=2)
|
||||
|
||||
except KeyError:
|
||||
return json.dumps({
|
||||
"error": "JOB_NOT_FOUND",
|
||||
"message": f"Job ID '{job_id}' does not exist. Please check the job_id."
|
||||
}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"error": "INTERNAL_ERROR",
|
||||
"message": f"Failed to get job status. Error: {str(e)}"
|
||||
}, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_job_result(job_id: str) -> str:
|
||||
"""
|
||||
Retrieve the transcription result for a completed job.
|
||||
|
||||
IMPORTANT: Only call this when get_job_status() returns status="completed".
|
||||
If the job is not completed, this will return an error.
|
||||
|
||||
Args:
|
||||
job_id: Job ID from transcribe_async()
|
||||
|
||||
Returns:
|
||||
Transcription text as a string
|
||||
|
||||
Errors:
|
||||
- "Job not found" if invalid job_id
|
||||
- "Job not completed yet" if status is not "completed"
|
||||
- "Result file not found" if transcription file is missing
|
||||
"""
|
||||
try:
|
||||
result_text = job_queue.get_job_result(job_id)
|
||||
return result_text # Return raw text, not JSON
|
||||
|
||||
except KeyError:
|
||||
return json.dumps({
|
||||
"error": "JOB_NOT_FOUND",
|
||||
"message": f"Job ID '{job_id}' does not exist."
|
||||
}, indent=2)
|
||||
|
||||
except ValueError as e:
|
||||
# Extract status from error message
|
||||
status_match = str(e).split("Current status: ")
|
||||
current_status = status_match[1] if len(status_match) > 1 else "unknown"
|
||||
return json.dumps({
|
||||
"error": "JOB_NOT_COMPLETED",
|
||||
"message": f"Job is not completed yet. Current status: {current_status}. Please wait and check again.",
|
||||
"current_status": current_status
|
||||
}, indent=2)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
return json.dumps({
|
||||
"error": "RESULT_FILE_NOT_FOUND",
|
||||
"message": f"Result file not found. Error: {str(e)}"
|
||||
}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"error": "INTERNAL_ERROR",
|
||||
"message": f"Failed to get job result. Error: {str(e)}"
|
||||
}, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def list_transcription_jobs(
|
||||
status_filter: Optional[str] = None,
|
||||
limit: int = 20
|
||||
) -> str:
|
||||
"""
|
||||
List transcription jobs with optional filtering.
|
||||
|
||||
Useful for:
|
||||
- Checking all your submitted jobs
|
||||
- Finding completed jobs
|
||||
- Monitoring queue status
|
||||
|
||||
Args:
|
||||
status_filter: Filter by status (queued, running, completed, failed)
|
||||
limit: Maximum number of jobs to return (default: 20)
|
||||
|
||||
Returns:
|
||||
JSON string with list of jobs
|
||||
"""
|
||||
try:
|
||||
# Parse status filter
|
||||
status_obj = None
|
||||
if status_filter:
|
||||
try:
|
||||
status_obj = JobStatus(status_filter)
|
||||
except ValueError:
|
||||
return json.dumps({
|
||||
"error": "INVALID_STATUS",
|
||||
"message": f"Invalid status: {status_filter}. Must be one of: queued, running, completed, failed"
|
||||
}, indent=2)
|
||||
|
||||
jobs = job_queue.list_jobs(status_filter=status_obj, limit=limit)
|
||||
|
||||
return json.dumps({
|
||||
"jobs": jobs,
|
||||
"total": len(jobs),
|
||||
"filters": {
|
||||
"status": status_filter,
|
||||
"limit": limit
|
||||
}
|
||||
}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"error": "INTERNAL_ERROR",
|
||||
"message": f"Failed to list jobs. Error: {str(e)}"
|
||||
}, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def check_gpu_health() -> str:
|
||||
"""
|
||||
Test GPU availability and performance by running a quick transcription.
|
||||
|
||||
This tool loads the tiny model and transcribes a 1-second test audio file
|
||||
to verify the GPU is working correctly.
|
||||
|
||||
Use this when:
|
||||
- You want to verify GPU is available before submitting large jobs
|
||||
- You suspect GPU performance issues
|
||||
- For monitoring/debugging purposes
|
||||
|
||||
Returns:
|
||||
JSON string with detailed GPU status including:
|
||||
- gpu_available, gpu_working, device_name, memory_info
|
||||
- test_duration_seconds (GPU: <1s, CPU: 5-10s)
|
||||
- interpretation message
|
||||
|
||||
Note: If this returns gpu_working=false, transcriptions will be very slow.
|
||||
"""
|
||||
try:
|
||||
status = check_gpu_health(expected_device="auto")
|
||||
|
||||
# Add interpretation
|
||||
interpretation = "GPU is healthy and working correctly"
|
||||
if not status.gpu_available:
|
||||
interpretation = "GPU not available on this system"
|
||||
elif not status.gpu_working:
|
||||
interpretation = f"GPU available but not working correctly: {status.error}"
|
||||
elif status.test_duration_seconds > 2.0:
|
||||
interpretation = f"GPU working but performance degraded (test took {status.test_duration_seconds:.2f}s, expected <1s)"
|
||||
|
||||
result = status.to_dict()
|
||||
result["interpretation"] = interpretation
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"error": "GPU_CHECK_FAILED",
|
||||
"message": f"GPU health check failed. Error: {str(e)}",
|
||||
"gpu_available": False,
|
||||
"gpu_working": False
|
||||
}, indent=2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("starting mcp server for whisper stt transcriptor")
|
||||
|
||||
# Execute common startup sequence
|
||||
job_queue, health_monitor = startup_sequence(
|
||||
service_name="MCP Whisper Server",
|
||||
require_gpu=True,
|
||||
initialize_queue=True,
|
||||
initialize_monitoring=True
|
||||
)
|
||||
|
||||
try:
|
||||
mcp.run()
|
||||
finally:
|
||||
# Cleanup on shutdown
|
||||
cleanup_on_shutdown(
|
||||
job_queue=job_queue,
|
||||
health_monitor=health_monitor,
|
||||
wait_for_current_job=True
|
||||
)
|
||||
6
src/utils/__init__.py
Normal file
6
src/utils/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Utility modules for Whisper transcription service.
|
||||
|
||||
Includes audio processing, formatters, test audio generation, input validation,
|
||||
circuit breaker, and startup logic.
|
||||
"""
|
||||
68
src/utils/audio_processor.py
Normal file
68
src/utils/audio_processor.py
Normal file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Audio Processing Module
|
||||
Responsible for audio file validation and preprocessing
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Union, Any
|
||||
from pathlib import Path
|
||||
from faster_whisper import decode_audio
|
||||
|
||||
from utils.input_validation import (
|
||||
validate_audio_file as validate_audio_file_secure,
|
||||
sanitize_error_message
|
||||
)
|
||||
|
||||
# Log configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def validate_audio_file(audio_path: str, allowed_dirs: list = None) -> None:
|
||||
"""
|
||||
Validate if an audio file is valid (with security checks).
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If audio file doesn't exist
|
||||
ValueError: If audio file format is unsupported or file is empty
|
||||
OSError: If file size cannot be checked
|
||||
|
||||
Returns:
|
||||
None: If validation passes
|
||||
"""
|
||||
try:
|
||||
# Use secure validation
|
||||
validate_audio_file_secure(audio_path, allowed_dirs)
|
||||
except Exception as e:
|
||||
# Re-raise with sanitized error messages
|
||||
error_msg = sanitize_error_message(str(e))
|
||||
|
||||
if "not found" in str(e).lower():
|
||||
raise FileNotFoundError(error_msg)
|
||||
elif "size" in str(e).lower():
|
||||
raise OSError(error_msg)
|
||||
else:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def process_audio(audio_path: str) -> Union[str, Any]:
|
||||
"""
|
||||
Process audio file, perform decoding and preprocessing
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
|
||||
Returns:
|
||||
Union[str, Any]: Processed audio data or original file path
|
||||
"""
|
||||
# Try to preprocess audio using decode_audio to handle more formats
|
||||
try:
|
||||
audio_data = decode_audio(audio_path)
|
||||
logger.info(f"Successfully preprocessed audio: {os.path.basename(audio_path)}")
|
||||
return audio_data
|
||||
except Exception as audio_error:
|
||||
logger.warning(f"Audio preprocessing failed, will use file path directly: {str(audio_error)}")
|
||||
return audio_path
|
||||
291
src/utils/circuit_breaker.py
Normal file
291
src/utils/circuit_breaker.py
Normal file
@@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Circuit Breaker Pattern Implementation
|
||||
|
||||
Prevents repeated failed attempts and provides fail-fast behavior.
|
||||
Useful for GPU health checks and other operations that may fail repeatedly.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states."""
|
||||
CLOSED = "closed" # Normal operation, requests pass through
|
||||
OPEN = "open" # Circuit is open, requests fail immediately
|
||||
HALF_OPEN = "half_open" # Testing if circuit can close
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitBreakerConfig:
|
||||
"""Configuration for circuit breaker."""
|
||||
failure_threshold: int = 3 # Failures before opening circuit
|
||||
success_threshold: int = 2 # Successes before closing from half-open
|
||||
timeout_seconds: int = 60 # Time before attempting half-open
|
||||
half_open_max_calls: int = 1 # Max calls to test in half-open state
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker implementation for preventing repeated failures.
|
||||
|
||||
Usage:
|
||||
breaker = CircuitBreaker(name="gpu_health", failure_threshold=3)
|
||||
|
||||
@breaker.call
|
||||
def check_gpu():
|
||||
# This function will be protected by circuit breaker
|
||||
return perform_gpu_check()
|
||||
|
||||
# Or use decorator:
|
||||
@breaker.decorator()
|
||||
def my_function():
|
||||
return "result"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
failure_threshold: int = 3,
|
||||
success_threshold: int = 2,
|
||||
timeout_seconds: int = 60,
|
||||
half_open_max_calls: int = 1
|
||||
):
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
name: Name of the circuit (for logging)
|
||||
failure_threshold: Number of failures before opening
|
||||
success_threshold: Number of successes to close from half-open
|
||||
timeout_seconds: Seconds before transitioning to half-open
|
||||
half_open_max_calls: Max concurrent calls in half-open state
|
||||
"""
|
||||
self.name = name
|
||||
self.config = CircuitBreakerConfig(
|
||||
failure_threshold=failure_threshold,
|
||||
success_threshold=success_threshold,
|
||||
timeout_seconds=timeout_seconds,
|
||||
half_open_max_calls=half_open_max_calls
|
||||
)
|
||||
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
self._last_failure_time: Optional[datetime] = None
|
||||
self._half_open_calls = 0
|
||||
self._lock = threading.RLock()
|
||||
|
||||
logger.info(
|
||||
f"Circuit breaker '{name}' initialized: "
|
||||
f"failure_threshold={failure_threshold}, "
|
||||
f"timeout={timeout_seconds}s"
|
||||
)
|
||||
|
||||
@property
|
||||
def state(self) -> CircuitState:
|
||||
"""Get current circuit state."""
|
||||
with self._lock:
|
||||
self._update_state()
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if circuit is closed (normal operation)."""
|
||||
return self.state == CircuitState.CLOSED
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
"""Check if circuit is open (failing fast)."""
|
||||
return self.state == CircuitState.OPEN
|
||||
|
||||
@property
|
||||
def is_half_open(self) -> bool:
|
||||
"""Check if circuit is half-open (testing)."""
|
||||
return self.state == CircuitState.HALF_OPEN
|
||||
|
||||
def _update_state(self):
|
||||
"""Update state based on timeout and counters."""
|
||||
if self._state == CircuitState.OPEN:
|
||||
# Check if timeout has passed
|
||||
if self._last_failure_time:
|
||||
elapsed = datetime.utcnow() - self._last_failure_time
|
||||
if elapsed.total_seconds() >= self.config.timeout_seconds:
|
||||
logger.info(
|
||||
f"Circuit '{self.name}': Transitioning to HALF_OPEN "
|
||||
f"after {elapsed.total_seconds():.0f}s timeout"
|
||||
)
|
||||
self._state = CircuitState.HALF_OPEN
|
||||
self._half_open_calls = 0
|
||||
self._success_count = 0
|
||||
|
||||
def _on_success(self):
|
||||
"""Handle successful call."""
|
||||
with self._lock:
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
self._success_count += 1
|
||||
logger.debug(
|
||||
f"Circuit '{self.name}': Success in HALF_OPEN "
|
||||
f"({self._success_count}/{self.config.success_threshold})"
|
||||
)
|
||||
|
||||
if self._success_count >= self.config.success_threshold:
|
||||
logger.info(f"Circuit '{self.name}': Closing circuit after successful test")
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
self._last_failure_time = None
|
||||
|
||||
elif self._state == CircuitState.CLOSED:
|
||||
# Reset failure count on success
|
||||
self._failure_count = 0
|
||||
|
||||
def _on_failure(self, error: Exception):
|
||||
"""Handle failed call."""
|
||||
with self._lock:
|
||||
self._failure_count += 1
|
||||
self._last_failure_time = datetime.utcnow()
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
logger.warning(
|
||||
f"Circuit '{self.name}': Failure in HALF_OPEN, reopening circuit"
|
||||
)
|
||||
self._state = CircuitState.OPEN
|
||||
self._success_count = 0
|
||||
|
||||
elif self._state == CircuitState.CLOSED:
|
||||
logger.debug(
|
||||
f"Circuit '{self.name}': Failure {self._failure_count}/"
|
||||
f"{self.config.failure_threshold}"
|
||||
)
|
||||
|
||||
if self._failure_count >= self.config.failure_threshold:
|
||||
logger.warning(
|
||||
f"Circuit '{self.name}': Opening circuit after "
|
||||
f"{self._failure_count} failures. "
|
||||
f"Will retry in {self.config.timeout_seconds}s"
|
||||
)
|
||||
self._state = CircuitState.OPEN
|
||||
self._success_count = 0
|
||||
|
||||
def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute function with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
func: Function to execute
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Function result
|
||||
|
||||
Raises:
|
||||
CircuitBreakerOpen: If circuit is open
|
||||
Exception: Original exception from func if it fails
|
||||
"""
|
||||
with self._lock:
|
||||
self._update_state()
|
||||
|
||||
# Check if circuit is open
|
||||
if self._state == CircuitState.OPEN:
|
||||
raise CircuitBreakerOpen(
|
||||
f"Circuit '{self.name}' is OPEN. "
|
||||
f"Failing fast to prevent repeated failures. "
|
||||
f"Last failure: {self._last_failure_time.isoformat() if self._last_failure_time else 'unknown'}. "
|
||||
f"Will retry in {self.config.timeout_seconds}s"
|
||||
)
|
||||
|
||||
# Check half-open call limit
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
if self._half_open_calls >= self.config.half_open_max_calls:
|
||||
raise CircuitBreakerOpen(
|
||||
f"Circuit '{self.name}' is HALF_OPEN with max calls reached. "
|
||||
f"Please wait for current test to complete."
|
||||
)
|
||||
self._half_open_calls += 1
|
||||
|
||||
# Execute function
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
self._on_success()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self._on_failure(e)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Decrement half-open counter
|
||||
with self._lock:
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
self._half_open_calls -= 1
|
||||
|
||||
def decorator(self):
|
||||
"""
|
||||
Decorator for protecting functions with circuit breaker.
|
||||
|
||||
Usage:
|
||||
breaker = CircuitBreaker("my_service")
|
||||
|
||||
@breaker.decorator()
|
||||
def my_function():
|
||||
return do_something()
|
||||
"""
|
||||
def wrapper(func):
|
||||
def decorated(*args, **kwargs):
|
||||
return self.call(func, *args, **kwargs)
|
||||
return decorated
|
||||
return wrapper
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Manually reset circuit breaker to closed state.
|
||||
|
||||
Useful for:
|
||||
- Testing
|
||||
- Manual intervention
|
||||
- Clearing error state
|
||||
"""
|
||||
with self._lock:
|
||||
logger.info(f"Circuit '{self.name}': Manual reset to CLOSED state")
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
self._last_failure_time = None
|
||||
self._half_open_calls = 0
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""
|
||||
Get circuit breaker statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with current state and counters
|
||||
"""
|
||||
with self._lock:
|
||||
self._update_state()
|
||||
return {
|
||||
"name": self.name,
|
||||
"state": self._state.value,
|
||||
"failure_count": self._failure_count,
|
||||
"success_count": self._success_count,
|
||||
"last_failure_time": self._last_failure_time.isoformat() if self._last_failure_time else None,
|
||||
"config": {
|
||||
"failure_threshold": self.config.failure_threshold,
|
||||
"success_threshold": self.config.success_threshold,
|
||||
"timeout_seconds": self.config.timeout_seconds,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class CircuitBreakerOpen(Exception):
|
||||
"""Exception raised when circuit breaker is open."""
|
||||
pass
|
||||
411
src/utils/input_validation.py
Normal file
411
src/utils/input_validation.py
Normal file
@@ -0,0 +1,411 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Input Validation and Path Sanitization Module
|
||||
|
||||
Provides robust validation for user inputs with security protections
|
||||
against path traversal, injection attacks, and other malicious inputs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum file size (10GB)
|
||||
MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 * 1024
|
||||
|
||||
# Allowed audio file extensions
|
||||
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"}
|
||||
|
||||
# Allowed output formats
|
||||
ALLOWED_OUTPUT_FORMATS = {"vtt", "srt", "txt", "json"}
|
||||
|
||||
# Model name validation (whitelist)
|
||||
ALLOWED_MODEL_NAMES = {"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"}
|
||||
|
||||
# Device validation
|
||||
ALLOWED_DEVICES = {"cuda", "auto", "cpu"}
|
||||
|
||||
# Compute type validation
|
||||
ALLOWED_COMPUTE_TYPES = {"float16", "int8", "auto"}
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Base exception for validation errors."""
|
||||
pass
|
||||
|
||||
|
||||
class PathTraversalError(ValidationError):
|
||||
"""Exception for path traversal attempts."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidFileTypeError(ValidationError):
|
||||
"""Exception for invalid file types."""
|
||||
pass
|
||||
|
||||
|
||||
class FileSizeError(ValidationError):
|
||||
"""Exception for file size issues."""
|
||||
pass
|
||||
|
||||
|
||||
def sanitize_error_message(error_msg: str, sanitize_paths: bool = True) -> str:
|
||||
"""
|
||||
Sanitize error messages to prevent information leakage.
|
||||
|
||||
Args:
|
||||
error_msg: Original error message
|
||||
sanitize_paths: Whether to sanitize file paths (default: True)
|
||||
|
||||
Returns:
|
||||
Sanitized error message
|
||||
"""
|
||||
if not sanitize_paths:
|
||||
return error_msg
|
||||
|
||||
# Replace absolute paths with relative paths
|
||||
# Pattern: /home/user/... or /media/... or /var/... or /tmp/...
|
||||
path_pattern = r'(/(?:home|media|var|tmp|opt|usr)/[^\s:,]+)'
|
||||
|
||||
def replace_path(match):
|
||||
full_path = match.group(1)
|
||||
try:
|
||||
# Try to get just the filename
|
||||
basename = os.path.basename(full_path)
|
||||
return f"<file:{basename}>"
|
||||
except:
|
||||
return "<file:redacted>"
|
||||
|
||||
sanitized = re.sub(path_pattern, replace_path, error_msg)
|
||||
|
||||
# Also sanitize user names if present
|
||||
sanitized = re.sub(r'/home/([^/]+)/', '/home/<user>/', sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def validate_path_safe(file_path: str, allowed_dirs: Optional[List[str]] = None) -> Path:
|
||||
"""
|
||||
Validate and sanitize a file path to prevent directory traversal attacks.
|
||||
|
||||
Args:
|
||||
file_path: Path to validate
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
|
||||
Returns:
|
||||
Resolved Path object
|
||||
|
||||
Raises:
|
||||
PathTraversalError: If path contains traversal attempts
|
||||
ValidationError: If path is invalid
|
||||
"""
|
||||
if not file_path:
|
||||
raise ValidationError("File path cannot be empty")
|
||||
|
||||
# Convert to Path object
|
||||
try:
|
||||
path = Path(file_path)
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Invalid path format: {sanitize_error_message(str(e))}")
|
||||
|
||||
# Check for path traversal attempts
|
||||
path_str = str(path)
|
||||
if ".." in path_str:
|
||||
logger.warning(f"Path traversal attempt detected: {path_str}")
|
||||
raise PathTraversalError("Path traversal (..) is not allowed")
|
||||
|
||||
# Check for null bytes
|
||||
if "\x00" in path_str:
|
||||
logger.warning(f"Null byte in path detected: {path_str}")
|
||||
raise PathTraversalError("Null bytes in path are not allowed")
|
||||
|
||||
# Resolve to absolute path (but don't follow symlinks yet)
|
||||
try:
|
||||
resolved_path = path.resolve()
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Cannot resolve path: {sanitize_error_message(str(e))}")
|
||||
|
||||
# If allowed_dirs specified, ensure path is within one of them
|
||||
if allowed_dirs:
|
||||
allowed = False
|
||||
for allowed_dir in allowed_dirs:
|
||||
try:
|
||||
allowed_dir_path = Path(allowed_dir).resolve()
|
||||
# Check if resolved_path is under allowed_dir
|
||||
resolved_path.relative_to(allowed_dir_path)
|
||||
allowed = True
|
||||
break
|
||||
except ValueError:
|
||||
# Not relative to this allowed_dir
|
||||
continue
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
f"Path outside allowed directories: {path_str}, "
|
||||
f"allowed: {allowed_dirs}"
|
||||
)
|
||||
raise PathTraversalError(
|
||||
f"Path must be within allowed directories. "
|
||||
f"Allowed: {[os.path.basename(d) for d in allowed_dirs]}"
|
||||
)
|
||||
|
||||
return resolved_path
|
||||
|
||||
|
||||
def validate_audio_file(
|
||||
file_path: str,
|
||||
allowed_dirs: Optional[List[str]] = None,
|
||||
max_size_bytes: int = MAX_FILE_SIZE_BYTES
|
||||
) -> Path:
|
||||
"""
|
||||
Validate audio file path with security checks.
|
||||
|
||||
Args:
|
||||
file_path: Path to audio file
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
max_size_bytes: Maximum allowed file size
|
||||
|
||||
Returns:
|
||||
Validated Path object
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
PathTraversalError: If path traversal detected
|
||||
FileNotFoundError: If file doesn't exist
|
||||
InvalidFileTypeError: If file type not allowed
|
||||
FileSizeError: If file too large
|
||||
"""
|
||||
# Validate and sanitize path
|
||||
validated_path = validate_path_safe(file_path, allowed_dirs)
|
||||
|
||||
# Check file exists
|
||||
if not validated_path.exists():
|
||||
raise FileNotFoundError(f"Audio file not found: {validated_path.name}")
|
||||
|
||||
# Check it's a file (not directory)
|
||||
if not validated_path.is_file():
|
||||
raise ValidationError(f"Path is not a file: {validated_path.name}")
|
||||
|
||||
# Check file extension
|
||||
file_ext = validated_path.suffix.lower()
|
||||
if file_ext not in ALLOWED_AUDIO_EXTENSIONS:
|
||||
raise InvalidFileTypeError(
|
||||
f"Unsupported audio format: {file_ext}. "
|
||||
f"Supported: {', '.join(sorted(ALLOWED_AUDIO_EXTENSIONS))}"
|
||||
)
|
||||
|
||||
# Check file size
|
||||
try:
|
||||
file_size = validated_path.stat().st_size
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Cannot check file size: {sanitize_error_message(str(e))}")
|
||||
|
||||
if file_size == 0:
|
||||
raise FileSizeError(f"Audio file is empty: {validated_path.name}")
|
||||
|
||||
if file_size > max_size_bytes:
|
||||
raise FileSizeError(
|
||||
f"File too large: {file_size / (1024**3):.2f}GB. "
|
||||
f"Maximum: {max_size_bytes / (1024**3):.2f}GB"
|
||||
)
|
||||
|
||||
# Warn for large files (>1GB)
|
||||
if file_size > 1024 * 1024 * 1024:
|
||||
logger.warning(
|
||||
f"Large file: {file_size / (1024**3):.2f}GB, "
|
||||
f"may require extended processing time"
|
||||
)
|
||||
|
||||
return validated_path
|
||||
|
||||
|
||||
def validate_output_directory(
|
||||
dir_path: str,
|
||||
allowed_dirs: Optional[List[str]] = None,
|
||||
create_if_missing: bool = True
|
||||
) -> Path:
|
||||
"""
|
||||
Validate output directory path.
|
||||
|
||||
Args:
|
||||
dir_path: Directory path
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
create_if_missing: Create directory if it doesn't exist
|
||||
|
||||
Returns:
|
||||
Validated Path object
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
PathTraversalError: If path traversal detected
|
||||
"""
|
||||
# Validate and sanitize path
|
||||
validated_path = validate_path_safe(dir_path, allowed_dirs)
|
||||
|
||||
# Create if requested and doesn't exist
|
||||
if create_if_missing and not validated_path.exists():
|
||||
try:
|
||||
validated_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created output directory: {validated_path}")
|
||||
except Exception as e:
|
||||
raise ValidationError(
|
||||
f"Cannot create output directory: {sanitize_error_message(str(e))}"
|
||||
)
|
||||
|
||||
# Check it's a directory
|
||||
if validated_path.exists() and not validated_path.is_dir():
|
||||
raise ValidationError(f"Path exists but is not a directory: {validated_path.name}")
|
||||
|
||||
return validated_path
|
||||
|
||||
|
||||
def validate_model_name(model_name: str) -> str:
|
||||
"""
|
||||
Validate Whisper model name.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
Validated model name
|
||||
|
||||
Raises:
|
||||
ValidationError: If model name invalid
|
||||
"""
|
||||
if not model_name:
|
||||
raise ValidationError("Model name cannot be empty")
|
||||
|
||||
model_name = model_name.strip().lower()
|
||||
|
||||
if model_name not in ALLOWED_MODEL_NAMES:
|
||||
raise ValidationError(
|
||||
f"Invalid model name: {model_name}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_MODEL_NAMES))}"
|
||||
)
|
||||
|
||||
return model_name
|
||||
|
||||
|
||||
def validate_device(device: str) -> str:
|
||||
"""
|
||||
Validate device parameter.
|
||||
|
||||
Args:
|
||||
device: Device name to validate
|
||||
|
||||
Returns:
|
||||
Validated device name
|
||||
|
||||
Raises:
|
||||
ValidationError: If device invalid
|
||||
"""
|
||||
if not device:
|
||||
raise ValidationError("Device cannot be empty")
|
||||
|
||||
device = device.strip().lower()
|
||||
|
||||
if device not in ALLOWED_DEVICES:
|
||||
raise ValidationError(
|
||||
f"Invalid device: {device}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_DEVICES))}"
|
||||
)
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def validate_compute_type(compute_type: str) -> str:
|
||||
"""
|
||||
Validate compute type parameter.
|
||||
|
||||
Args:
|
||||
compute_type: Compute type to validate
|
||||
|
||||
Returns:
|
||||
Validated compute type
|
||||
|
||||
Raises:
|
||||
ValidationError: If compute type invalid
|
||||
"""
|
||||
if not compute_type:
|
||||
raise ValidationError("Compute type cannot be empty")
|
||||
|
||||
compute_type = compute_type.strip().lower()
|
||||
|
||||
if compute_type not in ALLOWED_COMPUTE_TYPES:
|
||||
raise ValidationError(
|
||||
f"Invalid compute type: {compute_type}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_COMPUTE_TYPES))}"
|
||||
)
|
||||
|
||||
return compute_type
|
||||
|
||||
|
||||
def validate_output_format(output_format: str) -> str:
|
||||
"""
|
||||
Validate output format parameter.
|
||||
|
||||
Args:
|
||||
output_format: Output format to validate
|
||||
|
||||
Returns:
|
||||
Validated output format
|
||||
|
||||
Raises:
|
||||
ValidationError: If output format invalid
|
||||
"""
|
||||
if not output_format:
|
||||
raise ValidationError("Output format cannot be empty")
|
||||
|
||||
output_format = output_format.strip().lower()
|
||||
|
||||
if output_format not in ALLOWED_OUTPUT_FORMATS:
|
||||
raise ValidationError(
|
||||
f"Invalid output format: {output_format}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_OUTPUT_FORMATS))}"
|
||||
)
|
||||
|
||||
return output_format
|
||||
|
||||
|
||||
def validate_numeric_range(
|
||||
value: float,
|
||||
min_value: float,
|
||||
max_value: float,
|
||||
param_name: str
|
||||
) -> float:
|
||||
"""
|
||||
Validate numeric parameter is within range.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
min_value: Minimum allowed value
|
||||
max_value: Maximum allowed value
|
||||
param_name: Parameter name for error messages
|
||||
|
||||
Returns:
|
||||
Validated value
|
||||
|
||||
Raises:
|
||||
ValidationError: If value out of range
|
||||
"""
|
||||
if value < min_value or value > max_value:
|
||||
raise ValidationError(
|
||||
f"{param_name} must be between {min_value} and {max_value}, "
|
||||
f"got {value}"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_beam_size(beam_size: int) -> int:
|
||||
"""Validate beam size parameter."""
|
||||
return int(validate_numeric_range(beam_size, 1, 20, "beam_size"))
|
||||
|
||||
|
||||
def validate_temperature(temperature: float) -> float:
|
||||
"""Validate temperature parameter."""
|
||||
return validate_numeric_range(temperature, 0.0, 1.0, "temperature")
|
||||
237
src/utils/startup.py
Normal file
237
src/utils/startup.py
Normal file
@@ -0,0 +1,237 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Common Startup Logic Module
|
||||
|
||||
Centralizes startup procedures shared between MCP and API servers,
|
||||
including GPU health checks, job queue initialization, and health monitoring.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import GPU health check with reset
|
||||
try:
|
||||
from core.gpu_health import check_gpu_health_with_reset
|
||||
GPU_HEALTH_CHECK_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"GPU health check with reset not available: {e}")
|
||||
GPU_HEALTH_CHECK_AVAILABLE = False
|
||||
|
||||
|
||||
def perform_startup_gpu_check(
|
||||
required_device: str = "cuda",
|
||||
auto_reset: bool = True,
|
||||
exit_on_failure: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Perform startup GPU health check with optional auto-reset.
|
||||
|
||||
This function:
|
||||
1. Checks if GPU health check is available
|
||||
2. Runs comprehensive GPU health check
|
||||
3. Attempts auto-reset if check fails and auto_reset=True
|
||||
4. Optionally exits process if check fails
|
||||
|
||||
Args:
|
||||
required_device: Required device ("cuda", "auto")
|
||||
auto_reset: Enable automatic GPU driver reset on failure
|
||||
exit_on_failure: Exit process if GPU check fails
|
||||
|
||||
Returns:
|
||||
True if GPU check passed, False otherwise
|
||||
|
||||
Side effects:
|
||||
May exit process if exit_on_failure=True and check fails
|
||||
"""
|
||||
if not GPU_HEALTH_CHECK_AVAILABLE:
|
||||
logger.warning("GPU health check not available, starting without GPU validation")
|
||||
if exit_on_failure:
|
||||
logger.error("GPU health check required but not available. Exiting.")
|
||||
sys.exit(1)
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info("=" * 70)
|
||||
logger.info("PERFORMING STARTUP GPU HEALTH CHECK")
|
||||
logger.info("=" * 70)
|
||||
|
||||
status = check_gpu_health_with_reset(
|
||||
expected_device=required_device,
|
||||
auto_reset=auto_reset
|
||||
)
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("STARTUP GPU CHECK SUCCESSFUL")
|
||||
logger.info(f"GPU Device: {status.device_name}")
|
||||
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
|
||||
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
|
||||
logger.info("=" * 70)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("=" * 70)
|
||||
logger.error("STARTUP GPU CHECK FAILED")
|
||||
logger.error(f"Error: {e}")
|
||||
|
||||
if exit_on_failure:
|
||||
logger.error("This service requires GPU. Terminating.")
|
||||
logger.error("=" * 70)
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.error("Continuing without GPU (may have reduced functionality)")
|
||||
logger.error("=" * 70)
|
||||
return False
|
||||
|
||||
|
||||
def initialize_job_queue(
|
||||
max_queue_size: Optional[int] = None,
|
||||
metadata_dir: Optional[str] = None
|
||||
) -> 'JobQueue':
|
||||
"""
|
||||
Initialize job queue with environment variable configuration.
|
||||
|
||||
Args:
|
||||
max_queue_size: Override for max queue size (uses env var if None)
|
||||
metadata_dir: Override for metadata directory (uses env var if None)
|
||||
|
||||
Returns:
|
||||
Initialized JobQueue instance (started)
|
||||
"""
|
||||
from core.job_queue import JobQueue
|
||||
|
||||
# Get configuration from environment
|
||||
if max_queue_size is None:
|
||||
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
|
||||
|
||||
if metadata_dir is None:
|
||||
metadata_dir = os.getenv(
|
||||
"JOB_METADATA_DIR",
|
||||
"/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
|
||||
)
|
||||
|
||||
logger.info("Initializing job queue...")
|
||||
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
|
||||
job_queue.start()
|
||||
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
|
||||
|
||||
return job_queue
|
||||
|
||||
|
||||
def initialize_health_monitor(
|
||||
check_interval_minutes: Optional[int] = None,
|
||||
enabled: Optional[bool] = None
|
||||
) -> Optional['HealthMonitor']:
|
||||
"""
|
||||
Initialize GPU health monitor with environment variable configuration.
|
||||
|
||||
Args:
|
||||
check_interval_minutes: Override for check interval (uses env var if None)
|
||||
enabled: Override for enabled status (uses env var if None)
|
||||
|
||||
Returns:
|
||||
Initialized HealthMonitor instance (started), or None if disabled
|
||||
"""
|
||||
from core.gpu_health import HealthMonitor
|
||||
|
||||
# Get configuration from environment
|
||||
if enabled is None:
|
||||
enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true"
|
||||
|
||||
if not enabled:
|
||||
logger.info("GPU health monitoring disabled")
|
||||
return None
|
||||
|
||||
if check_interval_minutes is None:
|
||||
check_interval_minutes = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
|
||||
|
||||
health_monitor = HealthMonitor(check_interval_minutes=check_interval_minutes)
|
||||
health_monitor.start()
|
||||
logger.info(f"GPU health monitor started (interval={check_interval_minutes} minutes)")
|
||||
|
||||
return health_monitor
|
||||
|
||||
|
||||
def startup_sequence(
|
||||
service_name: str = "whisper-transcription",
|
||||
require_gpu: bool = True,
|
||||
initialize_queue: bool = True,
|
||||
initialize_monitoring: bool = True
|
||||
) -> Tuple[Optional['JobQueue'], Optional['HealthMonitor']]:
|
||||
"""
|
||||
Execute complete startup sequence for a Whisper transcription server.
|
||||
|
||||
This function performs all common startup tasks:
|
||||
1. GPU health check with auto-reset
|
||||
2. Job queue initialization
|
||||
3. Health monitor initialization
|
||||
|
||||
Args:
|
||||
service_name: Name of the service (for logging)
|
||||
require_gpu: Whether GPU is required (exit if not available)
|
||||
initialize_queue: Whether to initialize job queue
|
||||
initialize_monitoring: Whether to initialize health monitoring
|
||||
|
||||
Returns:
|
||||
Tuple of (job_queue, health_monitor) - either may be None
|
||||
|
||||
Side effects:
|
||||
May exit process if GPU required but unavailable
|
||||
"""
|
||||
logger.info(f"Starting {service_name}...")
|
||||
|
||||
# Step 1: GPU health check
|
||||
gpu_ok = perform_startup_gpu_check(
|
||||
required_device="cuda",
|
||||
auto_reset=True,
|
||||
exit_on_failure=require_gpu
|
||||
)
|
||||
|
||||
if not gpu_ok and require_gpu:
|
||||
# Should not reach here (exit_on_failure should have exited)
|
||||
logger.error("GPU check failed and GPU is required")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 2: Initialize job queue
|
||||
job_queue = None
|
||||
if initialize_queue:
|
||||
job_queue = initialize_job_queue()
|
||||
|
||||
# Step 3: Initialize health monitor
|
||||
health_monitor = None
|
||||
if initialize_monitoring:
|
||||
health_monitor = initialize_health_monitor()
|
||||
|
||||
logger.info(f"{service_name} startup sequence completed")
|
||||
|
||||
return job_queue, health_monitor
|
||||
|
||||
|
||||
def cleanup_on_shutdown(
|
||||
job_queue: Optional['JobQueue'] = None,
|
||||
health_monitor: Optional['HealthMonitor'] = None,
|
||||
wait_for_current_job: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Perform cleanup on server shutdown.
|
||||
|
||||
Args:
|
||||
job_queue: JobQueue instance to stop (if any)
|
||||
health_monitor: HealthMonitor instance to stop (if any)
|
||||
wait_for_current_job: Wait for current job to complete before stopping
|
||||
"""
|
||||
logger.info("Shutting down...")
|
||||
|
||||
if job_queue:
|
||||
job_queue.stop(wait_for_current=wait_for_current_job)
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
if health_monitor:
|
||||
health_monitor.stop()
|
||||
logger.info("Health monitor stopped")
|
||||
|
||||
logger.info("Shutdown complete")
|
||||
96
src/utils/test_audio_generator.py
Normal file
96
src/utils/test_audio_generator.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Test audio generator for GPU health checks.
|
||||
|
||||
Generates realistic test audio with speech using TTS (text-to-speech).
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
||||
def generate_test_audio(duration_seconds: float = 3.0, frequency: int = 440) -> str:
|
||||
"""
|
||||
Generate a test audio file with real speech for GPU health checks.
|
||||
|
||||
Args:
|
||||
duration_seconds: Duration of audio in seconds (default: 3.0)
|
||||
frequency: Legacy parameter, ignored (kept for backward compatibility)
|
||||
|
||||
Returns:
|
||||
str: Path to temporary audio file
|
||||
|
||||
Implementation:
|
||||
- Generate real speech using gTTS (Google Text-to-Speech)
|
||||
- Fallback to pyttsx3 if gTTS fails or is unavailable
|
||||
- Raises RuntimeError if both TTS engines fail
|
||||
- Save as MP3 format
|
||||
- Store in system temp directory
|
||||
- Reuse same file if exists (cache)
|
||||
"""
|
||||
# Use a consistent filename in temp directory for caching
|
||||
temp_dir = tempfile.gettempdir()
|
||||
audio_path = os.path.join(temp_dir, f"whisper_test_voice_{int(duration_seconds)}s.mp3")
|
||||
|
||||
# Return cached file if it exists and is valid
|
||||
if os.path.exists(audio_path):
|
||||
try:
|
||||
# Verify file is readable and not empty
|
||||
if os.path.getsize(audio_path) > 0:
|
||||
return audio_path
|
||||
except Exception:
|
||||
# If file is corrupted, regenerate it
|
||||
pass
|
||||
|
||||
# Generate speech with different text based on duration
|
||||
if duration_seconds >= 3:
|
||||
text = "This is a test of the Whisper speech recognition system. Testing one, two, three."
|
||||
elif duration_seconds >= 2:
|
||||
text = "This is a test of the Whisper system."
|
||||
else:
|
||||
text = "Testing Whisper."
|
||||
|
||||
# Try gTTS first (better quality, requires internet)
|
||||
try:
|
||||
from gtts import gTTS
|
||||
tts = gTTS(text=text, lang='en', slow=False)
|
||||
tts.save(audio_path)
|
||||
return audio_path
|
||||
except Exception as e:
|
||||
print(f"gTTS failed ({e}), trying pyttsx3...")
|
||||
|
||||
# Fallback to pyttsx3 (offline, lower quality)
|
||||
try:
|
||||
import pyttsx3
|
||||
engine = pyttsx3.init()
|
||||
engine.save_to_file(text, audio_path)
|
||||
engine.runAndWait()
|
||||
|
||||
# Verify file was created
|
||||
if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0:
|
||||
return audio_path
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to generate test audio. Both gTTS and pyttsx3 failed. "
|
||||
f"gTTS error: {e}. Please ensure TTS dependencies are installed: "
|
||||
f"pip install gTTS pyttsx3"
|
||||
)
|
||||
|
||||
|
||||
def cleanup_test_audio() -> None:
|
||||
"""
|
||||
Remove all cached test audio files from temp directory.
|
||||
|
||||
Useful for cleanup after testing or to force regeneration.
|
||||
"""
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
# Find all test audio files
|
||||
for filename in os.listdir(temp_dir):
|
||||
if (filename.startswith("whisper_test_") and
|
||||
(filename.endswith(".wav") or filename.endswith(".mp3"))):
|
||||
filepath = os.path.join(temp_dir, filename)
|
||||
try:
|
||||
os.remove(filepath)
|
||||
except Exception:
|
||||
# Ignore errors during cleanup
|
||||
pass
|
||||
24
supervisor/transcriptor-api.conf
Normal file
24
supervisor/transcriptor-api.conf
Normal file
@@ -0,0 +1,24 @@
|
||||
[program:whisper-api-server]
|
||||
command=/home/uad/agents/tools/mcp-transcriptor/venv/bin/python /home/uad/agents/tools/mcp-transcriptor/src/servers/api_server.py
|
||||
directory=/home/uad/agents/tools/mcp-transcriptor
|
||||
user=uad
|
||||
autostart=true
|
||||
autorestart=true
|
||||
redirect_stderr=true
|
||||
stdout_logfile=/home/uad/agents/tools/mcp-transcriptor/logs/transcriptor-api.log
|
||||
stdout_logfile_maxbytes=50MB
|
||||
stdout_logfile_backups=10
|
||||
environment=
|
||||
PYTHONPATH="/home/uad/agents/tools/mcp-transcriptor/src",
|
||||
CUDA_VISIBLE_DEVICES="0",
|
||||
API_HOST="0.0.0.0",
|
||||
API_PORT="8000",
|
||||
WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/models",
|
||||
TRANSCRIPTION_OUTPUT_DIR="/home/uad/agents/tools/mcp-transcriptor/outputs",
|
||||
TRANSCRIPTION_BATCH_OUTPUT_DIR="/home/uad/agents/tools/mcp-transcriptor/outputs/batch",
|
||||
TRANSCRIPTION_MODEL="large-v3",
|
||||
TRANSCRIPTION_DEVICE="auto",
|
||||
TRANSCRIPTION_COMPUTE_TYPE="auto",
|
||||
TRANSCRIPTION_OUTPUT_FORMAT="txt"
|
||||
stopwaitsecs=10
|
||||
stopsignal=TERM
|
||||
537
tests/test_async_api_integration.py
Executable file
537
tests/test_async_api_integration.py
Executable file
@@ -0,0 +1,537 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Phase 2: Async Job Queue Integration
|
||||
|
||||
Tests the async job queue system for both API and MCP servers.
|
||||
Validates all new endpoints and error handling.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add src to path (go up one level from tests/ to root)
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
# Color codes for terminal output
|
||||
class Colors:
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
END = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
|
||||
def print_success(msg):
|
||||
print(f"{Colors.GREEN}✓ {msg}{Colors.END}")
|
||||
|
||||
def print_error(msg):
|
||||
print(f"{Colors.RED}✗ {msg}{Colors.END}")
|
||||
|
||||
def print_info(msg):
|
||||
print(f"{Colors.BLUE}ℹ {msg}{Colors.END}")
|
||||
|
||||
def print_section(msg):
|
||||
print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{msg}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}\n")
|
||||
|
||||
|
||||
class Phase2Tester:
|
||||
def __init__(self, api_url="http://localhost:8000"):
|
||||
self.api_url = api_url
|
||||
self.test_results = []
|
||||
|
||||
def test(self, name, func):
|
||||
"""Run a test and record result"""
|
||||
try:
|
||||
logger.info(f"Testing: {name}")
|
||||
print_info(f"Testing: {name}")
|
||||
func()
|
||||
logger.info(f"PASSED: {name}")
|
||||
print_success(f"PASSED: {name}")
|
||||
self.test_results.append((name, True, None))
|
||||
return True
|
||||
except AssertionError as e:
|
||||
logger.error(f"FAILED: {name} - {str(e)}")
|
||||
print_error(f"FAILED: {name}")
|
||||
print_error(f" Reason: {str(e)}")
|
||||
self.test_results.append((name, False, str(e)))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR: {name} - {str(e)}")
|
||||
print_error(f"ERROR: {name}")
|
||||
print_error(f" Exception: {str(e)}")
|
||||
self.test_results.append((name, False, f"Exception: {str(e)}"))
|
||||
return False
|
||||
|
||||
def print_summary(self):
|
||||
"""Print test summary"""
|
||||
print_section("TEST SUMMARY")
|
||||
|
||||
passed = sum(1 for _, result, _ in self.test_results if result)
|
||||
failed = len(self.test_results) - passed
|
||||
|
||||
for name, result, error in self.test_results:
|
||||
if result:
|
||||
print_success(f"{name}")
|
||||
else:
|
||||
print_error(f"{name}")
|
||||
if error:
|
||||
print(f" {error}")
|
||||
|
||||
print(f"\n{Colors.BOLD}Total: {len(self.test_results)} | ", end="")
|
||||
print(f"{Colors.GREEN}Passed: {passed}{Colors.END} | ", end="")
|
||||
print(f"{Colors.RED}Failed: {failed}{Colors.END}\n")
|
||||
|
||||
return failed == 0
|
||||
|
||||
# ========================================================================
|
||||
# API Server Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_api_root_endpoint(self):
|
||||
"""Test GET / returns new API information"""
|
||||
logger.info(f"GET {self.api_url}/")
|
||||
resp = requests.get(f"{self.api_url}/")
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
logger.info(f"Response data: {json.dumps(data, indent=2)}")
|
||||
assert data["version"] == "0.2.0", "Version should be 0.2.0"
|
||||
assert "POST /jobs" in str(data["endpoints"]), "Should have POST /jobs endpoint"
|
||||
assert "workflow" in data, "Should have workflow documentation"
|
||||
|
||||
def test_api_health_endpoint(self):
|
||||
"""Test GET /health still works"""
|
||||
logger.info(f"GET {self.api_url}/health")
|
||||
resp = requests.get(f"{self.api_url}/health")
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
logger.info(f"Response data: {data}")
|
||||
assert data["status"] == "healthy", "Health check should return healthy"
|
||||
|
||||
def test_api_models_endpoint(self):
|
||||
"""Test GET /models still works"""
|
||||
logger.info(f"GET {self.api_url}/models")
|
||||
resp = requests.get(f"{self.api_url}/models")
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
logger.info(f"Available models: {data.get('available_models', [])}")
|
||||
assert "available_models" in data, "Should return available models"
|
||||
|
||||
def test_api_gpu_health_endpoint(self):
|
||||
"""Test GET /health/gpu returns GPU status"""
|
||||
logger.info(f"GET {self.api_url}/health/gpu")
|
||||
resp = requests.get(f"{self.api_url}/health/gpu")
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
logger.info(f"GPU health: {json.dumps(data, indent=2)}")
|
||||
assert "gpu_available" in data, "Should have gpu_available field"
|
||||
assert "gpu_working" in data, "Should have gpu_working field"
|
||||
assert "interpretation" in data, "Should have interpretation field"
|
||||
|
||||
print_info(f" GPU Status: {data.get('interpretation', 'unknown')}")
|
||||
|
||||
def test_api_submit_job_invalid_audio(self):
|
||||
"""Test POST /jobs with invalid audio path returns 400"""
|
||||
payload = {
|
||||
"audio_path": "/nonexistent/file.mp3",
|
||||
"model_name": "tiny",
|
||||
"output_format": "txt"
|
||||
}
|
||||
|
||||
logger.info(f"POST {self.api_url}/jobs with invalid audio path")
|
||||
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
|
||||
resp = requests.post(f"{self.api_url}/jobs", json=payload)
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
logger.info(f"Response: {resp.json()}")
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "error" in data["detail"], "Should have error field"
|
||||
assert data["detail"]["error"] == "Invalid audio file", f"Wrong error type: {data['detail']['error']}"
|
||||
|
||||
print_info(f" Error message: {data['detail']['message'][:50]}...")
|
||||
|
||||
def test_api_submit_job_cpu_device_rejected(self):
|
||||
"""Test POST /jobs with device=cpu is rejected (400)"""
|
||||
# Create a test audio file first
|
||||
logger.info("Creating test audio file...")
|
||||
test_audio = self._create_test_audio_file()
|
||||
logger.info(f"Test audio created at: {test_audio}")
|
||||
|
||||
payload = {
|
||||
"audio_path": test_audio,
|
||||
"model_name": "tiny",
|
||||
"device": "cpu",
|
||||
"output_format": "txt"
|
||||
}
|
||||
|
||||
logger.info(f"POST {self.api_url}/jobs with device=cpu")
|
||||
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
|
||||
resp = requests.post(f"{self.api_url}/jobs", json=payload)
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
logger.info(f"Response: {resp.json()}")
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "error" in data["detail"], "Should have error field"
|
||||
assert "Invalid device" in data["detail"]["error"] or "CPU" in data["detail"]["message"], \
|
||||
"Should reject CPU device"
|
||||
|
||||
def test_api_submit_job_success(self):
|
||||
"""Test POST /jobs with valid audio returns job_id"""
|
||||
logger.info("Creating test audio file...")
|
||||
test_audio = self._create_test_audio_file()
|
||||
logger.info(f"Test audio created at: {test_audio}")
|
||||
|
||||
payload = {
|
||||
"audio_path": test_audio,
|
||||
"model_name": "tiny",
|
||||
"device": "auto",
|
||||
"output_format": "txt"
|
||||
}
|
||||
|
||||
logger.info(f"POST {self.api_url}/jobs with valid audio")
|
||||
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
|
||||
resp = requests.post(f"{self.api_url}/jobs", json=payload)
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
logger.info(f"Response: {json.dumps(resp.json(), indent=2)}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "job_id" in data, "Should return job_id"
|
||||
assert "status" in data, "Should return status"
|
||||
assert data["status"] == "queued", f"Status should be queued, got {data['status']}"
|
||||
assert "queue_position" in data, "Should return queue_position"
|
||||
assert "message" in data, "Should return message"
|
||||
|
||||
logger.info(f"Job submitted successfully: {data['job_id']}")
|
||||
print_info(f" Job ID: {data['job_id']}")
|
||||
print_info(f" Queue position: {data['queue_position']}")
|
||||
|
||||
# Store job_id for later tests
|
||||
self.test_job_id = data["job_id"]
|
||||
|
||||
def test_api_get_job_status(self):
|
||||
"""Test GET /jobs/{job_id} returns job status"""
|
||||
if not hasattr(self, 'test_job_id'):
|
||||
print_info(" Skipping (no test_job_id from previous test)")
|
||||
return
|
||||
|
||||
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "job_id" in data, "Should return job_id"
|
||||
assert "status" in data, "Should return status"
|
||||
assert data["status"] in ["queued", "running", "completed", "failed"], \
|
||||
f"Invalid status: {data['status']}"
|
||||
|
||||
print_info(f" Status: {data['status']}")
|
||||
|
||||
def test_api_get_job_status_not_found(self):
|
||||
"""Test GET /jobs/{job_id} with invalid ID returns 404"""
|
||||
fake_job_id = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
resp = requests.get(f"{self.api_url}/jobs/{fake_job_id}")
|
||||
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "error" in data["detail"], "Should have error field"
|
||||
assert data["detail"]["error"] == "Job not found", f"Wrong error: {data['detail']['error']}"
|
||||
|
||||
def test_api_get_job_result_not_completed(self):
|
||||
"""Test GET /jobs/{job_id}/result when job not completed returns 409"""
|
||||
if not hasattr(self, 'test_job_id'):
|
||||
print_info(" Skipping (no test_job_id from previous test)")
|
||||
return
|
||||
|
||||
# Check current status
|
||||
status_resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
|
||||
current_status = status_resp.json()["status"]
|
||||
|
||||
if current_status == "completed":
|
||||
print_info(" Skipping (job already completed)")
|
||||
return
|
||||
|
||||
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}/result")
|
||||
assert resp.status_code == 409, f"Expected 409, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "error" in data["detail"], "Should have error field"
|
||||
assert data["detail"]["error"] == "Job not completed", f"Wrong error: {data['detail']['error']}"
|
||||
assert "current_status" in data["detail"], "Should include current_status"
|
||||
|
||||
def test_api_list_jobs(self):
|
||||
"""Test GET /jobs returns job list"""
|
||||
resp = requests.get(f"{self.api_url}/jobs")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "jobs" in data, "Should have jobs field"
|
||||
assert "total" in data, "Should have total field"
|
||||
assert isinstance(data["jobs"], list), "Jobs should be a list"
|
||||
|
||||
print_info(f" Total jobs: {data['total']}")
|
||||
|
||||
def test_api_list_jobs_with_filter(self):
|
||||
"""Test GET /jobs?status=queued filters by status"""
|
||||
resp = requests.get(f"{self.api_url}/jobs?status=queued&limit=10")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "jobs" in data, "Should have jobs field"
|
||||
assert "filters" in data, "Should have filters field"
|
||||
assert data["filters"]["status"] == "queued", "Filter should be applied"
|
||||
|
||||
# All returned jobs should be queued
|
||||
for job in data["jobs"]:
|
||||
assert job["status"] == "queued", f"Job {job['job_id']} has wrong status: {job['status']}"
|
||||
|
||||
def test_api_wait_for_job_completion(self):
|
||||
"""Test waiting for job to complete and retrieving result"""
|
||||
if not hasattr(self, 'test_job_id'):
|
||||
logger.warning("Skipping - no test_job_id from previous test")
|
||||
print_info(" Skipping (no test_job_id from previous test)")
|
||||
return
|
||||
|
||||
logger.info(f"Waiting for job {self.test_job_id} to complete (max 60s)...")
|
||||
print_info(" Waiting for job to complete (max 60s)...")
|
||||
|
||||
max_wait = 60
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < max_wait:
|
||||
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
|
||||
data = resp.json()
|
||||
status = data["status"]
|
||||
elapsed = int(time.time() - start_time)
|
||||
|
||||
logger.info(f"Job status: {status} (elapsed: {elapsed}s)")
|
||||
print_info(f" Status: {status} (elapsed: {elapsed}s)")
|
||||
|
||||
if status == "completed":
|
||||
logger.info("Job completed successfully!")
|
||||
print_success(" Job completed!")
|
||||
|
||||
# Now get the result
|
||||
logger.info("Fetching job result...")
|
||||
result_resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}/result")
|
||||
logger.info(f"Result response status: {result_resp.status_code}")
|
||||
assert result_resp.status_code == 200, f"Expected 200, got {result_resp.status_code}"
|
||||
|
||||
result_data = result_resp.json()
|
||||
logger.info(f"Result data keys: {result_data.keys()}")
|
||||
assert "result" in result_data, "Should have result field"
|
||||
assert len(result_data["result"]) > 0, "Result should not be empty"
|
||||
|
||||
actual_text = result_data["result"].strip()
|
||||
logger.info(f"Transcription result: '{actual_text}'")
|
||||
print_info(f" Transcription: '{actual_text}'")
|
||||
return
|
||||
|
||||
elif status == "failed":
|
||||
error_msg = f"Job failed: {data.get('error', 'unknown error')}"
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
error_msg = f"Job did not complete within {max_wait}s"
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
# ========================================================================
|
||||
# MCP Server Tests (Import-based)
|
||||
# ========================================================================
|
||||
|
||||
def test_mcp_imports(self):
|
||||
"""Test MCP server modules can be imported"""
|
||||
try:
|
||||
logger.info("Importing MCP server module...")
|
||||
from servers import whisper_server
|
||||
|
||||
logger.info("Checking for new async tools...")
|
||||
assert hasattr(whisper_server, 'transcribe_async'), "Should have transcribe_async tool"
|
||||
assert hasattr(whisper_server, 'get_job_status'), "Should have get_job_status tool"
|
||||
assert hasattr(whisper_server, 'get_job_result'), "Should have get_job_result tool"
|
||||
assert hasattr(whisper_server, 'list_transcription_jobs'), "Should have list_transcription_jobs tool"
|
||||
assert hasattr(whisper_server, 'check_gpu_health'), "Should have check_gpu_health tool"
|
||||
assert hasattr(whisper_server, 'get_model_info_api'), "Should have get_model_info_api tool"
|
||||
logger.info("All new tools found!")
|
||||
|
||||
# Verify old tools are removed
|
||||
logger.info("Verifying old tools are removed...")
|
||||
assert not hasattr(whisper_server, 'transcribe'), "Old transcribe tool should be removed"
|
||||
assert not hasattr(whisper_server, 'batch_transcribe_audio'), "Old batch_transcribe_audio tool should be removed"
|
||||
logger.info("Old tools successfully removed!")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import MCP server: {e}")
|
||||
raise AssertionError(f"Failed to import MCP server: {e}")
|
||||
|
||||
def test_job_queue_integration(self):
|
||||
"""Test JobQueue integration is working"""
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
|
||||
# Create a test queue
|
||||
test_queue = JobQueue(max_queue_size=5, metadata_dir="/tmp/test_job_queue")
|
||||
|
||||
try:
|
||||
# Verify it can be started
|
||||
test_queue.start()
|
||||
assert test_queue._worker_thread is not None, "Worker thread should be created"
|
||||
assert test_queue._worker_thread.is_alive(), "Worker thread should be running"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
test_queue.stop(wait_for_current=False)
|
||||
|
||||
def test_health_monitor_integration(self):
|
||||
"""Test HealthMonitor integration is working"""
|
||||
from core.gpu_health import HealthMonitor
|
||||
|
||||
# Create a test monitor
|
||||
test_monitor = HealthMonitor(check_interval_minutes=60) # Long interval
|
||||
|
||||
try:
|
||||
# Verify it can be started
|
||||
test_monitor.start()
|
||||
assert test_monitor._thread is not None, "Monitor thread should be created"
|
||||
assert test_monitor._thread.is_alive(), "Monitor thread should be running"
|
||||
|
||||
# Check we can get status
|
||||
status = test_monitor.get_latest_status()
|
||||
assert status is not None, "Should have initial status"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
test_monitor.stop()
|
||||
|
||||
# ========================================================================
|
||||
# Helper Methods
|
||||
# ========================================================================
|
||||
|
||||
def _create_test_audio_file(self):
|
||||
"""Get the path to the test audio file"""
|
||||
# Use relative path from project root
|
||||
project_root = Path(__file__).parent.parent
|
||||
test_audio_path = str(project_root / "data" / "test.mp3")
|
||||
if not os.path.exists(test_audio_path):
|
||||
raise FileNotFoundError(f"Test audio file not found: {test_audio_path}")
|
||||
return test_audio_path
|
||||
|
||||
|
||||
def main():
|
||||
print_section("PHASE 2: ASYNC JOB QUEUE INTEGRATION TESTS")
|
||||
logger.info("=" * 70)
|
||||
logger.info("PHASE 2: ASYNC JOB QUEUE INTEGRATION TESTS")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Check if API server is running
|
||||
api_url = os.getenv("API_URL", "http://localhost:8000")
|
||||
logger.info(f"Testing API server at: {api_url}")
|
||||
print_info(f"Testing API server at: {api_url}")
|
||||
|
||||
try:
|
||||
logger.info("Checking API server health...")
|
||||
resp = requests.get(f"{api_url}/health", timeout=2)
|
||||
logger.info(f"Health check status: {resp.status_code}")
|
||||
if resp.status_code != 200:
|
||||
logger.error(f"API server not responding correctly at {api_url}")
|
||||
print_error(f"API server not responding correctly at {api_url}")
|
||||
print_error("Please start the API server with: ./run_api_server.sh")
|
||||
return 1
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Cannot connect to API server: {e}")
|
||||
print_error(f"Cannot connect to API server at {api_url}")
|
||||
print_error("Please start the API server with: ./run_api_server.sh")
|
||||
return 1
|
||||
|
||||
logger.info(f"API server is running at {api_url}")
|
||||
print_success(f"API server is running at {api_url}")
|
||||
|
||||
# Create tester
|
||||
tester = Phase2Tester(api_url=api_url)
|
||||
|
||||
# ========================================================================
|
||||
# Run API Tests
|
||||
# ========================================================================
|
||||
print_section("API SERVER TESTS")
|
||||
logger.info("Starting API server tests...")
|
||||
|
||||
tester.test("API Root Endpoint", tester.test_api_root_endpoint)
|
||||
tester.test("API Health Endpoint", tester.test_api_health_endpoint)
|
||||
tester.test("API Models Endpoint", tester.test_api_models_endpoint)
|
||||
tester.test("API GPU Health Endpoint", tester.test_api_gpu_health_endpoint)
|
||||
|
||||
print_section("API JOB SUBMISSION TESTS")
|
||||
|
||||
tester.test("Submit Job - Invalid Audio (400)", tester.test_api_submit_job_invalid_audio)
|
||||
tester.test("Submit Job - CPU Device Rejected (400)", tester.test_api_submit_job_cpu_device_rejected)
|
||||
tester.test("Submit Job - Success (200)", tester.test_api_submit_job_success)
|
||||
|
||||
print_section("API JOB STATUS TESTS")
|
||||
|
||||
tester.test("Get Job Status - Success", tester.test_api_get_job_status)
|
||||
tester.test("Get Job Status - Not Found (404)", tester.test_api_get_job_status_not_found)
|
||||
tester.test("Get Job Result - Not Completed (409)", tester.test_api_get_job_result_not_completed)
|
||||
|
||||
print_section("API JOB LISTING TESTS")
|
||||
|
||||
tester.test("List Jobs", tester.test_api_list_jobs)
|
||||
tester.test("List Jobs - With Filter", tester.test_api_list_jobs_with_filter)
|
||||
|
||||
print_section("API JOB COMPLETION TEST")
|
||||
|
||||
tester.test("Wait for Job Completion & Get Result", tester.test_api_wait_for_job_completion)
|
||||
|
||||
# ========================================================================
|
||||
# Run MCP Tests
|
||||
# ========================================================================
|
||||
print_section("MCP SERVER TESTS")
|
||||
logger.info("Starting MCP server tests...")
|
||||
|
||||
tester.test("MCP Module Imports", tester.test_mcp_imports)
|
||||
tester.test("JobQueue Integration", tester.test_job_queue_integration)
|
||||
tester.test("HealthMonitor Integration", tester.test_health_monitor_integration)
|
||||
|
||||
# ========================================================================
|
||||
# Print Summary
|
||||
# ========================================================================
|
||||
logger.info("All tests completed, generating summary...")
|
||||
success = tester.print_summary()
|
||||
|
||||
if success:
|
||||
logger.info("ALL TESTS PASSED!")
|
||||
print_section("ALL TESTS PASSED! ✓")
|
||||
return 0
|
||||
else:
|
||||
logger.error("SOME TESTS FAILED!")
|
||||
print_section("SOME TESTS FAILED! ✗")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
287
tests/test_core_components.py
Executable file
287
tests/test_core_components.py
Executable file
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Test script for Phase 1 components.
|
||||
|
||||
Tests:
|
||||
1. Test audio file validation
|
||||
2. GPU health check
|
||||
3. Job queue operations
|
||||
|
||||
IMPORTANT: This service requires GPU. Tests will fail if GPU is not available.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
|
||||
# Add src to path (go up one level from tests/ to root)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from core.gpu_health import check_gpu_health, HealthMonitor
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
|
||||
|
||||
def check_gpu_available():
|
||||
"""
|
||||
Check if GPU is available. Exit if not.
|
||||
This service requires GPU and will not run on CPU.
|
||||
"""
|
||||
print("\n" + "="*60)
|
||||
print("GPU REQUIREMENT CHECK")
|
||||
print("="*60)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("✗ CUDA not available - GPU is required for this service")
|
||||
print(" This service is configured for GPU-only operation")
|
||||
print(" Please ensure CUDA is properly installed and GPU is accessible")
|
||||
print("="*60)
|
||||
sys.exit(1)
|
||||
|
||||
gpu_name = torch.cuda.get_device_name(0)
|
||||
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||||
print(f"✓ GPU available: {gpu_name}")
|
||||
print(f"✓ GPU memory: {gpu_memory:.2f} GB")
|
||||
print("="*60)
|
||||
|
||||
|
||||
def test_audio_file():
|
||||
"""Test audio file existence and validity."""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 1: Test Audio File")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Use the actual test audio file (relative to project root)
|
||||
project_root = os.path.join(os.path.dirname(__file__), '..')
|
||||
audio_path = os.path.join(project_root, "data/test.mp3")
|
||||
|
||||
# Verify file exists
|
||||
assert os.path.exists(audio_path), "Audio file not found"
|
||||
print(f"✓ Test audio file exists: {audio_path}")
|
||||
|
||||
# Verify file is not empty
|
||||
file_size = os.path.getsize(audio_path)
|
||||
assert file_size > 0, "Audio file is empty"
|
||||
print(f"✓ Audio file size: {file_size} bytes")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Audio file test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_gpu_health():
|
||||
"""Test GPU health check."""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 2: GPU Health Check")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Test with cuda device (enforcing GPU requirement)
|
||||
print("\nRunning health check with device='cuda'...")
|
||||
logging.info("Starting GPU health check...")
|
||||
status = check_gpu_health(expected_device="cuda")
|
||||
logging.info("GPU health check completed")
|
||||
|
||||
print(f"✓ Health check completed")
|
||||
print(f" - GPU available: {status.gpu_available}")
|
||||
print(f" - GPU working: {status.gpu_working}")
|
||||
print(f" - Device used: {status.device_used}")
|
||||
print(f" - Device name: {status.device_name}")
|
||||
print(f" - Memory total: {status.memory_total_gb:.2f} GB")
|
||||
print(f" - Memory available: {status.memory_available_gb:.2f} GB")
|
||||
print(f" - Test duration: {status.test_duration_seconds:.2f}s")
|
||||
print(f" - Error: {status.error}")
|
||||
|
||||
# Test health monitor
|
||||
print("\nTesting HealthMonitor...")
|
||||
monitor = HealthMonitor(check_interval_minutes=1)
|
||||
monitor.start()
|
||||
print("✓ Health monitor started")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
latest = monitor.get_latest_status()
|
||||
assert latest is not None, "No status available from monitor"
|
||||
print(f"✓ Latest status retrieved: {latest.device_used}")
|
||||
|
||||
monitor.stop()
|
||||
print("✓ Health monitor stopped")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ GPU health test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_job_queue():
|
||||
"""Test job queue operations."""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 3: Job Queue")
|
||||
print("="*60)
|
||||
|
||||
# Create temp directory for testing
|
||||
import tempfile
|
||||
temp_dir = tempfile.mkdtemp(prefix="job_queue_test_")
|
||||
print(f"Using temp directory: {temp_dir}")
|
||||
|
||||
try:
|
||||
# Initialize job queue
|
||||
print("\nInitializing job queue...")
|
||||
job_queue = JobQueue(max_queue_size=10, metadata_dir=temp_dir)
|
||||
job_queue.start()
|
||||
print("✓ Job queue started")
|
||||
|
||||
# Use the actual test audio file (relative to project root)
|
||||
project_root = os.path.join(os.path.dirname(__file__), '..')
|
||||
audio_path = os.path.join(project_root, "data/test.mp3")
|
||||
|
||||
# Test job submission
|
||||
print("\nSubmitting test job...")
|
||||
logging.info("Submitting transcription job to queue...")
|
||||
job_info = job_queue.submit_job(
|
||||
audio_path=audio_path,
|
||||
model_name="tiny",
|
||||
device="cuda", # Enforcing GPU requirement
|
||||
output_format="txt"
|
||||
)
|
||||
job_id = job_info["job_id"]
|
||||
logging.info(f"Job submitted: {job_id}")
|
||||
print(f"✓ Job submitted: {job_id}")
|
||||
print(f" - Status: {job_info['status']}")
|
||||
print(f" - Queue position: {job_info['queue_position']}")
|
||||
|
||||
# Test job status retrieval
|
||||
print("\nRetrieving job status...")
|
||||
logging.info("About to call get_job_status()...")
|
||||
status = job_queue.get_job_status(job_id)
|
||||
logging.info(f"get_job_status() returned: {status['status']}")
|
||||
print(f"✓ Job status retrieved")
|
||||
print(f" - Status: {status['status']}")
|
||||
print(f" - Queue position: {status['queue_position']}")
|
||||
|
||||
# Wait for job to process
|
||||
print("\nWaiting for job to process (max 30 seconds)...", flush=True)
|
||||
logging.info("Waiting for transcription to complete...")
|
||||
max_wait = 30
|
||||
start = time.time()
|
||||
while time.time() - start < max_wait:
|
||||
logging.info("Calling get_job_status()...")
|
||||
status = job_queue.get_job_status(job_id)
|
||||
print(f" Status: {status['status']}", flush=True)
|
||||
logging.info(f"Job status: {status['status']}")
|
||||
|
||||
if status['status'] in ['completed', 'failed']:
|
||||
logging.info("Job completed or failed, breaking out of loop")
|
||||
break
|
||||
|
||||
logging.info("Job still running, sleeping 2 seconds...")
|
||||
time.sleep(2)
|
||||
|
||||
final_status = job_queue.get_job_status(job_id)
|
||||
print(f"\nFinal job status: {final_status['status']}")
|
||||
|
||||
if final_status['status'] == 'completed':
|
||||
print(f"✓ Job completed successfully")
|
||||
print(f" - Result path: {final_status['result_path']}")
|
||||
print(f" - Processing time: {final_status['processing_time_seconds']:.2f}s")
|
||||
|
||||
# Test result retrieval
|
||||
print("\nRetrieving job result...")
|
||||
logging.info("Calling get_job_result()...")
|
||||
result = job_queue.get_job_result(job_id)
|
||||
logging.info(f"Result retrieved: {len(result)} characters")
|
||||
print(f"✓ Result retrieved ({len(result)} characters)")
|
||||
print(f" Preview: {result[:100]}...")
|
||||
|
||||
elif final_status['status'] == 'failed':
|
||||
print(f"✗ Job failed: {final_status['error']}")
|
||||
|
||||
# Test persistence by stopping and restarting
|
||||
print("\nTesting persistence...")
|
||||
logging.info("Stopping job queue...")
|
||||
job_queue.stop(wait_for_current=False)
|
||||
print("✓ Job queue stopped")
|
||||
logging.info("Job queue stopped")
|
||||
|
||||
logging.info("Restarting job queue...")
|
||||
job_queue2 = JobQueue(max_queue_size=10, metadata_dir=temp_dir)
|
||||
job_queue2.start()
|
||||
print("✓ Job queue restarted")
|
||||
logging.info("Job queue restarted")
|
||||
|
||||
logging.info("Checking job status after restart...")
|
||||
status_after_restart = job_queue2.get_job_status(job_id)
|
||||
print(f"✓ Job still exists after restart: {status_after_restart['status']}")
|
||||
logging.info(f"Job status after restart: {status_after_restart['status']}")
|
||||
|
||||
logging.info("Stopping job queue 2...")
|
||||
job_queue2.stop()
|
||||
logging.info("Job queue 2 stopped")
|
||||
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir)
|
||||
print(f"✓ Cleaned up temp directory")
|
||||
|
||||
return final_status['status'] == 'completed'
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Job queue test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("\n" + "="*60)
|
||||
print("PHASE 1 COMPONENT TESTS")
|
||||
print("="*60)
|
||||
|
||||
# Check GPU availability first - exit if no GPU
|
||||
check_gpu_available()
|
||||
|
||||
results = {
|
||||
"Test Audio File": test_audio_file(),
|
||||
"GPU Health Check": test_gpu_health(),
|
||||
"Job Queue": test_job_queue(),
|
||||
}
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("TEST SUMMARY")
|
||||
print("="*60)
|
||||
|
||||
for test_name, passed in results.items():
|
||||
status = "✓ PASSED" if passed else "✗ FAILED"
|
||||
print(f"{test_name:.<40} {status}")
|
||||
|
||||
all_passed = all(results.values())
|
||||
print("\n" + "="*60)
|
||||
if all_passed:
|
||||
print("ALL TESTS PASSED ✓")
|
||||
else:
|
||||
print("SOME TESTS FAILED ✗")
|
||||
print("="*60)
|
||||
|
||||
return 0 if all_passed else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
523
tests/test_e2e_integration.py
Executable file
523
tests/test_e2e_integration.py
Executable file
@@ -0,0 +1,523 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Phase 4: End-to-End Integration Testing
|
||||
|
||||
Comprehensive integration tests for the async job queue system.
|
||||
Tests all scenarios from the DEV_PLAN.md Phase 4 checklist.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
import subprocess
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add src to path (go up one level from tests/ to root)
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
# Color codes for terminal output
|
||||
class Colors:
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
CYAN = '\033[96m'
|
||||
END = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
|
||||
def print_success(msg):
|
||||
print(f"{Colors.GREEN}✓ {msg}{Colors.END}")
|
||||
|
||||
def print_error(msg):
|
||||
print(f"{Colors.RED}✗ {msg}{Colors.END}")
|
||||
|
||||
def print_info(msg):
|
||||
print(f"{Colors.BLUE}ℹ {msg}{Colors.END}")
|
||||
|
||||
def print_warning(msg):
|
||||
print(f"{Colors.YELLOW}⚠ {msg}{Colors.END}")
|
||||
|
||||
def print_section(msg):
|
||||
print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{msg}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}\n")
|
||||
|
||||
|
||||
class Phase4Tester:
|
||||
def __init__(self, api_url="http://localhost:8000", test_audio=None):
|
||||
self.api_url = api_url
|
||||
# Use relative path from project root if not provided
|
||||
if test_audio is None:
|
||||
project_root = Path(__file__).parent.parent
|
||||
test_audio = str(project_root / "data" / "test.mp3")
|
||||
self.test_audio = test_audio
|
||||
self.test_results = []
|
||||
self.server_process = None
|
||||
|
||||
# Verify test audio exists
|
||||
if not os.path.exists(test_audio):
|
||||
raise FileNotFoundError(f"Test audio file not found: {test_audio}")
|
||||
|
||||
def test(self, name, func):
|
||||
"""Run a test and record result"""
|
||||
try:
|
||||
logger.info(f"Testing: {name}")
|
||||
print_info(f"Testing: {name}")
|
||||
func()
|
||||
logger.info(f"PASSED: {name}")
|
||||
print_success(f"PASSED: {name}")
|
||||
self.test_results.append((name, True, None))
|
||||
return True
|
||||
except AssertionError as e:
|
||||
logger.error(f"FAILED: {name} - {str(e)}")
|
||||
print_error(f"FAILED: {name}")
|
||||
print_error(f" Reason: {str(e)}")
|
||||
self.test_results.append((name, False, str(e)))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR: {name} - {str(e)}")
|
||||
print_error(f"ERROR: {name}")
|
||||
print_error(f" Exception: {str(e)}")
|
||||
self.test_results.append((name, False, f"Exception: {str(e)}"))
|
||||
return False
|
||||
|
||||
def start_api_server(self, wait_time=5):
|
||||
"""Start the API server in background"""
|
||||
print_info("Starting API server...")
|
||||
# Script is in project root, one level up from tests/
|
||||
script_path = Path(__file__).parent.parent / "run_api_server.sh"
|
||||
|
||||
# Start server in background
|
||||
self.server_process = subprocess.Popen(
|
||||
[str(script_path)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
preexec_fn=os.setsid
|
||||
)
|
||||
|
||||
# Wait for server to start
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Verify server is running
|
||||
try:
|
||||
response = requests.get(f"{self.api_url}/health", timeout=5)
|
||||
if response.status_code == 200:
|
||||
print_success("API server started successfully")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
print_error("API server failed to start")
|
||||
return False
|
||||
|
||||
def stop_api_server(self):
|
||||
"""Stop the API server"""
|
||||
if self.server_process:
|
||||
print_info("Stopping API server...")
|
||||
os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM)
|
||||
self.server_process.wait(timeout=10)
|
||||
print_success("API server stopped")
|
||||
|
||||
def wait_for_job_completion(self, job_id, timeout=60, poll_interval=2):
|
||||
"""Poll job status until completed or failed"""
|
||||
start_time = time.time()
|
||||
last_status = None
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}")
|
||||
assert response.status_code == 200, f"Failed to get job status: {response.status_code}"
|
||||
|
||||
status_data = response.json()
|
||||
current_status = status_data['status']
|
||||
|
||||
# Print status changes
|
||||
if current_status != last_status:
|
||||
if status_data.get('queue_position') is not None:
|
||||
print_info(f" Job status: {current_status}, queue position: {status_data['queue_position']}")
|
||||
else:
|
||||
print_info(f" Job status: {current_status}")
|
||||
last_status = current_status
|
||||
|
||||
if current_status == "completed":
|
||||
return status_data
|
||||
elif current_status == "failed":
|
||||
raise AssertionError(f"Job failed: {status_data.get('error', 'Unknown error')}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise AssertionError(f"Request failed: {e}")
|
||||
|
||||
raise AssertionError(f"Job did not complete within {timeout} seconds")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 1: Single Job Submission and Completion
|
||||
# ========================================================================
|
||||
def test_single_job_flow(self):
|
||||
"""Test complete job flow: submit → poll → get result"""
|
||||
# Submit job
|
||||
print_info(" Submitting job...")
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "large-v3",
|
||||
"output_format": "txt"
|
||||
})
|
||||
assert response.status_code == 200, f"Job submission failed: {response.status_code}"
|
||||
|
||||
job_data = response.json()
|
||||
assert 'job_id' in job_data, "No job_id in response"
|
||||
# Status can be 'queued' or 'running' (if queue is empty and job starts immediately)
|
||||
assert job_data['status'] in ['queued', 'running'], f"Expected status 'queued' or 'running', got '{job_data['status']}'"
|
||||
|
||||
job_id = job_data['job_id']
|
||||
print_success(f" Job submitted: {job_id}")
|
||||
|
||||
# Wait for completion
|
||||
print_info(" Waiting for job completion...")
|
||||
final_status = self.wait_for_job_completion(job_id)
|
||||
|
||||
assert final_status['status'] == 'completed', "Job did not complete"
|
||||
assert final_status['result_path'] is not None, "No result_path in completed job"
|
||||
assert final_status['processing_time_seconds'] is not None, "No processing time"
|
||||
print_success(f" Job completed in {final_status['processing_time_seconds']:.2f}s")
|
||||
|
||||
# Get result
|
||||
print_info(" Retrieving result...")
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
|
||||
assert response.status_code == 200, f"Failed to get result: {response.status_code}"
|
||||
|
||||
result_text = response.text
|
||||
assert len(result_text) > 0, "Empty result"
|
||||
print_success(f" Got result: {len(result_text)} characters")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 2: Multiple Jobs in Queue (FIFO)
|
||||
# ========================================================================
|
||||
def test_multiple_jobs_fifo(self):
|
||||
"""Test multiple jobs are processed in FIFO order"""
|
||||
job_ids = []
|
||||
|
||||
# Submit 3 jobs
|
||||
print_info(" Submitting 3 jobs...")
|
||||
for i in range(3):
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "tiny", # Use tiny model for faster processing
|
||||
"output_format": "txt"
|
||||
})
|
||||
assert response.status_code == 200, f"Job {i+1} submission failed"
|
||||
|
||||
job_data = response.json()
|
||||
job_ids.append(job_data['job_id'])
|
||||
print_info(f" Job {i+1} submitted: {job_data['job_id']}, queue_position: {job_data.get('queue_position', 0)}")
|
||||
|
||||
# Wait for all jobs to complete
|
||||
print_info(" Waiting for all jobs to complete...")
|
||||
for i, job_id in enumerate(job_ids):
|
||||
print_info(f" Waiting for job {i+1}/{len(job_ids)}...")
|
||||
final_status = self.wait_for_job_completion(job_id, timeout=120)
|
||||
assert final_status['status'] == 'completed', f"Job {i+1} failed"
|
||||
|
||||
print_success(f" All {len(job_ids)} jobs completed in FIFO order")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 3: GPU Health Check
|
||||
# ========================================================================
|
||||
def test_gpu_health_check(self):
|
||||
"""Test GPU health check endpoint"""
|
||||
print_info(" Checking GPU health...")
|
||||
response = requests.get(f"{self.api_url}/health/gpu")
|
||||
assert response.status_code == 200, f"GPU health check failed: {response.status_code}"
|
||||
|
||||
health_data = response.json()
|
||||
assert 'gpu_available' in health_data, "Missing gpu_available field"
|
||||
assert 'gpu_working' in health_data, "Missing gpu_working field"
|
||||
assert 'device_used' in health_data, "Missing device_used field"
|
||||
|
||||
print_info(f" GPU Available: {health_data['gpu_available']}")
|
||||
print_info(f" GPU Working: {health_data['gpu_working']}")
|
||||
print_info(f" Device: {health_data['device_used']}")
|
||||
|
||||
if health_data['gpu_available']:
|
||||
assert health_data['device_name'], "GPU available but no device_name"
|
||||
assert health_data['test_duration_seconds'] < 3, "GPU test took too long (might be using CPU)"
|
||||
print_success(f" GPU is healthy: {health_data['device_name']}")
|
||||
else:
|
||||
print_warning(" GPU not available on this system")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 4: Invalid Audio Path
|
||||
# ========================================================================
|
||||
def test_invalid_audio_path(self):
|
||||
"""Test job submission with invalid audio path"""
|
||||
print_info(" Submitting job with invalid path...")
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": "/invalid/path/does/not/exist.mp3",
|
||||
"model_name": "large-v3"
|
||||
})
|
||||
|
||||
# Should return 400 Bad Request
|
||||
assert response.status_code == 400, f"Expected 400, got {response.status_code}"
|
||||
|
||||
error_data = response.json()
|
||||
assert 'detail' in error_data or 'error' in error_data, "No error message in response"
|
||||
print_success(" Invalid path rejected correctly")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 5: Job Not Found
|
||||
# ========================================================================
|
||||
def test_job_not_found(self):
|
||||
"""Test retrieving non-existent job"""
|
||||
print_info(" Requesting non-existent job...")
|
||||
fake_job_id = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
response = requests.get(f"{self.api_url}/jobs/{fake_job_id}")
|
||||
assert response.status_code == 404, f"Expected 404, got {response.status_code}"
|
||||
print_success(" Non-existent job handled correctly")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 6: Result Before Completion
|
||||
# ========================================================================
|
||||
def test_result_before_completion(self):
|
||||
"""Test getting result for job that hasn't completed"""
|
||||
print_info(" Submitting job and trying to get result immediately...")
|
||||
|
||||
# Submit job
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "large-v3"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
job_id = response.json()['job_id']
|
||||
|
||||
# Try to get result immediately (job is still queued/running)
|
||||
time.sleep(0.5)
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
|
||||
|
||||
# Should return 409 Conflict or similar
|
||||
assert response.status_code in [409, 400, 404], f"Expected 4xx error, got {response.status_code}"
|
||||
print_success(" Result request before completion handled correctly")
|
||||
|
||||
# Clean up: wait for job to complete
|
||||
self.wait_for_job_completion(job_id)
|
||||
|
||||
# ========================================================================
|
||||
# TEST 7: List Jobs
|
||||
# ========================================================================
|
||||
def test_list_jobs(self):
|
||||
"""Test listing jobs with filters"""
|
||||
print_info(" Testing job listing...")
|
||||
|
||||
# List all jobs
|
||||
response = requests.get(f"{self.api_url}/jobs")
|
||||
assert response.status_code == 200, f"List jobs failed: {response.status_code}"
|
||||
|
||||
jobs_data = response.json()
|
||||
assert 'jobs' in jobs_data, "No jobs array in response"
|
||||
assert isinstance(jobs_data['jobs'], list), "Jobs is not a list"
|
||||
print_info(f" Found {len(jobs_data['jobs'])} jobs")
|
||||
|
||||
# List only completed jobs
|
||||
response = requests.get(f"{self.api_url}/jobs?status=completed")
|
||||
assert response.status_code == 200
|
||||
completed_jobs = response.json()['jobs']
|
||||
print_info(f" Found {len(completed_jobs)} completed jobs")
|
||||
|
||||
# List with limit
|
||||
response = requests.get(f"{self.api_url}/jobs?limit=5")
|
||||
assert response.status_code == 200
|
||||
limited_jobs = response.json()['jobs']
|
||||
assert len(limited_jobs) <= 5, "Limit not respected"
|
||||
print_success(" Job listing works correctly")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 8: Server Restart with Job Persistence
|
||||
# ========================================================================
|
||||
def test_server_restart_persistence(self):
|
||||
"""Test that jobs persist across server restarts"""
|
||||
print_info(" Testing job persistence across restart...")
|
||||
|
||||
# Submit a job
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "tiny"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
job_id = response.json()['job_id']
|
||||
print_info(f" Submitted job: {job_id}")
|
||||
|
||||
# Get job count before restart
|
||||
response = requests.get(f"{self.api_url}/jobs")
|
||||
jobs_before = len(response.json()['jobs'])
|
||||
print_info(f" Jobs before restart: {jobs_before}")
|
||||
|
||||
# Restart server
|
||||
print_info(" Restarting server...")
|
||||
self.stop_api_server()
|
||||
time.sleep(2)
|
||||
assert self.start_api_server(wait_time=8), "Server failed to restart"
|
||||
|
||||
# Check jobs after restart
|
||||
response = requests.get(f"{self.api_url}/jobs")
|
||||
assert response.status_code == 200
|
||||
jobs_after = len(response.json()['jobs'])
|
||||
print_info(f" Jobs after restart: {jobs_after}")
|
||||
|
||||
# Check our specific job is still there (this is the key test)
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}")
|
||||
assert response.status_code == 200, "Job not found after restart"
|
||||
|
||||
# Note: Total count may differ due to job retention/cleanup, but persistence works if we can find the job
|
||||
if jobs_after < jobs_before:
|
||||
print_warning(f" Job count decreased ({jobs_before} -> {jobs_after}), may be due to cleanup")
|
||||
|
||||
print_success(" Jobs persisted correctly across restart")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 9: Health Endpoint
|
||||
# ========================================================================
|
||||
def test_health_endpoint(self):
|
||||
"""Test basic health endpoint"""
|
||||
print_info(" Checking health endpoint...")
|
||||
response = requests.get(f"{self.api_url}/health")
|
||||
assert response.status_code == 200, f"Health check failed: {response.status_code}"
|
||||
|
||||
health_data = response.json()
|
||||
assert health_data['status'] == 'healthy', "Server not healthy"
|
||||
print_success(" Health endpoint OK")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 10: Models Endpoint
|
||||
# ========================================================================
|
||||
def test_models_endpoint(self):
|
||||
"""Test models information endpoint"""
|
||||
print_info(" Checking models endpoint...")
|
||||
response = requests.get(f"{self.api_url}/models")
|
||||
assert response.status_code == 200, f"Models endpoint failed: {response.status_code}"
|
||||
|
||||
models_data = response.json()
|
||||
assert 'available_models' in models_data, "No available_models field"
|
||||
assert 'available_devices' in models_data, "No available_devices field"
|
||||
assert len(models_data['available_models']) > 0, "No models listed"
|
||||
print_info(f" Available models: {len(models_data['available_models'])}")
|
||||
print_success(" Models endpoint OK")
|
||||
|
||||
def print_summary(self):
|
||||
"""Print test summary"""
|
||||
print_section("TEST SUMMARY")
|
||||
|
||||
passed = sum(1 for _, result, _ in self.test_results if result)
|
||||
failed = len(self.test_results) - passed
|
||||
|
||||
for name, result, error in self.test_results:
|
||||
if result:
|
||||
print_success(f"{name}")
|
||||
else:
|
||||
print_error(f"{name}")
|
||||
if error:
|
||||
print(f" {error}")
|
||||
|
||||
print(f"\n{Colors.BOLD}Total: {len(self.test_results)} | ", end="")
|
||||
print(f"{Colors.GREEN}Passed: {passed}{Colors.END} | ", end="")
|
||||
print(f"{Colors.RED}Failed: {failed}{Colors.END}\n")
|
||||
|
||||
return failed == 0
|
||||
|
||||
def run_all_tests(self, start_server=True):
|
||||
"""Run all Phase 4 integration tests"""
|
||||
print_section("PHASE 4: END-TO-END INTEGRATION TESTING")
|
||||
|
||||
try:
|
||||
# Start server if requested
|
||||
if start_server:
|
||||
if not self.start_api_server():
|
||||
print_error("Failed to start API server. Aborting tests.")
|
||||
return False
|
||||
else:
|
||||
# Verify server is already running
|
||||
try:
|
||||
response = requests.get(f"{self.api_url}/health", timeout=5)
|
||||
if response.status_code != 200:
|
||||
print_error("Server is not responding. Please start it first.")
|
||||
return False
|
||||
print_info("Using existing API server")
|
||||
except:
|
||||
print_error("Cannot connect to API server. Please start it first.")
|
||||
return False
|
||||
|
||||
# Run tests
|
||||
print_section("TEST 1: Single Job Submission and Completion")
|
||||
self.test("Single job flow (submit → poll → get result)", self.test_single_job_flow)
|
||||
|
||||
print_section("TEST 2: Multiple Jobs (FIFO Order)")
|
||||
self.test("Multiple jobs in queue (FIFO)", self.test_multiple_jobs_fifo)
|
||||
|
||||
print_section("TEST 3: GPU Health Check")
|
||||
self.test("GPU health check endpoint", self.test_gpu_health_check)
|
||||
|
||||
print_section("TEST 4: Error Handling - Invalid Path")
|
||||
self.test("Invalid audio path rejection", self.test_invalid_audio_path)
|
||||
|
||||
print_section("TEST 5: Error Handling - Job Not Found")
|
||||
self.test("Non-existent job handling", self.test_job_not_found)
|
||||
|
||||
print_section("TEST 6: Error Handling - Result Before Completion")
|
||||
self.test("Result request before completion", self.test_result_before_completion)
|
||||
|
||||
print_section("TEST 7: Job Listing")
|
||||
self.test("List jobs with filters", self.test_list_jobs)
|
||||
|
||||
print_section("TEST 8: Health Endpoint")
|
||||
self.test("Basic health endpoint", self.test_health_endpoint)
|
||||
|
||||
print_section("TEST 9: Models Endpoint")
|
||||
self.test("Models information endpoint", self.test_models_endpoint)
|
||||
|
||||
print_section("TEST 10: Server Restart Persistence")
|
||||
self.test("Job persistence across server restart", self.test_server_restart_persistence)
|
||||
|
||||
# Print summary
|
||||
success = self.print_summary()
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if start_server and self.server_process:
|
||||
self.stop_api_server()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main test runner"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Phase 4 Integration Tests')
|
||||
parser.add_argument('--url', default='http://localhost:8000', help='API server URL')
|
||||
# Default to None so Phase4Tester uses relative path
|
||||
parser.add_argument('--audio', default=None,
|
||||
help='Path to test audio file (default: <project_root>/data/test.mp3)')
|
||||
parser.add_argument('--no-start-server', action='store_true',
|
||||
help='Do not start server (assume it is already running)')
|
||||
args = parser.parse_args()
|
||||
|
||||
tester = Phase4Tester(api_url=args.url, test_audio=args.audio)
|
||||
success = tester.run_all_tests(start_server=not args.no_start_server)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,108 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Faster Whisper-based Speech Recognition MCP Service
|
||||
Provides high-performance audio transcription with batch processing acceleration and multiple output formats
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from model_manager import get_model_info
|
||||
from transcriber import transcribe_audio, batch_transcribe
|
||||
|
||||
# Log configuration
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create FastMCP server instance
|
||||
mcp = FastMCP(
|
||||
name="fast-whisper-mcp-server",
|
||||
version="0.1.1",
|
||||
dependencies=["faster-whisper>=0.9.0", "torch==2.6.0+cu126", "torchaudio==2.6.0+cu126", "numpy>=1.20.0"]
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def get_model_info_api() -> str:
|
||||
"""
|
||||
Get available Whisper model information
|
||||
"""
|
||||
return get_model_info()
|
||||
|
||||
@mcp.tool()
|
||||
def transcribe(audio_path: str, model_name: str = "large-v3", device: str = "auto",
|
||||
compute_type: str = "auto", language: str = None, output_format: str = "vtt",
|
||||
beam_size: int = 5, temperature: float = 0.0, initial_prompt: str = None,
|
||||
output_directory: str = None) -> str:
|
||||
"""
|
||||
Transcribe audio files using Faster Whisper
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
model_name: Model name (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: Execution device (cpu, cuda, auto)
|
||||
compute_type: Computation type (float16, int8, auto)
|
||||
language: Language code (such as zh, en, ja, etc., auto-detect by default)
|
||||
output_format: Output format (vtt, srt, json or txt)
|
||||
beam_size: Beam search size, larger values may improve accuracy but reduce speed
|
||||
temperature: Sampling temperature, greedy decoding
|
||||
initial_prompt: Initial prompt text, can help the model better understand context
|
||||
output_directory: Output directory path, defaults to the audio file's directory
|
||||
|
||||
Returns:
|
||||
str: Transcription result, in VTT subtitle or JSON format
|
||||
"""
|
||||
return transcribe_audio(
|
||||
audio_path=audio_path,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
output_directory=output_directory
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def batch_transcribe_audio(audio_folder: str, output_folder: str = None, model_name: str = "large-v3",
|
||||
device: str = "auto", compute_type: str = "auto", language: str = None,
|
||||
output_format: str = "vtt", beam_size: int = 5, temperature: float = 0.0,
|
||||
initial_prompt: str = None, parallel_files: int = 1) -> str:
|
||||
"""
|
||||
Batch transcribe audio files in a folder
|
||||
|
||||
Args:
|
||||
audio_folder: Path to the folder containing audio files
|
||||
output_folder: Output folder path, defaults to a 'transcript' subfolder in audio_folder
|
||||
model_name: Model name (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: Execution device (cpu, cuda, auto)
|
||||
compute_type: Computation type (float16, int8, auto)
|
||||
language: Language code (such as zh, en, ja, etc., auto-detect by default)
|
||||
output_format: Output format (vtt, srt, json or txt)
|
||||
beam_size: Beam search size, larger values may improve accuracy but reduce speed
|
||||
temperature: Sampling temperature, 0 means greedy decoding
|
||||
initial_prompt: Initial prompt text, can help the model better understand context
|
||||
parallel_files: Number of files to process in parallel (only effective in CPU mode)
|
||||
|
||||
Returns:
|
||||
str: Batch processing summary, including processing time and success rate
|
||||
"""
|
||||
return batch_transcribe(
|
||||
audio_folder=audio_folder,
|
||||
output_folder=output_folder,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
parallel_files=parallel_files
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("starting mcp server for whisper stt transcriptor")
|
||||
mcp.run()
|
||||
Reference in New Issue
Block a user