Compare commits

..

13 Commits

Author SHA1 Message Date
Alihan
5fb742a312 Add circuit breaker, input validation, and refactor startup logic
- Implement circuit breaker pattern for GPU health checks
  - Prevents repeated failures with configurable thresholds
  - Three states: CLOSED, OPEN, HALF_OPEN
  - Integrated into GPU health monitoring

- Add comprehensive input validation and path sanitization
  - Path traversal attack prevention
  - Whitelist-based validation for models, devices, formats
  - Error message sanitization to prevent information leakage
  - File size limits and security checks

- Centralize startup logic across servers
  - Extract common startup procedures to utils/startup.py
  - Deduplicate GPU health checks and initialization code
  - Simplify both MCP and API server startup sequences

- Add proper Python package structure
  - Add __init__.py files to all modules
  - Improve package organization

- Add circuit breaker status API endpoints
  - GET /health/circuit-breaker - View circuit breaker stats
  - POST /health/circuit-breaker/reset - Reset circuit breaker

- Reorganize test files into tests/ directory
  - Rename and restructure test files for better organization
2025-10-10 01:03:55 +03:00
Alihan
40555592e6 Fix deadlock in job queue and refactor Phase 2 tests
- Fix: Change threading.Lock to threading.RLock in JobQueue to prevent deadlock
  - Issue: list_jobs() acquired lock then called get_job_status() which tried to acquire same lock
  - Solution: Use re-entrant lock (RLock) to allow nested lock acquisition (src/core/job_queue.py:144)

- Refactor: Update test_phase2.py to use real test.mp3 file
  - Changed _create_test_audio_file() to return /home/uad/agents/tools/mcp-transcriptor/data/test.mp3
  - Removed specific text assertion, now just verifies transcription is not empty
  - Tests use tiny model for speed while processing real 6.95s audio file

- Update: Improve audio validation error handling in transcriber.py
  - Changed validate_audio_file() to use exception-based validation
  - Better error messages for API responses

- Add: Job queue configuration to startup scripts
  - Added JOB_QUEUE_MAX_SIZE, JOB_METADATA_DIR, JOB_RETENTION_DAYS env vars
  - Added GPU health monitoring configuration
  - Create job metadata directory on startup
2025-10-10 00:11:36 +03:00
Alihan
1292f0f09b Add GPU auto-reset, job queue, health monitoring, and test infrastructure
Major features:
- GPU auto-reset on CUDA errors with cooldown protection (handles sleep/wake)
- Async job queue system for long-running transcriptions
- Comprehensive GPU health monitoring with real model tests
- Phase 1 component testing with detailed logging

New modules:
- src/core/gpu_reset.py: GPU driver reset with 5-min cooldown
- src/core/gpu_health.py: Real GPU health checks using model inference
- src/core/job_queue.py: FIFO queue with background worker and persistence
- src/utils/test_audio_generator.py: Test audio generation for GPU checks
- test_phase1.py: Component tests with logging
- reset_gpu.sh: GPU driver reset script

Updates:
- CLAUDE.md: Added GPU auto-reset docs and passwordless sudo setup
- requirements.txt: Updated to PyTorch CUDA 12.4
- Model manager: Integrated GPU health check with reset
- Both servers: Added startup GPU validation with auto-reset
- Startup scripts: Added GPU_RESET_COOLDOWN_MINUTES env var
2025-10-09 23:13:11 +03:00
Alihan
e7a457e602 Refactor codebase structure with organized src/ directory
- Reorganize source code into src/ directory with logical subdirectories:
  - src/servers/: MCP and REST API server implementations
  - src/core/: Core business logic (transcriber, model_manager)
  - src/utils/: Utility modules (audio_processor, formatters)

- Update all import statements to use proper module paths
- Configure PYTHONPATH in startup scripts and Dockerfile
- Update documentation with new structure and paths
- Update pyproject.toml with package configuration
- Keep DevOps files (scripts, Dockerfile, configs) at root level

All functionality validated and working correctly.
2025-10-07 12:28:03 +03:00
Alihan
7c9a8d8378 Merge branch 'alihan-specific' of https://gitea.umutalihandikel.com/alihan/Fast-Whisper-MCP-Server into alihan-specific 2025-10-07 11:20:34 +03:00
Alihan
2cc9f298a5 seperate mcp & api servers 2025-10-07 11:20:03 +03:00
ALIHAN DIKEL
56ccc0e1d7 . 2025-07-05 14:35:47 +03:00
ALIHAN DIKEL
53af30619f . 2025-07-05 14:34:26 +03:00
Alihan
046204d555 transcription flow cilalama, bugfixes 2025-06-15 17:50:05 +03:00
Alihan
9c020f947b resolve 2025-06-14 18:59:35 +03:00
Alihan
4936684db4 . 2025-06-14 18:58:57 +03:00
ALIHAN DIKEL
8e30a8812c read dockerfile 2025-06-14 16:12:09 +03:00
ALIHAN DIKEL
37935066ad alihan spesifiklestirildi 2025-06-14 15:59:16 +03:00
40 changed files with 6115 additions and 1078 deletions

4
.gitignore vendored
View File

@@ -14,4 +14,6 @@ venv/
# Cython
*.pyd
logs/**
User/**
data/**

529
CLAUDE.md Normal file
View File

@@ -0,0 +1,529 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Overview
This is a Whisper-based speech recognition service that provides high-performance audio transcription using Faster Whisper. The service runs as either:
1. **MCP Server** - For integration with Claude Desktop and other MCP clients
2. **REST API Server** - For HTTP-based integrations with async job queue support
Both servers share the same core transcription logic and can run independently or simultaneously on different ports.
**Key Features:**
- Async job queue system for long-running transcriptions (prevents HTTP timeouts)
- GPU health monitoring with strict failure detection (prevents silent CPU fallback)
- **Automatic GPU driver reset** on CUDA errors with cooldown protection (handles sleep/wake issues)
- Dual-server architecture (MCP + REST API)
- Model caching for fast repeated transcriptions
- Automatic batch size optimization based on GPU memory
## Development Commands
### Environment Setup
```bash
# Create and activate virtual environment
python3.12 -m venv venv
source venv/bin/activate
# Install dependencies
pip install -r requirements.txt
# Install PyTorch with CUDA 12.6 support
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
# For CUDA 12.1
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
# For CPU-only
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
```
### Running the Servers
#### MCP Server (for Claude Desktop)
```bash
# Using the startup script (recommended - sets all env vars)
./run_mcp_server.sh
# Direct Python execution
python whisper_server.py
# Using MCP CLI for development testing
mcp dev whisper_server.py
# Run server with MCP CLI
mcp run whisper_server.py
```
#### REST API Server (for HTTP clients)
```bash
# Using the startup script (recommended - sets all env vars)
./run_api_server.sh
# Direct Python execution with uvicorn
python api_server.py
# Or using uvicorn directly
uvicorn api_server:app --host 0.0.0.0 --port 8000
# Development mode with auto-reload
uvicorn api_server:app --reload --host 0.0.0.0 --port 8000
```
#### Running Both Simultaneously
```bash
# Terminal 1: Start MCP server
./run_mcp_server.sh
# Terminal 2: Start REST API server
./run_api_server.sh
```
### Docker
```bash
# Build Docker image
docker build -t whisper-mcp-server .
# Run with GPU support
docker run --gpus all -v /path/to/models:/models -v /path/to/outputs:/outputs whisper-mcp-server
```
## Architecture
### Directory Structure
```
.
├── src/ # Source code directory
│ ├── servers/ # Server implementations
│ │ ├── whisper_server.py # MCP server entry point
│ │ └── api_server.py # REST API server (async job queue)
│ ├── core/ # Core business logic
│ │ ├── transcriber.py # Transcription logic (single & batch)
│ │ ├── model_manager.py # Model lifecycle & caching
│ │ ├── job_queue.py # Async job queue manager
│ │ └── gpu_health.py # GPU health monitoring
│ └── utils/ # Utility modules
│ ├── audio_processor.py # Audio validation & preprocessing
│ ├── formatters.py # Output format conversion
│ └── test_audio_generator.py # Test audio generation for GPU checks
├── run_mcp_server.sh # MCP server startup script
├── run_api_server.sh # API server startup script
├── reset_gpu.sh # GPU driver reset script
├── DEV_PLAN.md # Development plan for async features
├── requirements.txt # Python dependencies
└── pyproject.toml # Project configuration
```
### Core Components
1. **src/servers/whisper_server.py** - MCP server entry point
- Uses FastMCP framework to expose MCP tools
- Three main tools: `get_model_info_api()`, `transcribe()`, `batch_transcribe_audio()`
- Server initialization at line 19
2. **src/servers/api_server.py** - REST API server entry point
- Uses FastAPI framework for HTTP endpoints
- Provides REST endpoints: `/`, `/health`, `/models`, `/transcribe`, `/batch-transcribe`, `/upload-transcribe`
- Shares core transcription logic with MCP server
- File upload support via multipart/form-data
3. **src/core/transcriber.py** - Core transcription logic (shared by both servers)
- `transcribe_audio()`:39 - Single file transcription with environment variable support
- `batch_transcribe()`:209 - Batch processing with progress reporting
- All parameters support environment variable defaults (lines 21-37)
- Delegates output formatting to utils.formatters
4. **src/core/model_manager.py** - Whisper model lifecycle management
- `get_whisper_model()`:44 - Returns cached model instances or loads new ones
- `test_gpu_driver()`:20 - GPU validation before model loading
- **CRITICAL**: GPU-only mode enforced at lines 64-90 (no CPU fallback)
- Global `model_instances` dict caches loaded models to prevent reloading
- Automatic batch size optimization based on GPU memory (lines 134-147)
5. **src/core/job_queue.py** - Async job queue manager
- `JobQueue` class manages FIFO queue with background worker thread
- `submit_job()` - Validates audio, checks GPU health, adds to queue
- `get_job_status()` - Returns current job status and queue position
- `get_job_result()` - Returns transcription result for completed jobs
- Jobs persist to disk as JSON files for crash recovery
- Single worker thread processes jobs sequentially (prevents GPU contention)
6. **src/core/gpu_health.py** - GPU health monitoring
- `check_gpu_health()`:39 - Real GPU test using tiny model + test audio
- `GPUHealthStatus` dataclass contains detailed GPU metrics
- **CRITICAL**: Raises RuntimeError if device=cuda but GPU fails (lines 99-135)
- Prevents silent CPU fallback that would cause 10-100x slowdown
- `HealthMonitor` class for periodic background monitoring
7. **src/utils/audio_processor.py** - Audio file validation and preprocessing
- `validate_audio_file()`:15 - Checks file existence, format, and size
- `process_audio()`:50 - Decodes audio using faster_whisper's decode_audio
8. **src/utils/formatters.py** - Output format conversion
- `format_vtt()`, `format_srt()`, `format_txt()`, `format_json()` - Convert segments to various formats
- All formatters accept segment lists from Whisper output
9. **src/utils/test_audio_generator.py** - Test audio generation
- `generate_test_audio()` - Creates synthetic 1-second audio for GPU health checks
- Uses numpy to generate sine wave, no external audio files needed
### Key Architecture Patterns
- **Dual Server Architecture**: Both MCP and REST API servers import and use the same core modules (core.transcriber, core.model_manager, utils.audio_processor, utils.formatters), ensuring consistent behavior
- **Model Caching**: Models are cached in `model_instances` dictionary with key format `{model_name}_{device}_{compute_type}` (src/core/model_manager.py:104). This cache is shared if both servers run in the same process
- **Batch Processing**: CUDA devices automatically use BatchedInferencePipeline for performance (src/core/model_manager.py:132-160)
- **Environment Variable Configuration**: All transcription parameters support env var defaults (src/core/transcriber.py:21-37)
- **GPU-Only Mode**: Service is configured for GPU-only operation. `device="auto"` requires CUDA, `device="cpu"` is rejected (src/core/model_manager.py:64-90)
- **Async Job Queue**: Long-running transcriptions use async queue pattern to prevent HTTP timeouts. Jobs return immediately with job_id for polling
- **GPU Health Monitoring**: Real GPU tests with tiny model prevent silent CPU fallback. Jobs are rejected immediately if GPU fails rather than running 10-100x slower on CPU
## Environment Variables
All configuration can be set via environment variables in run_mcp_server.sh and run_api_server.sh:
**API Server Specific:**
- `API_HOST` - API server host (default: 0.0.0.0)
- `API_PORT` - API server port (default: 8000)
**Job Queue Configuration (if using async features):**
- `JOB_QUEUE_MAX_SIZE` - Maximum queue size (default: 100)
- `JOB_METADATA_DIR` - Directory for job metadata JSON files
- `JOB_RETENTION_DAYS` - Auto-cleanup old jobs (0=disabled)
**GPU Health Monitoring:**
- `GPU_HEALTH_CHECK_ENABLED` - Enable periodic GPU monitoring (true/false)
- `GPU_HEALTH_CHECK_INTERVAL_MINUTES` - Monitoring interval (default: 10)
- `GPU_HEALTH_TEST_MODEL` - Model for health checks (default: tiny)
**GPU Auto-Reset Configuration:**
- `GPU_RESET_COOLDOWN_MINUTES` - Minimum time between GPU reset attempts (default: 5 minutes)
- Prevents reset loops while allowing recovery from sleep/wake cycles
- Auto-reset is **enabled by default**
- Service terminates if GPU unavailable after reset attempt
**Transcription Configuration (shared by both servers):**
- `CUDA_VISIBLE_DEVICES` - GPU device selection
- `WHISPER_MODEL_DIR` - Model storage location (defaults to None for HuggingFace cache)
- `TRANSCRIPTION_OUTPUT_DIR` - Default output directory for single transcriptions
- `TRANSCRIPTION_BATCH_OUTPUT_DIR` - Default output directory for batch processing
- `TRANSCRIPTION_MODEL` - Model size (tiny, base, small, medium, large-v1, large-v2, large-v3)
- `TRANSCRIPTION_DEVICE` - Execution device (cuda, auto) - **NOTE: cpu is rejected in GPU-only mode**
- `TRANSCRIPTION_COMPUTE_TYPE` - Computation type (float16, int8, auto)
- `TRANSCRIPTION_OUTPUT_FORMAT` - Output format (vtt, srt, txt, json)
- `TRANSCRIPTION_BEAM_SIZE` - Beam search size (default: 5)
- `TRANSCRIPTION_TEMPERATURE` - Sampling temperature (default: 0.0)
- `TRANSCRIPTION_USE_TIMESTAMP` - Add timestamp to filenames (true/false)
- `TRANSCRIPTION_FILENAME_PREFIX` - Prefix for output filenames
- `TRANSCRIPTION_FILENAME_SUFFIX` - Suffix for output filenames
- `TRANSCRIPTION_LANGUAGE` - Language code (zh, en, ja, etc., auto-detect if not set)
## Supported Configurations
- **Models**: tiny, base, small, medium, large-v1, large-v2, large-v3
- **Audio formats**: .mp3, .wav, .m4a, .flac, .ogg, .aac
- **Output formats**: vtt, srt, json, txt
- **Languages**: zh (Chinese), en (English), ja (Japanese), ko (Korean), de (German), fr (French), es (Spanish), ru (Russian), it (Italian), pt (Portuguese), nl (Dutch), ar (Arabic), hi (Hindi), tr (Turkish), vi (Vietnamese), th (Thai), id (Indonesian)
## REST API Endpoints
The REST API server provides the following HTTP endpoints:
### GET /
Returns API information and available endpoints.
### GET /health
Health check endpoint. Returns `{"status": "healthy", "service": "whisper-transcription"}`.
### GET /models
Returns available Whisper models, devices, languages, and system information (GPU details if CUDA available).
### POST /transcribe
Transcribe a single audio file that exists on the server.
**Request Body:**
```json
{
"audio_path": "/path/to/audio.mp3",
"model_name": "large-v3",
"device": "auto",
"compute_type": "auto",
"language": "en",
"output_format": "txt",
"beam_size": 5,
"temperature": 0.0,
"initial_prompt": null,
"output_directory": null
}
```
**Response:**
```json
{
"success": true,
"message": "Transcription successful, results saved to: /path/to/output.txt",
"output_path": "/path/to/output.txt"
}
```
### POST /batch-transcribe
Batch transcribe all audio files in a folder.
**Request Body:**
```json
{
"audio_folder": "/path/to/audio/folder",
"output_folder": "/path/to/output",
"model_name": "large-v3",
"output_format": "txt",
...
}
```
**Response:**
```json
{
"success": true,
"summary": "Batch processing completed, total transcription time: 00:05:23 | Success: 10/10 | Failed: 0/10"
}
```
### POST /upload-transcribe
Upload an audio file and transcribe it immediately. Returns the transcription file as a download.
**Form Data:**
- `file`: Audio file (multipart/form-data)
- `model_name`: Model name (default: "large-v3")
- `device`: Device (default: "auto")
- `output_format`: Output format (default: "txt")
- ... (other transcription parameters)
**Response:** Returns the transcription file for download.
### API Usage Examples
```bash
# Get model information
curl http://localhost:8000/models
# Transcribe existing file (synchronous)
curl -X POST http://localhost:8000/transcribe \
-H "Content-Type: application/json" \
-d '{"audio_path": "/path/to/audio.mp3", "output_format": "txt"}'
# Upload and transcribe
curl -X POST http://localhost:8000/upload-transcribe \
-F "file=@audio.mp3" \
-F "output_format=txt" \
-F "model_name=large-v3"
# Async job queue (if enabled)
# Submit job
curl -X POST http://localhost:8000/jobs \
-H "Content-Type: application/json" \
-d '{"audio_path": "/path/to/audio.mp3"}'
# Returns: {"job_id": "abc-123", "status": "queued", "queue_position": 1}
# Check status
curl http://localhost:8000/jobs/abc-123
# Returns: {"status": "running", ...}
# Get result (when completed)
curl http://localhost:8000/jobs/abc-123/result
# Returns: transcription text
# Check GPU health
curl http://localhost:8000/health/gpu
# Returns: {"gpu_available": true, "gpu_working": true, ...}
```
## GPU Auto-Reset Configuration
### Overview
This service features automatic GPU driver reset on CUDA errors, which is especially useful for recovering from sleep/wake cycles. The reset functionality is **enabled by default** and includes cooldown protection to prevent reset loops.
### Passwordless Sudo Setup (Required)
For automatic GPU reset to work, you must configure passwordless sudo for NVIDIA commands. Create a sudoers configuration file:
```bash
sudo visudo -f /etc/sudoers.d/whisper-gpu-reset
```
Add the following (replace `your_username` with your actual username):
```
# Whisper GPU Auto-Reset Permissions
your_username ALL=(ALL) NOPASSWD: /bin/systemctl stop nvidia-persistenced
your_username ALL=(ALL) NOPASSWD: /bin/systemctl start nvidia-persistenced
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia_uvm
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia_drm
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia_modeset
your_username ALL=(ALL) NOPASSWD: /sbin/rmmod nvidia
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia_modeset
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia_uvm
your_username ALL=(ALL) NOPASSWD: /sbin/modprobe nvidia_drm
```
**Security Note:** These permissions are limited to specific NVIDIA driver commands only. The reset script (`reset_gpu.sh`) is executed with sudo but is part of the codebase and can be audited.
### How It Works
1. **Startup Check**: When the service starts, it performs a GPU health check
- If CUDA errors detected → automatic reset attempt → retry
- If retry fails → service terminates
2. **Runtime Check**: Before job submission and model loading
- If CUDA errors detected → automatic reset attempt → retry
- If retry fails → job rejected, service continues
3. **Cooldown Protection**: Prevents reset loops
- Minimum 5 minutes between reset attempts (configurable via `GPU_RESET_COOLDOWN_MINUTES`)
- Cooldown persists across restarts (stored in `/tmp/whisper-gpu-last-reset`)
- If reset needed but cooldown active → service/job fails immediately
### Manual GPU Reset
You can manually reset the GPU anytime:
```bash
./reset_gpu.sh
```
Or clear the cooldown to allow immediate reset:
```python
from core.gpu_reset import clear_reset_cooldown
clear_reset_cooldown()
```
### Behavior Examples
**After sleep/wake with GPU issue:**
```
Service starts → GPU check fails (CUDA error)
→ Cooldown OK → Reset drivers → Wait 3s → Retry
→ Success → Service continues
```
**Multiple failures (hardware issue):**
```
First failure → Reset → Retry fails → Job fails
Second failure within 5 min → Cooldown active → Fail immediately
(Prevents reset loop)
```
**Normal operation:**
```
No CUDA errors → No resets → Normal performance
Reset only happens on actual CUDA failures
```
## Important Implementation Details
### GPU-Only Architecture
- **CRITICAL**: Service enforces GPU-only mode. CPU device is explicitly rejected (src/core/model_manager.py:84-90)
- `device="auto"` requires CUDA to be available, raises RuntimeError if not (src/core/model_manager.py:64-73)
- GPU health checks use real model loading + transcription, not just torch.cuda.is_available()
- If GPU health check fails, jobs are rejected immediately rather than silently falling back to CPU
- **GPU Auto-Reset**: Automatic driver reset on CUDA errors with 5-minute cooldown (handles sleep/wake issues)
### Model Management
- GPU memory is checked before loading models (src/core/model_manager.py:115-127)
- Batch size dynamically adjusts: 32 (>16GB), 16 (>12GB), 8 (>8GB), 4 (>4GB), 2 (otherwise)
- Models are cached globally in `model_instances` dict, shared across requests
- Model loading includes GPU driver test to fail fast if GPU is unavailable (src/core/model_manager.py:112-114)
### Transcription Settings
- VAD (Voice Activity Detection) is enabled by default for better long-audio accuracy (src/core/transcriber.py:102)
- Word timestamps are enabled by default (src/core/transcriber.py:107)
- Files over 1GB generate warnings about processing time (src/utils/audio_processor.py:42)
- Default output format is "txt" for REST API, configured via environment variables for MCP server
### Async Job Queue (if enabled)
- Single worker thread processes jobs sequentially (prevents GPU memory contention)
- Jobs persist to disk as JSON files in JOB_METADATA_DIR
- Queue has max size limit (default 100), returns 503 when full
- Job status polling recommended every 5-10 seconds for LLM agents
## Development Workflow
### Testing GPU Health
```python
# Test GPU health check manually
from src.core.gpu_health import check_gpu_health
status = check_gpu_health(expected_device="cuda")
print(f"GPU Working: {status.gpu_working}")
print(f"Device: {status.device_used}")
print(f"Test Duration: {status.test_duration_seconds}s")
# Expected: <1s for GPU, 3-10s for CPU
```
### Testing Job Queue
```python
# Test job queue manually
from src.core.job_queue import JobQueue
queue = JobQueue(max_queue_size=100, metadata_dir="/tmp/jobs")
queue.start()
# Submit job
job_info = queue.submit_job(
audio_path="/path/to/test.mp3",
model_name="large-v3",
device="cuda"
)
print(f"Job ID: {job_info['job_id']}")
# Poll status
status = queue.get_job_status(job_info['job_id'])
print(f"Status: {status['status']}")
# Get result when completed
result = queue.get_job_result(job_info['job_id'])
```
### Common Debugging
**Model loading issues:**
- Check `WHISPER_MODEL_DIR` is set correctly
- Verify GPU memory with `nvidia-smi`
- Check logs for GPU driver test failures at model_manager.py:112-114
**GPU not detected:**
- Verify `CUDA_VISIBLE_DEVICES` is set correctly
- Check `torch.cuda.is_available()` returns True
- Run GPU health check to see detailed error
**Silent failures:**
- Check that service is NOT silently falling back to CPU
- GPU health check should RAISE errors, not log warnings
- If device=cuda fails, the job should be rejected, not processed on CPU
**Job queue issues:**
- Check `JOB_METADATA_DIR` exists and is writable
- Verify background worker thread is running (check logs)
- Job metadata files are in {JOB_METADATA_DIR}/{job_id}.json
### File Locations
- **Logs**: `mcp.logs` (MCP server), `api.logs` (API server)
- **Models**: `$WHISPER_MODEL_DIR` or HuggingFace cache
- **Outputs**: `$TRANSCRIPTION_OUTPUT_DIR` or `$TRANSCRIPTION_BATCH_OUTPUT_DIR`
- **Job Metadata**: `$JOB_METADATA_DIR/{job_id}.json`
### Important Development Notes
- See `DEV_PLAN.md` for detailed architecture and implementation plan for async job queue features
- The service is designed for GPU-only operation - CPU fallback is intentionally disabled to prevent silent performance degradation
- When modifying model_manager.py, maintain the strict GPU-only enforcement
- When adding new endpoints, follow the async pattern if transcription time >30 seconds

56
Dockerfile Normal file
View File

@@ -0,0 +1,56 @@
# Use NVIDIA CUDA base image with Python
FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04
# Install Python 3.12
RUN apt-get update && apt-get install -y \
software-properties-common \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update && apt-get install -y \
python3.12 \
python3.12-venv \
python3.12-dev \
python3-pip \
ffmpeg \
git \
&& rm -rf /var/lib/apt/lists/*
# Make python3.12 the default
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
# Upgrade pip
RUN python -m pip install --upgrade pip
# Set working directory
WORKDIR /app
# Copy requirements first for better caching
COPY requirements.txt .
# Install Python dependencies with CUDA support
RUN pip install --no-cache-dir \
faster-whisper \
torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121 \
torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121 \
mcp[cli]
# Copy application code
COPY src/ ./src/
COPY pyproject.toml .
COPY README.md .
# Create directories for models and outputs
RUN mkdir -p /models /outputs
# Set Python path
ENV PYTHONPATH=/app/src
# Set environment variables for GPU
ENV WHISPER_MODEL_DIR=/models
ENV TRANSCRIPTION_OUTPUT_DIR=/outputs
ENV TRANSCRIPTION_MODEL=large-v3
ENV TRANSCRIPTION_DEVICE=cuda
ENV TRANSCRIPTION_COMPUTE_TYPE=float16
# Run the server
CMD ["python", "src/servers/whisper_server.py"]

View File

@@ -1,184 +0,0 @@
# Whisper 语音识别 MCP 服务器
基于 Faster Whisper 的语音识别 MCP 服务器,提供高性能的音频转录功能。
## 功能特点
- 集成 Faster Whisper 进行高效语音识别
- 支持批处理加速,提高转录速度
- 自动使用 CUDA 加速(如果可用)
- 支持多种模型大小tiny 到 large-v3
- 输出格式支持 VTT 字幕和 JSON
- 支持批量转录文件夹中的音频文件
- 模型实例缓存,避免重复加载
## 安装
### 依赖项
- Python 3.10+
- faster-whisper>=0.9.0
- torch==2.6.0+cu126
- torchaudio==2.6.0+cu126
- mcp[cli]>=1.2.0
### 安装步骤
1. 克隆或下载此仓库
2. 创建并激活虚拟环境(推荐)
3. 安装依赖项:
```bash
pip install -r requirements.txt
```
## 使用方法
### 启动服务器
在 Windows 上,直接运行 `start_server.bat`
在其他平台上,运行:
```bash
python whisper_server.py
```
### 配置 Claude Desktop
1. 打开 Claude Desktop 配置文件:
- Windows: `%APPDATA%\Claude\claude_desktop_config.json`
- macOS: `~/Library/Application Support/Claude/claude_desktop_config.json`
2. 添加 Whisper 服务器配置:
```json
{
"mcpServers": {
"whisper": {
"command": "python",
"args": ["D:/path/to/whisper_server.py"],
"env": {}
}
}
}
```
3. 重启 Claude Desktop
### 可用工具
服务器提供以下工具:
1. **get_model_info** - 获取可用的 Whisper 模型信息
2. **transcribe** - 转录单个音频文件
3. **batch_transcribe** - 批量转录文件夹中的音频文件
## 性能优化建议
- 使用 CUDA 加速可显著提高转录速度
- 对于大量短音频,批处理模式效率更高
- 根据 GPU 显存大小自动调整批处理大小
- 对于长音频,使用 VAD 过滤可提高准确性
- 指定正确的语言可提高转录质量
## 本地测试方案
1. 使用 MCP Inspector 进行快速测试:
```bash
mcp dev whisper_server.py
```
2. 使用 Claude Desktop 进行集成测试
3. 使用命令行直接调用(需要安装 mcp[cli]
```bash
mcp run whisper_server.py
```
## 错误处理
服务器实现了以下错误处理机制:
- 音频文件不存在检查
- 模型加载失败处理
- 转录过程异常捕获
- GPU 内存管理
- 批处理参数自适应调整
## 许可证
MIT
## 致谢
本项目在开发过程中得到了以下优秀AI工具和模型的帮助
- [GitHub Copilot](https://github.com/features/copilot) - AI结对编程助手
- [Trae](https://trae.ai/) - 智能AI编码助手
- [Cline](https://cline.ai/) - AI驱动的终端
- [DeepSeek](https://www.deepseek.com/) - 先进的AI模型
- [Claude-3.7-Sonnet](https://www.anthropic.com/claude) - Anthropic强大的AI助手
- [Gemini-2.0-Flash](https://ai.google/gemini/) - Google的多模态AI模型
- [VS Code](https://code.visualstudio.com/) - 强大的代码编辑器
- [Whisper](https://github.com/openai/whisper) - OpenAI的语音识别模型
- [Faster Whisper](https://github.com/guillaumekln/faster-whisper) - 优化的Whisper实现
特别感谢这些出色的工具和背后的团队。
---
# Whisper 语音识别 MCP 服务器cline claude sonnet 3.7 完成所有任务后的说明)
我已经成功创建了基于Python的语音识别MCP服务器该服务器集成了Faster Whisper并支持批处理加速。服务器具有以下特点
## 核心功能
- **集成Faster Whisper**:使用最新版本(>=0.9.0)进行高效语音识别
- **CUDA加速**自动检测并使用CUDA加速默认使用torch==2.6.0+cu126
- **批处理优化**使用BatchedInferencePipeline实现批处理根据GPU显存自动调整批大小
- **模型缓存**:实现模型实例缓存,避免重复加载模型
- **多格式输出**支持VTT字幕和JSON格式的转录结果
- **批量处理**:支持批量转录整个文件夹的音频文件
## 主要工具
服务器提供了三个主要工具:
1. **get_model_info**获取可用的Whisper模型信息和系统配置
2. **transcribe**:转录单个音频文件,支持多种参数配置
3. **batch_transcribe**:批量转录文件夹中的音频文件
## 错误处理机制
- 音频文件存在性验证
- 模型加载异常捕获和日志记录
- 转录过程异常处理
- GPU内存管理和清理
- 批处理参数自适应调整
## 性能优化
- 根据GPU显存大小动态调整批处理大小(4-32)
- 使用VAD(语音活动检测)过滤提高准确性
- 模型实例缓存避免重复加载
- 自动选择最佳设备和计算类型
## 本地测试方案
提供了多种测试方法:
- 使用MCP Inspector进行快速测试`mcp dev whisper_server.py`
- 使用Claude Desktop进行集成测试
- 使用命令行直接调用:`mcp run whisper_server.py`
所有文件已准备就绪,包括:
- whisper_server.py主服务器代码
- requirements.txt依赖项列表
- start_server.batWindows启动脚本
- README.md详细文档
您可以通过运行start_server.bat或直接执行`python whisper_server.py`来启动服务器。

163
README.md
View File

@@ -1,163 +0,0 @@
# Whisper Speech Recognition MCP Server
---
[中文文档](README-CN.md)
---
A high-performance speech recognition MCP server based on Faster Whisper, providing efficient audio transcription capabilities.
## Features
- Integrated with Faster Whisper for efficient speech recognition
- Batch processing acceleration for improved transcription speed
- Automatic CUDA acceleration (if available)
- Support for multiple model sizes (tiny to large-v3)
- Output formats include VTT subtitles, SRT, and JSON
- Support for batch transcription of audio files in a folder
- Model instance caching to avoid repeated loading
- Dynamic batch size adjustment based on GPU memory
## Installation
### Dependencies
- Python 3.10+
- faster-whisper>=0.9.0
- torch==2.6.0+cu126
- torchaudio==2.6.0+cu126
- mcp[cli]>=1.2.0
### Installation Steps
1. Clone or download this repository
2. Create and activate a virtual environment (recommended)
3. Install dependencies:
```bash
pip install -r requirements.txt
```
### PyTorch Installation Guide
Install the appropriate version of PyTorch based on your CUDA version:
- CUDA 12.6:
```bash
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
```
- CUDA 12.1:
```bash
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
```
- CPU version:
```bash
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
```
You can check your CUDA version with `nvcc --version` or `nvidia-smi`.
## Usage
### Starting the Server
On Windows, simply run `start_server.bat`.
On other platforms, run:
```bash
python whisper_server.py
```
### Configuring Claude Desktop
1. Open the Claude Desktop configuration file:
- Windows: `%APPDATA%\Claude\claude_desktop_config.json`
- macOS: `~/Library/Application Support/Claude/claude_desktop_config.json`
2. Add the Whisper server configuration:
```json
{
"mcpServers": {
"whisper": {
"command": "python",
"args": ["D:/path/to/whisper_server.py"],
"env": {}
}
}
}
```
3. Restart Claude Desktop
### Available Tools
The server provides the following tools:
1. **get_model_info** - Get information about available Whisper models
2. **transcribe** - Transcribe a single audio file
3. **batch_transcribe** - Batch transcribe audio files in a folder
## Performance Optimization Tips
- Using CUDA acceleration significantly improves transcription speed
- Batch processing mode is more efficient for large numbers of short audio files
- Batch size is automatically adjusted based on GPU memory size
- Using VAD (Voice Activity Detection) filtering improves accuracy for long audio
- Specifying the correct language can improve transcription quality
## Local Testing Methods
1. Use MCP Inspector for quick testing:
```bash
mcp dev whisper_server.py
```
2. Use Claude Desktop for integration testing
3. Use command line direct invocation (requires mcp[cli]):
```bash
mcp run whisper_server.py
```
## Error Handling
The server implements the following error handling mechanisms:
- Audio file existence check
- Model loading failure handling
- Transcription process exception catching
- GPU memory management
- Batch processing parameter adaptive adjustment
## Project Structure
- `whisper_server.py`: Main server code
- `model_manager.py`: Whisper model loading and caching
- `audio_processor.py`: Audio file validation and preprocessing
- `formatters.py`: Output formatting (VTT, SRT, JSON)
- `transcriber.py`: Core transcription logic
- `start_server.bat`: Windows startup script
## License
MIT
## Acknowledgements
This project was developed with the assistance of these amazing AI tools and models:
- [GitHub Copilot](https://github.com/features/copilot) - AI pair programmer
- [Trae](https://trae.ai/) - Agentic AI coding assistant
- [Cline](https://cline.ai/) - AI-powered terminal
- [DeepSeek](https://www.deepseek.com/) - Advanced AI model
- [Claude-3.7-Sonnet](https://www.anthropic.com/claude) - Anthropic's powerful AI assistant
- [Gemini-2.0-Flash](https://ai.google/gemini/) - Google's multimodal AI model
- [VS Code](https://code.visualstudio.com/) - Powerful code editor
- [Whisper](https://github.com/openai/whisper) - OpenAI's speech recognition model
- [Faster Whisper](https://github.com/guillaumekln/faster-whisper) - Optimized Whisper implementation
Special thanks to these incredible tools and the teams behind them.

View File

@@ -1,5 +0,0 @@
"""
语音识别MCP服务模块
"""
__version__ = "0.1.0"

85
api.logs Normal file
View File

@@ -0,0 +1,85 @@
INFO:__main__:======================================================================
INFO:__main__:PERFORMING STARTUP GPU HEALTH CHECK
INFO:__main__:======================================================================
INFO:faster_whisper:Processing audio with duration 00:01.512
INFO:faster_whisper:Detected language 'en' with probability 0.95
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 1.04s
INFO:__main__:======================================================================
INFO:__main__:STARTUP GPU CHECK SUCCESSFUL
INFO:__main__:GPU Device: NVIDIA GeForce RTX 3060
INFO:__main__:Memory Available: 11.66 GB
INFO:__main__:Test Duration: 1.04s
INFO:__main__:======================================================================
INFO:__main__:Starting Whisper REST API server on 0.0.0.0:8000
INFO: Started server process [69821]
INFO: Waiting for application startup.
INFO:__main__:Starting job queue and health monitor...
INFO:core.job_queue:Starting job queue (max size: 100)
INFO:core.job_queue:Loading jobs from /media/raid/agents/tools/mcp-transcriptor/outputs/jobs
INFO:core.job_queue:Loaded 8 jobs from disk
INFO:core.job_queue:Job queue worker loop started
INFO:core.job_queue:Job queue worker started
INFO:__main__:Job queue started (max_size=100, metadata_dir=/media/raid/agents/tools/mcp-transcriptor/outputs/jobs)
INFO:core.gpu_health:Starting GPU health monitor (interval: 10.0 minutes)
INFO:faster_whisper:Processing audio with duration 00:01.512
INFO:faster_whisper:Detected language 'en' with probability 0.95
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.37s
INFO:__main__:GPU health monitor started (interval=10 minutes)
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO: 127.0.0.1:48092 - "GET /jobs HTTP/1.1" 200 OK
INFO: 127.0.0.1:60874 - "GET /jobs?status=completed&limit=3 HTTP/1.1" 200 OK
INFO: 127.0.0.1:60876 - "GET /jobs?status=failed&limit=10 HTTP/1.1" 200 OK
INFO:core.job_queue:Running GPU health check before job submission
INFO:faster_whisper:Processing audio with duration 00:01.512
INFO:faster_whisper:Detected language 'en' with probability 0.95
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.39s
INFO:core.job_queue:GPU health check passed
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 submitted: /tmp/whisper_test_voice_1s.mp3 (queue position: 1)
INFO: 127.0.0.1:58376 - "POST /jobs HTTP/1.1" 200 OK
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 started processing
INFO:core.model_manager:Running GPU health check with auto-reset before model loading
INFO:faster_whisper:Processing audio with duration 00:01.512
INFO:faster_whisper:Detected language 'en' with probability 0.95
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.54s
INFO:core.model_manager:Loading Whisper model: tiny device: cuda compute type: float16
INFO:core.model_manager:Available GPU memory: 12.52 GB
INFO:core.model_manager:Enabling batch processing acceleration, batch size: 16
INFO:core.transcriber:Starting transcription of file: whisper_test_voice_1s.mp3
INFO:utils.audio_processor:Successfully preprocessed audio: whisper_test_voice_1s.mp3
INFO:core.transcriber:Using batch acceleration for transcription...
INFO:faster_whisper:Processing audio with duration 00:01.512
INFO:faster_whisper:VAD filter removed 00:00.000 of audio
INFO:faster_whisper:Detected language 'en' with probability 0.95
INFO:core.transcriber:Transcription completed, time used: 0.16 seconds, detected language: en, audio length: 1.51 seconds
INFO:core.transcriber:Transcription results saved to: /media/raid/agents/tools/mcp-transcriptor/outputs/whisper_test_voice_1s.txt
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 completed successfully: /media/raid/agents/tools/mcp-transcriptor/outputs/whisper_test_voice_1s.txt
INFO:core.job_queue:Job 6be8e49a-bdc1-4508-af99-280bef033cb0 finished: status=completed, duration=1.1s
INFO: 127.0.0.1:41646 - "GET /jobs/6be8e49a-bdc1-4508-af99-280bef033cb0 HTTP/1.1" 200 OK
INFO: 127.0.0.1:34046 - "GET /jobs/6be8e49a-bdc1-4508-af99-280bef033cb0/result HTTP/1.1" 200 OK
INFO:core.job_queue:Running GPU health check before job submission
INFO:faster_whisper:Processing audio with duration 00:01.512
INFO:faster_whisper:Detected language 'en' with probability 0.95
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.39s
INFO:core.job_queue:GPU health check passed
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a submitted: /home/uad/agents/tools/mcp-transcriptor/data/test.mp3 (queue position: 1)
INFO: 127.0.0.1:44576 - "POST /jobs HTTP/1.1" 200 OK
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a started processing
INFO:core.model_manager:Running GPU health check with auto-reset before model loading
INFO:faster_whisper:Processing audio with duration 00:01.512
INFO:faster_whisper:Detected language 'en' with probability 0.95
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.39s
INFO:core.model_manager:Loading Whisper model: large-v3 device: cuda compute type: float16
INFO:core.model_manager:Available GPU memory: 12.52 GB
INFO:core.model_manager:Enabling batch processing acceleration, batch size: 16
INFO:core.transcriber:Starting transcription of file: test.mp3
INFO:utils.audio_processor:Successfully preprocessed audio: test.mp3
INFO:core.transcriber:Using batch acceleration for transcription...
INFO:faster_whisper:Processing audio with duration 00:06.955
INFO:faster_whisper:VAD filter removed 00:00.299 of audio
INFO:core.transcriber:Transcription completed, time used: 0.52 seconds, detected language: en, audio length: 6.95 seconds
INFO:core.transcriber:Transcription results saved to: /media/raid/agents/tools/mcp-transcriptor/outputs/test.txt
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a completed successfully: /media/raid/agents/tools/mcp-transcriptor/outputs/test.txt
INFO:core.job_queue:Job 41ce74c0-8929-457b-96b3-1b8e4a720a7a finished: status=completed, duration=23.3s
INFO: 127.0.0.1:59120 - "GET /jobs/41ce74c0-8929-457b-96b3-1b8e4a720a7a HTTP/1.1" 200 OK
INFO: 127.0.0.1:53806 - "GET /jobs/41ce74c0-8929-457b-96b3-1b8e4a720a7a/result HTTP/1.1" 200 OK

View File

@@ -1,67 +0,0 @@
#!/usr/bin/env python3
"""
音频处理模块
负责音频文件的验证和预处理
"""
import os
import logging
from typing import Union, Any
from faster_whisper import decode_audio
# 日志配置
logger = logging.getLogger(__name__)
def validate_audio_file(audio_path: str) -> str:
"""
验证音频文件是否有效
Args:
audio_path: 音频文件路径
Returns:
str: 验证结果,"ok"表示验证通过,否则返回错误信息
"""
# 验证参数
if not os.path.exists(audio_path):
return f"错误: 音频文件不存在: {audio_path}"
# 验证文件格式
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
file_ext = os.path.splitext(audio_path)[1].lower()
if file_ext not in supported_formats:
return f"错误: 不支持的音频格式: {file_ext}。支持的格式: {', '.join(supported_formats)}"
# 验证文件大小
try:
file_size = os.path.getsize(audio_path)
if file_size == 0:
return f"错误: 音频文件为空: {audio_path}"
# 大文件警告超过1GB
if file_size > 1024 * 1024 * 1024:
logger.warning(f"警告: 文件大小超过1GB可能需要较长处理时间: {audio_path}")
except Exception as e:
logger.error(f"检查文件大小失败: {str(e)}")
return f"错误: 检查文件大小失败: {str(e)}"
return "ok"
def process_audio(audio_path: str) -> Union[str, Any]:
"""
处理音频文件,进行解码和预处理
Args:
audio_path: 音频文件路径
Returns:
Union[str, Any]: 处理后的音频数据或原始文件路径
"""
# 尝试使用decode_audio预处理音频以处理更多格式
try:
audio_data = decode_audio(audio_path)
logger.info(f"成功预处理音频: {os.path.basename(audio_path)}")
return audio_data
except Exception as audio_error:
logger.warning(f"音频预处理失败,将直接使用文件路径: {str(audio_error)}")
return audio_path

1
data Symbolic link
View File

@@ -0,0 +1 @@
/media/raid/agents/tools/mcp-transcriptor

25
mcp.logs Normal file
View File

@@ -0,0 +1,25 @@
starting mcp server for whisper stt transcriptor
INFO:__main__:======================================================================
INFO:__main__:PERFORMING STARTUP GPU HEALTH CHECK
INFO:__main__:======================================================================
INFO:faster_whisper:Processing audio with duration 00:01.512
INFO:faster_whisper:Detected language 'en' with probability 0.95
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.93s
INFO:__main__:======================================================================
INFO:__main__:STARTUP GPU CHECK SUCCESSFUL
INFO:__main__:GPU Device: NVIDIA GeForce RTX 3060
INFO:__main__:Memory Available: 11.66 GB
INFO:__main__:Test Duration: 0.93s
INFO:__main__:======================================================================
INFO:__main__:Initializing job queue...
INFO:core.job_queue:Starting job queue (max size: 100)
INFO:core.job_queue:Loading jobs from /media/raid/agents/tools/mcp-transcriptor/outputs/jobs
INFO:core.job_queue:Loaded 5 jobs from disk
INFO:core.job_queue:Job queue worker loop started
INFO:core.job_queue:Job queue worker started
INFO:__main__:Job queue started (max_size=100, metadata_dir=/media/raid/agents/tools/mcp-transcriptor/outputs/jobs)
INFO:core.gpu_health:Starting GPU health monitor (interval: 10.0 minutes)
INFO:faster_whisper:Processing audio with duration 00:01.512
INFO:faster_whisper:Detected language 'en' with probability 0.95
INFO:core.gpu_health:GPU health check passed: NVIDIA GeForce RTX 3060, test duration: 0.38s
INFO:__main__:GPU health monitor started (interval=10 minutes)

View File

@@ -1,176 +0,0 @@
#!/usr/bin/env python3
"""
模型管理模块
负责Whisper模型的加载、缓存和管理
"""
import os
import time
import logging
from typing import Dict, Any
import torch
from faster_whisper import WhisperModel, BatchedInferencePipeline
# 日志配置
logger = logging.getLogger(__name__)
# 全局模型实例缓存
model_instances = {}
def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[str, Any]:
"""
获取或创建Whisper模型实例
Args:
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
device: 运行设备 (cpu, cuda, auto)
compute_type: 计算类型 (float16, int8, auto)
Returns:
dict: 包含模型实例和配置的字典
"""
global model_instances
# 验证模型名称
valid_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"]
if model_name not in valid_models:
raise ValueError(f"无效的模型名称: {model_name}。有效的模型: {', '.join(valid_models)}")
# 自动检测设备
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
# 验证设备和计算类型
if device not in ["cpu", "cuda"]:
raise ValueError(f"无效的设备: {device}。有效的设备: cpu, cuda")
if device == "cuda" and not torch.cuda.is_available():
logger.warning("CUDA不可用自动切换到CPU")
device = "cpu"
compute_type = "int8"
if compute_type not in ["float16", "int8"]:
raise ValueError(f"无效的计算类型: {compute_type}。有效的计算类型: float16, int8")
if device == "cpu" and compute_type == "float16":
logger.warning("CPU设备不支持float16计算类型自动切换到int8")
compute_type = "int8"
# 生成模型键
model_key = f"{model_name}_{device}_{compute_type}"
# 如果模型已实例化,直接返回
if model_key in model_instances:
logger.info(f"使用缓存的模型实例: {model_key}")
return model_instances[model_key]
# 清理GPU内存如果使用CUDA
if device == "cuda":
torch.cuda.empty_cache()
# 实例化模型
try:
logger.info(f"加载Whisper模型: {model_name} 设备: {device} 计算类型: {compute_type}")
# 基础模型
model = WhisperModel(
model_name,
device=device,
compute_type=compute_type,
download_root=os.environ.get("WHISPER_MODEL_DIR", None) # 支持自定义模型目录
)
# 批处理设置 - 默认启用批处理以提高速度
batched_model = None
batch_size = 0
if device == "cuda": # 只在CUDA设备上使用批处理
# 根据显存大小确定合适的批大小
if torch.cuda.is_available():
gpu_mem = torch.cuda.get_device_properties(0).total_memory
free_mem = gpu_mem - torch.cuda.memory_allocated()
# 根据GPU显存动态调整批大小
if free_mem > 16e9: # >16GB
batch_size = 32
elif free_mem > 12e9: # >12GB
batch_size = 16
elif free_mem > 8e9: # >8GB
batch_size = 8
elif free_mem > 4e9: # >4GB
batch_size = 4
else: # 较小显存
batch_size = 2
logger.info(f"可用GPU显存: {free_mem / 1e9:.2f} GB")
else:
batch_size = 8 # 默认值
logger.info(f"启用批处理加速,批大小: {batch_size}")
batched_model = BatchedInferencePipeline(model=model)
# 创建结果对象
result = {
'model': model,
'device': device,
'compute_type': compute_type,
'batched_model': batched_model,
'batch_size': batch_size,
'load_time': time.time()
}
# 缓存实例
model_instances[model_key] = result
return result
except Exception as e:
logger.error(f"加载模型失败: {str(e)}")
raise
def get_model_info() -> str:
"""
获取可用的Whisper模型信息
Returns:
str: 模型信息的JSON字符串
"""
import json
models = [
"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"
]
devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
compute_types = ["float16", "int8"] if torch.cuda.is_available() else ["int8"]
# 支持的语言列表
languages = {
"zh": "中文", "en": "英语", "ja": "日语", "ko": "韩语", "de": "德语",
"fr": "法语", "es": "西班牙语", "ru": "俄语", "it": "意大利语",
"pt": "葡萄牙语", "nl": "荷兰语", "ar": "阿拉伯语", "hi": "印地语",
"tr": "土耳其语", "vi": "越南语", "th": "泰语", "id": "印尼语"
}
# 支持的音频格式
audio_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
info = {
"available_models": models,
"default_model": "large-v3",
"available_devices": devices,
"default_device": "cuda" if torch.cuda.is_available() else "cpu",
"available_compute_types": compute_types,
"default_compute_type": "float16" if torch.cuda.is_available() else "int8",
"cuda_available": torch.cuda.is_available(),
"supported_languages": languages,
"supported_audio_formats": audio_formats,
"version": "0.1.1"
}
if torch.cuda.is_available():
info["gpu_info"] = {
"name": torch.cuda.get_device_name(0),
"memory_total": f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB",
"memory_available": f"{torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated() / 1e9:.2f} GB"
}
return json.dumps(info, indent=2)

View File

@@ -1,9 +1,12 @@
[project]
name = "fast-whisper-mcp-server"
version = "0.1.0"
description = "Add your description here"
version = "0.1.1"
description = "High-performance speech recognition service with MCP and REST API servers"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"faster-whisper>=1.1.1",
]
[tool.setuptools]
packages = ["src"]

View File

@@ -1,22 +1,36 @@
# uv pip install -r ./requirements.txt --index-url https://download.pytorch.org/whl/cu126
# uv pip install -r ./requirements.txt --index-url https://download.pytorch.org/whl/cu124
faster-whisper
torch==2.6.0+cu126
torchaudio==2.6.0+cu126
torch #==2.6.0+cu124
torchaudio #==2.6.0+cu124
# uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
# uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
# pip install faster-whisper>=0.9.0
# pip install mcp[cli]>=1.2.0
mcp[cli]
# PyTorch安装指南:
# 请根据您的CUDA版本安装适当版本的PyTorch:
# REST API dependencies
fastapi>=0.115.0
uvicorn[standard]>=0.32.0
python-multipart>=0.0.9
# Test audio generation dependencies
gTTS>=2.3.0
pyttsx3>=2.90
scipy>=1.10.0
numpy>=1.24.0
# PyTorch Installation Guide:
# Please install the appropriate version of PyTorch based on your CUDA version:
#
# • CUDA 12.6:
# pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
#
# • CUDA 12.4:
# pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
#
# • CUDA 12.1:
# pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
#
# • CPU版本:
# • CPU version:
# pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
#
# 可用命令`nvcc --version`或`nvidia-smi`查看CUDA版本

70
reset_gpu.sh Executable file
View File

@@ -0,0 +1,70 @@
#!/bin/bash
# Script to reset NVIDIA GPU drivers without rebooting
# This reloads kernel modules and restarts nvidia-persistenced service
echo "============================================================"
echo "NVIDIA GPU Driver Reset Script"
echo "============================================================"
echo ""
# Stop nvidia-persistenced service
echo "Stopping nvidia-persistenced service..."
sudo systemctl stop nvidia-persistenced
if [ $? -eq 0 ]; then
echo "✓ nvidia-persistenced stopped"
else
echo "✗ Failed to stop nvidia-persistenced"
exit 1
fi
echo ""
# Unload NVIDIA kernel modules (in correct order)
echo "Unloading NVIDIA kernel modules..."
sudo rmmod nvidia_uvm 2>/dev/null && echo "✓ nvidia_uvm unloaded" || echo " nvidia_uvm not loaded or failed to unload"
sudo rmmod nvidia_drm 2>/dev/null && echo "✓ nvidia_drm unloaded" || echo " nvidia_drm not loaded or failed to unload"
sudo rmmod nvidia_modeset 2>/dev/null && echo "✓ nvidia_modeset unloaded" || echo " nvidia_modeset not loaded or failed to unload"
sudo rmmod nvidia 2>/dev/null && echo "✓ nvidia unloaded" || echo " nvidia not loaded or failed to unload"
echo ""
# Small delay to ensure clean unload
sleep 1
# Reload NVIDIA kernel modules (in correct order)
echo "Loading NVIDIA kernel modules..."
sudo modprobe nvidia && echo "✓ nvidia loaded" || { echo "✗ Failed to load nvidia"; exit 1; }
sudo modprobe nvidia_modeset && echo "✓ nvidia_modeset loaded" || { echo "✗ Failed to load nvidia_modeset"; exit 1; }
sudo modprobe nvidia_uvm && echo "✓ nvidia_uvm loaded" || { echo "✗ Failed to load nvidia_uvm"; exit 1; }
sudo modprobe nvidia_drm && echo "✓ nvidia_drm loaded" || { echo "✗ Failed to load nvidia_drm"; exit 1; }
echo ""
# Restart nvidia-persistenced service
echo "Starting nvidia-persistenced service..."
sudo systemctl start nvidia-persistenced
if [ $? -eq 0 ]; then
echo "✓ nvidia-persistenced started"
else
echo "✗ Failed to start nvidia-persistenced"
exit 1
fi
echo ""
# Verify GPU is accessible
echo "Verifying GPU accessibility..."
if command -v nvidia-smi &> /dev/null; then
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
if [ $? -eq 0 ]; then
echo "✓ GPU reset successful"
else
echo "✗ GPU not accessible"
exit 1
fi
else
echo "✗ nvidia-smi not found"
exit 1
fi
echo ""
echo "============================================================"
echo "GPU driver reset completed successfully"
echo "============================================================"

62
run_api_server.sh Executable file
View File

@@ -0,0 +1,62 @@
#!/bin/bash
set -e
datetime_prefix() {
date "+[%Y-%m-%d %H:%M:%S]"
}
# Set Python path to include src directory
export PYTHONPATH="/home/uad/agents/tools/mcp-transcriptor/src:$PYTHONPATH"
# Set CUDA library path
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/targets/x86_64-linux/lib:$LD_LIBRARY_PATH
# Set environment variables
export CUDA_VISIBLE_DEVICES=1
export WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/data/models"
export TRANSCRIPTION_OUTPUT_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs"
export TRANSCRIPTION_BATCH_OUTPUT_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/batch"
export TRANSCRIPTION_MODEL="large-v3"
export TRANSCRIPTION_DEVICE="cuda"
export TRANSCRIPTION_COMPUTE_TYPE="float16"
export TRANSCRIPTION_OUTPUT_FORMAT="txt"
export TRANSCRIPTION_BEAM_SIZE="5"
export TRANSCRIPTION_TEMPERATURE="0.0"
export TRANSCRIPTION_USE_TIMESTAMP="false"
export TRANSCRIPTION_FILENAME_PREFIX=""
# API server configuration
export API_HOST="0.0.0.0"
export API_PORT="8000"
# GPU Auto-Reset Configuration
export GPU_RESET_COOLDOWN_MINUTES=5 # Minimum time between GPU reset attempts
# Job Queue Configuration
export JOB_QUEUE_MAX_SIZE=100
export JOB_METADATA_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
export JOB_RETENTION_DAYS=7
# GPU Health Monitoring
export GPU_HEALTH_CHECK_ENABLED=true
export GPU_HEALTH_CHECK_INTERVAL_MINUTES=10
export GPU_HEALTH_TEST_MODEL="tiny"
# Log start of the script
echo "$(datetime_prefix) Starting Whisper REST API server..."
echo "$(datetime_prefix) Model directory: $WHISPER_MODEL_DIR"
echo "$(datetime_prefix) API server: http://$API_HOST:$API_PORT"
# Optional: Verify required directories exist
if [ ! -d "$WHISPER_MODEL_DIR" ]; then
echo "$(datetime_prefix) Warning: Whisper model directory does not exist: $WHISPER_MODEL_DIR"
echo "$(datetime_prefix) Models will be downloaded to default cache directory"
fi
# Ensure output directories exist
mkdir -p "$TRANSCRIPTION_OUTPUT_DIR"
mkdir -p "$TRANSCRIPTION_BATCH_OUTPUT_DIR"
mkdir -p "$JOB_METADATA_DIR"
# Run the API server
/home/uad/agents/tools/mcp-transcriptor/venv/bin/python -u /home/uad/agents/tools/mcp-transcriptor/src/servers/api_server.py 2>&1 | tee /home/uad/agents/tools/mcp-transcriptor/api.logs

61
run_mcp_server.sh Executable file
View File

@@ -0,0 +1,61 @@
#!/bin/bash
set -e
datetime_prefix() {
date "+[%Y-%m-%d %H:%M:%S]"
}
# Get current user ID to avoid permission issues
USER_ID=$(id -u)
GROUP_ID=$(id -g)
# Set Python path to include src directory
export PYTHONPATH="/home/uad/agents/tools/mcp-transcriptor/src:$PYTHONPATH"
# Set CUDA library path
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/targets/x86_64-linux/lib:$LD_LIBRARY_PATH
# Set environment variables
export CUDA_VISIBLE_DEVICES=1
export WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/data/models"
export TRANSCRIPTION_OUTPUT_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs"
export TRANSCRIPTION_BATCH_OUTPUT_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/batch"
export TRANSCRIPTION_MODEL="large-v3"
export TRANSCRIPTION_DEVICE="cuda"
export TRANSCRIPTION_COMPUTE_TYPE="cuda"
export TRANSCRIPTION_OUTPUT_FORMAT="txt"
export TRANSCRIPTION_BEAM_SIZE="2"
export TRANSCRIPTION_TEMPERATURE="0.0"
export TRANSCRIPTION_USE_TIMESTAMP="false"
export TRANSCRIPTION_FILENAME_PREFIX="test_"
# GPU Auto-Reset Configuration
export GPU_RESET_COOLDOWN_MINUTES=5 # Minimum time between GPU reset attempts
# Job Queue Configuration
export JOB_QUEUE_MAX_SIZE=100
export JOB_METADATA_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
export JOB_RETENTION_DAYS=7
# GPU Health Monitoring
export GPU_HEALTH_CHECK_ENABLED=true
export GPU_HEALTH_CHECK_INTERVAL_MINUTES=10
export GPU_HEALTH_TEST_MODEL="tiny"
# Log start of the script
echo "$(datetime_prefix) Starting whisper server script..."
echo "test: $WHISPER_MODEL_DIR"
# Optional: Verify required directories exist
if [ ! -d "$WHISPER_MODEL_DIR" ]; then
echo "$(datetime_prefix) Error: Whisper model directory does not exist: $WHISPER_MODEL_DIR"
exit 1
fi
# Ensure job metadata directory exists
mkdir -p "$JOB_METADATA_DIR"
# Run the Python script with the defined environment variables
#/home/uad/agents/tools/mcp-transcriptor/venv/bin/python /home/uad/agents/tools/mcp-transcriptor/whisper_server.py 2>&1 | tee /home/uad/agents/tools/mcp-transcriptor/mcp.logs
/home/uad/agents/tools/mcp-transcriptor/venv/bin/python -u /home/uad/agents/tools/mcp-transcriptor/src/servers/whisper_server.py 2>&1 | tee /home/uad/agents/tools/mcp-transcriptor/mcp.logs

9
src/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
"""
Faster Whisper MCP Transcription Service
High-performance audio transcription service with dual-server architecture
(MCP and REST API) featuring async job queue and GPU health monitoring.
"""
__version__ = "0.2.0"
__author__ = "Whisper MCP Team"

6
src/core/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""
Core modules for Whisper transcription service.
Includes model management, transcription logic, job queue, GPU health monitoring,
and GPU reset functionality.
"""

467
src/core/gpu_health.py Normal file
View File

@@ -0,0 +1,467 @@
"""
GPU health monitoring for Whisper transcription service.
Performs real GPU health checks using actual model loading and transcription,
with strict failure handling to prevent silent CPU fallbacks.
Includes circuit breaker pattern to prevent repeated failed checks.
"""
import time
import logging
import threading
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Optional, List
import torch
from utils.test_audio_generator import generate_test_audio
from utils.circuit_breaker import CircuitBreaker, CircuitBreakerOpen
logger = logging.getLogger(__name__)
# Global circuit breaker for GPU health checks
_gpu_health_circuit_breaker = CircuitBreaker(
name="gpu_health_check",
failure_threshold=3, # Open after 3 consecutive failures
success_threshold=2, # Close after 2 consecutive successes
timeout_seconds=60, # Try again after 60 seconds
half_open_max_calls=1 # Only 1 test call in half-open state
)
# Import reset functionality (after logger initialization)
try:
from core.gpu_reset import (
reset_gpu_drivers,
can_attempt_reset,
record_reset_attempt
)
GPU_RESET_AVAILABLE = True
except ImportError as e:
logger.warning(f"GPU reset functionality not available: {e}")
GPU_RESET_AVAILABLE = False
@dataclass
class GPUHealthStatus:
"""Data class for GPU health check results."""
gpu_available: bool # torch.cuda.is_available()
gpu_working: bool # Model actually loaded and ran on GPU
device_used: str # "cuda" or "cpu"
device_name: str # GPU name if available
memory_total_gb: float # Total GPU memory
memory_available_gb: float # Available GPU memory
test_duration_seconds: float # How long test took
timestamp: str # ISO timestamp
error: Optional[str] = None # Error message if any
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return asdict(self)
def _check_gpu_health_internal(expected_device: str = "auto") -> GPUHealthStatus:
"""
Comprehensive GPU health check using real model + transcription.
Args:
expected_device: Expected device ("auto", "cuda", "cpu")
Returns:
GPUHealthStatus object
Raises:
RuntimeError: If expected_device="cuda" but GPU test fails
Implementation:
1. Generate test audio (1 second)
2. Load tiny model with requested device
3. Transcribe test audio
4. Time the operation
5. Verify model actually ran on GPU (check torch.cuda.memory_allocated)
6. CRITICAL: If expected_device="cuda" but used="cpu" → raise RuntimeError
7. Return detailed status
Performance Expectations:
- GPU (tiny model): 0.3-1.0 seconds
- CPU (tiny model): 3-10 seconds
- If GPU test takes >2 seconds, likely running on CPU
"""
timestamp = datetime.utcnow().isoformat() + "Z"
start_time = time.time()
# Initialize status with defaults
gpu_available = torch.cuda.is_available()
device_name = ""
memory_total_gb = 0.0
memory_available_gb = 0.0
actual_device = "cpu"
gpu_working = False
error_msg = None
# Get GPU info if available
if gpu_available:
try:
device_name = torch.cuda.get_device_name(0)
memory_total_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
memory_available_gb = (torch.cuda.get_device_properties(0).total_memory -
torch.cuda.memory_allocated(0)) / (1024**3)
except Exception as e:
logger.warning(f"Failed to get GPU info: {e}")
try:
# Generate test audio
test_audio_path = generate_test_audio(duration_seconds=1.0)
# Import here to avoid circular dependencies
from faster_whisper import WhisperModel
# Determine device to use
test_device = "cpu"
test_compute_type = "int8"
if expected_device == "cuda" or (expected_device == "auto" and gpu_available):
test_device = "cuda"
test_compute_type = "float16"
# Record GPU memory before loading model
gpu_memory_before = 0
if torch.cuda.is_available():
gpu_memory_before = torch.cuda.memory_allocated(0)
# Load tiny model and transcribe
try:
model = WhisperModel(
"tiny",
device=test_device,
compute_type=test_compute_type
)
# Transcribe test audio
segments, info = model.transcribe(test_audio_path, beam_size=1)
# Consume segments (needed to actually run inference)
list(segments)
# Check if GPU was actually used
# faster-whisper uses CTranslate2 which manages GPU memory separately
# from PyTorch, so we check model.model.device instead of torch memory
actual_device = model.model.device
if actual_device == "cuda":
gpu_working = True
else:
gpu_working = False
except Exception as e:
error_msg = f"Model loading/inference failed: {str(e)}"
actual_device = "cpu"
gpu_working = False
logger.error(f"GPU health check failed: {error_msg}")
except Exception as e:
error_msg = f"Health check setup failed: {str(e)}"
logger.error(f"GPU health check error: {error_msg}")
# Calculate test duration
test_duration = time.time() - start_time
# Create status object
status = GPUHealthStatus(
gpu_available=gpu_available,
gpu_working=gpu_working,
device_used=actual_device,
device_name=device_name,
memory_total_gb=round(memory_total_gb, 2),
memory_available_gb=round(memory_available_gb, 2),
test_duration_seconds=round(test_duration, 2),
timestamp=timestamp,
error=error_msg
)
# CRITICAL: Reject if expected CUDA but got CPU
# This service is GPU-only, so also reject device="auto" that falls back to CPU
if expected_device == "cuda" and actual_device == "cpu":
error_message = (
f"GPU device requested but model loaded on CPU. "
f"This indicates GPU driver issues or insufficient memory. "
f"Transcription would be 10-100x slower than expected. "
f"Please check CUDA installation and GPU availability. "
f"Details: {error_msg or 'Unknown cause'}"
)
logger.error(error_message)
raise RuntimeError(error_message)
# CRITICAL: Service requires GPU - reject CPU fallback even with device="auto"
if expected_device == "auto" and actual_device == "cpu":
error_message = (
f"GPU required but not available. This service is configured for GPU-only operation. "
f"CUDA is not available or GPU health check failed. "
f"Please ensure CUDA is properly installed and GPU is accessible. "
f"Details: {error_msg or 'CUDA not available'}"
)
logger.error(error_message)
raise RuntimeError(error_message)
# Log health check result
if gpu_working:
logger.info(
f"GPU health check passed: {device_name}, "
f"test duration: {test_duration:.2f}s"
)
elif gpu_available and not gpu_working:
logger.warning(
f"GPU available but health check failed. "
f"Test duration: {test_duration:.2f}s, Error: {error_msg}"
)
else:
logger.info(f"GPU not available, using CPU. Test duration: {test_duration:.2f}s")
return status
def check_gpu_health(expected_device: str = "auto", use_circuit_breaker: bool = True) -> GPUHealthStatus:
"""
GPU health check with optional circuit breaker protection.
This is the main entry point for GPU health checks. By default, it uses
circuit breaker pattern to prevent repeated failed checks.
Args:
expected_device: Expected device ("auto", "cuda", "cpu")
use_circuit_breaker: Enable circuit breaker protection (default: True)
Returns:
GPUHealthStatus object
Raises:
RuntimeError: If expected_device="cuda" but GPU test fails
CircuitBreakerOpen: If circuit breaker is open (too many recent failures)
"""
if use_circuit_breaker:
try:
return _gpu_health_circuit_breaker.call(_check_gpu_health_internal, expected_device)
except CircuitBreakerOpen as e:
# Circuit is open, fail fast without attempting check
logger.warning(f"GPU health check circuit breaker is OPEN: {e}")
raise RuntimeError(f"GPU health check unavailable: {str(e)}")
else:
return _check_gpu_health_internal(expected_device)
def get_circuit_breaker_stats() -> dict:
"""
Get current circuit breaker statistics.
Returns:
Dictionary with circuit state and failure/success counts
"""
return _gpu_health_circuit_breaker.get_stats()
def reset_circuit_breaker():
"""
Manually reset the GPU health check circuit breaker.
Useful for:
- Testing
- Manual intervention after fixing GPU issues
- Clearing persistent error state
"""
_gpu_health_circuit_breaker.reset()
logger.info("GPU health check circuit breaker manually reset")
def check_gpu_health_with_reset(
expected_device: str = "cuda",
auto_reset: bool = True
) -> GPUHealthStatus:
"""
Check GPU health with automatic reset on failure.
This function wraps check_gpu_health() and adds automatic GPU driver
reset capability with cooldown protection to prevent reset loops.
Behavior:
1. Attempt GPU health check
2. If fails and auto_reset=True:
a. Check cooldown (prevents reset loops)
b. If cooldown active → raise error immediately
c. If cooldown OK → reset drivers → wait 3s → retry
3. If retry fails → raise error (terminate service)
Args:
expected_device: Expected device ("auto", "cuda", "cpu")
auto_reset: Enable automatic reset on failure (default: True)
Returns:
GPUHealthStatus object
Raises:
RuntimeError: If GPU check fails and reset not available/allowed,
or if retry after reset fails
"""
try:
# First attempt
return check_gpu_health(expected_device)
except (RuntimeError, Exception) as first_error:
# GPU health check failed
if not auto_reset:
logger.error(f"GPU health check failed (auto-reset disabled): {first_error}")
raise
if not GPU_RESET_AVAILABLE:
logger.error(
f"GPU health check failed but reset not available: {first_error}"
)
raise RuntimeError(
f"GPU unavailable and reset functionality not available: {first_error}"
)
# Check if reset is allowed (cooldown protection)
if not can_attempt_reset():
error_msg = (
f"GPU health check failed but reset cooldown is active. "
f"This prevents reset loops. Last error: {first_error}. "
f"Service will terminate. If this happens after sleep/wake, "
f"the cooldown should expire soon and next restart will work."
)
logger.error(error_msg)
raise RuntimeError(error_msg)
# Attempt GPU reset
logger.warning("=" * 70)
logger.warning(f"GPU HEALTH CHECK FAILED: {first_error}")
logger.warning("Attempting automatic GPU driver reset...")
logger.warning("=" * 70)
try:
reset_gpu_drivers()
record_reset_attempt()
logger.info("GPU reset completed, waiting 3 seconds for stabilization...")
time.sleep(3)
# Retry GPU health check
logger.info("Retrying GPU health check after reset...")
status = check_gpu_health(expected_device)
logger.warning("=" * 70)
logger.warning("GPU HEALTH CHECK SUCCESS AFTER RESET")
logger.warning(f"Device: {status.device_name}")
logger.warning(f"Memory: {status.memory_available_gb:.2f} GB available")
logger.warning("=" * 70)
return status
except Exception as reset_error:
error_msg = (
f"GPU health check failed after reset attempt. "
f"Original error: {first_error}. "
f"Reset error: {reset_error}. "
f"Service will terminate."
)
logger.error(error_msg)
raise RuntimeError(error_msg)
class HealthMonitor:
"""Background thread for periodic GPU health monitoring."""
def __init__(self, check_interval_minutes: int = 10):
"""
Initialize health monitor.
Args:
check_interval_minutes: Interval between health checks (default: 10)
"""
self._check_interval_seconds = check_interval_minutes * 60
self._thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._latest_status: Optional[GPUHealthStatus] = None
self._history: List[GPUHealthStatus] = []
self._max_history = 100
self._lock = threading.Lock()
def start(self):
"""Start background monitoring thread."""
if self._thread is not None and self._thread.is_alive():
logger.warning("Health monitor already running")
return
logger.info(
f"Starting GPU health monitor "
f"(interval: {self._check_interval_seconds / 60:.1f} minutes)"
)
# Run initial health check
try:
status = check_gpu_health(expected_device="auto")
with self._lock:
self._latest_status = status
self._history.append(status)
except Exception as e:
logger.error(f"Initial GPU health check failed: {e}")
# Start background thread
self._stop_event.clear()
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
self._thread.start()
def stop(self):
"""Stop background monitoring thread."""
if self._thread is None:
return
logger.info("Stopping GPU health monitor")
self._stop_event.set()
self._thread.join(timeout=5.0)
self._thread = None
def get_latest_status(self) -> Optional[GPUHealthStatus]:
"""Get most recent health check result."""
with self._lock:
return self._latest_status
def get_health_history(self, limit: int = 10) -> List[GPUHealthStatus]:
"""
Get recent health check history.
Args:
limit: Maximum number of results to return
Returns:
List of GPUHealthStatus objects (most recent first)
"""
with self._lock:
return list(reversed(self._history[-limit:]))
def _monitor_loop(self):
"""Background monitoring loop."""
while not self._stop_event.wait(timeout=self._check_interval_seconds):
try:
logger.debug("Running periodic GPU health check")
status = check_gpu_health(expected_device="auto")
with self._lock:
self._latest_status = status
self._history.append(status)
# Trim history if too long
if len(self._history) > self._max_history:
self._history = self._history[-self._max_history:]
# Log warning if performance degraded
if status.gpu_available and not status.gpu_working:
logger.warning(
f"GPU health degraded: {status.error}. "
f"Test duration: {status.test_duration_seconds}s"
)
elif status.gpu_working and status.test_duration_seconds > 2.0:
logger.warning(
f"GPU health check slow: {status.test_duration_seconds}s "
f"(expected <1s). May indicate performance issues."
)
except Exception as e:
logger.error(f"Periodic GPU health check failed: {e}")

220
src/core/gpu_reset.py Normal file
View File

@@ -0,0 +1,220 @@
"""
GPU driver reset functionality with cooldown protection.
Provides automatic GPU driver reset on CUDA errors with safeguards
to prevent reset loops while allowing recovery from sleep/wake cycles.
"""
import os
import time
import logging
import subprocess
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
# Cooldown file location
RESET_TIMESTAMP_FILE = "/tmp/whisper-gpu-last-reset"
# Default cooldown period (minutes)
DEFAULT_COOLDOWN_MINUTES = 5
def get_cooldown_minutes() -> int:
"""
Get cooldown period from environment variable.
Returns:
Cooldown period in minutes (default: 5)
"""
try:
return int(os.getenv("GPU_RESET_COOLDOWN_MINUTES", str(DEFAULT_COOLDOWN_MINUTES)))
except ValueError:
logger.warning(
f"Invalid GPU_RESET_COOLDOWN_MINUTES value, using default: {DEFAULT_COOLDOWN_MINUTES}"
)
return DEFAULT_COOLDOWN_MINUTES
def get_last_reset_time() -> Optional[datetime]:
"""
Read timestamp of last GPU reset attempt.
Returns:
datetime object of last reset, or None if no previous reset
"""
try:
if not os.path.exists(RESET_TIMESTAMP_FILE):
return None
with open(RESET_TIMESTAMP_FILE, 'r') as f:
timestamp_str = f.read().strip()
return datetime.fromisoformat(timestamp_str)
except Exception as e:
logger.warning(f"Failed to read last reset timestamp: {e}")
return None
def record_reset_attempt() -> None:
"""
Record current time as last GPU reset attempt.
Creates/updates timestamp file with current UTC time.
"""
try:
timestamp = datetime.utcnow().isoformat()
with open(RESET_TIMESTAMP_FILE, 'w') as f:
f.write(timestamp)
logger.info(f"Recorded GPU reset timestamp: {timestamp}")
except Exception as e:
logger.error(f"Failed to record reset timestamp: {e}")
def can_attempt_reset() -> bool:
"""
Check if GPU reset can be attempted based on cooldown period.
Returns:
True if reset is allowed (no recent reset or cooldown expired),
False if cooldown is still active
"""
last_reset = get_last_reset_time()
if last_reset is None:
# No previous reset recorded
logger.debug("No previous GPU reset found, reset allowed")
return True
cooldown_minutes = get_cooldown_minutes()
cooldown_period = timedelta(minutes=cooldown_minutes)
time_since_reset = datetime.utcnow() - last_reset
if time_since_reset < cooldown_period:
remaining = cooldown_period - time_since_reset
logger.warning(
f"GPU reset cooldown active. "
f"Last reset: {last_reset.isoformat()}, "
f"Cooldown: {cooldown_minutes} min, "
f"Remaining: {remaining.total_seconds():.0f}s"
)
return False
logger.info(
f"GPU reset cooldown expired. "
f"Last reset: {last_reset.isoformat()}, "
f"Time since: {time_since_reset.total_seconds():.0f}s"
)
return True
def reset_gpu_drivers() -> None:
"""
Execute GPU driver reset using reset_gpu.sh script.
This function:
1. Locates reset_gpu.sh in project root
2. Executes it with sudo
3. Waits for completion
4. Raises exception if reset fails
Raises:
RuntimeError: If reset script fails or is not found
PermissionError: If sudo permissions not configured
Note:
Requires passwordless sudo for nvidia commands.
See CLAUDE.md for setup instructions.
"""
logger.warning("=" * 60)
logger.warning("INITIATING GPU DRIVER RESET")
logger.warning("=" * 60)
# Find reset_gpu.sh script
script_path = Path(__file__).parent.parent.parent / "reset_gpu.sh"
if not script_path.exists():
error_msg = f"GPU reset script not found: {script_path}"
logger.error(error_msg)
raise RuntimeError(error_msg)
if not os.access(script_path, os.X_OK):
error_msg = f"GPU reset script not executable: {script_path}"
logger.error(error_msg)
raise RuntimeError(error_msg)
logger.info(f"Executing GPU reset script: {script_path}")
logger.warning("This will temporarily interrupt all GPU operations")
try:
# Execute reset script with sudo
result = subprocess.run(
['sudo', str(script_path)],
capture_output=True,
text=True,
timeout=30 # 30 second timeout
)
# Log script output
if result.stdout:
logger.info(f"Reset script output:\n{result.stdout}")
if result.stderr:
logger.warning(f"Reset script stderr:\n{result.stderr}")
# Check exit code
if result.returncode != 0:
error_msg = (
f"GPU reset script failed with exit code {result.returncode}. "
f"This may indicate:\n"
f" 1. Sudo permissions not configured (see CLAUDE.md)\n"
f" 2. GPU hardware issue\n"
f" 3. Driver installation problem\n"
f"Output: {result.stderr or result.stdout}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
logger.warning("=" * 60)
logger.warning("GPU DRIVER RESET COMPLETED SUCCESSFULLY")
logger.warning("=" * 60)
except subprocess.TimeoutExpired:
error_msg = "GPU reset script timed out after 30 seconds"
logger.error(error_msg)
raise RuntimeError(error_msg)
except FileNotFoundError:
error_msg = (
"sudo command not found. This service requires sudo access "
"for GPU driver reset. Please ensure sudo is installed."
)
logger.error(error_msg)
raise RuntimeError(error_msg)
except Exception as e:
error_msg = f"Unexpected error during GPU reset: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def clear_reset_cooldown() -> None:
"""
Clear reset cooldown by removing timestamp file.
Useful for testing or manual intervention.
"""
try:
if os.path.exists(RESET_TIMESTAMP_FILE):
os.remove(RESET_TIMESTAMP_FILE)
logger.info("GPU reset cooldown cleared")
else:
logger.debug("No cooldown to clear")
except Exception as e:
logger.error(f"Failed to clear reset cooldown: {e}")

545
src/core/job_queue.py Normal file
View File

@@ -0,0 +1,545 @@
"""
Job queue manager for asynchronous transcription processing.
Provides FIFO job queue with background worker thread and disk persistence.
"""
import os
import json
import time
import queue
import logging
import threading
import uuid
from enum import Enum
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Optional, List, Dict
from core.gpu_health import check_gpu_health_with_reset
from core.transcriber import transcribe_audio
from utils.audio_processor import validate_audio_file
logger = logging.getLogger(__name__)
class JobStatus(Enum):
"""Job status enumeration."""
QUEUED = "queued" # In queue, waiting
RUNNING = "running" # Currently processing
COMPLETED = "completed" # Successfully finished
FAILED = "failed" # Error occurred
@dataclass
class Job:
"""Represents a transcription job."""
job_id: str
status: JobStatus
created_at: datetime
started_at: Optional[datetime]
completed_at: Optional[datetime]
queue_position: int
# Request parameters
audio_path: str
model_name: str
device: str
compute_type: str
language: Optional[str]
output_format: str
beam_size: int
temperature: float
initial_prompt: Optional[str]
output_directory: Optional[str]
# Results
result_path: Optional[str]
error: Optional[str]
processing_time_seconds: Optional[float]
def to_dict(self) -> dict:
"""Serialize to dictionary for JSON storage."""
return {
"job_id": self.job_id,
"status": self.status.value,
"created_at": self.created_at.isoformat() if self.created_at else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"queue_position": self.queue_position,
"request_params": {
"audio_path": self.audio_path,
"model_name": self.model_name,
"device": self.device,
"compute_type": self.compute_type,
"language": self.language,
"output_format": self.output_format,
"beam_size": self.beam_size,
"temperature": self.temperature,
"initial_prompt": self.initial_prompt,
"output_directory": self.output_directory,
},
"result_path": self.result_path,
"error": self.error,
"processing_time_seconds": self.processing_time_seconds,
}
@classmethod
def from_dict(cls, data: dict) -> 'Job':
"""Deserialize from dictionary."""
params = data.get("request_params", {})
return cls(
job_id=data["job_id"],
status=JobStatus(data["status"]),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.utcnow(),
started_at=datetime.fromisoformat(data["started_at"]) if data.get("started_at") else None,
completed_at=datetime.fromisoformat(data["completed_at"]) if data.get("completed_at") else None,
queue_position=data.get("queue_position", 0),
audio_path=params["audio_path"],
model_name=params.get("model_name", "large-v3"),
device=params.get("device", "auto"),
compute_type=params.get("compute_type", "auto"),
language=params.get("language"),
output_format=params.get("output_format", "txt"),
beam_size=params.get("beam_size", 5),
temperature=params.get("temperature", 0.0),
initial_prompt=params.get("initial_prompt"),
output_directory=params.get("output_directory"),
result_path=data.get("result_path"),
error=data.get("error"),
processing_time_seconds=data.get("processing_time_seconds"),
)
def save_to_disk(self, metadata_dir: str):
"""Save job metadata to {metadata_dir}/{job_id}.json"""
os.makedirs(metadata_dir, exist_ok=True)
filepath = os.path.join(metadata_dir, f"{self.job_id}.json")
try:
with open(filepath, 'w') as f:
json.dump(self.to_dict(), f, indent=2)
except Exception as e:
logger.error(f"Failed to save job {self.job_id} to disk: {e}")
class JobQueue:
"""Manages job queue with background worker."""
def __init__(self,
max_queue_size: int = 100,
metadata_dir: str = "/outputs/jobs"):
"""
Initialize job queue.
Args:
max_queue_size: Maximum number of jobs in queue
metadata_dir: Directory to store job metadata JSON files
"""
self._queue = queue.Queue(maxsize=max_queue_size)
self._jobs: Dict[str, Job] = {}
self._metadata_dir = metadata_dir
self._worker_thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._current_job_id: Optional[str] = None
self._lock = threading.RLock() # Use RLock to allow re-entrant locking
self._max_queue_size = max_queue_size
# Create metadata directory
os.makedirs(metadata_dir, exist_ok=True)
def start(self):
"""
Start background worker thread.
Load existing jobs from disk on startup.
"""
if self._worker_thread is not None and self._worker_thread.is_alive():
logger.warning("Job queue worker already running")
return
logger.info(f"Starting job queue (max size: {self._max_queue_size})")
# Load existing jobs from disk
self._load_jobs_from_disk()
# Start background worker thread
self._stop_event.clear()
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
self._worker_thread.start()
logger.info("Job queue worker started")
def stop(self, wait_for_current: bool = True):
"""
Stop background worker.
Args:
wait_for_current: If True, wait for current job to complete
"""
if self._worker_thread is None:
return
logger.info(f"Stopping job queue (wait_for_current={wait_for_current})")
self._stop_event.set()
if wait_for_current:
self._worker_thread.join(timeout=30.0)
else:
self._worker_thread.join(timeout=1.0)
self._worker_thread = None
logger.info("Job queue worker stopped")
def submit_job(self,
audio_path: str,
model_name: str = "large-v3",
device: str = "auto",
compute_type: str = "auto",
language: Optional[str] = None,
output_format: str = "txt",
beam_size: int = 5,
temperature: float = 0.0,
initial_prompt: Optional[str] = None,
output_directory: Optional[str] = None) -> dict:
"""
Submit a new transcription job.
Returns:
dict: {
"job_id": str,
"status": str,
"queue_position": int,
"created_at": str
}
Raises:
queue.Full: If queue is at max capacity
RuntimeError: If GPU health check fails (when device="cuda")
FileNotFoundError: If audio file doesn't exist
"""
# 1. Validate audio file exists
try:
validate_audio_file(audio_path)
except Exception as e:
raise FileNotFoundError(f"Audio file validation failed: {e}")
# 2. Check GPU health (GPU required for all devices since this is GPU-only service)
# Both device="cuda" and device="auto" require GPU
if device == "cuda" or device == "auto":
try:
logger.info("Running GPU health check before job submission")
# Use expected_device to match what user requested
expected = "cuda" if device == "cuda" else "auto"
health_status = check_gpu_health_with_reset(expected_device=expected, auto_reset=True)
if not health_status.gpu_working:
raise RuntimeError(
f"GPU device required but not available. "
f"GPU check failed: {health_status.error}. "
f"This service is configured for GPU-only operation."
)
logger.info("GPU health check passed")
except RuntimeError as e:
# Re-raise GPU health errors
logger.error(f"Job rejected due to GPU health check failure: {e}")
raise RuntimeError(f"Job rejected: {e}")
elif device == "cpu":
# Reject CPU device explicitly
raise ValueError(
"CPU device requested but this service is configured for GPU-only operation. "
"Please use device='cuda' or device='auto' with a GPU available."
)
# 3. Generate job_id
job_id = str(uuid.uuid4())
# 4. Create Job object
job = Job(
job_id=job_id,
status=JobStatus.QUEUED,
created_at=datetime.utcnow(),
started_at=None,
completed_at=None,
queue_position=0, # Will be calculated
audio_path=audio_path,
model_name=model_name,
device=device,
compute_type=compute_type,
language=language,
output_format=output_format,
beam_size=beam_size,
temperature=temperature,
initial_prompt=initial_prompt,
output_directory=output_directory,
result_path=None,
error=None,
processing_time_seconds=None,
)
# 5. Add to queue (raises queue.Full if full)
try:
self._queue.put_nowait(job)
except queue.Full:
raise queue.Full(
f"Job queue is full (max size: {self._max_queue_size}). "
f"Please try again later."
)
# 6. Add to jobs dict and save to disk
with self._lock:
self._jobs[job_id] = job
self._calculate_queue_positions()
job.save_to_disk(self._metadata_dir)
logger.info(
f"Job {job_id} submitted: {audio_path} "
f"(queue position: {job.queue_position})"
)
# 7. Return job info
return {
"job_id": job_id,
"status": job.status.value,
"queue_position": job.queue_position,
"created_at": job.created_at.isoformat() + "Z"
}
def get_job_status(self, job_id: str) -> dict:
"""
Get current status of a job.
Returns:
dict: Job status information
Raises:
KeyError: If job_id not found
"""
with self._lock:
if job_id not in self._jobs:
raise KeyError(f"Job {job_id} not found")
job = self._jobs[job_id]
return {
"job_id": job.job_id,
"status": job.status.value,
"queue_position": job.queue_position if job.status == JobStatus.QUEUED else None,
"created_at": job.created_at.isoformat() + "Z",
"started_at": job.started_at.isoformat() + "Z" if job.started_at else None,
"completed_at": job.completed_at.isoformat() + "Z" if job.completed_at else None,
"result_path": job.result_path,
"error": job.error,
"processing_time_seconds": job.processing_time_seconds,
}
def get_job_result(self, job_id: str) -> str:
"""
Get transcription result text for completed job.
Returns:
str: Content of transcription file
Raises:
KeyError: If job_id not found
ValueError: If job not completed
FileNotFoundError: If result file missing
"""
with self._lock:
if job_id not in self._jobs:
raise KeyError(f"Job {job_id} not found")
job = self._jobs[job_id]
if job.status != JobStatus.COMPLETED:
raise ValueError(
f"Job {job_id} is not completed yet. "
f"Current status: {job.status.value}"
)
if not job.result_path:
raise FileNotFoundError(f"Job {job_id} has no result path")
# Read result file (outside lock to avoid blocking)
if not os.path.exists(job.result_path):
raise FileNotFoundError(
f"Result file not found: {job.result_path}"
)
with open(job.result_path, 'r', encoding='utf-8') as f:
return f.read()
def list_jobs(self,
status_filter: Optional[JobStatus] = None,
limit: int = 100) -> List[dict]:
"""
List jobs with optional status filter.
Args:
status_filter: Only return jobs with this status
limit: Maximum number of jobs to return
Returns:
List of job status dictionaries
"""
with self._lock:
jobs = list(self._jobs.values())
# Filter by status
if status_filter:
jobs = [j for j in jobs if j.status == status_filter]
# Sort by created_at (newest first)
jobs.sort(key=lambda j: j.created_at, reverse=True)
# Limit results
jobs = jobs[:limit]
# Convert to dict
return [self.get_job_status(j.job_id) for j in jobs]
def _worker_loop(self):
"""
Background worker thread function.
Processes jobs from queue in FIFO order.
"""
logger.info("Job queue worker loop started")
while not self._stop_event.is_set():
try:
# Get job from queue (with timeout to check stop_event)
try:
job = self._queue.get(timeout=1.0)
except queue.Empty:
continue
# Update job status to running
with self._lock:
self._current_job_id = job.job_id
job.status = JobStatus.RUNNING
job.started_at = datetime.utcnow()
job.queue_position = 0
self._calculate_queue_positions()
job.save_to_disk(self._metadata_dir)
logger.info(f"Job {job.job_id} started processing")
# Process job
start_time = time.time()
try:
result = transcribe_audio(
audio_path=job.audio_path,
model_name=job.model_name,
device=job.device,
compute_type=job.compute_type,
language=job.language,
output_format=job.output_format,
beam_size=job.beam_size,
temperature=job.temperature,
initial_prompt=job.initial_prompt,
output_directory=job.output_directory
)
# Parse result
if "saved to:" in result:
job.result_path = result.split("saved to:")[1].strip()
job.status = JobStatus.COMPLETED
logger.info(
f"Job {job.job_id} completed successfully: {job.result_path}"
)
else:
job.status = JobStatus.FAILED
job.error = result
logger.error(f"Job {job.job_id} failed: {result}")
except Exception as e:
job.status = JobStatus.FAILED
job.error = str(e)
logger.error(f"Job {job.job_id} failed with exception: {e}")
finally:
# Update job completion info
job.completed_at = datetime.utcnow()
job.processing_time_seconds = time.time() - start_time
job.save_to_disk(self._metadata_dir)
with self._lock:
self._current_job_id = None
self._calculate_queue_positions()
self._queue.task_done()
logger.info(
f"Job {job.job_id} finished: "
f"status={job.status.value}, "
f"duration={job.processing_time_seconds:.1f}s"
)
except Exception as e:
logger.error(f"Unexpected error in worker loop: {e}")
logger.info("Job queue worker loop stopped")
def _load_jobs_from_disk(self):
"""Load existing job metadata from disk on startup."""
if not os.path.exists(self._metadata_dir):
logger.info("No existing job metadata directory found")
return
logger.info(f"Loading jobs from {self._metadata_dir}")
loaded_count = 0
for filename in os.listdir(self._metadata_dir):
if not filename.endswith(".json"):
continue
filepath = os.path.join(self._metadata_dir, filename)
try:
with open(filepath, 'r') as f:
data = json.load(f)
job = Job.from_dict(data)
# Handle jobs that were running when server stopped
if job.status == JobStatus.RUNNING:
job.status = JobStatus.FAILED
job.error = "Server restarted while job was running"
job.completed_at = datetime.utcnow()
job.save_to_disk(self._metadata_dir)
logger.warning(
f"Job {job.job_id} was running during shutdown, "
f"marking as failed"
)
# Re-queue queued jobs
elif job.status == JobStatus.QUEUED:
try:
self._queue.put_nowait(job)
logger.info(f"Re-queued job {job.job_id} from disk")
except queue.Full:
job.status = JobStatus.FAILED
job.error = "Queue full on server restart"
job.completed_at = datetime.utcnow()
job.save_to_disk(self._metadata_dir)
logger.warning(
f"Job {job.job_id} could not be re-queued (queue full)"
)
self._jobs[job.job_id] = job
loaded_count += 1
except Exception as e:
logger.error(f"Failed to load job from {filepath}: {e}")
logger.info(f"Loaded {loaded_count} jobs from disk")
self._calculate_queue_positions()
def _calculate_queue_positions(self):
"""Update queue_position for all queued jobs."""
queued_jobs = [
j for j in self._jobs.values()
if j.status == JobStatus.QUEUED
]
# Sort by created_at (FIFO)
queued_jobs.sort(key=lambda j: j.created_at)
# Update positions
for i, job in enumerate(queued_jobs, start=1):
job.queue_position = i

246
src/core/model_manager.py Normal file
View File

@@ -0,0 +1,246 @@
#!/usr/bin/env python3
"""
Model Management Module
Responsible for loading, caching, and managing Whisper models
"""
import os
import time
import logging
from typing import Dict, Any
import torch
from faster_whisper import WhisperModel, BatchedInferencePipeline
# Log configuration
logger = logging.getLogger(__name__)
# Import GPU health check with reset capability
try:
from core.gpu_health import check_gpu_health_with_reset
GPU_HEALTH_CHECK_AVAILABLE = True
except ImportError:
logger.warning("GPU health check with reset not available")
GPU_HEALTH_CHECK_AVAILABLE = False
# Global model instance cache
model_instances = {}
def test_gpu_driver():
"""Simple GPU driver test"""
try:
if not torch.cuda.is_available():
logger.error("CUDA not available in PyTorch")
raise RuntimeError("CUDA not available")
gpu_count = torch.cuda.device_count()
if gpu_count == 0:
logger.error("No CUDA devices found")
raise RuntimeError("No CUDA devices")
# Quick GPU test
test_tensor = torch.randn(10, 10).cuda()
_ = test_tensor @ test_tensor.T
device_name = torch.cuda.get_device_name(0)
memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
logger.info(f"GPU test passed: {device_name} ({memory_gb:.1f}GB)")
except Exception as e:
logger.error(f"GPU test failed: {e}")
raise RuntimeError(f"GPU initialization failed: {e}")
def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[str, Any]:
"""
Get or create Whisper model instance
Args:
model_name: Model name (tiny, base, small, medium, large-v1, large-v2, large-v3)
device: Running device (cpu, cuda, auto)
compute_type: Computation type (float16, int8, auto)
Returns:
dict: Dictionary containing model instance and configuration
"""
global model_instances
# Validate model name
valid_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"]
if model_name not in valid_models:
raise ValueError(f"Invalid model name: {model_name}. Valid models: {', '.join(valid_models)}")
# Auto-detect device - GPU REQUIRED (no CPU fallback)
if device == "auto":
if not torch.cuda.is_available():
error_msg = (
"GPU required but CUDA is not available. "
"This service is configured for GPU-only operation. "
"Please ensure CUDA is properly installed and GPU is accessible."
)
logger.error(error_msg)
raise RuntimeError(error_msg)
device = "cuda"
# Auto-detect compute type
if compute_type == "auto":
compute_type = "float16" if device == "cuda" else "int8"
# Validate device and compute type
if device not in ["cpu", "cuda"]:
raise ValueError(f"Invalid device: {device}. Valid devices: cpu, cuda")
# CRITICAL: Reject CPU device - this service is GPU-only
if device == "cpu":
error_msg = (
"CPU device requested but this service is configured for GPU-only operation. "
"Please use device='cuda' or device='auto' with a GPU available."
)
logger.error(error_msg)
raise ValueError(error_msg)
if device == "cuda" and not torch.cuda.is_available():
logger.error("CUDA requested but not available")
raise RuntimeError("CUDA not available but explicitly requested")
if compute_type not in ["float16", "int8"]:
raise ValueError(f"Invalid compute type: {compute_type}. Valid compute types: float16, int8")
if device == "cpu" and compute_type == "float16":
logger.warning("CPU device does not support float16 computation type, automatically switching to int8")
compute_type = "int8"
# Generate model key
model_key = f"{model_name}_{device}_{compute_type}"
# If model is already instantiated, return directly
if model_key in model_instances:
logger.info(f"Using cached model instance: {model_key}")
return model_instances[model_key]
# Test GPU driver before loading model and clean
if device == "cuda":
# Use GPU health check with reset capability if available
if GPU_HEALTH_CHECK_AVAILABLE:
try:
logger.info("Running GPU health check with auto-reset before model loading")
check_gpu_health_with_reset(expected_device="cuda", auto_reset=True)
except Exception as e:
logger.error(f"GPU health check failed: {e}")
raise RuntimeError(f"GPU not available for model loading: {e}")
else:
# Fallback to simple GPU test
test_gpu_driver()
torch.cuda.empty_cache()
# Instantiate model
try:
logger.info(f"Loading Whisper model: {model_name} device: {device} compute type: {compute_type}")
# Base model
model = WhisperModel(
model_name,
device=device,
compute_type=compute_type,
download_root=os.environ.get("WHISPER_MODEL_DIR", None) # Support custom model directory
)
# Batch processing settings - batch processing enabled by default to improve speed
batched_model = None
batch_size = 0
if device == "cuda": # Only use batch processing on CUDA devices
# Determine appropriate batch size based on available memory
if torch.cuda.is_available():
gpu_mem = torch.cuda.get_device_properties(0).total_memory
free_mem = gpu_mem - torch.cuda.memory_allocated()
# Dynamically adjust batch size based on GPU memory
if free_mem > 16e9: # >16GB
batch_size = 32
elif free_mem > 12e9: # >12GB
batch_size = 16
elif free_mem > 8e9: # >8GB
batch_size = 8
elif free_mem > 4e9: # >4GB
batch_size = 4
else: # Smaller memory
batch_size = 2
logger.info(f"Available GPU memory: {free_mem / 1e9:.2f} GB")
else:
batch_size = 8 # Default value
logger.info(f"Enabling batch processing acceleration, batch size: {batch_size}")
batched_model = BatchedInferencePipeline(model=model)
# Create result object
result = {
'model': model,
'device': device,
'compute_type': compute_type,
'batched_model': batched_model,
'batch_size': batch_size,
'load_time': time.time()
}
# Cache instance
model_instances[model_key] = result
return result
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
def get_model_info() -> str:
"""
Get available Whisper model information
Returns:
str: JSON string of model information
"""
import json
models = [
"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"
]
# GPU-only service - CUDA required
if not torch.cuda.is_available():
raise RuntimeError(
"GPU required but CUDA is not available. "
"This service is configured for GPU-only operation."
)
devices = ["cuda"] # CPU not supported in GPU-only mode
compute_types = ["float16", "int8"]
# Supported language list
languages = {
"zh": "Chinese", "en": "English", "ja": "Japanese", "ko": "Korean", "de": "German",
"fr": "French", "es": "Spanish", "ru": "Russian", "it": "Italian",
"pt": "Portuguese", "nl": "Dutch", "ar": "Arabic", "hi": "Hindi",
"tr": "Turkish", "vi": "Vietnamese", "th": "Thai", "id": "Indonesian"
}
# Supported audio formats
audio_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
info = {
"available_models": models,
"default_model": "large-v3",
"available_devices": devices,
"default_device": "cuda", # GPU-only service
"available_compute_types": compute_types,
"default_compute_type": "float16",
"cuda_available": True, # Required for this service
"gpu_only_mode": True, # Indicate this is GPU-only
"supported_languages": languages,
"supported_audio_formats": audio_formats,
"version": "0.1.1"
}
# GPU info (always present in GPU-only mode)
info["gpu_info"] = {
"name": torch.cuda.get_device_name(0),
"memory_total": f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB",
"memory_available": f"{torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated() / 1e9:.2f} GB"
}
return json.dumps(info, indent=2)

361
src/core/transcriber.py Normal file
View File

@@ -0,0 +1,361 @@
#!/usr/bin/env python3
"""
Transcription Core Module with Environment Variable Support
Contains core logic for audio transcription
"""
import os
import time
import logging
from typing import Dict, Any, Tuple, List, Optional, Union
from core.model_manager import get_whisper_model
from utils.audio_processor import validate_audio_file, process_audio
from utils.formatters import format_vtt, format_srt, format_json, format_txt, format_time
# Logging configuration
logger = logging.getLogger(__name__)
# Environment variable defaults
DEFAULT_OUTPUT_DIR = os.getenv('TRANSCRIPTION_OUTPUT_DIR')
DEFAULT_BATCH_OUTPUT_DIR = os.getenv('TRANSCRIPTION_BATCH_OUTPUT_DIR')
DEFAULT_MODEL = os.getenv('TRANSCRIPTION_MODEL', 'large-v3')
DEFAULT_DEVICE = os.getenv('TRANSCRIPTION_DEVICE', 'auto')
DEFAULT_COMPUTE_TYPE = os.getenv('TRANSCRIPTION_COMPUTE_TYPE', 'auto')
DEFAULT_LANGUAGE = os.getenv('TRANSCRIPTION_LANGUAGE', None)
DEFAULT_OUTPUT_FORMAT = os.getenv('TRANSCRIPTION_OUTPUT_FORMAT', 'txt')
DEFAULT_BEAM_SIZE = int(os.getenv('TRANSCRIPTION_BEAM_SIZE', '5'))
DEFAULT_TEMPERATURE = float(os.getenv('TRANSCRIPTION_TEMPERATURE', '0.0'))
# Model storage configuration
WHISPER_MODEL_DIR = os.getenv('WHISPER_MODEL_DIR', None)
# File naming configuration
USE_TIMESTAMP = os.getenv('TRANSCRIPTION_USE_TIMESTAMP', 'false').lower() == 'true'
FILENAME_PREFIX = os.getenv('TRANSCRIPTION_FILENAME_PREFIX', '')
FILENAME_SUFFIX = os.getenv('TRANSCRIPTION_FILENAME_SUFFIX', '')
def transcribe_audio(
audio_path: str,
model_name: str = None,
device: str = None,
compute_type: str = None,
language: str = None,
output_format: str = None,
beam_size: int = None,
temperature: float = None,
initial_prompt: str = None,
output_directory: str = None
) -> str:
"""
Transcribe audio file using Faster Whisper with ENV VAR support
Args:
audio_path: Path to audio file
model_name: Model name (defaults to TRANSCRIPTION_MODEL env var or "large-v3")
device: Execution device (defaults to TRANSCRIPTION_DEVICE env var or "auto")
compute_type: Computation type (defaults to TRANSCRIPTION_COMPUTE_TYPE env var or "auto")
language: Language code (defaults to TRANSCRIPTION_LANGUAGE env var or auto-detect)
output_format: Output format (defaults to TRANSCRIPTION_OUTPUT_FORMAT env var or "vtt")
beam_size: Beam search size (defaults to TRANSCRIPTION_BEAM_SIZE env var or 5)
temperature: Sampling temperature (defaults to TRANSCRIPTION_TEMPERATURE env var or 0.0)
initial_prompt: Initial prompt text
output_directory: Output directory (defaults to TRANSCRIPTION_OUTPUT_DIR env var or audio file directory)
Returns:
str: Transcription result path or error message
"""
# Apply environment variable defaults
model_name = model_name or DEFAULT_MODEL
device = device or DEFAULT_DEVICE
compute_type = compute_type or DEFAULT_COMPUTE_TYPE
language = language or DEFAULT_LANGUAGE
output_format = output_format or DEFAULT_OUTPUT_FORMAT
beam_size = beam_size if beam_size is not None else DEFAULT_BEAM_SIZE
temperature = temperature if temperature is not None else DEFAULT_TEMPERATURE
# Validate audio file
try:
validate_audio_file(audio_path)
except (FileNotFoundError, ValueError, OSError) as e:
return f"Error: {str(e)}"
try:
# Get model instance
model_instance = get_whisper_model(model_name, device, compute_type)
# Validate language code
supported_languages = {
"zh": "Chinese", "en": "English", "ja": "Japanese", "ko": "Korean", "de": "German",
"fr": "French", "es": "Spanish", "ru": "Russian", "it": "Italian",
"pt": "Portuguese", "nl": "Dutch", "ar": "Arabic", "hi": "Hindi",
"tr": "Turkish", "vi": "Vietnamese", "th": "Thai", "id": "Indonesian"
}
if language is not None and language not in supported_languages:
logger.warning(f"Unknown language code: {language}, will use auto-detection")
language = None
# Set transcription parameters
options = {
"language": language,
"vad_filter": True,
"vad_parameters": {"min_silence_duration_ms": 500},
"beam_size": beam_size,
"temperature": temperature,
"initial_prompt": initial_prompt,
"word_timestamps": True,
"suppress_tokens": [-1],
"condition_on_previous_text": True,
"compression_ratio_threshold": 2.4,
}
start_time = time.time()
logger.info(f"Starting transcription of file: {os.path.basename(audio_path)}")
# Process audio
audio_source = process_audio(audio_path)
# Execute transcription
if model_instance['batched_model'] is not None and model_instance['device'] == 'cuda':
logger.info("Using batch acceleration for transcription...")
segments, info = model_instance['batched_model'].transcribe(
audio_source,
batch_size=model_instance['batch_size'],
**options
)
else:
logger.info("Using standard model for transcription...")
segments, info = model_instance['model'].transcribe(audio_source, **options)
# Convert generator to list
segment_list = list(segments)
if not segment_list:
return "Transcription failed, no results obtained"
# Record transcription information
elapsed_time = time.time() - start_time
logger.info(f"Transcription completed, time used: {elapsed_time:.2f} seconds, detected language: {info.language}, audio length: {info.duration:.2f} seconds")
# Format transcription results based on output format
output_format_lower = output_format.lower()
if output_format_lower == "vtt":
transcription_result = format_vtt(segment_list)
elif output_format_lower == "srt":
transcription_result = format_srt(segment_list)
elif output_format_lower == "txt":
transcription_result = format_txt(segment_list)
elif output_format_lower == "json":
transcription_result = format_json(segment_list, info)
else:
raise ValueError(f"Unsupported output format: {output_format}. Supported formats: vtt, srt, txt, json")
# Determine output directory
audio_dir = os.path.dirname(audio_path)
audio_filename = os.path.splitext(os.path.basename(audio_path))[0]
# Priority: parameter > env var > audio directory
if output_directory is not None:
output_dir = output_directory
elif DEFAULT_OUTPUT_DIR is not None:
output_dir = DEFAULT_OUTPUT_DIR
else:
output_dir = audio_dir
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Generate filename with customizable format
filename_parts = []
# Add prefix if specified
if FILENAME_PREFIX:
filename_parts.append(FILENAME_PREFIX)
# Add base filename
filename_parts.append(audio_filename)
# Add suffix if specified
if FILENAME_SUFFIX:
filename_parts.append(FILENAME_SUFFIX)
# Add timestamp if enabled
if USE_TIMESTAMP:
timestamp = time.strftime("%Y%m%d%H%M%S")
filename_parts.append(timestamp)
# Join parts and add extension
base_name = "_".join(filename_parts)
output_filename = f"{base_name}.{output_format_lower}"
output_path = os.path.join(output_dir, output_filename)
# Write transcription results to file
try:
with open(output_path, "w", encoding="utf-8") as f:
f.write(transcription_result)
logger.info(f"Transcription results saved to: {output_path}")
return f"Transcription successful, results saved to: {output_path}"
except Exception as e:
logger.error(f"Failed to save transcription results: {str(e)}")
return f"Transcription successful, but failed to save results: {str(e)}"
except Exception as e:
logger.error(f"Transcription failed: {str(e)}")
return f"Error occurred during transcription: {str(e)}"
def batch_transcribe(
audio_folder: str,
output_folder: str = None,
model_name: str = None,
device: str = None,
compute_type: str = None,
language: str = None,
output_format: str = None,
beam_size: int = None,
temperature: float = None,
initial_prompt: str = None,
parallel_files: int = 1
) -> str:
"""
Batch transcribe audio files with ENV VAR support
Args:
audio_folder: Path to folder containing audio files
output_folder: Output folder (defaults to TRANSCRIPTION_BATCH_OUTPUT_DIR env var or "transcript" subfolder)
model_name: Model name (defaults to TRANSCRIPTION_MODEL env var or "large-v3")
device: Execution device (defaults to TRANSCRIPTION_DEVICE env var or "auto")
compute_type: Computation type (defaults to TRANSCRIPTION_COMPUTE_TYPE env var or "auto")
language: Language code (defaults to TRANSCRIPTION_LANGUAGE env var or auto-detect)
output_format: Output format (defaults to TRANSCRIPTION_OUTPUT_FORMAT env var or "vtt")
beam_size: Beam search size (defaults to TRANSCRIPTION_BEAM_SIZE env var or 5)
temperature: Sampling temperature (defaults to TRANSCRIPTION_TEMPERATURE env var or 0.0)
initial_prompt: Initial prompt text
parallel_files: Number of files to process in parallel (not implemented yet)
Returns:
str: Batch processing summary
"""
# Apply environment variable defaults
model_name = model_name or DEFAULT_MODEL
device = device or DEFAULT_DEVICE
compute_type = compute_type or DEFAULT_COMPUTE_TYPE
language = language or DEFAULT_LANGUAGE
output_format = output_format or DEFAULT_OUTPUT_FORMAT
beam_size = beam_size if beam_size is not None else DEFAULT_BEAM_SIZE
temperature = temperature if temperature is not None else DEFAULT_TEMPERATURE
if not os.path.isdir(audio_folder):
return f"Error: Folder does not exist: {audio_folder}"
# Determine output folder with environment variable support
if output_folder is not None:
# Use provided parameter
pass
elif DEFAULT_BATCH_OUTPUT_DIR is not None:
# Use environment variable
output_folder = DEFAULT_BATCH_OUTPUT_DIR
else:
# Use default subfolder
output_folder = os.path.join(audio_folder, "transcript")
# Ensure output directory exists
os.makedirs(output_folder, exist_ok=True)
# Validate output format
valid_formats = ["txt", "vtt", "srt", "json"]
if output_format.lower() not in valid_formats:
return f"Error: Unsupported output format: {output_format}. Supported formats: {', '.join(valid_formats)}"
# Get all audio files
audio_files = []
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
for filename in os.listdir(audio_folder):
file_ext = os.path.splitext(filename)[1].lower()
if file_ext in supported_formats:
audio_files.append(os.path.join(audio_folder, filename))
if not audio_files:
return f"No supported audio files found in {audio_folder}. Supported formats: {', '.join(supported_formats)}"
# Record start time
start_time = time.time()
total_files = len(audio_files)
logger.info(f"Starting batch transcription of {total_files} files, output format: {output_format}")
# Preload model
try:
get_whisper_model(model_name, device, compute_type)
logger.info(f"Model preloaded: {model_name}")
except Exception as e:
logger.error(f"Failed to preload model: {str(e)}")
return f"Batch processing failed: Cannot load model {model_name}: {str(e)}"
# Process files
results = []
success_count = 0
error_count = 0
for i, audio_path in enumerate(audio_files):
file_name = os.path.basename(audio_path)
elapsed = time.time() - start_time
# Report progress
progress_msg = report_progress(i, total_files, elapsed)
logger.info(f"{progress_msg} | Currently processing: {file_name}")
# Execute transcription
try:
result = transcribe_audio(
audio_path=audio_path,
model_name=model_name,
device=device,
compute_type=compute_type,
language=language,
output_format=output_format,
beam_size=beam_size,
temperature=temperature,
initial_prompt=initial_prompt,
output_directory=output_folder
)
if result.startswith("Error:") or result.startswith("Error occurred during transcription:"):
logger.error(f"Transcription failed: {file_name} - {result}")
results.append(f"❌ Failed: {file_name} - {result}")
error_count += 1
else:
output_path = result.split(": ")[1] if ": " in result else "Unknown path"
success_count += 1
results.append(f"✅ Success: {file_name} -> {os.path.basename(output_path)}")
except Exception as e:
logger.error(f"Error occurred during transcription process: {file_name} - {str(e)}")
results.append(f"❌ Failed: {file_name} - {str(e)}")
error_count += 1
# Calculate total transcription time
total_transcription_time = time.time() - start_time
# Generate summary
summary = f"Batch processing completed, total transcription time: {format_time(total_transcription_time)}"
summary += f" | Success: {success_count}/{total_files}"
summary += f" | Failed: {error_count}/{total_files}"
# Output results
for result in results:
logger.info(result)
return summary
def report_progress(current: int, total: int, elapsed_time: float) -> str:
"""Generate progress report"""
progress = current / total * 100
eta = (elapsed_time / current) * (total - current) if current > 0 else 0
return (f"Progress: {current}/{total} ({progress:.1f}%)" +
f" | Time used: {format_time(elapsed_time)}" +
f" | Estimated remaining: {format_time(eta)}")

5
src/servers/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""
Server implementations for Whisper transcription service.
Includes MCP server (whisper_server.py) and REST API server (api_server.py).
"""

466
src/servers/api_server.py Normal file
View File

@@ -0,0 +1,466 @@
#!/usr/bin/env python3
"""
FastAPI REST API Server for Whisper Transcription
Provides HTTP REST endpoints for audio transcription
"""
import os
import sys
import logging
import queue as queue_module
import tempfile
from pathlib import Path
from contextlib import asynccontextmanager
from typing import Optional, List
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel, Field
import json
from core.model_manager import get_model_info
from core.job_queue import JobQueue, JobStatus
from core.gpu_health import HealthMonitor, check_gpu_health, get_circuit_breaker_stats, reset_circuit_breaker
from utils.startup import startup_sequence, cleanup_on_shutdown
# Logging configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global instances
job_queue: Optional[JobQueue] = None
health_monitor: Optional[HealthMonitor] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""FastAPI lifespan context manager for startup/shutdown"""
global job_queue, health_monitor
# Startup - use common startup logic (without GPU check, handled in main)
logger.info("Initializing job queue and health monitor...")
from utils.startup import initialize_job_queue, initialize_health_monitor
job_queue = initialize_job_queue()
health_monitor = initialize_health_monitor()
yield
# Shutdown - use common cleanup logic
cleanup_on_shutdown(
job_queue=job_queue,
health_monitor=health_monitor,
wait_for_current_job=True
)
# Create FastAPI app
app = FastAPI(
title="Whisper Speech Recognition API",
description="High-performance audio transcription API based on Faster Whisper with async job queue",
version="0.2.0",
lifespan=lifespan
)
# Request/Response Models
class SubmitJobRequest(BaseModel):
audio_path: str = Field(..., description="Path to the audio file on the server")
model_name: str = Field("large-v3", description="Whisper model name")
device: str = Field("auto", description="Execution device (cuda, auto)")
compute_type: str = Field("auto", description="Computation type (float16, int8, auto)")
language: Optional[str] = Field(None, description="Language code (zh, en, ja, etc.)")
output_format: str = Field("txt", description="Output format (vtt, srt, json, txt)")
beam_size: int = Field(5, description="Beam search size")
temperature: float = Field(0.0, description="Sampling temperature")
initial_prompt: Optional[str] = Field(None, description="Initial prompt text")
output_directory: Optional[str] = Field(None, description="Output directory path")
# API Endpoints
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"name": "Whisper Speech Recognition API",
"version": "0.2.0",
"description": "Async job queue-based transcription service",
"endpoints": {
"GET /": "API information",
"GET /health": "Health check",
"GET /health/gpu": "GPU health check",
"GET /health/circuit-breaker": "Get circuit breaker stats",
"POST /health/circuit-breaker/reset": "Reset circuit breaker",
"GET /models": "Get available models information",
"POST /jobs": "Submit transcription job (async)",
"GET /jobs/{job_id}": "Get job status",
"GET /jobs/{job_id}/result": "Get job result",
"GET /jobs": "List jobs with optional filtering"
},
"workflow": {
"1": "Submit job via POST /jobs → receive job_id",
"2": "Poll status via GET /jobs/{job_id} → wait for status='completed'",
"3": "Get result via GET /jobs/{job_id}/result → retrieve transcription"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "service": "whisper-transcription"}
@app.get("/models")
async def get_models():
"""Get available Whisper models and configuration information"""
try:
model_info = get_model_info()
return JSONResponse(content=json.loads(model_info))
except Exception as e:
logger.error(f"Failed to get model info: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
@app.post("/jobs")
async def submit_job(request: SubmitJobRequest):
"""
Submit a transcription job for async processing.
Returns immediately with job_id. Poll GET /jobs/{job_id} for status.
"""
try:
job_info = job_queue.submit_job(
audio_path=request.audio_path,
model_name=request.model_name,
device=request.device,
compute_type=request.compute_type,
language=request.language,
output_format=request.output_format,
beam_size=request.beam_size,
temperature=request.temperature,
initial_prompt=request.initial_prompt,
output_directory=request.output_directory
)
return JSONResponse(
status_code=200,
content={
**job_info,
"message": f"Job submitted successfully. Poll /jobs/{job_info['job_id']} for status."
}
)
except queue_module.Full:
# Queue is full
logger.warning("Job queue is full, rejecting request")
raise HTTPException(
status_code=503,
detail={
"error": "Queue full",
"message": f"Job queue is full ({job_queue._max_queue_size}/{job_queue._max_queue_size}). Please try again later or contact administrator.",
"queue_size": job_queue._max_queue_size,
"max_queue_size": job_queue._max_queue_size
}
)
except FileNotFoundError as e:
# Invalid audio file
logger.error(f"Invalid audio file: {e}")
raise HTTPException(
status_code=400,
detail={
"error": "Invalid audio file",
"message": str(e),
"audio_path": request.audio_path
}
)
except ValueError as e:
# CPU device rejected
logger.error(f"Invalid device parameter: {e}")
raise HTTPException(
status_code=400,
detail={
"error": "Invalid device",
"message": str(e)
}
)
except RuntimeError as e:
# GPU health check failed
logger.error(f"GPU health check failed: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "GPU unavailable",
"message": f"Job rejected: {str(e)}"
}
)
except Exception as e:
logger.error(f"Failed to submit job: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "Internal error",
"message": f"Failed to submit job: {str(e)}"
}
)
@app.get("/jobs/{job_id}")
async def get_job_status_endpoint(job_id: str):
"""
Get the current status of a job.
Returns job status including queue position, timestamps, and result path when completed.
"""
try:
status = job_queue.get_job_status(job_id)
return JSONResponse(status_code=200, content=status)
except KeyError:
raise HTTPException(
status_code=404,
detail={
"error": "Job not found",
"message": f"Job ID '{job_id}' does not exist or has been cleaned up",
"job_id": job_id
}
)
except Exception as e:
logger.error(f"Failed to get job status: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "Internal error",
"message": f"Failed to get job status: {str(e)}"
}
)
@app.get("/jobs/{job_id}/result")
async def get_job_result_endpoint(job_id: str):
"""
Get the transcription result for a completed job.
Returns the transcription text. Only works for jobs with status='completed'.
"""
try:
result_text = job_queue.get_job_result(job_id)
return JSONResponse(
status_code=200,
content={"job_id": job_id, "result": result_text}
)
except KeyError:
raise HTTPException(
status_code=404,
detail={
"error": "Job not found",
"message": f"Job ID '{job_id}' does not exist or has been cleaned up",
"job_id": job_id
}
)
except ValueError as e:
# Job not completed yet
# Extract current status from error message
status_match = str(e).split("Current status: ")
current_status = status_match[1] if len(status_match) > 1 else "unknown"
raise HTTPException(
status_code=409,
detail={
"error": "Job not completed",
"message": f"Job is not completed yet. Current status: {current_status}. Please wait and poll again.",
"job_id": job_id,
"current_status": current_status
}
)
except FileNotFoundError as e:
# Result file missing
logger.error(f"Result file not found for job {job_id}: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "Result file not found",
"message": str(e),
"job_id": job_id
}
)
except Exception as e:
logger.error(f"Failed to get job result: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "Internal error",
"message": f"Failed to get job result: {str(e)}"
}
)
@app.get("/jobs")
async def list_jobs_endpoint(
status: Optional[str] = None,
limit: int = 100
):
"""
List jobs with optional filtering.
Query parameters:
- status: Filter by status (queued, running, completed, failed)
- limit: Maximum number of results (default: 100)
"""
try:
# Parse status filter
status_filter = None
if status:
try:
status_filter = JobStatus(status)
except ValueError:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid status",
"message": f"Invalid status value: {status}. Must be one of: queued, running, completed, failed"
}
)
jobs = job_queue.list_jobs(status_filter=status_filter, limit=limit)
return JSONResponse(
status_code=200,
content={
"jobs": jobs,
"total": len(jobs),
"filters": {
"status": status,
"limit": limit
}
}
)
except Exception as e:
logger.error(f"Failed to list jobs: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "Internal error",
"message": f"Failed to list jobs: {str(e)}"
}
)
@app.get("/health/gpu")
async def gpu_health_check_endpoint():
"""
Check GPU health by running a quick transcription test.
Returns detailed GPU status including device name, memory, and test duration.
"""
try:
status = check_gpu_health(expected_device="auto")
# Add interpretation
interpretation = "GPU is healthy and working correctly"
if not status.gpu_available:
interpretation = "GPU not available on this system"
elif not status.gpu_working:
interpretation = f"GPU available but not working correctly: {status.error}"
elif status.test_duration_seconds > 2.0:
interpretation = f"GPU working but performance degraded (test took {status.test_duration_seconds:.2f}s, expected <1s)"
return JSONResponse(
status_code=200,
content={
**status.to_dict(),
"interpretation": interpretation
}
)
except Exception as e:
logger.error(f"GPU health check failed: {e}")
return JSONResponse(
status_code=200, # Still return 200 but with error details
content={
"gpu_available": False,
"gpu_working": False,
"device_used": "unknown",
"device_name": "",
"memory_total_gb": 0.0,
"memory_available_gb": 0.0,
"test_duration_seconds": 0.0,
"timestamp": "",
"error": str(e),
"interpretation": f"GPU health check failed: {str(e)}"
}
)
@app.get("/health/circuit-breaker")
async def get_circuit_breaker_status():
"""
Get GPU health check circuit breaker statistics.
Returns current state, failure/success counts, and last failure time.
"""
try:
stats = get_circuit_breaker_stats()
return JSONResponse(status_code=200, content=stats)
except Exception as e:
logger.error(f"Failed to get circuit breaker stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/health/circuit-breaker/reset")
async def reset_circuit_breaker_endpoint():
"""
Manually reset the GPU health check circuit breaker.
Useful after fixing GPU issues or for testing purposes.
"""
try:
reset_circuit_breaker()
return JSONResponse(
status_code=200,
content={"message": "Circuit breaker reset successfully"}
)
except Exception as e:
logger.error(f"Failed to reset circuit breaker: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Perform startup GPU health check
from utils.startup import perform_startup_gpu_check
perform_startup_gpu_check(
required_device="cuda",
auto_reset=True,
exit_on_failure=True
)
# Get configuration from environment variables
host = os.getenv("API_HOST", "0.0.0.0")
port = int(os.getenv("API_PORT", "8000"))
logger.info(f"Starting Whisper REST API server on {host}:{port}")
uvicorn.run(
app,
host=host,
port=port,
log_level="info"
)

View File

@@ -0,0 +1,334 @@
#!/usr/bin/env python3
"""
Faster Whisper-based Speech Recognition MCP Service
Provides high-performance audio transcription with batch processing acceleration and multiple output formats
"""
import os
import sys
import logging
import json
import base64
import tempfile
from pathlib import Path
from typing import Optional
from mcp.server.fastmcp import FastMCP
from core.model_manager import get_model_info
from core.job_queue import JobQueue, JobStatus
from core.gpu_health import HealthMonitor, check_gpu_health
from utils.startup import startup_sequence, cleanup_on_shutdown
# Log configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global instances
job_queue: Optional[JobQueue] = None
health_monitor: Optional[HealthMonitor] = None
# Create FastMCP server instance
mcp = FastMCP(
name="fast-whisper-mcp-server",
version="0.2.0",
dependencies=["faster-whisper>=0.9.0", "torch==2.6.0+cu126", "torchaudio==2.6.0+cu126", "numpy>=1.20.0"]
)
@mcp.tool()
def get_model_info_api() -> str:
"""
Get available Whisper model information and system configuration.
Returns available models, devices, languages, and GPU information.
"""
return get_model_info()
@mcp.tool()
def transcribe_async(
audio_path: str,
model_name: str = "large-v3",
device: str = "auto",
compute_type: str = "auto",
language: Optional[str] = None,
output_format: str = "txt",
beam_size: int = 5,
temperature: float = 0.0,
initial_prompt: Optional[str] = None,
output_directory: Optional[str] = None
) -> str:
"""
Submit an audio file for asynchronous transcription.
IMPORTANT: This tool returns immediately with a job_id. Use get_job_status()
to check progress and get_job_result() to retrieve the transcription.
WORKFLOW FOR LLM AGENTS:
1. Call this tool to submit the job
2. You will receive a job_id and queue_position
3. Poll get_job_status(job_id) every 5-10 seconds to check progress
4. When status="completed", call get_job_result(job_id) to get transcription
For long audio files (>10 minutes), expect processing to take several minutes.
You can check queue_position to estimate wait time (each job ~2-5 minutes).
Args:
audio_path: Path to audio file on server
model_name: Whisper model (tiny, base, small, medium, large-v3)
device: Execution device (cuda, auto) - cpu is rejected
compute_type: Computation type (float16, int8, auto)
language: Language code (en, zh, ja, etc.) or auto-detect
output_format: Output format (txt, vtt, srt, json)
beam_size: Beam search size (larger=better quality, slower)
temperature: Sampling temperature (0.0=greedy)
initial_prompt: Optional prompt to guide transcription
output_directory: Where to save result (uses default if not specified)
Returns:
JSON string with job_id, status, queue_position, and instructions
"""
try:
job_info = job_queue.submit_job(
audio_path=audio_path,
model_name=model_name,
device=device,
compute_type=compute_type,
language=language,
output_format=output_format,
beam_size=beam_size,
temperature=temperature,
initial_prompt=initial_prompt,
output_directory=output_directory
)
return json.dumps({
**job_info,
"message": f"Job submitted successfully. Poll get_job_status('{job_info['job_id']}') for updates."
}, indent=2)
except Exception as e:
error_type = type(e).__name__
if "Full" in error_type or "queue is full" in str(e).lower():
error_code = "QUEUE_FULL"
message = f"Job queue is full. Please try again in a few minutes. Error: {str(e)}"
elif "FileNotFoundError" in error_type or "not found" in str(e).lower():
error_code = "INVALID_AUDIO_FILE"
message = f"Audio file not found or invalid. Error: {str(e)}"
elif "RuntimeError" in error_type or "GPU" in str(e):
error_code = "GPU_UNAVAILABLE"
message = f"GPU unavailable. Error: {str(e)}"
elif "ValueError" in error_type or "CPU" in str(e):
error_code = "INVALID_DEVICE"
message = f"Invalid device parameter. Error: {str(e)}"
else:
error_code = "INTERNAL_ERROR"
message = f"Failed to submit job. Error: {str(e)}"
return json.dumps({
"error": error_code,
"message": message
}, indent=2)
@mcp.tool()
def get_job_status(job_id: str) -> str:
"""
Check the status of a transcription job.
Status values:
- "queued": Job is waiting in queue. Check queue_position.
- "running": Job is currently being processed.
- "completed": Transcription finished. Call get_job_result() to retrieve.
- "failed": Job failed. Check error field for details.
Args:
job_id: Job ID from transcribe_async()
Returns:
JSON string with detailed job status including:
- status, queue_position, timestamps, error (if any)
"""
try:
status = job_queue.get_job_status(job_id)
return json.dumps(status, indent=2)
except KeyError:
return json.dumps({
"error": "JOB_NOT_FOUND",
"message": f"Job ID '{job_id}' does not exist. Please check the job_id."
}, indent=2)
except Exception as e:
return json.dumps({
"error": "INTERNAL_ERROR",
"message": f"Failed to get job status. Error: {str(e)}"
}, indent=2)
@mcp.tool()
def get_job_result(job_id: str) -> str:
"""
Retrieve the transcription result for a completed job.
IMPORTANT: Only call this when get_job_status() returns status="completed".
If the job is not completed, this will return an error.
Args:
job_id: Job ID from transcribe_async()
Returns:
Transcription text as a string
Errors:
- "Job not found" if invalid job_id
- "Job not completed yet" if status is not "completed"
- "Result file not found" if transcription file is missing
"""
try:
result_text = job_queue.get_job_result(job_id)
return result_text # Return raw text, not JSON
except KeyError:
return json.dumps({
"error": "JOB_NOT_FOUND",
"message": f"Job ID '{job_id}' does not exist."
}, indent=2)
except ValueError as e:
# Extract status from error message
status_match = str(e).split("Current status: ")
current_status = status_match[1] if len(status_match) > 1 else "unknown"
return json.dumps({
"error": "JOB_NOT_COMPLETED",
"message": f"Job is not completed yet. Current status: {current_status}. Please wait and check again.",
"current_status": current_status
}, indent=2)
except FileNotFoundError as e:
return json.dumps({
"error": "RESULT_FILE_NOT_FOUND",
"message": f"Result file not found. Error: {str(e)}"
}, indent=2)
except Exception as e:
return json.dumps({
"error": "INTERNAL_ERROR",
"message": f"Failed to get job result. Error: {str(e)}"
}, indent=2)
@mcp.tool()
def list_transcription_jobs(
status_filter: Optional[str] = None,
limit: int = 20
) -> str:
"""
List transcription jobs with optional filtering.
Useful for:
- Checking all your submitted jobs
- Finding completed jobs
- Monitoring queue status
Args:
status_filter: Filter by status (queued, running, completed, failed)
limit: Maximum number of jobs to return (default: 20)
Returns:
JSON string with list of jobs
"""
try:
# Parse status filter
status_obj = None
if status_filter:
try:
status_obj = JobStatus(status_filter)
except ValueError:
return json.dumps({
"error": "INVALID_STATUS",
"message": f"Invalid status: {status_filter}. Must be one of: queued, running, completed, failed"
}, indent=2)
jobs = job_queue.list_jobs(status_filter=status_obj, limit=limit)
return json.dumps({
"jobs": jobs,
"total": len(jobs),
"filters": {
"status": status_filter,
"limit": limit
}
}, indent=2)
except Exception as e:
return json.dumps({
"error": "INTERNAL_ERROR",
"message": f"Failed to list jobs. Error: {str(e)}"
}, indent=2)
@mcp.tool()
def check_gpu_health() -> str:
"""
Test GPU availability and performance by running a quick transcription.
This tool loads the tiny model and transcribes a 1-second test audio file
to verify the GPU is working correctly.
Use this when:
- You want to verify GPU is available before submitting large jobs
- You suspect GPU performance issues
- For monitoring/debugging purposes
Returns:
JSON string with detailed GPU status including:
- gpu_available, gpu_working, device_name, memory_info
- test_duration_seconds (GPU: <1s, CPU: 5-10s)
- interpretation message
Note: If this returns gpu_working=false, transcriptions will be very slow.
"""
try:
status = check_gpu_health(expected_device="auto")
# Add interpretation
interpretation = "GPU is healthy and working correctly"
if not status.gpu_available:
interpretation = "GPU not available on this system"
elif not status.gpu_working:
interpretation = f"GPU available but not working correctly: {status.error}"
elif status.test_duration_seconds > 2.0:
interpretation = f"GPU working but performance degraded (test took {status.test_duration_seconds:.2f}s, expected <1s)"
result = status.to_dict()
result["interpretation"] = interpretation
return json.dumps(result, indent=2)
except Exception as e:
return json.dumps({
"error": "GPU_CHECK_FAILED",
"message": f"GPU health check failed. Error: {str(e)}",
"gpu_available": False,
"gpu_working": False
}, indent=2)
if __name__ == "__main__":
print("starting mcp server for whisper stt transcriptor")
# Execute common startup sequence
job_queue, health_monitor = startup_sequence(
service_name="MCP Whisper Server",
require_gpu=True,
initialize_queue=True,
initialize_monitoring=True
)
try:
mcp.run()
finally:
# Cleanup on shutdown
cleanup_on_shutdown(
job_queue=job_queue,
health_monitor=health_monitor,
wait_for_current_job=True
)

6
src/utils/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""
Utility modules for Whisper transcription service.
Includes audio processing, formatters, test audio generation, input validation,
circuit breaker, and startup logic.
"""

View File

@@ -0,0 +1,68 @@
#!/usr/bin/env python3
"""
Audio Processing Module
Responsible for audio file validation and preprocessing
"""
import os
import logging
from typing import Union, Any
from pathlib import Path
from faster_whisper import decode_audio
from utils.input_validation import (
validate_audio_file as validate_audio_file_secure,
sanitize_error_message
)
# Log configuration
logger = logging.getLogger(__name__)
def validate_audio_file(audio_path: str, allowed_dirs: list = None) -> None:
"""
Validate if an audio file is valid (with security checks).
Args:
audio_path: Path to the audio file
allowed_dirs: Optional list of allowed base directories
Raises:
FileNotFoundError: If audio file doesn't exist
ValueError: If audio file format is unsupported or file is empty
OSError: If file size cannot be checked
Returns:
None: If validation passes
"""
try:
# Use secure validation
validate_audio_file_secure(audio_path, allowed_dirs)
except Exception as e:
# Re-raise with sanitized error messages
error_msg = sanitize_error_message(str(e))
if "not found" in str(e).lower():
raise FileNotFoundError(error_msg)
elif "size" in str(e).lower():
raise OSError(error_msg)
else:
raise ValueError(error_msg)
def process_audio(audio_path: str) -> Union[str, Any]:
"""
Process audio file, perform decoding and preprocessing
Args:
audio_path: Path to the audio file
Returns:
Union[str, Any]: Processed audio data or original file path
"""
# Try to preprocess audio using decode_audio to handle more formats
try:
audio_data = decode_audio(audio_path)
logger.info(f"Successfully preprocessed audio: {os.path.basename(audio_path)}")
return audio_data
except Exception as audio_error:
logger.warning(f"Audio preprocessing failed, will use file path directly: {str(audio_error)}")
return audio_path

View File

@@ -0,0 +1,291 @@
#!/usr/bin/env python3
"""
Circuit Breaker Pattern Implementation
Prevents repeated failed attempts and provides fail-fast behavior.
Useful for GPU health checks and other operations that may fail repeatedly.
"""
import time
import logging
import threading
from enum import Enum
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Callable, Any, Optional
logger = logging.getLogger(__name__)
class CircuitState(Enum):
"""Circuit breaker states."""
CLOSED = "closed" # Normal operation, requests pass through
OPEN = "open" # Circuit is open, requests fail immediately
HALF_OPEN = "half_open" # Testing if circuit can close
@dataclass
class CircuitBreakerConfig:
"""Configuration for circuit breaker."""
failure_threshold: int = 3 # Failures before opening circuit
success_threshold: int = 2 # Successes before closing from half-open
timeout_seconds: int = 60 # Time before attempting half-open
half_open_max_calls: int = 1 # Max calls to test in half-open state
class CircuitBreaker:
"""
Circuit breaker implementation for preventing repeated failures.
Usage:
breaker = CircuitBreaker(name="gpu_health", failure_threshold=3)
@breaker.call
def check_gpu():
# This function will be protected by circuit breaker
return perform_gpu_check()
# Or use decorator:
@breaker.decorator()
def my_function():
return "result"
"""
def __init__(
self,
name: str,
failure_threshold: int = 3,
success_threshold: int = 2,
timeout_seconds: int = 60,
half_open_max_calls: int = 1
):
"""
Initialize circuit breaker.
Args:
name: Name of the circuit (for logging)
failure_threshold: Number of failures before opening
success_threshold: Number of successes to close from half-open
timeout_seconds: Seconds before transitioning to half-open
half_open_max_calls: Max concurrent calls in half-open state
"""
self.name = name
self.config = CircuitBreakerConfig(
failure_threshold=failure_threshold,
success_threshold=success_threshold,
timeout_seconds=timeout_seconds,
half_open_max_calls=half_open_max_calls
)
self._state = CircuitState.CLOSED
self._failure_count = 0
self._success_count = 0
self._last_failure_time: Optional[datetime] = None
self._half_open_calls = 0
self._lock = threading.RLock()
logger.info(
f"Circuit breaker '{name}' initialized: "
f"failure_threshold={failure_threshold}, "
f"timeout={timeout_seconds}s"
)
@property
def state(self) -> CircuitState:
"""Get current circuit state."""
with self._lock:
self._update_state()
return self._state
@property
def is_closed(self) -> bool:
"""Check if circuit is closed (normal operation)."""
return self.state == CircuitState.CLOSED
@property
def is_open(self) -> bool:
"""Check if circuit is open (failing fast)."""
return self.state == CircuitState.OPEN
@property
def is_half_open(self) -> bool:
"""Check if circuit is half-open (testing)."""
return self.state == CircuitState.HALF_OPEN
def _update_state(self):
"""Update state based on timeout and counters."""
if self._state == CircuitState.OPEN:
# Check if timeout has passed
if self._last_failure_time:
elapsed = datetime.utcnow() - self._last_failure_time
if elapsed.total_seconds() >= self.config.timeout_seconds:
logger.info(
f"Circuit '{self.name}': Transitioning to HALF_OPEN "
f"after {elapsed.total_seconds():.0f}s timeout"
)
self._state = CircuitState.HALF_OPEN
self._half_open_calls = 0
self._success_count = 0
def _on_success(self):
"""Handle successful call."""
with self._lock:
if self._state == CircuitState.HALF_OPEN:
self._success_count += 1
logger.debug(
f"Circuit '{self.name}': Success in HALF_OPEN "
f"({self._success_count}/{self.config.success_threshold})"
)
if self._success_count >= self.config.success_threshold:
logger.info(f"Circuit '{self.name}': Closing circuit after successful test")
self._state = CircuitState.CLOSED
self._failure_count = 0
self._success_count = 0
self._last_failure_time = None
elif self._state == CircuitState.CLOSED:
# Reset failure count on success
self._failure_count = 0
def _on_failure(self, error: Exception):
"""Handle failed call."""
with self._lock:
self._failure_count += 1
self._last_failure_time = datetime.utcnow()
if self._state == CircuitState.HALF_OPEN:
logger.warning(
f"Circuit '{self.name}': Failure in HALF_OPEN, reopening circuit"
)
self._state = CircuitState.OPEN
self._success_count = 0
elif self._state == CircuitState.CLOSED:
logger.debug(
f"Circuit '{self.name}': Failure {self._failure_count}/"
f"{self.config.failure_threshold}"
)
if self._failure_count >= self.config.failure_threshold:
logger.warning(
f"Circuit '{self.name}': Opening circuit after "
f"{self._failure_count} failures. "
f"Will retry in {self.config.timeout_seconds}s"
)
self._state = CircuitState.OPEN
self._success_count = 0
def call(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute function with circuit breaker protection.
Args:
func: Function to execute
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
Function result
Raises:
CircuitBreakerOpen: If circuit is open
Exception: Original exception from func if it fails
"""
with self._lock:
self._update_state()
# Check if circuit is open
if self._state == CircuitState.OPEN:
raise CircuitBreakerOpen(
f"Circuit '{self.name}' is OPEN. "
f"Failing fast to prevent repeated failures. "
f"Last failure: {self._last_failure_time.isoformat() if self._last_failure_time else 'unknown'}. "
f"Will retry in {self.config.timeout_seconds}s"
)
# Check half-open call limit
if self._state == CircuitState.HALF_OPEN:
if self._half_open_calls >= self.config.half_open_max_calls:
raise CircuitBreakerOpen(
f"Circuit '{self.name}' is HALF_OPEN with max calls reached. "
f"Please wait for current test to complete."
)
self._half_open_calls += 1
# Execute function
try:
result = func(*args, **kwargs)
self._on_success()
return result
except Exception as e:
self._on_failure(e)
raise
finally:
# Decrement half-open counter
with self._lock:
if self._state == CircuitState.HALF_OPEN:
self._half_open_calls -= 1
def decorator(self):
"""
Decorator for protecting functions with circuit breaker.
Usage:
breaker = CircuitBreaker("my_service")
@breaker.decorator()
def my_function():
return do_something()
"""
def wrapper(func):
def decorated(*args, **kwargs):
return self.call(func, *args, **kwargs)
return decorated
return wrapper
def reset(self):
"""
Manually reset circuit breaker to closed state.
Useful for:
- Testing
- Manual intervention
- Clearing error state
"""
with self._lock:
logger.info(f"Circuit '{self.name}': Manual reset to CLOSED state")
self._state = CircuitState.CLOSED
self._failure_count = 0
self._success_count = 0
self._last_failure_time = None
self._half_open_calls = 0
def get_stats(self) -> dict:
"""
Get circuit breaker statistics.
Returns:
Dictionary with current state and counters
"""
with self._lock:
self._update_state()
return {
"name": self.name,
"state": self._state.value,
"failure_count": self._failure_count,
"success_count": self._success_count,
"last_failure_time": self._last_failure_time.isoformat() if self._last_failure_time else None,
"config": {
"failure_threshold": self.config.failure_threshold,
"success_threshold": self.config.success_threshold,
"timeout_seconds": self.config.timeout_seconds,
}
}
class CircuitBreakerOpen(Exception):
"""Exception raised when circuit breaker is open."""
pass

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python3
"""
格式化输出模块
负责将转录结果格式化为不同的输出格式VTTSRTJSON
Formatting Output Module
Responsible for formatting transcription results into different output formats (VTT, SRT, JSON, TXT)
"""
import json
@@ -9,13 +9,13 @@ from typing import List, Dict, Any
def format_vtt(segments: List) -> str:
"""
将转录结果格式化为VTT
Format transcription results to VTT
Args:
segments: 转录段落列表
segments: List of transcription segments
Returns:
str: VTT格式的字幕内容
str: Subtitle content in VTT format
"""
vtt_content = "WEBVTT\n\n"
@@ -31,13 +31,13 @@ def format_vtt(segments: List) -> str:
def format_srt(segments: List) -> str:
"""
将转录结果格式化为SRT
Format transcription results to SRT
Args:
segments: 转录段落列表
segments: List of transcription segments
Returns:
str: SRT格式的字幕内容
str: Subtitle content in SRT format
"""
srt_content = ""
index = 1
@@ -53,16 +53,51 @@ def format_srt(segments: List) -> str:
return srt_content
def format_json(segments: List, info: Any) -> str:
def format_txt(segments: List) -> str:
"""
将转录结果格式化为JSON
Format transcription results to plain text
Args:
segments: 转录段落列表
info: 转录信息对象
segments: List of transcription segments
Returns:
str: JSON格式的转录结果
str: Plain text transcription content
"""
text_content = ""
for segment in segments:
text = segment.text.strip()
if text:
# Add the text content
text_content += text
# Add appropriate spacing between segments
# If the text doesn't end with punctuation, add a space
if not text.endswith(('.', '!', '?', ':', ';')):
text_content += " "
else:
# If it ends with punctuation, add a space for natural flow
text_content += " "
# Clean up any trailing whitespace and ensure single line breaks
text_content = text_content.strip()
# Replace multiple spaces with single spaces
while " " in text_content:
text_content = text_content.replace(" ", " ")
return text_content
def format_json(segments: List, info: Any) -> str:
"""
Format transcription results to JSON
Args:
segments: List of transcription segments
info: Transcription information object
Returns:
str: Transcription results in JSON format
"""
result = {
"segments": [{
@@ -86,13 +121,13 @@ def format_json(segments: List, info: Any) -> str:
def format_timestamp(seconds: float) -> str:
"""
格式化时间戳为VTT格式
Format timestamp for VTT format
Args:
seconds: 秒数
seconds: Number of seconds
Returns:
str: 格式化的时间戳 (HH:MM:SS.mmm)
str: Formatted timestamp (HH:MM:SS.mmm)
"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
@@ -101,13 +136,13 @@ def format_timestamp(seconds: float) -> str:
def format_timestamp_srt(seconds: float) -> str:
"""
格式化时间戳为SRT格式
Format timestamp for SRT format
Args:
seconds: 秒数
seconds: Number of seconds
Returns:
str: 格式化的时间戳 (HH:MM:SS,mmm)
str: Formatted timestamp (HH:MM:SS,mmm)
"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
@@ -117,13 +152,13 @@ def format_timestamp_srt(seconds: float) -> str:
def format_time(seconds: float) -> str:
"""
格式化时间为可读格式
Format time into readable format
Args:
seconds: 秒数
seconds: Number of seconds
Returns:
str: 格式化的时间 (HH:MM:SS)
str: Formatted time (HH:MM:SS)
"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)

View File

@@ -0,0 +1,411 @@
#!/usr/bin/env python3
"""
Input Validation and Path Sanitization Module
Provides robust validation for user inputs with security protections
against path traversal, injection attacks, and other malicious inputs.
"""
import os
import re
import logging
from pathlib import Path
from typing import Optional, List
logger = logging.getLogger(__name__)
# Maximum file size (10GB)
MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 * 1024
# Allowed audio file extensions
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"}
# Allowed output formats
ALLOWED_OUTPUT_FORMATS = {"vtt", "srt", "txt", "json"}
# Model name validation (whitelist)
ALLOWED_MODEL_NAMES = {"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"}
# Device validation
ALLOWED_DEVICES = {"cuda", "auto", "cpu"}
# Compute type validation
ALLOWED_COMPUTE_TYPES = {"float16", "int8", "auto"}
class ValidationError(Exception):
"""Base exception for validation errors."""
pass
class PathTraversalError(ValidationError):
"""Exception for path traversal attempts."""
pass
class InvalidFileTypeError(ValidationError):
"""Exception for invalid file types."""
pass
class FileSizeError(ValidationError):
"""Exception for file size issues."""
pass
def sanitize_error_message(error_msg: str, sanitize_paths: bool = True) -> str:
"""
Sanitize error messages to prevent information leakage.
Args:
error_msg: Original error message
sanitize_paths: Whether to sanitize file paths (default: True)
Returns:
Sanitized error message
"""
if not sanitize_paths:
return error_msg
# Replace absolute paths with relative paths
# Pattern: /home/user/... or /media/... or /var/... or /tmp/...
path_pattern = r'(/(?:home|media|var|tmp|opt|usr)/[^\s:,]+)'
def replace_path(match):
full_path = match.group(1)
try:
# Try to get just the filename
basename = os.path.basename(full_path)
return f"<file:{basename}>"
except:
return "<file:redacted>"
sanitized = re.sub(path_pattern, replace_path, error_msg)
# Also sanitize user names if present
sanitized = re.sub(r'/home/([^/]+)/', '/home/<user>/', sanitized)
return sanitized
def validate_path_safe(file_path: str, allowed_dirs: Optional[List[str]] = None) -> Path:
"""
Validate and sanitize a file path to prevent directory traversal attacks.
Args:
file_path: Path to validate
allowed_dirs: Optional list of allowed base directories
Returns:
Resolved Path object
Raises:
PathTraversalError: If path contains traversal attempts
ValidationError: If path is invalid
"""
if not file_path:
raise ValidationError("File path cannot be empty")
# Convert to Path object
try:
path = Path(file_path)
except Exception as e:
raise ValidationError(f"Invalid path format: {sanitize_error_message(str(e))}")
# Check for path traversal attempts
path_str = str(path)
if ".." in path_str:
logger.warning(f"Path traversal attempt detected: {path_str}")
raise PathTraversalError("Path traversal (..) is not allowed")
# Check for null bytes
if "\x00" in path_str:
logger.warning(f"Null byte in path detected: {path_str}")
raise PathTraversalError("Null bytes in path are not allowed")
# Resolve to absolute path (but don't follow symlinks yet)
try:
resolved_path = path.resolve()
except Exception as e:
raise ValidationError(f"Cannot resolve path: {sanitize_error_message(str(e))}")
# If allowed_dirs specified, ensure path is within one of them
if allowed_dirs:
allowed = False
for allowed_dir in allowed_dirs:
try:
allowed_dir_path = Path(allowed_dir).resolve()
# Check if resolved_path is under allowed_dir
resolved_path.relative_to(allowed_dir_path)
allowed = True
break
except ValueError:
# Not relative to this allowed_dir
continue
if not allowed:
logger.warning(
f"Path outside allowed directories: {path_str}, "
f"allowed: {allowed_dirs}"
)
raise PathTraversalError(
f"Path must be within allowed directories. "
f"Allowed: {[os.path.basename(d) for d in allowed_dirs]}"
)
return resolved_path
def validate_audio_file(
file_path: str,
allowed_dirs: Optional[List[str]] = None,
max_size_bytes: int = MAX_FILE_SIZE_BYTES
) -> Path:
"""
Validate audio file path with security checks.
Args:
file_path: Path to audio file
allowed_dirs: Optional list of allowed base directories
max_size_bytes: Maximum allowed file size
Returns:
Validated Path object
Raises:
ValidationError: If validation fails
PathTraversalError: If path traversal detected
FileNotFoundError: If file doesn't exist
InvalidFileTypeError: If file type not allowed
FileSizeError: If file too large
"""
# Validate and sanitize path
validated_path = validate_path_safe(file_path, allowed_dirs)
# Check file exists
if not validated_path.exists():
raise FileNotFoundError(f"Audio file not found: {validated_path.name}")
# Check it's a file (not directory)
if not validated_path.is_file():
raise ValidationError(f"Path is not a file: {validated_path.name}")
# Check file extension
file_ext = validated_path.suffix.lower()
if file_ext not in ALLOWED_AUDIO_EXTENSIONS:
raise InvalidFileTypeError(
f"Unsupported audio format: {file_ext}. "
f"Supported: {', '.join(sorted(ALLOWED_AUDIO_EXTENSIONS))}"
)
# Check file size
try:
file_size = validated_path.stat().st_size
except Exception as e:
raise ValidationError(f"Cannot check file size: {sanitize_error_message(str(e))}")
if file_size == 0:
raise FileSizeError(f"Audio file is empty: {validated_path.name}")
if file_size > max_size_bytes:
raise FileSizeError(
f"File too large: {file_size / (1024**3):.2f}GB. "
f"Maximum: {max_size_bytes / (1024**3):.2f}GB"
)
# Warn for large files (>1GB)
if file_size > 1024 * 1024 * 1024:
logger.warning(
f"Large file: {file_size / (1024**3):.2f}GB, "
f"may require extended processing time"
)
return validated_path
def validate_output_directory(
dir_path: str,
allowed_dirs: Optional[List[str]] = None,
create_if_missing: bool = True
) -> Path:
"""
Validate output directory path.
Args:
dir_path: Directory path
allowed_dirs: Optional list of allowed base directories
create_if_missing: Create directory if it doesn't exist
Returns:
Validated Path object
Raises:
ValidationError: If validation fails
PathTraversalError: If path traversal detected
"""
# Validate and sanitize path
validated_path = validate_path_safe(dir_path, allowed_dirs)
# Create if requested and doesn't exist
if create_if_missing and not validated_path.exists():
try:
validated_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Created output directory: {validated_path}")
except Exception as e:
raise ValidationError(
f"Cannot create output directory: {sanitize_error_message(str(e))}"
)
# Check it's a directory
if validated_path.exists() and not validated_path.is_dir():
raise ValidationError(f"Path exists but is not a directory: {validated_path.name}")
return validated_path
def validate_model_name(model_name: str) -> str:
"""
Validate Whisper model name.
Args:
model_name: Model name to validate
Returns:
Validated model name
Raises:
ValidationError: If model name invalid
"""
if not model_name:
raise ValidationError("Model name cannot be empty")
model_name = model_name.strip().lower()
if model_name not in ALLOWED_MODEL_NAMES:
raise ValidationError(
f"Invalid model name: {model_name}. "
f"Allowed: {', '.join(sorted(ALLOWED_MODEL_NAMES))}"
)
return model_name
def validate_device(device: str) -> str:
"""
Validate device parameter.
Args:
device: Device name to validate
Returns:
Validated device name
Raises:
ValidationError: If device invalid
"""
if not device:
raise ValidationError("Device cannot be empty")
device = device.strip().lower()
if device not in ALLOWED_DEVICES:
raise ValidationError(
f"Invalid device: {device}. "
f"Allowed: {', '.join(sorted(ALLOWED_DEVICES))}"
)
return device
def validate_compute_type(compute_type: str) -> str:
"""
Validate compute type parameter.
Args:
compute_type: Compute type to validate
Returns:
Validated compute type
Raises:
ValidationError: If compute type invalid
"""
if not compute_type:
raise ValidationError("Compute type cannot be empty")
compute_type = compute_type.strip().lower()
if compute_type not in ALLOWED_COMPUTE_TYPES:
raise ValidationError(
f"Invalid compute type: {compute_type}. "
f"Allowed: {', '.join(sorted(ALLOWED_COMPUTE_TYPES))}"
)
return compute_type
def validate_output_format(output_format: str) -> str:
"""
Validate output format parameter.
Args:
output_format: Output format to validate
Returns:
Validated output format
Raises:
ValidationError: If output format invalid
"""
if not output_format:
raise ValidationError("Output format cannot be empty")
output_format = output_format.strip().lower()
if output_format not in ALLOWED_OUTPUT_FORMATS:
raise ValidationError(
f"Invalid output format: {output_format}. "
f"Allowed: {', '.join(sorted(ALLOWED_OUTPUT_FORMATS))}"
)
return output_format
def validate_numeric_range(
value: float,
min_value: float,
max_value: float,
param_name: str
) -> float:
"""
Validate numeric parameter is within range.
Args:
value: Value to validate
min_value: Minimum allowed value
max_value: Maximum allowed value
param_name: Parameter name for error messages
Returns:
Validated value
Raises:
ValidationError: If value out of range
"""
if value < min_value or value > max_value:
raise ValidationError(
f"{param_name} must be between {min_value} and {max_value}, "
f"got {value}"
)
return value
def validate_beam_size(beam_size: int) -> int:
"""Validate beam size parameter."""
return int(validate_numeric_range(beam_size, 1, 20, "beam_size"))
def validate_temperature(temperature: float) -> float:
"""Validate temperature parameter."""
return validate_numeric_range(temperature, 0.0, 1.0, "temperature")

237
src/utils/startup.py Normal file
View File

@@ -0,0 +1,237 @@
#!/usr/bin/env python3
"""
Common Startup Logic Module
Centralizes startup procedures shared between MCP and API servers,
including GPU health checks, job queue initialization, and health monitoring.
"""
import os
import sys
import logging
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
# Import GPU health check with reset
try:
from core.gpu_health import check_gpu_health_with_reset
GPU_HEALTH_CHECK_AVAILABLE = True
except ImportError as e:
logger.warning(f"GPU health check with reset not available: {e}")
GPU_HEALTH_CHECK_AVAILABLE = False
def perform_startup_gpu_check(
required_device: str = "cuda",
auto_reset: bool = True,
exit_on_failure: bool = True
) -> bool:
"""
Perform startup GPU health check with optional auto-reset.
This function:
1. Checks if GPU health check is available
2. Runs comprehensive GPU health check
3. Attempts auto-reset if check fails and auto_reset=True
4. Optionally exits process if check fails
Args:
required_device: Required device ("cuda", "auto")
auto_reset: Enable automatic GPU driver reset on failure
exit_on_failure: Exit process if GPU check fails
Returns:
True if GPU check passed, False otherwise
Side effects:
May exit process if exit_on_failure=True and check fails
"""
if not GPU_HEALTH_CHECK_AVAILABLE:
logger.warning("GPU health check not available, starting without GPU validation")
if exit_on_failure:
logger.error("GPU health check required but not available. Exiting.")
sys.exit(1)
return False
try:
logger.info("=" * 70)
logger.info("PERFORMING STARTUP GPU HEALTH CHECK")
logger.info("=" * 70)
status = check_gpu_health_with_reset(
expected_device=required_device,
auto_reset=auto_reset
)
logger.info("=" * 70)
logger.info("STARTUP GPU CHECK SUCCESSFUL")
logger.info(f"GPU Device: {status.device_name}")
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
logger.info("=" * 70)
return True
except Exception as e:
logger.error("=" * 70)
logger.error("STARTUP GPU CHECK FAILED")
logger.error(f"Error: {e}")
if exit_on_failure:
logger.error("This service requires GPU. Terminating.")
logger.error("=" * 70)
sys.exit(1)
else:
logger.error("Continuing without GPU (may have reduced functionality)")
logger.error("=" * 70)
return False
def initialize_job_queue(
max_queue_size: Optional[int] = None,
metadata_dir: Optional[str] = None
) -> 'JobQueue':
"""
Initialize job queue with environment variable configuration.
Args:
max_queue_size: Override for max queue size (uses env var if None)
metadata_dir: Override for metadata directory (uses env var if None)
Returns:
Initialized JobQueue instance (started)
"""
from core.job_queue import JobQueue
# Get configuration from environment
if max_queue_size is None:
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
if metadata_dir is None:
metadata_dir = os.getenv(
"JOB_METADATA_DIR",
"/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
)
logger.info("Initializing job queue...")
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
job_queue.start()
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
return job_queue
def initialize_health_monitor(
check_interval_minutes: Optional[int] = None,
enabled: Optional[bool] = None
) -> Optional['HealthMonitor']:
"""
Initialize GPU health monitor with environment variable configuration.
Args:
check_interval_minutes: Override for check interval (uses env var if None)
enabled: Override for enabled status (uses env var if None)
Returns:
Initialized HealthMonitor instance (started), or None if disabled
"""
from core.gpu_health import HealthMonitor
# Get configuration from environment
if enabled is None:
enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true"
if not enabled:
logger.info("GPU health monitoring disabled")
return None
if check_interval_minutes is None:
check_interval_minutes = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
health_monitor = HealthMonitor(check_interval_minutes=check_interval_minutes)
health_monitor.start()
logger.info(f"GPU health monitor started (interval={check_interval_minutes} minutes)")
return health_monitor
def startup_sequence(
service_name: str = "whisper-transcription",
require_gpu: bool = True,
initialize_queue: bool = True,
initialize_monitoring: bool = True
) -> Tuple[Optional['JobQueue'], Optional['HealthMonitor']]:
"""
Execute complete startup sequence for a Whisper transcription server.
This function performs all common startup tasks:
1. GPU health check with auto-reset
2. Job queue initialization
3. Health monitor initialization
Args:
service_name: Name of the service (for logging)
require_gpu: Whether GPU is required (exit if not available)
initialize_queue: Whether to initialize job queue
initialize_monitoring: Whether to initialize health monitoring
Returns:
Tuple of (job_queue, health_monitor) - either may be None
Side effects:
May exit process if GPU required but unavailable
"""
logger.info(f"Starting {service_name}...")
# Step 1: GPU health check
gpu_ok = perform_startup_gpu_check(
required_device="cuda",
auto_reset=True,
exit_on_failure=require_gpu
)
if not gpu_ok and require_gpu:
# Should not reach here (exit_on_failure should have exited)
logger.error("GPU check failed and GPU is required")
sys.exit(1)
# Step 2: Initialize job queue
job_queue = None
if initialize_queue:
job_queue = initialize_job_queue()
# Step 3: Initialize health monitor
health_monitor = None
if initialize_monitoring:
health_monitor = initialize_health_monitor()
logger.info(f"{service_name} startup sequence completed")
return job_queue, health_monitor
def cleanup_on_shutdown(
job_queue: Optional['JobQueue'] = None,
health_monitor: Optional['HealthMonitor'] = None,
wait_for_current_job: bool = True
) -> None:
"""
Perform cleanup on server shutdown.
Args:
job_queue: JobQueue instance to stop (if any)
health_monitor: HealthMonitor instance to stop (if any)
wait_for_current_job: Wait for current job to complete before stopping
"""
logger.info("Shutting down...")
if job_queue:
job_queue.stop(wait_for_current=wait_for_current_job)
logger.info("Job queue stopped")
if health_monitor:
health_monitor.stop()
logger.info("Health monitor stopped")
logger.info("Shutdown complete")

View File

@@ -0,0 +1,96 @@
"""
Test audio generator for GPU health checks.
Generates realistic test audio with speech using TTS (text-to-speech).
"""
import os
import tempfile
def generate_test_audio(duration_seconds: float = 3.0, frequency: int = 440) -> str:
"""
Generate a test audio file with real speech for GPU health checks.
Args:
duration_seconds: Duration of audio in seconds (default: 3.0)
frequency: Legacy parameter, ignored (kept for backward compatibility)
Returns:
str: Path to temporary audio file
Implementation:
- Generate real speech using gTTS (Google Text-to-Speech)
- Fallback to pyttsx3 if gTTS fails or is unavailable
- Raises RuntimeError if both TTS engines fail
- Save as MP3 format
- Store in system temp directory
- Reuse same file if exists (cache)
"""
# Use a consistent filename in temp directory for caching
temp_dir = tempfile.gettempdir()
audio_path = os.path.join(temp_dir, f"whisper_test_voice_{int(duration_seconds)}s.mp3")
# Return cached file if it exists and is valid
if os.path.exists(audio_path):
try:
# Verify file is readable and not empty
if os.path.getsize(audio_path) > 0:
return audio_path
except Exception:
# If file is corrupted, regenerate it
pass
# Generate speech with different text based on duration
if duration_seconds >= 3:
text = "This is a test of the Whisper speech recognition system. Testing one, two, three."
elif duration_seconds >= 2:
text = "This is a test of the Whisper system."
else:
text = "Testing Whisper."
# Try gTTS first (better quality, requires internet)
try:
from gtts import gTTS
tts = gTTS(text=text, lang='en', slow=False)
tts.save(audio_path)
return audio_path
except Exception as e:
print(f"gTTS failed ({e}), trying pyttsx3...")
# Fallback to pyttsx3 (offline, lower quality)
try:
import pyttsx3
engine = pyttsx3.init()
engine.save_to_file(text, audio_path)
engine.runAndWait()
# Verify file was created
if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0:
return audio_path
except Exception as e:
raise RuntimeError(
f"Failed to generate test audio. Both gTTS and pyttsx3 failed. "
f"gTTS error: {e}. Please ensure TTS dependencies are installed: "
f"pip install gTTS pyttsx3"
)
def cleanup_test_audio() -> None:
"""
Remove all cached test audio files from temp directory.
Useful for cleanup after testing or to force regeneration.
"""
temp_dir = tempfile.gettempdir()
# Find all test audio files
for filename in os.listdir(temp_dir):
if (filename.startswith("whisper_test_") and
(filename.endswith(".wav") or filename.endswith(".mp3"))):
filepath = os.path.join(temp_dir, filename)
try:
os.remove(filepath)
except Exception:
# Ignore errors during cleanup
pass

View File

@@ -1,16 +0,0 @@
@echo off
echo 启动Whisper语音识别MCP服务器...
:: 激活虚拟环境(如果存在)
if exist "..\venv\Scripts\activate.bat" (
call ..\venv\Scripts\activate.bat
)
:: 运行MCP服务器
python whisper_server.py
:: 如果出错,暂停以查看错误信息
if %ERRORLEVEL% neq 0 (
echo 服务器启动失败,错误代码: %ERRORLEVEL%
pause
)

View File

@@ -0,0 +1,24 @@
[program:whisper-api-server]
command=/home/uad/agents/tools/mcp-transcriptor/venv/bin/python /home/uad/agents/tools/mcp-transcriptor/src/servers/api_server.py
directory=/home/uad/agents/tools/mcp-transcriptor
user=uad
autostart=true
autorestart=true
redirect_stderr=true
stdout_logfile=/home/uad/agents/tools/mcp-transcriptor/logs/transcriptor-api.log
stdout_logfile_maxbytes=50MB
stdout_logfile_backups=10
environment=
PYTHONPATH="/home/uad/agents/tools/mcp-transcriptor/src",
CUDA_VISIBLE_DEVICES="0",
API_HOST="0.0.0.0",
API_PORT="8000",
WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/models",
TRANSCRIPTION_OUTPUT_DIR="/home/uad/agents/tools/mcp-transcriptor/outputs",
TRANSCRIPTION_BATCH_OUTPUT_DIR="/home/uad/agents/tools/mcp-transcriptor/outputs/batch",
TRANSCRIPTION_MODEL="large-v3",
TRANSCRIPTION_DEVICE="auto",
TRANSCRIPTION_COMPUTE_TYPE="auto",
TRANSCRIPTION_OUTPUT_FORMAT="txt"
stopwaitsecs=10
stopsignal=TERM

View File

@@ -0,0 +1,537 @@
#!/usr/bin/env python3
"""
Test Phase 2: Async Job Queue Integration
Tests the async job queue system for both API and MCP servers.
Validates all new endpoints and error handling.
"""
import os
import sys
import time
import json
import logging
import requests
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# Add src to path (go up one level from tests/ to root)
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
# Color codes for terminal output
class Colors:
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
END = '\033[0m'
BOLD = '\033[1m'
def print_success(msg):
print(f"{Colors.GREEN}{msg}{Colors.END}")
def print_error(msg):
print(f"{Colors.RED}{msg}{Colors.END}")
def print_info(msg):
print(f"{Colors.BLUE} {msg}{Colors.END}")
def print_section(msg):
print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}")
print(f"{Colors.BOLD}{Colors.YELLOW}{msg}{Colors.END}")
print(f"{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}\n")
class Phase2Tester:
def __init__(self, api_url="http://localhost:8000"):
self.api_url = api_url
self.test_results = []
def test(self, name, func):
"""Run a test and record result"""
try:
logger.info(f"Testing: {name}")
print_info(f"Testing: {name}")
func()
logger.info(f"PASSED: {name}")
print_success(f"PASSED: {name}")
self.test_results.append((name, True, None))
return True
except AssertionError as e:
logger.error(f"FAILED: {name} - {str(e)}")
print_error(f"FAILED: {name}")
print_error(f" Reason: {str(e)}")
self.test_results.append((name, False, str(e)))
return False
except Exception as e:
logger.error(f"ERROR: {name} - {str(e)}")
print_error(f"ERROR: {name}")
print_error(f" Exception: {str(e)}")
self.test_results.append((name, False, f"Exception: {str(e)}"))
return False
def print_summary(self):
"""Print test summary"""
print_section("TEST SUMMARY")
passed = sum(1 for _, result, _ in self.test_results if result)
failed = len(self.test_results) - passed
for name, result, error in self.test_results:
if result:
print_success(f"{name}")
else:
print_error(f"{name}")
if error:
print(f" {error}")
print(f"\n{Colors.BOLD}Total: {len(self.test_results)} | ", end="")
print(f"{Colors.GREEN}Passed: {passed}{Colors.END} | ", end="")
print(f"{Colors.RED}Failed: {failed}{Colors.END}\n")
return failed == 0
# ========================================================================
# API Server Tests
# ========================================================================
def test_api_root_endpoint(self):
"""Test GET / returns new API information"""
logger.info(f"GET {self.api_url}/")
resp = requests.get(f"{self.api_url}/")
logger.info(f"Response status: {resp.status_code}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
logger.info(f"Response data: {json.dumps(data, indent=2)}")
assert data["version"] == "0.2.0", "Version should be 0.2.0"
assert "POST /jobs" in str(data["endpoints"]), "Should have POST /jobs endpoint"
assert "workflow" in data, "Should have workflow documentation"
def test_api_health_endpoint(self):
"""Test GET /health still works"""
logger.info(f"GET {self.api_url}/health")
resp = requests.get(f"{self.api_url}/health")
logger.info(f"Response status: {resp.status_code}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
logger.info(f"Response data: {data}")
assert data["status"] == "healthy", "Health check should return healthy"
def test_api_models_endpoint(self):
"""Test GET /models still works"""
logger.info(f"GET {self.api_url}/models")
resp = requests.get(f"{self.api_url}/models")
logger.info(f"Response status: {resp.status_code}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
logger.info(f"Available models: {data.get('available_models', [])}")
assert "available_models" in data, "Should return available models"
def test_api_gpu_health_endpoint(self):
"""Test GET /health/gpu returns GPU status"""
logger.info(f"GET {self.api_url}/health/gpu")
resp = requests.get(f"{self.api_url}/health/gpu")
logger.info(f"Response status: {resp.status_code}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
logger.info(f"GPU health: {json.dumps(data, indent=2)}")
assert "gpu_available" in data, "Should have gpu_available field"
assert "gpu_working" in data, "Should have gpu_working field"
assert "interpretation" in data, "Should have interpretation field"
print_info(f" GPU Status: {data.get('interpretation', 'unknown')}")
def test_api_submit_job_invalid_audio(self):
"""Test POST /jobs with invalid audio path returns 400"""
payload = {
"audio_path": "/nonexistent/file.mp3",
"model_name": "tiny",
"output_format": "txt"
}
logger.info(f"POST {self.api_url}/jobs with invalid audio path")
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
resp = requests.post(f"{self.api_url}/jobs", json=payload)
logger.info(f"Response status: {resp.status_code}")
logger.info(f"Response: {resp.json()}")
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
data = resp.json()
assert "error" in data["detail"], "Should have error field"
assert data["detail"]["error"] == "Invalid audio file", f"Wrong error type: {data['detail']['error']}"
print_info(f" Error message: {data['detail']['message'][:50]}...")
def test_api_submit_job_cpu_device_rejected(self):
"""Test POST /jobs with device=cpu is rejected (400)"""
# Create a test audio file first
logger.info("Creating test audio file...")
test_audio = self._create_test_audio_file()
logger.info(f"Test audio created at: {test_audio}")
payload = {
"audio_path": test_audio,
"model_name": "tiny",
"device": "cpu",
"output_format": "txt"
}
logger.info(f"POST {self.api_url}/jobs with device=cpu")
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
resp = requests.post(f"{self.api_url}/jobs", json=payload)
logger.info(f"Response status: {resp.status_code}")
logger.info(f"Response: {resp.json()}")
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
data = resp.json()
assert "error" in data["detail"], "Should have error field"
assert "Invalid device" in data["detail"]["error"] or "CPU" in data["detail"]["message"], \
"Should reject CPU device"
def test_api_submit_job_success(self):
"""Test POST /jobs with valid audio returns job_id"""
logger.info("Creating test audio file...")
test_audio = self._create_test_audio_file()
logger.info(f"Test audio created at: {test_audio}")
payload = {
"audio_path": test_audio,
"model_name": "tiny",
"device": "auto",
"output_format": "txt"
}
logger.info(f"POST {self.api_url}/jobs with valid audio")
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
resp = requests.post(f"{self.api_url}/jobs", json=payload)
logger.info(f"Response status: {resp.status_code}")
logger.info(f"Response: {json.dumps(resp.json(), indent=2)}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
assert "job_id" in data, "Should return job_id"
assert "status" in data, "Should return status"
assert data["status"] == "queued", f"Status should be queued, got {data['status']}"
assert "queue_position" in data, "Should return queue_position"
assert "message" in data, "Should return message"
logger.info(f"Job submitted successfully: {data['job_id']}")
print_info(f" Job ID: {data['job_id']}")
print_info(f" Queue position: {data['queue_position']}")
# Store job_id for later tests
self.test_job_id = data["job_id"]
def test_api_get_job_status(self):
"""Test GET /jobs/{job_id} returns job status"""
if not hasattr(self, 'test_job_id'):
print_info(" Skipping (no test_job_id from previous test)")
return
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
assert "job_id" in data, "Should return job_id"
assert "status" in data, "Should return status"
assert data["status"] in ["queued", "running", "completed", "failed"], \
f"Invalid status: {data['status']}"
print_info(f" Status: {data['status']}")
def test_api_get_job_status_not_found(self):
"""Test GET /jobs/{job_id} with invalid ID returns 404"""
fake_job_id = "00000000-0000-0000-0000-000000000000"
resp = requests.get(f"{self.api_url}/jobs/{fake_job_id}")
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}"
data = resp.json()
assert "error" in data["detail"], "Should have error field"
assert data["detail"]["error"] == "Job not found", f"Wrong error: {data['detail']['error']}"
def test_api_get_job_result_not_completed(self):
"""Test GET /jobs/{job_id}/result when job not completed returns 409"""
if not hasattr(self, 'test_job_id'):
print_info(" Skipping (no test_job_id from previous test)")
return
# Check current status
status_resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
current_status = status_resp.json()["status"]
if current_status == "completed":
print_info(" Skipping (job already completed)")
return
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}/result")
assert resp.status_code == 409, f"Expected 409, got {resp.status_code}"
data = resp.json()
assert "error" in data["detail"], "Should have error field"
assert data["detail"]["error"] == "Job not completed", f"Wrong error: {data['detail']['error']}"
assert "current_status" in data["detail"], "Should include current_status"
def test_api_list_jobs(self):
"""Test GET /jobs returns job list"""
resp = requests.get(f"{self.api_url}/jobs")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
assert "jobs" in data, "Should have jobs field"
assert "total" in data, "Should have total field"
assert isinstance(data["jobs"], list), "Jobs should be a list"
print_info(f" Total jobs: {data['total']}")
def test_api_list_jobs_with_filter(self):
"""Test GET /jobs?status=queued filters by status"""
resp = requests.get(f"{self.api_url}/jobs?status=queued&limit=10")
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
data = resp.json()
assert "jobs" in data, "Should have jobs field"
assert "filters" in data, "Should have filters field"
assert data["filters"]["status"] == "queued", "Filter should be applied"
# All returned jobs should be queued
for job in data["jobs"]:
assert job["status"] == "queued", f"Job {job['job_id']} has wrong status: {job['status']}"
def test_api_wait_for_job_completion(self):
"""Test waiting for job to complete and retrieving result"""
if not hasattr(self, 'test_job_id'):
logger.warning("Skipping - no test_job_id from previous test")
print_info(" Skipping (no test_job_id from previous test)")
return
logger.info(f"Waiting for job {self.test_job_id} to complete (max 60s)...")
print_info(" Waiting for job to complete (max 60s)...")
max_wait = 60
start_time = time.time()
while time.time() - start_time < max_wait:
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
data = resp.json()
status = data["status"]
elapsed = int(time.time() - start_time)
logger.info(f"Job status: {status} (elapsed: {elapsed}s)")
print_info(f" Status: {status} (elapsed: {elapsed}s)")
if status == "completed":
logger.info("Job completed successfully!")
print_success(" Job completed!")
# Now get the result
logger.info("Fetching job result...")
result_resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}/result")
logger.info(f"Result response status: {result_resp.status_code}")
assert result_resp.status_code == 200, f"Expected 200, got {result_resp.status_code}"
result_data = result_resp.json()
logger.info(f"Result data keys: {result_data.keys()}")
assert "result" in result_data, "Should have result field"
assert len(result_data["result"]) > 0, "Result should not be empty"
actual_text = result_data["result"].strip()
logger.info(f"Transcription result: '{actual_text}'")
print_info(f" Transcription: '{actual_text}'")
return
elif status == "failed":
error_msg = f"Job failed: {data.get('error', 'unknown error')}"
logger.error(error_msg)
raise AssertionError(error_msg)
time.sleep(2)
error_msg = f"Job did not complete within {max_wait}s"
logger.error(error_msg)
raise AssertionError(error_msg)
# ========================================================================
# MCP Server Tests (Import-based)
# ========================================================================
def test_mcp_imports(self):
"""Test MCP server modules can be imported"""
try:
logger.info("Importing MCP server module...")
from servers import whisper_server
logger.info("Checking for new async tools...")
assert hasattr(whisper_server, 'transcribe_async'), "Should have transcribe_async tool"
assert hasattr(whisper_server, 'get_job_status'), "Should have get_job_status tool"
assert hasattr(whisper_server, 'get_job_result'), "Should have get_job_result tool"
assert hasattr(whisper_server, 'list_transcription_jobs'), "Should have list_transcription_jobs tool"
assert hasattr(whisper_server, 'check_gpu_health'), "Should have check_gpu_health tool"
assert hasattr(whisper_server, 'get_model_info_api'), "Should have get_model_info_api tool"
logger.info("All new tools found!")
# Verify old tools are removed
logger.info("Verifying old tools are removed...")
assert not hasattr(whisper_server, 'transcribe'), "Old transcribe tool should be removed"
assert not hasattr(whisper_server, 'batch_transcribe_audio'), "Old batch_transcribe_audio tool should be removed"
logger.info("Old tools successfully removed!")
except ImportError as e:
logger.error(f"Failed to import MCP server: {e}")
raise AssertionError(f"Failed to import MCP server: {e}")
def test_job_queue_integration(self):
"""Test JobQueue integration is working"""
from core.job_queue import JobQueue, JobStatus
# Create a test queue
test_queue = JobQueue(max_queue_size=5, metadata_dir="/tmp/test_job_queue")
try:
# Verify it can be started
test_queue.start()
assert test_queue._worker_thread is not None, "Worker thread should be created"
assert test_queue._worker_thread.is_alive(), "Worker thread should be running"
finally:
# Clean up
test_queue.stop(wait_for_current=False)
def test_health_monitor_integration(self):
"""Test HealthMonitor integration is working"""
from core.gpu_health import HealthMonitor
# Create a test monitor
test_monitor = HealthMonitor(check_interval_minutes=60) # Long interval
try:
# Verify it can be started
test_monitor.start()
assert test_monitor._thread is not None, "Monitor thread should be created"
assert test_monitor._thread.is_alive(), "Monitor thread should be running"
# Check we can get status
status = test_monitor.get_latest_status()
assert status is not None, "Should have initial status"
finally:
# Clean up
test_monitor.stop()
# ========================================================================
# Helper Methods
# ========================================================================
def _create_test_audio_file(self):
"""Get the path to the test audio file"""
# Use relative path from project root
project_root = Path(__file__).parent.parent
test_audio_path = str(project_root / "data" / "test.mp3")
if not os.path.exists(test_audio_path):
raise FileNotFoundError(f"Test audio file not found: {test_audio_path}")
return test_audio_path
def main():
print_section("PHASE 2: ASYNC JOB QUEUE INTEGRATION TESTS")
logger.info("=" * 70)
logger.info("PHASE 2: ASYNC JOB QUEUE INTEGRATION TESTS")
logger.info("=" * 70)
# Check if API server is running
api_url = os.getenv("API_URL", "http://localhost:8000")
logger.info(f"Testing API server at: {api_url}")
print_info(f"Testing API server at: {api_url}")
try:
logger.info("Checking API server health...")
resp = requests.get(f"{api_url}/health", timeout=2)
logger.info(f"Health check status: {resp.status_code}")
if resp.status_code != 200:
logger.error(f"API server not responding correctly at {api_url}")
print_error(f"API server not responding correctly at {api_url}")
print_error("Please start the API server with: ./run_api_server.sh")
return 1
except requests.exceptions.RequestException as e:
logger.error(f"Cannot connect to API server: {e}")
print_error(f"Cannot connect to API server at {api_url}")
print_error("Please start the API server with: ./run_api_server.sh")
return 1
logger.info(f"API server is running at {api_url}")
print_success(f"API server is running at {api_url}")
# Create tester
tester = Phase2Tester(api_url=api_url)
# ========================================================================
# Run API Tests
# ========================================================================
print_section("API SERVER TESTS")
logger.info("Starting API server tests...")
tester.test("API Root Endpoint", tester.test_api_root_endpoint)
tester.test("API Health Endpoint", tester.test_api_health_endpoint)
tester.test("API Models Endpoint", tester.test_api_models_endpoint)
tester.test("API GPU Health Endpoint", tester.test_api_gpu_health_endpoint)
print_section("API JOB SUBMISSION TESTS")
tester.test("Submit Job - Invalid Audio (400)", tester.test_api_submit_job_invalid_audio)
tester.test("Submit Job - CPU Device Rejected (400)", tester.test_api_submit_job_cpu_device_rejected)
tester.test("Submit Job - Success (200)", tester.test_api_submit_job_success)
print_section("API JOB STATUS TESTS")
tester.test("Get Job Status - Success", tester.test_api_get_job_status)
tester.test("Get Job Status - Not Found (404)", tester.test_api_get_job_status_not_found)
tester.test("Get Job Result - Not Completed (409)", tester.test_api_get_job_result_not_completed)
print_section("API JOB LISTING TESTS")
tester.test("List Jobs", tester.test_api_list_jobs)
tester.test("List Jobs - With Filter", tester.test_api_list_jobs_with_filter)
print_section("API JOB COMPLETION TEST")
tester.test("Wait for Job Completion & Get Result", tester.test_api_wait_for_job_completion)
# ========================================================================
# Run MCP Tests
# ========================================================================
print_section("MCP SERVER TESTS")
logger.info("Starting MCP server tests...")
tester.test("MCP Module Imports", tester.test_mcp_imports)
tester.test("JobQueue Integration", tester.test_job_queue_integration)
tester.test("HealthMonitor Integration", tester.test_health_monitor_integration)
# ========================================================================
# Print Summary
# ========================================================================
logger.info("All tests completed, generating summary...")
success = tester.print_summary()
if success:
logger.info("ALL TESTS PASSED!")
print_section("ALL TESTS PASSED! ✓")
return 0
else:
logger.error("SOME TESTS FAILED!")
print_section("SOME TESTS FAILED! ✗")
return 1
if __name__ == "__main__":
sys.exit(main())

287
tests/test_core_components.py Executable file
View File

@@ -0,0 +1,287 @@
#!/usr/bin/env python3
"""
Test script for Phase 1 components.
Tests:
1. Test audio file validation
2. GPU health check
3. Job queue operations
IMPORTANT: This service requires GPU. Tests will fail if GPU is not available.
"""
import sys
import os
import time
import torch
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%H:%M:%S'
)
# Add src to path (go up one level from tests/ to root)
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from core.gpu_health import check_gpu_health, HealthMonitor
from core.job_queue import JobQueue, JobStatus
def check_gpu_available():
"""
Check if GPU is available. Exit if not.
This service requires GPU and will not run on CPU.
"""
print("\n" + "="*60)
print("GPU REQUIREMENT CHECK")
print("="*60)
if not torch.cuda.is_available():
print("✗ CUDA not available - GPU is required for this service")
print(" This service is configured for GPU-only operation")
print(" Please ensure CUDA is properly installed and GPU is accessible")
print("="*60)
sys.exit(1)
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print(f"✓ GPU available: {gpu_name}")
print(f"✓ GPU memory: {gpu_memory:.2f} GB")
print("="*60)
def test_audio_file():
"""Test audio file existence and validity."""
print("\n" + "="*60)
print("TEST 1: Test Audio File")
print("="*60)
try:
# Use the actual test audio file (relative to project root)
project_root = os.path.join(os.path.dirname(__file__), '..')
audio_path = os.path.join(project_root, "data/test.mp3")
# Verify file exists
assert os.path.exists(audio_path), "Audio file not found"
print(f"✓ Test audio file exists: {audio_path}")
# Verify file is not empty
file_size = os.path.getsize(audio_path)
assert file_size > 0, "Audio file is empty"
print(f"✓ Audio file size: {file_size} bytes")
return True
except Exception as e:
print(f"✗ Audio file test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_gpu_health():
"""Test GPU health check."""
print("\n" + "="*60)
print("TEST 2: GPU Health Check")
print("="*60)
try:
# Test with cuda device (enforcing GPU requirement)
print("\nRunning health check with device='cuda'...")
logging.info("Starting GPU health check...")
status = check_gpu_health(expected_device="cuda")
logging.info("GPU health check completed")
print(f"✓ Health check completed")
print(f" - GPU available: {status.gpu_available}")
print(f" - GPU working: {status.gpu_working}")
print(f" - Device used: {status.device_used}")
print(f" - Device name: {status.device_name}")
print(f" - Memory total: {status.memory_total_gb:.2f} GB")
print(f" - Memory available: {status.memory_available_gb:.2f} GB")
print(f" - Test duration: {status.test_duration_seconds:.2f}s")
print(f" - Error: {status.error}")
# Test health monitor
print("\nTesting HealthMonitor...")
monitor = HealthMonitor(check_interval_minutes=1)
monitor.start()
print("✓ Health monitor started")
time.sleep(1)
latest = monitor.get_latest_status()
assert latest is not None, "No status available from monitor"
print(f"✓ Latest status retrieved: {latest.device_used}")
monitor.stop()
print("✓ Health monitor stopped")
return True
except Exception as e:
print(f"✗ GPU health test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_job_queue():
"""Test job queue operations."""
print("\n" + "="*60)
print("TEST 3: Job Queue")
print("="*60)
# Create temp directory for testing
import tempfile
temp_dir = tempfile.mkdtemp(prefix="job_queue_test_")
print(f"Using temp directory: {temp_dir}")
try:
# Initialize job queue
print("\nInitializing job queue...")
job_queue = JobQueue(max_queue_size=10, metadata_dir=temp_dir)
job_queue.start()
print("✓ Job queue started")
# Use the actual test audio file (relative to project root)
project_root = os.path.join(os.path.dirname(__file__), '..')
audio_path = os.path.join(project_root, "data/test.mp3")
# Test job submission
print("\nSubmitting test job...")
logging.info("Submitting transcription job to queue...")
job_info = job_queue.submit_job(
audio_path=audio_path,
model_name="tiny",
device="cuda", # Enforcing GPU requirement
output_format="txt"
)
job_id = job_info["job_id"]
logging.info(f"Job submitted: {job_id}")
print(f"✓ Job submitted: {job_id}")
print(f" - Status: {job_info['status']}")
print(f" - Queue position: {job_info['queue_position']}")
# Test job status retrieval
print("\nRetrieving job status...")
logging.info("About to call get_job_status()...")
status = job_queue.get_job_status(job_id)
logging.info(f"get_job_status() returned: {status['status']}")
print(f"✓ Job status retrieved")
print(f" - Status: {status['status']}")
print(f" - Queue position: {status['queue_position']}")
# Wait for job to process
print("\nWaiting for job to process (max 30 seconds)...", flush=True)
logging.info("Waiting for transcription to complete...")
max_wait = 30
start = time.time()
while time.time() - start < max_wait:
logging.info("Calling get_job_status()...")
status = job_queue.get_job_status(job_id)
print(f" Status: {status['status']}", flush=True)
logging.info(f"Job status: {status['status']}")
if status['status'] in ['completed', 'failed']:
logging.info("Job completed or failed, breaking out of loop")
break
logging.info("Job still running, sleeping 2 seconds...")
time.sleep(2)
final_status = job_queue.get_job_status(job_id)
print(f"\nFinal job status: {final_status['status']}")
if final_status['status'] == 'completed':
print(f"✓ Job completed successfully")
print(f" - Result path: {final_status['result_path']}")
print(f" - Processing time: {final_status['processing_time_seconds']:.2f}s")
# Test result retrieval
print("\nRetrieving job result...")
logging.info("Calling get_job_result()...")
result = job_queue.get_job_result(job_id)
logging.info(f"Result retrieved: {len(result)} characters")
print(f"✓ Result retrieved ({len(result)} characters)")
print(f" Preview: {result[:100]}...")
elif final_status['status'] == 'failed':
print(f"✗ Job failed: {final_status['error']}")
# Test persistence by stopping and restarting
print("\nTesting persistence...")
logging.info("Stopping job queue...")
job_queue.stop(wait_for_current=False)
print("✓ Job queue stopped")
logging.info("Job queue stopped")
logging.info("Restarting job queue...")
job_queue2 = JobQueue(max_queue_size=10, metadata_dir=temp_dir)
job_queue2.start()
print("✓ Job queue restarted")
logging.info("Job queue restarted")
logging.info("Checking job status after restart...")
status_after_restart = job_queue2.get_job_status(job_id)
print(f"✓ Job still exists after restart: {status_after_restart['status']}")
logging.info(f"Job status after restart: {status_after_restart['status']}")
logging.info("Stopping job queue 2...")
job_queue2.stop()
logging.info("Job queue 2 stopped")
# Cleanup
import shutil
shutil.rmtree(temp_dir)
print(f"✓ Cleaned up temp directory")
return final_status['status'] == 'completed'
except Exception as e:
print(f"✗ Job queue test failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Run all tests."""
print("\n" + "="*60)
print("PHASE 1 COMPONENT TESTS")
print("="*60)
# Check GPU availability first - exit if no GPU
check_gpu_available()
results = {
"Test Audio File": test_audio_file(),
"GPU Health Check": test_gpu_health(),
"Job Queue": test_job_queue(),
}
print("\n" + "="*60)
print("TEST SUMMARY")
print("="*60)
for test_name, passed in results.items():
status = "✓ PASSED" if passed else "✗ FAILED"
print(f"{test_name:.<40} {status}")
all_passed = all(results.values())
print("\n" + "="*60)
if all_passed:
print("ALL TESTS PASSED ✓")
else:
print("SOME TESTS FAILED ✗")
print("="*60)
return 0 if all_passed else 1
if __name__ == "__main__":
sys.exit(main())

523
tests/test_e2e_integration.py Executable file
View File

@@ -0,0 +1,523 @@
#!/usr/bin/env python3
"""
Test Phase 4: End-to-End Integration Testing
Comprehensive integration tests for the async job queue system.
Tests all scenarios from the DEV_PLAN.md Phase 4 checklist.
"""
import os
import sys
import time
import json
import logging
import requests
import subprocess
import signal
from pathlib import Path
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# Add src to path (go up one level from tests/ to root)
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
# Color codes for terminal output
class Colors:
GREEN = '\033[92m'
RED = '\033[91m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
CYAN = '\033[96m'
END = '\033[0m'
BOLD = '\033[1m'
def print_success(msg):
print(f"{Colors.GREEN}{msg}{Colors.END}")
def print_error(msg):
print(f"{Colors.RED}{msg}{Colors.END}")
def print_info(msg):
print(f"{Colors.BLUE} {msg}{Colors.END}")
def print_warning(msg):
print(f"{Colors.YELLOW}{msg}{Colors.END}")
def print_section(msg):
print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}")
print(f"{Colors.BOLD}{Colors.YELLOW}{msg}{Colors.END}")
print(f"{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}\n")
class Phase4Tester:
def __init__(self, api_url="http://localhost:8000", test_audio=None):
self.api_url = api_url
# Use relative path from project root if not provided
if test_audio is None:
project_root = Path(__file__).parent.parent
test_audio = str(project_root / "data" / "test.mp3")
self.test_audio = test_audio
self.test_results = []
self.server_process = None
# Verify test audio exists
if not os.path.exists(test_audio):
raise FileNotFoundError(f"Test audio file not found: {test_audio}")
def test(self, name, func):
"""Run a test and record result"""
try:
logger.info(f"Testing: {name}")
print_info(f"Testing: {name}")
func()
logger.info(f"PASSED: {name}")
print_success(f"PASSED: {name}")
self.test_results.append((name, True, None))
return True
except AssertionError as e:
logger.error(f"FAILED: {name} - {str(e)}")
print_error(f"FAILED: {name}")
print_error(f" Reason: {str(e)}")
self.test_results.append((name, False, str(e)))
return False
except Exception as e:
logger.error(f"ERROR: {name} - {str(e)}")
print_error(f"ERROR: {name}")
print_error(f" Exception: {str(e)}")
self.test_results.append((name, False, f"Exception: {str(e)}"))
return False
def start_api_server(self, wait_time=5):
"""Start the API server in background"""
print_info("Starting API server...")
# Script is in project root, one level up from tests/
script_path = Path(__file__).parent.parent / "run_api_server.sh"
# Start server in background
self.server_process = subprocess.Popen(
[str(script_path)],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
preexec_fn=os.setsid
)
# Wait for server to start
time.sleep(wait_time)
# Verify server is running
try:
response = requests.get(f"{self.api_url}/health", timeout=5)
if response.status_code == 200:
print_success("API server started successfully")
return True
except:
pass
print_error("API server failed to start")
return False
def stop_api_server(self):
"""Stop the API server"""
if self.server_process:
print_info("Stopping API server...")
os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM)
self.server_process.wait(timeout=10)
print_success("API server stopped")
def wait_for_job_completion(self, job_id, timeout=60, poll_interval=2):
"""Poll job status until completed or failed"""
start_time = time.time()
last_status = None
while time.time() - start_time < timeout:
try:
response = requests.get(f"{self.api_url}/jobs/{job_id}")
assert response.status_code == 200, f"Failed to get job status: {response.status_code}"
status_data = response.json()
current_status = status_data['status']
# Print status changes
if current_status != last_status:
if status_data.get('queue_position') is not None:
print_info(f" Job status: {current_status}, queue position: {status_data['queue_position']}")
else:
print_info(f" Job status: {current_status}")
last_status = current_status
if current_status == "completed":
return status_data
elif current_status == "failed":
raise AssertionError(f"Job failed: {status_data.get('error', 'Unknown error')}")
time.sleep(poll_interval)
except requests.exceptions.RequestException as e:
raise AssertionError(f"Request failed: {e}")
raise AssertionError(f"Job did not complete within {timeout} seconds")
# ========================================================================
# TEST 1: Single Job Submission and Completion
# ========================================================================
def test_single_job_flow(self):
"""Test complete job flow: submit → poll → get result"""
# Submit job
print_info(" Submitting job...")
response = requests.post(f"{self.api_url}/jobs", json={
"audio_path": self.test_audio,
"model_name": "large-v3",
"output_format": "txt"
})
assert response.status_code == 200, f"Job submission failed: {response.status_code}"
job_data = response.json()
assert 'job_id' in job_data, "No job_id in response"
# Status can be 'queued' or 'running' (if queue is empty and job starts immediately)
assert job_data['status'] in ['queued', 'running'], f"Expected status 'queued' or 'running', got '{job_data['status']}'"
job_id = job_data['job_id']
print_success(f" Job submitted: {job_id}")
# Wait for completion
print_info(" Waiting for job completion...")
final_status = self.wait_for_job_completion(job_id)
assert final_status['status'] == 'completed', "Job did not complete"
assert final_status['result_path'] is not None, "No result_path in completed job"
assert final_status['processing_time_seconds'] is not None, "No processing time"
print_success(f" Job completed in {final_status['processing_time_seconds']:.2f}s")
# Get result
print_info(" Retrieving result...")
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
assert response.status_code == 200, f"Failed to get result: {response.status_code}"
result_text = response.text
assert len(result_text) > 0, "Empty result"
print_success(f" Got result: {len(result_text)} characters")
# ========================================================================
# TEST 2: Multiple Jobs in Queue (FIFO)
# ========================================================================
def test_multiple_jobs_fifo(self):
"""Test multiple jobs are processed in FIFO order"""
job_ids = []
# Submit 3 jobs
print_info(" Submitting 3 jobs...")
for i in range(3):
response = requests.post(f"{self.api_url}/jobs", json={
"audio_path": self.test_audio,
"model_name": "tiny", # Use tiny model for faster processing
"output_format": "txt"
})
assert response.status_code == 200, f"Job {i+1} submission failed"
job_data = response.json()
job_ids.append(job_data['job_id'])
print_info(f" Job {i+1} submitted: {job_data['job_id']}, queue_position: {job_data.get('queue_position', 0)}")
# Wait for all jobs to complete
print_info(" Waiting for all jobs to complete...")
for i, job_id in enumerate(job_ids):
print_info(f" Waiting for job {i+1}/{len(job_ids)}...")
final_status = self.wait_for_job_completion(job_id, timeout=120)
assert final_status['status'] == 'completed', f"Job {i+1} failed"
print_success(f" All {len(job_ids)} jobs completed in FIFO order")
# ========================================================================
# TEST 3: GPU Health Check
# ========================================================================
def test_gpu_health_check(self):
"""Test GPU health check endpoint"""
print_info(" Checking GPU health...")
response = requests.get(f"{self.api_url}/health/gpu")
assert response.status_code == 200, f"GPU health check failed: {response.status_code}"
health_data = response.json()
assert 'gpu_available' in health_data, "Missing gpu_available field"
assert 'gpu_working' in health_data, "Missing gpu_working field"
assert 'device_used' in health_data, "Missing device_used field"
print_info(f" GPU Available: {health_data['gpu_available']}")
print_info(f" GPU Working: {health_data['gpu_working']}")
print_info(f" Device: {health_data['device_used']}")
if health_data['gpu_available']:
assert health_data['device_name'], "GPU available but no device_name"
assert health_data['test_duration_seconds'] < 3, "GPU test took too long (might be using CPU)"
print_success(f" GPU is healthy: {health_data['device_name']}")
else:
print_warning(" GPU not available on this system")
# ========================================================================
# TEST 4: Invalid Audio Path
# ========================================================================
def test_invalid_audio_path(self):
"""Test job submission with invalid audio path"""
print_info(" Submitting job with invalid path...")
response = requests.post(f"{self.api_url}/jobs", json={
"audio_path": "/invalid/path/does/not/exist.mp3",
"model_name": "large-v3"
})
# Should return 400 Bad Request
assert response.status_code == 400, f"Expected 400, got {response.status_code}"
error_data = response.json()
assert 'detail' in error_data or 'error' in error_data, "No error message in response"
print_success(" Invalid path rejected correctly")
# ========================================================================
# TEST 5: Job Not Found
# ========================================================================
def test_job_not_found(self):
"""Test retrieving non-existent job"""
print_info(" Requesting non-existent job...")
fake_job_id = "00000000-0000-0000-0000-000000000000"
response = requests.get(f"{self.api_url}/jobs/{fake_job_id}")
assert response.status_code == 404, f"Expected 404, got {response.status_code}"
print_success(" Non-existent job handled correctly")
# ========================================================================
# TEST 6: Result Before Completion
# ========================================================================
def test_result_before_completion(self):
"""Test getting result for job that hasn't completed"""
print_info(" Submitting job and trying to get result immediately...")
# Submit job
response = requests.post(f"{self.api_url}/jobs", json={
"audio_path": self.test_audio,
"model_name": "large-v3"
})
assert response.status_code == 200
job_id = response.json()['job_id']
# Try to get result immediately (job is still queued/running)
time.sleep(0.5)
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
# Should return 409 Conflict or similar
assert response.status_code in [409, 400, 404], f"Expected 4xx error, got {response.status_code}"
print_success(" Result request before completion handled correctly")
# Clean up: wait for job to complete
self.wait_for_job_completion(job_id)
# ========================================================================
# TEST 7: List Jobs
# ========================================================================
def test_list_jobs(self):
"""Test listing jobs with filters"""
print_info(" Testing job listing...")
# List all jobs
response = requests.get(f"{self.api_url}/jobs")
assert response.status_code == 200, f"List jobs failed: {response.status_code}"
jobs_data = response.json()
assert 'jobs' in jobs_data, "No jobs array in response"
assert isinstance(jobs_data['jobs'], list), "Jobs is not a list"
print_info(f" Found {len(jobs_data['jobs'])} jobs")
# List only completed jobs
response = requests.get(f"{self.api_url}/jobs?status=completed")
assert response.status_code == 200
completed_jobs = response.json()['jobs']
print_info(f" Found {len(completed_jobs)} completed jobs")
# List with limit
response = requests.get(f"{self.api_url}/jobs?limit=5")
assert response.status_code == 200
limited_jobs = response.json()['jobs']
assert len(limited_jobs) <= 5, "Limit not respected"
print_success(" Job listing works correctly")
# ========================================================================
# TEST 8: Server Restart with Job Persistence
# ========================================================================
def test_server_restart_persistence(self):
"""Test that jobs persist across server restarts"""
print_info(" Testing job persistence across restart...")
# Submit a job
response = requests.post(f"{self.api_url}/jobs", json={
"audio_path": self.test_audio,
"model_name": "tiny"
})
assert response.status_code == 200
job_id = response.json()['job_id']
print_info(f" Submitted job: {job_id}")
# Get job count before restart
response = requests.get(f"{self.api_url}/jobs")
jobs_before = len(response.json()['jobs'])
print_info(f" Jobs before restart: {jobs_before}")
# Restart server
print_info(" Restarting server...")
self.stop_api_server()
time.sleep(2)
assert self.start_api_server(wait_time=8), "Server failed to restart"
# Check jobs after restart
response = requests.get(f"{self.api_url}/jobs")
assert response.status_code == 200
jobs_after = len(response.json()['jobs'])
print_info(f" Jobs after restart: {jobs_after}")
# Check our specific job is still there (this is the key test)
response = requests.get(f"{self.api_url}/jobs/{job_id}")
assert response.status_code == 200, "Job not found after restart"
# Note: Total count may differ due to job retention/cleanup, but persistence works if we can find the job
if jobs_after < jobs_before:
print_warning(f" Job count decreased ({jobs_before} -> {jobs_after}), may be due to cleanup")
print_success(" Jobs persisted correctly across restart")
# ========================================================================
# TEST 9: Health Endpoint
# ========================================================================
def test_health_endpoint(self):
"""Test basic health endpoint"""
print_info(" Checking health endpoint...")
response = requests.get(f"{self.api_url}/health")
assert response.status_code == 200, f"Health check failed: {response.status_code}"
health_data = response.json()
assert health_data['status'] == 'healthy', "Server not healthy"
print_success(" Health endpoint OK")
# ========================================================================
# TEST 10: Models Endpoint
# ========================================================================
def test_models_endpoint(self):
"""Test models information endpoint"""
print_info(" Checking models endpoint...")
response = requests.get(f"{self.api_url}/models")
assert response.status_code == 200, f"Models endpoint failed: {response.status_code}"
models_data = response.json()
assert 'available_models' in models_data, "No available_models field"
assert 'available_devices' in models_data, "No available_devices field"
assert len(models_data['available_models']) > 0, "No models listed"
print_info(f" Available models: {len(models_data['available_models'])}")
print_success(" Models endpoint OK")
def print_summary(self):
"""Print test summary"""
print_section("TEST SUMMARY")
passed = sum(1 for _, result, _ in self.test_results if result)
failed = len(self.test_results) - passed
for name, result, error in self.test_results:
if result:
print_success(f"{name}")
else:
print_error(f"{name}")
if error:
print(f" {error}")
print(f"\n{Colors.BOLD}Total: {len(self.test_results)} | ", end="")
print(f"{Colors.GREEN}Passed: {passed}{Colors.END} | ", end="")
print(f"{Colors.RED}Failed: {failed}{Colors.END}\n")
return failed == 0
def run_all_tests(self, start_server=True):
"""Run all Phase 4 integration tests"""
print_section("PHASE 4: END-TO-END INTEGRATION TESTING")
try:
# Start server if requested
if start_server:
if not self.start_api_server():
print_error("Failed to start API server. Aborting tests.")
return False
else:
# Verify server is already running
try:
response = requests.get(f"{self.api_url}/health", timeout=5)
if response.status_code != 200:
print_error("Server is not responding. Please start it first.")
return False
print_info("Using existing API server")
except:
print_error("Cannot connect to API server. Please start it first.")
return False
# Run tests
print_section("TEST 1: Single Job Submission and Completion")
self.test("Single job flow (submit → poll → get result)", self.test_single_job_flow)
print_section("TEST 2: Multiple Jobs (FIFO Order)")
self.test("Multiple jobs in queue (FIFO)", self.test_multiple_jobs_fifo)
print_section("TEST 3: GPU Health Check")
self.test("GPU health check endpoint", self.test_gpu_health_check)
print_section("TEST 4: Error Handling - Invalid Path")
self.test("Invalid audio path rejection", self.test_invalid_audio_path)
print_section("TEST 5: Error Handling - Job Not Found")
self.test("Non-existent job handling", self.test_job_not_found)
print_section("TEST 6: Error Handling - Result Before Completion")
self.test("Result request before completion", self.test_result_before_completion)
print_section("TEST 7: Job Listing")
self.test("List jobs with filters", self.test_list_jobs)
print_section("TEST 8: Health Endpoint")
self.test("Basic health endpoint", self.test_health_endpoint)
print_section("TEST 9: Models Endpoint")
self.test("Models information endpoint", self.test_models_endpoint)
print_section("TEST 10: Server Restart Persistence")
self.test("Job persistence across server restart", self.test_server_restart_persistence)
# Print summary
success = self.print_summary()
return success
finally:
# Cleanup
if start_server and self.server_process:
self.stop_api_server()
def main():
"""Main test runner"""
import argparse
parser = argparse.ArgumentParser(description='Phase 4 Integration Tests')
parser.add_argument('--url', default='http://localhost:8000', help='API server URL')
# Default to None so Phase4Tester uses relative path
parser.add_argument('--audio', default=None,
help='Path to test audio file (default: <project_root>/data/test.mp3)')
parser.add_argument('--no-start-server', action='store_true',
help='Do not start server (assume it is already running)')
args = parser.parse_args()
tester = Phase4Tester(api_url=args.url, test_audio=args.audio)
success = tester.run_all_tests(start_server=not args.no_start_server)
sys.exit(0 if success else 1)
if __name__ == '__main__':
main()

View File

@@ -1,326 +0,0 @@
#!/usr/bin/env python3
"""
转录核心模块
包含音频转录的核心逻辑
"""
import os
import time
import logging
from typing import Dict, Any, Tuple, List, Optional, Union
from model_manager import get_whisper_model
from audio_processor import validate_audio_file, process_audio
from formatters import format_vtt, format_srt, format_json, format_time
# 日志配置
logger = logging.getLogger(__name__)
def transcribe_audio(
audio_path: str,
model_name: str = "large-v3",
device: str = "auto",
compute_type: str = "auto",
language: str = None,
output_format: str = "vtt",
beam_size: int = 5,
temperature: float = 0.0,
initial_prompt: str = None,
output_directory: str = None
) -> str:
"""
使用Faster Whisper转录音频文件
Args:
audio_path: 音频文件路径
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
device: 运行设备 (cpu, cuda, auto)
compute_type: 计算类型 (float16, int8, auto)
language: 语言代码 (如zh, en, ja等默认自动检测)
output_format: 输出格式 (vtt, srt或json)
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
temperature: 采样温度,贪婪解码
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
output_directory: 输出目录路径,默认为音频文件所在目录
Returns:
str: 转录结果格式为VTT字幕或JSON
"""
# 验证音频文件
validation_result = validate_audio_file(audio_path)
if validation_result != "ok":
return validation_result
try:
# 获取模型实例
model_instance = get_whisper_model(model_name, device, compute_type)
# 验证语言代码
supported_languages = {
"zh": "中文", "en": "英语", "ja": "日语", "ko": "韩语", "de": "德语",
"fr": "法语", "es": "西班牙语", "ru": "俄语", "it": "意大利语",
"pt": "葡萄牙语", "nl": "荷兰语", "ar": "阿拉伯语", "hi": "印地语",
"tr": "土耳其语", "vi": "越南语", "th": "泰语", "id": "印尼语"
}
if language is not None and language not in supported_languages:
logger.warning(f"未知的语言代码: {language},将使用自动检测")
language = None
# 设置转录参数
options = {
"language": language,
"vad_filter": True, # 使用语音活动检测
"vad_parameters": {"min_silence_duration_ms": 500}, # VAD参数优化
"beam_size": beam_size,
"temperature": temperature,
"initial_prompt": initial_prompt,
"word_timestamps": True, # 启用单词级时间戳
"suppress_tokens": [-1], # 抑制特殊标记
"condition_on_previous_text": True, # 基于前文进行条件生成
"compression_ratio_threshold": 2.4 # 压缩比阈值,用于过滤重复内容
}
start_time = time.time()
logger.info(f"开始转录文件: {os.path.basename(audio_path)}")
# 处理音频
audio_source = process_audio(audio_path)
# 执行转录 - 优先使用批处理模型
if model_instance['batched_model'] is not None and model_instance['device'] == 'cuda':
logger.info("使用批处理加速进行转录...")
# 批处理模型需要单独设置batch_size参数
segments, info = model_instance['batched_model'].transcribe(
audio_source,
batch_size=model_instance['batch_size'],
**options
)
else:
logger.info("使用标准模型进行转录...")
segments, info = model_instance['model'].transcribe(audio_source, **options)
# 将生成器转换为列表
segment_list = list(segments)
if not segment_list:
return "转录失败,未获得结果"
# 记录转录信息
elapsed_time = time.time() - start_time
logger.info(f"转录完成,用时: {elapsed_time:.2f}秒,检测语言: {info.language},音频长度: {info.duration:.2f}")
# 格式化转录结果
if output_format.lower() == "vtt":
transcription_result = format_vtt(segment_list)
elif output_format.lower() == "srt":
transcription_result = format_srt(segment_list)
else:
transcription_result = format_json(segment_list, info)
# 获取音频文件的目录和文件名
audio_dir = os.path.dirname(audio_path)
audio_filename = os.path.splitext(os.path.basename(audio_path))[0]
# 设置输出目录
if output_directory is None:
output_dir = audio_dir
else:
output_dir = output_directory
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 生成带有时间戳的文件名
timestamp = time.strftime("%Y%m%d%H%M%S")
output_filename = f"{audio_filename}_{timestamp}.{output_format.lower()}"
output_path = os.path.join(output_dir, output_filename)
# 将转录结果写入文件
try:
with open(output_path, "w", encoding="utf-8") as f:
f.write(transcription_result)
logger.info(f"转录结果已保存到: {output_path}")
return f"转录成功,结果已保存到: {output_path}"
except Exception as e:
logger.error(f"保存转录结果失败: {str(e)}")
return f"转录成功,但保存结果失败: {str(e)}"
except Exception as e:
logger.error(f"转录失败: {str(e)}")
return f"转录过程中发生错误: {str(e)}"
def report_progress(current: int, total: int, elapsed_time: float) -> str:
"""
生成进度报告
Args:
current: 当前处理的项目数
total: 总项目数
elapsed_time: 已用时间(秒)
Returns:
str: 格式化的进度报告
"""
progress = current / total * 100
eta = (elapsed_time / current) * (total - current) if current > 0 else 0
return (f"进度: {current}/{total} ({progress:.1f}%)" +
f" | 已用时间: {format_time(elapsed_time)}" +
f" | 预计剩余: {format_time(eta)}")
def batch_transcribe(
audio_folder: str,
output_folder: str = None,
model_name: str = "large-v3",
device: str = "auto",
compute_type: str = "auto",
language: str = None,
output_format: str = "vtt",
beam_size: int = 5,
temperature: float = 0.0,
initial_prompt: str = None,
parallel_files: int = 1
) -> str:
"""
批量转录文件夹中的音频文件
Args:
audio_folder: 包含音频文件的文件夹路径
output_folder: 输出文件夹路径默认为audio_folder下的transcript子文件夹
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
device: 运行设备 (cpu, cuda, auto)
compute_type: 计算类型 (float16, int8, auto)
language: 语言代码 (如zh, en, ja等默认自动检测)
output_format: 输出格式 (vtt, srt或json)
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
temperature: 采样温度0表示贪婪解码
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
parallel_files: 并行处理的文件数量仅在CPU模式下有效
Returns:
str: 批处理结果摘要,包含处理时间和成功率
"""
if not os.path.isdir(audio_folder):
return f"错误: 文件夹不存在: {audio_folder}"
# 设置输出文件夹
if output_folder is None:
output_folder = os.path.join(audio_folder, "transcript")
# 确保输出目录存在
os.makedirs(output_folder, exist_ok=True)
# 验证输出格式
valid_formats = ["vtt", "srt", "json"]
if output_format.lower() not in valid_formats:
return f"错误: 不支持的输出格式: {output_format}。支持的格式: {', '.join(valid_formats)}"
# 获取所有音频文件
audio_files = []
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
for filename in os.listdir(audio_folder):
file_ext = os.path.splitext(filename)[1].lower()
if file_ext in supported_formats:
audio_files.append(os.path.join(audio_folder, filename))
if not audio_files:
return f"{audio_folder} 中未找到支持的音频文件。支持的格式: {', '.join(supported_formats)}"
# 记录开始时间
start_time = time.time()
total_files = len(audio_files)
logger.info(f"开始批量转录 {total_files} 个文件,输出格式: {output_format}")
# 预加载模型以避免重复加载
try:
get_whisper_model(model_name, device, compute_type)
logger.info(f"已预加载模型: {model_name}")
except Exception as e:
logger.error(f"预加载模型失败: {str(e)}")
return f"批处理失败: 无法加载模型 {model_name}: {str(e)}"
# 处理每个文件
results = []
success_count = 0
error_count = 0
total_audio_duration = 0
# 处理每个文件
for i, audio_path in enumerate(audio_files):
file_name = os.path.basename(audio_path)
elapsed = time.time() - start_time
# 报告进度
progress_msg = report_progress(i, total_files, elapsed)
logger.info(f"{progress_msg} | 当前处理: {file_name}")
# 执行转录
try:
result = transcribe_audio(
audio_path=audio_path,
model_name=model_name,
device=device,
compute_type=compute_type,
language=language,
output_format=output_format,
beam_size=beam_size,
temperature=temperature,
initial_prompt=initial_prompt,
output_directory=output_folder
)
# 检查结果是否包含错误信息
if result.startswith("错误:") or result.startswith("转录过程中发生错误:"):
logger.error(f"转录失败: {file_name} - {result}")
results.append(f"❌ 失败: {file_name} - {result}")
error_count += 1
continue
# 如果转录成功,提取输出路径信息
if result.startswith("转录成功"):
# 从返回消息中提取输出路径
output_path = result.split(": ")[1] if ": " in result else "未知路径"
success_count += 1
results.append(f"✅ 成功: {file_name} -> {os.path.basename(output_path)}")
# 提取音频时长
audio_duration = 0
if output_format.lower() == "json":
# 尝试从输出文件中解析音频时长
try:
import json
# 从输出文件中读取JSON内容
with open(output_path, "r", encoding="utf-8") as json_file:
json_content = json_file.read()
json_data = json.loads(json_content)
audio_duration = json_data.get("duration", 0)
except Exception as e:
logger.warning(f"无法从JSON文件中提取音频时长: {str(e)}")
audio_duration = 0
else:
# 尝试从文件名中提取音频信息
try:
# 这里我们不能直接访问info对象因为它在transcribe_audio函数的作用域内
# 使用一个保守的估计值或从结果字符串中提取信息
audio_duration = 0 # 默认为0
except Exception as e:
logger.warning(f"无法从文件名中提取音频时长: {str(e)}")
audio_duration = 0
# 累加音频时长
total_audio_duration += audio_duration
except Exception as e:
logger.error(f"转录过程中发生错误: {file_name} - {str(e)}")
results.append(f"❌ 失败: {file_name} - {str(e)}")
error_count += 1
# 计算总转录时间
total_transcription_time = time.time() - start_time
# 生成批处理结果摘要
summary = f"批处理完成,总转录时间: {format_time(total_transcription_time)}"
summary += f" | 成功: {success_count}/{total_files}"
summary += f" | 失败: {error_count}/{total_files}"
# 输出结果
for result in results:
logger.info(result)
return summary

View File

@@ -1,108 +0,0 @@
#!/usr/bin/env python3
"""
基于Faster Whisper的语音识别MCP服务
提供高性能的音频转录功能,支持批处理加速和多种输出格式
"""
import os
import logging
from mcp.server.fastmcp import FastMCP
from model_manager import get_model_info
from transcriber import transcribe_audio, batch_transcribe
# 日志配置
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建FastMCP服务器实例
mcp = FastMCP(
name="fast-whisper-mcp-server",
version="0.1.1",
dependencies=["faster-whisper>=0.9.0", "torch==2.6.0+cu126", "torchaudio==2.6.0+cu126", "numpy>=1.20.0"]
)
@mcp.tool()
def get_model_info_api() -> str:
"""
获取可用的Whisper模型信息
"""
return get_model_info()
@mcp.tool()
def transcribe(audio_path: str, model_name: str = "large-v3", device: str = "auto",
compute_type: str = "auto", language: str = None, output_format: str = "vtt",
beam_size: int = 5, temperature: float = 0.0, initial_prompt: str = None,
output_directory: str = None) -> str:
"""
使用Faster Whisper转录音频文件
Args:
audio_path: 音频文件路径
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
device: 运行设备 (cpu, cuda, auto)
compute_type: 计算类型 (float16, int8, auto)
language: 语言代码 (如zh, en, ja等默认自动检测)
output_format: 输出格式 (vtt, srt或json)
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
temperature: 采样温度,贪婪解码
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
output_directory: 输出目录路径,默认为音频文件所在目录
Returns:
str: 转录结果格式为VTT字幕或JSON
"""
return transcribe_audio(
audio_path=audio_path,
model_name=model_name,
device=device,
compute_type=compute_type,
language=language,
output_format=output_format,
beam_size=beam_size,
temperature=temperature,
initial_prompt=initial_prompt,
output_directory=output_directory
)
@mcp.tool()
def batch_transcribe_audio(audio_folder: str, output_folder: str = None, model_name: str = "large-v3",
device: str = "auto", compute_type: str = "auto", language: str = None,
output_format: str = "vtt", beam_size: int = 5, temperature: float = 0.0,
initial_prompt: str = None, parallel_files: int = 1) -> str:
"""
批量转录文件夹中的音频文件
Args:
audio_folder: 包含音频文件的文件夹路径
output_folder: 输出文件夹路径默认为audio_folder下的transcript子文件夹
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
device: 运行设备 (cpu, cuda, auto)
compute_type: 计算类型 (float16, int8, auto)
language: 语言代码 (如zh, en, ja等默认自动检测)
output_format: 输出格式 (vtt, srt或json)
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
temperature: 采样温度0表示贪婪解码
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
parallel_files: 并行处理的文件数量仅在CPU模式下有效
Returns:
str: 批处理结果摘要,包含处理时间和成功率
"""
return batch_transcribe(
audio_folder=audio_folder,
output_folder=output_folder,
model_name=model_name,
device=device,
compute_type=compute_type,
language=language,
output_format=output_format,
beam_size=beam_size,
temperature=temperature,
initial_prompt=initial_prompt,
parallel_files=parallel_files
)
if __name__ == "__main__":
# 运行服务器
mcp.run()