Compare commits
8 Commits
5fb742a312
...
alihan-spe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
990fa28668 | ||
|
|
fb1e5dceba | ||
|
|
f6777b1488 | ||
|
|
3c0f79645c | ||
|
|
c6462e2bbe | ||
|
|
d47c2843c3 | ||
|
|
06b8bc1304 | ||
|
|
66b36e71e8 |
60
.dockerignore
Normal file
60
.dockerignore
Normal file
@@ -0,0 +1,60 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
.venv
|
||||
|
||||
# Project specific
|
||||
logs/
|
||||
outputs/
|
||||
models/
|
||||
*.log
|
||||
*.logs
|
||||
mcp.logs
|
||||
api.logs
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
.github/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Docker
|
||||
.dockerignore
|
||||
docker-compose.yml
|
||||
docker-compose.*.yml
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Documentation (optional - uncomment if you want to exclude)
|
||||
# README.md
|
||||
# CLAUDE.md
|
||||
# IMPLEMENTATION_PLAN.md
|
||||
|
||||
# Scripts (already in container)
|
||||
# reset_gpu.sh - NEEDED for GPU health checks
|
||||
run_api_server.sh
|
||||
run_mcp_server.sh
|
||||
|
||||
# Supervisor config (not needed in container)
|
||||
supervisor/
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -17,3 +17,8 @@ venv/
|
||||
logs/**
|
||||
User/**
|
||||
data/**
|
||||
models/*
|
||||
outputs/*
|
||||
api.logs
|
||||
|
||||
IMPLEMENTATION_PLAN.md
|
||||
|
||||
529
CLAUDE.md
529
CLAUDE.md
@@ -1,529 +0,0 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Overview
|
||||
|
||||
This is a Whisper-based speech recognition service that provides high-performance audio transcription using Faster Whisper. The service runs as either:
|
||||
|
||||
1. **MCP Server** - For integration with Claude Desktop and other MCP clients
|
||||
2. **REST API Server** - For HTTP-based integrations with async job queue support
|
||||
|
||||
Both servers share the same core transcription logic and can run independently or simultaneously on different ports.
|
||||
|
||||
**Key Features:**
|
||||
- Async job queue system for long-running transcriptions (prevents HTTP timeouts)
|
||||
- GPU health monitoring with strict failure detection (prevents silent CPU fallback)
|
||||
- **Automatic GPU driver reset** on CUDA errors with cooldown protection (handles sleep/wake issues)
|
||||
- Dual-server architecture (MCP + REST API)
|
||||
- Model caching for fast repeated transcriptions
|
||||
- Automatic batch size optimization based on GPU memory
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Environment Setup
|
||||
```bash
|
||||
# Create and activate virtual environment
|
||||
python3.12 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Install PyTorch with CUDA 12.6 support
|
||||
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
|
||||
# For CUDA 12.1
|
||||
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# For CPU-only
|
||||
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
### Running the Servers
|
||||
|
||||
#### MCP Server (for Claude Desktop)
|
||||
|
||||
```bash
|
||||
# Using the startup script (recommended - sets all env vars)
|
||||
./run_mcp_server.sh
|
||||
|
||||
# Direct Python execution
|
||||
python whisper_server.py
|
||||
|
||||
# Using MCP CLI for development testing
|
||||
mcp dev whisper_server.py
|
||||
|
||||
# Run server with MCP CLI
|
||||
mcp run whisper_server.py
|
||||
```
|
||||
|
||||
#### REST API Server (for HTTP clients)
|
||||
|
||||
```bash
|
||||
# Using the startup script (recommended - sets all env vars)
|
||||
./run_api_server.sh
|
||||
|
||||
# Direct Python execution with uvicorn
|
||||
python api_server.py
|
||||
|
||||
# Or using uvicorn directly
|
||||
uvicorn api_server:app --host 0.0.0.0 --port 8000
|
||||
|
||||
# Development mode with auto-reload
|
||||
uvicorn api_server:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
#### Running Both Simultaneously
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start MCP server
|
||||
./run_mcp_server.sh
|
||||
|
||||
# Terminal 2: Start REST API server
|
||||
./run_api_server.sh
|
||||
```
|
||||
|
||||
### Docker
|
||||
|
||||
```bash
|
||||
# Build Docker image
|
||||
docker build -t whisper-mcp-server .
|
||||
|
||||
# Run with GPU support
|
||||
docker run --gpus all -v /path/to/models:/models -v /path/to/outputs:/outputs whisper-mcp-server
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
.
|
||||
├── src/ # Source code directory
|
||||
│ ├── servers/ # Server implementations
|
||||
│ │ ├── whisper_server.py # MCP server entry point
|
||||
│ │ └── api_server.py # REST API server (async job queue)
|
||||
│ ├── core/ # Core business logic
|
||||
│ │ ├── transcriber.py # Transcription logic (single & batch)
|
||||
│ │ ├── model_manager.py # Model lifecycle & caching
|
||||
│ │ ├── job_queue.py # Async job queue manager
|
||||
│ │ └── gpu_health.py # GPU health monitoring
|
||||
│ └── utils/ # Utility modules
|
||||
│ ├── audio_processor.py # Audio validation & preprocessing
|
||||
│ ├── formatters.py # Output format conversion
|
||||
│ └── test_audio_generator.py # Test audio generation for GPU checks
|
||||
├── run_mcp_server.sh # MCP server startup script
|
||||
├── run_api_server.sh # API server startup script
|
||||
├── reset_gpu.sh # GPU driver reset script
|
||||
├── DEV_PLAN.md # Development plan for async features
|
||||
├── requirements.txt # Python dependencies
|
||||
└── pyproject.toml # Project configuration
|
||||
```
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **src/servers/whisper_server.py** - MCP server entry point
|
||||
- Uses FastMCP framework to expose MCP tools
|
||||
- Three main tools: `get_model_info_api()`, `transcribe()`, `batch_transcribe_audio()`
|
||||
- Server initialization at line 19
|
||||
|
||||
2. **src/servers/api_server.py** - REST API server entry point
|
||||
- Uses FastAPI framework for HTTP endpoints
|
||||
- Provides REST endpoints: `/`, `/health`, `/models`, `/transcribe`, `/batch-transcribe`, `/upload-transcribe`
|
||||
- Shares core transcription logic with MCP server
|
||||
- File upload support via multipart/form-data
|
||||
|
||||
3. **src/core/transcriber.py** - Core transcription logic (shared by both servers)
|
||||
- `transcribe_audio()`:39 - Single file transcription with environment variable support
|
||||
- `batch_transcribe()`:209 - Batch processing with progress reporting
|
||||
- All parameters support environment variable defaults (lines 21-37)
|
||||
- Delegates output formatting to utils.formatters
|
||||
|
||||
4. **src/core/model_manager.py** - Whisper model lifecycle management
|
||||
- `get_whisper_model()`:44 - Returns cached model instances or loads new ones
|
||||
- `test_gpu_driver()`:20 - GPU validation before model loading
|
||||
- **CRITICAL**: GPU-only mode enforced at lines 64-90 (no CPU fallback)
|
||||
- Global `model_instances` dict caches loaded models to prevent reloading
|
||||
- Automatic batch size optimization based on GPU memory (lines 134-147)
|
||||
|
||||
5. **src/core/job_queue.py** - Async job queue manager
|
||||
- `JobQueue` class manages FIFO queue with background worker thread
|
||||
- `submit_job()` - Validates audio, checks GPU health, adds to queue
|
||||
- `get_job_status()` - Returns current job status and queue position
|
||||
- `get_job_result()` - Returns transcription result for completed jobs
|
||||
- Jobs persist to disk as JSON files for crash recovery
|
||||
- Single worker thread processes jobs sequentially (prevents GPU contention)
|
||||
|
||||
6. **src/core/gpu_health.py** - GPU health monitoring
|
||||
- `check_gpu_health()`:39 - Real GPU test using tiny model + test audio
|
||||
- `GPUHealthStatus` dataclass contains detailed GPU metrics
|
||||
- **CRITICAL**: Raises RuntimeError if device=cuda but GPU fails (lines 99-135)
|
||||
- Prevents silent CPU fallback that would cause 10-100x slowdown
|
||||
- `HealthMonitor` class for periodic background monitoring
|
||||
|
||||
7. **src/utils/audio_processor.py** - Audio file validation and preprocessing
|
||||
- `validate_audio_file()`:15 - Checks file existence, format, and size
|
||||
- `process_audio()`:50 - Decodes audio using faster_whisper's decode_audio
|
||||
|
||||
8. **src/utils/formatters.py** - Output format conversion
|
||||
- `format_vtt()`, `format_srt()`, `format_txt()`, `format_json()` - Convert segments to various formats
|
||||
- All formatters accept segment lists from Whisper output
|
||||
|
||||
9. **src/utils/test_audio_generator.py** - Test audio generation
|
||||
- `generate_test_audio()` - Creates synthetic 1-second audio for GPU health checks
|
||||
- Uses numpy to generate sine wave, no external audio files needed
|
||||
|
||||
### Key Architecture Patterns
|
||||
|
||||
- **Dual Server Architecture**: Both MCP and REST API servers import and use the same core modules (core.transcriber, core.model_manager, utils.audio_processor, utils.formatters), ensuring consistent behavior
|
||||
- **Model Caching**: Models are cached in `model_instances` dictionary with key format `{model_name}_{device}_{compute_type}` (src/core/model_manager.py:104). This cache is shared if both servers run in the same process
|
||||
- **Batch Processing**: CUDA devices automatically use BatchedInferencePipeline for performance (src/core/model_manager.py:132-160)
|
||||
- **Environment Variable Configuration**: All transcription parameters support env var defaults (src/core/transcriber.py:21-37)
|
||||
- **GPU-Only Mode**: Service is configured for GPU-only operation. `device="auto"` requires CUDA, `device="cpu"` is rejected (src/core/model_manager.py:64-90)
|
||||
- **Async Job Queue**: Long-running transcriptions use async queue pattern to prevent HTTP timeouts. Jobs return immediately with job_id for polling
|
||||
- **GPU Health Monitoring**: Real GPU tests with tiny model prevent silent CPU fallback. Jobs are rejected immediately if GPU fails rather than running 10-100x slower on CPU
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All configuration can be set via environment variables in run_mcp_server.sh and run_api_server.sh:
|
||||
|
||||
**API Server Specific:**
|
||||
- `API_HOST` - API server host (default: 0.0.0.0)
|
||||
- `API_PORT` - API server port (default: 8000)
|
||||
|
||||
**Job Queue Configuration (if using async features):**
|
||||
- `JOB_QUEUE_MAX_SIZE` - Maximum queue size (default: 100)
|
||||
- `JOB_METADATA_DIR` - Directory for job metadata JSON files
|
||||
- `JOB_RETENTION_DAYS` - Auto-cleanup old jobs (0=disabled)
|
||||
|
||||
**GPU Health Monitoring:**
|
||||
- `GPU_HEALTH_CHECK_ENABLED` - Enable periodic GPU monitoring (true/false)
|
||||
- `GPU_HEALTH_CHECK_INTERVAL_MINUTES` - Monitoring interval (default: 10)
|
||||
- `GPU_HEALTH_TEST_MODEL` - Model for health checks (default: tiny)
|
||||
|
||||
**GPU Auto-Reset Configuration:**
|
||||
- `GPU_RESET_COOLDOWN_MINUTES` - Minimum time between GPU reset attempts (default: 5 minutes)
|
||||
- Prevents reset loops while allowing recovery from sleep/wake cycles
|
||||
- Auto-reset is **enabled by default**
|
||||
- Service terminates if GPU unavailable after reset attempt
|
||||
|
||||
**Transcription Configuration (shared by both servers):**
|
||||
|
||||
- `CUDA_VISIBLE_DEVICES` - GPU device selection
|
||||
- `WHISPER_MODEL_DIR` - Model storage location (defaults to None for HuggingFace cache)
|
||||
- `TRANSCRIPTION_OUTPUT_DIR` - Default output directory for single transcriptions
|
||||
- `TRANSCRIPTION_BATCH_OUTPUT_DIR` - Default output directory for batch processing
|
||||
- `TRANSCRIPTION_MODEL` - Model size (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
- `TRANSCRIPTION_DEVICE` - Execution device (cuda, auto) - **NOTE: cpu is rejected in GPU-only mode**
|
||||
- `TRANSCRIPTION_COMPUTE_TYPE` - Computation type (float16, int8, auto)
|
||||
- `TRANSCRIPTION_OUTPUT_FORMAT` - Output format (vtt, srt, txt, json)
|
||||
- `TRANSCRIPTION_BEAM_SIZE` - Beam search size (default: 5)
|
||||
- `TRANSCRIPTION_TEMPERATURE` - Sampling temperature (default: 0.0)
|
||||
- `TRANSCRIPTION_USE_TIMESTAMP` - Add timestamp to filenames (true/false)
|
||||
- `TRANSCRIPTION_FILENAME_PREFIX` - Prefix for output filenames
|
||||
- `TRANSCRIPTION_FILENAME_SUFFIX` - Suffix for output filenames
|
||||
- `TRANSCRIPTION_LANGUAGE` - Language code (zh, en, ja, etc., auto-detect if not set)
|
||||
|
||||
## Supported Configurations
|
||||
|
||||
- **Models**: tiny, base, small, medium, large-v1, large-v2, large-v3
|
||||
- **Audio formats**: .mp3, .wav, .m4a, .flac, .ogg, .aac
|
||||
- **Output formats**: vtt, srt, json, txt
|
||||
- **Languages**: zh (Chinese), en (English), ja (Japanese), ko (Korean), de (German), fr (French), es (Spanish), ru (Russian), it (Italian), pt (Portuguese), nl (Dutch), ar (Arabic), hi (Hindi), tr (Turkish), vi (Vietnamese), th (Thai), id (Indonesian)
|
||||
|
||||
## REST API Endpoints
|
||||
|
||||
The REST API server provides the following HTTP endpoints:
|
||||
|
||||
### GET /
|
||||
Returns API information and available endpoints.
|
||||
|
||||
### GET /health
|
||||
Health check endpoint. Returns `{"status": "healthy", "service": "whisper-transcription"}`.
|
||||
|
||||
### GET /models
|
||||
Returns available Whisper models, devices, languages, and system information (GPU details if CUDA available).
|
||||
|
||||
### POST /transcribe
|
||||
Transcribe a single audio file that exists on the server.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"audio_path": "/path/to/audio.mp3",
|
||||
"model_name": "large-v3",
|
||||
"device": "auto",
|
||||
"compute_type": "auto",
|
||||
"language": "en",
|
||||
"output_format": "txt",
|
||||
"beam_size": 5,
|
||||
"temperature": 0.0,
|
||||
"initial_prompt": null,
|
||||
"output_directory": null
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"message": "Transcription successful, results saved to: /path/to/output.txt",
|
||||
"output_path": "/path/to/output.txt"
|
||||
}
|
||||
```
|
||||
|
||||
### POST /batch-transcribe
|
||||
Batch transcribe all audio files in a folder.
|
||||
|
||||
**Request Body:**
|
||||
```json
|
||||
{
|
||||
"audio_folder": "/path/to/audio/folder",
|
||||
"output_folder": "/path/to/output",
|
||||
"model_name": "large-v3",
|
||||
"output_format": "txt",
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"summary": "Batch processing completed, total transcription time: 00:05:23 | Success: 10/10 | Failed: 0/10"
|
||||
}
|
||||
```
|
||||
|
||||
### POST /upload-transcribe
|
||||
Upload an audio file and transcribe it immediately. Returns the transcription file as a download.
|
||||
|
||||
**Form Data:**
|
||||
- `file`: Audio file (multipart/form-data)
|
||||
- `model_name`: Model name (default: "large-v3")
|
||||
- `device`: Device (default: "auto")
|
||||
- `output_format`: Output format (default: "txt")
|
||||
- ... (other transcription parameters)
|
||||
|
||||
**Response:** Returns the transcription file for download.
|
||||
|
||||
### API Usage Examples
|
||||
|
||||
```bash
|
||||
# Get model information
|
||||
curl http://localhost:8000/models
|
||||
|
||||
# Transcribe existing file (synchronous)
|
||||
curl -X POST http://localhost:8000/transcribe \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"audio_path": "/path/to/audio.mp3", "output_format": "txt"}'
|
||||
|
||||
# Upload and transcribe
|
||||
curl -X POST http://localhost:8000/upload-transcribe \
|
||||
-F "file=@audio.mp3" \
|
||||
-F "output_format=txt" \
|
||||
-F "model_name=large-v3"
|
||||
|
||||
# Async job queue (if enabled)
|
||||
# Submit job
|
||||
curl -X POST http://localhost:8000/jobs \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"audio_path": "/path/to/audio.mp3"}'
|
||||
# Returns: {"job_id": "abc-123", "status": "queued", "queue_position": 1}
|
||||
|
||||
# Check status
|
||||
curl http://localhost:8000/jobs/abc-123
|
||||
# Returns: {"status": "running", ...}
|
||||
|
||||
# Get result (when completed)
|
||||
curl http://localhost:8000/jobs/abc-123/result
|
||||
# Returns: transcription text
|
||||
|
||||
# Check GPU health
|
||||
curl http://localhost:8000/health/gpu
|
||||
# Returns: {"gpu_available": true, "gpu_working": true, ...}
|
||||
```
|
||||
|
||||
## GPU Auto-Reset Configuration
|
||||
|
||||
### Overview
|
||||
This service features automatic GPU driver reset on CUDA errors, which is especially useful for recovering from sleep/wake cycles. The reset functionality is **enabled by default** and includes cooldown protection to prevent reset loops.
|
||||
|
||||
### Passwordless Sudo Setup (Required)
|
||||
|
||||
For automatic GPU reset to work, you must configure passwordless sudo for NVIDIA commands. Create a sudoers configuration file:
|
||||
|
||||
```bash
|
||||
sudo visudo -f /etc/sudoers.d/whisper-gpu-reset
|
||||
```
|
||||
|
||||
Add the following (replace `your_username` with your actual username):
|
||||
|
||||
```
|
||||
# Whisper GPU Auto-Reset Permissions
|
||||
your_username ALL=(ALL) NOPASSWD: /bin/systemctl stop nvidia-persistenced
|
||||
your_username ALL=(ALL) NOPASSWD: /bin/systemctl start nvidia-persistenced
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia_uvm
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia_drm
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia_modeset
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia_modeset
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia_uvm
|
||||
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia_drm
|
||||
```
|
||||
|
||||
**Security Note:** These permissions are limited to specific NVIDIA driver commands only. The reset script (`reset_gpu.sh`) is executed with sudo but is part of the codebase and can be audited.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Startup Check**: When the service starts, it performs a GPU health check
|
||||
- If CUDA errors detected → automatic reset attempt → retry
|
||||
- If retry fails → service terminates
|
||||
|
||||
2. **Runtime Check**: Before job submission and model loading
|
||||
- If CUDA errors detected → automatic reset attempt → retry
|
||||
- If retry fails → job rejected, service continues
|
||||
|
||||
3. **Cooldown Protection**: Prevents reset loops
|
||||
- Minimum 5 minutes between reset attempts (configurable via `GPU_RESET_COOLDOWN_MINUTES`)
|
||||
- Cooldown persists across restarts (stored in `/tmp/whisper-gpu-last-reset`)
|
||||
- If reset needed but cooldown active → service/job fails immediately
|
||||
|
||||
### Manual GPU Reset
|
||||
|
||||
You can manually reset the GPU anytime:
|
||||
|
||||
```bash
|
||||
./reset_gpu.sh
|
||||
```
|
||||
|
||||
Or clear the cooldown to allow immediate reset:
|
||||
|
||||
```python
|
||||
from core.gpu_reset import clear_reset_cooldown
|
||||
clear_reset_cooldown()
|
||||
```
|
||||
|
||||
### Behavior Examples
|
||||
|
||||
**After sleep/wake with GPU issue:**
|
||||
```
|
||||
Service starts → GPU check fails (CUDA error)
|
||||
→ Cooldown OK → Reset drivers → Wait 3s → Retry
|
||||
→ Success → Service continues
|
||||
```
|
||||
|
||||
**Multiple failures (hardware issue):**
|
||||
```
|
||||
First failure → Reset → Retry fails → Job fails
|
||||
Second failure within 5 min → Cooldown active → Fail immediately
|
||||
(Prevents reset loop)
|
||||
```
|
||||
|
||||
**Normal operation:**
|
||||
```
|
||||
No CUDA errors → No resets → Normal performance
|
||||
Reset only happens on actual CUDA failures
|
||||
```
|
||||
|
||||
## Important Implementation Details
|
||||
|
||||
### GPU-Only Architecture
|
||||
- **CRITICAL**: Service enforces GPU-only mode. CPU device is explicitly rejected (src/core/model_manager.py:84-90)
|
||||
- `device="auto"` requires CUDA to be available, raises RuntimeError if not (src/core/model_manager.py:64-73)
|
||||
- GPU health checks use real model loading + transcription, not just torch.cuda.is_available()
|
||||
- If GPU health check fails, jobs are rejected immediately rather than silently falling back to CPU
|
||||
- **GPU Auto-Reset**: Automatic driver reset on CUDA errors with 5-minute cooldown (handles sleep/wake issues)
|
||||
|
||||
### Model Management
|
||||
- GPU memory is checked before loading models (src/core/model_manager.py:115-127)
|
||||
- Batch size dynamically adjusts: 32 (>16GB), 16 (>12GB), 8 (>8GB), 4 (>4GB), 2 (otherwise)
|
||||
- Models are cached globally in `model_instances` dict, shared across requests
|
||||
- Model loading includes GPU driver test to fail fast if GPU is unavailable (src/core/model_manager.py:112-114)
|
||||
|
||||
### Transcription Settings
|
||||
- VAD (Voice Activity Detection) is enabled by default for better long-audio accuracy (src/core/transcriber.py:102)
|
||||
- Word timestamps are enabled by default (src/core/transcriber.py:107)
|
||||
- Files over 1GB generate warnings about processing time (src/utils/audio_processor.py:42)
|
||||
- Default output format is "txt" for REST API, configured via environment variables for MCP server
|
||||
|
||||
### Async Job Queue (if enabled)
|
||||
- Single worker thread processes jobs sequentially (prevents GPU memory contention)
|
||||
- Jobs persist to disk as JSON files in JOB_METADATA_DIR
|
||||
- Queue has max size limit (default 100), returns 503 when full
|
||||
- Job status polling recommended every 5-10 seconds for LLM agents
|
||||
|
||||
## Development Workflow
|
||||
|
||||
### Testing GPU Health
|
||||
```python
|
||||
# Test GPU health check manually
|
||||
from src.core.gpu_health import check_gpu_health
|
||||
|
||||
status = check_gpu_health(expected_device="cuda")
|
||||
print(f"GPU Working: {status.gpu_working}")
|
||||
print(f"Device: {status.device_used}")
|
||||
print(f"Test Duration: {status.test_duration_seconds}s")
|
||||
# Expected: <1s for GPU, 3-10s for CPU
|
||||
```
|
||||
|
||||
### Testing Job Queue
|
||||
```python
|
||||
# Test job queue manually
|
||||
from src.core.job_queue import JobQueue
|
||||
|
||||
queue = JobQueue(max_queue_size=100, metadata_dir="/tmp/jobs")
|
||||
queue.start()
|
||||
|
||||
# Submit job
|
||||
job_info = queue.submit_job(
|
||||
audio_path="/path/to/test.mp3",
|
||||
model_name="large-v3",
|
||||
device="cuda"
|
||||
)
|
||||
print(f"Job ID: {job_info['job_id']}")
|
||||
|
||||
# Poll status
|
||||
status = queue.get_job_status(job_info['job_id'])
|
||||
print(f"Status: {status['status']}")
|
||||
|
||||
# Get result when completed
|
||||
result = queue.get_job_result(job_info['job_id'])
|
||||
```
|
||||
|
||||
### Common Debugging
|
||||
|
||||
**Model loading issues:**
|
||||
- Check `WHISPER_MODEL_DIR` is set correctly
|
||||
- Verify GPU memory with `nvidia-smi`
|
||||
- Check logs for GPU driver test failures at model_manager.py:112-114
|
||||
|
||||
**GPU not detected:**
|
||||
- Verify `CUDA_VISIBLE_DEVICES` is set correctly
|
||||
- Check `torch.cuda.is_available()` returns True
|
||||
- Run GPU health check to see detailed error
|
||||
|
||||
**Silent failures:**
|
||||
- Check that service is NOT silently falling back to CPU
|
||||
- GPU health check should RAISE errors, not log warnings
|
||||
- If device=cuda fails, the job should be rejected, not processed on CPU
|
||||
|
||||
**Job queue issues:**
|
||||
- Check `JOB_METADATA_DIR` exists and is writable
|
||||
- Verify background worker thread is running (check logs)
|
||||
- Job metadata files are in {JOB_METADATA_DIR}/{job_id}.json
|
||||
|
||||
### File Locations
|
||||
|
||||
- **Logs**: `mcp.logs` (MCP server), `api.logs` (API server)
|
||||
- **Models**: `$WHISPER_MODEL_DIR` or HuggingFace cache
|
||||
- **Outputs**: `$TRANSCRIPTION_OUTPUT_DIR` or `$TRANSCRIPTION_BATCH_OUTPUT_DIR`
|
||||
- **Job Metadata**: `$JOB_METADATA_DIR/{job_id}.json`
|
||||
|
||||
### Important Development Notes
|
||||
|
||||
- See `DEV_PLAN.md` for detailed architecture and implementation plan for async job queue features
|
||||
- The service is designed for GPU-only operation - CPU fallback is intentionally disabled to prevent silent performance degradation
|
||||
- When modifying model_manager.py, maintain the strict GPU-only enforcement
|
||||
- When adding new endpoints, follow the async pattern if transcription time >30 seconds
|
||||
93
Dockerfile
93
Dockerfile
@@ -1,25 +1,34 @@
|
||||
# Use NVIDIA CUDA base image with Python
|
||||
FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04
|
||||
# Multi-purpose Whisper Transcriptor Docker Image
|
||||
# Supports both MCP Server and REST API Server modes
|
||||
# Use SERVER_MODE environment variable to select: "mcp" or "api"
|
||||
|
||||
# Install Python 3.12
|
||||
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
|
||||
|
||||
# Prevent interactive prompts during installation
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
software-properties-common \
|
||||
curl \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update && apt-get install -y \
|
||||
python3.12 \
|
||||
python3.12-venv \
|
||||
python3.12-dev \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
git \
|
||||
nginx \
|
||||
supervisor \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Make python3.12 the default
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1
|
||||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
|
||||
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
|
||||
|
||||
# Upgrade pip
|
||||
RUN python -m pip install --upgrade pip
|
||||
# Install pip using ensurepip (Python 3.12+ doesn't have distutils)
|
||||
RUN python -m ensurepip --upgrade && \
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
@@ -27,30 +36,68 @@ WORKDIR /app
|
||||
# Copy requirements first for better caching
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies with CUDA support
|
||||
# Install Python dependencies with CUDA 12.4 support
|
||||
RUN pip install --no-cache-dir \
|
||||
torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124 && \
|
||||
pip install --no-cache-dir \
|
||||
torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 && \
|
||||
pip install --no-cache-dir \
|
||||
faster-whisper \
|
||||
torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121 \
|
||||
torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121 \
|
||||
mcp[cli]
|
||||
fastapi>=0.115.0 \
|
||||
uvicorn[standard]>=0.32.0 \
|
||||
python-multipart>=0.0.9 \
|
||||
aiofiles>=23.0.0 \
|
||||
mcp[cli]>=1.2.0 \
|
||||
gTTS>=2.3.0 \
|
||||
pyttsx3>=2.90 \
|
||||
scipy>=1.10.0 \
|
||||
numpy>=1.24.0
|
||||
|
||||
# Copy application code
|
||||
COPY src/ ./src/
|
||||
COPY pyproject.toml .
|
||||
COPY README.md .
|
||||
|
||||
# Create directories for models and outputs
|
||||
RUN mkdir -p /models /outputs
|
||||
# Copy test audio file for GPU health checks
|
||||
COPY test.mp3 .
|
||||
|
||||
# Copy nginx configuration
|
||||
COPY nginx/transcriptor.conf /etc/nginx/sites-available/transcriptor.conf
|
||||
|
||||
# Copy entrypoint script and GPU reset script
|
||||
COPY docker-entrypoint.sh /docker-entrypoint.sh
|
||||
COPY reset_gpu.sh /app/reset_gpu.sh
|
||||
RUN chmod +x /docker-entrypoint.sh /app/reset_gpu.sh
|
||||
|
||||
# Create directories for models, outputs, and logs
|
||||
RUN mkdir -p /models /outputs /logs /app/outputs/uploads /app/outputs/batch /app/outputs/jobs
|
||||
|
||||
# Set Python path
|
||||
ENV PYTHONPATH=/app/src
|
||||
|
||||
# Set environment variables for GPU
|
||||
ENV WHISPER_MODEL_DIR=/models
|
||||
ENV TRANSCRIPTION_OUTPUT_DIR=/outputs
|
||||
ENV TRANSCRIPTION_MODEL=large-v3
|
||||
ENV TRANSCRIPTION_DEVICE=cuda
|
||||
ENV TRANSCRIPTION_COMPUTE_TYPE=float16
|
||||
# Default environment variables (can be overridden)
|
||||
ENV WHISPER_MODEL_DIR=/models \
|
||||
TRANSCRIPTION_OUTPUT_DIR=/outputs \
|
||||
TRANSCRIPTION_BATCH_OUTPUT_DIR=/outputs/batch \
|
||||
TRANSCRIPTION_MODEL=large-v3 \
|
||||
TRANSCRIPTION_DEVICE=auto \
|
||||
TRANSCRIPTION_COMPUTE_TYPE=auto \
|
||||
TRANSCRIPTION_OUTPUT_FORMAT=txt \
|
||||
TRANSCRIPTION_BEAM_SIZE=5 \
|
||||
TRANSCRIPTION_TEMPERATURE=0.0 \
|
||||
API_HOST=127.0.0.1 \
|
||||
API_PORT=33767 \
|
||||
JOB_QUEUE_MAX_SIZE=5 \
|
||||
JOB_METADATA_DIR=/outputs/jobs \
|
||||
JOB_RETENTION_DAYS=7 \
|
||||
GPU_HEALTH_CHECK_ENABLED=true \
|
||||
GPU_HEALTH_CHECK_INTERVAL_MINUTES=10 \
|
||||
GPU_HEALTH_TEST_MODEL=tiny \
|
||||
GPU_HEALTH_TEST_AUDIO=/test-audio/test.mp3 \
|
||||
GPU_RESET_COOLDOWN_MINUTES=5 \
|
||||
SERVER_MODE=api
|
||||
|
||||
# Run the server
|
||||
CMD ["python", "src/servers/whisper_server.py"]
|
||||
# Expose port 80 for nginx (API mode only)
|
||||
EXPOSE 80
|
||||
|
||||
# Use entrypoint script to handle different server modes
|
||||
ENTRYPOINT ["/docker-entrypoint.sh"]
|
||||
|
||||
403
TRANSCRIPTOR_API_FIX.md
Normal file
403
TRANSCRIPTOR_API_FIX.md
Normal file
@@ -0,0 +1,403 @@
|
||||
# Transcriptor API - Filename Validation Bug Fix
|
||||
|
||||
## Issue Summary
|
||||
|
||||
The transcriptor API is rejecting valid audio files due to overly strict path validation. Files with `..` (double periods) anywhere in the filename are being rejected as potential path traversal attacks, even when they appear naturally in legitimate filenames.
|
||||
|
||||
## Current Behavior
|
||||
|
||||
### Error Observed
|
||||
```json
|
||||
{
|
||||
"detail": {
|
||||
"error": "Upload failed",
|
||||
"message": "Audio file validation failed: Path traversal (..) is not allowed"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### HTTP Response
|
||||
- **Status Code**: 500
|
||||
- **Endpoint**: `POST /transcribe`
|
||||
- **Request**: File upload with filename containing `..`
|
||||
|
||||
### Example Failing Filename
|
||||
```
|
||||
This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a
|
||||
^^^
|
||||
(Three dots, parsed as "..")
|
||||
```
|
||||
|
||||
## Root Cause Analysis
|
||||
|
||||
### Current Validation Logic (Problematic)
|
||||
The API is likely checking for `..` anywhere in the filename string, which creates false positives:
|
||||
|
||||
```python
|
||||
# CURRENT (WRONG)
|
||||
if ".." in filename:
|
||||
raise ValidationError("Path traversal (..) is not allowed")
|
||||
```
|
||||
|
||||
This rejects legitimate filenames like:
|
||||
- `"video...mp4"` (ellipsis in title)
|
||||
- `"Part 1... Part 2.m4a"` (ellipsis separator)
|
||||
- `"Wait... what.mp4"` (dramatic pause)
|
||||
|
||||
### Actual Security Concern
|
||||
Path traversal attacks use `..` as **directory separators** to navigate up the filesystem:
|
||||
- `../../etc/passwd` (DANGEROUS)
|
||||
- `../../../secrets.txt` (DANGEROUS)
|
||||
- `video...mp4` (SAFE - just a filename)
|
||||
|
||||
## Recommended Fix
|
||||
|
||||
### Option 1: Path Component Validation (Recommended)
|
||||
|
||||
Check for `..` only when it appears as a **complete path component**, not as part of the filename text.
|
||||
|
||||
```python
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def validate_filename(filename: str) -> bool:
|
||||
"""
|
||||
Validate filename for path traversal attacks.
|
||||
|
||||
Returns True if safe, raises ValidationError if dangerous.
|
||||
"""
|
||||
# Normalize the path
|
||||
normalized = os.path.normpath(filename)
|
||||
|
||||
# Check if normalization changed the path (indicates traversal)
|
||||
if normalized != filename:
|
||||
raise ValidationError(f"Path traversal detected: {filename}")
|
||||
|
||||
# Check for absolute paths
|
||||
if os.path.isabs(filename):
|
||||
raise ValidationError(f"Absolute paths not allowed: {filename}")
|
||||
|
||||
# Split into components and check for parent directory references
|
||||
parts = Path(filename).parts
|
||||
if ".." in parts:
|
||||
raise ValidationError(f"Parent directory references not allowed: {filename}")
|
||||
|
||||
# Check for any path separators (should be basename only)
|
||||
if os.sep in filename or (os.altsep and os.altsep in filename):
|
||||
raise ValidationError(f"Path separators not allowed: {filename}")
|
||||
|
||||
return True
|
||||
|
||||
# Examples:
|
||||
validate_filename("video.mp4") # ✓ PASS
|
||||
validate_filename("video...mp4") # ✓ PASS (ellipsis)
|
||||
validate_filename("This is... a video.m4a") # ✓ PASS
|
||||
validate_filename("../../../etc/passwd") # ✗ FAIL (traversal)
|
||||
validate_filename("dir/../file.mp4") # ✗ FAIL (traversal)
|
||||
validate_filename("/etc/passwd") # ✗ FAIL (absolute)
|
||||
```
|
||||
|
||||
### Option 2: Basename-Only Validation (Simpler)
|
||||
|
||||
Only accept basenames (no directory components at all):
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
def validate_filename(filename: str) -> bool:
|
||||
"""
|
||||
Ensure filename contains no path components.
|
||||
"""
|
||||
# Extract basename
|
||||
basename = os.path.basename(filename)
|
||||
|
||||
# Must match original (no path components)
|
||||
if basename != filename:
|
||||
raise ValidationError(f"Filename must not contain path components: {filename}")
|
||||
|
||||
# Additional check: no path separators
|
||||
if "/" in filename or "\\" in filename:
|
||||
raise ValidationError(f"Path separators not allowed: {filename}")
|
||||
|
||||
return True
|
||||
|
||||
# Examples:
|
||||
validate_filename("video.mp4") # ✓ PASS
|
||||
validate_filename("video...mp4") # ✓ PASS
|
||||
validate_filename("../file.mp4") # ✗ FAIL
|
||||
validate_filename("dir/file.mp4") # ✗ FAIL
|
||||
```
|
||||
|
||||
### Option 3: Regex Pattern Matching (Most Strict)
|
||||
|
||||
Use a whitelist approach for allowed characters:
|
||||
|
||||
```python
|
||||
import re
|
||||
|
||||
def validate_filename(filename: str) -> bool:
|
||||
"""
|
||||
Validate filename using whitelist of safe characters.
|
||||
"""
|
||||
# Allow: letters, numbers, spaces, dots, hyphens, underscores
|
||||
# Length: 1-255 characters
|
||||
pattern = r'^[a-zA-Z0-9 .\-_]{1,255}\.[a-zA-Z0-9]{2,10}$'
|
||||
|
||||
if not re.match(pattern, filename):
|
||||
raise ValidationError(f"Invalid filename format: {filename}")
|
||||
|
||||
# Additional safety: reject if starts/ends with dot
|
||||
if filename.startswith('.') or filename.endswith('.'):
|
||||
raise ValidationError(f"Filename cannot start or end with dot: {filename}")
|
||||
|
||||
return True
|
||||
|
||||
# Examples:
|
||||
validate_filename("video.mp4") # ✓ PASS
|
||||
validate_filename("video...mp4") # ✓ PASS
|
||||
validate_filename("My Video... Part 2.m4a") # ✓ PASS
|
||||
validate_filename("../file.mp4") # ✗ FAIL (starts with ..)
|
||||
validate_filename("file<>.mp4") # ✗ FAIL (invalid chars)
|
||||
```
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### 1. Locate Current Validation Code
|
||||
|
||||
Search for files containing the validation logic:
|
||||
|
||||
```bash
|
||||
grep -r "Path traversal" /path/to/transcriptor-api
|
||||
grep -r '".."' /path/to/transcriptor-api
|
||||
grep -r "normpath\|basename" /path/to/transcriptor-api
|
||||
```
|
||||
|
||||
### 2. Update Validation Function
|
||||
|
||||
Replace the current naive check with one of the recommended solutions above.
|
||||
|
||||
**Priority Order:**
|
||||
1. **Option 1** (Path Component Validation) - Best security/usability balance
|
||||
2. **Option 2** (Basename-Only) - Simplest, very secure
|
||||
3. **Option 3** (Regex) - Most restrictive, may reject valid files
|
||||
|
||||
### 3. Test Cases
|
||||
|
||||
Create comprehensive test suite:
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
def test_valid_filenames():
|
||||
"""Test filenames that should be accepted."""
|
||||
valid_names = [
|
||||
"video.mp4",
|
||||
"audio.m4a",
|
||||
"This is... a test.mp4",
|
||||
"Part 1... Part 2.wav",
|
||||
"video...multiple...dots.mp3",
|
||||
"My-Video_2024.mp4",
|
||||
"song (remix).m4a",
|
||||
]
|
||||
|
||||
for filename in valid_names:
|
||||
assert validate_filename(filename), f"Should accept: {filename}"
|
||||
|
||||
def test_dangerous_filenames():
|
||||
"""Test filenames that should be rejected."""
|
||||
dangerous_names = [
|
||||
"../../../etc/passwd",
|
||||
"../../secrets.txt",
|
||||
"../file.mp4",
|
||||
"/etc/passwd",
|
||||
"C:\\Windows\\System32\\file.txt",
|
||||
"dir/../file.mp4",
|
||||
"file/../../etc/passwd",
|
||||
]
|
||||
|
||||
for filename in dangerous_names:
|
||||
with pytest.raises(ValidationError):
|
||||
validate_filename(filename)
|
||||
|
||||
def test_edge_cases():
|
||||
"""Test edge cases."""
|
||||
edge_cases = [
|
||||
(".", False), # Current directory
|
||||
("..", False), # Parent directory
|
||||
("...", True), # Just dots (valid)
|
||||
("....", True), # Multiple dots (valid)
|
||||
(".hidden.mp4", True), # Hidden file (valid on Unix)
|
||||
("", False), # Empty string
|
||||
("a" * 256, False), # Too long
|
||||
]
|
||||
|
||||
for filename, should_pass in edge_cases:
|
||||
if should_pass:
|
||||
assert validate_filename(filename)
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
validate_filename(filename)
|
||||
```
|
||||
|
||||
### 4. Update Error Response
|
||||
|
||||
Provide clearer error messages:
|
||||
|
||||
```python
|
||||
# BAD (current)
|
||||
{"detail": {"error": "Upload failed", "message": "Audio file validation failed: Path traversal (..) is not allowed"}}
|
||||
|
||||
# GOOD (improved)
|
||||
{
|
||||
"detail": {
|
||||
"error": "Invalid filename",
|
||||
"message": "Filename contains path traversal characters. Please use only the filename without directory paths.",
|
||||
"filename": "../../etc/passwd",
|
||||
"suggestion": "Use: passwd.txt"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing the Fix
|
||||
|
||||
### Manual Testing
|
||||
|
||||
1. **Test with problematic filename from bug report:**
|
||||
```bash
|
||||
curl -X POST http://192.168.1.210:33767/transcribe \
|
||||
-F "file=@/path/to/This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a" \
|
||||
-F "model=medium"
|
||||
```
|
||||
Expected: HTTP 200 (success)
|
||||
|
||||
2. **Test with actual path traversal:**
|
||||
```bash
|
||||
curl -X POST http://192.168.1.210:33767/transcribe \
|
||||
-F "file=@/tmp/test.m4a;filename=../../etc/passwd" \
|
||||
-F "model=medium"
|
||||
```
|
||||
Expected: HTTP 400 (validation error)
|
||||
|
||||
3. **Test with various ellipsis patterns:**
|
||||
- `"video...mp4"` → Should pass
|
||||
- `"Part 1... Part 2.m4a"` → Should pass
|
||||
- `"Wait... what!.mp4"` → Should pass
|
||||
|
||||
### Automated Testing
|
||||
|
||||
```python
|
||||
# integration_test.py
|
||||
import requests
|
||||
|
||||
def test_ellipsis_filenames():
|
||||
"""Test files with ellipsis in names."""
|
||||
test_cases = [
|
||||
"video...mp4",
|
||||
"This is... a test.m4a",
|
||||
"Wait... what.mp3",
|
||||
]
|
||||
|
||||
for filename in test_cases:
|
||||
response = requests.post(
|
||||
"http://192.168.1.210:33767/transcribe",
|
||||
files={"file": (filename, open("test_audio.m4a", "rb"))},
|
||||
data={"model": "medium"}
|
||||
)
|
||||
assert response.status_code == 200, f"Failed for: {filename}"
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### What We're Protecting Against
|
||||
|
||||
1. **Path Traversal**: `../../../sensitive/file`
|
||||
2. **Absolute Paths**: `/etc/passwd` or `C:\Windows\System32\`
|
||||
3. **Hidden Paths**: `./.git/config`
|
||||
|
||||
### What We're NOT Breaking
|
||||
|
||||
1. **Ellipsis in titles**: `"Wait... what.mp4"`
|
||||
2. **Multiple extensions**: `"file.tar.gz"`
|
||||
3. **Special characters**: `"My Video (2024).mp4"`
|
||||
|
||||
### Additional Hardening (Optional)
|
||||
|
||||
```python
|
||||
def sanitize_and_validate_filename(filename: str) -> str:
|
||||
"""
|
||||
Sanitize filename and validate for safety.
|
||||
Returns cleaned filename or raises error.
|
||||
"""
|
||||
# Remove null bytes
|
||||
filename = filename.replace("\0", "")
|
||||
|
||||
# Extract basename (strips any path components)
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
# Limit length
|
||||
max_length = 255
|
||||
if len(filename) > max_length:
|
||||
name, ext = os.path.splitext(filename)
|
||||
filename = name[:max_length-len(ext)] + ext
|
||||
|
||||
# Validate
|
||||
validate_filename(filename)
|
||||
|
||||
return filename
|
||||
```
|
||||
|
||||
## Deployment Checklist
|
||||
|
||||
- [ ] Update validation function with recommended fix
|
||||
- [ ] Add comprehensive test suite
|
||||
- [ ] Test with real-world filenames (including bug report case)
|
||||
- [ ] Test security: attempt path traversal attacks
|
||||
- [ ] Update API documentation
|
||||
- [ ] Review error messages for clarity
|
||||
- [ ] Deploy to staging environment
|
||||
- [ ] Run integration tests
|
||||
- [ ] Monitor logs for validation failures
|
||||
- [ ] Deploy to production
|
||||
- [ ] Verify bug reporter's file now works
|
||||
|
||||
## Contact & Context
|
||||
|
||||
**Bug Report Date**: 2025-10-26
|
||||
**Affected Endpoint**: `POST /transcribe`
|
||||
**Error Code**: HTTP 500
|
||||
**Client Application**: yt-dlp-webui v3
|
||||
|
||||
**Example Failing Request:**
|
||||
```
|
||||
POST http://192.168.1.210:33767/transcribe
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
file: "This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a"
|
||||
model: "medium"
|
||||
```
|
||||
|
||||
**Current Behavior**: Returns 500 error with path traversal message
|
||||
**Expected Behavior**: Accepts file and processes transcription
|
||||
|
||||
---
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Files to Check
|
||||
- `/path/to/api/validators.py` or similar
|
||||
- `/path/to/api/upload_handler.py`
|
||||
- `/path/to/api/routes/transcribe.py`
|
||||
|
||||
### Search Commands
|
||||
```bash
|
||||
# Find validation code
|
||||
rg "Path traversal" --type py
|
||||
rg '"\.\."' --type py
|
||||
rg "ValidationError.*filename" --type py
|
||||
|
||||
# Find upload handlers
|
||||
rg "def.*upload|def.*transcribe" --type py
|
||||
```
|
||||
|
||||
### Priority Fix
|
||||
Use **Option 1 (Path Component Validation)** - it provides the best balance of security and usability.
|
||||
85
api.logs
85
api.logs
@@ -1,85 +0,0 @@
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:PERFORMING STARTUP GPU HEALTH CHECK
|
||||
INFO:__main__:======================================================================
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 1.04s
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:STARTUP GPU CHECK SUCCESSFUL
|
||||
INFO:__main__:GPU Device: NVIDIA GeForce RTX 3060
|
||||
INFO:__main__:Memory Available: 11.66 GB
|
||||
INFO:__main__:Test Duration: 1.04s
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:Starting Whisper REST API server on 0.0.0.0:8000
|
||||
INFO: Started server process [69821]
|
||||
INFO: Waiting for application startup.
|
||||
INFO:__main__:Starting job queue and health monitor...
|
||||
INFO:core.job_queue:Starting job queue (max size: 100)
|
||||
INFO:core.job_queue:Loading jobs from /media/raid/agents/tools/mcp-transcriptor/outputs/jobs
|
||||
INFO:core.job_queue:Loaded 8 jobs from disk
|
||||
INFO:core.job_queue:Job queue worker loop started
|
||||
INFO:core.job_queue:Job queue worker started
|
||||
INFO:__main__:Job queue started (max_size=100, metadata_dir=/media/raid/agents/tools/mcp-transcriptor/outputs/jobs)
|
||||
INFO:core.gpu_health:Starting GPU health monitor (interval: 10.0 minutes)
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.37s
|
||||
INFO:__main__:GPU health monitor started (interval=10 minutes)
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
|
||||
INFO: 127.0.0.1:48092 - "GET /jobs HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:60874 - "GET /jobs?status=completed&limit=3 HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:60876 - "GET /jobs?status=failed&limit=10 HTTP/1.1" 200 OK
|
||||
INFO:core.job_queue:Running GPU health check before job submission
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.39s
|
||||
INFO:core.job_queue:GPU health check passed
|
||||
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 submitted: /tmp/whisper_test_voice_1s.mp3 (queue position: 1)
|
||||
INFO: 127.0.0.1:58376 - "POST /jobs HTTP/1.1" 200 OK
|
||||
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 started processing
|
||||
INFO:core.model_manager:Running GPU health check with auto-reset before model loading
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.54s
|
||||
INFO:core.model_manager:Loading Whisper model: tiny device: cuda compute type: float16
|
||||
INFO:core.model_manager:Available GPU memory: 12.52 GB
|
||||
INFO:core.model_manager:Enabling batch processing acceleration, batch size: 16
|
||||
INFO:core.transcriber:Starting transcription of file: whisper_test_voice_1s.mp3
|
||||
INFO:utils.audio_processor:Successfully preprocessed audio: whisper_test_voice_1s.mp3
|
||||
INFO:core.transcriber:Using batch acceleration for transcription...
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:VAD filter removed 00:00.000 of audio
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.transcriber:Transcription completed, time used: 0.16 seconds, detected language: en, audio length: 1.51 seconds
|
||||
INFO:core.transcriber:Transcription results saved to: /media/raid/agents/tools/mcp-transcriptor/outputs/whisper_test_voice_1s.txt
|
||||
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 completed successfully: /media/raid/agents/tools/mcp-transcriptor/outputs/whisper_test_voice_1s.txt
|
||||
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 finished: status=completed, duration=1.1s
|
||||
INFO: 127.0.0.1:41646 - "GET /jobs/6be8e49a-bdc1-4508-af99-280bef033cb0 HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:34046 - "GET /jobs/6be8e49a-bdc1-4508-af99-280bef033cb0/result HTTP/1.1" 200 OK
|
||||
INFO:core.job_queue:Running GPU health check before job submission
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.39s
|
||||
INFO:core.job_queue:GPU health check passed
|
||||
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a submitted: /home/uad/agents/tools/mcp-transcriptor/data/test.mp3 (queue position: 1)
|
||||
INFO: 127.0.0.1:44576 - "POST /jobs HTTP/1.1" 200 OK
|
||||
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a started processing
|
||||
INFO:core.model_manager:Running GPU health check with auto-reset before model loading
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.39s
|
||||
INFO:core.model_manager:Loading Whisper model: large-v3 device: cuda compute type: float16
|
||||
INFO:core.model_manager:Available GPU memory: 12.52 GB
|
||||
INFO:core.model_manager:Enabling batch processing acceleration, batch size: 16
|
||||
INFO:core.transcriber:Starting transcription of file: test.mp3
|
||||
INFO:utils.audio_processor:Successfully preprocessed audio: test.mp3
|
||||
INFO:core.transcriber:Using batch acceleration for transcription...
|
||||
INFO:faster_whisper:Processing audio with duration 00:06.955
|
||||
INFO:faster_whisper:VAD filter removed 00:00.299 of audio
|
||||
INFO:core.transcriber:Transcription completed, time used: 0.52 seconds, detected language: en, audio length: 6.95 seconds
|
||||
INFO:core.transcriber:Transcription results saved to: /media/raid/agents/tools/mcp-transcriptor/outputs/test.txt
|
||||
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a completed successfully: /media/raid/agents/tools/mcp-transcriptor/outputs/test.txt
|
||||
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a finished: status=completed, duration=23.3s
|
||||
INFO: 127.0.0.1:59120 - "GET /jobs/41ce74c0-8929-457b-96b3-1b8e4a720a7a HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:53806 - "GET /jobs/41ce74c0-8929-457b-96b3-1b8e4a720a7a/result HTTP/1.1" 200 OK
|
||||
19
docker-build.sh
Executable file
19
docker-build.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
echo "$(datetime_prefix) Building Whisper Transcriptor Docker image..."
|
||||
|
||||
# Build the Docker image
|
||||
docker build -t transcriptor-apimcp:latest .
|
||||
|
||||
echo "$(datetime_prefix) Build complete!"
|
||||
echo "$(datetime_prefix) Image: transcriptor-apimcp:latest"
|
||||
echo ""
|
||||
echo "Usage:"
|
||||
echo " API mode: ./docker-run-api.sh"
|
||||
echo " MCP mode: ./docker-run-mcp.sh"
|
||||
echo " Or use: docker-compose up transcriptor-api"
|
||||
106
docker-compose.yml
Normal file
106
docker-compose.yml
Normal file
@@ -0,0 +1,106 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# API Server mode with nginx reverse proxy
|
||||
transcriptor-api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: transcriptor-apimcp:latest
|
||||
container_name: transcriptor-api
|
||||
runtime: nvidia
|
||||
environment:
|
||||
NVIDIA_VISIBLE_DEVICES: "0"
|
||||
NVIDIA_DRIVER_CAPABILITIES: compute,utility
|
||||
SERVER_MODE: api
|
||||
API_HOST: 127.0.0.1
|
||||
API_PORT: 33767
|
||||
WHISPER_MODEL_DIR: /models
|
||||
TRANSCRIPTION_OUTPUT_DIR: /outputs
|
||||
TRANSCRIPTION_BATCH_OUTPUT_DIR: /outputs/batch
|
||||
TRANSCRIPTION_MODEL: large-v3
|
||||
TRANSCRIPTION_DEVICE: auto
|
||||
TRANSCRIPTION_COMPUTE_TYPE: auto
|
||||
TRANSCRIPTION_OUTPUT_FORMAT: txt
|
||||
TRANSCRIPTION_BEAM_SIZE: 5
|
||||
TRANSCRIPTION_TEMPERATURE: 0.0
|
||||
JOB_QUEUE_MAX_SIZE: 5
|
||||
JOB_METADATA_DIR: /outputs/jobs
|
||||
JOB_RETENTION_DAYS: 7
|
||||
GPU_HEALTH_CHECK_ENABLED: "true"
|
||||
GPU_HEALTH_CHECK_INTERVAL_MINUTES: 10
|
||||
GPU_HEALTH_TEST_MODEL: tiny
|
||||
GPU_HEALTH_TEST_AUDIO: /test-audio/test.mp3
|
||||
GPU_RESET_COOLDOWN_MINUTES: 5
|
||||
# Optional proxy settings (uncomment if needed)
|
||||
# HTTP_PROXY: http://192.168.1.212:8080
|
||||
# HTTPS_PROXY: http://192.168.1.212:8080
|
||||
ports:
|
||||
- "33767:80" # Map host:33767 to container nginx:80
|
||||
volumes:
|
||||
- /home/uad/agents/tools/mcp-transcriptor/models:/models
|
||||
- /home/uad/agents/tools/mcp-transcriptor/outputs:/outputs
|
||||
- /home/uad/agents/tools/mcp-transcriptor/logs:/logs
|
||||
- /home/uad/agents/tools/mcp-transcriptor/data/test.mp3:/test-audio/test.mp3:ro
|
||||
- /etc/localtime:/etc/localtime:ro # Sync container time with host
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 40s
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- transcriptor-network
|
||||
|
||||
# MCP Server mode (stdio based)
|
||||
transcriptor-mcp:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: transcriptor-apimcp:latest
|
||||
container_name: transcriptor-mcp
|
||||
environment:
|
||||
SERVER_MODE: mcp
|
||||
WHISPER_MODEL_DIR: /models
|
||||
TRANSCRIPTION_OUTPUT_DIR: /outputs
|
||||
TRANSCRIPTION_BATCH_OUTPUT_DIR: /outputs/batch
|
||||
TRANSCRIPTION_MODEL: large-v3
|
||||
TRANSCRIPTION_DEVICE: auto
|
||||
TRANSCRIPTION_COMPUTE_TYPE: auto
|
||||
TRANSCRIPTION_OUTPUT_FORMAT: txt
|
||||
TRANSCRIPTION_BEAM_SIZE: 5
|
||||
TRANSCRIPTION_TEMPERATURE: 0.0
|
||||
JOB_QUEUE_MAX_SIZE: 100
|
||||
JOB_METADATA_DIR: /outputs/jobs
|
||||
JOB_RETENTION_DAYS: 7
|
||||
GPU_HEALTH_CHECK_ENABLED: "true"
|
||||
GPU_HEALTH_CHECK_INTERVAL_MINUTES: 10
|
||||
GPU_HEALTH_TEST_MODEL: tiny
|
||||
GPU_RESET_COOLDOWN_MINUTES: 5
|
||||
# Optional proxy settings (uncomment if needed)
|
||||
# HTTP_PROXY: http://192.168.1.212:8080
|
||||
# HTTPS_PROXY: http://192.168.1.212:8080
|
||||
volumes:
|
||||
- /home/uad/agents/tools/mcp-transcriptor/models:/models
|
||||
- /home/uad/agents/tools/mcp-transcriptor/outputs:/outputs
|
||||
- /home/uad/agents/tools/mcp-transcriptor/logs:/logs
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
stdin_open: true # Enable stdin for MCP stdio mode
|
||||
tty: true
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- transcriptor-network
|
||||
profiles:
|
||||
- mcp # Only start when explicitly requested
|
||||
|
||||
networks:
|
||||
transcriptor-network:
|
||||
driver: bridge
|
||||
67
docker-entrypoint.sh
Executable file
67
docker-entrypoint.sh
Executable file
@@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Docker Entrypoint Script for Whisper Transcriptor
|
||||
# Supports both MCP and API server modes
|
||||
|
||||
datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
echo "$(datetime_prefix) Starting Whisper Transcriptor in ${SERVER_MODE} mode..."
|
||||
|
||||
# Ensure required directories exist
|
||||
mkdir -p "$WHISPER_MODEL_DIR"
|
||||
mkdir -p "$TRANSCRIPTION_OUTPUT_DIR"
|
||||
mkdir -p "$TRANSCRIPTION_BATCH_OUTPUT_DIR"
|
||||
mkdir -p "$JOB_METADATA_DIR"
|
||||
mkdir -p /app/outputs/uploads
|
||||
|
||||
# Display GPU information
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
echo "$(datetime_prefix) GPU Information:"
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
else
|
||||
echo "$(datetime_prefix) Warning: nvidia-smi not found. GPU may not be available."
|
||||
fi
|
||||
|
||||
# Check server mode and start appropriate service
|
||||
case "${SERVER_MODE}" in
|
||||
"api")
|
||||
echo "$(datetime_prefix) Starting API Server mode with nginx reverse proxy"
|
||||
|
||||
# Update nginx configuration to use correct backend
|
||||
sed -i "s/server 127.0.0.1:33767;/server ${API_HOST}:${API_PORT};/" /etc/nginx/sites-available/transcriptor.conf
|
||||
|
||||
# Enable nginx site
|
||||
ln -sf /etc/nginx/sites-available/transcriptor.conf /etc/nginx/sites-enabled/
|
||||
rm -f /etc/nginx/sites-enabled/default
|
||||
|
||||
# Test nginx configuration
|
||||
echo "$(datetime_prefix) Testing nginx configuration..."
|
||||
nginx -t
|
||||
|
||||
# Start nginx in background
|
||||
echo "$(datetime_prefix) Starting nginx..."
|
||||
nginx
|
||||
|
||||
# Start API server (foreground - this keeps container running)
|
||||
echo "$(datetime_prefix) Starting API server on ${API_HOST}:${API_PORT}"
|
||||
echo "$(datetime_prefix) API accessible via nginx on port 80"
|
||||
exec python -u /app/src/servers/api_server.py
|
||||
;;
|
||||
|
||||
"mcp")
|
||||
echo "$(datetime_prefix) Starting MCP Server mode (stdio)"
|
||||
echo "$(datetime_prefix) Model directory: $WHISPER_MODEL_DIR"
|
||||
|
||||
# Start MCP server in stdio mode
|
||||
exec python -u /app/src/servers/whisper_server.py
|
||||
;;
|
||||
|
||||
*)
|
||||
echo "$(datetime_prefix) ERROR: Invalid SERVER_MODE: ${SERVER_MODE}"
|
||||
echo "$(datetime_prefix) Valid modes: 'api' or 'mcp'"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
62
docker-run-api.sh
Executable file
62
docker-run-api.sh
Executable file
@@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
echo "$(datetime_prefix) Starting Whisper Transcriptor in API mode with nginx..."
|
||||
|
||||
# Check if image exists
|
||||
if ! docker image inspect transcriptor-apimcp:latest &> /dev/null; then
|
||||
echo "$(datetime_prefix) Image not found. Building first..."
|
||||
./docker-build.sh
|
||||
fi
|
||||
|
||||
# Stop and remove existing container if running
|
||||
if docker ps -a --format '{{.Names}}' | grep -q '^transcriptor-api$'; then
|
||||
echo "$(datetime_prefix) Stopping existing container..."
|
||||
docker stop transcriptor-api || true
|
||||
docker rm transcriptor-api || true
|
||||
fi
|
||||
|
||||
# Run the container in API mode
|
||||
docker run -d \
|
||||
--name transcriptor-api \
|
||||
--gpus all \
|
||||
-p 33767:80 \
|
||||
-e SERVER_MODE=api \
|
||||
-e API_HOST=127.0.0.1 \
|
||||
-e API_PORT=33767 \
|
||||
-e CUDA_VISIBLE_DEVICES=0 \
|
||||
-e TRANSCRIPTION_MODEL=large-v3 \
|
||||
-e TRANSCRIPTION_DEVICE=auto \
|
||||
-e TRANSCRIPTION_COMPUTE_TYPE=auto \
|
||||
-e JOB_QUEUE_MAX_SIZE=5 \
|
||||
-v "$(pwd)/models:/models" \
|
||||
-v "$(pwd)/outputs:/outputs" \
|
||||
-v "$(pwd)/logs:/logs" \
|
||||
--restart unless-stopped \
|
||||
transcriptor-apimcp:latest
|
||||
|
||||
echo "$(datetime_prefix) Container started!"
|
||||
echo ""
|
||||
echo "API Server running at: http://localhost:33767"
|
||||
echo ""
|
||||
echo "Useful commands:"
|
||||
echo " Check logs: docker logs -f transcriptor-api"
|
||||
echo " Check status: docker ps | grep transcriptor-api"
|
||||
echo " Test health: curl http://localhost:33767/health"
|
||||
echo " Test GPU: curl http://localhost:33767/health/gpu"
|
||||
echo " Stop container: docker stop transcriptor-api"
|
||||
echo " Restart: docker restart transcriptor-api"
|
||||
echo ""
|
||||
echo "$(datetime_prefix) Waiting for service to start..."
|
||||
sleep 5
|
||||
|
||||
# Test health endpoint
|
||||
if curl -s http://localhost:33767/health > /dev/null 2>&1; then
|
||||
echo "$(datetime_prefix) ✓ Service is healthy!"
|
||||
else
|
||||
echo "$(datetime_prefix) ⚠ Service not responding yet. Check logs with: docker logs transcriptor-api"
|
||||
fi
|
||||
40
docker-run-mcp.sh
Executable file
40
docker-run-mcp.sh
Executable file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
echo "$(datetime_prefix) Starting Whisper Transcriptor in MCP mode..."
|
||||
|
||||
# Check if image exists
|
||||
if ! docker image inspect transcriptor-apimcp:latest &> /dev/null; then
|
||||
echo "$(datetime_prefix) Image not found. Building first..."
|
||||
./docker-build.sh
|
||||
fi
|
||||
|
||||
# Stop and remove existing container if running
|
||||
if docker ps -a --format '{{.Names}}' | grep -q '^transcriptor-mcp$'; then
|
||||
echo "$(datetime_prefix) Stopping existing container..."
|
||||
docker stop transcriptor-mcp || true
|
||||
docker rm transcriptor-mcp || true
|
||||
fi
|
||||
|
||||
# Run the container in MCP mode (interactive stdio)
|
||||
echo "$(datetime_prefix) Starting MCP server in stdio mode..."
|
||||
echo "$(datetime_prefix) Press Ctrl+C to stop"
|
||||
echo ""
|
||||
|
||||
docker run -it --rm \
|
||||
--name transcriptor-mcp \
|
||||
--gpus all \
|
||||
-e SERVER_MODE=mcp \
|
||||
-e CUDA_VISIBLE_DEVICES=0 \
|
||||
-e TRANSCRIPTION_MODEL=large-v3 \
|
||||
-e TRANSCRIPTION_DEVICE=auto \
|
||||
-e TRANSCRIPTION_COMPUTE_TYPE=auto \
|
||||
-e JOB_QUEUE_MAX_SIZE=100 \
|
||||
-v "$(pwd)/models:/models" \
|
||||
-v "$(pwd)/outputs:/outputs" \
|
||||
-v "$(pwd)/logs:/logs" \
|
||||
transcriptor-apimcp:latest
|
||||
25
mcp.logs
25
mcp.logs
@@ -1,25 +0,0 @@
|
||||
starting mcp server for whisper stt transcriptor
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:PERFORMING STARTUP GPU HEALTH CHECK
|
||||
INFO:__main__:======================================================================
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.93s
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:STARTUP GPU CHECK SUCCESSFUL
|
||||
INFO:__main__:GPU Device: NVIDIA GeForce RTX 3060
|
||||
INFO:__main__:Memory Available: 11.66 GB
|
||||
INFO:__main__:Test Duration: 0.93s
|
||||
INFO:__main__:======================================================================
|
||||
INFO:__main__:Initializing job queue...
|
||||
INFO:core.job_queue:Starting job queue (max size: 100)
|
||||
INFO:core.job_queue:Loading jobs from /media/raid/agents/tools/mcp-transcriptor/outputs/jobs
|
||||
INFO:core.job_queue:Loaded 5 jobs from disk
|
||||
INFO:core.job_queue:Job queue worker loop started
|
||||
INFO:core.job_queue:Job queue worker started
|
||||
INFO:__main__:Job queue started (max_size=100, metadata_dir=/media/raid/agents/tools/mcp-transcriptor/outputs/jobs)
|
||||
INFO:core.gpu_health:Starting GPU health monitor (interval: 10.0 minutes)
|
||||
INFO:faster_whisper:Processing audio with duration 00:01.512
|
||||
INFO:faster_whisper:Detected language 'en' with probability 0.95
|
||||
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.38s
|
||||
INFO:__main__:GPU health monitor started (interval=10 minutes)
|
||||
132
nginx/README.md
Normal file
132
nginx/README.md
Normal file
@@ -0,0 +1,132 @@
|
||||
# Nginx Configuration for Transcriptor API
|
||||
|
||||
This directory contains nginx reverse proxy configuration to fix 504 Gateway Timeout errors.
|
||||
|
||||
## Problem
|
||||
|
||||
The transcriptor API can take a long time (10+ minutes) to process large audio files with the `large-v3` model. Without proper timeout configuration, requests will fail with 504 Gateway Timeout.
|
||||
|
||||
## Solution
|
||||
|
||||
The provided `transcriptor.conf` file configures nginx with appropriate timeouts:
|
||||
|
||||
- **proxy_connect_timeout**: 600s (10 minutes)
|
||||
- **proxy_send_timeout**: 600s (10 minutes)
|
||||
- **proxy_read_timeout**: 3600s (1 hour)
|
||||
- **client_max_body_size**: 500M (for large audio files)
|
||||
|
||||
## Installation
|
||||
|
||||
### Option 1: Deploy nginx configuration (if using nginx)
|
||||
|
||||
```bash
|
||||
# Copy configuration to nginx
|
||||
sudo cp transcriptor.conf /etc/nginx/sites-available/
|
||||
|
||||
# Create symlink to enable it
|
||||
sudo ln -s /etc/nginx/sites-available/transcriptor.conf /etc/nginx/sites-enabled/
|
||||
|
||||
# Test configuration
|
||||
sudo nginx -t
|
||||
|
||||
# Reload nginx
|
||||
sudo systemctl reload nginx
|
||||
```
|
||||
|
||||
### Option 2: Run API server directly (current setup)
|
||||
|
||||
The API server at `src/servers/api_server.py` has been updated with:
|
||||
- `timeout_keep_alive=3600` (1 hour)
|
||||
- `timeout_graceful_shutdown=60`
|
||||
|
||||
No additional nginx configuration is needed if you're running the API directly.
|
||||
|
||||
## Restart Service
|
||||
|
||||
After making changes, restart the transcriptor service:
|
||||
|
||||
```bash
|
||||
# If using supervisor
|
||||
sudo supervisorctl restart transcriptor-api
|
||||
|
||||
# If using systemd
|
||||
sudo systemctl restart transcriptor-api
|
||||
|
||||
# If using docker
|
||||
docker restart <container-name>
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Test the API is working:
|
||||
|
||||
```bash
|
||||
# Health check (should return 200)
|
||||
curl http://192.168.1.210:33767/health
|
||||
|
||||
# Check timeout configuration
|
||||
curl -X POST http://192.168.1.210:33767/transcribe \
|
||||
-F "file=@test_audio.mp3" \
|
||||
-F "model=large-v3" \
|
||||
-F "output_format=txt"
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
Check logs for timeout warnings:
|
||||
|
||||
```bash
|
||||
# Supervisor logs
|
||||
tail -f /home/uad/agents/tools/mcp-transcriptor/logs/transcriptor-api.log
|
||||
|
||||
# Look for messages like:
|
||||
# - "Job {job_id} is taking longer than expected: 610.5s elapsed (threshold: 600s)"
|
||||
# - "Job {job_id} exceeded maximum timeout: 3610.2s elapsed (max: 3600s)"
|
||||
```
|
||||
|
||||
## Configuration Environment Variables
|
||||
|
||||
You can also configure timeouts via environment variables in `supervisor/transcriptor-api.conf`:
|
||||
|
||||
```ini
|
||||
environment=
|
||||
...
|
||||
JOB_TIMEOUT_WARNING_SECONDS="600", # Warn after 10 minutes
|
||||
JOB_TIMEOUT_MAX_SECONDS="3600", # Fail after 1 hour
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Still getting 504 errors?
|
||||
|
||||
1. **Check service is running**:
|
||||
```bash
|
||||
sudo supervisorctl status transcriptor-api
|
||||
```
|
||||
|
||||
2. **Check port is listening**:
|
||||
```bash
|
||||
sudo netstat -tlnp | grep 33767
|
||||
```
|
||||
|
||||
3. **Check logs for errors**:
|
||||
```bash
|
||||
tail -100 /home/uad/agents/tools/mcp-transcriptor/logs/transcriptor-api.log
|
||||
```
|
||||
|
||||
4. **Test direct connection** (bypass nginx):
|
||||
```bash
|
||||
curl http://localhost:33767/health
|
||||
```
|
||||
|
||||
5. **Verify GPU is working**:
|
||||
```bash
|
||||
curl http://192.168.1.210:33767/health/gpu
|
||||
```
|
||||
|
||||
### Job takes too long?
|
||||
|
||||
Consider:
|
||||
- Using a smaller model (e.g., `medium` instead of `large-v3`)
|
||||
- Splitting large audio files into smaller chunks
|
||||
- Increasing `JOB_TIMEOUT_MAX_SECONDS` for very long audio files
|
||||
85
nginx/transcriptor.conf
Normal file
85
nginx/transcriptor.conf
Normal file
@@ -0,0 +1,85 @@
|
||||
# Nginx reverse proxy configuration for Whisper Transcriptor API
|
||||
# Place this file in /etc/nginx/sites-available/ and symlink to sites-enabled/
|
||||
|
||||
upstream transcriptor_backend {
|
||||
# Backend transcriptor API server
|
||||
server 127.0.0.1:33767;
|
||||
|
||||
# Connection pooling
|
||||
keepalive 32;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80;
|
||||
server_name transcriptor.local; # Change to your domain
|
||||
|
||||
# Increase client body size for large audio uploads (up to 500MB)
|
||||
client_max_body_size 500M;
|
||||
|
||||
# Timeouts for long-running transcription jobs
|
||||
proxy_connect_timeout 600s; # 10 minutes to establish connection
|
||||
proxy_send_timeout 600s; # 10 minutes to send request
|
||||
proxy_read_timeout 3600s; # 1 hour to read response (transcription can be slow)
|
||||
|
||||
# Buffer settings for large responses
|
||||
proxy_buffering on;
|
||||
proxy_buffer_size 4k;
|
||||
proxy_buffers 8 4k;
|
||||
proxy_busy_buffers_size 8k;
|
||||
|
||||
# API endpoints
|
||||
location / {
|
||||
proxy_pass http://transcriptor_backend;
|
||||
|
||||
# Forward client info
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# HTTP/1.1 for keepalive
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Connection "";
|
||||
|
||||
# Disable buffering for streaming endpoints
|
||||
proxy_request_buffering off;
|
||||
}
|
||||
|
||||
# Health check endpoint with shorter timeout
|
||||
location /health {
|
||||
proxy_pass http://transcriptor_backend;
|
||||
proxy_read_timeout 10s;
|
||||
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
}
|
||||
|
||||
# Access and error logs
|
||||
access_log /var/log/nginx/transcriptor_access.log;
|
||||
error_log /var/log/nginx/transcriptor_error.log warn;
|
||||
}
|
||||
|
||||
# HTTPS configuration (optional, recommended for production)
|
||||
# server {
|
||||
# listen 443 ssl http2;
|
||||
# server_name transcriptor.local;
|
||||
#
|
||||
# ssl_certificate /etc/ssl/certs/transcriptor.crt;
|
||||
# ssl_certificate_key /etc/ssl/private/transcriptor.key;
|
||||
#
|
||||
# # SSL settings
|
||||
# ssl_protocols TLSv1.2 TLSv1.3;
|
||||
# ssl_ciphers HIGH:!aNULL:!MD5;
|
||||
# ssl_prefer_server_ciphers on;
|
||||
#
|
||||
# # Same settings as HTTP above
|
||||
# client_max_body_size 500M;
|
||||
# proxy_connect_timeout 600s;
|
||||
# proxy_send_timeout 600s;
|
||||
# proxy_read_timeout 3600s;
|
||||
#
|
||||
# location / {
|
||||
# proxy_pass http://transcriptor_backend;
|
||||
# # ... (same proxy settings as above)
|
||||
# }
|
||||
# }
|
||||
@@ -12,6 +12,7 @@ mcp[cli]
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.32.0
|
||||
python-multipart>=0.0.9
|
||||
aiofiles>=23.0.0 # Async file I/O
|
||||
|
||||
# Test audio generation dependencies
|
||||
gTTS>=2.3.0
|
||||
|
||||
28
reset_gpu.sh
28
reset_gpu.sh
@@ -2,12 +2,28 @@
|
||||
|
||||
# Script to reset NVIDIA GPU drivers without rebooting
|
||||
# This reloads kernel modules and restarts nvidia-persistenced service
|
||||
# Also handles stopping/starting Ollama to release GPU resources
|
||||
|
||||
echo "============================================================"
|
||||
echo "NVIDIA GPU Driver Reset Script"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
|
||||
# Stop Ollama via supervisorctl
|
||||
echo "Stopping Ollama service..."
|
||||
sudo supervisorctl stop ollama 2>/dev/null
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✓ Ollama stopped via supervisorctl"
|
||||
OLLAMA_WAS_RUNNING=true
|
||||
else
|
||||
echo " Ollama not running or supervisorctl not available"
|
||||
OLLAMA_WAS_RUNNING=false
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Give Ollama time to release GPU resources
|
||||
sleep 2
|
||||
|
||||
# Stop nvidia-persistenced service
|
||||
echo "Stopping nvidia-persistenced service..."
|
||||
sudo systemctl stop nvidia-persistenced
|
||||
@@ -65,6 +81,18 @@ else
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Restart Ollama if it was running
|
||||
if [ "$OLLAMA_WAS_RUNNING" = true ]; then
|
||||
echo "Restarting Ollama service..."
|
||||
sudo supervisorctl start ollama
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✓ Ollama restarted"
|
||||
else
|
||||
echo "✗ Failed to restart Ollama"
|
||||
fi
|
||||
echo ""
|
||||
fi
|
||||
|
||||
echo "============================================================"
|
||||
echo "GPU driver reset completed successfully"
|
||||
echo "============================================================"
|
||||
|
||||
@@ -11,6 +11,10 @@ export PYTHONPATH="/home/uad/agents/tools/mcp-transcriptor/src:$PYTHONPATH"
|
||||
# Set CUDA library path
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/targets/x86_64-linux/lib:$LD_LIBRARY_PATH
|
||||
|
||||
# Set proxy for model downloads
|
||||
export HTTP_PROXY=http://192.168.1.212:8080
|
||||
export HTTPS_PROXY=http://192.168.1.212:8080
|
||||
|
||||
# Set environment variables
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/data/models"
|
||||
@@ -27,13 +31,13 @@ export TRANSCRIPTION_FILENAME_PREFIX=""
|
||||
|
||||
# API server configuration
|
||||
export API_HOST="0.0.0.0"
|
||||
export API_PORT="8000"
|
||||
export API_PORT="33767"
|
||||
|
||||
# GPU Auto-Reset Configuration
|
||||
export GPU_RESET_COOLDOWN_MINUTES=5 # Minimum time between GPU reset attempts
|
||||
|
||||
# Job Queue Configuration
|
||||
export JOB_QUEUE_MAX_SIZE=100
|
||||
export JOB_QUEUE_MAX_SIZE=5
|
||||
export JOB_METADATA_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
|
||||
export JOB_RETENTION_DAYS=7
|
||||
|
||||
|
||||
@@ -15,6 +15,10 @@ export PYTHONPATH="/home/uad/agents/tools/mcp-transcriptor/src:$PYTHONPATH"
|
||||
# Set CUDA library path
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/targets/x86_64-linux/lib:$LD_LIBRARY_PATH
|
||||
|
||||
# Set proxy for model downloads
|
||||
export HTTP_PROXY=http://192.168.1.212:8080
|
||||
export HTTPS_PROXY=http://192.168.1.212:8080
|
||||
|
||||
# Set environment variables
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/data/models"
|
||||
|
||||
@@ -6,6 +6,7 @@ with strict failure handling to prevent silent CPU fallbacks.
|
||||
Includes circuit breaker pattern to prevent repeated failed checks.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
@@ -14,7 +15,6 @@ from datetime import datetime
|
||||
from typing import Optional, List
|
||||
import torch
|
||||
|
||||
from utils.test_audio_generator import generate_test_audio
|
||||
from utils.circuit_breaker import CircuitBreaker, CircuitBreakerOpen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -109,8 +109,18 @@ def _check_gpu_health_internal(expected_device: str = "auto") -> GPUHealthStatus
|
||||
logger.warning(f"Failed to get GPU info: {e}")
|
||||
|
||||
try:
|
||||
# Generate test audio
|
||||
test_audio_path = generate_test_audio(duration_seconds=1.0)
|
||||
# Get test audio path from environment variable
|
||||
test_audio_path = os.getenv("GPU_HEALTH_TEST_AUDIO")
|
||||
|
||||
if not test_audio_path:
|
||||
raise ValueError("GPU_HEALTH_TEST_AUDIO environment variable not set")
|
||||
|
||||
# Verify test audio file exists
|
||||
if not os.path.exists(test_audio_path):
|
||||
raise FileNotFoundError(
|
||||
f"Test audio file not found: {test_audio_path}. "
|
||||
f"Please ensure test audio exists before running GPU health checks."
|
||||
)
|
||||
|
||||
# Import here to avoid circular dependencies
|
||||
from faster_whisper import WhisperModel
|
||||
@@ -129,6 +139,7 @@ def _check_gpu_health_internal(expected_device: str = "auto") -> GPUHealthStatus
|
||||
gpu_memory_before = torch.cuda.memory_allocated(0)
|
||||
|
||||
# Load tiny model and transcribe
|
||||
model = None
|
||||
try:
|
||||
model = WhisperModel(
|
||||
"tiny",
|
||||
@@ -140,7 +151,7 @@ def _check_gpu_health_internal(expected_device: str = "auto") -> GPUHealthStatus
|
||||
segments, info = model.transcribe(test_audio_path, beam_size=1)
|
||||
|
||||
# Consume segments (needed to actually run inference)
|
||||
list(segments)
|
||||
segments_list = list(segments)
|
||||
|
||||
# Check if GPU was actually used
|
||||
# faster-whisper uses CTranslate2 which manages GPU memory separately
|
||||
@@ -156,6 +167,19 @@ def _check_gpu_health_internal(expected_device: str = "auto") -> GPUHealthStatus
|
||||
actual_device = "cpu"
|
||||
gpu_working = False
|
||||
logger.error(f"GPU health check failed: {error_msg}")
|
||||
finally:
|
||||
# Clean up model resources to prevent GPU memory leak
|
||||
if model is not None:
|
||||
try:
|
||||
del model
|
||||
segments_list = None
|
||||
# Force garbage collection and empty CUDA cache
|
||||
import gc
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Error cleaning up GPU health check model: {cleanup_error}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Health check setup failed: {str(e)}"
|
||||
|
||||
@@ -15,12 +15,17 @@ from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cooldown file location
|
||||
# Cooldown file location (stores monotonic timestamp for drift protection)
|
||||
RESET_TIMESTAMP_FILE = "/tmp/whisper-gpu-last-reset"
|
||||
|
||||
# Default cooldown period (minutes)
|
||||
DEFAULT_COOLDOWN_MINUTES = 5
|
||||
|
||||
# Cooldown period in seconds (for monotonic comparison)
|
||||
def get_cooldown_seconds() -> float:
|
||||
"""Get cooldown period in seconds."""
|
||||
return get_cooldown_minutes() * 60.0
|
||||
|
||||
|
||||
def get_cooldown_minutes() -> int:
|
||||
"""
|
||||
@@ -38,12 +43,12 @@ def get_cooldown_minutes() -> int:
|
||||
return DEFAULT_COOLDOWN_MINUTES
|
||||
|
||||
|
||||
def get_last_reset_time() -> Optional[datetime]:
|
||||
def get_last_reset_time() -> Optional[float]:
|
||||
"""
|
||||
Read timestamp of last GPU reset attempt.
|
||||
Read monotonic timestamp of last GPU reset attempt.
|
||||
|
||||
Returns:
|
||||
datetime object of last reset, or None if no previous reset
|
||||
Monotonic timestamp of last reset, or None if no previous reset
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(RESET_TIMESTAMP_FILE):
|
||||
@@ -52,7 +57,7 @@ def get_last_reset_time() -> Optional[datetime]:
|
||||
with open(RESET_TIMESTAMP_FILE, 'r') as f:
|
||||
timestamp_str = f.read().strip()
|
||||
|
||||
return datetime.fromisoformat(timestamp_str)
|
||||
return float(timestamp_str)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read last reset timestamp: {e}")
|
||||
@@ -63,15 +68,17 @@ def record_reset_attempt() -> None:
|
||||
"""
|
||||
Record current time as last GPU reset attempt.
|
||||
|
||||
Creates/updates timestamp file with current UTC time.
|
||||
Creates/updates timestamp file with monotonic time (drift-protected).
|
||||
"""
|
||||
try:
|
||||
timestamp = datetime.utcnow().isoformat()
|
||||
# Use monotonic time to prevent NTP drift issues
|
||||
timestamp_monotonic = time.monotonic()
|
||||
timestamp_iso = datetime.utcnow().isoformat() # For logging only
|
||||
|
||||
with open(RESET_TIMESTAMP_FILE, 'w') as f:
|
||||
f.write(timestamp)
|
||||
f.write(str(timestamp_monotonic))
|
||||
|
||||
logger.info(f"Recorded GPU reset timestamp: {timestamp}")
|
||||
logger.info(f"Recorded GPU reset timestamp: {timestamp_iso} (monotonic: {timestamp_monotonic:.2f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record reset timestamp: {e}")
|
||||
@@ -81,35 +88,36 @@ def can_attempt_reset() -> bool:
|
||||
"""
|
||||
Check if GPU reset can be attempted based on cooldown period.
|
||||
|
||||
Uses monotonic time to prevent NTP drift issues.
|
||||
|
||||
Returns:
|
||||
True if reset is allowed (no recent reset or cooldown expired),
|
||||
False if cooldown is still active
|
||||
"""
|
||||
last_reset = get_last_reset_time()
|
||||
last_reset_monotonic = get_last_reset_time()
|
||||
|
||||
if last_reset is None:
|
||||
if last_reset_monotonic is None:
|
||||
# No previous reset recorded
|
||||
logger.debug("No previous GPU reset found, reset allowed")
|
||||
return True
|
||||
|
||||
cooldown_minutes = get_cooldown_minutes()
|
||||
cooldown_period = timedelta(minutes=cooldown_minutes)
|
||||
time_since_reset = datetime.utcnow() - last_reset
|
||||
# Use monotonic time for drift-safe comparison
|
||||
current_monotonic = time.monotonic()
|
||||
time_since_reset_seconds = current_monotonic - last_reset_monotonic
|
||||
cooldown_seconds = get_cooldown_seconds()
|
||||
|
||||
if time_since_reset < cooldown_period:
|
||||
remaining = cooldown_period - time_since_reset
|
||||
if time_since_reset_seconds < cooldown_seconds:
|
||||
remaining_seconds = cooldown_seconds - time_since_reset_seconds
|
||||
logger.warning(
|
||||
f"GPU reset cooldown active. "
|
||||
f"Last reset: {last_reset.isoformat()}, "
|
||||
f"Cooldown: {cooldown_minutes} min, "
|
||||
f"Remaining: {remaining.total_seconds():.0f}s"
|
||||
f"Cooldown: {get_cooldown_minutes()} min, "
|
||||
f"Remaining: {remaining_seconds:.0f}s"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
f"GPU reset cooldown expired. "
|
||||
f"Last reset: {last_reset.isoformat()}, "
|
||||
f"Time since: {time_since_reset.total_seconds():.0f}s"
|
||||
f"Time since last reset: {time_since_reset_seconds:.0f}s"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -149,16 +157,29 @@ def reset_gpu_drivers() -> None:
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
logger.info(f"Executing GPU reset script: {script_path}")
|
||||
# Resolve to absolute path and validate it's the expected script
|
||||
# This prevents path injection if script_path was somehow manipulated
|
||||
resolved_path = script_path.resolve()
|
||||
|
||||
# Security check: Ensure resolved path is still in expected location
|
||||
expected_parent = Path(__file__).parent.parent.parent.resolve()
|
||||
if resolved_path.parent != expected_parent:
|
||||
error_msg = f"Security check failed: Script path outside expected directory"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
logger.info(f"Executing GPU reset script: {resolved_path}")
|
||||
logger.warning("This will temporarily interrupt all GPU operations")
|
||||
|
||||
try:
|
||||
# Execute reset script with sudo
|
||||
# Using list form (not shell=True) prevents shell injection
|
||||
result = subprocess.run(
|
||||
['sudo', str(script_path)],
|
||||
['sudo', str(resolved_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30 # 30 second timeout
|
||||
timeout=30, # 30 second timeout
|
||||
shell=False # Explicitly disable shell to prevent injection
|
||||
)
|
||||
|
||||
# Log script output
|
||||
|
||||
@@ -14,14 +14,23 @@ import uuid
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict
|
||||
from typing import Optional, List, Dict, Deque
|
||||
from collections import deque
|
||||
|
||||
from core.gpu_health import check_gpu_health_with_reset
|
||||
from core.transcriber import transcribe_audio
|
||||
from core.job_repository import JobRepository
|
||||
from utils.audio_processor import validate_audio_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
DEFAULT_JOB_TTL_HOURS = 24 # How long to keep completed jobs in memory
|
||||
GPU_HEALTH_CACHE_TTL_SECONDS = 30 # Cache GPU health check results
|
||||
CLEANUP_INTERVAL_SECONDS = 3600 # Run TTL cleanup every hour (1 hour)
|
||||
JOB_TIMEOUT_WARNING_SECONDS = 600 # Warn if job takes > 10 minutes
|
||||
JOB_TIMEOUT_MAX_SECONDS = 3600 # Maximum 1 hour per job
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
"""Job status enumeration."""
|
||||
@@ -84,6 +93,10 @@ class Job:
|
||||
"processing_time_seconds": self.processing_time_seconds,
|
||||
}
|
||||
|
||||
def mark_for_persistence(self, repository):
|
||||
"""Mark job as dirty for write-behind persistence."""
|
||||
repository.mark_dirty(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> 'Job':
|
||||
"""Deserialize from dictionary."""
|
||||
@@ -110,42 +123,68 @@ class Job:
|
||||
processing_time_seconds=data.get("processing_time_seconds"),
|
||||
)
|
||||
|
||||
def save_to_disk(self, metadata_dir: str):
|
||||
"""Save job metadata to {metadata_dir}/{job_id}.json"""
|
||||
os.makedirs(metadata_dir, exist_ok=True)
|
||||
filepath = os.path.join(metadata_dir, f"{self.job_id}.json")
|
||||
|
||||
try:
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(self.to_dict(), f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save job {self.job_id} to disk: {e}")
|
||||
|
||||
|
||||
class JobQueue:
|
||||
"""Manages job queue with background worker."""
|
||||
"""
|
||||
Manages job queue with background worker.
|
||||
|
||||
THREAD SAFETY & LOCK ORDERING
|
||||
==============================
|
||||
This class uses multiple locks to protect shared state. To prevent deadlocks,
|
||||
all code MUST follow this strict lock ordering:
|
||||
|
||||
LOCK HIERARCHY (acquire in this order):
|
||||
1. _jobs_lock - Protects _jobs dict and _current_job_id
|
||||
2. _queue_positions_lock - Protects _queued_job_ids deque
|
||||
|
||||
RULES:
|
||||
- NEVER acquire _jobs_lock while holding _queue_positions_lock
|
||||
- Always release locks in reverse order of acquisition
|
||||
- Keep lock hold time minimal - release before I/O operations
|
||||
- Use snapshot/copy pattern when data must cross lock boundaries
|
||||
|
||||
CRITICAL METHODS:
|
||||
- _calculate_queue_positions(): Uses snapshot pattern to avoid nested locks
|
||||
- submit_job(): Acquires locks separately, never nested
|
||||
- _worker_loop(): Acquires locks separately in correct order
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_queue_size: int = 100,
|
||||
metadata_dir: str = "/outputs/jobs"):
|
||||
metadata_dir: str = "/outputs/jobs",
|
||||
job_ttl_hours: int = 24):
|
||||
"""
|
||||
Initialize job queue.
|
||||
|
||||
Args:
|
||||
max_queue_size: Maximum number of jobs in queue
|
||||
metadata_dir: Directory to store job metadata JSON files
|
||||
job_ttl_hours: Hours to keep completed/failed jobs before cleanup
|
||||
"""
|
||||
self._queue = queue.Queue(maxsize=max_queue_size)
|
||||
self._jobs: Dict[str, Job] = {}
|
||||
self._metadata_dir = metadata_dir
|
||||
self._repository = JobRepository(
|
||||
metadata_dir=metadata_dir,
|
||||
job_ttl_hours=job_ttl_hours
|
||||
)
|
||||
self._worker_thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
self._current_job_id: Optional[str] = None
|
||||
self._lock = threading.RLock() # Use RLock to allow re-entrant locking
|
||||
self._jobs_lock = threading.Lock() # Lock for _jobs dict
|
||||
self._queue_positions_lock = threading.Lock() # Lock for position tracking
|
||||
self._max_queue_size = max_queue_size
|
||||
|
||||
# Create metadata directory
|
||||
os.makedirs(metadata_dir, exist_ok=True)
|
||||
# Maintain ordered queue for O(1) position lookups
|
||||
# Deque of job_ids in queue order (FIFO)
|
||||
self._queued_job_ids: Deque[str] = deque()
|
||||
|
||||
# TTL cleanup tracking
|
||||
self._last_cleanup_time = datetime.utcnow()
|
||||
|
||||
# GPU health check caching
|
||||
self._gpu_health_cache: Optional[any] = None
|
||||
self._gpu_health_cache_time: Optional[datetime] = None
|
||||
self._gpu_health_cache_ttl_seconds = GPU_HEALTH_CACHE_TTL_SECONDS
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
@@ -158,6 +197,9 @@ class JobQueue:
|
||||
|
||||
logger.info(f"Starting job queue (max size: {self._max_queue_size})")
|
||||
|
||||
# Start repository flush thread
|
||||
self._repository.start()
|
||||
|
||||
# Load existing jobs from disk
|
||||
self._load_jobs_from_disk()
|
||||
|
||||
@@ -187,6 +229,10 @@ class JobQueue:
|
||||
self._worker_thread.join(timeout=1.0)
|
||||
|
||||
self._worker_thread = None
|
||||
|
||||
# Stop repository and flush pending writes
|
||||
self._repository.stop(flush_pending=True)
|
||||
|
||||
logger.info("Job queue worker stopped")
|
||||
|
||||
def submit_job(self,
|
||||
@@ -224,19 +270,42 @@ class JobQueue:
|
||||
|
||||
# 2. Check GPU health (GPU required for all devices since this is GPU-only service)
|
||||
# Both device="cuda" and device="auto" require GPU
|
||||
# Use cached health check result if available (30s TTL)
|
||||
if device == "cuda" or device == "auto":
|
||||
try:
|
||||
logger.info("Running GPU health check before job submission")
|
||||
# Use expected_device to match what user requested
|
||||
expected = "cuda" if device == "cuda" else "auto"
|
||||
health_status = check_gpu_health_with_reset(expected_device=expected, auto_reset=True)
|
||||
# Check cache first
|
||||
now = datetime.utcnow()
|
||||
cache_valid = (
|
||||
self._gpu_health_cache is not None and
|
||||
self._gpu_health_cache_time is not None and
|
||||
(now - self._gpu_health_cache_time).total_seconds() < self._gpu_health_cache_ttl_seconds
|
||||
)
|
||||
|
||||
if cache_valid:
|
||||
logger.debug("Using cached GPU health check result")
|
||||
health_status = self._gpu_health_cache
|
||||
else:
|
||||
logger.info("Running GPU health check before job submission")
|
||||
# Use expected_device to match what user requested
|
||||
expected = "cuda" if device == "cuda" else "auto"
|
||||
health_status = check_gpu_health_with_reset(expected_device=expected, auto_reset=True)
|
||||
|
||||
# Cache the result
|
||||
self._gpu_health_cache = health_status
|
||||
self._gpu_health_cache_time = now
|
||||
logger.info("GPU health check passed and cached")
|
||||
|
||||
if not health_status.gpu_working:
|
||||
# Invalidate cache on failure
|
||||
self._gpu_health_cache = None
|
||||
self._gpu_health_cache_time = None
|
||||
|
||||
raise RuntimeError(
|
||||
f"GPU device required but not available. "
|
||||
f"GPU check failed: {health_status.error}. "
|
||||
f"This service is configured for GPU-only operation."
|
||||
)
|
||||
logger.info("GPU health check passed")
|
||||
|
||||
except RuntimeError as e:
|
||||
# Re-raise GPU health errors
|
||||
logger.error(f"Job rejected due to GPU health check failure: {e}")
|
||||
@@ -250,6 +319,7 @@ class JobQueue:
|
||||
|
||||
# 3. Generate job_id
|
||||
job_id = str(uuid.uuid4())
|
||||
logger.debug(f"Generated job_id: {job_id}")
|
||||
|
||||
# 4. Create Job object
|
||||
job = Job(
|
||||
@@ -273,34 +343,88 @@ class JobQueue:
|
||||
error=None,
|
||||
processing_time_seconds=None,
|
||||
)
|
||||
logger.debug(f"Created Job object for {job_id}")
|
||||
|
||||
# 5. Add to queue (raises queue.Full if full)
|
||||
logger.debug(f"Attempting to add job {job_id} to queue (current size: {self._queue.qsize()})")
|
||||
try:
|
||||
self._queue.put_nowait(job)
|
||||
logger.debug(f"Successfully added job {job_id} to queue")
|
||||
except queue.Full:
|
||||
raise queue.Full(
|
||||
f"Job queue is full (max size: {self._max_queue_size}). "
|
||||
f"Please try again later."
|
||||
)
|
||||
|
||||
# 6. Add to jobs dict and save to disk
|
||||
with self._lock:
|
||||
# 6. Add to jobs dict and update queue tracking
|
||||
# LOCK ORDERING: Always acquire _jobs_lock before _queue_positions_lock
|
||||
logger.debug(f"Acquiring _jobs_lock for job {job_id}")
|
||||
with self._jobs_lock:
|
||||
self._jobs[job_id] = job
|
||||
logger.debug(f"Added job {job_id} to _jobs dict")
|
||||
|
||||
# Update queue positions (separate lock to avoid deadlock)
|
||||
logger.debug(f"Acquiring _queue_positions_lock for job {job_id}")
|
||||
with self._queue_positions_lock:
|
||||
# Add to ordered queue for O(1) position tracking
|
||||
self._queued_job_ids.append(job_id)
|
||||
logger.debug(f"Added job {job_id} to _queued_job_ids, calling _calculate_queue_positions()")
|
||||
# Calculate positions - this will briefly acquire _jobs_lock internally
|
||||
self._calculate_queue_positions()
|
||||
job.save_to_disk(self._metadata_dir)
|
||||
logger.debug(f"Finished _calculate_queue_positions() for job {job_id}")
|
||||
|
||||
# Capture return data (need to re-acquire lock after position calculation)
|
||||
logger.debug(f"Re-acquiring _jobs_lock to capture return data for job {job_id}")
|
||||
with self._jobs_lock:
|
||||
return_data = {
|
||||
"job_id": job_id,
|
||||
"status": job.status.value,
|
||||
"queue_position": job.queue_position,
|
||||
"created_at": job.created_at.isoformat() + "Z"
|
||||
}
|
||||
queue_position = job.queue_position
|
||||
logger.debug(f"Captured return data for job {job_id}, queue_position={queue_position}")
|
||||
|
||||
# Mark for async persistence (outside lock to avoid blocking)
|
||||
logger.debug(f"Marking job {job_id} for persistence")
|
||||
job.mark_for_persistence(self._repository)
|
||||
logger.debug(f"Job {job_id} marked for persistence successfully")
|
||||
|
||||
logger.info(
|
||||
f"Job {job_id} submitted: {audio_path} "
|
||||
f"(queue position: {job.queue_position})"
|
||||
f"(queue position: {queue_position})"
|
||||
)
|
||||
|
||||
# Run periodic TTL cleanup (every 100 jobs)
|
||||
self._maybe_cleanup_old_jobs()
|
||||
|
||||
# 7. Return job info
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": job.status.value,
|
||||
"queue_position": job.queue_position,
|
||||
"created_at": job.created_at.isoformat() + "Z"
|
||||
}
|
||||
return return_data
|
||||
|
||||
def _maybe_cleanup_old_jobs(self):
|
||||
"""Periodically cleanup old completed/failed jobs based on TTL."""
|
||||
# Only run cleanup every hour
|
||||
now = datetime.utcnow()
|
||||
if (now - self._last_cleanup_time).total_seconds() < CLEANUP_INTERVAL_SECONDS:
|
||||
return
|
||||
|
||||
self._last_cleanup_time = now
|
||||
|
||||
# Get jobs snapshot
|
||||
with self._jobs_lock:
|
||||
jobs_snapshot = dict(self._jobs)
|
||||
|
||||
# Run cleanup (removes from disk)
|
||||
deleted_job_ids = self._repository.cleanup_old_jobs(jobs_snapshot)
|
||||
|
||||
# Remove from in-memory dict
|
||||
if deleted_job_ids:
|
||||
with self._jobs_lock:
|
||||
for job_id in deleted_job_ids:
|
||||
if job_id in self._jobs:
|
||||
del self._jobs[job_id]
|
||||
|
||||
logger.info(f"TTL cleanup removed {len(deleted_job_ids)} old jobs")
|
||||
|
||||
def get_job_status(self, job_id: str) -> dict:
|
||||
"""
|
||||
@@ -312,23 +436,38 @@ class JobQueue:
|
||||
Raises:
|
||||
KeyError: If job_id not found
|
||||
"""
|
||||
with self._lock:
|
||||
# Copy job data inside lock, release before building response
|
||||
with self._jobs_lock:
|
||||
if job_id not in self._jobs:
|
||||
raise KeyError(f"Job {job_id} not found")
|
||||
|
||||
job = self._jobs[job_id]
|
||||
return {
|
||||
# Copy all fields we need while holding lock
|
||||
job_data = {
|
||||
"job_id": job.job_id,
|
||||
"status": job.status.value,
|
||||
"queue_position": job.queue_position if job.status == JobStatus.QUEUED else None,
|
||||
"created_at": job.created_at.isoformat() + "Z",
|
||||
"started_at": job.started_at.isoformat() + "Z" if job.started_at else None,
|
||||
"completed_at": job.completed_at.isoformat() + "Z" if job.completed_at else None,
|
||||
"created_at": job.created_at,
|
||||
"started_at": job.started_at,
|
||||
"completed_at": job.completed_at,
|
||||
"result_path": job.result_path,
|
||||
"error": job.error,
|
||||
"processing_time_seconds": job.processing_time_seconds,
|
||||
}
|
||||
|
||||
# Format response outside lock
|
||||
return {
|
||||
"job_id": job_data["job_id"],
|
||||
"status": job_data["status"],
|
||||
"queue_position": job_data["queue_position"],
|
||||
"created_at": job_data["created_at"].isoformat() + "Z",
|
||||
"started_at": job_data["started_at"].isoformat() + "Z" if job_data["started_at"] else None,
|
||||
"completed_at": job_data["completed_at"].isoformat() + "Z" if job_data["completed_at"] else None,
|
||||
"result_path": job_data["result_path"],
|
||||
"error": job_data["error"],
|
||||
"processing_time_seconds": job_data["processing_time_seconds"],
|
||||
}
|
||||
|
||||
def get_job_result(self, job_id: str) -> str:
|
||||
"""
|
||||
Get transcription result text for completed job.
|
||||
@@ -341,7 +480,8 @@ class JobQueue:
|
||||
ValueError: If job not completed
|
||||
FileNotFoundError: If result file missing
|
||||
"""
|
||||
with self._lock:
|
||||
# Copy necessary data inside lock, then release before file I/O
|
||||
with self._jobs_lock:
|
||||
if job_id not in self._jobs:
|
||||
raise KeyError(f"Job {job_id} not found")
|
||||
|
||||
@@ -356,13 +496,16 @@ class JobQueue:
|
||||
if not job.result_path:
|
||||
raise FileNotFoundError(f"Job {job_id} has no result path")
|
||||
|
||||
# Copy result_path while holding lock
|
||||
result_path = job.result_path
|
||||
|
||||
# Read result file (outside lock to avoid blocking)
|
||||
if not os.path.exists(job.result_path):
|
||||
if not os.path.exists(result_path):
|
||||
raise FileNotFoundError(
|
||||
f"Result file not found: {job.result_path}"
|
||||
f"Result file not found: {result_path}"
|
||||
)
|
||||
|
||||
with open(job.result_path, 'r', encoding='utf-8') as f:
|
||||
with open(result_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
def list_jobs(self,
|
||||
@@ -378,7 +521,7 @@ class JobQueue:
|
||||
Returns:
|
||||
List of job status dictionaries
|
||||
"""
|
||||
with self._lock:
|
||||
with self._jobs_lock:
|
||||
jobs = list(self._jobs.values())
|
||||
|
||||
# Filter by status
|
||||
@@ -391,8 +534,23 @@ class JobQueue:
|
||||
# Limit results
|
||||
jobs = jobs[:limit]
|
||||
|
||||
# Convert to dict
|
||||
return [self.get_job_status(j.job_id) for j in jobs]
|
||||
# Convert to dict directly (avoid N+1 by building response in single pass)
|
||||
# This eliminates the need to call get_job_status() for each job
|
||||
result = []
|
||||
for job in jobs:
|
||||
result.append({
|
||||
"job_id": job.job_id,
|
||||
"status": job.status.value,
|
||||
"queue_position": job.queue_position if job.status == JobStatus.QUEUED else None,
|
||||
"created_at": job.created_at.isoformat() + "Z",
|
||||
"started_at": job.started_at.isoformat() + "Z" if job.started_at else None,
|
||||
"completed_at": job.completed_at.isoformat() + "Z" if job.completed_at else None,
|
||||
"result_path": job.result_path,
|
||||
"error": job.error,
|
||||
"processing_time_seconds": job.processing_time_seconds,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _worker_loop(self):
|
||||
"""
|
||||
@@ -410,34 +568,83 @@ class JobQueue:
|
||||
continue
|
||||
|
||||
# Update job status to running
|
||||
with self._lock:
|
||||
with self._jobs_lock:
|
||||
self._current_job_id = job.job_id
|
||||
job.status = JobStatus.RUNNING
|
||||
job.started_at = datetime.utcnow()
|
||||
job.queue_position = 0
|
||||
self._calculate_queue_positions()
|
||||
job.save_to_disk(self._metadata_dir)
|
||||
|
||||
with self._queue_positions_lock:
|
||||
# Remove from ordered queue when starting processing
|
||||
if job.job_id in self._queued_job_ids:
|
||||
self._queued_job_ids.remove(job.job_id)
|
||||
# Recalculate positions since we removed a job
|
||||
self._calculate_queue_positions()
|
||||
|
||||
# Mark for async persistence (outside lock)
|
||||
job.mark_for_persistence(self._repository)
|
||||
|
||||
logger.info(f"Job {job.job_id} started processing")
|
||||
|
||||
# Process job
|
||||
# Process job with timeout tracking
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = transcribe_audio(
|
||||
audio_path=job.audio_path,
|
||||
model_name=job.model_name,
|
||||
device=job.device,
|
||||
compute_type=job.compute_type,
|
||||
language=job.language,
|
||||
output_format=job.output_format,
|
||||
beam_size=job.beam_size,
|
||||
temperature=job.temperature,
|
||||
initial_prompt=job.initial_prompt,
|
||||
output_directory=job.output_directory
|
||||
)
|
||||
# Start a monitoring thread for timeout warnings
|
||||
timeout_event = threading.Event()
|
||||
|
||||
def timeout_monitor():
|
||||
"""Monitor job execution time and emit warnings."""
|
||||
# Wait for warning threshold
|
||||
if timeout_event.wait(JOB_TIMEOUT_WARNING_SECONDS):
|
||||
return # Job completed before warning threshold
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.warning(
|
||||
f"Job {job.job_id} is taking longer than expected: "
|
||||
f"{elapsed:.1f}s elapsed (threshold: {JOB_TIMEOUT_WARNING_SECONDS}s)"
|
||||
)
|
||||
|
||||
# Wait for max timeout
|
||||
remaining = JOB_TIMEOUT_MAX_SECONDS - elapsed
|
||||
if remaining > 0:
|
||||
if timeout_event.wait(remaining):
|
||||
return # Job completed before max timeout
|
||||
|
||||
# Job exceeded max timeout
|
||||
elapsed = time.time() - start_time
|
||||
logger.error(
|
||||
f"Job {job.job_id} exceeded maximum timeout: "
|
||||
f"{elapsed:.1f}s elapsed (max: {JOB_TIMEOUT_MAX_SECONDS}s)"
|
||||
)
|
||||
|
||||
monitor_thread = threading.Thread(target=timeout_monitor, daemon=True)
|
||||
monitor_thread.start()
|
||||
|
||||
try:
|
||||
result = transcribe_audio(
|
||||
audio_path=job.audio_path,
|
||||
model_name=job.model_name,
|
||||
device=job.device,
|
||||
compute_type=job.compute_type,
|
||||
language=job.language,
|
||||
output_format=job.output_format,
|
||||
beam_size=job.beam_size,
|
||||
temperature=job.temperature,
|
||||
initial_prompt=job.initial_prompt,
|
||||
output_directory=job.output_directory
|
||||
)
|
||||
finally:
|
||||
# Signal timeout monitor to stop
|
||||
timeout_event.set()
|
||||
|
||||
# Check if job exceeded hard timeout
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > JOB_TIMEOUT_MAX_SECONDS:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = f"Job exceeded maximum timeout ({JOB_TIMEOUT_MAX_SECONDS}s): {elapsed:.1f}s elapsed"
|
||||
logger.error(f"Job {job.job_id} timed out: {job.error}")
|
||||
# Parse result
|
||||
if "saved to:" in result:
|
||||
elif "saved to:" in result:
|
||||
job.result_path = result.split("saved to:")[1].strip()
|
||||
job.status = JobStatus.COMPLETED
|
||||
logger.info(
|
||||
@@ -457,11 +664,14 @@ class JobQueue:
|
||||
# Update job completion info
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.processing_time_seconds = time.time() - start_time
|
||||
job.save_to_disk(self._metadata_dir)
|
||||
|
||||
with self._lock:
|
||||
with self._jobs_lock:
|
||||
self._current_job_id = None
|
||||
self._calculate_queue_positions()
|
||||
|
||||
# No need to recalculate positions here - job already removed from queue
|
||||
|
||||
# Mark for async persistence (outside lock)
|
||||
job.mark_for_persistence(self._repository)
|
||||
|
||||
self._queue.task_done()
|
||||
|
||||
@@ -478,22 +688,17 @@ class JobQueue:
|
||||
|
||||
def _load_jobs_from_disk(self):
|
||||
"""Load existing job metadata from disk on startup."""
|
||||
if not os.path.exists(self._metadata_dir):
|
||||
logger.info("No existing job metadata directory found")
|
||||
logger.info("Loading jobs from disk...")
|
||||
|
||||
job_data_list = self._repository.load_all_jobs()
|
||||
|
||||
if not job_data_list:
|
||||
logger.info("No existing jobs found on disk")
|
||||
return
|
||||
|
||||
logger.info(f"Loading jobs from {self._metadata_dir}")
|
||||
|
||||
loaded_count = 0
|
||||
for filename in os.listdir(self._metadata_dir):
|
||||
if not filename.endswith(".json"):
|
||||
continue
|
||||
|
||||
filepath = os.path.join(self._metadata_dir, filename)
|
||||
for data in job_data_list:
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
job = Job.from_dict(data)
|
||||
|
||||
# Handle jobs that were running when server stopped
|
||||
@@ -501,7 +706,7 @@ class JobQueue:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = "Server restarted while job was running"
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.save_to_disk(self._metadata_dir)
|
||||
job.mark_for_persistence(self._repository)
|
||||
logger.warning(
|
||||
f"Job {job.job_id} was running during shutdown, "
|
||||
f"marking as failed"
|
||||
@@ -511,35 +716,58 @@ class JobQueue:
|
||||
elif job.status == JobStatus.QUEUED:
|
||||
try:
|
||||
self._queue.put_nowait(job)
|
||||
# Add to ordered tracking deque
|
||||
with self._queue_positions_lock:
|
||||
self._queued_job_ids.append(job.job_id)
|
||||
logger.info(f"Re-queued job {job.job_id} from disk")
|
||||
except queue.Full:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = "Queue full on server restart"
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.save_to_disk(self._metadata_dir)
|
||||
job.mark_for_persistence(self._repository)
|
||||
logger.warning(
|
||||
f"Job {job.job_id} could not be re-queued (queue full)"
|
||||
)
|
||||
|
||||
self._jobs[job.job_id] = job
|
||||
with self._jobs_lock:
|
||||
self._jobs[job.job_id] = job
|
||||
loaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load job from {filepath}: {e}")
|
||||
logger.error(f"Failed to load job: {e}")
|
||||
|
||||
logger.info(f"Loaded {loaded_count} jobs from disk")
|
||||
self._calculate_queue_positions()
|
||||
|
||||
with self._queue_positions_lock:
|
||||
self._calculate_queue_positions()
|
||||
|
||||
def _calculate_queue_positions(self):
|
||||
"""Update queue_position for all queued jobs."""
|
||||
queued_jobs = [
|
||||
j for j in self._jobs.values()
|
||||
if j.status == JobStatus.QUEUED
|
||||
"""
|
||||
Update queue_position for all queued jobs.
|
||||
|
||||
Optimized O(n) implementation using deque. Only updates positions
|
||||
for jobs still in QUEUED status.
|
||||
|
||||
IMPORTANT: Must be called with _queue_positions_lock held.
|
||||
Does NOT acquire _jobs_lock to avoid deadlock - uses snapshot approach.
|
||||
"""
|
||||
# Step 1: Create snapshot of job statuses (acquire lock briefly)
|
||||
job_status_snapshot = {}
|
||||
with self._jobs_lock:
|
||||
for job_id in self._queued_job_ids:
|
||||
if job_id in self._jobs:
|
||||
job_status_snapshot[job_id] = self._jobs[job_id].status
|
||||
|
||||
# Step 2: Filter out jobs that are no longer queued (no lock needed)
|
||||
valid_queued_ids = [
|
||||
job_id for job_id in self._queued_job_ids
|
||||
if job_id in job_status_snapshot and job_status_snapshot[job_id] == JobStatus.QUEUED
|
||||
]
|
||||
|
||||
# Sort by created_at (FIFO)
|
||||
queued_jobs.sort(key=lambda j: j.created_at)
|
||||
self._queued_job_ids = deque(valid_queued_ids)
|
||||
|
||||
# Update positions
|
||||
for i, job in enumerate(queued_jobs, start=1):
|
||||
job.queue_position = i
|
||||
# Step 3: Update positions (acquire lock briefly for each update)
|
||||
for i, job_id in enumerate(self._queued_job_ids, start=1):
|
||||
with self._jobs_lock:
|
||||
if job_id in self._jobs:
|
||||
self._jobs[job_id].queue_position = i
|
||||
|
||||
278
src/core/job_repository.py
Normal file
278
src/core/job_repository.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Job persistence layer with async I/O and write-behind caching.
|
||||
|
||||
Handles disk storage for job metadata with batched writes to reduce I/O overhead.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, List
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
DEFAULT_BATCH_INTERVAL_SECONDS = 1.0
|
||||
DEFAULT_JOB_TTL_HOURS = 24
|
||||
MAX_DIRTY_JOBS_BEFORE_FLUSH = 50
|
||||
|
||||
|
||||
class JobRepository:
|
||||
"""
|
||||
Manages job persistence with write-behind caching and TTL-based cleanup.
|
||||
|
||||
Features:
|
||||
- Async disk I/O to avoid blocking main thread
|
||||
- Batched writes (flush every N seconds or M jobs)
|
||||
- TTL-based job cleanup (removes old completed/failed jobs)
|
||||
- Thread-safe operation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_dir: str = "/outputs/jobs",
|
||||
batch_interval_seconds: float = DEFAULT_BATCH_INTERVAL_SECONDS,
|
||||
job_ttl_hours: int = DEFAULT_JOB_TTL_HOURS,
|
||||
enable_ttl_cleanup: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize job repository.
|
||||
|
||||
Args:
|
||||
metadata_dir: Directory for job metadata JSON files
|
||||
batch_interval_seconds: How often to flush dirty jobs to disk
|
||||
job_ttl_hours: Hours to keep completed/failed jobs before cleanup
|
||||
enable_ttl_cleanup: Enable automatic TTL-based cleanup
|
||||
"""
|
||||
self._metadata_dir = Path(metadata_dir)
|
||||
self._batch_interval = batch_interval_seconds
|
||||
self._job_ttl = timedelta(hours=job_ttl_hours)
|
||||
self._enable_ttl_cleanup = enable_ttl_cleanup
|
||||
|
||||
# Dirty jobs pending flush (job_id -> Job)
|
||||
self._dirty_jobs: Dict[str, any] = {}
|
||||
self._dirty_lock = threading.Lock()
|
||||
|
||||
# Background flush thread
|
||||
self._flush_thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Create metadata directory
|
||||
self._metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(
|
||||
f"JobRepository initialized: dir={metadata_dir}, "
|
||||
f"batch_interval={batch_interval_seconds}s, ttl={job_ttl_hours}h"
|
||||
)
|
||||
|
||||
def start(self):
|
||||
"""Start background flush thread."""
|
||||
if self._flush_thread is not None and self._flush_thread.is_alive():
|
||||
logger.warning("JobRepository flush thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting JobRepository background flush thread")
|
||||
self._stop_event.clear()
|
||||
self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
|
||||
self._flush_thread.start()
|
||||
|
||||
def stop(self, flush_pending: bool = True):
|
||||
"""
|
||||
Stop background flush thread.
|
||||
|
||||
Args:
|
||||
flush_pending: If True, flush all pending writes before stopping
|
||||
"""
|
||||
if self._flush_thread is None:
|
||||
return
|
||||
|
||||
logger.info(f"Stopping JobRepository (flush_pending={flush_pending})")
|
||||
|
||||
if flush_pending:
|
||||
self.flush_dirty_jobs()
|
||||
|
||||
self._stop_event.set()
|
||||
self._flush_thread.join(timeout=5.0)
|
||||
self._flush_thread = None
|
||||
|
||||
logger.info("JobRepository stopped")
|
||||
|
||||
def mark_dirty(self, job: any):
|
||||
"""
|
||||
Mark a job as dirty (needs to be written to disk).
|
||||
|
||||
Args:
|
||||
job: Job object to persist
|
||||
"""
|
||||
with self._dirty_lock:
|
||||
self._dirty_jobs[job.job_id] = job
|
||||
|
||||
# Flush immediately if too many dirty jobs
|
||||
if len(self._dirty_jobs) >= MAX_DIRTY_JOBS_BEFORE_FLUSH:
|
||||
logger.debug(
|
||||
f"Dirty job threshold reached ({len(self._dirty_jobs)}), "
|
||||
f"triggering immediate flush"
|
||||
)
|
||||
self._flush_dirty_jobs_sync()
|
||||
|
||||
def flush_dirty_jobs(self):
|
||||
"""Flush all dirty jobs to disk (synchronous)."""
|
||||
with self._dirty_lock:
|
||||
self._flush_dirty_jobs_sync()
|
||||
|
||||
def _flush_dirty_jobs_sync(self):
|
||||
"""
|
||||
Internal: Flush dirty jobs to disk.
|
||||
Must be called with _dirty_lock held.
|
||||
"""
|
||||
if not self._dirty_jobs:
|
||||
return
|
||||
|
||||
jobs_to_flush = list(self._dirty_jobs.values())
|
||||
self._dirty_jobs.clear()
|
||||
|
||||
# Lock is already held by caller, do NOT re-acquire
|
||||
# Write jobs to disk (no lock needed for I/O)
|
||||
flush_count = 0
|
||||
for job in jobs_to_flush:
|
||||
try:
|
||||
self._write_job_to_disk(job)
|
||||
flush_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to flush job {job.job_id}: {e}")
|
||||
# Re-add to dirty queue for retry
|
||||
with self._dirty_lock:
|
||||
self._dirty_jobs[job.job_id] = job
|
||||
|
||||
if flush_count > 0:
|
||||
logger.debug(f"Flushed {flush_count} jobs to disk")
|
||||
|
||||
def _write_job_to_disk(self, job: any):
|
||||
"""Write single job to disk."""
|
||||
filepath = self._metadata_dir / f"{job.job_id}.json"
|
||||
|
||||
try:
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(job.to_dict(), f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write job {job.job_id} to {filepath}: {e}")
|
||||
raise
|
||||
|
||||
def load_job(self, job_id: str) -> Optional[Dict]:
|
||||
"""
|
||||
Load job from disk.
|
||||
|
||||
Args:
|
||||
job_id: Job ID to load
|
||||
|
||||
Returns:
|
||||
Job dictionary or None if not found
|
||||
"""
|
||||
filepath = self._metadata_dir / f"{job_id}.json"
|
||||
|
||||
if not filepath.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load job {job_id} from {filepath}: {e}")
|
||||
return None
|
||||
|
||||
def load_all_jobs(self) -> List[Dict]:
|
||||
"""
|
||||
Load all jobs from disk.
|
||||
|
||||
Returns:
|
||||
List of job dictionaries
|
||||
"""
|
||||
jobs = []
|
||||
|
||||
if not self._metadata_dir.exists():
|
||||
return jobs
|
||||
|
||||
for filepath in self._metadata_dir.glob("*.json"):
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
job_data = json.load(f)
|
||||
jobs.append(job_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load job from {filepath}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(jobs)} jobs from disk")
|
||||
return jobs
|
||||
|
||||
def delete_job(self, job_id: str):
|
||||
"""
|
||||
Delete job from disk.
|
||||
|
||||
Args:
|
||||
job_id: Job ID to delete
|
||||
"""
|
||||
filepath = self._metadata_dir / f"{job_id}.json"
|
||||
|
||||
try:
|
||||
if filepath.exists():
|
||||
filepath.unlink()
|
||||
logger.debug(f"Deleted job {job_id} from disk")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete job {job_id}: {e}")
|
||||
|
||||
def cleanup_old_jobs(self, jobs_dict: Dict[str, any]):
|
||||
"""
|
||||
Clean up old completed/failed jobs based on TTL.
|
||||
|
||||
Args:
|
||||
jobs_dict: Dictionary of job_id -> Job objects to check
|
||||
"""
|
||||
if not self._enable_ttl_cleanup:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
jobs_to_delete = []
|
||||
|
||||
for job_id, job in jobs_dict.items():
|
||||
# Only cleanup completed/failed jobs
|
||||
if job.status.value not in ["completed", "failed"]:
|
||||
continue
|
||||
|
||||
# Check if job has exceeded TTL
|
||||
if job.completed_at is None:
|
||||
continue
|
||||
|
||||
age = now - job.completed_at
|
||||
if age > self._job_ttl:
|
||||
jobs_to_delete.append(job_id)
|
||||
|
||||
# Delete old jobs
|
||||
for job_id in jobs_to_delete:
|
||||
try:
|
||||
self.delete_job(job_id)
|
||||
logger.info(
|
||||
f"Cleaned up old job {job_id} "
|
||||
f"(age: {(now - jobs_dict[job_id].completed_at).total_seconds() / 3600:.1f}h)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup job {job_id}: {e}")
|
||||
|
||||
return jobs_to_delete
|
||||
|
||||
def _flush_loop(self):
|
||||
"""Background thread for periodic flush."""
|
||||
logger.info("JobRepository flush loop started")
|
||||
|
||||
while not self._stop_event.wait(timeout=self._batch_interval):
|
||||
try:
|
||||
with self._dirty_lock:
|
||||
if self._dirty_jobs:
|
||||
self._flush_dirty_jobs_sync()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in flush loop: {e}")
|
||||
|
||||
logger.info("JobRepository flush loop stopped")
|
||||
@@ -7,7 +7,8 @@ Responsible for loading, caching, and managing Whisper models
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, OrderedDict
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
from faster_whisper import WhisperModel, BatchedInferencePipeline
|
||||
|
||||
@@ -22,8 +23,10 @@ except ImportError:
|
||||
logger.warning("GPU health check with reset not available")
|
||||
GPU_HEALTH_CHECK_AVAILABLE = False
|
||||
|
||||
# Global model instance cache
|
||||
model_instances = {}
|
||||
# Global model instance cache with LRU eviction
|
||||
# Maximum number of models to keep in memory (prevents OOM)
|
||||
MAX_CACHED_MODELS = int(os.getenv("MAX_CACHED_MODELS", "3"))
|
||||
model_instances: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
|
||||
def test_gpu_driver():
|
||||
"""Simple GPU driver test"""
|
||||
@@ -111,9 +114,11 @@ def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[s
|
||||
# Generate model key
|
||||
model_key = f"{model_name}_{device}_{compute_type}"
|
||||
|
||||
# If model is already instantiated, return directly
|
||||
# If model is already instantiated, move to end (mark as recently used) and return
|
||||
if model_key in model_instances:
|
||||
logger.info(f"Using cached model instance: {model_key}")
|
||||
# Move to end for LRU
|
||||
model_instances.move_to_end(model_key)
|
||||
return model_instances[model_key]
|
||||
|
||||
# Test GPU driver before loading model and clean
|
||||
@@ -182,8 +187,33 @@ def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[s
|
||||
'load_time': time.time()
|
||||
}
|
||||
|
||||
# Cache instance
|
||||
# Implement LRU eviction before adding new model
|
||||
if len(model_instances) >= MAX_CACHED_MODELS:
|
||||
# Remove oldest (least recently used) model
|
||||
evicted_key, evicted_model = model_instances.popitem(last=False)
|
||||
logger.info(
|
||||
f"Evicting cached model (LRU): {evicted_key} "
|
||||
f"(cache limit: {MAX_CACHED_MODELS})"
|
||||
)
|
||||
|
||||
# Clean up GPU memory if it was a CUDA model
|
||||
if evicted_model['device'] == 'cuda':
|
||||
try:
|
||||
# Delete model references
|
||||
del evicted_model['model']
|
||||
if evicted_model['batched_model'] is not None:
|
||||
del evicted_model['batched_model']
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("GPU memory released for evicted model")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Error cleaning up evicted model: {cleanup_error}")
|
||||
|
||||
# Cache instance (added to end of OrderedDict)
|
||||
model_instances[model_key] = result
|
||||
logger.info(
|
||||
f"Cached model: {model_key} "
|
||||
f"(cache size: {len(model_instances)}/{MAX_CACHED_MODELS})"
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -129,31 +129,10 @@ def transcribe_audio(
|
||||
logger.info("Using standard model for transcription...")
|
||||
segments, info = model_instance['model'].transcribe(audio_source, **options)
|
||||
|
||||
# Convert generator to list
|
||||
segment_list = list(segments)
|
||||
# Convert segments generator to list to release model resources
|
||||
segments = list(segments)
|
||||
|
||||
if not segment_list:
|
||||
return "Transcription failed, no results obtained"
|
||||
|
||||
# Record transcription information
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"Transcription completed, time used: {elapsed_time:.2f} seconds, detected language: {info.language}, audio length: {info.duration:.2f} seconds")
|
||||
|
||||
# Format transcription results based on output format
|
||||
output_format_lower = output_format.lower()
|
||||
|
||||
if output_format_lower == "vtt":
|
||||
transcription_result = format_vtt(segment_list)
|
||||
elif output_format_lower == "srt":
|
||||
transcription_result = format_srt(segment_list)
|
||||
elif output_format_lower == "txt":
|
||||
transcription_result = format_txt(segment_list)
|
||||
elif output_format_lower == "json":
|
||||
transcription_result = format_json(segment_list, info)
|
||||
else:
|
||||
raise ValueError(f"Unsupported output format: {output_format}. Supported formats: vtt, srt, txt, json")
|
||||
|
||||
# Determine output directory
|
||||
# Determine output directory and path early
|
||||
audio_dir = os.path.dirname(audio_path)
|
||||
audio_filename = os.path.splitext(os.path.basename(audio_path))[0]
|
||||
|
||||
@@ -170,14 +149,14 @@ def transcribe_audio(
|
||||
|
||||
# Generate filename with customizable format
|
||||
filename_parts = []
|
||||
|
||||
|
||||
# Add prefix if specified
|
||||
if FILENAME_PREFIX:
|
||||
filename_parts.append(FILENAME_PREFIX)
|
||||
|
||||
|
||||
# Add base filename
|
||||
filename_parts.append(audio_filename)
|
||||
|
||||
|
||||
# Add suffix if specified
|
||||
if FILENAME_SUFFIX:
|
||||
filename_parts.append(FILENAME_SUFFIX)
|
||||
@@ -188,24 +167,108 @@ def transcribe_audio(
|
||||
filename_parts.append(timestamp)
|
||||
|
||||
# Join parts and add extension
|
||||
output_format_lower = output_format.lower()
|
||||
base_name = "_".join(filename_parts)
|
||||
output_filename = f"{base_name}.{output_format_lower}"
|
||||
output_path = os.path.join(output_dir, output_filename)
|
||||
|
||||
# Write transcription results to file
|
||||
# Stream segments directly to file instead of loading all into memory
|
||||
# This prevents memory spikes with long audio files
|
||||
segment_count = 0
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(transcription_result)
|
||||
logger.info(f"Transcription results saved to: {output_path}")
|
||||
return f"Transcription successful, results saved to: {output_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save transcription results: {str(e)}")
|
||||
return f"Transcription successful, but failed to save results: {str(e)}"
|
||||
# Write format-specific header
|
||||
if output_format_lower == "vtt":
|
||||
f.write("WEBVTT\n\n")
|
||||
elif output_format_lower == "json":
|
||||
f.write('{"segments": [')
|
||||
|
||||
first_segment = True
|
||||
for segment in segments:
|
||||
segment_count += 1
|
||||
|
||||
# Format and write each segment immediately
|
||||
if output_format_lower == "vtt":
|
||||
start_time = format_time(segment.start)
|
||||
end_time = format_time(segment.end)
|
||||
f.write(f"{start_time} --> {end_time}\n{segment.text.strip()}\n\n")
|
||||
elif output_format_lower == "srt":
|
||||
start_time = format_time(segment.start).replace('.', ',')
|
||||
end_time = format_time(segment.end).replace('.', ',')
|
||||
f.write(f"{segment_count}\n{start_time} --> {end_time}\n{segment.text.strip()}\n\n")
|
||||
elif output_format_lower == "txt":
|
||||
f.write(segment.text.strip() + "\n")
|
||||
elif output_format_lower == "json":
|
||||
if not first_segment:
|
||||
f.write(',')
|
||||
import json as json_module
|
||||
segment_dict = {
|
||||
"start": segment.start,
|
||||
"end": segment.end,
|
||||
"text": segment.text.strip()
|
||||
}
|
||||
f.write(json_module.dumps(segment_dict))
|
||||
first_segment = False
|
||||
else:
|
||||
raise ValueError(f"Unsupported output format: {output_format}. Supported formats: vtt, srt, txt, json")
|
||||
|
||||
# Write format-specific footer
|
||||
if output_format_lower == "json":
|
||||
# Add metadata
|
||||
f.write(f'], "language": "{info.language}", "duration": {info.duration}}}')
|
||||
|
||||
except Exception as write_error:
|
||||
logger.error(f"Failed to write transcription during streaming: {str(write_error)}")
|
||||
# File handle automatically closed by context manager
|
||||
# Clean up partial file to prevent corrupted output
|
||||
if os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
logger.info(f"Cleaned up partial file: {output_path}")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Failed to cleanup partial file {output_path}: {cleanup_error}")
|
||||
raise
|
||||
|
||||
if segment_count == 0:
|
||||
if info.duration < 1.0:
|
||||
logger.warning(f"No segments: audio too short ({info.duration:.2f}s)")
|
||||
return "Transcription failed: Audio too short (< 1 second)"
|
||||
else:
|
||||
logger.warning(
|
||||
f"No segments generated: duration={info.duration:.2f}s, "
|
||||
f"language={info.language}, vad_enabled=True"
|
||||
)
|
||||
return "Transcription failed: No speech detected (VAD filtered all segments)"
|
||||
|
||||
# Record transcription information
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Transcription completed, time used: {elapsed_time:.2f} seconds, "
|
||||
f"detected language: {info.language}, audio length: {info.duration:.2f} seconds, "
|
||||
f"segments: {segment_count}"
|
||||
)
|
||||
|
||||
# File already written via streaming above
|
||||
logger.info(f"Transcription results saved to: {output_path}")
|
||||
return f"Transcription successful, results saved to: {output_path}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed: {str(e)}")
|
||||
return f"Error occurred during transcription: {str(e)}"
|
||||
|
||||
finally:
|
||||
# Force GPU memory cleanup after transcription to prevent accumulation
|
||||
if device == "cuda":
|
||||
import torch
|
||||
import gc
|
||||
# Clear segments list to free memory
|
||||
segments = None
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
# Empty CUDA cache
|
||||
torch.cuda.empty_cache()
|
||||
logger.debug("GPU memory cleaned up after transcription")
|
||||
|
||||
|
||||
def batch_transcribe(
|
||||
audio_folder: str,
|
||||
|
||||
@@ -9,28 +9,74 @@ import sys
|
||||
import logging
|
||||
import queue as queue_module
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, List
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
import json
|
||||
import aiofiles # Async file I/O
|
||||
|
||||
from core.model_manager import get_model_info
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
from core.gpu_health import HealthMonitor, check_gpu_health, get_circuit_breaker_stats, reset_circuit_breaker
|
||||
from utils.startup import startup_sequence, cleanup_on_shutdown
|
||||
from utils.input_validation import (
|
||||
ValidationError,
|
||||
PathTraversalError,
|
||||
InvalidFileTypeError,
|
||||
FileSizeError,
|
||||
validate_beam_size,
|
||||
validate_temperature,
|
||||
validate_model_name,
|
||||
validate_device,
|
||||
validate_compute_type,
|
||||
validate_output_format,
|
||||
validate_filename_safe
|
||||
)
|
||||
|
||||
# Logging configuration
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
UPLOAD_CHUNK_SIZE_BYTES = 8192 # 8KB chunks for streaming uploads
|
||||
GPU_TEST_SLOW_THRESHOLD_SECONDS = 2.0 # GPU health check performance threshold
|
||||
DISK_SPACE_BUFFER_PERCENT = 0.10 # Require 10% extra free space as buffer
|
||||
|
||||
# Global instances
|
||||
job_queue: Optional[JobQueue] = None
|
||||
health_monitor: Optional[HealthMonitor] = None
|
||||
|
||||
|
||||
def check_disk_space(path: str, required_bytes: int) -> None:
|
||||
"""
|
||||
Check if sufficient disk space is available.
|
||||
|
||||
Args:
|
||||
path: Path to check disk space for
|
||||
required_bytes: Required bytes
|
||||
|
||||
Raises:
|
||||
IOError: If insufficient disk space
|
||||
"""
|
||||
try:
|
||||
stat = shutil.disk_usage(path)
|
||||
required_with_buffer = required_bytes * (1.0 + DISK_SPACE_BUFFER_PERCENT)
|
||||
|
||||
if stat.free < required_with_buffer:
|
||||
raise IOError(
|
||||
f"Insufficient disk space: {stat.free / 1e9:.1f}GB available, "
|
||||
f"need {required_with_buffer / 1e9:.1f}GB (including {DISK_SPACE_BUFFER_PERCENT*100:.0f}% buffer)"
|
||||
)
|
||||
except IOError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check disk space: {e}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""FastAPI lifespan context manager for startup/shutdown"""
|
||||
@@ -76,6 +122,36 @@ class SubmitJobRequest(BaseModel):
|
||||
initial_prompt: Optional[str] = Field(None, description="Initial prompt text")
|
||||
output_directory: Optional[str] = Field(None, description="Output directory path")
|
||||
|
||||
@field_validator('beam_size')
|
||||
@classmethod
|
||||
def check_beam_size(cls, v):
|
||||
return validate_beam_size(v)
|
||||
|
||||
@field_validator('temperature')
|
||||
@classmethod
|
||||
def check_temperature(cls, v):
|
||||
return validate_temperature(v)
|
||||
|
||||
@field_validator('model_name')
|
||||
@classmethod
|
||||
def check_model_name(cls, v):
|
||||
return validate_model_name(v)
|
||||
|
||||
@field_validator('device')
|
||||
@classmethod
|
||||
def check_device(cls, v):
|
||||
return validate_device(v)
|
||||
|
||||
@field_validator('compute_type')
|
||||
@classmethod
|
||||
def check_compute_type(cls, v):
|
||||
return validate_compute_type(v)
|
||||
|
||||
@field_validator('output_format')
|
||||
@classmethod
|
||||
def check_output_format(cls, v):
|
||||
return validate_output_format(v)
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
@@ -93,6 +169,7 @@ async def root():
|
||||
"GET /health/circuit-breaker": "Get circuit breaker stats",
|
||||
"POST /health/circuit-breaker/reset": "Reset circuit breaker",
|
||||
"GET /models": "Get available models information",
|
||||
"POST /transcribe": "Upload audio file and submit transcription job",
|
||||
"POST /jobs": "Submit transcription job (async)",
|
||||
"GET /jobs/{job_id}": "Get job status",
|
||||
"GET /jobs/{job_id}/result": "Get job result",
|
||||
@@ -123,6 +200,137 @@ async def get_models():
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/transcribe")
|
||||
async def transcribe_upload(
|
||||
file: UploadFile = File(...),
|
||||
model: str = Form("medium"),
|
||||
language: Optional[str] = Form(None),
|
||||
output_format: str = Form("txt"),
|
||||
beam_size: int = Form(5),
|
||||
temperature: float = Form(0.0),
|
||||
initial_prompt: Optional[str] = Form(None)
|
||||
):
|
||||
"""
|
||||
Upload audio file and submit transcription job in one request.
|
||||
|
||||
Returns immediately with job_id. Poll GET /jobs/{job_id} for status.
|
||||
"""
|
||||
temp_file_path = None
|
||||
try:
|
||||
# Validate form parameters early
|
||||
try:
|
||||
# Validate filename for security (basename-only, no path traversal)
|
||||
validate_filename_safe(file.filename)
|
||||
model = validate_model_name(model)
|
||||
output_format = validate_output_format(output_format)
|
||||
beam_size = validate_beam_size(beam_size)
|
||||
temperature = validate_temperature(temperature)
|
||||
except ValidationError as ve:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"error_type": type(ve).__name__,
|
||||
"message": str(ve)
|
||||
}
|
||||
)
|
||||
# Early queue capacity check (backpressure)
|
||||
if job_queue._queue.qsize() >= job_queue._max_queue_size:
|
||||
logger.warning("Job queue is full, rejecting upload before file transfer")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={
|
||||
"error": "Queue full",
|
||||
"message": f"Job queue is full. Please try again later.",
|
||||
"queue_size": job_queue._queue.qsize(),
|
||||
"max_queue_size": job_queue._max_queue_size
|
||||
}
|
||||
)
|
||||
|
||||
# Save uploaded file to temp directory
|
||||
upload_dir = Path(os.getenv("TRANSCRIPTION_OUTPUT_DIR", "/tmp")) / "uploads"
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check disk space before accepting upload (estimate: file size * 2 for temp + output)
|
||||
if file.size:
|
||||
try:
|
||||
check_disk_space(str(upload_dir), file.size * 2)
|
||||
except IOError as disk_error:
|
||||
logger.error(f"Disk space check failed: {disk_error}")
|
||||
raise HTTPException(
|
||||
status_code=507, # Insufficient Storage
|
||||
detail={
|
||||
"error": "Insufficient disk space",
|
||||
"message": str(disk_error)
|
||||
}
|
||||
)
|
||||
|
||||
# Create temp file with original filename
|
||||
temp_file_path = upload_dir / file.filename
|
||||
|
||||
logger.info(f"Receiving upload: {file.filename} ({file.content_type})")
|
||||
|
||||
# Save uploaded file using async I/O to avoid blocking event loop
|
||||
async with aiofiles.open(temp_file_path, "wb") as f:
|
||||
# Read file in chunks to handle large files efficiently
|
||||
while chunk := await file.read(UPLOAD_CHUNK_SIZE_BYTES):
|
||||
await f.write(chunk)
|
||||
|
||||
logger.info(f"Saved upload to: {temp_file_path}")
|
||||
|
||||
# Submit transcription job
|
||||
job_info = job_queue.submit_job(
|
||||
audio_path=str(temp_file_path),
|
||||
model_name=model,
|
||||
device="auto",
|
||||
compute_type="auto",
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
output_directory=None
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
**job_info,
|
||||
"message": f"File uploaded and job submitted. Poll /jobs/{job_info['job_id']} for status."
|
||||
}
|
||||
)
|
||||
|
||||
except queue_module.Full:
|
||||
# Clean up temp file if queue is full
|
||||
if temp_file_path is not None and temp_file_path.exists():
|
||||
temp_file_path.unlink()
|
||||
|
||||
logger.warning("Job queue is full, rejecting upload")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={
|
||||
"error": "Queue full",
|
||||
"message": f"Job queue is full. Please try again later.",
|
||||
"queue_size": job_queue._max_queue_size,
|
||||
"max_queue_size": job_queue._max_queue_size
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up temp file on error
|
||||
if temp_file_path is not None and temp_file_path.exists():
|
||||
temp_file_path.unlink()
|
||||
|
||||
logger.error(f"Failed to process upload: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Upload failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.post("/jobs")
|
||||
async def submit_job(request: SubmitJobRequest):
|
||||
"""
|
||||
@@ -152,6 +360,18 @@ async def submit_job(request: SubmitJobRequest):
|
||||
}
|
||||
)
|
||||
|
||||
except (ValidationError, PathTraversalError, InvalidFileTypeError, FileSizeError) as ve:
|
||||
# Input validation errors
|
||||
logger.error(f"Validation error: {ve}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"error_type": type(ve).__name__,
|
||||
"message": str(ve)
|
||||
}
|
||||
)
|
||||
|
||||
except queue_module.Full:
|
||||
# Queue is full
|
||||
logger.warning("Job queue is full, rejecting request")
|
||||
@@ -373,7 +593,7 @@ async def gpu_health_check_endpoint():
|
||||
interpretation = "GPU not available on this system"
|
||||
elif not status.gpu_working:
|
||||
interpretation = f"GPU available but not working correctly: {status.error}"
|
||||
elif status.test_duration_seconds > 2.0:
|
||||
elif status.test_duration_seconds > GPU_TEST_SLOW_THRESHOLD_SECONDS:
|
||||
interpretation = f"GPU working but performance degraded (test took {status.test_duration_seconds:.2f}s, expected <1s)"
|
||||
|
||||
return JSONResponse(
|
||||
@@ -446,9 +666,11 @@ if __name__ == "__main__":
|
||||
# Perform startup GPU health check
|
||||
from utils.startup import perform_startup_gpu_check
|
||||
|
||||
# Disable auto_reset in Docker (sudo not available, GPU reset won't work)
|
||||
in_docker = os.path.exists('/.dockerenv')
|
||||
perform_startup_gpu_check(
|
||||
required_device="cuda",
|
||||
auto_reset=True,
|
||||
auto_reset=not in_docker,
|
||||
exit_on_failure=True
|
||||
)
|
||||
|
||||
@@ -462,5 +684,9 @@ if __name__ == "__main__":
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info"
|
||||
log_level="info",
|
||||
timeout_keep_alive=3600, # 1 hour - for long transcription jobs
|
||||
timeout_graceful_shutdown=60,
|
||||
limit_concurrency=10, # Limit concurrent connections
|
||||
backlog=100 # Queue up to 100 pending connections
|
||||
)
|
||||
|
||||
@@ -80,9 +80,11 @@ class CircuitBreaker:
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
self._last_failure_time: Optional[datetime] = None
|
||||
# Use monotonic clock for time drift protection
|
||||
self._last_failure_time_monotonic: Optional[float] = None
|
||||
self._last_failure_time_iso: Optional[str] = None # For logging only
|
||||
self._half_open_calls = 0
|
||||
self._lock = threading.RLock()
|
||||
self._lock = threading.RLock() # RLock needed: properties call self.state which acquires lock
|
||||
|
||||
logger.info(
|
||||
f"Circuit breaker '{name}' initialized: "
|
||||
@@ -113,15 +115,20 @@ class CircuitBreaker:
|
||||
return self.state == CircuitState.HALF_OPEN
|
||||
|
||||
def _update_state(self):
|
||||
"""Update state based on timeout and counters."""
|
||||
"""
|
||||
Update state based on timeout and counters.
|
||||
|
||||
Uses monotonic clock to prevent issues with system time changes
|
||||
(e.g., NTP adjustments, daylight saving time, manual clock changes).
|
||||
"""
|
||||
if self._state == CircuitState.OPEN:
|
||||
# Check if timeout has passed
|
||||
if self._last_failure_time:
|
||||
elapsed = datetime.utcnow() - self._last_failure_time
|
||||
if elapsed.total_seconds() >= self.config.timeout_seconds:
|
||||
# Check if timeout has passed using monotonic clock
|
||||
if self._last_failure_time_monotonic is not None:
|
||||
elapsed = time.monotonic() - self._last_failure_time_monotonic
|
||||
if elapsed >= self.config.timeout_seconds:
|
||||
logger.info(
|
||||
f"Circuit '{self.name}': Transitioning to HALF_OPEN "
|
||||
f"after {elapsed.total_seconds():.0f}s timeout"
|
||||
f"after {elapsed:.0f}s timeout"
|
||||
)
|
||||
self._state = CircuitState.HALF_OPEN
|
||||
self._half_open_calls = 0
|
||||
@@ -142,7 +149,8 @@ class CircuitBreaker:
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
self._last_failure_time = None
|
||||
self._last_failure_time_monotonic = None
|
||||
self._last_failure_time_iso = None
|
||||
|
||||
elif self._state == CircuitState.CLOSED:
|
||||
# Reset failure count on success
|
||||
@@ -152,7 +160,9 @@ class CircuitBreaker:
|
||||
"""Handle failed call."""
|
||||
with self._lock:
|
||||
self._failure_count += 1
|
||||
self._last_failure_time = datetime.utcnow()
|
||||
# Record failure time using monotonic clock for accuracy
|
||||
self._last_failure_time_monotonic = time.monotonic()
|
||||
self._last_failure_time_iso = datetime.utcnow().isoformat()
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
logger.warning(
|
||||
@@ -200,11 +210,12 @@ class CircuitBreaker:
|
||||
raise CircuitBreakerOpen(
|
||||
f"Circuit '{self.name}' is OPEN. "
|
||||
f"Failing fast to prevent repeated failures. "
|
||||
f"Last failure: {self._last_failure_time.isoformat() if self._last_failure_time else 'unknown'}. "
|
||||
f"Last failure: {self._last_failure_time_iso or 'unknown'}. "
|
||||
f"Will retry in {self.config.timeout_seconds}s"
|
||||
)
|
||||
|
||||
# Check half-open call limit
|
||||
half_open_incremented = False
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
if self._half_open_calls >= self.config.half_open_max_calls:
|
||||
raise CircuitBreakerOpen(
|
||||
@@ -212,6 +223,7 @@ class CircuitBreaker:
|
||||
f"Please wait for current test to complete."
|
||||
)
|
||||
self._half_open_calls += 1
|
||||
half_open_incremented = True
|
||||
|
||||
# Execute function
|
||||
try:
|
||||
@@ -224,9 +236,9 @@ class CircuitBreaker:
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Decrement half-open counter
|
||||
with self._lock:
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
# Decrement half-open counter only if we incremented it
|
||||
if half_open_incremented:
|
||||
with self._lock:
|
||||
self._half_open_calls -= 1
|
||||
|
||||
def decorator(self):
|
||||
@@ -260,7 +272,8 @@ class CircuitBreaker:
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
self._last_failure_time = None
|
||||
self._last_failure_time_monotonic = None
|
||||
self._last_failure_time_iso = None
|
||||
self._half_open_calls = 0
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
@@ -277,7 +290,7 @@ class CircuitBreaker:
|
||||
"state": self._state.value,
|
||||
"failure_count": self._failure_count,
|
||||
"success_count": self._success_count,
|
||||
"last_failure_time": self._last_failure_time.isoformat() if self._last_failure_time else None,
|
||||
"last_failure_time": self._last_failure_time_iso,
|
||||
"config": {
|
||||
"failure_threshold": self.config.failure_threshold,
|
||||
"success_threshold": self.config.success_threshold,
|
||||
|
||||
@@ -88,6 +88,72 @@ def sanitize_error_message(error_msg: str, sanitize_paths: bool = True) -> str:
|
||||
return sanitized
|
||||
|
||||
|
||||
def validate_filename_safe(filename: str) -> str:
|
||||
"""
|
||||
Validate uploaded filename for security (basename-only validation).
|
||||
|
||||
This function is specifically for validating uploaded filenames to ensure
|
||||
they don't contain path traversal attempts. It enforces that the filename:
|
||||
- Contains no directory separators (/, \)
|
||||
- Has no path components (must be basename only)
|
||||
- Contains no null bytes
|
||||
- Has a valid audio file extension
|
||||
|
||||
Args:
|
||||
filename: Filename to validate (should be basename only, not full path)
|
||||
|
||||
Returns:
|
||||
Validated filename (unchanged if valid)
|
||||
|
||||
Raises:
|
||||
ValidationError: If filename is invalid or empty
|
||||
PathTraversalError: If filename contains path components or traversal attempts
|
||||
InvalidFileTypeError: If file extension is not allowed
|
||||
|
||||
Examples:
|
||||
validate_filename_safe("video.mp4") # ✓ PASS
|
||||
validate_filename_safe("audio...mp3") # ✓ PASS (ellipsis OK)
|
||||
validate_filename_safe("Wait... what.m4a") # ✓ PASS
|
||||
validate_filename_safe("../../../etc/passwd") # ✗ FAIL (traversal)
|
||||
validate_filename_safe("dir/file.mp4") # ✗ FAIL (path separator)
|
||||
validate_filename_safe("/etc/passwd") # ✗ FAIL (absolute path)
|
||||
"""
|
||||
if not filename:
|
||||
raise ValidationError("Filename cannot be empty")
|
||||
|
||||
# Check for null bytes
|
||||
if "\x00" in filename:
|
||||
logger.warning(f"Null byte in filename detected: {filename}")
|
||||
raise PathTraversalError("Null bytes in filename are not allowed")
|
||||
|
||||
# Extract basename - if it differs from original, filename contained path components
|
||||
basename = os.path.basename(filename)
|
||||
if basename != filename:
|
||||
logger.warning(f"Filename contains path components: {filename}")
|
||||
raise PathTraversalError(
|
||||
"Filename must not contain path components. "
|
||||
f"Use only the filename: {basename}"
|
||||
)
|
||||
|
||||
# Additional check: explicitly reject any path separators
|
||||
if "/" in filename or "\\" in filename:
|
||||
logger.warning(f"Path separators in filename: {filename}")
|
||||
raise PathTraversalError("Path separators (/ or \\) are not allowed in filename")
|
||||
|
||||
# Check file extension (case-insensitive)
|
||||
file_ext = Path(filename).suffix.lower()
|
||||
if not file_ext:
|
||||
raise InvalidFileTypeError("Filename must have a file extension")
|
||||
|
||||
if file_ext not in ALLOWED_AUDIO_EXTENSIONS:
|
||||
raise InvalidFileTypeError(
|
||||
f"Unsupported audio format: {file_ext}. "
|
||||
f"Supported: {', '.join(sorted(ALLOWED_AUDIO_EXTENSIONS))}"
|
||||
)
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def validate_path_safe(file_path: str, allowed_dirs: Optional[List[str]] = None) -> Path:
|
||||
"""
|
||||
Validate and sanitize a file path to prevent directory traversal attacks.
|
||||
@@ -112,10 +178,13 @@ def validate_path_safe(file_path: str, allowed_dirs: Optional[List[str]] = None)
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Invalid path format: {sanitize_error_message(str(e))}")
|
||||
|
||||
# Check for path traversal attempts
|
||||
# Check for path traversal attempts in path components
|
||||
# This allows filenames with ellipsis (e.g., "Wait...mp3", "file...audio.m4a")
|
||||
# while blocking actual path traversal (e.g., "../../../etc/passwd")
|
||||
path_str = str(path)
|
||||
if ".." in path_str:
|
||||
logger.warning(f"Path traversal attempt detected: {path_str}")
|
||||
path_parts = path.parts
|
||||
if any(part == ".." for part in path_parts):
|
||||
logger.warning(f"Path traversal attempt detected in components: {path_str}")
|
||||
raise PathTraversalError("Path traversal (..) is not allowed")
|
||||
|
||||
# Check for null bytes
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Test audio generator for GPU health checks.
|
||||
|
||||
Generates realistic test audio with speech using TTS (text-to-speech).
|
||||
Returns path to existing test audio file - NO GENERATION, NO INTERNET.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -10,70 +10,35 @@ import tempfile
|
||||
|
||||
def generate_test_audio(duration_seconds: float = 3.0, frequency: int = 440) -> str:
|
||||
"""
|
||||
Generate a test audio file with real speech for GPU health checks.
|
||||
Return path to existing test audio file for GPU health checks.
|
||||
|
||||
NO AUDIO GENERATION - just returns path to pre-existing test file.
|
||||
NO INTERNET CONNECTION REQUIRED.
|
||||
|
||||
Args:
|
||||
duration_seconds: Duration of audio in seconds (default: 3.0)
|
||||
frequency: Legacy parameter, ignored (kept for backward compatibility)
|
||||
duration_seconds: Duration hint (default: 3.0) - used for cache lookup
|
||||
frequency: Legacy parameter, ignored
|
||||
|
||||
Returns:
|
||||
str: Path to temporary audio file
|
||||
str: Path to test audio file
|
||||
|
||||
Implementation:
|
||||
- Generate real speech using gTTS (Google Text-to-Speech)
|
||||
- Fallback to pyttsx3 if gTTS fails or is unavailable
|
||||
- Raises RuntimeError if both TTS engines fail
|
||||
- Save as MP3 format
|
||||
- Store in system temp directory
|
||||
- Reuse same file if exists (cache)
|
||||
Raises:
|
||||
RuntimeError: If test audio file doesn't exist
|
||||
"""
|
||||
# Use a consistent filename in temp directory for caching
|
||||
# Check for existing test audio in temp directory
|
||||
temp_dir = tempfile.gettempdir()
|
||||
audio_path = os.path.join(temp_dir, f"whisper_test_voice_{int(duration_seconds)}s.mp3")
|
||||
|
||||
# Return cached file if it exists and is valid
|
||||
if os.path.exists(audio_path):
|
||||
try:
|
||||
# Verify file is readable and not empty
|
||||
if os.path.getsize(audio_path) > 0:
|
||||
return audio_path
|
||||
except Exception:
|
||||
# If file is corrupted, regenerate it
|
||||
pass
|
||||
|
||||
# Generate speech with different text based on duration
|
||||
if duration_seconds >= 3:
|
||||
text = "This is a test of the Whisper speech recognition system. Testing one, two, three."
|
||||
elif duration_seconds >= 2:
|
||||
text = "This is a test of the Whisper system."
|
||||
else:
|
||||
text = "Testing Whisper."
|
||||
|
||||
# Try gTTS first (better quality, requires internet)
|
||||
try:
|
||||
from gtts import gTTS
|
||||
tts = gTTS(text=text, lang='en', slow=False)
|
||||
tts.save(audio_path)
|
||||
if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0:
|
||||
return audio_path
|
||||
except Exception as e:
|
||||
print(f"gTTS failed ({e}), trying pyttsx3...")
|
||||
|
||||
# Fallback to pyttsx3 (offline, lower quality)
|
||||
try:
|
||||
import pyttsx3
|
||||
engine = pyttsx3.init()
|
||||
engine.save_to_file(text, audio_path)
|
||||
engine.runAndWait()
|
||||
|
||||
# Verify file was created
|
||||
if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0:
|
||||
return audio_path
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to generate test audio. Both gTTS and pyttsx3 failed. "
|
||||
f"gTTS error: {e}. Please ensure TTS dependencies are installed: "
|
||||
f"pip install gTTS pyttsx3"
|
||||
)
|
||||
# If no cached file, raise error - we don't generate anything
|
||||
raise RuntimeError(
|
||||
f"Test audio file not found: {audio_path}. "
|
||||
f"Please ensure test audio exists before running GPU health checks. "
|
||||
f"Expected file location: {audio_path}"
|
||||
)
|
||||
|
||||
|
||||
def cleanup_test_audio() -> None:
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
[program:whisper-api-server]
|
||||
command=/home/uad/agents/tools/mcp-transcriptor/venv/bin/python /home/uad/agents/tools/mcp-transcriptor/src/servers/api_server.py
|
||||
directory=/home/uad/agents/tools/mcp-transcriptor
|
||||
user=uad
|
||||
autostart=true
|
||||
autorestart=true
|
||||
redirect_stderr=true
|
||||
stdout_logfile=/home/uad/agents/tools/mcp-transcriptor/logs/transcriptor-api.log
|
||||
stdout_logfile_maxbytes=50MB
|
||||
stdout_logfile_backups=10
|
||||
environment=
|
||||
PYTHONPATH="/home/uad/agents/tools/mcp-transcriptor/src",
|
||||
CUDA_VISIBLE_DEVICES="0",
|
||||
API_HOST="0.0.0.0",
|
||||
API_PORT="8000",
|
||||
WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/models",
|
||||
TRANSCRIPTION_OUTPUT_DIR="/home/uad/agents/tools/mcp-transcriptor/outputs",
|
||||
TRANSCRIPTION_BATCH_OUTPUT_DIR="/home/uad/agents/tools/mcp-transcriptor/outputs/batch",
|
||||
TRANSCRIPTION_MODEL="large-v3",
|
||||
TRANSCRIPTION_DEVICE="auto",
|
||||
TRANSCRIPTION_COMPUTE_TYPE="auto",
|
||||
TRANSCRIPTION_OUTPUT_FORMAT="txt"
|
||||
stopwaitsecs=10
|
||||
stopsignal=TERM
|
||||
60
test_filename_fix.py
Normal file
60
test_filename_fix.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick manual test to verify the filename validation fix.
|
||||
Tests the exact case from the bug report.
|
||||
"""
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, 'src')
|
||||
|
||||
from utils.input_validation import validate_filename_safe, PathTraversalError
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("FILENAME VALIDATION FIX - MANUAL TEST")
|
||||
print("="*70 + "\n")
|
||||
|
||||
# Bug report case
|
||||
bug_report_filename = "This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a"
|
||||
|
||||
print(f"Testing bug report filename:")
|
||||
print(f" '{bug_report_filename}'")
|
||||
print()
|
||||
|
||||
try:
|
||||
result = validate_filename_safe(bug_report_filename)
|
||||
print(f"✅ SUCCESS: Filename accepted!")
|
||||
print(f" Returned: '{result}'")
|
||||
except PathTraversalError as e:
|
||||
print(f"❌ FAILED: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
|
||||
# Test that security still works
|
||||
print("Verifying security (path traversal should still be blocked):")
|
||||
dangerous_filenames = [
|
||||
"../../../etc/passwd",
|
||||
"../../secrets.txt",
|
||||
"dir/file.m4a",
|
||||
]
|
||||
|
||||
for dangerous in dangerous_filenames:
|
||||
try:
|
||||
validate_filename_safe(dangerous)
|
||||
print(f"❌ SECURITY ISSUE: '{dangerous}' was accepted (should be blocked!)")
|
||||
sys.exit(1)
|
||||
except PathTraversalError:
|
||||
print(f"✅ '{dangerous}' correctly blocked")
|
||||
|
||||
print()
|
||||
print("="*70)
|
||||
print("ALL TESTS PASSED! ✅")
|
||||
print("="*70)
|
||||
print()
|
||||
print("The fix is working correctly:")
|
||||
print(" ✓ Filenames with ellipsis (...) are now accepted")
|
||||
print(" ✓ Path traversal attacks are still blocked")
|
||||
print()
|
||||
281
tests/test_input_validation.py
Normal file
281
tests/test_input_validation.py
Normal file
@@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for input validation module, specifically filename validation.
|
||||
|
||||
Tests the security-critical validate_filename_safe() function to ensure
|
||||
it correctly blocks path traversal attacks while allowing legitimate filenames.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import pytest
|
||||
|
||||
# Add src to path (go up one level from tests/ to root)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from src.utils.input_validation import (
|
||||
validate_filename_safe,
|
||||
ValidationError,
|
||||
PathTraversalError,
|
||||
InvalidFileTypeError,
|
||||
ALLOWED_AUDIO_EXTENSIONS
|
||||
)
|
||||
|
||||
|
||||
class TestValidFilenameSafe:
|
||||
"""Test validate_filename_safe() function with various inputs."""
|
||||
|
||||
def test_simple_valid_filenames(self):
|
||||
"""Test that simple, valid filenames are accepted."""
|
||||
valid_names = [
|
||||
"audio.m4a",
|
||||
"song.wav",
|
||||
"podcast.mp3",
|
||||
"recording.flac",
|
||||
"music.ogg",
|
||||
"voice.aac",
|
||||
]
|
||||
|
||||
for filename in valid_names:
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept: {filename}"
|
||||
|
||||
def test_filenames_with_ellipsis(self):
|
||||
"""Test filenames with ellipsis (multiple dots) are accepted."""
|
||||
# This is the key test case from the bug report
|
||||
ellipsis_names = [
|
||||
"audio...mp3",
|
||||
"This is... a test.m4a",
|
||||
"Part 1... Part 2.wav",
|
||||
"Wait... what.m4a",
|
||||
"video...multiple...dots.mp3",
|
||||
"This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a", # Bug report case
|
||||
]
|
||||
|
||||
for filename in ellipsis_names:
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept filename with ellipsis: {filename}"
|
||||
|
||||
def test_filenames_with_special_chars(self):
|
||||
"""Test filenames with various special characters."""
|
||||
special_char_names = [
|
||||
"My-Video_2024.m4a",
|
||||
"song (remix).m4a",
|
||||
"audio [final].wav",
|
||||
"test file with spaces.mp3",
|
||||
"file-name_with-symbols.flac",
|
||||
]
|
||||
|
||||
for filename in special_char_names:
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept: {filename}"
|
||||
|
||||
def test_multiple_extensions(self):
|
||||
"""Test filenames that look like they have multiple extensions."""
|
||||
multi_ext_names = [
|
||||
"backup.tar.gz.mp3", # .mp3 is valid
|
||||
"file.old.wav", # .wav is valid
|
||||
"audio.2024.m4a", # .m4a is valid
|
||||
]
|
||||
|
||||
for filename in multi_ext_names:
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept: {filename}"
|
||||
|
||||
def test_path_traversal_attempts(self):
|
||||
"""Test that path traversal attempts are rejected."""
|
||||
dangerous_names = [
|
||||
"../../../etc/passwd",
|
||||
"../../secrets.txt",
|
||||
"../file.mp4",
|
||||
"dir/../file.mp4",
|
||||
"file/../../etc/passwd",
|
||||
]
|
||||
|
||||
for filename in dangerous_names:
|
||||
with pytest.raises(PathTraversalError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "path" in str(exc_info.value).lower(), f"Should reject path traversal: {filename}"
|
||||
|
||||
def test_absolute_paths(self):
|
||||
"""Test that absolute paths are rejected."""
|
||||
absolute_paths = [
|
||||
"/etc/passwd",
|
||||
"/tmp/file.mp4",
|
||||
"/home/user/audio.wav",
|
||||
"C:\\Windows\\System32\\file.mp3", # Windows path
|
||||
"\\\\server\\share\\file.m4a", # UNC path
|
||||
]
|
||||
|
||||
for filename in absolute_paths:
|
||||
with pytest.raises(PathTraversalError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "path" in str(exc_info.value).lower(), f"Should reject absolute path: {filename}"
|
||||
|
||||
def test_path_separators(self):
|
||||
"""Test that filenames with path separators are rejected."""
|
||||
paths_with_separators = [
|
||||
"dir/file.mp4",
|
||||
"folder\\file.wav",
|
||||
"path/to/audio.m4a",
|
||||
"a/b/c/d.mp3",
|
||||
]
|
||||
|
||||
for filename in paths_with_separators:
|
||||
with pytest.raises(PathTraversalError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "separator" in str(exc_info.value).lower() or "path" in str(exc_info.value).lower(), \
|
||||
f"Should reject path with separators: {filename}"
|
||||
|
||||
def test_null_bytes(self):
|
||||
"""Test that filenames with null bytes are rejected."""
|
||||
null_byte_names = [
|
||||
"file\x00.mp4",
|
||||
"\x00malicious.wav",
|
||||
"audio\x00evil.m4a",
|
||||
]
|
||||
|
||||
for filename in null_byte_names:
|
||||
with pytest.raises(PathTraversalError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "null" in str(exc_info.value).lower(), f"Should reject null bytes: {repr(filename)}"
|
||||
|
||||
def test_empty_filename(self):
|
||||
"""Test that empty filename is rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_filename_safe("")
|
||||
assert "empty" in str(exc_info.value).lower()
|
||||
|
||||
def test_no_extension(self):
|
||||
"""Test that filenames without extensions are rejected."""
|
||||
no_ext_names = [
|
||||
"filename",
|
||||
"noextension",
|
||||
]
|
||||
|
||||
for filename in no_ext_names:
|
||||
with pytest.raises(InvalidFileTypeError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "extension" in str(exc_info.value).lower(), f"Should reject no extension: {filename}"
|
||||
|
||||
def test_invalid_extensions(self):
|
||||
"""Test that unsupported file extensions are rejected."""
|
||||
invalid_ext_names = [
|
||||
"document.pdf",
|
||||
"image.png",
|
||||
"video.avi",
|
||||
"script.sh",
|
||||
"executable.exe",
|
||||
"text.txt",
|
||||
]
|
||||
|
||||
for filename in invalid_ext_names:
|
||||
with pytest.raises(InvalidFileTypeError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "unsupported" in str(exc_info.value).lower() or "format" in str(exc_info.value).lower(), \
|
||||
f"Should reject invalid extension: {filename}"
|
||||
|
||||
def test_case_insensitive_extensions(self):
|
||||
"""Test that file extensions are case-insensitive."""
|
||||
case_variations = [
|
||||
"audio.MP3",
|
||||
"sound.WAV",
|
||||
"music.M4A",
|
||||
"podcast.FLAC",
|
||||
"voice.AAC",
|
||||
]
|
||||
|
||||
for filename in case_variations:
|
||||
# Should not raise exception
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept case variation: {filename}"
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test various edge cases."""
|
||||
# Just dots (but with valid extension) - should pass
|
||||
assert validate_filename_safe("...mp3") == "...mp3"
|
||||
assert validate_filename_safe("....wav") == "....wav"
|
||||
|
||||
# Filenames starting with dot (hidden files on Unix)
|
||||
assert validate_filename_safe(".hidden.m4a") == ".hidden.m4a"
|
||||
|
||||
# Very long filename (but valid)
|
||||
long_name = "a" * 200 + ".mp3"
|
||||
assert validate_filename_safe(long_name) == long_name
|
||||
|
||||
def test_allowed_extensions_comprehensive(self):
|
||||
"""Test all allowed extensions from ALLOWED_AUDIO_EXTENSIONS."""
|
||||
for ext in ALLOWED_AUDIO_EXTENSIONS:
|
||||
filename = f"test{ext}"
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept allowed extension: {ext}"
|
||||
|
||||
|
||||
class TestBugReportCase:
|
||||
"""Specific test for the bug report case."""
|
||||
|
||||
def test_bug_report_filename(self):
|
||||
"""
|
||||
Test the exact filename from the bug report that was failing.
|
||||
|
||||
Bug: "This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a"
|
||||
was being rejected due to "..." being parsed as ".."
|
||||
"""
|
||||
filename = "This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a"
|
||||
|
||||
# Should NOT raise any exception
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename
|
||||
|
||||
def test_various_ellipsis_patterns(self):
|
||||
"""Test various ellipsis patterns that should all be accepted."""
|
||||
patterns = [
|
||||
"...", # Three dots
|
||||
"....", # Four dots
|
||||
".....", # Five dots
|
||||
"file...end.mp3",
|
||||
"start...middle...end.wav",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
if not pattern.endswith(tuple(f"{ext}" for ext in ALLOWED_AUDIO_EXTENSIONS)):
|
||||
pattern += ".mp3" # Add valid extension
|
||||
result = validate_filename_safe(pattern)
|
||||
assert result == pattern
|
||||
|
||||
|
||||
class TestSecurityBoundary:
|
||||
"""Test the security boundary between safe and dangerous filenames."""
|
||||
|
||||
def test_just_two_dots_vs_path_separator(self):
|
||||
"""
|
||||
Test the critical distinction:
|
||||
- "file..mp3" (two dots in filename) = SAFE
|
||||
- "../file.mp3" (two dots as path component) = DANGEROUS
|
||||
"""
|
||||
# Safe: dots within filename
|
||||
safe_filenames = [
|
||||
"file..mp3",
|
||||
"..file.mp3",
|
||||
"file...mp3",
|
||||
"f..i..l..e.mp3",
|
||||
]
|
||||
|
||||
for filename in safe_filenames:
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should be safe: {filename}"
|
||||
|
||||
# Dangerous: dots as directory reference
|
||||
dangerous_filenames = [
|
||||
"../file.mp3",
|
||||
"../../file.mp3",
|
||||
"dir/../file.mp3",
|
||||
]
|
||||
|
||||
for filename in dangerous_filenames:
|
||||
with pytest.raises(PathTraversalError):
|
||||
validate_filename_safe(filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
208
tests/test_path_traversal_fix.py
Normal file
208
tests/test_path_traversal_fix.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test path traversal detection with ellipsis support.
|
||||
|
||||
Tests the fix for false positives where filenames containing ellipsis (...)
|
||||
were incorrectly flagged as path traversal attempts.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from utils.input_validation import (
|
||||
validate_path_safe,
|
||||
validate_audio_file,
|
||||
PathTraversalError,
|
||||
ValidationError,
|
||||
InvalidFileTypeError
|
||||
)
|
||||
|
||||
|
||||
class TestPathTraversalWithEllipsis:
|
||||
"""Test that ellipsis in filenames is allowed while blocking real attacks."""
|
||||
|
||||
def test_filename_with_ellipsis_allowed(self, tmp_path):
|
||||
"""Filenames with ellipsis (...) should be allowed."""
|
||||
test_cases = [
|
||||
"Wait... what.mp3",
|
||||
"This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a",
|
||||
"file...mp3",
|
||||
"test....audio.wav",
|
||||
"a..b..c.mp3",
|
||||
"dots.........everywhere.m4a"
|
||||
]
|
||||
|
||||
for filename in test_cases:
|
||||
# Create test file
|
||||
test_file = tmp_path / filename
|
||||
test_file.write_text("fake audio data")
|
||||
|
||||
# Should NOT raise PathTraversalError
|
||||
try:
|
||||
result = validate_path_safe(str(test_file), [str(tmp_path)])
|
||||
assert result.exists(), f"File should exist: {filename}"
|
||||
print(f"✓ PASS: {filename}")
|
||||
except PathTraversalError as e:
|
||||
pytest.fail(f"False positive for filename: {filename}. Error: {e}")
|
||||
|
||||
def test_actual_path_traversal_blocked(self, tmp_path):
|
||||
"""Actual path traversal attempts should be blocked."""
|
||||
attack_cases = [
|
||||
"../../../etc/passwd",
|
||||
"..\\..\\..\\windows\\system32",
|
||||
"legitimate/../../../etc/passwd",
|
||||
"dir/../../secret",
|
||||
"../",
|
||||
"..",
|
||||
"subdir/../../../etc/hosts"
|
||||
]
|
||||
|
||||
for attack_path in attack_cases:
|
||||
with pytest.raises(PathTraversalError):
|
||||
validate_path_safe(attack_path, [str(tmp_path)])
|
||||
print(f"✗ FAIL: Should have blocked: {attack_path}")
|
||||
print(f"✓ PASS: Blocked attack: {attack_path}")
|
||||
|
||||
def test_ellipsis_in_full_path_allowed(self, tmp_path):
|
||||
"""Full paths with ellipsis in filename should be allowed."""
|
||||
# Create nested directory
|
||||
subdir = tmp_path / "uploads"
|
||||
subdir.mkdir()
|
||||
|
||||
filename = "Wait... what.mp3"
|
||||
test_file = subdir / filename
|
||||
test_file.write_text("fake audio data")
|
||||
|
||||
# Full path should be allowed when directory is in allowed_dirs
|
||||
result = validate_path_safe(str(test_file), [str(tmp_path)])
|
||||
assert result.exists()
|
||||
print(f"✓ PASS: Full path with ellipsis: {test_file}")
|
||||
|
||||
def test_mixed_dots_edge_cases(self, tmp_path):
|
||||
"""Test edge cases with various dot patterns."""
|
||||
edge_cases = [
|
||||
("single.dot.mp3", True), # Normal filename
|
||||
("..two.dots.mp3", True), # Starts with two dots (filename)
|
||||
("three...dots.mp3", True), # Three consecutive dots
|
||||
("many.....dots.mp3", True), # Many consecutive dots
|
||||
(".", False), # Current directory (should fail)
|
||||
("..", False), # Parent directory (should fail)
|
||||
]
|
||||
|
||||
for filename, should_pass in edge_cases:
|
||||
if should_pass:
|
||||
# Create test file
|
||||
test_file = tmp_path / filename
|
||||
test_file.write_text("fake audio data")
|
||||
|
||||
try:
|
||||
result = validate_path_safe(str(test_file), [str(tmp_path)])
|
||||
assert result.exists(), f"File should exist: {filename}"
|
||||
print(f"✓ PASS: Allowed: {filename}")
|
||||
except PathTraversalError:
|
||||
pytest.fail(f"Should have allowed: {filename}")
|
||||
else:
|
||||
with pytest.raises((PathTraversalError, ValidationError)):
|
||||
validate_path_safe(filename, [str(tmp_path)])
|
||||
print(f"✓ PASS: Blocked: {filename}")
|
||||
|
||||
|
||||
class TestAudioFileValidationWithEllipsis:
|
||||
"""Test full audio file validation with ellipsis support."""
|
||||
|
||||
def test_audio_file_with_ellipsis(self, tmp_path):
|
||||
"""Audio files with ellipsis should pass validation."""
|
||||
filename = "This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a"
|
||||
test_file = tmp_path / filename
|
||||
test_file.write_bytes(b"fake audio data" * 100) # Non-empty file
|
||||
|
||||
# Should pass validation
|
||||
result = validate_audio_file(str(test_file), [str(tmp_path)])
|
||||
assert result.exists()
|
||||
print(f"✓ PASS: Audio validation with ellipsis: {filename}")
|
||||
|
||||
def test_audio_file_traversal_attack_blocked(self, tmp_path):
|
||||
"""Audio file validation should block path traversal."""
|
||||
attack_path = "../../../etc/passwd"
|
||||
|
||||
with pytest.raises(PathTraversalError):
|
||||
validate_audio_file(attack_path, [str(tmp_path)])
|
||||
print(f"✓ PASS: Audio validation blocked attack: {attack_path}")
|
||||
|
||||
|
||||
class TestComponentBasedDetection:
|
||||
"""Test that detection is based on path components, not string matching."""
|
||||
|
||||
def test_component_analysis(self, tmp_path):
|
||||
"""Verify that we're analyzing components, not doing string matching."""
|
||||
# These should PASS (ellipsis is in the filename component)
|
||||
safe_cases = [
|
||||
tmp_path / "file...mp3",
|
||||
tmp_path / "subdir" / "Wait...what.m4a",
|
||||
]
|
||||
|
||||
for test_path in safe_cases:
|
||||
test_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_path.write_text("data")
|
||||
|
||||
# Check that ".." is not in any component
|
||||
parts = Path(test_path).parts
|
||||
assert not any(part == ".." for part in parts), \
|
||||
f"Should not have '..' as a component: {test_path}"
|
||||
|
||||
# Validation should pass
|
||||
result = validate_path_safe(str(test_path), [str(tmp_path)])
|
||||
assert result.exists()
|
||||
print(f"✓ PASS: Component analysis correct: {test_path}")
|
||||
|
||||
def test_component_attack_detection(self):
|
||||
"""Verify that actual '..' components are detected."""
|
||||
# These should FAIL ('..' is a path component)
|
||||
attack_cases = [
|
||||
"../etc/passwd",
|
||||
"dir/../secret",
|
||||
"../../file.mp3",
|
||||
]
|
||||
|
||||
for attack_path in attack_cases:
|
||||
path = Path(attack_path)
|
||||
parts = path.parts
|
||||
|
||||
# Verify that ".." IS in components
|
||||
assert any(part == ".." for part in parts), \
|
||||
f"Should have '..' as a component: {attack_path}"
|
||||
print(f"✓ PASS: Attack has '..' component: {attack_path}")
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests with verbose output."""
|
||||
print("=" * 70)
|
||||
print("Running Path Traversal Detection Tests")
|
||||
print("=" * 70)
|
||||
|
||||
# Run pytest with verbose output
|
||||
exit_code = pytest.main([
|
||||
__file__,
|
||||
"-v",
|
||||
"--tb=short",
|
||||
"-p", "no:warnings"
|
||||
])
|
||||
|
||||
print("=" * 70)
|
||||
if exit_code == 0:
|
||||
print("✓ All tests passed!")
|
||||
else:
|
||||
print("✗ Some tests failed!")
|
||||
print("=" * 70)
|
||||
|
||||
return exit_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(run_tests())
|
||||
Reference in New Issue
Block a user