Compare commits
21 Commits
main
...
alihan-spe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
990fa28668 | ||
|
|
fb1e5dceba | ||
|
|
f6777b1488 | ||
|
|
3c0f79645c | ||
|
|
c6462e2bbe | ||
|
|
d47c2843c3 | ||
|
|
06b8bc1304 | ||
|
|
66b36e71e8 | ||
|
|
5fb742a312 | ||
|
|
40555592e6 | ||
|
|
1292f0f09b | ||
|
|
e7a457e602 | ||
|
|
7c9a8d8378 | ||
|
|
2cc9f298a5 | ||
|
|
56ccc0e1d7 | ||
|
|
53af30619f | ||
|
|
046204d555 | ||
|
|
9c020f947b | ||
|
|
4936684db4 | ||
|
|
8e30a8812c | ||
|
|
37935066ad |
60
.dockerignore
Normal file
60
.dockerignore
Normal file
@@ -0,0 +1,60 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
.venv
|
||||
|
||||
# Project specific
|
||||
logs/
|
||||
outputs/
|
||||
models/
|
||||
*.log
|
||||
*.logs
|
||||
mcp.logs
|
||||
api.logs
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
.github/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Docker
|
||||
.dockerignore
|
||||
docker-compose.yml
|
||||
docker-compose.*.yml
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Documentation (optional - uncomment if you want to exclude)
|
||||
# README.md
|
||||
# CLAUDE.md
|
||||
# IMPLEMENTATION_PLAN.md
|
||||
|
||||
# Scripts (already in container)
|
||||
# reset_gpu.sh - NEEDED for GPU health checks
|
||||
run_api_server.sh
|
||||
run_mcp_server.sh
|
||||
|
||||
# Supervisor config (not needed in container)
|
||||
supervisor/
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -14,4 +14,11 @@ venv/
|
||||
# Cython
|
||||
*.pyd
|
||||
|
||||
logs/**
|
||||
User/**
|
||||
data/**
|
||||
models/*
|
||||
outputs/*
|
||||
api.logs
|
||||
|
||||
IMPLEMENTATION_PLAN.md
|
||||
|
||||
103
Dockerfile
Normal file
103
Dockerfile
Normal file
@@ -0,0 +1,103 @@
|
||||
# Multi-purpose Whisper Transcriptor Docker Image
|
||||
# Supports both MCP Server and REST API Server modes
|
||||
# Use SERVER_MODE environment variable to select: "mcp" or "api"
|
||||
|
||||
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
|
||||
|
||||
# Prevent interactive prompts during installation
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
software-properties-common \
|
||||
curl \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update && apt-get install -y \
|
||||
python3.12 \
|
||||
python3.12-venv \
|
||||
python3.12-dev \
|
||||
ffmpeg \
|
||||
git \
|
||||
nginx \
|
||||
supervisor \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Make python3.12 the default
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
|
||||
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
|
||||
|
||||
# Install pip using ensurepip (Python 3.12+ doesn't have distutils)
|
||||
RUN python -m ensurepip --upgrade && \
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies with CUDA 12.4 support
|
||||
RUN pip install --no-cache-dir \
|
||||
torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124 && \
|
||||
pip install --no-cache-dir \
|
||||
torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 && \
|
||||
pip install --no-cache-dir \
|
||||
faster-whisper \
|
||||
fastapi>=0.115.0 \
|
||||
uvicorn[standard]>=0.32.0 \
|
||||
python-multipart>=0.0.9 \
|
||||
aiofiles>=23.0.0 \
|
||||
mcp[cli]>=1.2.0 \
|
||||
gTTS>=2.3.0 \
|
||||
pyttsx3>=2.90 \
|
||||
scipy>=1.10.0 \
|
||||
numpy>=1.24.0
|
||||
|
||||
# Copy application code
|
||||
COPY src/ ./src/
|
||||
COPY pyproject.toml .
|
||||
|
||||
# Copy test audio file for GPU health checks
|
||||
COPY test.mp3 .
|
||||
|
||||
# Copy nginx configuration
|
||||
COPY nginx/transcriptor.conf /etc/nginx/sites-available/transcriptor.conf
|
||||
|
||||
# Copy entrypoint script and GPU reset script
|
||||
COPY docker-entrypoint.sh /docker-entrypoint.sh
|
||||
COPY reset_gpu.sh /app/reset_gpu.sh
|
||||
RUN chmod +x /docker-entrypoint.sh /app/reset_gpu.sh
|
||||
|
||||
# Create directories for models, outputs, and logs
|
||||
RUN mkdir -p /models /outputs /logs /app/outputs/uploads /app/outputs/batch /app/outputs/jobs
|
||||
|
||||
# Set Python path
|
||||
ENV PYTHONPATH=/app/src
|
||||
|
||||
# Default environment variables (can be overridden)
|
||||
ENV WHISPER_MODEL_DIR=/models \
|
||||
TRANSCRIPTION_OUTPUT_DIR=/outputs \
|
||||
TRANSCRIPTION_BATCH_OUTPUT_DIR=/outputs/batch \
|
||||
TRANSCRIPTION_MODEL=large-v3 \
|
||||
TRANSCRIPTION_DEVICE=auto \
|
||||
TRANSCRIPTION_COMPUTE_TYPE=auto \
|
||||
TRANSCRIPTION_OUTPUT_FORMAT=txt \
|
||||
TRANSCRIPTION_BEAM_SIZE=5 \
|
||||
TRANSCRIPTION_TEMPERATURE=0.0 \
|
||||
API_HOST=127.0.0.1 \
|
||||
API_PORT=33767 \
|
||||
JOB_QUEUE_MAX_SIZE=5 \
|
||||
JOB_METADATA_DIR=/outputs/jobs \
|
||||
JOB_RETENTION_DAYS=7 \
|
||||
GPU_HEALTH_CHECK_ENABLED=true \
|
||||
GPU_HEALTH_CHECK_INTERVAL_MINUTES=10 \
|
||||
GPU_HEALTH_TEST_MODEL=tiny \
|
||||
GPU_HEALTH_TEST_AUDIO=/test-audio/test.mp3 \
|
||||
GPU_RESET_COOLDOWN_MINUTES=5 \
|
||||
SERVER_MODE=api
|
||||
|
||||
# Expose port 80 for nginx (API mode only)
|
||||
EXPOSE 80
|
||||
|
||||
# Use entrypoint script to handle different server modes
|
||||
ENTRYPOINT ["/docker-entrypoint.sh"]
|
||||
184
README-CN.md
184
README-CN.md
@@ -1,184 +0,0 @@
|
||||
# Whisper 语音识别 MCP 服务器
|
||||
|
||||
基于 Faster Whisper 的语音识别 MCP 服务器,提供高性能的音频转录功能。
|
||||
|
||||
## 功能特点
|
||||
|
||||
- 集成 Faster Whisper 进行高效语音识别
|
||||
- 支持批处理加速,提高转录速度
|
||||
- 自动使用 CUDA 加速(如果可用)
|
||||
- 支持多种模型大小(tiny 到 large-v3)
|
||||
- 输出格式支持 VTT 字幕和 JSON
|
||||
- 支持批量转录文件夹中的音频文件
|
||||
- 模型实例缓存,避免重复加载
|
||||
|
||||
## 安装
|
||||
|
||||
### 依赖项
|
||||
|
||||
- Python 3.10+
|
||||
- faster-whisper>=0.9.0
|
||||
- torch==2.6.0+cu126
|
||||
- torchaudio==2.6.0+cu126
|
||||
- mcp[cli]>=1.2.0
|
||||
|
||||
### 安装步骤
|
||||
|
||||
1. 克隆或下载此仓库
|
||||
2. 创建并激活虚拟环境(推荐)
|
||||
3. 安装依赖项:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 启动服务器
|
||||
|
||||
在 Windows 上,直接运行 `start_server.bat`。
|
||||
|
||||
在其他平台上,运行:
|
||||
|
||||
```bash
|
||||
python whisper_server.py
|
||||
```
|
||||
|
||||
### 配置 Claude Desktop
|
||||
|
||||
1. 打开 Claude Desktop 配置文件:
|
||||
- Windows: `%APPDATA%\Claude\claude_desktop_config.json`
|
||||
- macOS: `~/Library/Application Support/Claude/claude_desktop_config.json`
|
||||
|
||||
2. 添加 Whisper 服务器配置:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"whisper": {
|
||||
"command": "python",
|
||||
"args": ["D:/path/to/whisper_server.py"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. 重启 Claude Desktop
|
||||
|
||||
### 可用工具
|
||||
|
||||
服务器提供以下工具:
|
||||
|
||||
1. **get_model_info** - 获取可用的 Whisper 模型信息
|
||||
2. **transcribe** - 转录单个音频文件
|
||||
3. **batch_transcribe** - 批量转录文件夹中的音频文件
|
||||
|
||||
## 性能优化建议
|
||||
|
||||
- 使用 CUDA 加速可显著提高转录速度
|
||||
- 对于大量短音频,批处理模式效率更高
|
||||
- 根据 GPU 显存大小自动调整批处理大小
|
||||
- 对于长音频,使用 VAD 过滤可提高准确性
|
||||
- 指定正确的语言可提高转录质量
|
||||
|
||||
## 本地测试方案
|
||||
|
||||
1. 使用 MCP Inspector 进行快速测试:
|
||||
|
||||
```bash
|
||||
mcp dev whisper_server.py
|
||||
```
|
||||
|
||||
2. 使用 Claude Desktop 进行集成测试
|
||||
|
||||
3. 使用命令行直接调用(需要安装 mcp[cli]):
|
||||
|
||||
```bash
|
||||
mcp run whisper_server.py
|
||||
```
|
||||
|
||||
## 错误处理
|
||||
|
||||
服务器实现了以下错误处理机制:
|
||||
|
||||
- 音频文件不存在检查
|
||||
- 模型加载失败处理
|
||||
- 转录过程异常捕获
|
||||
- GPU 内存管理
|
||||
- 批处理参数自适应调整
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT
|
||||
|
||||
## 致谢
|
||||
|
||||
本项目在开发过程中得到了以下优秀AI工具和模型的帮助:
|
||||
|
||||
- [GitHub Copilot](https://github.com/features/copilot) - AI结对编程助手
|
||||
- [Trae](https://trae.ai/) - 智能AI编码助手
|
||||
- [Cline](https://cline.ai/) - AI驱动的终端
|
||||
- [DeepSeek](https://www.deepseek.com/) - 先进的AI模型
|
||||
- [Claude-3.7-Sonnet](https://www.anthropic.com/claude) - Anthropic强大的AI助手
|
||||
- [Gemini-2.0-Flash](https://ai.google/gemini/) - Google的多模态AI模型
|
||||
- [VS Code](https://code.visualstudio.com/) - 强大的代码编辑器
|
||||
- [Whisper](https://github.com/openai/whisper) - OpenAI的语音识别模型
|
||||
- [Faster Whisper](https://github.com/guillaumekln/faster-whisper) - 优化的Whisper实现
|
||||
|
||||
特别感谢这些出色的工具和背后的团队。
|
||||
|
||||
---
|
||||
|
||||
# Whisper 语音识别 MCP 服务器(cline claude sonnet 3.7 完成所有任务后的说明)
|
||||
|
||||
我已经成功创建了基于Python的语音识别MCP服务器,该服务器集成了Faster Whisper并支持批处理加速。服务器具有以下特点:
|
||||
|
||||
## 核心功能
|
||||
|
||||
- **集成Faster Whisper**:使用最新版本(>=0.9.0)进行高效语音识别
|
||||
- **CUDA加速**:自动检测并使用CUDA加速,默认使用torch==2.6.0+cu126
|
||||
- **批处理优化**:使用BatchedInferencePipeline实现批处理,根据GPU显存自动调整批大小
|
||||
- **模型缓存**:实现模型实例缓存,避免重复加载模型
|
||||
- **多格式输出**:支持VTT字幕和JSON格式的转录结果
|
||||
- **批量处理**:支持批量转录整个文件夹的音频文件
|
||||
|
||||
## 主要工具
|
||||
|
||||
服务器提供了三个主要工具:
|
||||
|
||||
1. **get_model_info**:获取可用的Whisper模型信息和系统配置
|
||||
2. **transcribe**:转录单个音频文件,支持多种参数配置
|
||||
3. **batch_transcribe**:批量转录文件夹中的音频文件
|
||||
|
||||
## 错误处理机制
|
||||
|
||||
- 音频文件存在性验证
|
||||
- 模型加载异常捕获和日志记录
|
||||
- 转录过程异常处理
|
||||
- GPU内存管理和清理
|
||||
- 批处理参数自适应调整
|
||||
|
||||
## 性能优化
|
||||
|
||||
- 根据GPU显存大小动态调整批处理大小(4-32)
|
||||
- 使用VAD(语音活动检测)过滤提高准确性
|
||||
- 模型实例缓存避免重复加载
|
||||
- 自动选择最佳设备和计算类型
|
||||
|
||||
## 本地测试方案
|
||||
|
||||
提供了多种测试方法:
|
||||
|
||||
- 使用MCP Inspector进行快速测试:`mcp dev whisper_server.py`
|
||||
- 使用Claude Desktop进行集成测试
|
||||
- 使用命令行直接调用:`mcp run whisper_server.py`
|
||||
|
||||
所有文件已准备就绪,包括:
|
||||
|
||||
- whisper_server.py:主服务器代码
|
||||
- requirements.txt:依赖项列表
|
||||
- start_server.bat:Windows启动脚本
|
||||
- README.md:详细文档
|
||||
|
||||
您可以通过运行start_server.bat或直接执行`python whisper_server.py`来启动服务器。
|
||||
163
README.md
163
README.md
@@ -1,163 +0,0 @@
|
||||
# Whisper Speech Recognition MCP Server
|
||||
---
|
||||
[中文文档](README-CN.md)
|
||||
---
|
||||
A high-performance speech recognition MCP server based on Faster Whisper, providing efficient audio transcription capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- Integrated with Faster Whisper for efficient speech recognition
|
||||
- Batch processing acceleration for improved transcription speed
|
||||
- Automatic CUDA acceleration (if available)
|
||||
- Support for multiple model sizes (tiny to large-v3)
|
||||
- Output formats include VTT subtitles, SRT, and JSON
|
||||
- Support for batch transcription of audio files in a folder
|
||||
- Model instance caching to avoid repeated loading
|
||||
- Dynamic batch size adjustment based on GPU memory
|
||||
|
||||
## Installation
|
||||
|
||||
### Dependencies
|
||||
|
||||
- Python 3.10+
|
||||
- faster-whisper>=0.9.0
|
||||
- torch==2.6.0+cu126
|
||||
- torchaudio==2.6.0+cu126
|
||||
- mcp[cli]>=1.2.0
|
||||
|
||||
### Installation Steps
|
||||
|
||||
1. Clone or download this repository
|
||||
2. Create and activate a virtual environment (recommended)
|
||||
3. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### PyTorch Installation Guide
|
||||
|
||||
Install the appropriate version of PyTorch based on your CUDA version:
|
||||
|
||||
- CUDA 12.6:
|
||||
```bash
|
||||
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
```
|
||||
|
||||
- CUDA 12.1:
|
||||
```bash
|
||||
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
- CPU version:
|
||||
```bash
|
||||
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
You can check your CUDA version with `nvcc --version` or `nvidia-smi`.
|
||||
|
||||
## Usage
|
||||
|
||||
### Starting the Server
|
||||
|
||||
On Windows, simply run `start_server.bat`.
|
||||
|
||||
On other platforms, run:
|
||||
|
||||
```bash
|
||||
python whisper_server.py
|
||||
```
|
||||
|
||||
### Configuring Claude Desktop
|
||||
|
||||
1. Open the Claude Desktop configuration file:
|
||||
- Windows: `%APPDATA%\Claude\claude_desktop_config.json`
|
||||
- macOS: `~/Library/Application Support/Claude/claude_desktop_config.json`
|
||||
|
||||
2. Add the Whisper server configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"whisper": {
|
||||
"command": "python",
|
||||
"args": ["D:/path/to/whisper_server.py"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. Restart Claude Desktop
|
||||
|
||||
### Available Tools
|
||||
|
||||
The server provides the following tools:
|
||||
|
||||
1. **get_model_info** - Get information about available Whisper models
|
||||
2. **transcribe** - Transcribe a single audio file
|
||||
3. **batch_transcribe** - Batch transcribe audio files in a folder
|
||||
|
||||
## Performance Optimization Tips
|
||||
|
||||
- Using CUDA acceleration significantly improves transcription speed
|
||||
- Batch processing mode is more efficient for large numbers of short audio files
|
||||
- Batch size is automatically adjusted based on GPU memory size
|
||||
- Using VAD (Voice Activity Detection) filtering improves accuracy for long audio
|
||||
- Specifying the correct language can improve transcription quality
|
||||
|
||||
## Local Testing Methods
|
||||
|
||||
1. Use MCP Inspector for quick testing:
|
||||
|
||||
```bash
|
||||
mcp dev whisper_server.py
|
||||
```
|
||||
|
||||
2. Use Claude Desktop for integration testing
|
||||
|
||||
3. Use command line direct invocation (requires mcp[cli]):
|
||||
|
||||
```bash
|
||||
mcp run whisper_server.py
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The server implements the following error handling mechanisms:
|
||||
|
||||
- Audio file existence check
|
||||
- Model loading failure handling
|
||||
- Transcription process exception catching
|
||||
- GPU memory management
|
||||
- Batch processing parameter adaptive adjustment
|
||||
|
||||
## Project Structure
|
||||
|
||||
- `whisper_server.py`: Main server code
|
||||
- `model_manager.py`: Whisper model loading and caching
|
||||
- `audio_processor.py`: Audio file validation and preprocessing
|
||||
- `formatters.py`: Output formatting (VTT, SRT, JSON)
|
||||
- `transcriber.py`: Core transcription logic
|
||||
- `start_server.bat`: Windows startup script
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
This project was developed with the assistance of these amazing AI tools and models:
|
||||
|
||||
- [GitHub Copilot](https://github.com/features/copilot) - AI pair programmer
|
||||
- [Trae](https://trae.ai/) - Agentic AI coding assistant
|
||||
- [Cline](https://cline.ai/) - AI-powered terminal
|
||||
- [DeepSeek](https://www.deepseek.com/) - Advanced AI model
|
||||
- [Claude-3.7-Sonnet](https://www.anthropic.com/claude) - Anthropic's powerful AI assistant
|
||||
- [Gemini-2.0-Flash](https://ai.google/gemini/) - Google's multimodal AI model
|
||||
- [VS Code](https://code.visualstudio.com/) - Powerful code editor
|
||||
- [Whisper](https://github.com/openai/whisper) - OpenAI's speech recognition model
|
||||
- [Faster Whisper](https://github.com/guillaumekln/faster-whisper) - Optimized Whisper implementation
|
||||
|
||||
Special thanks to these incredible tools and the teams behind them.
|
||||
|
||||
403
TRANSCRIPTOR_API_FIX.md
Normal file
403
TRANSCRIPTOR_API_FIX.md
Normal file
@@ -0,0 +1,403 @@
|
||||
# Transcriptor API - Filename Validation Bug Fix
|
||||
|
||||
## Issue Summary
|
||||
|
||||
The transcriptor API is rejecting valid audio files due to overly strict path validation. Files with `..` (double periods) anywhere in the filename are being rejected as potential path traversal attacks, even when they appear naturally in legitimate filenames.
|
||||
|
||||
## Current Behavior
|
||||
|
||||
### Error Observed
|
||||
```json
|
||||
{
|
||||
"detail": {
|
||||
"error": "Upload failed",
|
||||
"message": "Audio file validation failed: Path traversal (..) is not allowed"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### HTTP Response
|
||||
- **Status Code**: 500
|
||||
- **Endpoint**: `POST /transcribe`
|
||||
- **Request**: File upload with filename containing `..`
|
||||
|
||||
### Example Failing Filename
|
||||
```
|
||||
This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a
|
||||
^^^
|
||||
(Three dots, parsed as "..")
|
||||
```
|
||||
|
||||
## Root Cause Analysis
|
||||
|
||||
### Current Validation Logic (Problematic)
|
||||
The API is likely checking for `..` anywhere in the filename string, which creates false positives:
|
||||
|
||||
```python
|
||||
# CURRENT (WRONG)
|
||||
if ".." in filename:
|
||||
raise ValidationError("Path traversal (..) is not allowed")
|
||||
```
|
||||
|
||||
This rejects legitimate filenames like:
|
||||
- `"video...mp4"` (ellipsis in title)
|
||||
- `"Part 1... Part 2.m4a"` (ellipsis separator)
|
||||
- `"Wait... what.mp4"` (dramatic pause)
|
||||
|
||||
### Actual Security Concern
|
||||
Path traversal attacks use `..` as **directory separators** to navigate up the filesystem:
|
||||
- `../../etc/passwd` (DANGEROUS)
|
||||
- `../../../secrets.txt` (DANGEROUS)
|
||||
- `video...mp4` (SAFE - just a filename)
|
||||
|
||||
## Recommended Fix
|
||||
|
||||
### Option 1: Path Component Validation (Recommended)
|
||||
|
||||
Check for `..` only when it appears as a **complete path component**, not as part of the filename text.
|
||||
|
||||
```python
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def validate_filename(filename: str) -> bool:
|
||||
"""
|
||||
Validate filename for path traversal attacks.
|
||||
|
||||
Returns True if safe, raises ValidationError if dangerous.
|
||||
"""
|
||||
# Normalize the path
|
||||
normalized = os.path.normpath(filename)
|
||||
|
||||
# Check if normalization changed the path (indicates traversal)
|
||||
if normalized != filename:
|
||||
raise ValidationError(f"Path traversal detected: {filename}")
|
||||
|
||||
# Check for absolute paths
|
||||
if os.path.isabs(filename):
|
||||
raise ValidationError(f"Absolute paths not allowed: {filename}")
|
||||
|
||||
# Split into components and check for parent directory references
|
||||
parts = Path(filename).parts
|
||||
if ".." in parts:
|
||||
raise ValidationError(f"Parent directory references not allowed: {filename}")
|
||||
|
||||
# Check for any path separators (should be basename only)
|
||||
if os.sep in filename or (os.altsep and os.altsep in filename):
|
||||
raise ValidationError(f"Path separators not allowed: {filename}")
|
||||
|
||||
return True
|
||||
|
||||
# Examples:
|
||||
validate_filename("video.mp4") # ✓ PASS
|
||||
validate_filename("video...mp4") # ✓ PASS (ellipsis)
|
||||
validate_filename("This is... a video.m4a") # ✓ PASS
|
||||
validate_filename("../../../etc/passwd") # ✗ FAIL (traversal)
|
||||
validate_filename("dir/../file.mp4") # ✗ FAIL (traversal)
|
||||
validate_filename("/etc/passwd") # ✗ FAIL (absolute)
|
||||
```
|
||||
|
||||
### Option 2: Basename-Only Validation (Simpler)
|
||||
|
||||
Only accept basenames (no directory components at all):
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
def validate_filename(filename: str) -> bool:
|
||||
"""
|
||||
Ensure filename contains no path components.
|
||||
"""
|
||||
# Extract basename
|
||||
basename = os.path.basename(filename)
|
||||
|
||||
# Must match original (no path components)
|
||||
if basename != filename:
|
||||
raise ValidationError(f"Filename must not contain path components: {filename}")
|
||||
|
||||
# Additional check: no path separators
|
||||
if "/" in filename or "\\" in filename:
|
||||
raise ValidationError(f"Path separators not allowed: {filename}")
|
||||
|
||||
return True
|
||||
|
||||
# Examples:
|
||||
validate_filename("video.mp4") # ✓ PASS
|
||||
validate_filename("video...mp4") # ✓ PASS
|
||||
validate_filename("../file.mp4") # ✗ FAIL
|
||||
validate_filename("dir/file.mp4") # ✗ FAIL
|
||||
```
|
||||
|
||||
### Option 3: Regex Pattern Matching (Most Strict)
|
||||
|
||||
Use a whitelist approach for allowed characters:
|
||||
|
||||
```python
|
||||
import re
|
||||
|
||||
def validate_filename(filename: str) -> bool:
|
||||
"""
|
||||
Validate filename using whitelist of safe characters.
|
||||
"""
|
||||
# Allow: letters, numbers, spaces, dots, hyphens, underscores
|
||||
# Length: 1-255 characters
|
||||
pattern = r'^[a-zA-Z0-9 .\-_]{1,255}\.[a-zA-Z0-9]{2,10}$'
|
||||
|
||||
if not re.match(pattern, filename):
|
||||
raise ValidationError(f"Invalid filename format: {filename}")
|
||||
|
||||
# Additional safety: reject if starts/ends with dot
|
||||
if filename.startswith('.') or filename.endswith('.'):
|
||||
raise ValidationError(f"Filename cannot start or end with dot: {filename}")
|
||||
|
||||
return True
|
||||
|
||||
# Examples:
|
||||
validate_filename("video.mp4") # ✓ PASS
|
||||
validate_filename("video...mp4") # ✓ PASS
|
||||
validate_filename("My Video... Part 2.m4a") # ✓ PASS
|
||||
validate_filename("../file.mp4") # ✗ FAIL (starts with ..)
|
||||
validate_filename("file<>.mp4") # ✗ FAIL (invalid chars)
|
||||
```
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### 1. Locate Current Validation Code
|
||||
|
||||
Search for files containing the validation logic:
|
||||
|
||||
```bash
|
||||
grep -r "Path traversal" /path/to/transcriptor-api
|
||||
grep -r '".."' /path/to/transcriptor-api
|
||||
grep -r "normpath\|basename" /path/to/transcriptor-api
|
||||
```
|
||||
|
||||
### 2. Update Validation Function
|
||||
|
||||
Replace the current naive check with one of the recommended solutions above.
|
||||
|
||||
**Priority Order:**
|
||||
1. **Option 1** (Path Component Validation) - Best security/usability balance
|
||||
2. **Option 2** (Basename-Only) - Simplest, very secure
|
||||
3. **Option 3** (Regex) - Most restrictive, may reject valid files
|
||||
|
||||
### 3. Test Cases
|
||||
|
||||
Create comprehensive test suite:
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
def test_valid_filenames():
|
||||
"""Test filenames that should be accepted."""
|
||||
valid_names = [
|
||||
"video.mp4",
|
||||
"audio.m4a",
|
||||
"This is... a test.mp4",
|
||||
"Part 1... Part 2.wav",
|
||||
"video...multiple...dots.mp3",
|
||||
"My-Video_2024.mp4",
|
||||
"song (remix).m4a",
|
||||
]
|
||||
|
||||
for filename in valid_names:
|
||||
assert validate_filename(filename), f"Should accept: {filename}"
|
||||
|
||||
def test_dangerous_filenames():
|
||||
"""Test filenames that should be rejected."""
|
||||
dangerous_names = [
|
||||
"../../../etc/passwd",
|
||||
"../../secrets.txt",
|
||||
"../file.mp4",
|
||||
"/etc/passwd",
|
||||
"C:\\Windows\\System32\\file.txt",
|
||||
"dir/../file.mp4",
|
||||
"file/../../etc/passwd",
|
||||
]
|
||||
|
||||
for filename in dangerous_names:
|
||||
with pytest.raises(ValidationError):
|
||||
validate_filename(filename)
|
||||
|
||||
def test_edge_cases():
|
||||
"""Test edge cases."""
|
||||
edge_cases = [
|
||||
(".", False), # Current directory
|
||||
("..", False), # Parent directory
|
||||
("...", True), # Just dots (valid)
|
||||
("....", True), # Multiple dots (valid)
|
||||
(".hidden.mp4", True), # Hidden file (valid on Unix)
|
||||
("", False), # Empty string
|
||||
("a" * 256, False), # Too long
|
||||
]
|
||||
|
||||
for filename, should_pass in edge_cases:
|
||||
if should_pass:
|
||||
assert validate_filename(filename)
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
validate_filename(filename)
|
||||
```
|
||||
|
||||
### 4. Update Error Response
|
||||
|
||||
Provide clearer error messages:
|
||||
|
||||
```python
|
||||
# BAD (current)
|
||||
{"detail": {"error": "Upload failed", "message": "Audio file validation failed: Path traversal (..) is not allowed"}}
|
||||
|
||||
# GOOD (improved)
|
||||
{
|
||||
"detail": {
|
||||
"error": "Invalid filename",
|
||||
"message": "Filename contains path traversal characters. Please use only the filename without directory paths.",
|
||||
"filename": "../../etc/passwd",
|
||||
"suggestion": "Use: passwd.txt"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing the Fix
|
||||
|
||||
### Manual Testing
|
||||
|
||||
1. **Test with problematic filename from bug report:**
|
||||
```bash
|
||||
curl -X POST http://192.168.1.210:33767/transcribe \
|
||||
-F "file=@/path/to/This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a" \
|
||||
-F "model=medium"
|
||||
```
|
||||
Expected: HTTP 200 (success)
|
||||
|
||||
2. **Test with actual path traversal:**
|
||||
```bash
|
||||
curl -X POST http://192.168.1.210:33767/transcribe \
|
||||
-F "file=@/tmp/test.m4a;filename=../../etc/passwd" \
|
||||
-F "model=medium"
|
||||
```
|
||||
Expected: HTTP 400 (validation error)
|
||||
|
||||
3. **Test with various ellipsis patterns:**
|
||||
- `"video...mp4"` → Should pass
|
||||
- `"Part 1... Part 2.m4a"` → Should pass
|
||||
- `"Wait... what!.mp4"` → Should pass
|
||||
|
||||
### Automated Testing
|
||||
|
||||
```python
|
||||
# integration_test.py
|
||||
import requests
|
||||
|
||||
def test_ellipsis_filenames():
|
||||
"""Test files with ellipsis in names."""
|
||||
test_cases = [
|
||||
"video...mp4",
|
||||
"This is... a test.m4a",
|
||||
"Wait... what.mp3",
|
||||
]
|
||||
|
||||
for filename in test_cases:
|
||||
response = requests.post(
|
||||
"http://192.168.1.210:33767/transcribe",
|
||||
files={"file": (filename, open("test_audio.m4a", "rb"))},
|
||||
data={"model": "medium"}
|
||||
)
|
||||
assert response.status_code == 200, f"Failed for: {filename}"
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### What We're Protecting Against
|
||||
|
||||
1. **Path Traversal**: `../../../sensitive/file`
|
||||
2. **Absolute Paths**: `/etc/passwd` or `C:\Windows\System32\`
|
||||
3. **Hidden Paths**: `./.git/config`
|
||||
|
||||
### What We're NOT Breaking
|
||||
|
||||
1. **Ellipsis in titles**: `"Wait... what.mp4"`
|
||||
2. **Multiple extensions**: `"file.tar.gz"`
|
||||
3. **Special characters**: `"My Video (2024).mp4"`
|
||||
|
||||
### Additional Hardening (Optional)
|
||||
|
||||
```python
|
||||
def sanitize_and_validate_filename(filename: str) -> str:
|
||||
"""
|
||||
Sanitize filename and validate for safety.
|
||||
Returns cleaned filename or raises error.
|
||||
"""
|
||||
# Remove null bytes
|
||||
filename = filename.replace("\0", "")
|
||||
|
||||
# Extract basename (strips any path components)
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
# Limit length
|
||||
max_length = 255
|
||||
if len(filename) > max_length:
|
||||
name, ext = os.path.splitext(filename)
|
||||
filename = name[:max_length-len(ext)] + ext
|
||||
|
||||
# Validate
|
||||
validate_filename(filename)
|
||||
|
||||
return filename
|
||||
```
|
||||
|
||||
## Deployment Checklist
|
||||
|
||||
- [ ] Update validation function with recommended fix
|
||||
- [ ] Add comprehensive test suite
|
||||
- [ ] Test with real-world filenames (including bug report case)
|
||||
- [ ] Test security: attempt path traversal attacks
|
||||
- [ ] Update API documentation
|
||||
- [ ] Review error messages for clarity
|
||||
- [ ] Deploy to staging environment
|
||||
- [ ] Run integration tests
|
||||
- [ ] Monitor logs for validation failures
|
||||
- [ ] Deploy to production
|
||||
- [ ] Verify bug reporter's file now works
|
||||
|
||||
## Contact & Context
|
||||
|
||||
**Bug Report Date**: 2025-10-26
|
||||
**Affected Endpoint**: `POST /transcribe`
|
||||
**Error Code**: HTTP 500
|
||||
**Client Application**: yt-dlp-webui v3
|
||||
|
||||
**Example Failing Request:**
|
||||
```
|
||||
POST http://192.168.1.210:33767/transcribe
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
file: "This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a"
|
||||
model: "medium"
|
||||
```
|
||||
|
||||
**Current Behavior**: Returns 500 error with path traversal message
|
||||
**Expected Behavior**: Accepts file and processes transcription
|
||||
|
||||
---
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Files to Check
|
||||
- `/path/to/api/validators.py` or similar
|
||||
- `/path/to/api/upload_handler.py`
|
||||
- `/path/to/api/routes/transcribe.py`
|
||||
|
||||
### Search Commands
|
||||
```bash
|
||||
# Find validation code
|
||||
rg "Path traversal" --type py
|
||||
rg '"\.\."' --type py
|
||||
rg "ValidationError.*filename" --type py
|
||||
|
||||
# Find upload handlers
|
||||
rg "def.*upload|def.*transcribe" --type py
|
||||
```
|
||||
|
||||
### Priority Fix
|
||||
Use **Option 1 (Path Component Validation)** - it provides the best balance of security and usability.
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
语音识别MCP服务模块
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
@@ -1,67 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
音频处理模块
|
||||
负责音频文件的验证和预处理
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Union, Any
|
||||
from faster_whisper import decode_audio
|
||||
|
||||
# 日志配置
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def validate_audio_file(audio_path: str) -> str:
|
||||
"""
|
||||
验证音频文件是否有效
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
|
||||
Returns:
|
||||
str: 验证结果,"ok"表示验证通过,否则返回错误信息
|
||||
"""
|
||||
# 验证参数
|
||||
if not os.path.exists(audio_path):
|
||||
return f"错误: 音频文件不存在: {audio_path}"
|
||||
|
||||
# 验证文件格式
|
||||
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
||||
file_ext = os.path.splitext(audio_path)[1].lower()
|
||||
if file_ext not in supported_formats:
|
||||
return f"错误: 不支持的音频格式: {file_ext}。支持的格式: {', '.join(supported_formats)}"
|
||||
|
||||
# 验证文件大小
|
||||
try:
|
||||
file_size = os.path.getsize(audio_path)
|
||||
if file_size == 0:
|
||||
return f"错误: 音频文件为空: {audio_path}"
|
||||
|
||||
# 大文件警告(超过1GB)
|
||||
if file_size > 1024 * 1024 * 1024:
|
||||
logger.warning(f"警告: 文件大小超过1GB,可能需要较长处理时间: {audio_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"检查文件大小失败: {str(e)}")
|
||||
return f"错误: 检查文件大小失败: {str(e)}"
|
||||
|
||||
return "ok"
|
||||
|
||||
def process_audio(audio_path: str) -> Union[str, Any]:
|
||||
"""
|
||||
处理音频文件,进行解码和预处理
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
|
||||
Returns:
|
||||
Union[str, Any]: 处理后的音频数据或原始文件路径
|
||||
"""
|
||||
# 尝试使用decode_audio预处理音频,以处理更多格式
|
||||
try:
|
||||
audio_data = decode_audio(audio_path)
|
||||
logger.info(f"成功预处理音频: {os.path.basename(audio_path)}")
|
||||
return audio_data
|
||||
except Exception as audio_error:
|
||||
logger.warning(f"音频预处理失败,将直接使用文件路径: {str(audio_error)}")
|
||||
return audio_path
|
||||
19
docker-build.sh
Executable file
19
docker-build.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
echo "$(datetime_prefix) Building Whisper Transcriptor Docker image..."
|
||||
|
||||
# Build the Docker image
|
||||
docker build -t transcriptor-apimcp:latest .
|
||||
|
||||
echo "$(datetime_prefix) Build complete!"
|
||||
echo "$(datetime_prefix) Image: transcriptor-apimcp:latest"
|
||||
echo ""
|
||||
echo "Usage:"
|
||||
echo " API mode: ./docker-run-api.sh"
|
||||
echo " MCP mode: ./docker-run-mcp.sh"
|
||||
echo " Or use: docker-compose up transcriptor-api"
|
||||
106
docker-compose.yml
Normal file
106
docker-compose.yml
Normal file
@@ -0,0 +1,106 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# API Server mode with nginx reverse proxy
|
||||
transcriptor-api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: transcriptor-apimcp:latest
|
||||
container_name: transcriptor-api
|
||||
runtime: nvidia
|
||||
environment:
|
||||
NVIDIA_VISIBLE_DEVICES: "0"
|
||||
NVIDIA_DRIVER_CAPABILITIES: compute,utility
|
||||
SERVER_MODE: api
|
||||
API_HOST: 127.0.0.1
|
||||
API_PORT: 33767
|
||||
WHISPER_MODEL_DIR: /models
|
||||
TRANSCRIPTION_OUTPUT_DIR: /outputs
|
||||
TRANSCRIPTION_BATCH_OUTPUT_DIR: /outputs/batch
|
||||
TRANSCRIPTION_MODEL: large-v3
|
||||
TRANSCRIPTION_DEVICE: auto
|
||||
TRANSCRIPTION_COMPUTE_TYPE: auto
|
||||
TRANSCRIPTION_OUTPUT_FORMAT: txt
|
||||
TRANSCRIPTION_BEAM_SIZE: 5
|
||||
TRANSCRIPTION_TEMPERATURE: 0.0
|
||||
JOB_QUEUE_MAX_SIZE: 5
|
||||
JOB_METADATA_DIR: /outputs/jobs
|
||||
JOB_RETENTION_DAYS: 7
|
||||
GPU_HEALTH_CHECK_ENABLED: "true"
|
||||
GPU_HEALTH_CHECK_INTERVAL_MINUTES: 10
|
||||
GPU_HEALTH_TEST_MODEL: tiny
|
||||
GPU_HEALTH_TEST_AUDIO: /test-audio/test.mp3
|
||||
GPU_RESET_COOLDOWN_MINUTES: 5
|
||||
# Optional proxy settings (uncomment if needed)
|
||||
# HTTP_PROXY: http://192.168.1.212:8080
|
||||
# HTTPS_PROXY: http://192.168.1.212:8080
|
||||
ports:
|
||||
- "33767:80" # Map host:33767 to container nginx:80
|
||||
volumes:
|
||||
- /home/uad/agents/tools/mcp-transcriptor/models:/models
|
||||
- /home/uad/agents/tools/mcp-transcriptor/outputs:/outputs
|
||||
- /home/uad/agents/tools/mcp-transcriptor/logs:/logs
|
||||
- /home/uad/agents/tools/mcp-transcriptor/data/test.mp3:/test-audio/test.mp3:ro
|
||||
- /etc/localtime:/etc/localtime:ro # Sync container time with host
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 40s
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- transcriptor-network
|
||||
|
||||
# MCP Server mode (stdio based)
|
||||
transcriptor-mcp:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: transcriptor-apimcp:latest
|
||||
container_name: transcriptor-mcp
|
||||
environment:
|
||||
SERVER_MODE: mcp
|
||||
WHISPER_MODEL_DIR: /models
|
||||
TRANSCRIPTION_OUTPUT_DIR: /outputs
|
||||
TRANSCRIPTION_BATCH_OUTPUT_DIR: /outputs/batch
|
||||
TRANSCRIPTION_MODEL: large-v3
|
||||
TRANSCRIPTION_DEVICE: auto
|
||||
TRANSCRIPTION_COMPUTE_TYPE: auto
|
||||
TRANSCRIPTION_OUTPUT_FORMAT: txt
|
||||
TRANSCRIPTION_BEAM_SIZE: 5
|
||||
TRANSCRIPTION_TEMPERATURE: 0.0
|
||||
JOB_QUEUE_MAX_SIZE: 100
|
||||
JOB_METADATA_DIR: /outputs/jobs
|
||||
JOB_RETENTION_DAYS: 7
|
||||
GPU_HEALTH_CHECK_ENABLED: "true"
|
||||
GPU_HEALTH_CHECK_INTERVAL_MINUTES: 10
|
||||
GPU_HEALTH_TEST_MODEL: tiny
|
||||
GPU_RESET_COOLDOWN_MINUTES: 5
|
||||
# Optional proxy settings (uncomment if needed)
|
||||
# HTTP_PROXY: http://192.168.1.212:8080
|
||||
# HTTPS_PROXY: http://192.168.1.212:8080
|
||||
volumes:
|
||||
- /home/uad/agents/tools/mcp-transcriptor/models:/models
|
||||
- /home/uad/agents/tools/mcp-transcriptor/outputs:/outputs
|
||||
- /home/uad/agents/tools/mcp-transcriptor/logs:/logs
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
stdin_open: true # Enable stdin for MCP stdio mode
|
||||
tty: true
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- transcriptor-network
|
||||
profiles:
|
||||
- mcp # Only start when explicitly requested
|
||||
|
||||
networks:
|
||||
transcriptor-network:
|
||||
driver: bridge
|
||||
67
docker-entrypoint.sh
Executable file
67
docker-entrypoint.sh
Executable file
@@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Docker Entrypoint Script for Whisper Transcriptor
|
||||
# Supports both MCP and API server modes
|
||||
|
||||
datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
echo "$(datetime_prefix) Starting Whisper Transcriptor in ${SERVER_MODE} mode..."
|
||||
|
||||
# Ensure required directories exist
|
||||
mkdir -p "$WHISPER_MODEL_DIR"
|
||||
mkdir -p "$TRANSCRIPTION_OUTPUT_DIR"
|
||||
mkdir -p "$TRANSCRIPTION_BATCH_OUTPUT_DIR"
|
||||
mkdir -p "$JOB_METADATA_DIR"
|
||||
mkdir -p /app/outputs/uploads
|
||||
|
||||
# Display GPU information
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
echo "$(datetime_prefix) GPU Information:"
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
else
|
||||
echo "$(datetime_prefix) Warning: nvidia-smi not found. GPU may not be available."
|
||||
fi
|
||||
|
||||
# Check server mode and start appropriate service
|
||||
case "${SERVER_MODE}" in
|
||||
"api")
|
||||
echo "$(datetime_prefix) Starting API Server mode with nginx reverse proxy"
|
||||
|
||||
# Update nginx configuration to use correct backend
|
||||
sed -i "s/server 127.0.0.1:33767;/server ${API_HOST}:${API_PORT};/" /etc/nginx/sites-available/transcriptor.conf
|
||||
|
||||
# Enable nginx site
|
||||
ln -sf /etc/nginx/sites-available/transcriptor.conf /etc/nginx/sites-enabled/
|
||||
rm -f /etc/nginx/sites-enabled/default
|
||||
|
||||
# Test nginx configuration
|
||||
echo "$(datetime_prefix) Testing nginx configuration..."
|
||||
nginx -t
|
||||
|
||||
# Start nginx in background
|
||||
echo "$(datetime_prefix) Starting nginx..."
|
||||
nginx
|
||||
|
||||
# Start API server (foreground - this keeps container running)
|
||||
echo "$(datetime_prefix) Starting API server on ${API_HOST}:${API_PORT}"
|
||||
echo "$(datetime_prefix) API accessible via nginx on port 80"
|
||||
exec python -u /app/src/servers/api_server.py
|
||||
;;
|
||||
|
||||
"mcp")
|
||||
echo "$(datetime_prefix) Starting MCP Server mode (stdio)"
|
||||
echo "$(datetime_prefix) Model directory: $WHISPER_MODEL_DIR"
|
||||
|
||||
# Start MCP server in stdio mode
|
||||
exec python -u /app/src/servers/whisper_server.py
|
||||
;;
|
||||
|
||||
*)
|
||||
echo "$(datetime_prefix) ERROR: Invalid SERVER_MODE: ${SERVER_MODE}"
|
||||
echo "$(datetime_prefix) Valid modes: 'api' or 'mcp'"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
62
docker-run-api.sh
Executable file
62
docker-run-api.sh
Executable file
@@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
echo "$(datetime_prefix) Starting Whisper Transcriptor in API mode with nginx..."
|
||||
|
||||
# Check if image exists
|
||||
if ! docker image inspect transcriptor-apimcp:latest &> /dev/null; then
|
||||
echo "$(datetime_prefix) Image not found. Building first..."
|
||||
./docker-build.sh
|
||||
fi
|
||||
|
||||
# Stop and remove existing container if running
|
||||
if docker ps -a --format '{{.Names}}' | grep -q '^transcriptor-api$'; then
|
||||
echo "$(datetime_prefix) Stopping existing container..."
|
||||
docker stop transcriptor-api || true
|
||||
docker rm transcriptor-api || true
|
||||
fi
|
||||
|
||||
# Run the container in API mode
|
||||
docker run -d \
|
||||
--name transcriptor-api \
|
||||
--gpus all \
|
||||
-p 33767:80 \
|
||||
-e SERVER_MODE=api \
|
||||
-e API_HOST=127.0.0.1 \
|
||||
-e API_PORT=33767 \
|
||||
-e CUDA_VISIBLE_DEVICES=0 \
|
||||
-e TRANSCRIPTION_MODEL=large-v3 \
|
||||
-e TRANSCRIPTION_DEVICE=auto \
|
||||
-e TRANSCRIPTION_COMPUTE_TYPE=auto \
|
||||
-e JOB_QUEUE_MAX_SIZE=5 \
|
||||
-v "$(pwd)/models:/models" \
|
||||
-v "$(pwd)/outputs:/outputs" \
|
||||
-v "$(pwd)/logs:/logs" \
|
||||
--restart unless-stopped \
|
||||
transcriptor-apimcp:latest
|
||||
|
||||
echo "$(datetime_prefix) Container started!"
|
||||
echo ""
|
||||
echo "API Server running at: http://localhost:33767"
|
||||
echo ""
|
||||
echo "Useful commands:"
|
||||
echo " Check logs: docker logs -f transcriptor-api"
|
||||
echo " Check status: docker ps | grep transcriptor-api"
|
||||
echo " Test health: curl http://localhost:33767/health"
|
||||
echo " Test GPU: curl http://localhost:33767/health/gpu"
|
||||
echo " Stop container: docker stop transcriptor-api"
|
||||
echo " Restart: docker restart transcriptor-api"
|
||||
echo ""
|
||||
echo "$(datetime_prefix) Waiting for service to start..."
|
||||
sleep 5
|
||||
|
||||
# Test health endpoint
|
||||
if curl -s http://localhost:33767/health > /dev/null 2>&1; then
|
||||
echo "$(datetime_prefix) ✓ Service is healthy!"
|
||||
else
|
||||
echo "$(datetime_prefix) ⚠ Service not responding yet. Check logs with: docker logs transcriptor-api"
|
||||
fi
|
||||
40
docker-run-mcp.sh
Executable file
40
docker-run-mcp.sh
Executable file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
echo "$(datetime_prefix) Starting Whisper Transcriptor in MCP mode..."
|
||||
|
||||
# Check if image exists
|
||||
if ! docker image inspect transcriptor-apimcp:latest &> /dev/null; then
|
||||
echo "$(datetime_prefix) Image not found. Building first..."
|
||||
./docker-build.sh
|
||||
fi
|
||||
|
||||
# Stop and remove existing container if running
|
||||
if docker ps -a --format '{{.Names}}' | grep -q '^transcriptor-mcp$'; then
|
||||
echo "$(datetime_prefix) Stopping existing container..."
|
||||
docker stop transcriptor-mcp || true
|
||||
docker rm transcriptor-mcp || true
|
||||
fi
|
||||
|
||||
# Run the container in MCP mode (interactive stdio)
|
||||
echo "$(datetime_prefix) Starting MCP server in stdio mode..."
|
||||
echo "$(datetime_prefix) Press Ctrl+C to stop"
|
||||
echo ""
|
||||
|
||||
docker run -it --rm \
|
||||
--name transcriptor-mcp \
|
||||
--gpus all \
|
||||
-e SERVER_MODE=mcp \
|
||||
-e CUDA_VISIBLE_DEVICES=0 \
|
||||
-e TRANSCRIPTION_MODEL=large-v3 \
|
||||
-e TRANSCRIPTION_DEVICE=auto \
|
||||
-e TRANSCRIPTION_COMPUTE_TYPE=auto \
|
||||
-e JOB_QUEUE_MAX_SIZE=100 \
|
||||
-v "$(pwd)/models:/models" \
|
||||
-v "$(pwd)/outputs:/outputs" \
|
||||
-v "$(pwd)/logs:/logs" \
|
||||
transcriptor-apimcp:latest
|
||||
176
model_manager.py
176
model_manager.py
@@ -1,176 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
模型管理模块
|
||||
负责Whisper模型的加载、缓存和管理
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
import torch
|
||||
from faster_whisper import WhisperModel, BatchedInferencePipeline
|
||||
|
||||
# 日志配置
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局模型实例缓存
|
||||
model_instances = {}
|
||||
|
||||
def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取或创建Whisper模型实例
|
||||
|
||||
Args:
|
||||
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: 运行设备 (cpu, cuda, auto)
|
||||
compute_type: 计算类型 (float16, int8, auto)
|
||||
|
||||
Returns:
|
||||
dict: 包含模型实例和配置的字典
|
||||
"""
|
||||
global model_instances
|
||||
|
||||
# 验证模型名称
|
||||
valid_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"]
|
||||
if model_name not in valid_models:
|
||||
raise ValueError(f"无效的模型名称: {model_name}。有效的模型: {', '.join(valid_models)}")
|
||||
|
||||
# 自动检测设备
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
|
||||
# 验证设备和计算类型
|
||||
if device not in ["cpu", "cuda"]:
|
||||
raise ValueError(f"无效的设备: {device}。有效的设备: cpu, cuda")
|
||||
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
logger.warning("CUDA不可用,自动切换到CPU")
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
|
||||
if compute_type not in ["float16", "int8"]:
|
||||
raise ValueError(f"无效的计算类型: {compute_type}。有效的计算类型: float16, int8")
|
||||
|
||||
if device == "cpu" and compute_type == "float16":
|
||||
logger.warning("CPU设备不支持float16计算类型,自动切换到int8")
|
||||
compute_type = "int8"
|
||||
|
||||
# 生成模型键
|
||||
model_key = f"{model_name}_{device}_{compute_type}"
|
||||
|
||||
# 如果模型已实例化,直接返回
|
||||
if model_key in model_instances:
|
||||
logger.info(f"使用缓存的模型实例: {model_key}")
|
||||
return model_instances[model_key]
|
||||
|
||||
# 清理GPU内存(如果使用CUDA)
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 实例化模型
|
||||
try:
|
||||
logger.info(f"加载Whisper模型: {model_name} 设备: {device} 计算类型: {compute_type}")
|
||||
|
||||
# 基础模型
|
||||
model = WhisperModel(
|
||||
model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
download_root=os.environ.get("WHISPER_MODEL_DIR", None) # 支持自定义模型目录
|
||||
)
|
||||
|
||||
# 批处理设置 - 默认启用批处理以提高速度
|
||||
batched_model = None
|
||||
batch_size = 0
|
||||
|
||||
if device == "cuda": # 只在CUDA设备上使用批处理
|
||||
# 根据显存大小确定合适的批大小
|
||||
if torch.cuda.is_available():
|
||||
gpu_mem = torch.cuda.get_device_properties(0).total_memory
|
||||
free_mem = gpu_mem - torch.cuda.memory_allocated()
|
||||
# 根据GPU显存动态调整批大小
|
||||
if free_mem > 16e9: # >16GB
|
||||
batch_size = 32
|
||||
elif free_mem > 12e9: # >12GB
|
||||
batch_size = 16
|
||||
elif free_mem > 8e9: # >8GB
|
||||
batch_size = 8
|
||||
elif free_mem > 4e9: # >4GB
|
||||
batch_size = 4
|
||||
else: # 较小显存
|
||||
batch_size = 2
|
||||
|
||||
logger.info(f"可用GPU显存: {free_mem / 1e9:.2f} GB")
|
||||
else:
|
||||
batch_size = 8 # 默认值
|
||||
|
||||
logger.info(f"启用批处理加速,批大小: {batch_size}")
|
||||
batched_model = BatchedInferencePipeline(model=model)
|
||||
|
||||
# 创建结果对象
|
||||
result = {
|
||||
'model': model,
|
||||
'device': device,
|
||||
'compute_type': compute_type,
|
||||
'batched_model': batched_model,
|
||||
'batch_size': batch_size,
|
||||
'load_time': time.time()
|
||||
}
|
||||
|
||||
# 缓存实例
|
||||
model_instances[model_key] = result
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载模型失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_model_info() -> str:
|
||||
"""
|
||||
获取可用的Whisper模型信息
|
||||
|
||||
Returns:
|
||||
str: 模型信息的JSON字符串
|
||||
"""
|
||||
import json
|
||||
|
||||
models = [
|
||||
"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"
|
||||
]
|
||||
devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
|
||||
compute_types = ["float16", "int8"] if torch.cuda.is_available() else ["int8"]
|
||||
|
||||
# 支持的语言列表
|
||||
languages = {
|
||||
"zh": "中文", "en": "英语", "ja": "日语", "ko": "韩语", "de": "德语",
|
||||
"fr": "法语", "es": "西班牙语", "ru": "俄语", "it": "意大利语",
|
||||
"pt": "葡萄牙语", "nl": "荷兰语", "ar": "阿拉伯语", "hi": "印地语",
|
||||
"tr": "土耳其语", "vi": "越南语", "th": "泰语", "id": "印尼语"
|
||||
}
|
||||
|
||||
# 支持的音频格式
|
||||
audio_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
||||
|
||||
info = {
|
||||
"available_models": models,
|
||||
"default_model": "large-v3",
|
||||
"available_devices": devices,
|
||||
"default_device": "cuda" if torch.cuda.is_available() else "cpu",
|
||||
"available_compute_types": compute_types,
|
||||
"default_compute_type": "float16" if torch.cuda.is_available() else "int8",
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"supported_languages": languages,
|
||||
"supported_audio_formats": audio_formats,
|
||||
"version": "0.1.1"
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
info["gpu_info"] = {
|
||||
"name": torch.cuda.get_device_name(0),
|
||||
"memory_total": f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB",
|
||||
"memory_available": f"{torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated() / 1e9:.2f} GB"
|
||||
}
|
||||
|
||||
return json.dumps(info, indent=2)
|
||||
132
nginx/README.md
Normal file
132
nginx/README.md
Normal file
@@ -0,0 +1,132 @@
|
||||
# Nginx Configuration for Transcriptor API
|
||||
|
||||
This directory contains nginx reverse proxy configuration to fix 504 Gateway Timeout errors.
|
||||
|
||||
## Problem
|
||||
|
||||
The transcriptor API can take a long time (10+ minutes) to process large audio files with the `large-v3` model. Without proper timeout configuration, requests will fail with 504 Gateway Timeout.
|
||||
|
||||
## Solution
|
||||
|
||||
The provided `transcriptor.conf` file configures nginx with appropriate timeouts:
|
||||
|
||||
- **proxy_connect_timeout**: 600s (10 minutes)
|
||||
- **proxy_send_timeout**: 600s (10 minutes)
|
||||
- **proxy_read_timeout**: 3600s (1 hour)
|
||||
- **client_max_body_size**: 500M (for large audio files)
|
||||
|
||||
## Installation
|
||||
|
||||
### Option 1: Deploy nginx configuration (if using nginx)
|
||||
|
||||
```bash
|
||||
# Copy configuration to nginx
|
||||
sudo cp transcriptor.conf /etc/nginx/sites-available/
|
||||
|
||||
# Create symlink to enable it
|
||||
sudo ln -s /etc/nginx/sites-available/transcriptor.conf /etc/nginx/sites-enabled/
|
||||
|
||||
# Test configuration
|
||||
sudo nginx -t
|
||||
|
||||
# Reload nginx
|
||||
sudo systemctl reload nginx
|
||||
```
|
||||
|
||||
### Option 2: Run API server directly (current setup)
|
||||
|
||||
The API server at `src/servers/api_server.py` has been updated with:
|
||||
- `timeout_keep_alive=3600` (1 hour)
|
||||
- `timeout_graceful_shutdown=60`
|
||||
|
||||
No additional nginx configuration is needed if you're running the API directly.
|
||||
|
||||
## Restart Service
|
||||
|
||||
After making changes, restart the transcriptor service:
|
||||
|
||||
```bash
|
||||
# If using supervisor
|
||||
sudo supervisorctl restart transcriptor-api
|
||||
|
||||
# If using systemd
|
||||
sudo systemctl restart transcriptor-api
|
||||
|
||||
# If using docker
|
||||
docker restart <container-name>
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Test the API is working:
|
||||
|
||||
```bash
|
||||
# Health check (should return 200)
|
||||
curl http://192.168.1.210:33767/health
|
||||
|
||||
# Check timeout configuration
|
||||
curl -X POST http://192.168.1.210:33767/transcribe \
|
||||
-F "file=@test_audio.mp3" \
|
||||
-F "model=large-v3" \
|
||||
-F "output_format=txt"
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
Check logs for timeout warnings:
|
||||
|
||||
```bash
|
||||
# Supervisor logs
|
||||
tail -f /home/uad/agents/tools/mcp-transcriptor/logs/transcriptor-api.log
|
||||
|
||||
# Look for messages like:
|
||||
# - "Job {job_id} is taking longer than expected: 610.5s elapsed (threshold: 600s)"
|
||||
# - "Job {job_id} exceeded maximum timeout: 3610.2s elapsed (max: 3600s)"
|
||||
```
|
||||
|
||||
## Configuration Environment Variables
|
||||
|
||||
You can also configure timeouts via environment variables in `supervisor/transcriptor-api.conf`:
|
||||
|
||||
```ini
|
||||
environment=
|
||||
...
|
||||
JOB_TIMEOUT_WARNING_SECONDS="600", # Warn after 10 minutes
|
||||
JOB_TIMEOUT_MAX_SECONDS="3600", # Fail after 1 hour
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Still getting 504 errors?
|
||||
|
||||
1. **Check service is running**:
|
||||
```bash
|
||||
sudo supervisorctl status transcriptor-api
|
||||
```
|
||||
|
||||
2. **Check port is listening**:
|
||||
```bash
|
||||
sudo netstat -tlnp | grep 33767
|
||||
```
|
||||
|
||||
3. **Check logs for errors**:
|
||||
```bash
|
||||
tail -100 /home/uad/agents/tools/mcp-transcriptor/logs/transcriptor-api.log
|
||||
```
|
||||
|
||||
4. **Test direct connection** (bypass nginx):
|
||||
```bash
|
||||
curl http://localhost:33767/health
|
||||
```
|
||||
|
||||
5. **Verify GPU is working**:
|
||||
```bash
|
||||
curl http://192.168.1.210:33767/health/gpu
|
||||
```
|
||||
|
||||
### Job takes too long?
|
||||
|
||||
Consider:
|
||||
- Using a smaller model (e.g., `medium` instead of `large-v3`)
|
||||
- Splitting large audio files into smaller chunks
|
||||
- Increasing `JOB_TIMEOUT_MAX_SECONDS` for very long audio files
|
||||
85
nginx/transcriptor.conf
Normal file
85
nginx/transcriptor.conf
Normal file
@@ -0,0 +1,85 @@
|
||||
# Nginx reverse proxy configuration for Whisper Transcriptor API
|
||||
# Place this file in /etc/nginx/sites-available/ and symlink to sites-enabled/
|
||||
|
||||
upstream transcriptor_backend {
|
||||
# Backend transcriptor API server
|
||||
server 127.0.0.1:33767;
|
||||
|
||||
# Connection pooling
|
||||
keepalive 32;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80;
|
||||
server_name transcriptor.local; # Change to your domain
|
||||
|
||||
# Increase client body size for large audio uploads (up to 500MB)
|
||||
client_max_body_size 500M;
|
||||
|
||||
# Timeouts for long-running transcription jobs
|
||||
proxy_connect_timeout 600s; # 10 minutes to establish connection
|
||||
proxy_send_timeout 600s; # 10 minutes to send request
|
||||
proxy_read_timeout 3600s; # 1 hour to read response (transcription can be slow)
|
||||
|
||||
# Buffer settings for large responses
|
||||
proxy_buffering on;
|
||||
proxy_buffer_size 4k;
|
||||
proxy_buffers 8 4k;
|
||||
proxy_busy_buffers_size 8k;
|
||||
|
||||
# API endpoints
|
||||
location / {
|
||||
proxy_pass http://transcriptor_backend;
|
||||
|
||||
# Forward client info
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# HTTP/1.1 for keepalive
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Connection "";
|
||||
|
||||
# Disable buffering for streaming endpoints
|
||||
proxy_request_buffering off;
|
||||
}
|
||||
|
||||
# Health check endpoint with shorter timeout
|
||||
location /health {
|
||||
proxy_pass http://transcriptor_backend;
|
||||
proxy_read_timeout 10s;
|
||||
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
}
|
||||
|
||||
# Access and error logs
|
||||
access_log /var/log/nginx/transcriptor_access.log;
|
||||
error_log /var/log/nginx/transcriptor_error.log warn;
|
||||
}
|
||||
|
||||
# HTTPS configuration (optional, recommended for production)
|
||||
# server {
|
||||
# listen 443 ssl http2;
|
||||
# server_name transcriptor.local;
|
||||
#
|
||||
# ssl_certificate /etc/ssl/certs/transcriptor.crt;
|
||||
# ssl_certificate_key /etc/ssl/private/transcriptor.key;
|
||||
#
|
||||
# # SSL settings
|
||||
# ssl_protocols TLSv1.2 TLSv1.3;
|
||||
# ssl_ciphers HIGH:!aNULL:!MD5;
|
||||
# ssl_prefer_server_ciphers on;
|
||||
#
|
||||
# # Same settings as HTTP above
|
||||
# client_max_body_size 500M;
|
||||
# proxy_connect_timeout 600s;
|
||||
# proxy_send_timeout 600s;
|
||||
# proxy_read_timeout 3600s;
|
||||
#
|
||||
# location / {
|
||||
# proxy_pass http://transcriptor_backend;
|
||||
# # ... (same proxy settings as above)
|
||||
# }
|
||||
# }
|
||||
@@ -1,9 +1,12 @@
|
||||
[project]
|
||||
name = "fast-whisper-mcp-server"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
version = "0.1.1"
|
||||
description = "High-performance speech recognition service with MCP and REST API servers"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"faster-whisper>=1.1.1",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["src"]
|
||||
|
||||
@@ -1,22 +1,37 @@
|
||||
# uv pip install -r ./requirements.txt --index-url https://download.pytorch.org/whl/cu126
|
||||
# uv pip install -r ./requirements.txt --index-url https://download.pytorch.org/whl/cu124
|
||||
faster-whisper
|
||||
torch==2.6.0+cu126
|
||||
torchaudio==2.6.0+cu126
|
||||
torch #==2.6.0+cu124
|
||||
torchaudio #==2.6.0+cu124
|
||||
|
||||
# uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
# uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
|
||||
# pip install faster-whisper>=0.9.0
|
||||
# pip install mcp[cli]>=1.2.0
|
||||
mcp[cli]
|
||||
|
||||
# PyTorch安装指南:
|
||||
# 请根据您的CUDA版本安装适当版本的PyTorch:
|
||||
# REST API dependencies
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.32.0
|
||||
python-multipart>=0.0.9
|
||||
aiofiles>=23.0.0 # Async file I/O
|
||||
|
||||
# Test audio generation dependencies
|
||||
gTTS>=2.3.0
|
||||
pyttsx3>=2.90
|
||||
scipy>=1.10.0
|
||||
numpy>=1.24.0
|
||||
|
||||
# PyTorch Installation Guide:
|
||||
# Please install the appropriate version of PyTorch based on your CUDA version:
|
||||
#
|
||||
# • CUDA 12.6:
|
||||
# pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
#
|
||||
# • CUDA 12.4:
|
||||
# pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
|
||||
#
|
||||
# • CUDA 12.1:
|
||||
# pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
||||
#
|
||||
# • CPU版本:
|
||||
# • CPU version:
|
||||
# pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
#
|
||||
# 可用命令`nvcc --version`或`nvidia-smi`查看CUDA版本
|
||||
98
reset_gpu.sh
Executable file
98
reset_gpu.sh
Executable file
@@ -0,0 +1,98 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to reset NVIDIA GPU drivers without rebooting
|
||||
# This reloads kernel modules and restarts nvidia-persistenced service
|
||||
# Also handles stopping/starting Ollama to release GPU resources
|
||||
|
||||
echo "============================================================"
|
||||
echo "NVIDIA GPU Driver Reset Script"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
|
||||
# Stop Ollama via supervisorctl
|
||||
echo "Stopping Ollama service..."
|
||||
sudo supervisorctl stop ollama 2>/dev/null
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✓ Ollama stopped via supervisorctl"
|
||||
OLLAMA_WAS_RUNNING=true
|
||||
else
|
||||
echo " Ollama not running or supervisorctl not available"
|
||||
OLLAMA_WAS_RUNNING=false
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Give Ollama time to release GPU resources
|
||||
sleep 2
|
||||
|
||||
# Stop nvidia-persistenced service
|
||||
echo "Stopping nvidia-persistenced service..."
|
||||
sudo systemctl stop nvidia-persistenced
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✓ nvidia-persistenced stopped"
|
||||
else
|
||||
echo "✗ Failed to stop nvidia-persistenced"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Unload NVIDIA kernel modules (in correct order)
|
||||
echo "Unloading NVIDIA kernel modules..."
|
||||
sudo rmmod nvidia_uvm 2>/dev/null && echo "✓ nvidia_uvm unloaded" || echo " nvidia_uvm not loaded or failed to unload"
|
||||
sudo rmmod nvidia_drm 2>/dev/null && echo "✓ nvidia_drm unloaded" || echo " nvidia_drm not loaded or failed to unload"
|
||||
sudo rmmod nvidia_modeset 2>/dev/null && echo "✓ nvidia_modeset unloaded" || echo " nvidia_modeset not loaded or failed to unload"
|
||||
sudo rmmod nvidia 2>/dev/null && echo "✓ nvidia unloaded" || echo " nvidia not loaded or failed to unload"
|
||||
echo ""
|
||||
|
||||
# Small delay to ensure clean unload
|
||||
sleep 1
|
||||
|
||||
# Reload NVIDIA kernel modules (in correct order)
|
||||
echo "Loading NVIDIA kernel modules..."
|
||||
sudo modprobe nvidia && echo "✓ nvidia loaded" || { echo "✗ Failed to load nvidia"; exit 1; }
|
||||
sudo modprobe nvidia_modeset && echo "✓ nvidia_modeset loaded" || { echo "✗ Failed to load nvidia_modeset"; exit 1; }
|
||||
sudo modprobe nvidia_uvm && echo "✓ nvidia_uvm loaded" || { echo "✗ Failed to load nvidia_uvm"; exit 1; }
|
||||
sudo modprobe nvidia_drm && echo "✓ nvidia_drm loaded" || { echo "✗ Failed to load nvidia_drm"; exit 1; }
|
||||
echo ""
|
||||
|
||||
# Restart nvidia-persistenced service
|
||||
echo "Starting nvidia-persistenced service..."
|
||||
sudo systemctl start nvidia-persistenced
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✓ nvidia-persistenced started"
|
||||
else
|
||||
echo "✗ Failed to start nvidia-persistenced"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Verify GPU is accessible
|
||||
echo "Verifying GPU accessibility..."
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✓ GPU reset successful"
|
||||
else
|
||||
echo "✗ GPU not accessible"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✗ nvidia-smi not found"
|
||||
exit 1
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Restart Ollama if it was running
|
||||
if [ "$OLLAMA_WAS_RUNNING" = true ]; then
|
||||
echo "Restarting Ollama service..."
|
||||
sudo supervisorctl start ollama
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✓ Ollama restarted"
|
||||
else
|
||||
echo "✗ Failed to restart Ollama"
|
||||
fi
|
||||
echo ""
|
||||
fi
|
||||
|
||||
echo "============================================================"
|
||||
echo "GPU driver reset completed successfully"
|
||||
echo "============================================================"
|
||||
66
run_api_server.sh
Executable file
66
run_api_server.sh
Executable file
@@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
# Set Python path to include src directory
|
||||
export PYTHONPATH="/home/uad/agents/tools/mcp-transcriptor/src:$PYTHONPATH"
|
||||
|
||||
# Set CUDA library path
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/targets/x86_64-linux/lib:$LD_LIBRARY_PATH
|
||||
|
||||
# Set proxy for model downloads
|
||||
export HTTP_PROXY=http://192.168.1.212:8080
|
||||
export HTTPS_PROXY=http://192.168.1.212:8080
|
||||
|
||||
# Set environment variables
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/data/models"
|
||||
export TRANSCRIPTION_OUTPUT_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs"
|
||||
export TRANSCRIPTION_BATCH_OUTPUT_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/batch"
|
||||
export TRANSCRIPTION_MODEL="large-v3"
|
||||
export TRANSCRIPTION_DEVICE="cuda"
|
||||
export TRANSCRIPTION_COMPUTE_TYPE="float16"
|
||||
export TRANSCRIPTION_OUTPUT_FORMAT="txt"
|
||||
export TRANSCRIPTION_BEAM_SIZE="5"
|
||||
export TRANSCRIPTION_TEMPERATURE="0.0"
|
||||
export TRANSCRIPTION_USE_TIMESTAMP="false"
|
||||
export TRANSCRIPTION_FILENAME_PREFIX=""
|
||||
|
||||
# API server configuration
|
||||
export API_HOST="0.0.0.0"
|
||||
export API_PORT="33767"
|
||||
|
||||
# GPU Auto-Reset Configuration
|
||||
export GPU_RESET_COOLDOWN_MINUTES=5 # Minimum time between GPU reset attempts
|
||||
|
||||
# Job Queue Configuration
|
||||
export JOB_QUEUE_MAX_SIZE=5
|
||||
export JOB_METADATA_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
|
||||
export JOB_RETENTION_DAYS=7
|
||||
|
||||
# GPU Health Monitoring
|
||||
export GPU_HEALTH_CHECK_ENABLED=true
|
||||
export GPU_HEALTH_CHECK_INTERVAL_MINUTES=10
|
||||
export GPU_HEALTH_TEST_MODEL="tiny"
|
||||
|
||||
# Log start of the script
|
||||
echo "$(datetime_prefix) Starting Whisper REST API server..."
|
||||
echo "$(datetime_prefix) Model directory: $WHISPER_MODEL_DIR"
|
||||
echo "$(datetime_prefix) API server: http://$API_HOST:$API_PORT"
|
||||
|
||||
# Optional: Verify required directories exist
|
||||
if [ ! -d "$WHISPER_MODEL_DIR" ]; then
|
||||
echo "$(datetime_prefix) Warning: Whisper model directory does not exist: $WHISPER_MODEL_DIR"
|
||||
echo "$(datetime_prefix) Models will be downloaded to default cache directory"
|
||||
fi
|
||||
|
||||
# Ensure output directories exist
|
||||
mkdir -p "$TRANSCRIPTION_OUTPUT_DIR"
|
||||
mkdir -p "$TRANSCRIPTION_BATCH_OUTPUT_DIR"
|
||||
mkdir -p "$JOB_METADATA_DIR"
|
||||
|
||||
# Run the API server
|
||||
/home/uad/agents/tools/mcp-transcriptor/venv/bin/python -u /home/uad/agents/tools/mcp-transcriptor/src/servers/api_server.py 2>&1 | tee /home/uad/agents/tools/mcp-transcriptor/api.logs
|
||||
65
run_mcp_server.sh
Executable file
65
run_mcp_server.sh
Executable file
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
# Get current user ID to avoid permission issues
|
||||
USER_ID=$(id -u)
|
||||
GROUP_ID=$(id -g)
|
||||
|
||||
# Set Python path to include src directory
|
||||
export PYTHONPATH="/home/uad/agents/tools/mcp-transcriptor/src:$PYTHONPATH"
|
||||
|
||||
# Set CUDA library path
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda-12.4/targets/x86_64-linux/lib:$LD_LIBRARY_PATH
|
||||
|
||||
# Set proxy for model downloads
|
||||
export HTTP_PROXY=http://192.168.1.212:8080
|
||||
export HTTPS_PROXY=http://192.168.1.212:8080
|
||||
|
||||
# Set environment variables
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/data/models"
|
||||
export TRANSCRIPTION_OUTPUT_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs"
|
||||
export TRANSCRIPTION_BATCH_OUTPUT_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/batch"
|
||||
export TRANSCRIPTION_MODEL="large-v3"
|
||||
export TRANSCRIPTION_DEVICE="cuda"
|
||||
export TRANSCRIPTION_COMPUTE_TYPE="cuda"
|
||||
export TRANSCRIPTION_OUTPUT_FORMAT="txt"
|
||||
export TRANSCRIPTION_BEAM_SIZE="2"
|
||||
export TRANSCRIPTION_TEMPERATURE="0.0"
|
||||
export TRANSCRIPTION_USE_TIMESTAMP="false"
|
||||
export TRANSCRIPTION_FILENAME_PREFIX="test_"
|
||||
|
||||
# GPU Auto-Reset Configuration
|
||||
export GPU_RESET_COOLDOWN_MINUTES=5 # Minimum time between GPU reset attempts
|
||||
|
||||
# Job Queue Configuration
|
||||
export JOB_QUEUE_MAX_SIZE=100
|
||||
export JOB_METADATA_DIR="/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
|
||||
export JOB_RETENTION_DAYS=7
|
||||
|
||||
# GPU Health Monitoring
|
||||
export GPU_HEALTH_CHECK_ENABLED=true
|
||||
export GPU_HEALTH_CHECK_INTERVAL_MINUTES=10
|
||||
export GPU_HEALTH_TEST_MODEL="tiny"
|
||||
|
||||
# Log start of the script
|
||||
echo "$(datetime_prefix) Starting whisper server script..."
|
||||
echo "test: $WHISPER_MODEL_DIR"
|
||||
|
||||
# Optional: Verify required directories exist
|
||||
if [ ! -d "$WHISPER_MODEL_DIR" ]; then
|
||||
echo "$(datetime_prefix) Error: Whisper model directory does not exist: $WHISPER_MODEL_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure job metadata directory exists
|
||||
mkdir -p "$JOB_METADATA_DIR"
|
||||
|
||||
# Run the Python script with the defined environment variables
|
||||
#/home/uad/agents/tools/mcp-transcriptor/venv/bin/python /home/uad/agents/tools/mcp-transcriptor/whisper_server.py 2>&1 | tee /home/uad/agents/tools/mcp-transcriptor/mcp.logs
|
||||
/home/uad/agents/tools/mcp-transcriptor/venv/bin/python -u /home/uad/agents/tools/mcp-transcriptor/src/servers/whisper_server.py 2>&1 | tee /home/uad/agents/tools/mcp-transcriptor/mcp.logs
|
||||
|
||||
9
src/__init__.py
Normal file
9
src/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Faster Whisper MCP Transcription Service
|
||||
|
||||
High-performance audio transcription service with dual-server architecture
|
||||
(MCP and REST API) featuring async job queue and GPU health monitoring.
|
||||
"""
|
||||
|
||||
__version__ = "0.2.0"
|
||||
__author__ = "Whisper MCP Team"
|
||||
6
src/core/__init__.py
Normal file
6
src/core/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Core modules for Whisper transcription service.
|
||||
|
||||
Includes model management, transcription logic, job queue, GPU health monitoring,
|
||||
and GPU reset functionality.
|
||||
"""
|
||||
491
src/core/gpu_health.py
Normal file
491
src/core/gpu_health.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""
|
||||
GPU health monitoring for Whisper transcription service.
|
||||
|
||||
Performs real GPU health checks using actual model loading and transcription,
|
||||
with strict failure handling to prevent silent CPU fallbacks.
|
||||
Includes circuit breaker pattern to prevent repeated failed checks.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
import torch
|
||||
|
||||
from utils.circuit_breaker import CircuitBreaker, CircuitBreakerOpen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global circuit breaker for GPU health checks
|
||||
_gpu_health_circuit_breaker = CircuitBreaker(
|
||||
name="gpu_health_check",
|
||||
failure_threshold=3, # Open after 3 consecutive failures
|
||||
success_threshold=2, # Close after 2 consecutive successes
|
||||
timeout_seconds=60, # Try again after 60 seconds
|
||||
half_open_max_calls=1 # Only 1 test call in half-open state
|
||||
)
|
||||
|
||||
# Import reset functionality (after logger initialization)
|
||||
try:
|
||||
from core.gpu_reset import (
|
||||
reset_gpu_drivers,
|
||||
can_attempt_reset,
|
||||
record_reset_attempt
|
||||
)
|
||||
GPU_RESET_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"GPU reset functionality not available: {e}")
|
||||
GPU_RESET_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUHealthStatus:
|
||||
"""Data class for GPU health check results."""
|
||||
gpu_available: bool # torch.cuda.is_available()
|
||||
gpu_working: bool # Model actually loaded and ran on GPU
|
||||
device_used: str # "cuda" or "cpu"
|
||||
device_name: str # GPU name if available
|
||||
memory_total_gb: float # Total GPU memory
|
||||
memory_available_gb: float # Available GPU memory
|
||||
test_duration_seconds: float # How long test took
|
||||
timestamp: str # ISO timestamp
|
||||
error: Optional[str] = None # Error message if any
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def _check_gpu_health_internal(expected_device: str = "auto") -> GPUHealthStatus:
|
||||
"""
|
||||
Comprehensive GPU health check using real model + transcription.
|
||||
|
||||
Args:
|
||||
expected_device: Expected device ("auto", "cuda", "cpu")
|
||||
|
||||
Returns:
|
||||
GPUHealthStatus object
|
||||
|
||||
Raises:
|
||||
RuntimeError: If expected_device="cuda" but GPU test fails
|
||||
|
||||
Implementation:
|
||||
1. Generate test audio (1 second)
|
||||
2. Load tiny model with requested device
|
||||
3. Transcribe test audio
|
||||
4. Time the operation
|
||||
5. Verify model actually ran on GPU (check torch.cuda.memory_allocated)
|
||||
6. CRITICAL: If expected_device="cuda" but used="cpu" → raise RuntimeError
|
||||
7. Return detailed status
|
||||
|
||||
Performance Expectations:
|
||||
- GPU (tiny model): 0.3-1.0 seconds
|
||||
- CPU (tiny model): 3-10 seconds
|
||||
- If GPU test takes >2 seconds, likely running on CPU
|
||||
"""
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
start_time = time.time()
|
||||
|
||||
# Initialize status with defaults
|
||||
gpu_available = torch.cuda.is_available()
|
||||
device_name = ""
|
||||
memory_total_gb = 0.0
|
||||
memory_available_gb = 0.0
|
||||
actual_device = "cpu"
|
||||
gpu_working = False
|
||||
error_msg = None
|
||||
|
||||
# Get GPU info if available
|
||||
if gpu_available:
|
||||
try:
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
memory_total_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||||
memory_available_gb = (torch.cuda.get_device_properties(0).total_memory -
|
||||
torch.cuda.memory_allocated(0)) / (1024**3)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get GPU info: {e}")
|
||||
|
||||
try:
|
||||
# Get test audio path from environment variable
|
||||
test_audio_path = os.getenv("GPU_HEALTH_TEST_AUDIO")
|
||||
|
||||
if not test_audio_path:
|
||||
raise ValueError("GPU_HEALTH_TEST_AUDIO environment variable not set")
|
||||
|
||||
# Verify test audio file exists
|
||||
if not os.path.exists(test_audio_path):
|
||||
raise FileNotFoundError(
|
||||
f"Test audio file not found: {test_audio_path}. "
|
||||
f"Please ensure test audio exists before running GPU health checks."
|
||||
)
|
||||
|
||||
# Import here to avoid circular dependencies
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
# Determine device to use
|
||||
test_device = "cpu"
|
||||
test_compute_type = "int8"
|
||||
|
||||
if expected_device == "cuda" or (expected_device == "auto" and gpu_available):
|
||||
test_device = "cuda"
|
||||
test_compute_type = "float16"
|
||||
|
||||
# Record GPU memory before loading model
|
||||
gpu_memory_before = 0
|
||||
if torch.cuda.is_available():
|
||||
gpu_memory_before = torch.cuda.memory_allocated(0)
|
||||
|
||||
# Load tiny model and transcribe
|
||||
model = None
|
||||
try:
|
||||
model = WhisperModel(
|
||||
"tiny",
|
||||
device=test_device,
|
||||
compute_type=test_compute_type
|
||||
)
|
||||
|
||||
# Transcribe test audio
|
||||
segments, info = model.transcribe(test_audio_path, beam_size=1)
|
||||
|
||||
# Consume segments (needed to actually run inference)
|
||||
segments_list = list(segments)
|
||||
|
||||
# Check if GPU was actually used
|
||||
# faster-whisper uses CTranslate2 which manages GPU memory separately
|
||||
# from PyTorch, so we check model.model.device instead of torch memory
|
||||
actual_device = model.model.device
|
||||
if actual_device == "cuda":
|
||||
gpu_working = True
|
||||
else:
|
||||
gpu_working = False
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Model loading/inference failed: {str(e)}"
|
||||
actual_device = "cpu"
|
||||
gpu_working = False
|
||||
logger.error(f"GPU health check failed: {error_msg}")
|
||||
finally:
|
||||
# Clean up model resources to prevent GPU memory leak
|
||||
if model is not None:
|
||||
try:
|
||||
del model
|
||||
segments_list = None
|
||||
# Force garbage collection and empty CUDA cache
|
||||
import gc
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Error cleaning up GPU health check model: {cleanup_error}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Health check setup failed: {str(e)}"
|
||||
logger.error(f"GPU health check error: {error_msg}")
|
||||
|
||||
# Calculate test duration
|
||||
test_duration = time.time() - start_time
|
||||
|
||||
# Create status object
|
||||
status = GPUHealthStatus(
|
||||
gpu_available=gpu_available,
|
||||
gpu_working=gpu_working,
|
||||
device_used=actual_device,
|
||||
device_name=device_name,
|
||||
memory_total_gb=round(memory_total_gb, 2),
|
||||
memory_available_gb=round(memory_available_gb, 2),
|
||||
test_duration_seconds=round(test_duration, 2),
|
||||
timestamp=timestamp,
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
# CRITICAL: Reject if expected CUDA but got CPU
|
||||
# This service is GPU-only, so also reject device="auto" that falls back to CPU
|
||||
if expected_device == "cuda" and actual_device == "cpu":
|
||||
error_message = (
|
||||
f"GPU device requested but model loaded on CPU. "
|
||||
f"This indicates GPU driver issues or insufficient memory. "
|
||||
f"Transcription would be 10-100x slower than expected. "
|
||||
f"Please check CUDA installation and GPU availability. "
|
||||
f"Details: {error_msg or 'Unknown cause'}"
|
||||
)
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
# CRITICAL: Service requires GPU - reject CPU fallback even with device="auto"
|
||||
if expected_device == "auto" and actual_device == "cpu":
|
||||
error_message = (
|
||||
f"GPU required but not available. This service is configured for GPU-only operation. "
|
||||
f"CUDA is not available or GPU health check failed. "
|
||||
f"Please ensure CUDA is properly installed and GPU is accessible. "
|
||||
f"Details: {error_msg or 'CUDA not available'}"
|
||||
)
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
# Log health check result
|
||||
if gpu_working:
|
||||
logger.info(
|
||||
f"GPU health check passed: {device_name}, "
|
||||
f"test duration: {test_duration:.2f}s"
|
||||
)
|
||||
elif gpu_available and not gpu_working:
|
||||
logger.warning(
|
||||
f"GPU available but health check failed. "
|
||||
f"Test duration: {test_duration:.2f}s, Error: {error_msg}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"GPU not available, using CPU. Test duration: {test_duration:.2f}s")
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def check_gpu_health(expected_device: str = "auto", use_circuit_breaker: bool = True) -> GPUHealthStatus:
|
||||
"""
|
||||
GPU health check with optional circuit breaker protection.
|
||||
|
||||
This is the main entry point for GPU health checks. By default, it uses
|
||||
circuit breaker pattern to prevent repeated failed checks.
|
||||
|
||||
Args:
|
||||
expected_device: Expected device ("auto", "cuda", "cpu")
|
||||
use_circuit_breaker: Enable circuit breaker protection (default: True)
|
||||
|
||||
Returns:
|
||||
GPUHealthStatus object
|
||||
|
||||
Raises:
|
||||
RuntimeError: If expected_device="cuda" but GPU test fails
|
||||
CircuitBreakerOpen: If circuit breaker is open (too many recent failures)
|
||||
"""
|
||||
if use_circuit_breaker:
|
||||
try:
|
||||
return _gpu_health_circuit_breaker.call(_check_gpu_health_internal, expected_device)
|
||||
except CircuitBreakerOpen as e:
|
||||
# Circuit is open, fail fast without attempting check
|
||||
logger.warning(f"GPU health check circuit breaker is OPEN: {e}")
|
||||
raise RuntimeError(f"GPU health check unavailable: {str(e)}")
|
||||
else:
|
||||
return _check_gpu_health_internal(expected_device)
|
||||
|
||||
|
||||
def get_circuit_breaker_stats() -> dict:
|
||||
"""
|
||||
Get current circuit breaker statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with circuit state and failure/success counts
|
||||
"""
|
||||
return _gpu_health_circuit_breaker.get_stats()
|
||||
|
||||
|
||||
def reset_circuit_breaker():
|
||||
"""
|
||||
Manually reset the GPU health check circuit breaker.
|
||||
|
||||
Useful for:
|
||||
- Testing
|
||||
- Manual intervention after fixing GPU issues
|
||||
- Clearing persistent error state
|
||||
"""
|
||||
_gpu_health_circuit_breaker.reset()
|
||||
logger.info("GPU health check circuit breaker manually reset")
|
||||
|
||||
|
||||
def check_gpu_health_with_reset(
|
||||
expected_device: str = "cuda",
|
||||
auto_reset: bool = True
|
||||
) -> GPUHealthStatus:
|
||||
"""
|
||||
Check GPU health with automatic reset on failure.
|
||||
|
||||
This function wraps check_gpu_health() and adds automatic GPU driver
|
||||
reset capability with cooldown protection to prevent reset loops.
|
||||
|
||||
Behavior:
|
||||
1. Attempt GPU health check
|
||||
2. If fails and auto_reset=True:
|
||||
a. Check cooldown (prevents reset loops)
|
||||
b. If cooldown active → raise error immediately
|
||||
c. If cooldown OK → reset drivers → wait 3s → retry
|
||||
3. If retry fails → raise error (terminate service)
|
||||
|
||||
Args:
|
||||
expected_device: Expected device ("auto", "cuda", "cpu")
|
||||
auto_reset: Enable automatic reset on failure (default: True)
|
||||
|
||||
Returns:
|
||||
GPUHealthStatus object
|
||||
|
||||
Raises:
|
||||
RuntimeError: If GPU check fails and reset not available/allowed,
|
||||
or if retry after reset fails
|
||||
"""
|
||||
try:
|
||||
# First attempt
|
||||
return check_gpu_health(expected_device)
|
||||
|
||||
except (RuntimeError, Exception) as first_error:
|
||||
# GPU health check failed
|
||||
|
||||
if not auto_reset:
|
||||
logger.error(f"GPU health check failed (auto-reset disabled): {first_error}")
|
||||
raise
|
||||
|
||||
if not GPU_RESET_AVAILABLE:
|
||||
logger.error(
|
||||
f"GPU health check failed but reset not available: {first_error}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"GPU unavailable and reset functionality not available: {first_error}"
|
||||
)
|
||||
|
||||
# Check if reset is allowed (cooldown protection)
|
||||
if not can_attempt_reset():
|
||||
error_msg = (
|
||||
f"GPU health check failed but reset cooldown is active. "
|
||||
f"This prevents reset loops. Last error: {first_error}. "
|
||||
f"Service will terminate. If this happens after sleep/wake, "
|
||||
f"the cooldown should expire soon and next restart will work."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Attempt GPU reset
|
||||
logger.warning("=" * 70)
|
||||
logger.warning(f"GPU HEALTH CHECK FAILED: {first_error}")
|
||||
logger.warning("Attempting automatic GPU driver reset...")
|
||||
logger.warning("=" * 70)
|
||||
|
||||
try:
|
||||
reset_gpu_drivers()
|
||||
record_reset_attempt()
|
||||
|
||||
logger.info("GPU reset completed, waiting 3 seconds for stabilization...")
|
||||
time.sleep(3)
|
||||
|
||||
# Retry GPU health check
|
||||
logger.info("Retrying GPU health check after reset...")
|
||||
status = check_gpu_health(expected_device)
|
||||
|
||||
logger.warning("=" * 70)
|
||||
logger.warning("GPU HEALTH CHECK SUCCESS AFTER RESET")
|
||||
logger.warning(f"Device: {status.device_name}")
|
||||
logger.warning(f"Memory: {status.memory_available_gb:.2f} GB available")
|
||||
logger.warning("=" * 70)
|
||||
|
||||
return status
|
||||
|
||||
except Exception as reset_error:
|
||||
error_msg = (
|
||||
f"GPU health check failed after reset attempt. "
|
||||
f"Original error: {first_error}. "
|
||||
f"Reset error: {reset_error}. "
|
||||
f"Service will terminate."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
class HealthMonitor:
|
||||
"""Background thread for periodic GPU health monitoring."""
|
||||
|
||||
def __init__(self, check_interval_minutes: int = 10):
|
||||
"""
|
||||
Initialize health monitor.
|
||||
|
||||
Args:
|
||||
check_interval_minutes: Interval between health checks (default: 10)
|
||||
"""
|
||||
self._check_interval_seconds = check_interval_minutes * 60
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
self._latest_status: Optional[GPUHealthStatus] = None
|
||||
self._history: List[GPUHealthStatus] = []
|
||||
self._max_history = 100
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def start(self):
|
||||
"""Start background monitoring thread."""
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
logger.warning("Health monitor already running")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Starting GPU health monitor "
|
||||
f"(interval: {self._check_interval_seconds / 60:.1f} minutes)"
|
||||
)
|
||||
|
||||
# Run initial health check
|
||||
try:
|
||||
status = check_gpu_health(expected_device="auto")
|
||||
with self._lock:
|
||||
self._latest_status = status
|
||||
self._history.append(status)
|
||||
except Exception as e:
|
||||
logger.error(f"Initial GPU health check failed: {e}")
|
||||
|
||||
# Start background thread
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop background monitoring thread."""
|
||||
if self._thread is None:
|
||||
return
|
||||
|
||||
logger.info("Stopping GPU health monitor")
|
||||
self._stop_event.set()
|
||||
self._thread.join(timeout=5.0)
|
||||
self._thread = None
|
||||
|
||||
def get_latest_status(self) -> Optional[GPUHealthStatus]:
|
||||
"""Get most recent health check result."""
|
||||
with self._lock:
|
||||
return self._latest_status
|
||||
|
||||
def get_health_history(self, limit: int = 10) -> List[GPUHealthStatus]:
|
||||
"""
|
||||
Get recent health check history.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of GPUHealthStatus objects (most recent first)
|
||||
"""
|
||||
with self._lock:
|
||||
return list(reversed(self._history[-limit:]))
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""Background monitoring loop."""
|
||||
while not self._stop_event.wait(timeout=self._check_interval_seconds):
|
||||
try:
|
||||
logger.debug("Running periodic GPU health check")
|
||||
status = check_gpu_health(expected_device="auto")
|
||||
|
||||
with self._lock:
|
||||
self._latest_status = status
|
||||
self._history.append(status)
|
||||
|
||||
# Trim history if too long
|
||||
if len(self._history) > self._max_history:
|
||||
self._history = self._history[-self._max_history:]
|
||||
|
||||
# Log warning if performance degraded
|
||||
if status.gpu_available and not status.gpu_working:
|
||||
logger.warning(
|
||||
f"GPU health degraded: {status.error}. "
|
||||
f"Test duration: {status.test_duration_seconds}s"
|
||||
)
|
||||
elif status.gpu_working and status.test_duration_seconds > 2.0:
|
||||
logger.warning(
|
||||
f"GPU health check slow: {status.test_duration_seconds}s "
|
||||
f"(expected <1s). May indicate performance issues."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Periodic GPU health check failed: {e}")
|
||||
241
src/core/gpu_reset.py
Normal file
241
src/core/gpu_reset.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
GPU driver reset functionality with cooldown protection.
|
||||
|
||||
Provides automatic GPU driver reset on CUDA errors with safeguards
|
||||
to prevent reset loops while allowing recovery from sleep/wake cycles.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import subprocess
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cooldown file location (stores monotonic timestamp for drift protection)
|
||||
RESET_TIMESTAMP_FILE = "/tmp/whisper-gpu-last-reset"
|
||||
|
||||
# Default cooldown period (minutes)
|
||||
DEFAULT_COOLDOWN_MINUTES = 5
|
||||
|
||||
# Cooldown period in seconds (for monotonic comparison)
|
||||
def get_cooldown_seconds() -> float:
|
||||
"""Get cooldown period in seconds."""
|
||||
return get_cooldown_minutes() * 60.0
|
||||
|
||||
|
||||
def get_cooldown_minutes() -> int:
|
||||
"""
|
||||
Get cooldown period from environment variable.
|
||||
|
||||
Returns:
|
||||
Cooldown period in minutes (default: 5)
|
||||
"""
|
||||
try:
|
||||
return int(os.getenv("GPU_RESET_COOLDOWN_MINUTES", str(DEFAULT_COOLDOWN_MINUTES)))
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid GPU_RESET_COOLDOWN_MINUTES value, using default: {DEFAULT_COOLDOWN_MINUTES}"
|
||||
)
|
||||
return DEFAULT_COOLDOWN_MINUTES
|
||||
|
||||
|
||||
def get_last_reset_time() -> Optional[float]:
|
||||
"""
|
||||
Read monotonic timestamp of last GPU reset attempt.
|
||||
|
||||
Returns:
|
||||
Monotonic timestamp of last reset, or None if no previous reset
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(RESET_TIMESTAMP_FILE):
|
||||
return None
|
||||
|
||||
with open(RESET_TIMESTAMP_FILE, 'r') as f:
|
||||
timestamp_str = f.read().strip()
|
||||
|
||||
return float(timestamp_str)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read last reset timestamp: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def record_reset_attempt() -> None:
|
||||
"""
|
||||
Record current time as last GPU reset attempt.
|
||||
|
||||
Creates/updates timestamp file with monotonic time (drift-protected).
|
||||
"""
|
||||
try:
|
||||
# Use monotonic time to prevent NTP drift issues
|
||||
timestamp_monotonic = time.monotonic()
|
||||
timestamp_iso = datetime.utcnow().isoformat() # For logging only
|
||||
|
||||
with open(RESET_TIMESTAMP_FILE, 'w') as f:
|
||||
f.write(str(timestamp_monotonic))
|
||||
|
||||
logger.info(f"Recorded GPU reset timestamp: {timestamp_iso} (monotonic: {timestamp_monotonic:.2f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record reset timestamp: {e}")
|
||||
|
||||
|
||||
def can_attempt_reset() -> bool:
|
||||
"""
|
||||
Check if GPU reset can be attempted based on cooldown period.
|
||||
|
||||
Uses monotonic time to prevent NTP drift issues.
|
||||
|
||||
Returns:
|
||||
True if reset is allowed (no recent reset or cooldown expired),
|
||||
False if cooldown is still active
|
||||
"""
|
||||
last_reset_monotonic = get_last_reset_time()
|
||||
|
||||
if last_reset_monotonic is None:
|
||||
# No previous reset recorded
|
||||
logger.debug("No previous GPU reset found, reset allowed")
|
||||
return True
|
||||
|
||||
# Use monotonic time for drift-safe comparison
|
||||
current_monotonic = time.monotonic()
|
||||
time_since_reset_seconds = current_monotonic - last_reset_monotonic
|
||||
cooldown_seconds = get_cooldown_seconds()
|
||||
|
||||
if time_since_reset_seconds < cooldown_seconds:
|
||||
remaining_seconds = cooldown_seconds - time_since_reset_seconds
|
||||
logger.warning(
|
||||
f"GPU reset cooldown active. "
|
||||
f"Cooldown: {get_cooldown_minutes()} min, "
|
||||
f"Remaining: {remaining_seconds:.0f}s"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
f"GPU reset cooldown expired. "
|
||||
f"Time since last reset: {time_since_reset_seconds:.0f}s"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def reset_gpu_drivers() -> None:
|
||||
"""
|
||||
Execute GPU driver reset using reset_gpu.sh script.
|
||||
|
||||
This function:
|
||||
1. Locates reset_gpu.sh in project root
|
||||
2. Executes it with sudo
|
||||
3. Waits for completion
|
||||
4. Raises exception if reset fails
|
||||
|
||||
Raises:
|
||||
RuntimeError: If reset script fails or is not found
|
||||
PermissionError: If sudo permissions not configured
|
||||
|
||||
Note:
|
||||
Requires passwordless sudo for nvidia commands.
|
||||
See CLAUDE.md for setup instructions.
|
||||
"""
|
||||
logger.warning("=" * 60)
|
||||
logger.warning("INITIATING GPU DRIVER RESET")
|
||||
logger.warning("=" * 60)
|
||||
|
||||
# Find reset_gpu.sh script
|
||||
script_path = Path(__file__).parent.parent.parent / "reset_gpu.sh"
|
||||
|
||||
if not script_path.exists():
|
||||
error_msg = f"GPU reset script not found: {script_path}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
if not os.access(script_path, os.X_OK):
|
||||
error_msg = f"GPU reset script not executable: {script_path}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Resolve to absolute path and validate it's the expected script
|
||||
# This prevents path injection if script_path was somehow manipulated
|
||||
resolved_path = script_path.resolve()
|
||||
|
||||
# Security check: Ensure resolved path is still in expected location
|
||||
expected_parent = Path(__file__).parent.parent.parent.resolve()
|
||||
if resolved_path.parent != expected_parent:
|
||||
error_msg = f"Security check failed: Script path outside expected directory"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
logger.info(f"Executing GPU reset script: {resolved_path}")
|
||||
logger.warning("This will temporarily interrupt all GPU operations")
|
||||
|
||||
try:
|
||||
# Execute reset script with sudo
|
||||
# Using list form (not shell=True) prevents shell injection
|
||||
result = subprocess.run(
|
||||
['sudo', str(resolved_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30, # 30 second timeout
|
||||
shell=False # Explicitly disable shell to prevent injection
|
||||
)
|
||||
|
||||
# Log script output
|
||||
if result.stdout:
|
||||
logger.info(f"Reset script output:\n{result.stdout}")
|
||||
|
||||
if result.stderr:
|
||||
logger.warning(f"Reset script stderr:\n{result.stderr}")
|
||||
|
||||
# Check exit code
|
||||
if result.returncode != 0:
|
||||
error_msg = (
|
||||
f"GPU reset script failed with exit code {result.returncode}. "
|
||||
f"This may indicate:\n"
|
||||
f" 1. Sudo permissions not configured (see CLAUDE.md)\n"
|
||||
f" 2. GPU hardware issue\n"
|
||||
f" 3. Driver installation problem\n"
|
||||
f"Output: {result.stderr or result.stdout}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
logger.warning("=" * 60)
|
||||
logger.warning("GPU DRIVER RESET COMPLETED SUCCESSFULLY")
|
||||
logger.warning("=" * 60)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
error_msg = "GPU reset script timed out after 30 seconds"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
except FileNotFoundError:
|
||||
error_msg = (
|
||||
"sudo command not found. This service requires sudo access "
|
||||
"for GPU driver reset. Please ensure sudo is installed."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during GPU reset: {e}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def clear_reset_cooldown() -> None:
|
||||
"""
|
||||
Clear reset cooldown by removing timestamp file.
|
||||
|
||||
Useful for testing or manual intervention.
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(RESET_TIMESTAMP_FILE):
|
||||
os.remove(RESET_TIMESTAMP_FILE)
|
||||
logger.info("GPU reset cooldown cleared")
|
||||
else:
|
||||
logger.debug("No cooldown to clear")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear reset cooldown: {e}")
|
||||
773
src/core/job_queue.py
Normal file
773
src/core/job_queue.py
Normal file
@@ -0,0 +1,773 @@
|
||||
"""
|
||||
Job queue manager for asynchronous transcription processing.
|
||||
|
||||
Provides FIFO job queue with background worker thread and disk persistence.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import queue
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Deque
|
||||
from collections import deque
|
||||
|
||||
from core.gpu_health import check_gpu_health_with_reset
|
||||
from core.transcriber import transcribe_audio
|
||||
from core.job_repository import JobRepository
|
||||
from utils.audio_processor import validate_audio_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
DEFAULT_JOB_TTL_HOURS = 24 # How long to keep completed jobs in memory
|
||||
GPU_HEALTH_CACHE_TTL_SECONDS = 30 # Cache GPU health check results
|
||||
CLEANUP_INTERVAL_SECONDS = 3600 # Run TTL cleanup every hour (1 hour)
|
||||
JOB_TIMEOUT_WARNING_SECONDS = 600 # Warn if job takes > 10 minutes
|
||||
JOB_TIMEOUT_MAX_SECONDS = 3600 # Maximum 1 hour per job
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
"""Job status enumeration."""
|
||||
QUEUED = "queued" # In queue, waiting
|
||||
RUNNING = "running" # Currently processing
|
||||
COMPLETED = "completed" # Successfully finished
|
||||
FAILED = "failed" # Error occurred
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
"""Represents a transcription job."""
|
||||
job_id: str
|
||||
status: JobStatus
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime]
|
||||
completed_at: Optional[datetime]
|
||||
queue_position: int
|
||||
|
||||
# Request parameters
|
||||
audio_path: str
|
||||
model_name: str
|
||||
device: str
|
||||
compute_type: str
|
||||
language: Optional[str]
|
||||
output_format: str
|
||||
beam_size: int
|
||||
temperature: float
|
||||
initial_prompt: Optional[str]
|
||||
output_directory: Optional[str]
|
||||
|
||||
# Results
|
||||
result_path: Optional[str]
|
||||
error: Optional[str]
|
||||
processing_time_seconds: Optional[float]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize to dictionary for JSON storage."""
|
||||
return {
|
||||
"job_id": self.job_id,
|
||||
"status": self.status.value,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"queue_position": self.queue_position,
|
||||
"request_params": {
|
||||
"audio_path": self.audio_path,
|
||||
"model_name": self.model_name,
|
||||
"device": self.device,
|
||||
"compute_type": self.compute_type,
|
||||
"language": self.language,
|
||||
"output_format": self.output_format,
|
||||
"beam_size": self.beam_size,
|
||||
"temperature": self.temperature,
|
||||
"initial_prompt": self.initial_prompt,
|
||||
"output_directory": self.output_directory,
|
||||
},
|
||||
"result_path": self.result_path,
|
||||
"error": self.error,
|
||||
"processing_time_seconds": self.processing_time_seconds,
|
||||
}
|
||||
|
||||
def mark_for_persistence(self, repository):
|
||||
"""Mark job as dirty for write-behind persistence."""
|
||||
repository.mark_dirty(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> 'Job':
|
||||
"""Deserialize from dictionary."""
|
||||
params = data.get("request_params", {})
|
||||
return cls(
|
||||
job_id=data["job_id"],
|
||||
status=JobStatus(data["status"]),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.utcnow(),
|
||||
started_at=datetime.fromisoformat(data["started_at"]) if data.get("started_at") else None,
|
||||
completed_at=datetime.fromisoformat(data["completed_at"]) if data.get("completed_at") else None,
|
||||
queue_position=data.get("queue_position", 0),
|
||||
audio_path=params["audio_path"],
|
||||
model_name=params.get("model_name", "large-v3"),
|
||||
device=params.get("device", "auto"),
|
||||
compute_type=params.get("compute_type", "auto"),
|
||||
language=params.get("language"),
|
||||
output_format=params.get("output_format", "txt"),
|
||||
beam_size=params.get("beam_size", 5),
|
||||
temperature=params.get("temperature", 0.0),
|
||||
initial_prompt=params.get("initial_prompt"),
|
||||
output_directory=params.get("output_directory"),
|
||||
result_path=data.get("result_path"),
|
||||
error=data.get("error"),
|
||||
processing_time_seconds=data.get("processing_time_seconds"),
|
||||
)
|
||||
|
||||
|
||||
class JobQueue:
|
||||
"""
|
||||
Manages job queue with background worker.
|
||||
|
||||
THREAD SAFETY & LOCK ORDERING
|
||||
==============================
|
||||
This class uses multiple locks to protect shared state. To prevent deadlocks,
|
||||
all code MUST follow this strict lock ordering:
|
||||
|
||||
LOCK HIERARCHY (acquire in this order):
|
||||
1. _jobs_lock - Protects _jobs dict and _current_job_id
|
||||
2. _queue_positions_lock - Protects _queued_job_ids deque
|
||||
|
||||
RULES:
|
||||
- NEVER acquire _jobs_lock while holding _queue_positions_lock
|
||||
- Always release locks in reverse order of acquisition
|
||||
- Keep lock hold time minimal - release before I/O operations
|
||||
- Use snapshot/copy pattern when data must cross lock boundaries
|
||||
|
||||
CRITICAL METHODS:
|
||||
- _calculate_queue_positions(): Uses snapshot pattern to avoid nested locks
|
||||
- submit_job(): Acquires locks separately, never nested
|
||||
- _worker_loop(): Acquires locks separately in correct order
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_queue_size: int = 100,
|
||||
metadata_dir: str = "/outputs/jobs",
|
||||
job_ttl_hours: int = 24):
|
||||
"""
|
||||
Initialize job queue.
|
||||
|
||||
Args:
|
||||
max_queue_size: Maximum number of jobs in queue
|
||||
metadata_dir: Directory to store job metadata JSON files
|
||||
job_ttl_hours: Hours to keep completed/failed jobs before cleanup
|
||||
"""
|
||||
self._queue = queue.Queue(maxsize=max_queue_size)
|
||||
self._jobs: Dict[str, Job] = {}
|
||||
self._repository = JobRepository(
|
||||
metadata_dir=metadata_dir,
|
||||
job_ttl_hours=job_ttl_hours
|
||||
)
|
||||
self._worker_thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
self._current_job_id: Optional[str] = None
|
||||
self._jobs_lock = threading.Lock() # Lock for _jobs dict
|
||||
self._queue_positions_lock = threading.Lock() # Lock for position tracking
|
||||
self._max_queue_size = max_queue_size
|
||||
|
||||
# Maintain ordered queue for O(1) position lookups
|
||||
# Deque of job_ids in queue order (FIFO)
|
||||
self._queued_job_ids: Deque[str] = deque()
|
||||
|
||||
# TTL cleanup tracking
|
||||
self._last_cleanup_time = datetime.utcnow()
|
||||
|
||||
# GPU health check caching
|
||||
self._gpu_health_cache: Optional[any] = None
|
||||
self._gpu_health_cache_time: Optional[datetime] = None
|
||||
self._gpu_health_cache_ttl_seconds = GPU_HEALTH_CACHE_TTL_SECONDS
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start background worker thread.
|
||||
Load existing jobs from disk on startup.
|
||||
"""
|
||||
if self._worker_thread is not None and self._worker_thread.is_alive():
|
||||
logger.warning("Job queue worker already running")
|
||||
return
|
||||
|
||||
logger.info(f"Starting job queue (max size: {self._max_queue_size})")
|
||||
|
||||
# Start repository flush thread
|
||||
self._repository.start()
|
||||
|
||||
# Load existing jobs from disk
|
||||
self._load_jobs_from_disk()
|
||||
|
||||
# Start background worker thread
|
||||
self._stop_event.clear()
|
||||
self._worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
|
||||
self._worker_thread.start()
|
||||
|
||||
logger.info("Job queue worker started")
|
||||
|
||||
def stop(self, wait_for_current: bool = True):
|
||||
"""
|
||||
Stop background worker.
|
||||
|
||||
Args:
|
||||
wait_for_current: If True, wait for current job to complete
|
||||
"""
|
||||
if self._worker_thread is None:
|
||||
return
|
||||
|
||||
logger.info(f"Stopping job queue (wait_for_current={wait_for_current})")
|
||||
self._stop_event.set()
|
||||
|
||||
if wait_for_current:
|
||||
self._worker_thread.join(timeout=30.0)
|
||||
else:
|
||||
self._worker_thread.join(timeout=1.0)
|
||||
|
||||
self._worker_thread = None
|
||||
|
||||
# Stop repository and flush pending writes
|
||||
self._repository.stop(flush_pending=True)
|
||||
|
||||
logger.info("Job queue worker stopped")
|
||||
|
||||
def submit_job(self,
|
||||
audio_path: str,
|
||||
model_name: str = "large-v3",
|
||||
device: str = "auto",
|
||||
compute_type: str = "auto",
|
||||
language: Optional[str] = None,
|
||||
output_format: str = "txt",
|
||||
beam_size: int = 5,
|
||||
temperature: float = 0.0,
|
||||
initial_prompt: Optional[str] = None,
|
||||
output_directory: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Submit a new transcription job.
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"job_id": str,
|
||||
"status": str,
|
||||
"queue_position": int,
|
||||
"created_at": str
|
||||
}
|
||||
|
||||
Raises:
|
||||
queue.Full: If queue is at max capacity
|
||||
RuntimeError: If GPU health check fails (when device="cuda")
|
||||
FileNotFoundError: If audio file doesn't exist
|
||||
"""
|
||||
# 1. Validate audio file exists
|
||||
try:
|
||||
validate_audio_file(audio_path)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"Audio file validation failed: {e}")
|
||||
|
||||
# 2. Check GPU health (GPU required for all devices since this is GPU-only service)
|
||||
# Both device="cuda" and device="auto" require GPU
|
||||
# Use cached health check result if available (30s TTL)
|
||||
if device == "cuda" or device == "auto":
|
||||
try:
|
||||
# Check cache first
|
||||
now = datetime.utcnow()
|
||||
cache_valid = (
|
||||
self._gpu_health_cache is not None and
|
||||
self._gpu_health_cache_time is not None and
|
||||
(now - self._gpu_health_cache_time).total_seconds() < self._gpu_health_cache_ttl_seconds
|
||||
)
|
||||
|
||||
if cache_valid:
|
||||
logger.debug("Using cached GPU health check result")
|
||||
health_status = self._gpu_health_cache
|
||||
else:
|
||||
logger.info("Running GPU health check before job submission")
|
||||
# Use expected_device to match what user requested
|
||||
expected = "cuda" if device == "cuda" else "auto"
|
||||
health_status = check_gpu_health_with_reset(expected_device=expected, auto_reset=True)
|
||||
|
||||
# Cache the result
|
||||
self._gpu_health_cache = health_status
|
||||
self._gpu_health_cache_time = now
|
||||
logger.info("GPU health check passed and cached")
|
||||
|
||||
if not health_status.gpu_working:
|
||||
# Invalidate cache on failure
|
||||
self._gpu_health_cache = None
|
||||
self._gpu_health_cache_time = None
|
||||
|
||||
raise RuntimeError(
|
||||
f"GPU device required but not available. "
|
||||
f"GPU check failed: {health_status.error}. "
|
||||
f"This service is configured for GPU-only operation."
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
# Re-raise GPU health errors
|
||||
logger.error(f"Job rejected due to GPU health check failure: {e}")
|
||||
raise RuntimeError(f"Job rejected: {e}")
|
||||
elif device == "cpu":
|
||||
# Reject CPU device explicitly
|
||||
raise ValueError(
|
||||
"CPU device requested but this service is configured for GPU-only operation. "
|
||||
"Please use device='cuda' or device='auto' with a GPU available."
|
||||
)
|
||||
|
||||
# 3. Generate job_id
|
||||
job_id = str(uuid.uuid4())
|
||||
logger.debug(f"Generated job_id: {job_id}")
|
||||
|
||||
# 4. Create Job object
|
||||
job = Job(
|
||||
job_id=job_id,
|
||||
status=JobStatus.QUEUED,
|
||||
created_at=datetime.utcnow(),
|
||||
started_at=None,
|
||||
completed_at=None,
|
||||
queue_position=0, # Will be calculated
|
||||
audio_path=audio_path,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
output_directory=output_directory,
|
||||
result_path=None,
|
||||
error=None,
|
||||
processing_time_seconds=None,
|
||||
)
|
||||
logger.debug(f"Created Job object for {job_id}")
|
||||
|
||||
# 5. Add to queue (raises queue.Full if full)
|
||||
logger.debug(f"Attempting to add job {job_id} to queue (current size: {self._queue.qsize()})")
|
||||
try:
|
||||
self._queue.put_nowait(job)
|
||||
logger.debug(f"Successfully added job {job_id} to queue")
|
||||
except queue.Full:
|
||||
raise queue.Full(
|
||||
f"Job queue is full (max size: {self._max_queue_size}). "
|
||||
f"Please try again later."
|
||||
)
|
||||
|
||||
# 6. Add to jobs dict and update queue tracking
|
||||
# LOCK ORDERING: Always acquire _jobs_lock before _queue_positions_lock
|
||||
logger.debug(f"Acquiring _jobs_lock for job {job_id}")
|
||||
with self._jobs_lock:
|
||||
self._jobs[job_id] = job
|
||||
logger.debug(f"Added job {job_id} to _jobs dict")
|
||||
|
||||
# Update queue positions (separate lock to avoid deadlock)
|
||||
logger.debug(f"Acquiring _queue_positions_lock for job {job_id}")
|
||||
with self._queue_positions_lock:
|
||||
# Add to ordered queue for O(1) position tracking
|
||||
self._queued_job_ids.append(job_id)
|
||||
logger.debug(f"Added job {job_id} to _queued_job_ids, calling _calculate_queue_positions()")
|
||||
# Calculate positions - this will briefly acquire _jobs_lock internally
|
||||
self._calculate_queue_positions()
|
||||
logger.debug(f"Finished _calculate_queue_positions() for job {job_id}")
|
||||
|
||||
# Capture return data (need to re-acquire lock after position calculation)
|
||||
logger.debug(f"Re-acquiring _jobs_lock to capture return data for job {job_id}")
|
||||
with self._jobs_lock:
|
||||
return_data = {
|
||||
"job_id": job_id,
|
||||
"status": job.status.value,
|
||||
"queue_position": job.queue_position,
|
||||
"created_at": job.created_at.isoformat() + "Z"
|
||||
}
|
||||
queue_position = job.queue_position
|
||||
logger.debug(f"Captured return data for job {job_id}, queue_position={queue_position}")
|
||||
|
||||
# Mark for async persistence (outside lock to avoid blocking)
|
||||
logger.debug(f"Marking job {job_id} for persistence")
|
||||
job.mark_for_persistence(self._repository)
|
||||
logger.debug(f"Job {job_id} marked for persistence successfully")
|
||||
|
||||
logger.info(
|
||||
f"Job {job_id} submitted: {audio_path} "
|
||||
f"(queue position: {queue_position})"
|
||||
)
|
||||
|
||||
# Run periodic TTL cleanup (every 100 jobs)
|
||||
self._maybe_cleanup_old_jobs()
|
||||
|
||||
# 7. Return job info
|
||||
return return_data
|
||||
|
||||
def _maybe_cleanup_old_jobs(self):
|
||||
"""Periodically cleanup old completed/failed jobs based on TTL."""
|
||||
# Only run cleanup every hour
|
||||
now = datetime.utcnow()
|
||||
if (now - self._last_cleanup_time).total_seconds() < CLEANUP_INTERVAL_SECONDS:
|
||||
return
|
||||
|
||||
self._last_cleanup_time = now
|
||||
|
||||
# Get jobs snapshot
|
||||
with self._jobs_lock:
|
||||
jobs_snapshot = dict(self._jobs)
|
||||
|
||||
# Run cleanup (removes from disk)
|
||||
deleted_job_ids = self._repository.cleanup_old_jobs(jobs_snapshot)
|
||||
|
||||
# Remove from in-memory dict
|
||||
if deleted_job_ids:
|
||||
with self._jobs_lock:
|
||||
for job_id in deleted_job_ids:
|
||||
if job_id in self._jobs:
|
||||
del self._jobs[job_id]
|
||||
|
||||
logger.info(f"TTL cleanup removed {len(deleted_job_ids)} old jobs")
|
||||
|
||||
def get_job_status(self, job_id: str) -> dict:
|
||||
"""
|
||||
Get current status of a job.
|
||||
|
||||
Returns:
|
||||
dict: Job status information
|
||||
|
||||
Raises:
|
||||
KeyError: If job_id not found
|
||||
"""
|
||||
# Copy job data inside lock, release before building response
|
||||
with self._jobs_lock:
|
||||
if job_id not in self._jobs:
|
||||
raise KeyError(f"Job {job_id} not found")
|
||||
|
||||
job = self._jobs[job_id]
|
||||
# Copy all fields we need while holding lock
|
||||
job_data = {
|
||||
"job_id": job.job_id,
|
||||
"status": job.status.value,
|
||||
"queue_position": job.queue_position if job.status == JobStatus.QUEUED else None,
|
||||
"created_at": job.created_at,
|
||||
"started_at": job.started_at,
|
||||
"completed_at": job.completed_at,
|
||||
"result_path": job.result_path,
|
||||
"error": job.error,
|
||||
"processing_time_seconds": job.processing_time_seconds,
|
||||
}
|
||||
|
||||
# Format response outside lock
|
||||
return {
|
||||
"job_id": job_data["job_id"],
|
||||
"status": job_data["status"],
|
||||
"queue_position": job_data["queue_position"],
|
||||
"created_at": job_data["created_at"].isoformat() + "Z",
|
||||
"started_at": job_data["started_at"].isoformat() + "Z" if job_data["started_at"] else None,
|
||||
"completed_at": job_data["completed_at"].isoformat() + "Z" if job_data["completed_at"] else None,
|
||||
"result_path": job_data["result_path"],
|
||||
"error": job_data["error"],
|
||||
"processing_time_seconds": job_data["processing_time_seconds"],
|
||||
}
|
||||
|
||||
def get_job_result(self, job_id: str) -> str:
|
||||
"""
|
||||
Get transcription result text for completed job.
|
||||
|
||||
Returns:
|
||||
str: Content of transcription file
|
||||
|
||||
Raises:
|
||||
KeyError: If job_id not found
|
||||
ValueError: If job not completed
|
||||
FileNotFoundError: If result file missing
|
||||
"""
|
||||
# Copy necessary data inside lock, then release before file I/O
|
||||
with self._jobs_lock:
|
||||
if job_id not in self._jobs:
|
||||
raise KeyError(f"Job {job_id} not found")
|
||||
|
||||
job = self._jobs[job_id]
|
||||
|
||||
if job.status != JobStatus.COMPLETED:
|
||||
raise ValueError(
|
||||
f"Job {job_id} is not completed yet. "
|
||||
f"Current status: {job.status.value}"
|
||||
)
|
||||
|
||||
if not job.result_path:
|
||||
raise FileNotFoundError(f"Job {job_id} has no result path")
|
||||
|
||||
# Copy result_path while holding lock
|
||||
result_path = job.result_path
|
||||
|
||||
# Read result file (outside lock to avoid blocking)
|
||||
if not os.path.exists(result_path):
|
||||
raise FileNotFoundError(
|
||||
f"Result file not found: {result_path}"
|
||||
)
|
||||
|
||||
with open(result_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
def list_jobs(self,
|
||||
status_filter: Optional[JobStatus] = None,
|
||||
limit: int = 100) -> List[dict]:
|
||||
"""
|
||||
List jobs with optional status filter.
|
||||
|
||||
Args:
|
||||
status_filter: Only return jobs with this status
|
||||
limit: Maximum number of jobs to return
|
||||
|
||||
Returns:
|
||||
List of job status dictionaries
|
||||
"""
|
||||
with self._jobs_lock:
|
||||
jobs = list(self._jobs.values())
|
||||
|
||||
# Filter by status
|
||||
if status_filter:
|
||||
jobs = [j for j in jobs if j.status == status_filter]
|
||||
|
||||
# Sort by created_at (newest first)
|
||||
jobs.sort(key=lambda j: j.created_at, reverse=True)
|
||||
|
||||
# Limit results
|
||||
jobs = jobs[:limit]
|
||||
|
||||
# Convert to dict directly (avoid N+1 by building response in single pass)
|
||||
# This eliminates the need to call get_job_status() for each job
|
||||
result = []
|
||||
for job in jobs:
|
||||
result.append({
|
||||
"job_id": job.job_id,
|
||||
"status": job.status.value,
|
||||
"queue_position": job.queue_position if job.status == JobStatus.QUEUED else None,
|
||||
"created_at": job.created_at.isoformat() + "Z",
|
||||
"started_at": job.started_at.isoformat() + "Z" if job.started_at else None,
|
||||
"completed_at": job.completed_at.isoformat() + "Z" if job.completed_at else None,
|
||||
"result_path": job.result_path,
|
||||
"error": job.error,
|
||||
"processing_time_seconds": job.processing_time_seconds,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _worker_loop(self):
|
||||
"""
|
||||
Background worker thread function.
|
||||
Processes jobs from queue in FIFO order.
|
||||
"""
|
||||
logger.info("Job queue worker loop started")
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
# Get job from queue (with timeout to check stop_event)
|
||||
try:
|
||||
job = self._queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# Update job status to running
|
||||
with self._jobs_lock:
|
||||
self._current_job_id = job.job_id
|
||||
job.status = JobStatus.RUNNING
|
||||
job.started_at = datetime.utcnow()
|
||||
job.queue_position = 0
|
||||
|
||||
with self._queue_positions_lock:
|
||||
# Remove from ordered queue when starting processing
|
||||
if job.job_id in self._queued_job_ids:
|
||||
self._queued_job_ids.remove(job.job_id)
|
||||
# Recalculate positions since we removed a job
|
||||
self._calculate_queue_positions()
|
||||
|
||||
# Mark for async persistence (outside lock)
|
||||
job.mark_for_persistence(self._repository)
|
||||
|
||||
logger.info(f"Job {job.job_id} started processing")
|
||||
|
||||
# Process job with timeout tracking
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Start a monitoring thread for timeout warnings
|
||||
timeout_event = threading.Event()
|
||||
|
||||
def timeout_monitor():
|
||||
"""Monitor job execution time and emit warnings."""
|
||||
# Wait for warning threshold
|
||||
if timeout_event.wait(JOB_TIMEOUT_WARNING_SECONDS):
|
||||
return # Job completed before warning threshold
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.warning(
|
||||
f"Job {job.job_id} is taking longer than expected: "
|
||||
f"{elapsed:.1f}s elapsed (threshold: {JOB_TIMEOUT_WARNING_SECONDS}s)"
|
||||
)
|
||||
|
||||
# Wait for max timeout
|
||||
remaining = JOB_TIMEOUT_MAX_SECONDS - elapsed
|
||||
if remaining > 0:
|
||||
if timeout_event.wait(remaining):
|
||||
return # Job completed before max timeout
|
||||
|
||||
# Job exceeded max timeout
|
||||
elapsed = time.time() - start_time
|
||||
logger.error(
|
||||
f"Job {job.job_id} exceeded maximum timeout: "
|
||||
f"{elapsed:.1f}s elapsed (max: {JOB_TIMEOUT_MAX_SECONDS}s)"
|
||||
)
|
||||
|
||||
monitor_thread = threading.Thread(target=timeout_monitor, daemon=True)
|
||||
monitor_thread.start()
|
||||
|
||||
try:
|
||||
result = transcribe_audio(
|
||||
audio_path=job.audio_path,
|
||||
model_name=job.model_name,
|
||||
device=job.device,
|
||||
compute_type=job.compute_type,
|
||||
language=job.language,
|
||||
output_format=job.output_format,
|
||||
beam_size=job.beam_size,
|
||||
temperature=job.temperature,
|
||||
initial_prompt=job.initial_prompt,
|
||||
output_directory=job.output_directory
|
||||
)
|
||||
finally:
|
||||
# Signal timeout monitor to stop
|
||||
timeout_event.set()
|
||||
|
||||
# Check if job exceeded hard timeout
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > JOB_TIMEOUT_MAX_SECONDS:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = f"Job exceeded maximum timeout ({JOB_TIMEOUT_MAX_SECONDS}s): {elapsed:.1f}s elapsed"
|
||||
logger.error(f"Job {job.job_id} timed out: {job.error}")
|
||||
# Parse result
|
||||
elif "saved to:" in result:
|
||||
job.result_path = result.split("saved to:")[1].strip()
|
||||
job.status = JobStatus.COMPLETED
|
||||
logger.info(
|
||||
f"Job {job.job_id} completed successfully: {job.result_path}"
|
||||
)
|
||||
else:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = result
|
||||
logger.error(f"Job {job.job_id} failed: {result}")
|
||||
|
||||
except Exception as e:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
logger.error(f"Job {job.job_id} failed with exception: {e}")
|
||||
|
||||
finally:
|
||||
# Update job completion info
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.processing_time_seconds = time.time() - start_time
|
||||
|
||||
with self._jobs_lock:
|
||||
self._current_job_id = None
|
||||
|
||||
# No need to recalculate positions here - job already removed from queue
|
||||
|
||||
# Mark for async persistence (outside lock)
|
||||
job.mark_for_persistence(self._repository)
|
||||
|
||||
self._queue.task_done()
|
||||
|
||||
logger.info(
|
||||
f"Job {job.job_id} finished: "
|
||||
f"status={job.status.value}, "
|
||||
f"duration={job.processing_time_seconds:.1f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in worker loop: {e}")
|
||||
|
||||
logger.info("Job queue worker loop stopped")
|
||||
|
||||
def _load_jobs_from_disk(self):
|
||||
"""Load existing job metadata from disk on startup."""
|
||||
logger.info("Loading jobs from disk...")
|
||||
|
||||
job_data_list = self._repository.load_all_jobs()
|
||||
|
||||
if not job_data_list:
|
||||
logger.info("No existing jobs found on disk")
|
||||
return
|
||||
|
||||
loaded_count = 0
|
||||
for data in job_data_list:
|
||||
try:
|
||||
job = Job.from_dict(data)
|
||||
|
||||
# Handle jobs that were running when server stopped
|
||||
if job.status == JobStatus.RUNNING:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = "Server restarted while job was running"
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.mark_for_persistence(self._repository)
|
||||
logger.warning(
|
||||
f"Job {job.job_id} was running during shutdown, "
|
||||
f"marking as failed"
|
||||
)
|
||||
|
||||
# Re-queue queued jobs
|
||||
elif job.status == JobStatus.QUEUED:
|
||||
try:
|
||||
self._queue.put_nowait(job)
|
||||
# Add to ordered tracking deque
|
||||
with self._queue_positions_lock:
|
||||
self._queued_job_ids.append(job.job_id)
|
||||
logger.info(f"Re-queued job {job.job_id} from disk")
|
||||
except queue.Full:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = "Queue full on server restart"
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.mark_for_persistence(self._repository)
|
||||
logger.warning(
|
||||
f"Job {job.job_id} could not be re-queued (queue full)"
|
||||
)
|
||||
|
||||
with self._jobs_lock:
|
||||
self._jobs[job.job_id] = job
|
||||
loaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load job: {e}")
|
||||
|
||||
logger.info(f"Loaded {loaded_count} jobs from disk")
|
||||
|
||||
with self._queue_positions_lock:
|
||||
self._calculate_queue_positions()
|
||||
|
||||
def _calculate_queue_positions(self):
|
||||
"""
|
||||
Update queue_position for all queued jobs.
|
||||
|
||||
Optimized O(n) implementation using deque. Only updates positions
|
||||
for jobs still in QUEUED status.
|
||||
|
||||
IMPORTANT: Must be called with _queue_positions_lock held.
|
||||
Does NOT acquire _jobs_lock to avoid deadlock - uses snapshot approach.
|
||||
"""
|
||||
# Step 1: Create snapshot of job statuses (acquire lock briefly)
|
||||
job_status_snapshot = {}
|
||||
with self._jobs_lock:
|
||||
for job_id in self._queued_job_ids:
|
||||
if job_id in self._jobs:
|
||||
job_status_snapshot[job_id] = self._jobs[job_id].status
|
||||
|
||||
# Step 2: Filter out jobs that are no longer queued (no lock needed)
|
||||
valid_queued_ids = [
|
||||
job_id for job_id in self._queued_job_ids
|
||||
if job_id in job_status_snapshot and job_status_snapshot[job_id] == JobStatus.QUEUED
|
||||
]
|
||||
|
||||
self._queued_job_ids = deque(valid_queued_ids)
|
||||
|
||||
# Step 3: Update positions (acquire lock briefly for each update)
|
||||
for i, job_id in enumerate(self._queued_job_ids, start=1):
|
||||
with self._jobs_lock:
|
||||
if job_id in self._jobs:
|
||||
self._jobs[job_id].queue_position = i
|
||||
278
src/core/job_repository.py
Normal file
278
src/core/job_repository.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Job persistence layer with async I/O and write-behind caching.
|
||||
|
||||
Handles disk storage for job metadata with batched writes to reduce I/O overhead.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, List
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
DEFAULT_BATCH_INTERVAL_SECONDS = 1.0
|
||||
DEFAULT_JOB_TTL_HOURS = 24
|
||||
MAX_DIRTY_JOBS_BEFORE_FLUSH = 50
|
||||
|
||||
|
||||
class JobRepository:
|
||||
"""
|
||||
Manages job persistence with write-behind caching and TTL-based cleanup.
|
||||
|
||||
Features:
|
||||
- Async disk I/O to avoid blocking main thread
|
||||
- Batched writes (flush every N seconds or M jobs)
|
||||
- TTL-based job cleanup (removes old completed/failed jobs)
|
||||
- Thread-safe operation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_dir: str = "/outputs/jobs",
|
||||
batch_interval_seconds: float = DEFAULT_BATCH_INTERVAL_SECONDS,
|
||||
job_ttl_hours: int = DEFAULT_JOB_TTL_HOURS,
|
||||
enable_ttl_cleanup: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize job repository.
|
||||
|
||||
Args:
|
||||
metadata_dir: Directory for job metadata JSON files
|
||||
batch_interval_seconds: How often to flush dirty jobs to disk
|
||||
job_ttl_hours: Hours to keep completed/failed jobs before cleanup
|
||||
enable_ttl_cleanup: Enable automatic TTL-based cleanup
|
||||
"""
|
||||
self._metadata_dir = Path(metadata_dir)
|
||||
self._batch_interval = batch_interval_seconds
|
||||
self._job_ttl = timedelta(hours=job_ttl_hours)
|
||||
self._enable_ttl_cleanup = enable_ttl_cleanup
|
||||
|
||||
# Dirty jobs pending flush (job_id -> Job)
|
||||
self._dirty_jobs: Dict[str, any] = {}
|
||||
self._dirty_lock = threading.Lock()
|
||||
|
||||
# Background flush thread
|
||||
self._flush_thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Create metadata directory
|
||||
self._metadata_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(
|
||||
f"JobRepository initialized: dir={metadata_dir}, "
|
||||
f"batch_interval={batch_interval_seconds}s, ttl={job_ttl_hours}h"
|
||||
)
|
||||
|
||||
def start(self):
|
||||
"""Start background flush thread."""
|
||||
if self._flush_thread is not None and self._flush_thread.is_alive():
|
||||
logger.warning("JobRepository flush thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting JobRepository background flush thread")
|
||||
self._stop_event.clear()
|
||||
self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
|
||||
self._flush_thread.start()
|
||||
|
||||
def stop(self, flush_pending: bool = True):
|
||||
"""
|
||||
Stop background flush thread.
|
||||
|
||||
Args:
|
||||
flush_pending: If True, flush all pending writes before stopping
|
||||
"""
|
||||
if self._flush_thread is None:
|
||||
return
|
||||
|
||||
logger.info(f"Stopping JobRepository (flush_pending={flush_pending})")
|
||||
|
||||
if flush_pending:
|
||||
self.flush_dirty_jobs()
|
||||
|
||||
self._stop_event.set()
|
||||
self._flush_thread.join(timeout=5.0)
|
||||
self._flush_thread = None
|
||||
|
||||
logger.info("JobRepository stopped")
|
||||
|
||||
def mark_dirty(self, job: any):
|
||||
"""
|
||||
Mark a job as dirty (needs to be written to disk).
|
||||
|
||||
Args:
|
||||
job: Job object to persist
|
||||
"""
|
||||
with self._dirty_lock:
|
||||
self._dirty_jobs[job.job_id] = job
|
||||
|
||||
# Flush immediately if too many dirty jobs
|
||||
if len(self._dirty_jobs) >= MAX_DIRTY_JOBS_BEFORE_FLUSH:
|
||||
logger.debug(
|
||||
f"Dirty job threshold reached ({len(self._dirty_jobs)}), "
|
||||
f"triggering immediate flush"
|
||||
)
|
||||
self._flush_dirty_jobs_sync()
|
||||
|
||||
def flush_dirty_jobs(self):
|
||||
"""Flush all dirty jobs to disk (synchronous)."""
|
||||
with self._dirty_lock:
|
||||
self._flush_dirty_jobs_sync()
|
||||
|
||||
def _flush_dirty_jobs_sync(self):
|
||||
"""
|
||||
Internal: Flush dirty jobs to disk.
|
||||
Must be called with _dirty_lock held.
|
||||
"""
|
||||
if not self._dirty_jobs:
|
||||
return
|
||||
|
||||
jobs_to_flush = list(self._dirty_jobs.values())
|
||||
self._dirty_jobs.clear()
|
||||
|
||||
# Lock is already held by caller, do NOT re-acquire
|
||||
# Write jobs to disk (no lock needed for I/O)
|
||||
flush_count = 0
|
||||
for job in jobs_to_flush:
|
||||
try:
|
||||
self._write_job_to_disk(job)
|
||||
flush_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to flush job {job.job_id}: {e}")
|
||||
# Re-add to dirty queue for retry
|
||||
with self._dirty_lock:
|
||||
self._dirty_jobs[job.job_id] = job
|
||||
|
||||
if flush_count > 0:
|
||||
logger.debug(f"Flushed {flush_count} jobs to disk")
|
||||
|
||||
def _write_job_to_disk(self, job: any):
|
||||
"""Write single job to disk."""
|
||||
filepath = self._metadata_dir / f"{job.job_id}.json"
|
||||
|
||||
try:
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(job.to_dict(), f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write job {job.job_id} to {filepath}: {e}")
|
||||
raise
|
||||
|
||||
def load_job(self, job_id: str) -> Optional[Dict]:
|
||||
"""
|
||||
Load job from disk.
|
||||
|
||||
Args:
|
||||
job_id: Job ID to load
|
||||
|
||||
Returns:
|
||||
Job dictionary or None if not found
|
||||
"""
|
||||
filepath = self._metadata_dir / f"{job_id}.json"
|
||||
|
||||
if not filepath.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load job {job_id} from {filepath}: {e}")
|
||||
return None
|
||||
|
||||
def load_all_jobs(self) -> List[Dict]:
|
||||
"""
|
||||
Load all jobs from disk.
|
||||
|
||||
Returns:
|
||||
List of job dictionaries
|
||||
"""
|
||||
jobs = []
|
||||
|
||||
if not self._metadata_dir.exists():
|
||||
return jobs
|
||||
|
||||
for filepath in self._metadata_dir.glob("*.json"):
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
job_data = json.load(f)
|
||||
jobs.append(job_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load job from {filepath}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(jobs)} jobs from disk")
|
||||
return jobs
|
||||
|
||||
def delete_job(self, job_id: str):
|
||||
"""
|
||||
Delete job from disk.
|
||||
|
||||
Args:
|
||||
job_id: Job ID to delete
|
||||
"""
|
||||
filepath = self._metadata_dir / f"{job_id}.json"
|
||||
|
||||
try:
|
||||
if filepath.exists():
|
||||
filepath.unlink()
|
||||
logger.debug(f"Deleted job {job_id} from disk")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete job {job_id}: {e}")
|
||||
|
||||
def cleanup_old_jobs(self, jobs_dict: Dict[str, any]):
|
||||
"""
|
||||
Clean up old completed/failed jobs based on TTL.
|
||||
|
||||
Args:
|
||||
jobs_dict: Dictionary of job_id -> Job objects to check
|
||||
"""
|
||||
if not self._enable_ttl_cleanup:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
jobs_to_delete = []
|
||||
|
||||
for job_id, job in jobs_dict.items():
|
||||
# Only cleanup completed/failed jobs
|
||||
if job.status.value not in ["completed", "failed"]:
|
||||
continue
|
||||
|
||||
# Check if job has exceeded TTL
|
||||
if job.completed_at is None:
|
||||
continue
|
||||
|
||||
age = now - job.completed_at
|
||||
if age > self._job_ttl:
|
||||
jobs_to_delete.append(job_id)
|
||||
|
||||
# Delete old jobs
|
||||
for job_id in jobs_to_delete:
|
||||
try:
|
||||
self.delete_job(job_id)
|
||||
logger.info(
|
||||
f"Cleaned up old job {job_id} "
|
||||
f"(age: {(now - jobs_dict[job_id].completed_at).total_seconds() / 3600:.1f}h)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup job {job_id}: {e}")
|
||||
|
||||
return jobs_to_delete
|
||||
|
||||
def _flush_loop(self):
|
||||
"""Background thread for periodic flush."""
|
||||
logger.info("JobRepository flush loop started")
|
||||
|
||||
while not self._stop_event.wait(timeout=self._batch_interval):
|
||||
try:
|
||||
with self._dirty_lock:
|
||||
if self._dirty_jobs:
|
||||
self._flush_dirty_jobs_sync()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in flush loop: {e}")
|
||||
|
||||
logger.info("JobRepository flush loop stopped")
|
||||
276
src/core/model_manager.py
Normal file
276
src/core/model_manager.py
Normal file
@@ -0,0 +1,276 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Model Management Module
|
||||
Responsible for loading, caching, and managing Whisper models
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any, OrderedDict
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
from faster_whisper import WhisperModel, BatchedInferencePipeline
|
||||
|
||||
# Log configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import GPU health check with reset capability
|
||||
try:
|
||||
from core.gpu_health import check_gpu_health_with_reset
|
||||
GPU_HEALTH_CHECK_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("GPU health check with reset not available")
|
||||
GPU_HEALTH_CHECK_AVAILABLE = False
|
||||
|
||||
# Global model instance cache with LRU eviction
|
||||
# Maximum number of models to keep in memory (prevents OOM)
|
||||
MAX_CACHED_MODELS = int(os.getenv("MAX_CACHED_MODELS", "3"))
|
||||
model_instances: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
|
||||
def test_gpu_driver():
|
||||
"""Simple GPU driver test"""
|
||||
try:
|
||||
if not torch.cuda.is_available():
|
||||
logger.error("CUDA not available in PyTorch")
|
||||
raise RuntimeError("CUDA not available")
|
||||
|
||||
gpu_count = torch.cuda.device_count()
|
||||
if gpu_count == 0:
|
||||
logger.error("No CUDA devices found")
|
||||
raise RuntimeError("No CUDA devices")
|
||||
|
||||
# Quick GPU test
|
||||
test_tensor = torch.randn(10, 10).cuda()
|
||||
_ = test_tensor @ test_tensor.T
|
||||
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
logger.info(f"GPU test passed: {device_name} ({memory_gb:.1f}GB)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GPU test failed: {e}")
|
||||
raise RuntimeError(f"GPU initialization failed: {e}")
|
||||
|
||||
def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get or create Whisper model instance
|
||||
|
||||
Args:
|
||||
model_name: Model name (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: Running device (cpu, cuda, auto)
|
||||
compute_type: Computation type (float16, int8, auto)
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing model instance and configuration
|
||||
"""
|
||||
global model_instances
|
||||
|
||||
# Validate model name
|
||||
valid_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"]
|
||||
if model_name not in valid_models:
|
||||
raise ValueError(f"Invalid model name: {model_name}. Valid models: {', '.join(valid_models)}")
|
||||
|
||||
# Auto-detect device - GPU REQUIRED (no CPU fallback)
|
||||
if device == "auto":
|
||||
if not torch.cuda.is_available():
|
||||
error_msg = (
|
||||
"GPU required but CUDA is not available. "
|
||||
"This service is configured for GPU-only operation. "
|
||||
"Please ensure CUDA is properly installed and GPU is accessible."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
device = "cuda"
|
||||
|
||||
# Auto-detect compute type
|
||||
if compute_type == "auto":
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
|
||||
# Validate device and compute type
|
||||
if device not in ["cpu", "cuda"]:
|
||||
raise ValueError(f"Invalid device: {device}. Valid devices: cpu, cuda")
|
||||
|
||||
# CRITICAL: Reject CPU device - this service is GPU-only
|
||||
if device == "cpu":
|
||||
error_msg = (
|
||||
"CPU device requested but this service is configured for GPU-only operation. "
|
||||
"Please use device='cuda' or device='auto' with a GPU available."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
logger.error("CUDA requested but not available")
|
||||
raise RuntimeError("CUDA not available but explicitly requested")
|
||||
|
||||
if compute_type not in ["float16", "int8"]:
|
||||
raise ValueError(f"Invalid compute type: {compute_type}. Valid compute types: float16, int8")
|
||||
|
||||
if device == "cpu" and compute_type == "float16":
|
||||
logger.warning("CPU device does not support float16 computation type, automatically switching to int8")
|
||||
compute_type = "int8"
|
||||
|
||||
# Generate model key
|
||||
model_key = f"{model_name}_{device}_{compute_type}"
|
||||
|
||||
# If model is already instantiated, move to end (mark as recently used) and return
|
||||
if model_key in model_instances:
|
||||
logger.info(f"Using cached model instance: {model_key}")
|
||||
# Move to end for LRU
|
||||
model_instances.move_to_end(model_key)
|
||||
return model_instances[model_key]
|
||||
|
||||
# Test GPU driver before loading model and clean
|
||||
if device == "cuda":
|
||||
# Use GPU health check with reset capability if available
|
||||
if GPU_HEALTH_CHECK_AVAILABLE:
|
||||
try:
|
||||
logger.info("Running GPU health check with auto-reset before model loading")
|
||||
check_gpu_health_with_reset(expected_device="cuda", auto_reset=True)
|
||||
except Exception as e:
|
||||
logger.error(f"GPU health check failed: {e}")
|
||||
raise RuntimeError(f"GPU not available for model loading: {e}")
|
||||
else:
|
||||
# Fallback to simple GPU test
|
||||
test_gpu_driver()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Instantiate model
|
||||
try:
|
||||
logger.info(f"Loading Whisper model: {model_name} device: {device} compute type: {compute_type}")
|
||||
|
||||
# Base model
|
||||
model = WhisperModel(
|
||||
model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
download_root=os.environ.get("WHISPER_MODEL_DIR", None) # Support custom model directory
|
||||
)
|
||||
|
||||
# Batch processing settings - batch processing enabled by default to improve speed
|
||||
batched_model = None
|
||||
batch_size = 0
|
||||
|
||||
if device == "cuda": # Only use batch processing on CUDA devices
|
||||
# Determine appropriate batch size based on available memory
|
||||
if torch.cuda.is_available():
|
||||
gpu_mem = torch.cuda.get_device_properties(0).total_memory
|
||||
free_mem = gpu_mem - torch.cuda.memory_allocated()
|
||||
# Dynamically adjust batch size based on GPU memory
|
||||
if free_mem > 16e9: # >16GB
|
||||
batch_size = 32
|
||||
elif free_mem > 12e9: # >12GB
|
||||
batch_size = 16
|
||||
elif free_mem > 8e9: # >8GB
|
||||
batch_size = 8
|
||||
elif free_mem > 4e9: # >4GB
|
||||
batch_size = 4
|
||||
else: # Smaller memory
|
||||
batch_size = 2
|
||||
|
||||
logger.info(f"Available GPU memory: {free_mem / 1e9:.2f} GB")
|
||||
else:
|
||||
batch_size = 8 # Default value
|
||||
|
||||
logger.info(f"Enabling batch processing acceleration, batch size: {batch_size}")
|
||||
batched_model = BatchedInferencePipeline(model=model)
|
||||
|
||||
# Create result object
|
||||
result = {
|
||||
'model': model,
|
||||
'device': device,
|
||||
'compute_type': compute_type,
|
||||
'batched_model': batched_model,
|
||||
'batch_size': batch_size,
|
||||
'load_time': time.time()
|
||||
}
|
||||
|
||||
# Implement LRU eviction before adding new model
|
||||
if len(model_instances) >= MAX_CACHED_MODELS:
|
||||
# Remove oldest (least recently used) model
|
||||
evicted_key, evicted_model = model_instances.popitem(last=False)
|
||||
logger.info(
|
||||
f"Evicting cached model (LRU): {evicted_key} "
|
||||
f"(cache limit: {MAX_CACHED_MODELS})"
|
||||
)
|
||||
|
||||
# Clean up GPU memory if it was a CUDA model
|
||||
if evicted_model['device'] == 'cuda':
|
||||
try:
|
||||
# Delete model references
|
||||
del evicted_model['model']
|
||||
if evicted_model['batched_model'] is not None:
|
||||
del evicted_model['batched_model']
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("GPU memory released for evicted model")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Error cleaning up evicted model: {cleanup_error}")
|
||||
|
||||
# Cache instance (added to end of OrderedDict)
|
||||
model_instances[model_key] = result
|
||||
logger.info(
|
||||
f"Cached model: {model_key} "
|
||||
f"(cache size: {len(model_instances)}/{MAX_CACHED_MODELS})"
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_model_info() -> str:
|
||||
"""
|
||||
Get available Whisper model information
|
||||
|
||||
Returns:
|
||||
str: JSON string of model information
|
||||
"""
|
||||
import json
|
||||
|
||||
models = [
|
||||
"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"
|
||||
]
|
||||
# GPU-only service - CUDA required
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
"GPU required but CUDA is not available. "
|
||||
"This service is configured for GPU-only operation."
|
||||
)
|
||||
devices = ["cuda"] # CPU not supported in GPU-only mode
|
||||
compute_types = ["float16", "int8"]
|
||||
|
||||
# Supported language list
|
||||
languages = {
|
||||
"zh": "Chinese", "en": "English", "ja": "Japanese", "ko": "Korean", "de": "German",
|
||||
"fr": "French", "es": "Spanish", "ru": "Russian", "it": "Italian",
|
||||
"pt": "Portuguese", "nl": "Dutch", "ar": "Arabic", "hi": "Hindi",
|
||||
"tr": "Turkish", "vi": "Vietnamese", "th": "Thai", "id": "Indonesian"
|
||||
}
|
||||
|
||||
# Supported audio formats
|
||||
audio_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
||||
|
||||
info = {
|
||||
"available_models": models,
|
||||
"default_model": "large-v3",
|
||||
"available_devices": devices,
|
||||
"default_device": "cuda", # GPU-only service
|
||||
"available_compute_types": compute_types,
|
||||
"default_compute_type": "float16",
|
||||
"cuda_available": True, # Required for this service
|
||||
"gpu_only_mode": True, # Indicate this is GPU-only
|
||||
"supported_languages": languages,
|
||||
"supported_audio_formats": audio_formats,
|
||||
"version": "0.1.1"
|
||||
}
|
||||
|
||||
# GPU info (always present in GPU-only mode)
|
||||
info["gpu_info"] = {
|
||||
"name": torch.cuda.get_device_name(0),
|
||||
"memory_total": f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB",
|
||||
"memory_available": f"{torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated() / 1e9:.2f} GB"
|
||||
}
|
||||
|
||||
return json.dumps(info, indent=2)
|
||||
424
src/core/transcriber.py
Normal file
424
src/core/transcriber.py
Normal file
@@ -0,0 +1,424 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Transcription Core Module with Environment Variable Support
|
||||
Contains core logic for audio transcription
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any, Tuple, List, Optional, Union
|
||||
|
||||
from core.model_manager import get_whisper_model
|
||||
from utils.audio_processor import validate_audio_file, process_audio
|
||||
from utils.formatters import format_vtt, format_srt, format_json, format_txt, format_time
|
||||
|
||||
# Logging configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Environment variable defaults
|
||||
DEFAULT_OUTPUT_DIR = os.getenv('TRANSCRIPTION_OUTPUT_DIR')
|
||||
DEFAULT_BATCH_OUTPUT_DIR = os.getenv('TRANSCRIPTION_BATCH_OUTPUT_DIR')
|
||||
DEFAULT_MODEL = os.getenv('TRANSCRIPTION_MODEL', 'large-v3')
|
||||
DEFAULT_DEVICE = os.getenv('TRANSCRIPTION_DEVICE', 'auto')
|
||||
DEFAULT_COMPUTE_TYPE = os.getenv('TRANSCRIPTION_COMPUTE_TYPE', 'auto')
|
||||
DEFAULT_LANGUAGE = os.getenv('TRANSCRIPTION_LANGUAGE', None)
|
||||
DEFAULT_OUTPUT_FORMAT = os.getenv('TRANSCRIPTION_OUTPUT_FORMAT', 'txt')
|
||||
DEFAULT_BEAM_SIZE = int(os.getenv('TRANSCRIPTION_BEAM_SIZE', '5'))
|
||||
DEFAULT_TEMPERATURE = float(os.getenv('TRANSCRIPTION_TEMPERATURE', '0.0'))
|
||||
|
||||
# Model storage configuration
|
||||
WHISPER_MODEL_DIR = os.getenv('WHISPER_MODEL_DIR', None)
|
||||
|
||||
# File naming configuration
|
||||
USE_TIMESTAMP = os.getenv('TRANSCRIPTION_USE_TIMESTAMP', 'false').lower() == 'true'
|
||||
FILENAME_PREFIX = os.getenv('TRANSCRIPTION_FILENAME_PREFIX', '')
|
||||
FILENAME_SUFFIX = os.getenv('TRANSCRIPTION_FILENAME_SUFFIX', '')
|
||||
|
||||
def transcribe_audio(
|
||||
audio_path: str,
|
||||
model_name: str = None,
|
||||
device: str = None,
|
||||
compute_type: str = None,
|
||||
language: str = None,
|
||||
output_format: str = None,
|
||||
beam_size: int = None,
|
||||
temperature: float = None,
|
||||
initial_prompt: str = None,
|
||||
output_directory: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio file using Faster Whisper with ENV VAR support
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
model_name: Model name (defaults to TRANSCRIPTION_MODEL env var or "large-v3")
|
||||
device: Execution device (defaults to TRANSCRIPTION_DEVICE env var or "auto")
|
||||
compute_type: Computation type (defaults to TRANSCRIPTION_COMPUTE_TYPE env var or "auto")
|
||||
language: Language code (defaults to TRANSCRIPTION_LANGUAGE env var or auto-detect)
|
||||
output_format: Output format (defaults to TRANSCRIPTION_OUTPUT_FORMAT env var or "vtt")
|
||||
beam_size: Beam search size (defaults to TRANSCRIPTION_BEAM_SIZE env var or 5)
|
||||
temperature: Sampling temperature (defaults to TRANSCRIPTION_TEMPERATURE env var or 0.0)
|
||||
initial_prompt: Initial prompt text
|
||||
output_directory: Output directory (defaults to TRANSCRIPTION_OUTPUT_DIR env var or audio file directory)
|
||||
|
||||
Returns:
|
||||
str: Transcription result path or error message
|
||||
"""
|
||||
# Apply environment variable defaults
|
||||
model_name = model_name or DEFAULT_MODEL
|
||||
device = device or DEFAULT_DEVICE
|
||||
compute_type = compute_type or DEFAULT_COMPUTE_TYPE
|
||||
language = language or DEFAULT_LANGUAGE
|
||||
output_format = output_format or DEFAULT_OUTPUT_FORMAT
|
||||
beam_size = beam_size if beam_size is not None else DEFAULT_BEAM_SIZE
|
||||
temperature = temperature if temperature is not None else DEFAULT_TEMPERATURE
|
||||
|
||||
# Validate audio file
|
||||
try:
|
||||
validate_audio_file(audio_path)
|
||||
except (FileNotFoundError, ValueError, OSError) as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
try:
|
||||
# Get model instance
|
||||
model_instance = get_whisper_model(model_name, device, compute_type)
|
||||
|
||||
# Validate language code
|
||||
supported_languages = {
|
||||
"zh": "Chinese", "en": "English", "ja": "Japanese", "ko": "Korean", "de": "German",
|
||||
"fr": "French", "es": "Spanish", "ru": "Russian", "it": "Italian",
|
||||
"pt": "Portuguese", "nl": "Dutch", "ar": "Arabic", "hi": "Hindi",
|
||||
"tr": "Turkish", "vi": "Vietnamese", "th": "Thai", "id": "Indonesian"
|
||||
}
|
||||
|
||||
if language is not None and language not in supported_languages:
|
||||
logger.warning(f"Unknown language code: {language}, will use auto-detection")
|
||||
language = None
|
||||
|
||||
# Set transcription parameters
|
||||
options = {
|
||||
"language": language,
|
||||
"vad_filter": True,
|
||||
"vad_parameters": {"min_silence_duration_ms": 500},
|
||||
"beam_size": beam_size,
|
||||
"temperature": temperature,
|
||||
"initial_prompt": initial_prompt,
|
||||
"word_timestamps": True,
|
||||
"suppress_tokens": [-1],
|
||||
"condition_on_previous_text": True,
|
||||
"compression_ratio_threshold": 2.4,
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(f"Starting transcription of file: {os.path.basename(audio_path)}")
|
||||
|
||||
# Process audio
|
||||
audio_source = process_audio(audio_path)
|
||||
|
||||
# Execute transcription
|
||||
if model_instance['batched_model'] is not None and model_instance['device'] == 'cuda':
|
||||
logger.info("Using batch acceleration for transcription...")
|
||||
segments, info = model_instance['batched_model'].transcribe(
|
||||
audio_source,
|
||||
batch_size=model_instance['batch_size'],
|
||||
**options
|
||||
)
|
||||
else:
|
||||
logger.info("Using standard model for transcription...")
|
||||
segments, info = model_instance['model'].transcribe(audio_source, **options)
|
||||
|
||||
# Convert segments generator to list to release model resources
|
||||
segments = list(segments)
|
||||
|
||||
# Determine output directory and path early
|
||||
audio_dir = os.path.dirname(audio_path)
|
||||
audio_filename = os.path.splitext(os.path.basename(audio_path))[0]
|
||||
|
||||
# Priority: parameter > env var > audio directory
|
||||
if output_directory is not None:
|
||||
output_dir = output_directory
|
||||
elif DEFAULT_OUTPUT_DIR is not None:
|
||||
output_dir = DEFAULT_OUTPUT_DIR
|
||||
else:
|
||||
output_dir = audio_dir
|
||||
|
||||
# Ensure output directory exists
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Generate filename with customizable format
|
||||
filename_parts = []
|
||||
|
||||
# Add prefix if specified
|
||||
if FILENAME_PREFIX:
|
||||
filename_parts.append(FILENAME_PREFIX)
|
||||
|
||||
# Add base filename
|
||||
filename_parts.append(audio_filename)
|
||||
|
||||
# Add suffix if specified
|
||||
if FILENAME_SUFFIX:
|
||||
filename_parts.append(FILENAME_SUFFIX)
|
||||
|
||||
# Add timestamp if enabled
|
||||
if USE_TIMESTAMP:
|
||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||
filename_parts.append(timestamp)
|
||||
|
||||
# Join parts and add extension
|
||||
output_format_lower = output_format.lower()
|
||||
base_name = "_".join(filename_parts)
|
||||
output_filename = f"{base_name}.{output_format_lower}"
|
||||
output_path = os.path.join(output_dir, output_filename)
|
||||
|
||||
# Stream segments directly to file instead of loading all into memory
|
||||
# This prevents memory spikes with long audio files
|
||||
segment_count = 0
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
# Write format-specific header
|
||||
if output_format_lower == "vtt":
|
||||
f.write("WEBVTT\n\n")
|
||||
elif output_format_lower == "json":
|
||||
f.write('{"segments": [')
|
||||
|
||||
first_segment = True
|
||||
for segment in segments:
|
||||
segment_count += 1
|
||||
|
||||
# Format and write each segment immediately
|
||||
if output_format_lower == "vtt":
|
||||
start_time = format_time(segment.start)
|
||||
end_time = format_time(segment.end)
|
||||
f.write(f"{start_time} --> {end_time}\n{segment.text.strip()}\n\n")
|
||||
elif output_format_lower == "srt":
|
||||
start_time = format_time(segment.start).replace('.', ',')
|
||||
end_time = format_time(segment.end).replace('.', ',')
|
||||
f.write(f"{segment_count}\n{start_time} --> {end_time}\n{segment.text.strip()}\n\n")
|
||||
elif output_format_lower == "txt":
|
||||
f.write(segment.text.strip() + "\n")
|
||||
elif output_format_lower == "json":
|
||||
if not first_segment:
|
||||
f.write(',')
|
||||
import json as json_module
|
||||
segment_dict = {
|
||||
"start": segment.start,
|
||||
"end": segment.end,
|
||||
"text": segment.text.strip()
|
||||
}
|
||||
f.write(json_module.dumps(segment_dict))
|
||||
first_segment = False
|
||||
else:
|
||||
raise ValueError(f"Unsupported output format: {output_format}. Supported formats: vtt, srt, txt, json")
|
||||
|
||||
# Write format-specific footer
|
||||
if output_format_lower == "json":
|
||||
# Add metadata
|
||||
f.write(f'], "language": "{info.language}", "duration": {info.duration}}}')
|
||||
|
||||
except Exception as write_error:
|
||||
logger.error(f"Failed to write transcription during streaming: {str(write_error)}")
|
||||
# File handle automatically closed by context manager
|
||||
# Clean up partial file to prevent corrupted output
|
||||
if os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
logger.info(f"Cleaned up partial file: {output_path}")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Failed to cleanup partial file {output_path}: {cleanup_error}")
|
||||
raise
|
||||
|
||||
if segment_count == 0:
|
||||
if info.duration < 1.0:
|
||||
logger.warning(f"No segments: audio too short ({info.duration:.2f}s)")
|
||||
return "Transcription failed: Audio too short (< 1 second)"
|
||||
else:
|
||||
logger.warning(
|
||||
f"No segments generated: duration={info.duration:.2f}s, "
|
||||
f"language={info.language}, vad_enabled=True"
|
||||
)
|
||||
return "Transcription failed: No speech detected (VAD filtered all segments)"
|
||||
|
||||
# Record transcription information
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Transcription completed, time used: {elapsed_time:.2f} seconds, "
|
||||
f"detected language: {info.language}, audio length: {info.duration:.2f} seconds, "
|
||||
f"segments: {segment_count}"
|
||||
)
|
||||
|
||||
# File already written via streaming above
|
||||
logger.info(f"Transcription results saved to: {output_path}")
|
||||
return f"Transcription successful, results saved to: {output_path}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed: {str(e)}")
|
||||
return f"Error occurred during transcription: {str(e)}"
|
||||
|
||||
finally:
|
||||
# Force GPU memory cleanup after transcription to prevent accumulation
|
||||
if device == "cuda":
|
||||
import torch
|
||||
import gc
|
||||
# Clear segments list to free memory
|
||||
segments = None
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
# Empty CUDA cache
|
||||
torch.cuda.empty_cache()
|
||||
logger.debug("GPU memory cleaned up after transcription")
|
||||
|
||||
|
||||
def batch_transcribe(
|
||||
audio_folder: str,
|
||||
output_folder: str = None,
|
||||
model_name: str = None,
|
||||
device: str = None,
|
||||
compute_type: str = None,
|
||||
language: str = None,
|
||||
output_format: str = None,
|
||||
beam_size: int = None,
|
||||
temperature: float = None,
|
||||
initial_prompt: str = None,
|
||||
parallel_files: int = 1
|
||||
) -> str:
|
||||
"""
|
||||
Batch transcribe audio files with ENV VAR support
|
||||
|
||||
Args:
|
||||
audio_folder: Path to folder containing audio files
|
||||
output_folder: Output folder (defaults to TRANSCRIPTION_BATCH_OUTPUT_DIR env var or "transcript" subfolder)
|
||||
model_name: Model name (defaults to TRANSCRIPTION_MODEL env var or "large-v3")
|
||||
device: Execution device (defaults to TRANSCRIPTION_DEVICE env var or "auto")
|
||||
compute_type: Computation type (defaults to TRANSCRIPTION_COMPUTE_TYPE env var or "auto")
|
||||
language: Language code (defaults to TRANSCRIPTION_LANGUAGE env var or auto-detect)
|
||||
output_format: Output format (defaults to TRANSCRIPTION_OUTPUT_FORMAT env var or "vtt")
|
||||
beam_size: Beam search size (defaults to TRANSCRIPTION_BEAM_SIZE env var or 5)
|
||||
temperature: Sampling temperature (defaults to TRANSCRIPTION_TEMPERATURE env var or 0.0)
|
||||
initial_prompt: Initial prompt text
|
||||
parallel_files: Number of files to process in parallel (not implemented yet)
|
||||
|
||||
Returns:
|
||||
str: Batch processing summary
|
||||
"""
|
||||
# Apply environment variable defaults
|
||||
model_name = model_name or DEFAULT_MODEL
|
||||
device = device or DEFAULT_DEVICE
|
||||
compute_type = compute_type or DEFAULT_COMPUTE_TYPE
|
||||
language = language or DEFAULT_LANGUAGE
|
||||
output_format = output_format or DEFAULT_OUTPUT_FORMAT
|
||||
beam_size = beam_size if beam_size is not None else DEFAULT_BEAM_SIZE
|
||||
temperature = temperature if temperature is not None else DEFAULT_TEMPERATURE
|
||||
|
||||
if not os.path.isdir(audio_folder):
|
||||
return f"Error: Folder does not exist: {audio_folder}"
|
||||
|
||||
# Determine output folder with environment variable support
|
||||
if output_folder is not None:
|
||||
# Use provided parameter
|
||||
pass
|
||||
elif DEFAULT_BATCH_OUTPUT_DIR is not None:
|
||||
# Use environment variable
|
||||
output_folder = DEFAULT_BATCH_OUTPUT_DIR
|
||||
else:
|
||||
# Use default subfolder
|
||||
output_folder = os.path.join(audio_folder, "transcript")
|
||||
|
||||
# Ensure output directory exists
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# Validate output format
|
||||
valid_formats = ["txt", "vtt", "srt", "json"]
|
||||
if output_format.lower() not in valid_formats:
|
||||
return f"Error: Unsupported output format: {output_format}. Supported formats: {', '.join(valid_formats)}"
|
||||
|
||||
# Get all audio files
|
||||
audio_files = []
|
||||
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
||||
|
||||
for filename in os.listdir(audio_folder):
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
if file_ext in supported_formats:
|
||||
audio_files.append(os.path.join(audio_folder, filename))
|
||||
|
||||
if not audio_files:
|
||||
return f"No supported audio files found in {audio_folder}. Supported formats: {', '.join(supported_formats)}"
|
||||
|
||||
# Record start time
|
||||
start_time = time.time()
|
||||
total_files = len(audio_files)
|
||||
logger.info(f"Starting batch transcription of {total_files} files, output format: {output_format}")
|
||||
|
||||
# Preload model
|
||||
try:
|
||||
get_whisper_model(model_name, device, compute_type)
|
||||
logger.info(f"Model preloaded: {model_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to preload model: {str(e)}")
|
||||
return f"Batch processing failed: Cannot load model {model_name}: {str(e)}"
|
||||
|
||||
# Process files
|
||||
results = []
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
for i, audio_path in enumerate(audio_files):
|
||||
file_name = os.path.basename(audio_path)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Report progress
|
||||
progress_msg = report_progress(i, total_files, elapsed)
|
||||
logger.info(f"{progress_msg} | Currently processing: {file_name}")
|
||||
|
||||
# Execute transcription
|
||||
try:
|
||||
result = transcribe_audio(
|
||||
audio_path=audio_path,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
output_directory=output_folder
|
||||
)
|
||||
|
||||
if result.startswith("Error:") or result.startswith("Error occurred during transcription:"):
|
||||
logger.error(f"Transcription failed: {file_name} - {result}")
|
||||
results.append(f"❌ Failed: {file_name} - {result}")
|
||||
error_count += 1
|
||||
else:
|
||||
output_path = result.split(": ")[1] if ": " in result else "Unknown path"
|
||||
success_count += 1
|
||||
results.append(f"✅ Success: {file_name} -> {os.path.basename(output_path)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during transcription process: {file_name} - {str(e)}")
|
||||
results.append(f"❌ Failed: {file_name} - {str(e)}")
|
||||
error_count += 1
|
||||
|
||||
# Calculate total transcription time
|
||||
total_transcription_time = time.time() - start_time
|
||||
|
||||
# Generate summary
|
||||
summary = f"Batch processing completed, total transcription time: {format_time(total_transcription_time)}"
|
||||
summary += f" | Success: {success_count}/{total_files}"
|
||||
summary += f" | Failed: {error_count}/{total_files}"
|
||||
|
||||
# Output results
|
||||
for result in results:
|
||||
logger.info(result)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def report_progress(current: int, total: int, elapsed_time: float) -> str:
|
||||
"""Generate progress report"""
|
||||
progress = current / total * 100
|
||||
eta = (elapsed_time / current) * (total - current) if current > 0 else 0
|
||||
return (f"Progress: {current}/{total} ({progress:.1f}%)" +
|
||||
f" | Time used: {format_time(elapsed_time)}" +
|
||||
f" | Estimated remaining: {format_time(eta)}")
|
||||
5
src/servers/__init__.py
Normal file
5
src/servers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Server implementations for Whisper transcription service.
|
||||
|
||||
Includes MCP server (whisper_server.py) and REST API server (api_server.py).
|
||||
"""
|
||||
692
src/servers/api_server.py
Normal file
692
src/servers/api_server.py
Normal file
@@ -0,0 +1,692 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FastAPI REST API Server for Whisper Transcription
|
||||
Provides HTTP REST endpoints for audio transcription
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import queue as queue_module
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, List
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
import json
|
||||
import aiofiles # Async file I/O
|
||||
|
||||
from core.model_manager import get_model_info
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
from core.gpu_health import HealthMonitor, check_gpu_health, get_circuit_breaker_stats, reset_circuit_breaker
|
||||
from utils.startup import startup_sequence, cleanup_on_shutdown
|
||||
from utils.input_validation import (
|
||||
ValidationError,
|
||||
PathTraversalError,
|
||||
InvalidFileTypeError,
|
||||
FileSizeError,
|
||||
validate_beam_size,
|
||||
validate_temperature,
|
||||
validate_model_name,
|
||||
validate_device,
|
||||
validate_compute_type,
|
||||
validate_output_format,
|
||||
validate_filename_safe
|
||||
)
|
||||
|
||||
# Logging configuration
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
UPLOAD_CHUNK_SIZE_BYTES = 8192 # 8KB chunks for streaming uploads
|
||||
GPU_TEST_SLOW_THRESHOLD_SECONDS = 2.0 # GPU health check performance threshold
|
||||
DISK_SPACE_BUFFER_PERCENT = 0.10 # Require 10% extra free space as buffer
|
||||
|
||||
# Global instances
|
||||
job_queue: Optional[JobQueue] = None
|
||||
health_monitor: Optional[HealthMonitor] = None
|
||||
|
||||
|
||||
def check_disk_space(path: str, required_bytes: int) -> None:
|
||||
"""
|
||||
Check if sufficient disk space is available.
|
||||
|
||||
Args:
|
||||
path: Path to check disk space for
|
||||
required_bytes: Required bytes
|
||||
|
||||
Raises:
|
||||
IOError: If insufficient disk space
|
||||
"""
|
||||
try:
|
||||
stat = shutil.disk_usage(path)
|
||||
required_with_buffer = required_bytes * (1.0 + DISK_SPACE_BUFFER_PERCENT)
|
||||
|
||||
if stat.free < required_with_buffer:
|
||||
raise IOError(
|
||||
f"Insufficient disk space: {stat.free / 1e9:.1f}GB available, "
|
||||
f"need {required_with_buffer / 1e9:.1f}GB (including {DISK_SPACE_BUFFER_PERCENT*100:.0f}% buffer)"
|
||||
)
|
||||
except IOError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check disk space: {e}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""FastAPI lifespan context manager for startup/shutdown"""
|
||||
global job_queue, health_monitor
|
||||
|
||||
# Startup - use common startup logic (without GPU check, handled in main)
|
||||
logger.info("Initializing job queue and health monitor...")
|
||||
|
||||
from utils.startup import initialize_job_queue, initialize_health_monitor
|
||||
|
||||
job_queue = initialize_job_queue()
|
||||
health_monitor = initialize_health_monitor()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown - use common cleanup logic
|
||||
cleanup_on_shutdown(
|
||||
job_queue=job_queue,
|
||||
health_monitor=health_monitor,
|
||||
wait_for_current_job=True
|
||||
)
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Whisper Speech Recognition API",
|
||||
description="High-performance audio transcription API based on Faster Whisper with async job queue",
|
||||
version="0.2.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class SubmitJobRequest(BaseModel):
|
||||
audio_path: str = Field(..., description="Path to the audio file on the server")
|
||||
model_name: str = Field("large-v3", description="Whisper model name")
|
||||
device: str = Field("auto", description="Execution device (cuda, auto)")
|
||||
compute_type: str = Field("auto", description="Computation type (float16, int8, auto)")
|
||||
language: Optional[str] = Field(None, description="Language code (zh, en, ja, etc.)")
|
||||
output_format: str = Field("txt", description="Output format (vtt, srt, json, txt)")
|
||||
beam_size: int = Field(5, description="Beam search size")
|
||||
temperature: float = Field(0.0, description="Sampling temperature")
|
||||
initial_prompt: Optional[str] = Field(None, description="Initial prompt text")
|
||||
output_directory: Optional[str] = Field(None, description="Output directory path")
|
||||
|
||||
@field_validator('beam_size')
|
||||
@classmethod
|
||||
def check_beam_size(cls, v):
|
||||
return validate_beam_size(v)
|
||||
|
||||
@field_validator('temperature')
|
||||
@classmethod
|
||||
def check_temperature(cls, v):
|
||||
return validate_temperature(v)
|
||||
|
||||
@field_validator('model_name')
|
||||
@classmethod
|
||||
def check_model_name(cls, v):
|
||||
return validate_model_name(v)
|
||||
|
||||
@field_validator('device')
|
||||
@classmethod
|
||||
def check_device(cls, v):
|
||||
return validate_device(v)
|
||||
|
||||
@field_validator('compute_type')
|
||||
@classmethod
|
||||
def check_compute_type(cls, v):
|
||||
return validate_compute_type(v)
|
||||
|
||||
@field_validator('output_format')
|
||||
@classmethod
|
||||
def check_output_format(cls, v):
|
||||
return validate_output_format(v)
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information"""
|
||||
return {
|
||||
"name": "Whisper Speech Recognition API",
|
||||
"version": "0.2.0",
|
||||
"description": "Async job queue-based transcription service",
|
||||
"endpoints": {
|
||||
"GET /": "API information",
|
||||
"GET /health": "Health check",
|
||||
"GET /health/gpu": "GPU health check",
|
||||
"GET /health/circuit-breaker": "Get circuit breaker stats",
|
||||
"POST /health/circuit-breaker/reset": "Reset circuit breaker",
|
||||
"GET /models": "Get available models information",
|
||||
"POST /transcribe": "Upload audio file and submit transcription job",
|
||||
"POST /jobs": "Submit transcription job (async)",
|
||||
"GET /jobs/{job_id}": "Get job status",
|
||||
"GET /jobs/{job_id}/result": "Get job result",
|
||||
"GET /jobs": "List jobs with optional filtering"
|
||||
},
|
||||
"workflow": {
|
||||
"1": "Submit job via POST /jobs → receive job_id",
|
||||
"2": "Poll status via GET /jobs/{job_id} → wait for status='completed'",
|
||||
"3": "Get result via GET /jobs/{job_id}/result → retrieve transcription"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy", "service": "whisper-transcription"}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def get_models():
|
||||
"""Get available Whisper models and configuration information"""
|
||||
try:
|
||||
model_info = get_model_info()
|
||||
return JSONResponse(content=json.loads(model_info))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model info: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/transcribe")
|
||||
async def transcribe_upload(
|
||||
file: UploadFile = File(...),
|
||||
model: str = Form("medium"),
|
||||
language: Optional[str] = Form(None),
|
||||
output_format: str = Form("txt"),
|
||||
beam_size: int = Form(5),
|
||||
temperature: float = Form(0.0),
|
||||
initial_prompt: Optional[str] = Form(None)
|
||||
):
|
||||
"""
|
||||
Upload audio file and submit transcription job in one request.
|
||||
|
||||
Returns immediately with job_id. Poll GET /jobs/{job_id} for status.
|
||||
"""
|
||||
temp_file_path = None
|
||||
try:
|
||||
# Validate form parameters early
|
||||
try:
|
||||
# Validate filename for security (basename-only, no path traversal)
|
||||
validate_filename_safe(file.filename)
|
||||
model = validate_model_name(model)
|
||||
output_format = validate_output_format(output_format)
|
||||
beam_size = validate_beam_size(beam_size)
|
||||
temperature = validate_temperature(temperature)
|
||||
except ValidationError as ve:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"error_type": type(ve).__name__,
|
||||
"message": str(ve)
|
||||
}
|
||||
)
|
||||
# Early queue capacity check (backpressure)
|
||||
if job_queue._queue.qsize() >= job_queue._max_queue_size:
|
||||
logger.warning("Job queue is full, rejecting upload before file transfer")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={
|
||||
"error": "Queue full",
|
||||
"message": f"Job queue is full. Please try again later.",
|
||||
"queue_size": job_queue._queue.qsize(),
|
||||
"max_queue_size": job_queue._max_queue_size
|
||||
}
|
||||
)
|
||||
|
||||
# Save uploaded file to temp directory
|
||||
upload_dir = Path(os.getenv("TRANSCRIPTION_OUTPUT_DIR", "/tmp")) / "uploads"
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check disk space before accepting upload (estimate: file size * 2 for temp + output)
|
||||
if file.size:
|
||||
try:
|
||||
check_disk_space(str(upload_dir), file.size * 2)
|
||||
except IOError as disk_error:
|
||||
logger.error(f"Disk space check failed: {disk_error}")
|
||||
raise HTTPException(
|
||||
status_code=507, # Insufficient Storage
|
||||
detail={
|
||||
"error": "Insufficient disk space",
|
||||
"message": str(disk_error)
|
||||
}
|
||||
)
|
||||
|
||||
# Create temp file with original filename
|
||||
temp_file_path = upload_dir / file.filename
|
||||
|
||||
logger.info(f"Receiving upload: {file.filename} ({file.content_type})")
|
||||
|
||||
# Save uploaded file using async I/O to avoid blocking event loop
|
||||
async with aiofiles.open(temp_file_path, "wb") as f:
|
||||
# Read file in chunks to handle large files efficiently
|
||||
while chunk := await file.read(UPLOAD_CHUNK_SIZE_BYTES):
|
||||
await f.write(chunk)
|
||||
|
||||
logger.info(f"Saved upload to: {temp_file_path}")
|
||||
|
||||
# Submit transcription job
|
||||
job_info = job_queue.submit_job(
|
||||
audio_path=str(temp_file_path),
|
||||
model_name=model,
|
||||
device="auto",
|
||||
compute_type="auto",
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
output_directory=None
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
**job_info,
|
||||
"message": f"File uploaded and job submitted. Poll /jobs/{job_info['job_id']} for status."
|
||||
}
|
||||
)
|
||||
|
||||
except queue_module.Full:
|
||||
# Clean up temp file if queue is full
|
||||
if temp_file_path is not None and temp_file_path.exists():
|
||||
temp_file_path.unlink()
|
||||
|
||||
logger.warning("Job queue is full, rejecting upload")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={
|
||||
"error": "Queue full",
|
||||
"message": f"Job queue is full. Please try again later.",
|
||||
"queue_size": job_queue._max_queue_size,
|
||||
"max_queue_size": job_queue._max_queue_size
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up temp file on error
|
||||
if temp_file_path is not None and temp_file_path.exists():
|
||||
temp_file_path.unlink()
|
||||
|
||||
logger.error(f"Failed to process upload: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Upload failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.post("/jobs")
|
||||
async def submit_job(request: SubmitJobRequest):
|
||||
"""
|
||||
Submit a transcription job for async processing.
|
||||
|
||||
Returns immediately with job_id. Poll GET /jobs/{job_id} for status.
|
||||
"""
|
||||
try:
|
||||
job_info = job_queue.submit_job(
|
||||
audio_path=request.audio_path,
|
||||
model_name=request.model_name,
|
||||
device=request.device,
|
||||
compute_type=request.compute_type,
|
||||
language=request.language,
|
||||
output_format=request.output_format,
|
||||
beam_size=request.beam_size,
|
||||
temperature=request.temperature,
|
||||
initial_prompt=request.initial_prompt,
|
||||
output_directory=request.output_directory
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
**job_info,
|
||||
"message": f"Job submitted successfully. Poll /jobs/{job_info['job_id']} for status."
|
||||
}
|
||||
)
|
||||
|
||||
except (ValidationError, PathTraversalError, InvalidFileTypeError, FileSizeError) as ve:
|
||||
# Input validation errors
|
||||
logger.error(f"Validation error: {ve}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"error_type": type(ve).__name__,
|
||||
"message": str(ve)
|
||||
}
|
||||
)
|
||||
|
||||
except queue_module.Full:
|
||||
# Queue is full
|
||||
logger.warning("Job queue is full, rejecting request")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={
|
||||
"error": "Queue full",
|
||||
"message": f"Job queue is full ({job_queue._max_queue_size}/{job_queue._max_queue_size}). Please try again later or contact administrator.",
|
||||
"queue_size": job_queue._max_queue_size,
|
||||
"max_queue_size": job_queue._max_queue_size
|
||||
}
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
# Invalid audio file
|
||||
logger.error(f"Invalid audio file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid audio file",
|
||||
"message": str(e),
|
||||
"audio_path": request.audio_path
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
# CPU device rejected
|
||||
logger.error(f"Invalid device parameter: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid device",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
# GPU health check failed
|
||||
logger.error(f"GPU health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "GPU unavailable",
|
||||
"message": f"Job rejected: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit job: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Internal error",
|
||||
"message": f"Failed to submit job: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/jobs/{job_id}")
|
||||
async def get_job_status_endpoint(job_id: str):
|
||||
"""
|
||||
Get the current status of a job.
|
||||
|
||||
Returns job status including queue position, timestamps, and result path when completed.
|
||||
"""
|
||||
try:
|
||||
status = job_queue.get_job_status(job_id)
|
||||
return JSONResponse(status_code=200, content=status)
|
||||
|
||||
except KeyError:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": "Job not found",
|
||||
"message": f"Job ID '{job_id}' does not exist or has been cleaned up",
|
||||
"job_id": job_id
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get job status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Internal error",
|
||||
"message": f"Failed to get job status: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/jobs/{job_id}/result")
|
||||
async def get_job_result_endpoint(job_id: str):
|
||||
"""
|
||||
Get the transcription result for a completed job.
|
||||
|
||||
Returns the transcription text. Only works for jobs with status='completed'.
|
||||
"""
|
||||
try:
|
||||
result_text = job_queue.get_job_result(job_id)
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"job_id": job_id, "result": result_text}
|
||||
)
|
||||
|
||||
except KeyError:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": "Job not found",
|
||||
"message": f"Job ID '{job_id}' does not exist or has been cleaned up",
|
||||
"job_id": job_id
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
# Job not completed yet
|
||||
# Extract current status from error message
|
||||
status_match = str(e).split("Current status: ")
|
||||
current_status = status_match[1] if len(status_match) > 1 else "unknown"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"error": "Job not completed",
|
||||
"message": f"Job is not completed yet. Current status: {current_status}. Please wait and poll again.",
|
||||
"job_id": job_id,
|
||||
"current_status": current_status
|
||||
}
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
# Result file missing
|
||||
logger.error(f"Result file not found for job {job_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Result file not found",
|
||||
"message": str(e),
|
||||
"job_id": job_id
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get job result: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Internal error",
|
||||
"message": f"Failed to get job result: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/jobs")
|
||||
async def list_jobs_endpoint(
|
||||
status: Optional[str] = None,
|
||||
limit: int = 100
|
||||
):
|
||||
"""
|
||||
List jobs with optional filtering.
|
||||
|
||||
Query parameters:
|
||||
- status: Filter by status (queued, running, completed, failed)
|
||||
- limit: Maximum number of results (default: 100)
|
||||
"""
|
||||
try:
|
||||
# Parse status filter
|
||||
status_filter = None
|
||||
if status:
|
||||
try:
|
||||
status_filter = JobStatus(status)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid status",
|
||||
"message": f"Invalid status value: {status}. Must be one of: queued, running, completed, failed"
|
||||
}
|
||||
)
|
||||
|
||||
jobs = job_queue.list_jobs(status_filter=status_filter, limit=limit)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"jobs": jobs,
|
||||
"total": len(jobs),
|
||||
"filters": {
|
||||
"status": status,
|
||||
"limit": limit
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list jobs: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Internal error",
|
||||
"message": f"Failed to list jobs: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health/gpu")
|
||||
async def gpu_health_check_endpoint():
|
||||
"""
|
||||
Check GPU health by running a quick transcription test.
|
||||
|
||||
Returns detailed GPU status including device name, memory, and test duration.
|
||||
"""
|
||||
try:
|
||||
status = check_gpu_health(expected_device="auto")
|
||||
|
||||
# Add interpretation
|
||||
interpretation = "GPU is healthy and working correctly"
|
||||
if not status.gpu_available:
|
||||
interpretation = "GPU not available on this system"
|
||||
elif not status.gpu_working:
|
||||
interpretation = f"GPU available but not working correctly: {status.error}"
|
||||
elif status.test_duration_seconds > GPU_TEST_SLOW_THRESHOLD_SECONDS:
|
||||
interpretation = f"GPU working but performance degraded (test took {status.test_duration_seconds:.2f}s, expected <1s)"
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
**status.to_dict(),
|
||||
"interpretation": interpretation
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GPU health check failed: {e}")
|
||||
return JSONResponse(
|
||||
status_code=200, # Still return 200 but with error details
|
||||
content={
|
||||
"gpu_available": False,
|
||||
"gpu_working": False,
|
||||
"device_used": "unknown",
|
||||
"device_name": "",
|
||||
"memory_total_gb": 0.0,
|
||||
"memory_available_gb": 0.0,
|
||||
"test_duration_seconds": 0.0,
|
||||
"timestamp": "",
|
||||
"error": str(e),
|
||||
"interpretation": f"GPU health check failed: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health/circuit-breaker")
|
||||
async def get_circuit_breaker_status():
|
||||
"""
|
||||
Get GPU health check circuit breaker statistics.
|
||||
|
||||
Returns current state, failure/success counts, and last failure time.
|
||||
"""
|
||||
try:
|
||||
stats = get_circuit_breaker_stats()
|
||||
return JSONResponse(status_code=200, content=stats)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get circuit breaker stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/health/circuit-breaker/reset")
|
||||
async def reset_circuit_breaker_endpoint():
|
||||
"""
|
||||
Manually reset the GPU health check circuit breaker.
|
||||
|
||||
Useful after fixing GPU issues or for testing purposes.
|
||||
"""
|
||||
try:
|
||||
reset_circuit_breaker()
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"message": "Circuit breaker reset successfully"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset circuit breaker: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# Perform startup GPU health check
|
||||
from utils.startup import perform_startup_gpu_check
|
||||
|
||||
# Disable auto_reset in Docker (sudo not available, GPU reset won't work)
|
||||
in_docker = os.path.exists('/.dockerenv')
|
||||
perform_startup_gpu_check(
|
||||
required_device="cuda",
|
||||
auto_reset=not in_docker,
|
||||
exit_on_failure=True
|
||||
)
|
||||
|
||||
# Get configuration from environment variables
|
||||
host = os.getenv("API_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("API_PORT", "8000"))
|
||||
|
||||
logger.info(f"Starting Whisper REST API server on {host}:{port}")
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info",
|
||||
timeout_keep_alive=3600, # 1 hour - for long transcription jobs
|
||||
timeout_graceful_shutdown=60,
|
||||
limit_concurrency=10, # Limit concurrent connections
|
||||
backlog=100 # Queue up to 100 pending connections
|
||||
)
|
||||
334
src/servers/whisper_server.py
Normal file
334
src/servers/whisper_server.py
Normal file
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Faster Whisper-based Speech Recognition MCP Service
|
||||
Provides high-performance audio transcription with batch processing acceleration and multiple output formats
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from core.model_manager import get_model_info
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
from core.gpu_health import HealthMonitor, check_gpu_health
|
||||
from utils.startup import startup_sequence, cleanup_on_shutdown
|
||||
|
||||
# Log configuration
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances
|
||||
job_queue: Optional[JobQueue] = None
|
||||
health_monitor: Optional[HealthMonitor] = None
|
||||
|
||||
# Create FastMCP server instance
|
||||
mcp = FastMCP(
|
||||
name="fast-whisper-mcp-server",
|
||||
version="0.2.0",
|
||||
dependencies=["faster-whisper>=0.9.0", "torch==2.6.0+cu126", "torchaudio==2.6.0+cu126", "numpy>=1.20.0"]
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def get_model_info_api() -> str:
|
||||
"""
|
||||
Get available Whisper model information and system configuration.
|
||||
|
||||
Returns available models, devices, languages, and GPU information.
|
||||
"""
|
||||
return get_model_info()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def transcribe_async(
|
||||
audio_path: str,
|
||||
model_name: str = "large-v3",
|
||||
device: str = "auto",
|
||||
compute_type: str = "auto",
|
||||
language: Optional[str] = None,
|
||||
output_format: str = "txt",
|
||||
beam_size: int = 5,
|
||||
temperature: float = 0.0,
|
||||
initial_prompt: Optional[str] = None,
|
||||
output_directory: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Submit an audio file for asynchronous transcription.
|
||||
|
||||
IMPORTANT: This tool returns immediately with a job_id. Use get_job_status()
|
||||
to check progress and get_job_result() to retrieve the transcription.
|
||||
|
||||
WORKFLOW FOR LLM AGENTS:
|
||||
1. Call this tool to submit the job
|
||||
2. You will receive a job_id and queue_position
|
||||
3. Poll get_job_status(job_id) every 5-10 seconds to check progress
|
||||
4. When status="completed", call get_job_result(job_id) to get transcription
|
||||
|
||||
For long audio files (>10 minutes), expect processing to take several minutes.
|
||||
You can check queue_position to estimate wait time (each job ~2-5 minutes).
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file on server
|
||||
model_name: Whisper model (tiny, base, small, medium, large-v3)
|
||||
device: Execution device (cuda, auto) - cpu is rejected
|
||||
compute_type: Computation type (float16, int8, auto)
|
||||
language: Language code (en, zh, ja, etc.) or auto-detect
|
||||
output_format: Output format (txt, vtt, srt, json)
|
||||
beam_size: Beam search size (larger=better quality, slower)
|
||||
temperature: Sampling temperature (0.0=greedy)
|
||||
initial_prompt: Optional prompt to guide transcription
|
||||
output_directory: Where to save result (uses default if not specified)
|
||||
|
||||
Returns:
|
||||
JSON string with job_id, status, queue_position, and instructions
|
||||
"""
|
||||
try:
|
||||
job_info = job_queue.submit_job(
|
||||
audio_path=audio_path,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
output_directory=output_directory
|
||||
)
|
||||
return json.dumps({
|
||||
**job_info,
|
||||
"message": f"Job submitted successfully. Poll get_job_status('{job_info['job_id']}') for updates."
|
||||
}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
if "Full" in error_type or "queue is full" in str(e).lower():
|
||||
error_code = "QUEUE_FULL"
|
||||
message = f"Job queue is full. Please try again in a few minutes. Error: {str(e)}"
|
||||
elif "FileNotFoundError" in error_type or "not found" in str(e).lower():
|
||||
error_code = "INVALID_AUDIO_FILE"
|
||||
message = f"Audio file not found or invalid. Error: {str(e)}"
|
||||
elif "RuntimeError" in error_type or "GPU" in str(e):
|
||||
error_code = "GPU_UNAVAILABLE"
|
||||
message = f"GPU unavailable. Error: {str(e)}"
|
||||
elif "ValueError" in error_type or "CPU" in str(e):
|
||||
error_code = "INVALID_DEVICE"
|
||||
message = f"Invalid device parameter. Error: {str(e)}"
|
||||
else:
|
||||
error_code = "INTERNAL_ERROR"
|
||||
message = f"Failed to submit job. Error: {str(e)}"
|
||||
|
||||
return json.dumps({
|
||||
"error": error_code,
|
||||
"message": message
|
||||
}, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_job_status(job_id: str) -> str:
|
||||
"""
|
||||
Check the status of a transcription job.
|
||||
|
||||
Status values:
|
||||
- "queued": Job is waiting in queue. Check queue_position.
|
||||
- "running": Job is currently being processed.
|
||||
- "completed": Transcription finished. Call get_job_result() to retrieve.
|
||||
- "failed": Job failed. Check error field for details.
|
||||
|
||||
Args:
|
||||
job_id: Job ID from transcribe_async()
|
||||
|
||||
Returns:
|
||||
JSON string with detailed job status including:
|
||||
- status, queue_position, timestamps, error (if any)
|
||||
"""
|
||||
try:
|
||||
status = job_queue.get_job_status(job_id)
|
||||
return json.dumps(status, indent=2)
|
||||
|
||||
except KeyError:
|
||||
return json.dumps({
|
||||
"error": "JOB_NOT_FOUND",
|
||||
"message": f"Job ID '{job_id}' does not exist. Please check the job_id."
|
||||
}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"error": "INTERNAL_ERROR",
|
||||
"message": f"Failed to get job status. Error: {str(e)}"
|
||||
}, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_job_result(job_id: str) -> str:
|
||||
"""
|
||||
Retrieve the transcription result for a completed job.
|
||||
|
||||
IMPORTANT: Only call this when get_job_status() returns status="completed".
|
||||
If the job is not completed, this will return an error.
|
||||
|
||||
Args:
|
||||
job_id: Job ID from transcribe_async()
|
||||
|
||||
Returns:
|
||||
Transcription text as a string
|
||||
|
||||
Errors:
|
||||
- "Job not found" if invalid job_id
|
||||
- "Job not completed yet" if status is not "completed"
|
||||
- "Result file not found" if transcription file is missing
|
||||
"""
|
||||
try:
|
||||
result_text = job_queue.get_job_result(job_id)
|
||||
return result_text # Return raw text, not JSON
|
||||
|
||||
except KeyError:
|
||||
return json.dumps({
|
||||
"error": "JOB_NOT_FOUND",
|
||||
"message": f"Job ID '{job_id}' does not exist."
|
||||
}, indent=2)
|
||||
|
||||
except ValueError as e:
|
||||
# Extract status from error message
|
||||
status_match = str(e).split("Current status: ")
|
||||
current_status = status_match[1] if len(status_match) > 1 else "unknown"
|
||||
return json.dumps({
|
||||
"error": "JOB_NOT_COMPLETED",
|
||||
"message": f"Job is not completed yet. Current status: {current_status}. Please wait and check again.",
|
||||
"current_status": current_status
|
||||
}, indent=2)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
return json.dumps({
|
||||
"error": "RESULT_FILE_NOT_FOUND",
|
||||
"message": f"Result file not found. Error: {str(e)}"
|
||||
}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"error": "INTERNAL_ERROR",
|
||||
"message": f"Failed to get job result. Error: {str(e)}"
|
||||
}, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def list_transcription_jobs(
|
||||
status_filter: Optional[str] = None,
|
||||
limit: int = 20
|
||||
) -> str:
|
||||
"""
|
||||
List transcription jobs with optional filtering.
|
||||
|
||||
Useful for:
|
||||
- Checking all your submitted jobs
|
||||
- Finding completed jobs
|
||||
- Monitoring queue status
|
||||
|
||||
Args:
|
||||
status_filter: Filter by status (queued, running, completed, failed)
|
||||
limit: Maximum number of jobs to return (default: 20)
|
||||
|
||||
Returns:
|
||||
JSON string with list of jobs
|
||||
"""
|
||||
try:
|
||||
# Parse status filter
|
||||
status_obj = None
|
||||
if status_filter:
|
||||
try:
|
||||
status_obj = JobStatus(status_filter)
|
||||
except ValueError:
|
||||
return json.dumps({
|
||||
"error": "INVALID_STATUS",
|
||||
"message": f"Invalid status: {status_filter}. Must be one of: queued, running, completed, failed"
|
||||
}, indent=2)
|
||||
|
||||
jobs = job_queue.list_jobs(status_filter=status_obj, limit=limit)
|
||||
|
||||
return json.dumps({
|
||||
"jobs": jobs,
|
||||
"total": len(jobs),
|
||||
"filters": {
|
||||
"status": status_filter,
|
||||
"limit": limit
|
||||
}
|
||||
}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"error": "INTERNAL_ERROR",
|
||||
"message": f"Failed to list jobs. Error: {str(e)}"
|
||||
}, indent=2)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def check_gpu_health() -> str:
|
||||
"""
|
||||
Test GPU availability and performance by running a quick transcription.
|
||||
|
||||
This tool loads the tiny model and transcribes a 1-second test audio file
|
||||
to verify the GPU is working correctly.
|
||||
|
||||
Use this when:
|
||||
- You want to verify GPU is available before submitting large jobs
|
||||
- You suspect GPU performance issues
|
||||
- For monitoring/debugging purposes
|
||||
|
||||
Returns:
|
||||
JSON string with detailed GPU status including:
|
||||
- gpu_available, gpu_working, device_name, memory_info
|
||||
- test_duration_seconds (GPU: <1s, CPU: 5-10s)
|
||||
- interpretation message
|
||||
|
||||
Note: If this returns gpu_working=false, transcriptions will be very slow.
|
||||
"""
|
||||
try:
|
||||
status = check_gpu_health(expected_device="auto")
|
||||
|
||||
# Add interpretation
|
||||
interpretation = "GPU is healthy and working correctly"
|
||||
if not status.gpu_available:
|
||||
interpretation = "GPU not available on this system"
|
||||
elif not status.gpu_working:
|
||||
interpretation = f"GPU available but not working correctly: {status.error}"
|
||||
elif status.test_duration_seconds > 2.0:
|
||||
interpretation = f"GPU working but performance degraded (test took {status.test_duration_seconds:.2f}s, expected <1s)"
|
||||
|
||||
result = status.to_dict()
|
||||
result["interpretation"] = interpretation
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"error": "GPU_CHECK_FAILED",
|
||||
"message": f"GPU health check failed. Error: {str(e)}",
|
||||
"gpu_available": False,
|
||||
"gpu_working": False
|
||||
}, indent=2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("starting mcp server for whisper stt transcriptor")
|
||||
|
||||
# Execute common startup sequence
|
||||
job_queue, health_monitor = startup_sequence(
|
||||
service_name="MCP Whisper Server",
|
||||
require_gpu=True,
|
||||
initialize_queue=True,
|
||||
initialize_monitoring=True
|
||||
)
|
||||
|
||||
try:
|
||||
mcp.run()
|
||||
finally:
|
||||
# Cleanup on shutdown
|
||||
cleanup_on_shutdown(
|
||||
job_queue=job_queue,
|
||||
health_monitor=health_monitor,
|
||||
wait_for_current_job=True
|
||||
)
|
||||
6
src/utils/__init__.py
Normal file
6
src/utils/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Utility modules for Whisper transcription service.
|
||||
|
||||
Includes audio processing, formatters, test audio generation, input validation,
|
||||
circuit breaker, and startup logic.
|
||||
"""
|
||||
68
src/utils/audio_processor.py
Normal file
68
src/utils/audio_processor.py
Normal file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Audio Processing Module
|
||||
Responsible for audio file validation and preprocessing
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Union, Any
|
||||
from pathlib import Path
|
||||
from faster_whisper import decode_audio
|
||||
|
||||
from utils.input_validation import (
|
||||
validate_audio_file as validate_audio_file_secure,
|
||||
sanitize_error_message
|
||||
)
|
||||
|
||||
# Log configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def validate_audio_file(audio_path: str, allowed_dirs: list = None) -> None:
|
||||
"""
|
||||
Validate if an audio file is valid (with security checks).
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If audio file doesn't exist
|
||||
ValueError: If audio file format is unsupported or file is empty
|
||||
OSError: If file size cannot be checked
|
||||
|
||||
Returns:
|
||||
None: If validation passes
|
||||
"""
|
||||
try:
|
||||
# Use secure validation
|
||||
validate_audio_file_secure(audio_path, allowed_dirs)
|
||||
except Exception as e:
|
||||
# Re-raise with sanitized error messages
|
||||
error_msg = sanitize_error_message(str(e))
|
||||
|
||||
if "not found" in str(e).lower():
|
||||
raise FileNotFoundError(error_msg)
|
||||
elif "size" in str(e).lower():
|
||||
raise OSError(error_msg)
|
||||
else:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def process_audio(audio_path: str) -> Union[str, Any]:
|
||||
"""
|
||||
Process audio file, perform decoding and preprocessing
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
|
||||
Returns:
|
||||
Union[str, Any]: Processed audio data or original file path
|
||||
"""
|
||||
# Try to preprocess audio using decode_audio to handle more formats
|
||||
try:
|
||||
audio_data = decode_audio(audio_path)
|
||||
logger.info(f"Successfully preprocessed audio: {os.path.basename(audio_path)}")
|
||||
return audio_data
|
||||
except Exception as audio_error:
|
||||
logger.warning(f"Audio preprocessing failed, will use file path directly: {str(audio_error)}")
|
||||
return audio_path
|
||||
304
src/utils/circuit_breaker.py
Normal file
304
src/utils/circuit_breaker.py
Normal file
@@ -0,0 +1,304 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Circuit Breaker Pattern Implementation
|
||||
|
||||
Prevents repeated failed attempts and provides fail-fast behavior.
|
||||
Useful for GPU health checks and other operations that may fail repeatedly.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker states."""
|
||||
CLOSED = "closed" # Normal operation, requests pass through
|
||||
OPEN = "open" # Circuit is open, requests fail immediately
|
||||
HALF_OPEN = "half_open" # Testing if circuit can close
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitBreakerConfig:
|
||||
"""Configuration for circuit breaker."""
|
||||
failure_threshold: int = 3 # Failures before opening circuit
|
||||
success_threshold: int = 2 # Successes before closing from half-open
|
||||
timeout_seconds: int = 60 # Time before attempting half-open
|
||||
half_open_max_calls: int = 1 # Max calls to test in half-open state
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Circuit breaker implementation for preventing repeated failures.
|
||||
|
||||
Usage:
|
||||
breaker = CircuitBreaker(name="gpu_health", failure_threshold=3)
|
||||
|
||||
@breaker.call
|
||||
def check_gpu():
|
||||
# This function will be protected by circuit breaker
|
||||
return perform_gpu_check()
|
||||
|
||||
# Or use decorator:
|
||||
@breaker.decorator()
|
||||
def my_function():
|
||||
return "result"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
failure_threshold: int = 3,
|
||||
success_threshold: int = 2,
|
||||
timeout_seconds: int = 60,
|
||||
half_open_max_calls: int = 1
|
||||
):
|
||||
"""
|
||||
Initialize circuit breaker.
|
||||
|
||||
Args:
|
||||
name: Name of the circuit (for logging)
|
||||
failure_threshold: Number of failures before opening
|
||||
success_threshold: Number of successes to close from half-open
|
||||
timeout_seconds: Seconds before transitioning to half-open
|
||||
half_open_max_calls: Max concurrent calls in half-open state
|
||||
"""
|
||||
self.name = name
|
||||
self.config = CircuitBreakerConfig(
|
||||
failure_threshold=failure_threshold,
|
||||
success_threshold=success_threshold,
|
||||
timeout_seconds=timeout_seconds,
|
||||
half_open_max_calls=half_open_max_calls
|
||||
)
|
||||
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
# Use monotonic clock for time drift protection
|
||||
self._last_failure_time_monotonic: Optional[float] = None
|
||||
self._last_failure_time_iso: Optional[str] = None # For logging only
|
||||
self._half_open_calls = 0
|
||||
self._lock = threading.RLock() # RLock needed: properties call self.state which acquires lock
|
||||
|
||||
logger.info(
|
||||
f"Circuit breaker '{name}' initialized: "
|
||||
f"failure_threshold={failure_threshold}, "
|
||||
f"timeout={timeout_seconds}s"
|
||||
)
|
||||
|
||||
@property
|
||||
def state(self) -> CircuitState:
|
||||
"""Get current circuit state."""
|
||||
with self._lock:
|
||||
self._update_state()
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if circuit is closed (normal operation)."""
|
||||
return self.state == CircuitState.CLOSED
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
"""Check if circuit is open (failing fast)."""
|
||||
return self.state == CircuitState.OPEN
|
||||
|
||||
@property
|
||||
def is_half_open(self) -> bool:
|
||||
"""Check if circuit is half-open (testing)."""
|
||||
return self.state == CircuitState.HALF_OPEN
|
||||
|
||||
def _update_state(self):
|
||||
"""
|
||||
Update state based on timeout and counters.
|
||||
|
||||
Uses monotonic clock to prevent issues with system time changes
|
||||
(e.g., NTP adjustments, daylight saving time, manual clock changes).
|
||||
"""
|
||||
if self._state == CircuitState.OPEN:
|
||||
# Check if timeout has passed using monotonic clock
|
||||
if self._last_failure_time_monotonic is not None:
|
||||
elapsed = time.monotonic() - self._last_failure_time_monotonic
|
||||
if elapsed >= self.config.timeout_seconds:
|
||||
logger.info(
|
||||
f"Circuit '{self.name}': Transitioning to HALF_OPEN "
|
||||
f"after {elapsed:.0f}s timeout"
|
||||
)
|
||||
self._state = CircuitState.HALF_OPEN
|
||||
self._half_open_calls = 0
|
||||
self._success_count = 0
|
||||
|
||||
def _on_success(self):
|
||||
"""Handle successful call."""
|
||||
with self._lock:
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
self._success_count += 1
|
||||
logger.debug(
|
||||
f"Circuit '{self.name}': Success in HALF_OPEN "
|
||||
f"({self._success_count}/{self.config.success_threshold})"
|
||||
)
|
||||
|
||||
if self._success_count >= self.config.success_threshold:
|
||||
logger.info(f"Circuit '{self.name}': Closing circuit after successful test")
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
self._last_failure_time_monotonic = None
|
||||
self._last_failure_time_iso = None
|
||||
|
||||
elif self._state == CircuitState.CLOSED:
|
||||
# Reset failure count on success
|
||||
self._failure_count = 0
|
||||
|
||||
def _on_failure(self, error: Exception):
|
||||
"""Handle failed call."""
|
||||
with self._lock:
|
||||
self._failure_count += 1
|
||||
# Record failure time using monotonic clock for accuracy
|
||||
self._last_failure_time_monotonic = time.monotonic()
|
||||
self._last_failure_time_iso = datetime.utcnow().isoformat()
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
logger.warning(
|
||||
f"Circuit '{self.name}': Failure in HALF_OPEN, reopening circuit"
|
||||
)
|
||||
self._state = CircuitState.OPEN
|
||||
self._success_count = 0
|
||||
|
||||
elif self._state == CircuitState.CLOSED:
|
||||
logger.debug(
|
||||
f"Circuit '{self.name}': Failure {self._failure_count}/"
|
||||
f"{self.config.failure_threshold}"
|
||||
)
|
||||
|
||||
if self._failure_count >= self.config.failure_threshold:
|
||||
logger.warning(
|
||||
f"Circuit '{self.name}': Opening circuit after "
|
||||
f"{self._failure_count} failures. "
|
||||
f"Will retry in {self.config.timeout_seconds}s"
|
||||
)
|
||||
self._state = CircuitState.OPEN
|
||||
self._success_count = 0
|
||||
|
||||
def call(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute function with circuit breaker protection.
|
||||
|
||||
Args:
|
||||
func: Function to execute
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Function result
|
||||
|
||||
Raises:
|
||||
CircuitBreakerOpen: If circuit is open
|
||||
Exception: Original exception from func if it fails
|
||||
"""
|
||||
with self._lock:
|
||||
self._update_state()
|
||||
|
||||
# Check if circuit is open
|
||||
if self._state == CircuitState.OPEN:
|
||||
raise CircuitBreakerOpen(
|
||||
f"Circuit '{self.name}' is OPEN. "
|
||||
f"Failing fast to prevent repeated failures. "
|
||||
f"Last failure: {self._last_failure_time_iso or 'unknown'}. "
|
||||
f"Will retry in {self.config.timeout_seconds}s"
|
||||
)
|
||||
|
||||
# Check half-open call limit
|
||||
half_open_incremented = False
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
if self._half_open_calls >= self.config.half_open_max_calls:
|
||||
raise CircuitBreakerOpen(
|
||||
f"Circuit '{self.name}' is HALF_OPEN with max calls reached. "
|
||||
f"Please wait for current test to complete."
|
||||
)
|
||||
self._half_open_calls += 1
|
||||
half_open_incremented = True
|
||||
|
||||
# Execute function
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
self._on_success()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self._on_failure(e)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Decrement half-open counter only if we incremented it
|
||||
if half_open_incremented:
|
||||
with self._lock:
|
||||
self._half_open_calls -= 1
|
||||
|
||||
def decorator(self):
|
||||
"""
|
||||
Decorator for protecting functions with circuit breaker.
|
||||
|
||||
Usage:
|
||||
breaker = CircuitBreaker("my_service")
|
||||
|
||||
@breaker.decorator()
|
||||
def my_function():
|
||||
return do_something()
|
||||
"""
|
||||
def wrapper(func):
|
||||
def decorated(*args, **kwargs):
|
||||
return self.call(func, *args, **kwargs)
|
||||
return decorated
|
||||
return wrapper
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Manually reset circuit breaker to closed state.
|
||||
|
||||
Useful for:
|
||||
- Testing
|
||||
- Manual intervention
|
||||
- Clearing error state
|
||||
"""
|
||||
with self._lock:
|
||||
logger.info(f"Circuit '{self.name}': Manual reset to CLOSED state")
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._success_count = 0
|
||||
self._last_failure_time_monotonic = None
|
||||
self._last_failure_time_iso = None
|
||||
self._half_open_calls = 0
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""
|
||||
Get circuit breaker statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with current state and counters
|
||||
"""
|
||||
with self._lock:
|
||||
self._update_state()
|
||||
return {
|
||||
"name": self.name,
|
||||
"state": self._state.value,
|
||||
"failure_count": self._failure_count,
|
||||
"success_count": self._success_count,
|
||||
"last_failure_time": self._last_failure_time_iso,
|
||||
"config": {
|
||||
"failure_threshold": self.config.failure_threshold,
|
||||
"success_threshold": self.config.success_threshold,
|
||||
"timeout_seconds": self.config.timeout_seconds,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class CircuitBreakerOpen(Exception):
|
||||
"""Exception raised when circuit breaker is open."""
|
||||
pass
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
格式化输出模块
|
||||
负责将转录结果格式化为不同的输出格式(VTT、SRT、JSON)
|
||||
Formatting Output Module
|
||||
Responsible for formatting transcription results into different output formats (VTT, SRT, JSON, TXT)
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -9,13 +9,13 @@ from typing import List, Dict, Any
|
||||
|
||||
def format_vtt(segments: List) -> str:
|
||||
"""
|
||||
将转录结果格式化为VTT
|
||||
Format transcription results to VTT
|
||||
|
||||
Args:
|
||||
segments: 转录段落列表
|
||||
segments: List of transcription segments
|
||||
|
||||
Returns:
|
||||
str: VTT格式的字幕内容
|
||||
str: Subtitle content in VTT format
|
||||
"""
|
||||
vtt_content = "WEBVTT\n\n"
|
||||
|
||||
@@ -31,13 +31,13 @@ def format_vtt(segments: List) -> str:
|
||||
|
||||
def format_srt(segments: List) -> str:
|
||||
"""
|
||||
将转录结果格式化为SRT
|
||||
Format transcription results to SRT
|
||||
|
||||
Args:
|
||||
segments: 转录段落列表
|
||||
segments: List of transcription segments
|
||||
|
||||
Returns:
|
||||
str: SRT格式的字幕内容
|
||||
str: Subtitle content in SRT format
|
||||
"""
|
||||
srt_content = ""
|
||||
index = 1
|
||||
@@ -53,16 +53,51 @@ def format_srt(segments: List) -> str:
|
||||
|
||||
return srt_content
|
||||
|
||||
def format_json(segments: List, info: Any) -> str:
|
||||
def format_txt(segments: List) -> str:
|
||||
"""
|
||||
将转录结果格式化为JSON
|
||||
Format transcription results to plain text
|
||||
|
||||
Args:
|
||||
segments: 转录段落列表
|
||||
info: 转录信息对象
|
||||
segments: List of transcription segments
|
||||
|
||||
Returns:
|
||||
str: JSON格式的转录结果
|
||||
str: Plain text transcription content
|
||||
"""
|
||||
text_content = ""
|
||||
|
||||
for segment in segments:
|
||||
text = segment.text.strip()
|
||||
if text:
|
||||
# Add the text content
|
||||
text_content += text
|
||||
|
||||
# Add appropriate spacing between segments
|
||||
# If the text doesn't end with punctuation, add a space
|
||||
if not text.endswith(('.', '!', '?', ':', ';')):
|
||||
text_content += " "
|
||||
else:
|
||||
# If it ends with punctuation, add a space for natural flow
|
||||
text_content += " "
|
||||
|
||||
# Clean up any trailing whitespace and ensure single line breaks
|
||||
text_content = text_content.strip()
|
||||
|
||||
# Replace multiple spaces with single spaces
|
||||
while " " in text_content:
|
||||
text_content = text_content.replace(" ", " ")
|
||||
|
||||
return text_content
|
||||
|
||||
def format_json(segments: List, info: Any) -> str:
|
||||
"""
|
||||
Format transcription results to JSON
|
||||
|
||||
Args:
|
||||
segments: List of transcription segments
|
||||
info: Transcription information object
|
||||
|
||||
Returns:
|
||||
str: Transcription results in JSON format
|
||||
"""
|
||||
result = {
|
||||
"segments": [{
|
||||
@@ -86,13 +121,13 @@ def format_json(segments: List, info: Any) -> str:
|
||||
|
||||
def format_timestamp(seconds: float) -> str:
|
||||
"""
|
||||
格式化时间戳为VTT格式
|
||||
Format timestamp for VTT format
|
||||
|
||||
Args:
|
||||
seconds: 秒数
|
||||
seconds: Number of seconds
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间戳 (HH:MM:SS.mmm)
|
||||
str: Formatted timestamp (HH:MM:SS.mmm)
|
||||
"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
@@ -101,13 +136,13 @@ def format_timestamp(seconds: float) -> str:
|
||||
|
||||
def format_timestamp_srt(seconds: float) -> str:
|
||||
"""
|
||||
格式化时间戳为SRT格式
|
||||
Format timestamp for SRT format
|
||||
|
||||
Args:
|
||||
seconds: 秒数
|
||||
seconds: Number of seconds
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间戳 (HH:MM:SS,mmm)
|
||||
str: Formatted timestamp (HH:MM:SS,mmm)
|
||||
"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
@@ -117,13 +152,13 @@ def format_timestamp_srt(seconds: float) -> str:
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""
|
||||
格式化时间为可读格式
|
||||
Format time into readable format
|
||||
|
||||
Args:
|
||||
seconds: 秒数
|
||||
seconds: Number of seconds
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间 (HH:MM:SS)
|
||||
str: Formatted time (HH:MM:SS)
|
||||
"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
480
src/utils/input_validation.py
Normal file
480
src/utils/input_validation.py
Normal file
@@ -0,0 +1,480 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Input Validation and Path Sanitization Module
|
||||
|
||||
Provides robust validation for user inputs with security protections
|
||||
against path traversal, injection attacks, and other malicious inputs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum file size (10GB)
|
||||
MAX_FILE_SIZE_BYTES = 10 * 1024 * 1024 * 1024
|
||||
|
||||
# Allowed audio file extensions
|
||||
ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"}
|
||||
|
||||
# Allowed output formats
|
||||
ALLOWED_OUTPUT_FORMATS = {"vtt", "srt", "txt", "json"}
|
||||
|
||||
# Model name validation (whitelist)
|
||||
ALLOWED_MODEL_NAMES = {"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"}
|
||||
|
||||
# Device validation
|
||||
ALLOWED_DEVICES = {"cuda", "auto", "cpu"}
|
||||
|
||||
# Compute type validation
|
||||
ALLOWED_COMPUTE_TYPES = {"float16", "int8", "auto"}
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Base exception for validation errors."""
|
||||
pass
|
||||
|
||||
|
||||
class PathTraversalError(ValidationError):
|
||||
"""Exception for path traversal attempts."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidFileTypeError(ValidationError):
|
||||
"""Exception for invalid file types."""
|
||||
pass
|
||||
|
||||
|
||||
class FileSizeError(ValidationError):
|
||||
"""Exception for file size issues."""
|
||||
pass
|
||||
|
||||
|
||||
def sanitize_error_message(error_msg: str, sanitize_paths: bool = True) -> str:
|
||||
"""
|
||||
Sanitize error messages to prevent information leakage.
|
||||
|
||||
Args:
|
||||
error_msg: Original error message
|
||||
sanitize_paths: Whether to sanitize file paths (default: True)
|
||||
|
||||
Returns:
|
||||
Sanitized error message
|
||||
"""
|
||||
if not sanitize_paths:
|
||||
return error_msg
|
||||
|
||||
# Replace absolute paths with relative paths
|
||||
# Pattern: /home/user/... or /media/... or /var/... or /tmp/...
|
||||
path_pattern = r'(/(?:home|media|var|tmp|opt|usr)/[^\s:,]+)'
|
||||
|
||||
def replace_path(match):
|
||||
full_path = match.group(1)
|
||||
try:
|
||||
# Try to get just the filename
|
||||
basename = os.path.basename(full_path)
|
||||
return f"<file:{basename}>"
|
||||
except:
|
||||
return "<file:redacted>"
|
||||
|
||||
sanitized = re.sub(path_pattern, replace_path, error_msg)
|
||||
|
||||
# Also sanitize user names if present
|
||||
sanitized = re.sub(r'/home/([^/]+)/', '/home/<user>/', sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def validate_filename_safe(filename: str) -> str:
|
||||
"""
|
||||
Validate uploaded filename for security (basename-only validation).
|
||||
|
||||
This function is specifically for validating uploaded filenames to ensure
|
||||
they don't contain path traversal attempts. It enforces that the filename:
|
||||
- Contains no directory separators (/, \)
|
||||
- Has no path components (must be basename only)
|
||||
- Contains no null bytes
|
||||
- Has a valid audio file extension
|
||||
|
||||
Args:
|
||||
filename: Filename to validate (should be basename only, not full path)
|
||||
|
||||
Returns:
|
||||
Validated filename (unchanged if valid)
|
||||
|
||||
Raises:
|
||||
ValidationError: If filename is invalid or empty
|
||||
PathTraversalError: If filename contains path components or traversal attempts
|
||||
InvalidFileTypeError: If file extension is not allowed
|
||||
|
||||
Examples:
|
||||
validate_filename_safe("video.mp4") # ✓ PASS
|
||||
validate_filename_safe("audio...mp3") # ✓ PASS (ellipsis OK)
|
||||
validate_filename_safe("Wait... what.m4a") # ✓ PASS
|
||||
validate_filename_safe("../../../etc/passwd") # ✗ FAIL (traversal)
|
||||
validate_filename_safe("dir/file.mp4") # ✗ FAIL (path separator)
|
||||
validate_filename_safe("/etc/passwd") # ✗ FAIL (absolute path)
|
||||
"""
|
||||
if not filename:
|
||||
raise ValidationError("Filename cannot be empty")
|
||||
|
||||
# Check for null bytes
|
||||
if "\x00" in filename:
|
||||
logger.warning(f"Null byte in filename detected: {filename}")
|
||||
raise PathTraversalError("Null bytes in filename are not allowed")
|
||||
|
||||
# Extract basename - if it differs from original, filename contained path components
|
||||
basename = os.path.basename(filename)
|
||||
if basename != filename:
|
||||
logger.warning(f"Filename contains path components: {filename}")
|
||||
raise PathTraversalError(
|
||||
"Filename must not contain path components. "
|
||||
f"Use only the filename: {basename}"
|
||||
)
|
||||
|
||||
# Additional check: explicitly reject any path separators
|
||||
if "/" in filename or "\\" in filename:
|
||||
logger.warning(f"Path separators in filename: {filename}")
|
||||
raise PathTraversalError("Path separators (/ or \\) are not allowed in filename")
|
||||
|
||||
# Check file extension (case-insensitive)
|
||||
file_ext = Path(filename).suffix.lower()
|
||||
if not file_ext:
|
||||
raise InvalidFileTypeError("Filename must have a file extension")
|
||||
|
||||
if file_ext not in ALLOWED_AUDIO_EXTENSIONS:
|
||||
raise InvalidFileTypeError(
|
||||
f"Unsupported audio format: {file_ext}. "
|
||||
f"Supported: {', '.join(sorted(ALLOWED_AUDIO_EXTENSIONS))}"
|
||||
)
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def validate_path_safe(file_path: str, allowed_dirs: Optional[List[str]] = None) -> Path:
|
||||
"""
|
||||
Validate and sanitize a file path to prevent directory traversal attacks.
|
||||
|
||||
Args:
|
||||
file_path: Path to validate
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
|
||||
Returns:
|
||||
Resolved Path object
|
||||
|
||||
Raises:
|
||||
PathTraversalError: If path contains traversal attempts
|
||||
ValidationError: If path is invalid
|
||||
"""
|
||||
if not file_path:
|
||||
raise ValidationError("File path cannot be empty")
|
||||
|
||||
# Convert to Path object
|
||||
try:
|
||||
path = Path(file_path)
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Invalid path format: {sanitize_error_message(str(e))}")
|
||||
|
||||
# Check for path traversal attempts in path components
|
||||
# This allows filenames with ellipsis (e.g., "Wait...mp3", "file...audio.m4a")
|
||||
# while blocking actual path traversal (e.g., "../../../etc/passwd")
|
||||
path_str = str(path)
|
||||
path_parts = path.parts
|
||||
if any(part == ".." for part in path_parts):
|
||||
logger.warning(f"Path traversal attempt detected in components: {path_str}")
|
||||
raise PathTraversalError("Path traversal (..) is not allowed")
|
||||
|
||||
# Check for null bytes
|
||||
if "\x00" in path_str:
|
||||
logger.warning(f"Null byte in path detected: {path_str}")
|
||||
raise PathTraversalError("Null bytes in path are not allowed")
|
||||
|
||||
# Resolve to absolute path (but don't follow symlinks yet)
|
||||
try:
|
||||
resolved_path = path.resolve()
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Cannot resolve path: {sanitize_error_message(str(e))}")
|
||||
|
||||
# If allowed_dirs specified, ensure path is within one of them
|
||||
if allowed_dirs:
|
||||
allowed = False
|
||||
for allowed_dir in allowed_dirs:
|
||||
try:
|
||||
allowed_dir_path = Path(allowed_dir).resolve()
|
||||
# Check if resolved_path is under allowed_dir
|
||||
resolved_path.relative_to(allowed_dir_path)
|
||||
allowed = True
|
||||
break
|
||||
except ValueError:
|
||||
# Not relative to this allowed_dir
|
||||
continue
|
||||
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
f"Path outside allowed directories: {path_str}, "
|
||||
f"allowed: {allowed_dirs}"
|
||||
)
|
||||
raise PathTraversalError(
|
||||
f"Path must be within allowed directories. "
|
||||
f"Allowed: {[os.path.basename(d) for d in allowed_dirs]}"
|
||||
)
|
||||
|
||||
return resolved_path
|
||||
|
||||
|
||||
def validate_audio_file(
|
||||
file_path: str,
|
||||
allowed_dirs: Optional[List[str]] = None,
|
||||
max_size_bytes: int = MAX_FILE_SIZE_BYTES
|
||||
) -> Path:
|
||||
"""
|
||||
Validate audio file path with security checks.
|
||||
|
||||
Args:
|
||||
file_path: Path to audio file
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
max_size_bytes: Maximum allowed file size
|
||||
|
||||
Returns:
|
||||
Validated Path object
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
PathTraversalError: If path traversal detected
|
||||
FileNotFoundError: If file doesn't exist
|
||||
InvalidFileTypeError: If file type not allowed
|
||||
FileSizeError: If file too large
|
||||
"""
|
||||
# Validate and sanitize path
|
||||
validated_path = validate_path_safe(file_path, allowed_dirs)
|
||||
|
||||
# Check file exists
|
||||
if not validated_path.exists():
|
||||
raise FileNotFoundError(f"Audio file not found: {validated_path.name}")
|
||||
|
||||
# Check it's a file (not directory)
|
||||
if not validated_path.is_file():
|
||||
raise ValidationError(f"Path is not a file: {validated_path.name}")
|
||||
|
||||
# Check file extension
|
||||
file_ext = validated_path.suffix.lower()
|
||||
if file_ext not in ALLOWED_AUDIO_EXTENSIONS:
|
||||
raise InvalidFileTypeError(
|
||||
f"Unsupported audio format: {file_ext}. "
|
||||
f"Supported: {', '.join(sorted(ALLOWED_AUDIO_EXTENSIONS))}"
|
||||
)
|
||||
|
||||
# Check file size
|
||||
try:
|
||||
file_size = validated_path.stat().st_size
|
||||
except Exception as e:
|
||||
raise ValidationError(f"Cannot check file size: {sanitize_error_message(str(e))}")
|
||||
|
||||
if file_size == 0:
|
||||
raise FileSizeError(f"Audio file is empty: {validated_path.name}")
|
||||
|
||||
if file_size > max_size_bytes:
|
||||
raise FileSizeError(
|
||||
f"File too large: {file_size / (1024**3):.2f}GB. "
|
||||
f"Maximum: {max_size_bytes / (1024**3):.2f}GB"
|
||||
)
|
||||
|
||||
# Warn for large files (>1GB)
|
||||
if file_size > 1024 * 1024 * 1024:
|
||||
logger.warning(
|
||||
f"Large file: {file_size / (1024**3):.2f}GB, "
|
||||
f"may require extended processing time"
|
||||
)
|
||||
|
||||
return validated_path
|
||||
|
||||
|
||||
def validate_output_directory(
|
||||
dir_path: str,
|
||||
allowed_dirs: Optional[List[str]] = None,
|
||||
create_if_missing: bool = True
|
||||
) -> Path:
|
||||
"""
|
||||
Validate output directory path.
|
||||
|
||||
Args:
|
||||
dir_path: Directory path
|
||||
allowed_dirs: Optional list of allowed base directories
|
||||
create_if_missing: Create directory if it doesn't exist
|
||||
|
||||
Returns:
|
||||
Validated Path object
|
||||
|
||||
Raises:
|
||||
ValidationError: If validation fails
|
||||
PathTraversalError: If path traversal detected
|
||||
"""
|
||||
# Validate and sanitize path
|
||||
validated_path = validate_path_safe(dir_path, allowed_dirs)
|
||||
|
||||
# Create if requested and doesn't exist
|
||||
if create_if_missing and not validated_path.exists():
|
||||
try:
|
||||
validated_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created output directory: {validated_path}")
|
||||
except Exception as e:
|
||||
raise ValidationError(
|
||||
f"Cannot create output directory: {sanitize_error_message(str(e))}"
|
||||
)
|
||||
|
||||
# Check it's a directory
|
||||
if validated_path.exists() and not validated_path.is_dir():
|
||||
raise ValidationError(f"Path exists but is not a directory: {validated_path.name}")
|
||||
|
||||
return validated_path
|
||||
|
||||
|
||||
def validate_model_name(model_name: str) -> str:
|
||||
"""
|
||||
Validate Whisper model name.
|
||||
|
||||
Args:
|
||||
model_name: Model name to validate
|
||||
|
||||
Returns:
|
||||
Validated model name
|
||||
|
||||
Raises:
|
||||
ValidationError: If model name invalid
|
||||
"""
|
||||
if not model_name:
|
||||
raise ValidationError("Model name cannot be empty")
|
||||
|
||||
model_name = model_name.strip().lower()
|
||||
|
||||
if model_name not in ALLOWED_MODEL_NAMES:
|
||||
raise ValidationError(
|
||||
f"Invalid model name: {model_name}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_MODEL_NAMES))}"
|
||||
)
|
||||
|
||||
return model_name
|
||||
|
||||
|
||||
def validate_device(device: str) -> str:
|
||||
"""
|
||||
Validate device parameter.
|
||||
|
||||
Args:
|
||||
device: Device name to validate
|
||||
|
||||
Returns:
|
||||
Validated device name
|
||||
|
||||
Raises:
|
||||
ValidationError: If device invalid
|
||||
"""
|
||||
if not device:
|
||||
raise ValidationError("Device cannot be empty")
|
||||
|
||||
device = device.strip().lower()
|
||||
|
||||
if device not in ALLOWED_DEVICES:
|
||||
raise ValidationError(
|
||||
f"Invalid device: {device}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_DEVICES))}"
|
||||
)
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def validate_compute_type(compute_type: str) -> str:
|
||||
"""
|
||||
Validate compute type parameter.
|
||||
|
||||
Args:
|
||||
compute_type: Compute type to validate
|
||||
|
||||
Returns:
|
||||
Validated compute type
|
||||
|
||||
Raises:
|
||||
ValidationError: If compute type invalid
|
||||
"""
|
||||
if not compute_type:
|
||||
raise ValidationError("Compute type cannot be empty")
|
||||
|
||||
compute_type = compute_type.strip().lower()
|
||||
|
||||
if compute_type not in ALLOWED_COMPUTE_TYPES:
|
||||
raise ValidationError(
|
||||
f"Invalid compute type: {compute_type}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_COMPUTE_TYPES))}"
|
||||
)
|
||||
|
||||
return compute_type
|
||||
|
||||
|
||||
def validate_output_format(output_format: str) -> str:
|
||||
"""
|
||||
Validate output format parameter.
|
||||
|
||||
Args:
|
||||
output_format: Output format to validate
|
||||
|
||||
Returns:
|
||||
Validated output format
|
||||
|
||||
Raises:
|
||||
ValidationError: If output format invalid
|
||||
"""
|
||||
if not output_format:
|
||||
raise ValidationError("Output format cannot be empty")
|
||||
|
||||
output_format = output_format.strip().lower()
|
||||
|
||||
if output_format not in ALLOWED_OUTPUT_FORMATS:
|
||||
raise ValidationError(
|
||||
f"Invalid output format: {output_format}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_OUTPUT_FORMATS))}"
|
||||
)
|
||||
|
||||
return output_format
|
||||
|
||||
|
||||
def validate_numeric_range(
|
||||
value: float,
|
||||
min_value: float,
|
||||
max_value: float,
|
||||
param_name: str
|
||||
) -> float:
|
||||
"""
|
||||
Validate numeric parameter is within range.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
min_value: Minimum allowed value
|
||||
max_value: Maximum allowed value
|
||||
param_name: Parameter name for error messages
|
||||
|
||||
Returns:
|
||||
Validated value
|
||||
|
||||
Raises:
|
||||
ValidationError: If value out of range
|
||||
"""
|
||||
if value < min_value or value > max_value:
|
||||
raise ValidationError(
|
||||
f"{param_name} must be between {min_value} and {max_value}, "
|
||||
f"got {value}"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_beam_size(beam_size: int) -> int:
|
||||
"""Validate beam size parameter."""
|
||||
return int(validate_numeric_range(beam_size, 1, 20, "beam_size"))
|
||||
|
||||
|
||||
def validate_temperature(temperature: float) -> float:
|
||||
"""Validate temperature parameter."""
|
||||
return validate_numeric_range(temperature, 0.0, 1.0, "temperature")
|
||||
237
src/utils/startup.py
Normal file
237
src/utils/startup.py
Normal file
@@ -0,0 +1,237 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Common Startup Logic Module
|
||||
|
||||
Centralizes startup procedures shared between MCP and API servers,
|
||||
including GPU health checks, job queue initialization, and health monitoring.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import GPU health check with reset
|
||||
try:
|
||||
from core.gpu_health import check_gpu_health_with_reset
|
||||
GPU_HEALTH_CHECK_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"GPU health check with reset not available: {e}")
|
||||
GPU_HEALTH_CHECK_AVAILABLE = False
|
||||
|
||||
|
||||
def perform_startup_gpu_check(
|
||||
required_device: str = "cuda",
|
||||
auto_reset: bool = True,
|
||||
exit_on_failure: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Perform startup GPU health check with optional auto-reset.
|
||||
|
||||
This function:
|
||||
1. Checks if GPU health check is available
|
||||
2. Runs comprehensive GPU health check
|
||||
3. Attempts auto-reset if check fails and auto_reset=True
|
||||
4. Optionally exits process if check fails
|
||||
|
||||
Args:
|
||||
required_device: Required device ("cuda", "auto")
|
||||
auto_reset: Enable automatic GPU driver reset on failure
|
||||
exit_on_failure: Exit process if GPU check fails
|
||||
|
||||
Returns:
|
||||
True if GPU check passed, False otherwise
|
||||
|
||||
Side effects:
|
||||
May exit process if exit_on_failure=True and check fails
|
||||
"""
|
||||
if not GPU_HEALTH_CHECK_AVAILABLE:
|
||||
logger.warning("GPU health check not available, starting without GPU validation")
|
||||
if exit_on_failure:
|
||||
logger.error("GPU health check required but not available. Exiting.")
|
||||
sys.exit(1)
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info("=" * 70)
|
||||
logger.info("PERFORMING STARTUP GPU HEALTH CHECK")
|
||||
logger.info("=" * 70)
|
||||
|
||||
status = check_gpu_health_with_reset(
|
||||
expected_device=required_device,
|
||||
auto_reset=auto_reset
|
||||
)
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("STARTUP GPU CHECK SUCCESSFUL")
|
||||
logger.info(f"GPU Device: {status.device_name}")
|
||||
logger.info(f"Memory Available: {status.memory_available_gb:.2f} GB")
|
||||
logger.info(f"Test Duration: {status.test_duration_seconds:.2f}s")
|
||||
logger.info("=" * 70)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("=" * 70)
|
||||
logger.error("STARTUP GPU CHECK FAILED")
|
||||
logger.error(f"Error: {e}")
|
||||
|
||||
if exit_on_failure:
|
||||
logger.error("This service requires GPU. Terminating.")
|
||||
logger.error("=" * 70)
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.error("Continuing without GPU (may have reduced functionality)")
|
||||
logger.error("=" * 70)
|
||||
return False
|
||||
|
||||
|
||||
def initialize_job_queue(
|
||||
max_queue_size: Optional[int] = None,
|
||||
metadata_dir: Optional[str] = None
|
||||
) -> 'JobQueue':
|
||||
"""
|
||||
Initialize job queue with environment variable configuration.
|
||||
|
||||
Args:
|
||||
max_queue_size: Override for max queue size (uses env var if None)
|
||||
metadata_dir: Override for metadata directory (uses env var if None)
|
||||
|
||||
Returns:
|
||||
Initialized JobQueue instance (started)
|
||||
"""
|
||||
from core.job_queue import JobQueue
|
||||
|
||||
# Get configuration from environment
|
||||
if max_queue_size is None:
|
||||
max_queue_size = int(os.getenv("JOB_QUEUE_MAX_SIZE", "100"))
|
||||
|
||||
if metadata_dir is None:
|
||||
metadata_dir = os.getenv(
|
||||
"JOB_METADATA_DIR",
|
||||
"/media/raid/agents/tools/mcp-transcriptor/outputs/jobs"
|
||||
)
|
||||
|
||||
logger.info("Initializing job queue...")
|
||||
job_queue = JobQueue(max_queue_size=max_queue_size, metadata_dir=metadata_dir)
|
||||
job_queue.start()
|
||||
logger.info(f"Job queue started (max_size={max_queue_size}, metadata_dir={metadata_dir})")
|
||||
|
||||
return job_queue
|
||||
|
||||
|
||||
def initialize_health_monitor(
|
||||
check_interval_minutes: Optional[int] = None,
|
||||
enabled: Optional[bool] = None
|
||||
) -> Optional['HealthMonitor']:
|
||||
"""
|
||||
Initialize GPU health monitor with environment variable configuration.
|
||||
|
||||
Args:
|
||||
check_interval_minutes: Override for check interval (uses env var if None)
|
||||
enabled: Override for enabled status (uses env var if None)
|
||||
|
||||
Returns:
|
||||
Initialized HealthMonitor instance (started), or None if disabled
|
||||
"""
|
||||
from core.gpu_health import HealthMonitor
|
||||
|
||||
# Get configuration from environment
|
||||
if enabled is None:
|
||||
enabled = os.getenv("GPU_HEALTH_CHECK_ENABLED", "true").lower() == "true"
|
||||
|
||||
if not enabled:
|
||||
logger.info("GPU health monitoring disabled")
|
||||
return None
|
||||
|
||||
if check_interval_minutes is None:
|
||||
check_interval_minutes = int(os.getenv("GPU_HEALTH_CHECK_INTERVAL_MINUTES", "10"))
|
||||
|
||||
health_monitor = HealthMonitor(check_interval_minutes=check_interval_minutes)
|
||||
health_monitor.start()
|
||||
logger.info(f"GPU health monitor started (interval={check_interval_minutes} minutes)")
|
||||
|
||||
return health_monitor
|
||||
|
||||
|
||||
def startup_sequence(
|
||||
service_name: str = "whisper-transcription",
|
||||
require_gpu: bool = True,
|
||||
initialize_queue: bool = True,
|
||||
initialize_monitoring: bool = True
|
||||
) -> Tuple[Optional['JobQueue'], Optional['HealthMonitor']]:
|
||||
"""
|
||||
Execute complete startup sequence for a Whisper transcription server.
|
||||
|
||||
This function performs all common startup tasks:
|
||||
1. GPU health check with auto-reset
|
||||
2. Job queue initialization
|
||||
3. Health monitor initialization
|
||||
|
||||
Args:
|
||||
service_name: Name of the service (for logging)
|
||||
require_gpu: Whether GPU is required (exit if not available)
|
||||
initialize_queue: Whether to initialize job queue
|
||||
initialize_monitoring: Whether to initialize health monitoring
|
||||
|
||||
Returns:
|
||||
Tuple of (job_queue, health_monitor) - either may be None
|
||||
|
||||
Side effects:
|
||||
May exit process if GPU required but unavailable
|
||||
"""
|
||||
logger.info(f"Starting {service_name}...")
|
||||
|
||||
# Step 1: GPU health check
|
||||
gpu_ok = perform_startup_gpu_check(
|
||||
required_device="cuda",
|
||||
auto_reset=True,
|
||||
exit_on_failure=require_gpu
|
||||
)
|
||||
|
||||
if not gpu_ok and require_gpu:
|
||||
# Should not reach here (exit_on_failure should have exited)
|
||||
logger.error("GPU check failed and GPU is required")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 2: Initialize job queue
|
||||
job_queue = None
|
||||
if initialize_queue:
|
||||
job_queue = initialize_job_queue()
|
||||
|
||||
# Step 3: Initialize health monitor
|
||||
health_monitor = None
|
||||
if initialize_monitoring:
|
||||
health_monitor = initialize_health_monitor()
|
||||
|
||||
logger.info(f"{service_name} startup sequence completed")
|
||||
|
||||
return job_queue, health_monitor
|
||||
|
||||
|
||||
def cleanup_on_shutdown(
|
||||
job_queue: Optional['JobQueue'] = None,
|
||||
health_monitor: Optional['HealthMonitor'] = None,
|
||||
wait_for_current_job: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Perform cleanup on server shutdown.
|
||||
|
||||
Args:
|
||||
job_queue: JobQueue instance to stop (if any)
|
||||
health_monitor: HealthMonitor instance to stop (if any)
|
||||
wait_for_current_job: Wait for current job to complete before stopping
|
||||
"""
|
||||
logger.info("Shutting down...")
|
||||
|
||||
if job_queue:
|
||||
job_queue.stop(wait_for_current=wait_for_current_job)
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
if health_monitor:
|
||||
health_monitor.stop()
|
||||
logger.info("Health monitor stopped")
|
||||
|
||||
logger.info("Shutdown complete")
|
||||
61
src/utils/test_audio_generator.py
Normal file
61
src/utils/test_audio_generator.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Test audio generator for GPU health checks.
|
||||
|
||||
Returns path to existing test audio file - NO GENERATION, NO INTERNET.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
||||
def generate_test_audio(duration_seconds: float = 3.0, frequency: int = 440) -> str:
|
||||
"""
|
||||
Return path to existing test audio file for GPU health checks.
|
||||
|
||||
NO AUDIO GENERATION - just returns path to pre-existing test file.
|
||||
NO INTERNET CONNECTION REQUIRED.
|
||||
|
||||
Args:
|
||||
duration_seconds: Duration hint (default: 3.0) - used for cache lookup
|
||||
frequency: Legacy parameter, ignored
|
||||
|
||||
Returns:
|
||||
str: Path to test audio file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If test audio file doesn't exist
|
||||
"""
|
||||
# Check for existing test audio in temp directory
|
||||
temp_dir = tempfile.gettempdir()
|
||||
audio_path = os.path.join(temp_dir, f"whisper_test_voice_{int(duration_seconds)}s.mp3")
|
||||
|
||||
# Return cached file if it exists and is valid
|
||||
if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0:
|
||||
return audio_path
|
||||
|
||||
# If no cached file, raise error - we don't generate anything
|
||||
raise RuntimeError(
|
||||
f"Test audio file not found: {audio_path}. "
|
||||
f"Please ensure test audio exists before running GPU health checks. "
|
||||
f"Expected file location: {audio_path}"
|
||||
)
|
||||
|
||||
|
||||
def cleanup_test_audio() -> None:
|
||||
"""
|
||||
Remove all cached test audio files from temp directory.
|
||||
|
||||
Useful for cleanup after testing or to force regeneration.
|
||||
"""
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
# Find all test audio files
|
||||
for filename in os.listdir(temp_dir):
|
||||
if (filename.startswith("whisper_test_") and
|
||||
(filename.endswith(".wav") or filename.endswith(".mp3"))):
|
||||
filepath = os.path.join(temp_dir, filename)
|
||||
try:
|
||||
os.remove(filepath)
|
||||
except Exception:
|
||||
# Ignore errors during cleanup
|
||||
pass
|
||||
@@ -1,16 +0,0 @@
|
||||
@echo off
|
||||
echo 启动Whisper语音识别MCP服务器...
|
||||
|
||||
:: 激活虚拟环境(如果存在)
|
||||
if exist "..\venv\Scripts\activate.bat" (
|
||||
call ..\venv\Scripts\activate.bat
|
||||
)
|
||||
|
||||
:: 运行MCP服务器
|
||||
python whisper_server.py
|
||||
|
||||
:: 如果出错,暂停以查看错误信息
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo 服务器启动失败,错误代码: %ERRORLEVEL%
|
||||
pause
|
||||
)
|
||||
60
test_filename_fix.py
Normal file
60
test_filename_fix.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick manual test to verify the filename validation fix.
|
||||
Tests the exact case from the bug report.
|
||||
"""
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, 'src')
|
||||
|
||||
from utils.input_validation import validate_filename_safe, PathTraversalError
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("FILENAME VALIDATION FIX - MANUAL TEST")
|
||||
print("="*70 + "\n")
|
||||
|
||||
# Bug report case
|
||||
bug_report_filename = "This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a"
|
||||
|
||||
print(f"Testing bug report filename:")
|
||||
print(f" '{bug_report_filename}'")
|
||||
print()
|
||||
|
||||
try:
|
||||
result = validate_filename_safe(bug_report_filename)
|
||||
print(f"✅ SUCCESS: Filename accepted!")
|
||||
print(f" Returned: '{result}'")
|
||||
except PathTraversalError as e:
|
||||
print(f"❌ FAILED: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
|
||||
# Test that security still works
|
||||
print("Verifying security (path traversal should still be blocked):")
|
||||
dangerous_filenames = [
|
||||
"../../../etc/passwd",
|
||||
"../../secrets.txt",
|
||||
"dir/file.m4a",
|
||||
]
|
||||
|
||||
for dangerous in dangerous_filenames:
|
||||
try:
|
||||
validate_filename_safe(dangerous)
|
||||
print(f"❌ SECURITY ISSUE: '{dangerous}' was accepted (should be blocked!)")
|
||||
sys.exit(1)
|
||||
except PathTraversalError:
|
||||
print(f"✅ '{dangerous}' correctly blocked")
|
||||
|
||||
print()
|
||||
print("="*70)
|
||||
print("ALL TESTS PASSED! ✅")
|
||||
print("="*70)
|
||||
print()
|
||||
print("The fix is working correctly:")
|
||||
print(" ✓ Filenames with ellipsis (...) are now accepted")
|
||||
print(" ✓ Path traversal attacks are still blocked")
|
||||
print()
|
||||
537
tests/test_async_api_integration.py
Executable file
537
tests/test_async_api_integration.py
Executable file
@@ -0,0 +1,537 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Phase 2: Async Job Queue Integration
|
||||
|
||||
Tests the async job queue system for both API and MCP servers.
|
||||
Validates all new endpoints and error handling.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add src to path (go up one level from tests/ to root)
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
# Color codes for terminal output
|
||||
class Colors:
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
END = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
|
||||
def print_success(msg):
|
||||
print(f"{Colors.GREEN}✓ {msg}{Colors.END}")
|
||||
|
||||
def print_error(msg):
|
||||
print(f"{Colors.RED}✗ {msg}{Colors.END}")
|
||||
|
||||
def print_info(msg):
|
||||
print(f"{Colors.BLUE}ℹ {msg}{Colors.END}")
|
||||
|
||||
def print_section(msg):
|
||||
print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{msg}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}\n")
|
||||
|
||||
|
||||
class Phase2Tester:
|
||||
def __init__(self, api_url="http://localhost:8000"):
|
||||
self.api_url = api_url
|
||||
self.test_results = []
|
||||
|
||||
def test(self, name, func):
|
||||
"""Run a test and record result"""
|
||||
try:
|
||||
logger.info(f"Testing: {name}")
|
||||
print_info(f"Testing: {name}")
|
||||
func()
|
||||
logger.info(f"PASSED: {name}")
|
||||
print_success(f"PASSED: {name}")
|
||||
self.test_results.append((name, True, None))
|
||||
return True
|
||||
except AssertionError as e:
|
||||
logger.error(f"FAILED: {name} - {str(e)}")
|
||||
print_error(f"FAILED: {name}")
|
||||
print_error(f" Reason: {str(e)}")
|
||||
self.test_results.append((name, False, str(e)))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR: {name} - {str(e)}")
|
||||
print_error(f"ERROR: {name}")
|
||||
print_error(f" Exception: {str(e)}")
|
||||
self.test_results.append((name, False, f"Exception: {str(e)}"))
|
||||
return False
|
||||
|
||||
def print_summary(self):
|
||||
"""Print test summary"""
|
||||
print_section("TEST SUMMARY")
|
||||
|
||||
passed = sum(1 for _, result, _ in self.test_results if result)
|
||||
failed = len(self.test_results) - passed
|
||||
|
||||
for name, result, error in self.test_results:
|
||||
if result:
|
||||
print_success(f"{name}")
|
||||
else:
|
||||
print_error(f"{name}")
|
||||
if error:
|
||||
print(f" {error}")
|
||||
|
||||
print(f"\n{Colors.BOLD}Total: {len(self.test_results)} | ", end="")
|
||||
print(f"{Colors.GREEN}Passed: {passed}{Colors.END} | ", end="")
|
||||
print(f"{Colors.RED}Failed: {failed}{Colors.END}\n")
|
||||
|
||||
return failed == 0
|
||||
|
||||
# ========================================================================
|
||||
# API Server Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_api_root_endpoint(self):
|
||||
"""Test GET / returns new API information"""
|
||||
logger.info(f"GET {self.api_url}/")
|
||||
resp = requests.get(f"{self.api_url}/")
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
logger.info(f"Response data: {json.dumps(data, indent=2)}")
|
||||
assert data["version"] == "0.2.0", "Version should be 0.2.0"
|
||||
assert "POST /jobs" in str(data["endpoints"]), "Should have POST /jobs endpoint"
|
||||
assert "workflow" in data, "Should have workflow documentation"
|
||||
|
||||
def test_api_health_endpoint(self):
|
||||
"""Test GET /health still works"""
|
||||
logger.info(f"GET {self.api_url}/health")
|
||||
resp = requests.get(f"{self.api_url}/health")
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
logger.info(f"Response data: {data}")
|
||||
assert data["status"] == "healthy", "Health check should return healthy"
|
||||
|
||||
def test_api_models_endpoint(self):
|
||||
"""Test GET /models still works"""
|
||||
logger.info(f"GET {self.api_url}/models")
|
||||
resp = requests.get(f"{self.api_url}/models")
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
logger.info(f"Available models: {data.get('available_models', [])}")
|
||||
assert "available_models" in data, "Should return available models"
|
||||
|
||||
def test_api_gpu_health_endpoint(self):
|
||||
"""Test GET /health/gpu returns GPU status"""
|
||||
logger.info(f"GET {self.api_url}/health/gpu")
|
||||
resp = requests.get(f"{self.api_url}/health/gpu")
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
logger.info(f"GPU health: {json.dumps(data, indent=2)}")
|
||||
assert "gpu_available" in data, "Should have gpu_available field"
|
||||
assert "gpu_working" in data, "Should have gpu_working field"
|
||||
assert "interpretation" in data, "Should have interpretation field"
|
||||
|
||||
print_info(f" GPU Status: {data.get('interpretation', 'unknown')}")
|
||||
|
||||
def test_api_submit_job_invalid_audio(self):
|
||||
"""Test POST /jobs with invalid audio path returns 400"""
|
||||
payload = {
|
||||
"audio_path": "/nonexistent/file.mp3",
|
||||
"model_name": "tiny",
|
||||
"output_format": "txt"
|
||||
}
|
||||
|
||||
logger.info(f"POST {self.api_url}/jobs with invalid audio path")
|
||||
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
|
||||
resp = requests.post(f"{self.api_url}/jobs", json=payload)
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
logger.info(f"Response: {resp.json()}")
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "error" in data["detail"], "Should have error field"
|
||||
assert data["detail"]["error"] == "Invalid audio file", f"Wrong error type: {data['detail']['error']}"
|
||||
|
||||
print_info(f" Error message: {data['detail']['message'][:50]}...")
|
||||
|
||||
def test_api_submit_job_cpu_device_rejected(self):
|
||||
"""Test POST /jobs with device=cpu is rejected (400)"""
|
||||
# Create a test audio file first
|
||||
logger.info("Creating test audio file...")
|
||||
test_audio = self._create_test_audio_file()
|
||||
logger.info(f"Test audio created at: {test_audio}")
|
||||
|
||||
payload = {
|
||||
"audio_path": test_audio,
|
||||
"model_name": "tiny",
|
||||
"device": "cpu",
|
||||
"output_format": "txt"
|
||||
}
|
||||
|
||||
logger.info(f"POST {self.api_url}/jobs with device=cpu")
|
||||
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
|
||||
resp = requests.post(f"{self.api_url}/jobs", json=payload)
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
logger.info(f"Response: {resp.json()}")
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "error" in data["detail"], "Should have error field"
|
||||
assert "Invalid device" in data["detail"]["error"] or "CPU" in data["detail"]["message"], \
|
||||
"Should reject CPU device"
|
||||
|
||||
def test_api_submit_job_success(self):
|
||||
"""Test POST /jobs with valid audio returns job_id"""
|
||||
logger.info("Creating test audio file...")
|
||||
test_audio = self._create_test_audio_file()
|
||||
logger.info(f"Test audio created at: {test_audio}")
|
||||
|
||||
payload = {
|
||||
"audio_path": test_audio,
|
||||
"model_name": "tiny",
|
||||
"device": "auto",
|
||||
"output_format": "txt"
|
||||
}
|
||||
|
||||
logger.info(f"POST {self.api_url}/jobs with valid audio")
|
||||
logger.info(f"Payload: {json.dumps(payload, indent=2)}")
|
||||
resp = requests.post(f"{self.api_url}/jobs", json=payload)
|
||||
logger.info(f"Response status: {resp.status_code}")
|
||||
logger.info(f"Response: {json.dumps(resp.json(), indent=2)}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "job_id" in data, "Should return job_id"
|
||||
assert "status" in data, "Should return status"
|
||||
assert data["status"] == "queued", f"Status should be queued, got {data['status']}"
|
||||
assert "queue_position" in data, "Should return queue_position"
|
||||
assert "message" in data, "Should return message"
|
||||
|
||||
logger.info(f"Job submitted successfully: {data['job_id']}")
|
||||
print_info(f" Job ID: {data['job_id']}")
|
||||
print_info(f" Queue position: {data['queue_position']}")
|
||||
|
||||
# Store job_id for later tests
|
||||
self.test_job_id = data["job_id"]
|
||||
|
||||
def test_api_get_job_status(self):
|
||||
"""Test GET /jobs/{job_id} returns job status"""
|
||||
if not hasattr(self, 'test_job_id'):
|
||||
print_info(" Skipping (no test_job_id from previous test)")
|
||||
return
|
||||
|
||||
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "job_id" in data, "Should return job_id"
|
||||
assert "status" in data, "Should return status"
|
||||
assert data["status"] in ["queued", "running", "completed", "failed"], \
|
||||
f"Invalid status: {data['status']}"
|
||||
|
||||
print_info(f" Status: {data['status']}")
|
||||
|
||||
def test_api_get_job_status_not_found(self):
|
||||
"""Test GET /jobs/{job_id} with invalid ID returns 404"""
|
||||
fake_job_id = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
resp = requests.get(f"{self.api_url}/jobs/{fake_job_id}")
|
||||
assert resp.status_code == 404, f"Expected 404, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "error" in data["detail"], "Should have error field"
|
||||
assert data["detail"]["error"] == "Job not found", f"Wrong error: {data['detail']['error']}"
|
||||
|
||||
def test_api_get_job_result_not_completed(self):
|
||||
"""Test GET /jobs/{job_id}/result when job not completed returns 409"""
|
||||
if not hasattr(self, 'test_job_id'):
|
||||
print_info(" Skipping (no test_job_id from previous test)")
|
||||
return
|
||||
|
||||
# Check current status
|
||||
status_resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
|
||||
current_status = status_resp.json()["status"]
|
||||
|
||||
if current_status == "completed":
|
||||
print_info(" Skipping (job already completed)")
|
||||
return
|
||||
|
||||
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}/result")
|
||||
assert resp.status_code == 409, f"Expected 409, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "error" in data["detail"], "Should have error field"
|
||||
assert data["detail"]["error"] == "Job not completed", f"Wrong error: {data['detail']['error']}"
|
||||
assert "current_status" in data["detail"], "Should include current_status"
|
||||
|
||||
def test_api_list_jobs(self):
|
||||
"""Test GET /jobs returns job list"""
|
||||
resp = requests.get(f"{self.api_url}/jobs")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "jobs" in data, "Should have jobs field"
|
||||
assert "total" in data, "Should have total field"
|
||||
assert isinstance(data["jobs"], list), "Jobs should be a list"
|
||||
|
||||
print_info(f" Total jobs: {data['total']}")
|
||||
|
||||
def test_api_list_jobs_with_filter(self):
|
||||
"""Test GET /jobs?status=queued filters by status"""
|
||||
resp = requests.get(f"{self.api_url}/jobs?status=queued&limit=10")
|
||||
assert resp.status_code == 200, f"Expected 200, got {resp.status_code}"
|
||||
|
||||
data = resp.json()
|
||||
assert "jobs" in data, "Should have jobs field"
|
||||
assert "filters" in data, "Should have filters field"
|
||||
assert data["filters"]["status"] == "queued", "Filter should be applied"
|
||||
|
||||
# All returned jobs should be queued
|
||||
for job in data["jobs"]:
|
||||
assert job["status"] == "queued", f"Job {job['job_id']} has wrong status: {job['status']}"
|
||||
|
||||
def test_api_wait_for_job_completion(self):
|
||||
"""Test waiting for job to complete and retrieving result"""
|
||||
if not hasattr(self, 'test_job_id'):
|
||||
logger.warning("Skipping - no test_job_id from previous test")
|
||||
print_info(" Skipping (no test_job_id from previous test)")
|
||||
return
|
||||
|
||||
logger.info(f"Waiting for job {self.test_job_id} to complete (max 60s)...")
|
||||
print_info(" Waiting for job to complete (max 60s)...")
|
||||
|
||||
max_wait = 60
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < max_wait:
|
||||
resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}")
|
||||
data = resp.json()
|
||||
status = data["status"]
|
||||
elapsed = int(time.time() - start_time)
|
||||
|
||||
logger.info(f"Job status: {status} (elapsed: {elapsed}s)")
|
||||
print_info(f" Status: {status} (elapsed: {elapsed}s)")
|
||||
|
||||
if status == "completed":
|
||||
logger.info("Job completed successfully!")
|
||||
print_success(" Job completed!")
|
||||
|
||||
# Now get the result
|
||||
logger.info("Fetching job result...")
|
||||
result_resp = requests.get(f"{self.api_url}/jobs/{self.test_job_id}/result")
|
||||
logger.info(f"Result response status: {result_resp.status_code}")
|
||||
assert result_resp.status_code == 200, f"Expected 200, got {result_resp.status_code}"
|
||||
|
||||
result_data = result_resp.json()
|
||||
logger.info(f"Result data keys: {result_data.keys()}")
|
||||
assert "result" in result_data, "Should have result field"
|
||||
assert len(result_data["result"]) > 0, "Result should not be empty"
|
||||
|
||||
actual_text = result_data["result"].strip()
|
||||
logger.info(f"Transcription result: '{actual_text}'")
|
||||
print_info(f" Transcription: '{actual_text}'")
|
||||
return
|
||||
|
||||
elif status == "failed":
|
||||
error_msg = f"Job failed: {data.get('error', 'unknown error')}"
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
error_msg = f"Job did not complete within {max_wait}s"
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
# ========================================================================
|
||||
# MCP Server Tests (Import-based)
|
||||
# ========================================================================
|
||||
|
||||
def test_mcp_imports(self):
|
||||
"""Test MCP server modules can be imported"""
|
||||
try:
|
||||
logger.info("Importing MCP server module...")
|
||||
from servers import whisper_server
|
||||
|
||||
logger.info("Checking for new async tools...")
|
||||
assert hasattr(whisper_server, 'transcribe_async'), "Should have transcribe_async tool"
|
||||
assert hasattr(whisper_server, 'get_job_status'), "Should have get_job_status tool"
|
||||
assert hasattr(whisper_server, 'get_job_result'), "Should have get_job_result tool"
|
||||
assert hasattr(whisper_server, 'list_transcription_jobs'), "Should have list_transcription_jobs tool"
|
||||
assert hasattr(whisper_server, 'check_gpu_health'), "Should have check_gpu_health tool"
|
||||
assert hasattr(whisper_server, 'get_model_info_api'), "Should have get_model_info_api tool"
|
||||
logger.info("All new tools found!")
|
||||
|
||||
# Verify old tools are removed
|
||||
logger.info("Verifying old tools are removed...")
|
||||
assert not hasattr(whisper_server, 'transcribe'), "Old transcribe tool should be removed"
|
||||
assert not hasattr(whisper_server, 'batch_transcribe_audio'), "Old batch_transcribe_audio tool should be removed"
|
||||
logger.info("Old tools successfully removed!")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import MCP server: {e}")
|
||||
raise AssertionError(f"Failed to import MCP server: {e}")
|
||||
|
||||
def test_job_queue_integration(self):
|
||||
"""Test JobQueue integration is working"""
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
|
||||
# Create a test queue
|
||||
test_queue = JobQueue(max_queue_size=5, metadata_dir="/tmp/test_job_queue")
|
||||
|
||||
try:
|
||||
# Verify it can be started
|
||||
test_queue.start()
|
||||
assert test_queue._worker_thread is not None, "Worker thread should be created"
|
||||
assert test_queue._worker_thread.is_alive(), "Worker thread should be running"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
test_queue.stop(wait_for_current=False)
|
||||
|
||||
def test_health_monitor_integration(self):
|
||||
"""Test HealthMonitor integration is working"""
|
||||
from core.gpu_health import HealthMonitor
|
||||
|
||||
# Create a test monitor
|
||||
test_monitor = HealthMonitor(check_interval_minutes=60) # Long interval
|
||||
|
||||
try:
|
||||
# Verify it can be started
|
||||
test_monitor.start()
|
||||
assert test_monitor._thread is not None, "Monitor thread should be created"
|
||||
assert test_monitor._thread.is_alive(), "Monitor thread should be running"
|
||||
|
||||
# Check we can get status
|
||||
status = test_monitor.get_latest_status()
|
||||
assert status is not None, "Should have initial status"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
test_monitor.stop()
|
||||
|
||||
# ========================================================================
|
||||
# Helper Methods
|
||||
# ========================================================================
|
||||
|
||||
def _create_test_audio_file(self):
|
||||
"""Get the path to the test audio file"""
|
||||
# Use relative path from project root
|
||||
project_root = Path(__file__).parent.parent
|
||||
test_audio_path = str(project_root / "data" / "test.mp3")
|
||||
if not os.path.exists(test_audio_path):
|
||||
raise FileNotFoundError(f"Test audio file not found: {test_audio_path}")
|
||||
return test_audio_path
|
||||
|
||||
|
||||
def main():
|
||||
print_section("PHASE 2: ASYNC JOB QUEUE INTEGRATION TESTS")
|
||||
logger.info("=" * 70)
|
||||
logger.info("PHASE 2: ASYNC JOB QUEUE INTEGRATION TESTS")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Check if API server is running
|
||||
api_url = os.getenv("API_URL", "http://localhost:8000")
|
||||
logger.info(f"Testing API server at: {api_url}")
|
||||
print_info(f"Testing API server at: {api_url}")
|
||||
|
||||
try:
|
||||
logger.info("Checking API server health...")
|
||||
resp = requests.get(f"{api_url}/health", timeout=2)
|
||||
logger.info(f"Health check status: {resp.status_code}")
|
||||
if resp.status_code != 200:
|
||||
logger.error(f"API server not responding correctly at {api_url}")
|
||||
print_error(f"API server not responding correctly at {api_url}")
|
||||
print_error("Please start the API server with: ./run_api_server.sh")
|
||||
return 1
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Cannot connect to API server: {e}")
|
||||
print_error(f"Cannot connect to API server at {api_url}")
|
||||
print_error("Please start the API server with: ./run_api_server.sh")
|
||||
return 1
|
||||
|
||||
logger.info(f"API server is running at {api_url}")
|
||||
print_success(f"API server is running at {api_url}")
|
||||
|
||||
# Create tester
|
||||
tester = Phase2Tester(api_url=api_url)
|
||||
|
||||
# ========================================================================
|
||||
# Run API Tests
|
||||
# ========================================================================
|
||||
print_section("API SERVER TESTS")
|
||||
logger.info("Starting API server tests...")
|
||||
|
||||
tester.test("API Root Endpoint", tester.test_api_root_endpoint)
|
||||
tester.test("API Health Endpoint", tester.test_api_health_endpoint)
|
||||
tester.test("API Models Endpoint", tester.test_api_models_endpoint)
|
||||
tester.test("API GPU Health Endpoint", tester.test_api_gpu_health_endpoint)
|
||||
|
||||
print_section("API JOB SUBMISSION TESTS")
|
||||
|
||||
tester.test("Submit Job - Invalid Audio (400)", tester.test_api_submit_job_invalid_audio)
|
||||
tester.test("Submit Job - CPU Device Rejected (400)", tester.test_api_submit_job_cpu_device_rejected)
|
||||
tester.test("Submit Job - Success (200)", tester.test_api_submit_job_success)
|
||||
|
||||
print_section("API JOB STATUS TESTS")
|
||||
|
||||
tester.test("Get Job Status - Success", tester.test_api_get_job_status)
|
||||
tester.test("Get Job Status - Not Found (404)", tester.test_api_get_job_status_not_found)
|
||||
tester.test("Get Job Result - Not Completed (409)", tester.test_api_get_job_result_not_completed)
|
||||
|
||||
print_section("API JOB LISTING TESTS")
|
||||
|
||||
tester.test("List Jobs", tester.test_api_list_jobs)
|
||||
tester.test("List Jobs - With Filter", tester.test_api_list_jobs_with_filter)
|
||||
|
||||
print_section("API JOB COMPLETION TEST")
|
||||
|
||||
tester.test("Wait for Job Completion & Get Result", tester.test_api_wait_for_job_completion)
|
||||
|
||||
# ========================================================================
|
||||
# Run MCP Tests
|
||||
# ========================================================================
|
||||
print_section("MCP SERVER TESTS")
|
||||
logger.info("Starting MCP server tests...")
|
||||
|
||||
tester.test("MCP Module Imports", tester.test_mcp_imports)
|
||||
tester.test("JobQueue Integration", tester.test_job_queue_integration)
|
||||
tester.test("HealthMonitor Integration", tester.test_health_monitor_integration)
|
||||
|
||||
# ========================================================================
|
||||
# Print Summary
|
||||
# ========================================================================
|
||||
logger.info("All tests completed, generating summary...")
|
||||
success = tester.print_summary()
|
||||
|
||||
if success:
|
||||
logger.info("ALL TESTS PASSED!")
|
||||
print_section("ALL TESTS PASSED! ✓")
|
||||
return 0
|
||||
else:
|
||||
logger.error("SOME TESTS FAILED!")
|
||||
print_section("SOME TESTS FAILED! ✗")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
287
tests/test_core_components.py
Executable file
287
tests/test_core_components.py
Executable file
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Test script for Phase 1 components.
|
||||
|
||||
Tests:
|
||||
1. Test audio file validation
|
||||
2. GPU health check
|
||||
3. Job queue operations
|
||||
|
||||
IMPORTANT: This service requires GPU. Tests will fail if GPU is not available.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
|
||||
# Add src to path (go up one level from tests/ to root)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from core.gpu_health import check_gpu_health, HealthMonitor
|
||||
from core.job_queue import JobQueue, JobStatus
|
||||
|
||||
|
||||
def check_gpu_available():
|
||||
"""
|
||||
Check if GPU is available. Exit if not.
|
||||
This service requires GPU and will not run on CPU.
|
||||
"""
|
||||
print("\n" + "="*60)
|
||||
print("GPU REQUIREMENT CHECK")
|
||||
print("="*60)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("✗ CUDA not available - GPU is required for this service")
|
||||
print(" This service is configured for GPU-only operation")
|
||||
print(" Please ensure CUDA is properly installed and GPU is accessible")
|
||||
print("="*60)
|
||||
sys.exit(1)
|
||||
|
||||
gpu_name = torch.cuda.get_device_name(0)
|
||||
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||||
print(f"✓ GPU available: {gpu_name}")
|
||||
print(f"✓ GPU memory: {gpu_memory:.2f} GB")
|
||||
print("="*60)
|
||||
|
||||
|
||||
def test_audio_file():
|
||||
"""Test audio file existence and validity."""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 1: Test Audio File")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Use the actual test audio file (relative to project root)
|
||||
project_root = os.path.join(os.path.dirname(__file__), '..')
|
||||
audio_path = os.path.join(project_root, "data/test.mp3")
|
||||
|
||||
# Verify file exists
|
||||
assert os.path.exists(audio_path), "Audio file not found"
|
||||
print(f"✓ Test audio file exists: {audio_path}")
|
||||
|
||||
# Verify file is not empty
|
||||
file_size = os.path.getsize(audio_path)
|
||||
assert file_size > 0, "Audio file is empty"
|
||||
print(f"✓ Audio file size: {file_size} bytes")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Audio file test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_gpu_health():
|
||||
"""Test GPU health check."""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 2: GPU Health Check")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Test with cuda device (enforcing GPU requirement)
|
||||
print("\nRunning health check with device='cuda'...")
|
||||
logging.info("Starting GPU health check...")
|
||||
status = check_gpu_health(expected_device="cuda")
|
||||
logging.info("GPU health check completed")
|
||||
|
||||
print(f"✓ Health check completed")
|
||||
print(f" - GPU available: {status.gpu_available}")
|
||||
print(f" - GPU working: {status.gpu_working}")
|
||||
print(f" - Device used: {status.device_used}")
|
||||
print(f" - Device name: {status.device_name}")
|
||||
print(f" - Memory total: {status.memory_total_gb:.2f} GB")
|
||||
print(f" - Memory available: {status.memory_available_gb:.2f} GB")
|
||||
print(f" - Test duration: {status.test_duration_seconds:.2f}s")
|
||||
print(f" - Error: {status.error}")
|
||||
|
||||
# Test health monitor
|
||||
print("\nTesting HealthMonitor...")
|
||||
monitor = HealthMonitor(check_interval_minutes=1)
|
||||
monitor.start()
|
||||
print("✓ Health monitor started")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
latest = monitor.get_latest_status()
|
||||
assert latest is not None, "No status available from monitor"
|
||||
print(f"✓ Latest status retrieved: {latest.device_used}")
|
||||
|
||||
monitor.stop()
|
||||
print("✓ Health monitor stopped")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ GPU health test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_job_queue():
|
||||
"""Test job queue operations."""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 3: Job Queue")
|
||||
print("="*60)
|
||||
|
||||
# Create temp directory for testing
|
||||
import tempfile
|
||||
temp_dir = tempfile.mkdtemp(prefix="job_queue_test_")
|
||||
print(f"Using temp directory: {temp_dir}")
|
||||
|
||||
try:
|
||||
# Initialize job queue
|
||||
print("\nInitializing job queue...")
|
||||
job_queue = JobQueue(max_queue_size=10, metadata_dir=temp_dir)
|
||||
job_queue.start()
|
||||
print("✓ Job queue started")
|
||||
|
||||
# Use the actual test audio file (relative to project root)
|
||||
project_root = os.path.join(os.path.dirname(__file__), '..')
|
||||
audio_path = os.path.join(project_root, "data/test.mp3")
|
||||
|
||||
# Test job submission
|
||||
print("\nSubmitting test job...")
|
||||
logging.info("Submitting transcription job to queue...")
|
||||
job_info = job_queue.submit_job(
|
||||
audio_path=audio_path,
|
||||
model_name="tiny",
|
||||
device="cuda", # Enforcing GPU requirement
|
||||
output_format="txt"
|
||||
)
|
||||
job_id = job_info["job_id"]
|
||||
logging.info(f"Job submitted: {job_id}")
|
||||
print(f"✓ Job submitted: {job_id}")
|
||||
print(f" - Status: {job_info['status']}")
|
||||
print(f" - Queue position: {job_info['queue_position']}")
|
||||
|
||||
# Test job status retrieval
|
||||
print("\nRetrieving job status...")
|
||||
logging.info("About to call get_job_status()...")
|
||||
status = job_queue.get_job_status(job_id)
|
||||
logging.info(f"get_job_status() returned: {status['status']}")
|
||||
print(f"✓ Job status retrieved")
|
||||
print(f" - Status: {status['status']}")
|
||||
print(f" - Queue position: {status['queue_position']}")
|
||||
|
||||
# Wait for job to process
|
||||
print("\nWaiting for job to process (max 30 seconds)...", flush=True)
|
||||
logging.info("Waiting for transcription to complete...")
|
||||
max_wait = 30
|
||||
start = time.time()
|
||||
while time.time() - start < max_wait:
|
||||
logging.info("Calling get_job_status()...")
|
||||
status = job_queue.get_job_status(job_id)
|
||||
print(f" Status: {status['status']}", flush=True)
|
||||
logging.info(f"Job status: {status['status']}")
|
||||
|
||||
if status['status'] in ['completed', 'failed']:
|
||||
logging.info("Job completed or failed, breaking out of loop")
|
||||
break
|
||||
|
||||
logging.info("Job still running, sleeping 2 seconds...")
|
||||
time.sleep(2)
|
||||
|
||||
final_status = job_queue.get_job_status(job_id)
|
||||
print(f"\nFinal job status: {final_status['status']}")
|
||||
|
||||
if final_status['status'] == 'completed':
|
||||
print(f"✓ Job completed successfully")
|
||||
print(f" - Result path: {final_status['result_path']}")
|
||||
print(f" - Processing time: {final_status['processing_time_seconds']:.2f}s")
|
||||
|
||||
# Test result retrieval
|
||||
print("\nRetrieving job result...")
|
||||
logging.info("Calling get_job_result()...")
|
||||
result = job_queue.get_job_result(job_id)
|
||||
logging.info(f"Result retrieved: {len(result)} characters")
|
||||
print(f"✓ Result retrieved ({len(result)} characters)")
|
||||
print(f" Preview: {result[:100]}...")
|
||||
|
||||
elif final_status['status'] == 'failed':
|
||||
print(f"✗ Job failed: {final_status['error']}")
|
||||
|
||||
# Test persistence by stopping and restarting
|
||||
print("\nTesting persistence...")
|
||||
logging.info("Stopping job queue...")
|
||||
job_queue.stop(wait_for_current=False)
|
||||
print("✓ Job queue stopped")
|
||||
logging.info("Job queue stopped")
|
||||
|
||||
logging.info("Restarting job queue...")
|
||||
job_queue2 = JobQueue(max_queue_size=10, metadata_dir=temp_dir)
|
||||
job_queue2.start()
|
||||
print("✓ Job queue restarted")
|
||||
logging.info("Job queue restarted")
|
||||
|
||||
logging.info("Checking job status after restart...")
|
||||
status_after_restart = job_queue2.get_job_status(job_id)
|
||||
print(f"✓ Job still exists after restart: {status_after_restart['status']}")
|
||||
logging.info(f"Job status after restart: {status_after_restart['status']}")
|
||||
|
||||
logging.info("Stopping job queue 2...")
|
||||
job_queue2.stop()
|
||||
logging.info("Job queue 2 stopped")
|
||||
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir)
|
||||
print(f"✓ Cleaned up temp directory")
|
||||
|
||||
return final_status['status'] == 'completed'
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Job queue test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("\n" + "="*60)
|
||||
print("PHASE 1 COMPONENT TESTS")
|
||||
print("="*60)
|
||||
|
||||
# Check GPU availability first - exit if no GPU
|
||||
check_gpu_available()
|
||||
|
||||
results = {
|
||||
"Test Audio File": test_audio_file(),
|
||||
"GPU Health Check": test_gpu_health(),
|
||||
"Job Queue": test_job_queue(),
|
||||
}
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("TEST SUMMARY")
|
||||
print("="*60)
|
||||
|
||||
for test_name, passed in results.items():
|
||||
status = "✓ PASSED" if passed else "✗ FAILED"
|
||||
print(f"{test_name:.<40} {status}")
|
||||
|
||||
all_passed = all(results.values())
|
||||
print("\n" + "="*60)
|
||||
if all_passed:
|
||||
print("ALL TESTS PASSED ✓")
|
||||
else:
|
||||
print("SOME TESTS FAILED ✗")
|
||||
print("="*60)
|
||||
|
||||
return 0 if all_passed else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
523
tests/test_e2e_integration.py
Executable file
523
tests/test_e2e_integration.py
Executable file
@@ -0,0 +1,523 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Phase 4: End-to-End Integration Testing
|
||||
|
||||
Comprehensive integration tests for the async job queue system.
|
||||
Tests all scenarios from the DEV_PLAN.md Phase 4 checklist.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
import subprocess
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add src to path (go up one level from tests/ to root)
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
# Color codes for terminal output
|
||||
class Colors:
|
||||
GREEN = '\033[92m'
|
||||
RED = '\033[91m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
CYAN = '\033[96m'
|
||||
END = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
|
||||
def print_success(msg):
|
||||
print(f"{Colors.GREEN}✓ {msg}{Colors.END}")
|
||||
|
||||
def print_error(msg):
|
||||
print(f"{Colors.RED}✗ {msg}{Colors.END}")
|
||||
|
||||
def print_info(msg):
|
||||
print(f"{Colors.BLUE}ℹ {msg}{Colors.END}")
|
||||
|
||||
def print_warning(msg):
|
||||
print(f"{Colors.YELLOW}⚠ {msg}{Colors.END}")
|
||||
|
||||
def print_section(msg):
|
||||
print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{msg}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{'='*70}{Colors.END}\n")
|
||||
|
||||
|
||||
class Phase4Tester:
|
||||
def __init__(self, api_url="http://localhost:8000", test_audio=None):
|
||||
self.api_url = api_url
|
||||
# Use relative path from project root if not provided
|
||||
if test_audio is None:
|
||||
project_root = Path(__file__).parent.parent
|
||||
test_audio = str(project_root / "data" / "test.mp3")
|
||||
self.test_audio = test_audio
|
||||
self.test_results = []
|
||||
self.server_process = None
|
||||
|
||||
# Verify test audio exists
|
||||
if not os.path.exists(test_audio):
|
||||
raise FileNotFoundError(f"Test audio file not found: {test_audio}")
|
||||
|
||||
def test(self, name, func):
|
||||
"""Run a test and record result"""
|
||||
try:
|
||||
logger.info(f"Testing: {name}")
|
||||
print_info(f"Testing: {name}")
|
||||
func()
|
||||
logger.info(f"PASSED: {name}")
|
||||
print_success(f"PASSED: {name}")
|
||||
self.test_results.append((name, True, None))
|
||||
return True
|
||||
except AssertionError as e:
|
||||
logger.error(f"FAILED: {name} - {str(e)}")
|
||||
print_error(f"FAILED: {name}")
|
||||
print_error(f" Reason: {str(e)}")
|
||||
self.test_results.append((name, False, str(e)))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR: {name} - {str(e)}")
|
||||
print_error(f"ERROR: {name}")
|
||||
print_error(f" Exception: {str(e)}")
|
||||
self.test_results.append((name, False, f"Exception: {str(e)}"))
|
||||
return False
|
||||
|
||||
def start_api_server(self, wait_time=5):
|
||||
"""Start the API server in background"""
|
||||
print_info("Starting API server...")
|
||||
# Script is in project root, one level up from tests/
|
||||
script_path = Path(__file__).parent.parent / "run_api_server.sh"
|
||||
|
||||
# Start server in background
|
||||
self.server_process = subprocess.Popen(
|
||||
[str(script_path)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
preexec_fn=os.setsid
|
||||
)
|
||||
|
||||
# Wait for server to start
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Verify server is running
|
||||
try:
|
||||
response = requests.get(f"{self.api_url}/health", timeout=5)
|
||||
if response.status_code == 200:
|
||||
print_success("API server started successfully")
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
print_error("API server failed to start")
|
||||
return False
|
||||
|
||||
def stop_api_server(self):
|
||||
"""Stop the API server"""
|
||||
if self.server_process:
|
||||
print_info("Stopping API server...")
|
||||
os.killpg(os.getpgid(self.server_process.pid), signal.SIGTERM)
|
||||
self.server_process.wait(timeout=10)
|
||||
print_success("API server stopped")
|
||||
|
||||
def wait_for_job_completion(self, job_id, timeout=60, poll_interval=2):
|
||||
"""Poll job status until completed or failed"""
|
||||
start_time = time.time()
|
||||
last_status = None
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}")
|
||||
assert response.status_code == 200, f"Failed to get job status: {response.status_code}"
|
||||
|
||||
status_data = response.json()
|
||||
current_status = status_data['status']
|
||||
|
||||
# Print status changes
|
||||
if current_status != last_status:
|
||||
if status_data.get('queue_position') is not None:
|
||||
print_info(f" Job status: {current_status}, queue position: {status_data['queue_position']}")
|
||||
else:
|
||||
print_info(f" Job status: {current_status}")
|
||||
last_status = current_status
|
||||
|
||||
if current_status == "completed":
|
||||
return status_data
|
||||
elif current_status == "failed":
|
||||
raise AssertionError(f"Job failed: {status_data.get('error', 'Unknown error')}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise AssertionError(f"Request failed: {e}")
|
||||
|
||||
raise AssertionError(f"Job did not complete within {timeout} seconds")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 1: Single Job Submission and Completion
|
||||
# ========================================================================
|
||||
def test_single_job_flow(self):
|
||||
"""Test complete job flow: submit → poll → get result"""
|
||||
# Submit job
|
||||
print_info(" Submitting job...")
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "large-v3",
|
||||
"output_format": "txt"
|
||||
})
|
||||
assert response.status_code == 200, f"Job submission failed: {response.status_code}"
|
||||
|
||||
job_data = response.json()
|
||||
assert 'job_id' in job_data, "No job_id in response"
|
||||
# Status can be 'queued' or 'running' (if queue is empty and job starts immediately)
|
||||
assert job_data['status'] in ['queued', 'running'], f"Expected status 'queued' or 'running', got '{job_data['status']}'"
|
||||
|
||||
job_id = job_data['job_id']
|
||||
print_success(f" Job submitted: {job_id}")
|
||||
|
||||
# Wait for completion
|
||||
print_info(" Waiting for job completion...")
|
||||
final_status = self.wait_for_job_completion(job_id)
|
||||
|
||||
assert final_status['status'] == 'completed', "Job did not complete"
|
||||
assert final_status['result_path'] is not None, "No result_path in completed job"
|
||||
assert final_status['processing_time_seconds'] is not None, "No processing time"
|
||||
print_success(f" Job completed in {final_status['processing_time_seconds']:.2f}s")
|
||||
|
||||
# Get result
|
||||
print_info(" Retrieving result...")
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
|
||||
assert response.status_code == 200, f"Failed to get result: {response.status_code}"
|
||||
|
||||
result_text = response.text
|
||||
assert len(result_text) > 0, "Empty result"
|
||||
print_success(f" Got result: {len(result_text)} characters")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 2: Multiple Jobs in Queue (FIFO)
|
||||
# ========================================================================
|
||||
def test_multiple_jobs_fifo(self):
|
||||
"""Test multiple jobs are processed in FIFO order"""
|
||||
job_ids = []
|
||||
|
||||
# Submit 3 jobs
|
||||
print_info(" Submitting 3 jobs...")
|
||||
for i in range(3):
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "tiny", # Use tiny model for faster processing
|
||||
"output_format": "txt"
|
||||
})
|
||||
assert response.status_code == 200, f"Job {i+1} submission failed"
|
||||
|
||||
job_data = response.json()
|
||||
job_ids.append(job_data['job_id'])
|
||||
print_info(f" Job {i+1} submitted: {job_data['job_id']}, queue_position: {job_data.get('queue_position', 0)}")
|
||||
|
||||
# Wait for all jobs to complete
|
||||
print_info(" Waiting for all jobs to complete...")
|
||||
for i, job_id in enumerate(job_ids):
|
||||
print_info(f" Waiting for job {i+1}/{len(job_ids)}...")
|
||||
final_status = self.wait_for_job_completion(job_id, timeout=120)
|
||||
assert final_status['status'] == 'completed', f"Job {i+1} failed"
|
||||
|
||||
print_success(f" All {len(job_ids)} jobs completed in FIFO order")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 3: GPU Health Check
|
||||
# ========================================================================
|
||||
def test_gpu_health_check(self):
|
||||
"""Test GPU health check endpoint"""
|
||||
print_info(" Checking GPU health...")
|
||||
response = requests.get(f"{self.api_url}/health/gpu")
|
||||
assert response.status_code == 200, f"GPU health check failed: {response.status_code}"
|
||||
|
||||
health_data = response.json()
|
||||
assert 'gpu_available' in health_data, "Missing gpu_available field"
|
||||
assert 'gpu_working' in health_data, "Missing gpu_working field"
|
||||
assert 'device_used' in health_data, "Missing device_used field"
|
||||
|
||||
print_info(f" GPU Available: {health_data['gpu_available']}")
|
||||
print_info(f" GPU Working: {health_data['gpu_working']}")
|
||||
print_info(f" Device: {health_data['device_used']}")
|
||||
|
||||
if health_data['gpu_available']:
|
||||
assert health_data['device_name'], "GPU available but no device_name"
|
||||
assert health_data['test_duration_seconds'] < 3, "GPU test took too long (might be using CPU)"
|
||||
print_success(f" GPU is healthy: {health_data['device_name']}")
|
||||
else:
|
||||
print_warning(" GPU not available on this system")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 4: Invalid Audio Path
|
||||
# ========================================================================
|
||||
def test_invalid_audio_path(self):
|
||||
"""Test job submission with invalid audio path"""
|
||||
print_info(" Submitting job with invalid path...")
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": "/invalid/path/does/not/exist.mp3",
|
||||
"model_name": "large-v3"
|
||||
})
|
||||
|
||||
# Should return 400 Bad Request
|
||||
assert response.status_code == 400, f"Expected 400, got {response.status_code}"
|
||||
|
||||
error_data = response.json()
|
||||
assert 'detail' in error_data or 'error' in error_data, "No error message in response"
|
||||
print_success(" Invalid path rejected correctly")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 5: Job Not Found
|
||||
# ========================================================================
|
||||
def test_job_not_found(self):
|
||||
"""Test retrieving non-existent job"""
|
||||
print_info(" Requesting non-existent job...")
|
||||
fake_job_id = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
response = requests.get(f"{self.api_url}/jobs/{fake_job_id}")
|
||||
assert response.status_code == 404, f"Expected 404, got {response.status_code}"
|
||||
print_success(" Non-existent job handled correctly")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 6: Result Before Completion
|
||||
# ========================================================================
|
||||
def test_result_before_completion(self):
|
||||
"""Test getting result for job that hasn't completed"""
|
||||
print_info(" Submitting job and trying to get result immediately...")
|
||||
|
||||
# Submit job
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "large-v3"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
job_id = response.json()['job_id']
|
||||
|
||||
# Try to get result immediately (job is still queued/running)
|
||||
time.sleep(0.5)
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}/result")
|
||||
|
||||
# Should return 409 Conflict or similar
|
||||
assert response.status_code in [409, 400, 404], f"Expected 4xx error, got {response.status_code}"
|
||||
print_success(" Result request before completion handled correctly")
|
||||
|
||||
# Clean up: wait for job to complete
|
||||
self.wait_for_job_completion(job_id)
|
||||
|
||||
# ========================================================================
|
||||
# TEST 7: List Jobs
|
||||
# ========================================================================
|
||||
def test_list_jobs(self):
|
||||
"""Test listing jobs with filters"""
|
||||
print_info(" Testing job listing...")
|
||||
|
||||
# List all jobs
|
||||
response = requests.get(f"{self.api_url}/jobs")
|
||||
assert response.status_code == 200, f"List jobs failed: {response.status_code}"
|
||||
|
||||
jobs_data = response.json()
|
||||
assert 'jobs' in jobs_data, "No jobs array in response"
|
||||
assert isinstance(jobs_data['jobs'], list), "Jobs is not a list"
|
||||
print_info(f" Found {len(jobs_data['jobs'])} jobs")
|
||||
|
||||
# List only completed jobs
|
||||
response = requests.get(f"{self.api_url}/jobs?status=completed")
|
||||
assert response.status_code == 200
|
||||
completed_jobs = response.json()['jobs']
|
||||
print_info(f" Found {len(completed_jobs)} completed jobs")
|
||||
|
||||
# List with limit
|
||||
response = requests.get(f"{self.api_url}/jobs?limit=5")
|
||||
assert response.status_code == 200
|
||||
limited_jobs = response.json()['jobs']
|
||||
assert len(limited_jobs) <= 5, "Limit not respected"
|
||||
print_success(" Job listing works correctly")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 8: Server Restart with Job Persistence
|
||||
# ========================================================================
|
||||
def test_server_restart_persistence(self):
|
||||
"""Test that jobs persist across server restarts"""
|
||||
print_info(" Testing job persistence across restart...")
|
||||
|
||||
# Submit a job
|
||||
response = requests.post(f"{self.api_url}/jobs", json={
|
||||
"audio_path": self.test_audio,
|
||||
"model_name": "tiny"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
job_id = response.json()['job_id']
|
||||
print_info(f" Submitted job: {job_id}")
|
||||
|
||||
# Get job count before restart
|
||||
response = requests.get(f"{self.api_url}/jobs")
|
||||
jobs_before = len(response.json()['jobs'])
|
||||
print_info(f" Jobs before restart: {jobs_before}")
|
||||
|
||||
# Restart server
|
||||
print_info(" Restarting server...")
|
||||
self.stop_api_server()
|
||||
time.sleep(2)
|
||||
assert self.start_api_server(wait_time=8), "Server failed to restart"
|
||||
|
||||
# Check jobs after restart
|
||||
response = requests.get(f"{self.api_url}/jobs")
|
||||
assert response.status_code == 200
|
||||
jobs_after = len(response.json()['jobs'])
|
||||
print_info(f" Jobs after restart: {jobs_after}")
|
||||
|
||||
# Check our specific job is still there (this is the key test)
|
||||
response = requests.get(f"{self.api_url}/jobs/{job_id}")
|
||||
assert response.status_code == 200, "Job not found after restart"
|
||||
|
||||
# Note: Total count may differ due to job retention/cleanup, but persistence works if we can find the job
|
||||
if jobs_after < jobs_before:
|
||||
print_warning(f" Job count decreased ({jobs_before} -> {jobs_after}), may be due to cleanup")
|
||||
|
||||
print_success(" Jobs persisted correctly across restart")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 9: Health Endpoint
|
||||
# ========================================================================
|
||||
def test_health_endpoint(self):
|
||||
"""Test basic health endpoint"""
|
||||
print_info(" Checking health endpoint...")
|
||||
response = requests.get(f"{self.api_url}/health")
|
||||
assert response.status_code == 200, f"Health check failed: {response.status_code}"
|
||||
|
||||
health_data = response.json()
|
||||
assert health_data['status'] == 'healthy', "Server not healthy"
|
||||
print_success(" Health endpoint OK")
|
||||
|
||||
# ========================================================================
|
||||
# TEST 10: Models Endpoint
|
||||
# ========================================================================
|
||||
def test_models_endpoint(self):
|
||||
"""Test models information endpoint"""
|
||||
print_info(" Checking models endpoint...")
|
||||
response = requests.get(f"{self.api_url}/models")
|
||||
assert response.status_code == 200, f"Models endpoint failed: {response.status_code}"
|
||||
|
||||
models_data = response.json()
|
||||
assert 'available_models' in models_data, "No available_models field"
|
||||
assert 'available_devices' in models_data, "No available_devices field"
|
||||
assert len(models_data['available_models']) > 0, "No models listed"
|
||||
print_info(f" Available models: {len(models_data['available_models'])}")
|
||||
print_success(" Models endpoint OK")
|
||||
|
||||
def print_summary(self):
|
||||
"""Print test summary"""
|
||||
print_section("TEST SUMMARY")
|
||||
|
||||
passed = sum(1 for _, result, _ in self.test_results if result)
|
||||
failed = len(self.test_results) - passed
|
||||
|
||||
for name, result, error in self.test_results:
|
||||
if result:
|
||||
print_success(f"{name}")
|
||||
else:
|
||||
print_error(f"{name}")
|
||||
if error:
|
||||
print(f" {error}")
|
||||
|
||||
print(f"\n{Colors.BOLD}Total: {len(self.test_results)} | ", end="")
|
||||
print(f"{Colors.GREEN}Passed: {passed}{Colors.END} | ", end="")
|
||||
print(f"{Colors.RED}Failed: {failed}{Colors.END}\n")
|
||||
|
||||
return failed == 0
|
||||
|
||||
def run_all_tests(self, start_server=True):
|
||||
"""Run all Phase 4 integration tests"""
|
||||
print_section("PHASE 4: END-TO-END INTEGRATION TESTING")
|
||||
|
||||
try:
|
||||
# Start server if requested
|
||||
if start_server:
|
||||
if not self.start_api_server():
|
||||
print_error("Failed to start API server. Aborting tests.")
|
||||
return False
|
||||
else:
|
||||
# Verify server is already running
|
||||
try:
|
||||
response = requests.get(f"{self.api_url}/health", timeout=5)
|
||||
if response.status_code != 200:
|
||||
print_error("Server is not responding. Please start it first.")
|
||||
return False
|
||||
print_info("Using existing API server")
|
||||
except:
|
||||
print_error("Cannot connect to API server. Please start it first.")
|
||||
return False
|
||||
|
||||
# Run tests
|
||||
print_section("TEST 1: Single Job Submission and Completion")
|
||||
self.test("Single job flow (submit → poll → get result)", self.test_single_job_flow)
|
||||
|
||||
print_section("TEST 2: Multiple Jobs (FIFO Order)")
|
||||
self.test("Multiple jobs in queue (FIFO)", self.test_multiple_jobs_fifo)
|
||||
|
||||
print_section("TEST 3: GPU Health Check")
|
||||
self.test("GPU health check endpoint", self.test_gpu_health_check)
|
||||
|
||||
print_section("TEST 4: Error Handling - Invalid Path")
|
||||
self.test("Invalid audio path rejection", self.test_invalid_audio_path)
|
||||
|
||||
print_section("TEST 5: Error Handling - Job Not Found")
|
||||
self.test("Non-existent job handling", self.test_job_not_found)
|
||||
|
||||
print_section("TEST 6: Error Handling - Result Before Completion")
|
||||
self.test("Result request before completion", self.test_result_before_completion)
|
||||
|
||||
print_section("TEST 7: Job Listing")
|
||||
self.test("List jobs with filters", self.test_list_jobs)
|
||||
|
||||
print_section("TEST 8: Health Endpoint")
|
||||
self.test("Basic health endpoint", self.test_health_endpoint)
|
||||
|
||||
print_section("TEST 9: Models Endpoint")
|
||||
self.test("Models information endpoint", self.test_models_endpoint)
|
||||
|
||||
print_section("TEST 10: Server Restart Persistence")
|
||||
self.test("Job persistence across server restart", self.test_server_restart_persistence)
|
||||
|
||||
# Print summary
|
||||
success = self.print_summary()
|
||||
|
||||
return success
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if start_server and self.server_process:
|
||||
self.stop_api_server()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main test runner"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Phase 4 Integration Tests')
|
||||
parser.add_argument('--url', default='http://localhost:8000', help='API server URL')
|
||||
# Default to None so Phase4Tester uses relative path
|
||||
parser.add_argument('--audio', default=None,
|
||||
help='Path to test audio file (default: <project_root>/data/test.mp3)')
|
||||
parser.add_argument('--no-start-server', action='store_true',
|
||||
help='Do not start server (assume it is already running)')
|
||||
args = parser.parse_args()
|
||||
|
||||
tester = Phase4Tester(api_url=args.url, test_audio=args.audio)
|
||||
success = tester.run_all_tests(start_server=not args.no_start_server)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
281
tests/test_input_validation.py
Normal file
281
tests/test_input_validation.py
Normal file
@@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for input validation module, specifically filename validation.
|
||||
|
||||
Tests the security-critical validate_filename_safe() function to ensure
|
||||
it correctly blocks path traversal attacks while allowing legitimate filenames.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import pytest
|
||||
|
||||
# Add src to path (go up one level from tests/ to root)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from src.utils.input_validation import (
|
||||
validate_filename_safe,
|
||||
ValidationError,
|
||||
PathTraversalError,
|
||||
InvalidFileTypeError,
|
||||
ALLOWED_AUDIO_EXTENSIONS
|
||||
)
|
||||
|
||||
|
||||
class TestValidFilenameSafe:
|
||||
"""Test validate_filename_safe() function with various inputs."""
|
||||
|
||||
def test_simple_valid_filenames(self):
|
||||
"""Test that simple, valid filenames are accepted."""
|
||||
valid_names = [
|
||||
"audio.m4a",
|
||||
"song.wav",
|
||||
"podcast.mp3",
|
||||
"recording.flac",
|
||||
"music.ogg",
|
||||
"voice.aac",
|
||||
]
|
||||
|
||||
for filename in valid_names:
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept: {filename}"
|
||||
|
||||
def test_filenames_with_ellipsis(self):
|
||||
"""Test filenames with ellipsis (multiple dots) are accepted."""
|
||||
# This is the key test case from the bug report
|
||||
ellipsis_names = [
|
||||
"audio...mp3",
|
||||
"This is... a test.m4a",
|
||||
"Part 1... Part 2.wav",
|
||||
"Wait... what.m4a",
|
||||
"video...multiple...dots.mp3",
|
||||
"This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a", # Bug report case
|
||||
]
|
||||
|
||||
for filename in ellipsis_names:
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept filename with ellipsis: {filename}"
|
||||
|
||||
def test_filenames_with_special_chars(self):
|
||||
"""Test filenames with various special characters."""
|
||||
special_char_names = [
|
||||
"My-Video_2024.m4a",
|
||||
"song (remix).m4a",
|
||||
"audio [final].wav",
|
||||
"test file with spaces.mp3",
|
||||
"file-name_with-symbols.flac",
|
||||
]
|
||||
|
||||
for filename in special_char_names:
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept: {filename}"
|
||||
|
||||
def test_multiple_extensions(self):
|
||||
"""Test filenames that look like they have multiple extensions."""
|
||||
multi_ext_names = [
|
||||
"backup.tar.gz.mp3", # .mp3 is valid
|
||||
"file.old.wav", # .wav is valid
|
||||
"audio.2024.m4a", # .m4a is valid
|
||||
]
|
||||
|
||||
for filename in multi_ext_names:
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept: {filename}"
|
||||
|
||||
def test_path_traversal_attempts(self):
|
||||
"""Test that path traversal attempts are rejected."""
|
||||
dangerous_names = [
|
||||
"../../../etc/passwd",
|
||||
"../../secrets.txt",
|
||||
"../file.mp4",
|
||||
"dir/../file.mp4",
|
||||
"file/../../etc/passwd",
|
||||
]
|
||||
|
||||
for filename in dangerous_names:
|
||||
with pytest.raises(PathTraversalError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "path" in str(exc_info.value).lower(), f"Should reject path traversal: {filename}"
|
||||
|
||||
def test_absolute_paths(self):
|
||||
"""Test that absolute paths are rejected."""
|
||||
absolute_paths = [
|
||||
"/etc/passwd",
|
||||
"/tmp/file.mp4",
|
||||
"/home/user/audio.wav",
|
||||
"C:\\Windows\\System32\\file.mp3", # Windows path
|
||||
"\\\\server\\share\\file.m4a", # UNC path
|
||||
]
|
||||
|
||||
for filename in absolute_paths:
|
||||
with pytest.raises(PathTraversalError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "path" in str(exc_info.value).lower(), f"Should reject absolute path: {filename}"
|
||||
|
||||
def test_path_separators(self):
|
||||
"""Test that filenames with path separators are rejected."""
|
||||
paths_with_separators = [
|
||||
"dir/file.mp4",
|
||||
"folder\\file.wav",
|
||||
"path/to/audio.m4a",
|
||||
"a/b/c/d.mp3",
|
||||
]
|
||||
|
||||
for filename in paths_with_separators:
|
||||
with pytest.raises(PathTraversalError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "separator" in str(exc_info.value).lower() or "path" in str(exc_info.value).lower(), \
|
||||
f"Should reject path with separators: {filename}"
|
||||
|
||||
def test_null_bytes(self):
|
||||
"""Test that filenames with null bytes are rejected."""
|
||||
null_byte_names = [
|
||||
"file\x00.mp4",
|
||||
"\x00malicious.wav",
|
||||
"audio\x00evil.m4a",
|
||||
]
|
||||
|
||||
for filename in null_byte_names:
|
||||
with pytest.raises(PathTraversalError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "null" in str(exc_info.value).lower(), f"Should reject null bytes: {repr(filename)}"
|
||||
|
||||
def test_empty_filename(self):
|
||||
"""Test that empty filename is rejected."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
validate_filename_safe("")
|
||||
assert "empty" in str(exc_info.value).lower()
|
||||
|
||||
def test_no_extension(self):
|
||||
"""Test that filenames without extensions are rejected."""
|
||||
no_ext_names = [
|
||||
"filename",
|
||||
"noextension",
|
||||
]
|
||||
|
||||
for filename in no_ext_names:
|
||||
with pytest.raises(InvalidFileTypeError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "extension" in str(exc_info.value).lower(), f"Should reject no extension: {filename}"
|
||||
|
||||
def test_invalid_extensions(self):
|
||||
"""Test that unsupported file extensions are rejected."""
|
||||
invalid_ext_names = [
|
||||
"document.pdf",
|
||||
"image.png",
|
||||
"video.avi",
|
||||
"script.sh",
|
||||
"executable.exe",
|
||||
"text.txt",
|
||||
]
|
||||
|
||||
for filename in invalid_ext_names:
|
||||
with pytest.raises(InvalidFileTypeError) as exc_info:
|
||||
validate_filename_safe(filename)
|
||||
assert "unsupported" in str(exc_info.value).lower() or "format" in str(exc_info.value).lower(), \
|
||||
f"Should reject invalid extension: {filename}"
|
||||
|
||||
def test_case_insensitive_extensions(self):
|
||||
"""Test that file extensions are case-insensitive."""
|
||||
case_variations = [
|
||||
"audio.MP3",
|
||||
"sound.WAV",
|
||||
"music.M4A",
|
||||
"podcast.FLAC",
|
||||
"voice.AAC",
|
||||
]
|
||||
|
||||
for filename in case_variations:
|
||||
# Should not raise exception
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept case variation: {filename}"
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test various edge cases."""
|
||||
# Just dots (but with valid extension) - should pass
|
||||
assert validate_filename_safe("...mp3") == "...mp3"
|
||||
assert validate_filename_safe("....wav") == "....wav"
|
||||
|
||||
# Filenames starting with dot (hidden files on Unix)
|
||||
assert validate_filename_safe(".hidden.m4a") == ".hidden.m4a"
|
||||
|
||||
# Very long filename (but valid)
|
||||
long_name = "a" * 200 + ".mp3"
|
||||
assert validate_filename_safe(long_name) == long_name
|
||||
|
||||
def test_allowed_extensions_comprehensive(self):
|
||||
"""Test all allowed extensions from ALLOWED_AUDIO_EXTENSIONS."""
|
||||
for ext in ALLOWED_AUDIO_EXTENSIONS:
|
||||
filename = f"test{ext}"
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should accept allowed extension: {ext}"
|
||||
|
||||
|
||||
class TestBugReportCase:
|
||||
"""Specific test for the bug report case."""
|
||||
|
||||
def test_bug_report_filename(self):
|
||||
"""
|
||||
Test the exact filename from the bug report that was failing.
|
||||
|
||||
Bug: "This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a"
|
||||
was being rejected due to "..." being parsed as ".."
|
||||
"""
|
||||
filename = "This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a"
|
||||
|
||||
# Should NOT raise any exception
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename
|
||||
|
||||
def test_various_ellipsis_patterns(self):
|
||||
"""Test various ellipsis patterns that should all be accepted."""
|
||||
patterns = [
|
||||
"...", # Three dots
|
||||
"....", # Four dots
|
||||
".....", # Five dots
|
||||
"file...end.mp3",
|
||||
"start...middle...end.wav",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
if not pattern.endswith(tuple(f"{ext}" for ext in ALLOWED_AUDIO_EXTENSIONS)):
|
||||
pattern += ".mp3" # Add valid extension
|
||||
result = validate_filename_safe(pattern)
|
||||
assert result == pattern
|
||||
|
||||
|
||||
class TestSecurityBoundary:
|
||||
"""Test the security boundary between safe and dangerous filenames."""
|
||||
|
||||
def test_just_two_dots_vs_path_separator(self):
|
||||
"""
|
||||
Test the critical distinction:
|
||||
- "file..mp3" (two dots in filename) = SAFE
|
||||
- "../file.mp3" (two dots as path component) = DANGEROUS
|
||||
"""
|
||||
# Safe: dots within filename
|
||||
safe_filenames = [
|
||||
"file..mp3",
|
||||
"..file.mp3",
|
||||
"file...mp3",
|
||||
"f..i..l..e.mp3",
|
||||
]
|
||||
|
||||
for filename in safe_filenames:
|
||||
result = validate_filename_safe(filename)
|
||||
assert result == filename, f"Should be safe: {filename}"
|
||||
|
||||
# Dangerous: dots as directory reference
|
||||
dangerous_filenames = [
|
||||
"../file.mp3",
|
||||
"../../file.mp3",
|
||||
"dir/../file.mp3",
|
||||
]
|
||||
|
||||
for filename in dangerous_filenames:
|
||||
with pytest.raises(PathTraversalError):
|
||||
validate_filename_safe(filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
208
tests/test_path_traversal_fix.py
Normal file
208
tests/test_path_traversal_fix.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test path traversal detection with ellipsis support.
|
||||
|
||||
Tests the fix for false positives where filenames containing ellipsis (...)
|
||||
were incorrectly flagged as path traversal attempts.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from utils.input_validation import (
|
||||
validate_path_safe,
|
||||
validate_audio_file,
|
||||
PathTraversalError,
|
||||
ValidationError,
|
||||
InvalidFileTypeError
|
||||
)
|
||||
|
||||
|
||||
class TestPathTraversalWithEllipsis:
|
||||
"""Test that ellipsis in filenames is allowed while blocking real attacks."""
|
||||
|
||||
def test_filename_with_ellipsis_allowed(self, tmp_path):
|
||||
"""Filenames with ellipsis (...) should be allowed."""
|
||||
test_cases = [
|
||||
"Wait... what.mp3",
|
||||
"This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a",
|
||||
"file...mp3",
|
||||
"test....audio.wav",
|
||||
"a..b..c.mp3",
|
||||
"dots.........everywhere.m4a"
|
||||
]
|
||||
|
||||
for filename in test_cases:
|
||||
# Create test file
|
||||
test_file = tmp_path / filename
|
||||
test_file.write_text("fake audio data")
|
||||
|
||||
# Should NOT raise PathTraversalError
|
||||
try:
|
||||
result = validate_path_safe(str(test_file), [str(tmp_path)])
|
||||
assert result.exists(), f"File should exist: {filename}"
|
||||
print(f"✓ PASS: {filename}")
|
||||
except PathTraversalError as e:
|
||||
pytest.fail(f"False positive for filename: {filename}. Error: {e}")
|
||||
|
||||
def test_actual_path_traversal_blocked(self, tmp_path):
|
||||
"""Actual path traversal attempts should be blocked."""
|
||||
attack_cases = [
|
||||
"../../../etc/passwd",
|
||||
"..\\..\\..\\windows\\system32",
|
||||
"legitimate/../../../etc/passwd",
|
||||
"dir/../../secret",
|
||||
"../",
|
||||
"..",
|
||||
"subdir/../../../etc/hosts"
|
||||
]
|
||||
|
||||
for attack_path in attack_cases:
|
||||
with pytest.raises(PathTraversalError):
|
||||
validate_path_safe(attack_path, [str(tmp_path)])
|
||||
print(f"✗ FAIL: Should have blocked: {attack_path}")
|
||||
print(f"✓ PASS: Blocked attack: {attack_path}")
|
||||
|
||||
def test_ellipsis_in_full_path_allowed(self, tmp_path):
|
||||
"""Full paths with ellipsis in filename should be allowed."""
|
||||
# Create nested directory
|
||||
subdir = tmp_path / "uploads"
|
||||
subdir.mkdir()
|
||||
|
||||
filename = "Wait... what.mp3"
|
||||
test_file = subdir / filename
|
||||
test_file.write_text("fake audio data")
|
||||
|
||||
# Full path should be allowed when directory is in allowed_dirs
|
||||
result = validate_path_safe(str(test_file), [str(tmp_path)])
|
||||
assert result.exists()
|
||||
print(f"✓ PASS: Full path with ellipsis: {test_file}")
|
||||
|
||||
def test_mixed_dots_edge_cases(self, tmp_path):
|
||||
"""Test edge cases with various dot patterns."""
|
||||
edge_cases = [
|
||||
("single.dot.mp3", True), # Normal filename
|
||||
("..two.dots.mp3", True), # Starts with two dots (filename)
|
||||
("three...dots.mp3", True), # Three consecutive dots
|
||||
("many.....dots.mp3", True), # Many consecutive dots
|
||||
(".", False), # Current directory (should fail)
|
||||
("..", False), # Parent directory (should fail)
|
||||
]
|
||||
|
||||
for filename, should_pass in edge_cases:
|
||||
if should_pass:
|
||||
# Create test file
|
||||
test_file = tmp_path / filename
|
||||
test_file.write_text("fake audio data")
|
||||
|
||||
try:
|
||||
result = validate_path_safe(str(test_file), [str(tmp_path)])
|
||||
assert result.exists(), f"File should exist: {filename}"
|
||||
print(f"✓ PASS: Allowed: {filename}")
|
||||
except PathTraversalError:
|
||||
pytest.fail(f"Should have allowed: {filename}")
|
||||
else:
|
||||
with pytest.raises((PathTraversalError, ValidationError)):
|
||||
validate_path_safe(filename, [str(tmp_path)])
|
||||
print(f"✓ PASS: Blocked: {filename}")
|
||||
|
||||
|
||||
class TestAudioFileValidationWithEllipsis:
|
||||
"""Test full audio file validation with ellipsis support."""
|
||||
|
||||
def test_audio_file_with_ellipsis(self, tmp_path):
|
||||
"""Audio files with ellipsis should pass validation."""
|
||||
filename = "This Weird FPV Drone only takes one kind of Battery... Rekon 35 V2.m4a"
|
||||
test_file = tmp_path / filename
|
||||
test_file.write_bytes(b"fake audio data" * 100) # Non-empty file
|
||||
|
||||
# Should pass validation
|
||||
result = validate_audio_file(str(test_file), [str(tmp_path)])
|
||||
assert result.exists()
|
||||
print(f"✓ PASS: Audio validation with ellipsis: {filename}")
|
||||
|
||||
def test_audio_file_traversal_attack_blocked(self, tmp_path):
|
||||
"""Audio file validation should block path traversal."""
|
||||
attack_path = "../../../etc/passwd"
|
||||
|
||||
with pytest.raises(PathTraversalError):
|
||||
validate_audio_file(attack_path, [str(tmp_path)])
|
||||
print(f"✓ PASS: Audio validation blocked attack: {attack_path}")
|
||||
|
||||
|
||||
class TestComponentBasedDetection:
|
||||
"""Test that detection is based on path components, not string matching."""
|
||||
|
||||
def test_component_analysis(self, tmp_path):
|
||||
"""Verify that we're analyzing components, not doing string matching."""
|
||||
# These should PASS (ellipsis is in the filename component)
|
||||
safe_cases = [
|
||||
tmp_path / "file...mp3",
|
||||
tmp_path / "subdir" / "Wait...what.m4a",
|
||||
]
|
||||
|
||||
for test_path in safe_cases:
|
||||
test_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_path.write_text("data")
|
||||
|
||||
# Check that ".." is not in any component
|
||||
parts = Path(test_path).parts
|
||||
assert not any(part == ".." for part in parts), \
|
||||
f"Should not have '..' as a component: {test_path}"
|
||||
|
||||
# Validation should pass
|
||||
result = validate_path_safe(str(test_path), [str(tmp_path)])
|
||||
assert result.exists()
|
||||
print(f"✓ PASS: Component analysis correct: {test_path}")
|
||||
|
||||
def test_component_attack_detection(self):
|
||||
"""Verify that actual '..' components are detected."""
|
||||
# These should FAIL ('..' is a path component)
|
||||
attack_cases = [
|
||||
"../etc/passwd",
|
||||
"dir/../secret",
|
||||
"../../file.mp3",
|
||||
]
|
||||
|
||||
for attack_path in attack_cases:
|
||||
path = Path(attack_path)
|
||||
parts = path.parts
|
||||
|
||||
# Verify that ".." IS in components
|
||||
assert any(part == ".." for part in parts), \
|
||||
f"Should have '..' as a component: {attack_path}"
|
||||
print(f"✓ PASS: Attack has '..' component: {attack_path}")
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests with verbose output."""
|
||||
print("=" * 70)
|
||||
print("Running Path Traversal Detection Tests")
|
||||
print("=" * 70)
|
||||
|
||||
# Run pytest with verbose output
|
||||
exit_code = pytest.main([
|
||||
__file__,
|
||||
"-v",
|
||||
"--tb=short",
|
||||
"-p", "no:warnings"
|
||||
])
|
||||
|
||||
print("=" * 70)
|
||||
if exit_code == 0:
|
||||
print("✓ All tests passed!")
|
||||
else:
|
||||
print("✗ Some tests failed!")
|
||||
print("=" * 70)
|
||||
|
||||
return exit_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(run_tests())
|
||||
326
transcriber.py
326
transcriber.py
@@ -1,326 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
转录核心模块
|
||||
包含音频转录的核心逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any, Tuple, List, Optional, Union
|
||||
|
||||
from model_manager import get_whisper_model
|
||||
from audio_processor import validate_audio_file, process_audio
|
||||
from formatters import format_vtt, format_srt, format_json, format_time
|
||||
|
||||
# 日志配置
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def transcribe_audio(
|
||||
audio_path: str,
|
||||
model_name: str = "large-v3",
|
||||
device: str = "auto",
|
||||
compute_type: str = "auto",
|
||||
language: str = None,
|
||||
output_format: str = "vtt",
|
||||
beam_size: int = 5,
|
||||
temperature: float = 0.0,
|
||||
initial_prompt: str = None,
|
||||
output_directory: str = None
|
||||
) -> str:
|
||||
"""
|
||||
使用Faster Whisper转录音频文件
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: 运行设备 (cpu, cuda, auto)
|
||||
compute_type: 计算类型 (float16, int8, auto)
|
||||
language: 语言代码 (如zh, en, ja等,默认自动检测)
|
||||
output_format: 输出格式 (vtt, srt或json)
|
||||
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
|
||||
temperature: 采样温度,贪婪解码
|
||||
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
|
||||
output_directory: 输出目录路径,默认为音频文件所在目录
|
||||
|
||||
Returns:
|
||||
str: 转录结果,格式为VTT字幕或JSON
|
||||
"""
|
||||
# 验证音频文件
|
||||
validation_result = validate_audio_file(audio_path)
|
||||
if validation_result != "ok":
|
||||
return validation_result
|
||||
|
||||
try:
|
||||
# 获取模型实例
|
||||
model_instance = get_whisper_model(model_name, device, compute_type)
|
||||
|
||||
# 验证语言代码
|
||||
supported_languages = {
|
||||
"zh": "中文", "en": "英语", "ja": "日语", "ko": "韩语", "de": "德语",
|
||||
"fr": "法语", "es": "西班牙语", "ru": "俄语", "it": "意大利语",
|
||||
"pt": "葡萄牙语", "nl": "荷兰语", "ar": "阿拉伯语", "hi": "印地语",
|
||||
"tr": "土耳其语", "vi": "越南语", "th": "泰语", "id": "印尼语"
|
||||
}
|
||||
|
||||
if language is not None and language not in supported_languages:
|
||||
logger.warning(f"未知的语言代码: {language},将使用自动检测")
|
||||
language = None
|
||||
|
||||
# 设置转录参数
|
||||
options = {
|
||||
"language": language,
|
||||
"vad_filter": True, # 使用语音活动检测
|
||||
"vad_parameters": {"min_silence_duration_ms": 500}, # VAD参数优化
|
||||
"beam_size": beam_size,
|
||||
"temperature": temperature,
|
||||
"initial_prompt": initial_prompt,
|
||||
"word_timestamps": True, # 启用单词级时间戳
|
||||
"suppress_tokens": [-1], # 抑制特殊标记
|
||||
"condition_on_previous_text": True, # 基于前文进行条件生成
|
||||
"compression_ratio_threshold": 2.4 # 压缩比阈值,用于过滤重复内容
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(f"开始转录文件: {os.path.basename(audio_path)}")
|
||||
|
||||
# 处理音频
|
||||
audio_source = process_audio(audio_path)
|
||||
|
||||
# 执行转录 - 优先使用批处理模型
|
||||
if model_instance['batched_model'] is not None and model_instance['device'] == 'cuda':
|
||||
logger.info("使用批处理加速进行转录...")
|
||||
# 批处理模型需要单独设置batch_size参数
|
||||
segments, info = model_instance['batched_model'].transcribe(
|
||||
audio_source,
|
||||
batch_size=model_instance['batch_size'],
|
||||
**options
|
||||
)
|
||||
else:
|
||||
logger.info("使用标准模型进行转录...")
|
||||
segments, info = model_instance['model'].transcribe(audio_source, **options)
|
||||
|
||||
# 将生成器转换为列表
|
||||
segment_list = list(segments)
|
||||
|
||||
if not segment_list:
|
||||
return "转录失败,未获得结果"
|
||||
|
||||
# 记录转录信息
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"转录完成,用时: {elapsed_time:.2f}秒,检测语言: {info.language},音频长度: {info.duration:.2f}秒")
|
||||
|
||||
# 格式化转录结果
|
||||
if output_format.lower() == "vtt":
|
||||
transcription_result = format_vtt(segment_list)
|
||||
elif output_format.lower() == "srt":
|
||||
transcription_result = format_srt(segment_list)
|
||||
else:
|
||||
transcription_result = format_json(segment_list, info)
|
||||
|
||||
# 获取音频文件的目录和文件名
|
||||
audio_dir = os.path.dirname(audio_path)
|
||||
audio_filename = os.path.splitext(os.path.basename(audio_path))[0]
|
||||
|
||||
# 设置输出目录
|
||||
if output_directory is None:
|
||||
output_dir = audio_dir
|
||||
else:
|
||||
output_dir = output_directory
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 生成带有时间戳的文件名
|
||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||
output_filename = f"{audio_filename}_{timestamp}.{output_format.lower()}"
|
||||
output_path = os.path.join(output_dir, output_filename)
|
||||
|
||||
# 将转录结果写入文件
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(transcription_result)
|
||||
logger.info(f"转录结果已保存到: {output_path}")
|
||||
return f"转录成功,结果已保存到: {output_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"保存转录结果失败: {str(e)}")
|
||||
return f"转录成功,但保存结果失败: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转录失败: {str(e)}")
|
||||
return f"转录过程中发生错误: {str(e)}"
|
||||
|
||||
|
||||
def report_progress(current: int, total: int, elapsed_time: float) -> str:
|
||||
"""
|
||||
生成进度报告
|
||||
|
||||
Args:
|
||||
current: 当前处理的项目数
|
||||
total: 总项目数
|
||||
elapsed_time: 已用时间(秒)
|
||||
|
||||
Returns:
|
||||
str: 格式化的进度报告
|
||||
"""
|
||||
progress = current / total * 100
|
||||
eta = (elapsed_time / current) * (total - current) if current > 0 else 0
|
||||
return (f"进度: {current}/{total} ({progress:.1f}%)" +
|
||||
f" | 已用时间: {format_time(elapsed_time)}" +
|
||||
f" | 预计剩余: {format_time(eta)}")
|
||||
|
||||
def batch_transcribe(
|
||||
audio_folder: str,
|
||||
output_folder: str = None,
|
||||
model_name: str = "large-v3",
|
||||
device: str = "auto",
|
||||
compute_type: str = "auto",
|
||||
language: str = None,
|
||||
output_format: str = "vtt",
|
||||
beam_size: int = 5,
|
||||
temperature: float = 0.0,
|
||||
initial_prompt: str = None,
|
||||
parallel_files: int = 1
|
||||
) -> str:
|
||||
"""
|
||||
批量转录文件夹中的音频文件
|
||||
|
||||
Args:
|
||||
audio_folder: 包含音频文件的文件夹路径
|
||||
output_folder: 输出文件夹路径,默认为audio_folder下的transcript子文件夹
|
||||
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: 运行设备 (cpu, cuda, auto)
|
||||
compute_type: 计算类型 (float16, int8, auto)
|
||||
language: 语言代码 (如zh, en, ja等,默认自动检测)
|
||||
output_format: 输出格式 (vtt, srt或json)
|
||||
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
|
||||
temperature: 采样温度,0表示贪婪解码
|
||||
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
|
||||
parallel_files: 并行处理的文件数量(仅在CPU模式下有效)
|
||||
|
||||
Returns:
|
||||
str: 批处理结果摘要,包含处理时间和成功率
|
||||
"""
|
||||
if not os.path.isdir(audio_folder):
|
||||
return f"错误: 文件夹不存在: {audio_folder}"
|
||||
|
||||
# 设置输出文件夹
|
||||
if output_folder is None:
|
||||
output_folder = os.path.join(audio_folder, "transcript")
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# 验证输出格式
|
||||
valid_formats = ["vtt", "srt", "json"]
|
||||
if output_format.lower() not in valid_formats:
|
||||
return f"错误: 不支持的输出格式: {output_format}。支持的格式: {', '.join(valid_formats)}"
|
||||
|
||||
# 获取所有音频文件
|
||||
audio_files = []
|
||||
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
||||
|
||||
for filename in os.listdir(audio_folder):
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
if file_ext in supported_formats:
|
||||
audio_files.append(os.path.join(audio_folder, filename))
|
||||
|
||||
if not audio_files:
|
||||
return f"在 {audio_folder} 中未找到支持的音频文件。支持的格式: {', '.join(supported_formats)}"
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
total_files = len(audio_files)
|
||||
logger.info(f"开始批量转录 {total_files} 个文件,输出格式: {output_format}")
|
||||
|
||||
# 预加载模型以避免重复加载
|
||||
try:
|
||||
get_whisper_model(model_name, device, compute_type)
|
||||
logger.info(f"已预加载模型: {model_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"预加载模型失败: {str(e)}")
|
||||
return f"批处理失败: 无法加载模型 {model_name}: {str(e)}"
|
||||
|
||||
# 处理每个文件
|
||||
results = []
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
total_audio_duration = 0
|
||||
|
||||
# 处理每个文件
|
||||
for i, audio_path in enumerate(audio_files):
|
||||
file_name = os.path.basename(audio_path)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# 报告进度
|
||||
progress_msg = report_progress(i, total_files, elapsed)
|
||||
logger.info(f"{progress_msg} | 当前处理: {file_name}")
|
||||
|
||||
# 执行转录
|
||||
try:
|
||||
result = transcribe_audio(
|
||||
audio_path=audio_path,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
output_directory=output_folder
|
||||
)
|
||||
|
||||
# 检查结果是否包含错误信息
|
||||
if result.startswith("错误:") or result.startswith("转录过程中发生错误:"):
|
||||
logger.error(f"转录失败: {file_name} - {result}")
|
||||
results.append(f"❌ 失败: {file_name} - {result}")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# 如果转录成功,提取输出路径信息
|
||||
if result.startswith("转录成功"):
|
||||
# 从返回消息中提取输出路径
|
||||
output_path = result.split(": ")[1] if ": " in result else "未知路径"
|
||||
success_count += 1
|
||||
results.append(f"✅ 成功: {file_name} -> {os.path.basename(output_path)}")
|
||||
|
||||
# 提取音频时长
|
||||
audio_duration = 0
|
||||
if output_format.lower() == "json":
|
||||
# 尝试从输出文件中解析音频时长
|
||||
try:
|
||||
import json
|
||||
# 从输出文件中读取JSON内容
|
||||
with open(output_path, "r", encoding="utf-8") as json_file:
|
||||
json_content = json_file.read()
|
||||
json_data = json.loads(json_content)
|
||||
audio_duration = json_data.get("duration", 0)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法从JSON文件中提取音频时长: {str(e)}")
|
||||
audio_duration = 0
|
||||
else:
|
||||
# 尝试从文件名中提取音频信息
|
||||
try:
|
||||
# 这里我们不能直接访问info对象,因为它在transcribe_audio函数的作用域内
|
||||
# 使用一个保守的估计值或从结果字符串中提取信息
|
||||
audio_duration = 0 # 默认为0
|
||||
except Exception as e:
|
||||
logger.warning(f"无法从文件名中提取音频时长: {str(e)}")
|
||||
audio_duration = 0
|
||||
|
||||
# 累加音频时长
|
||||
total_audio_duration += audio_duration
|
||||
except Exception as e:
|
||||
logger.error(f"转录过程中发生错误: {file_name} - {str(e)}")
|
||||
results.append(f"❌ 失败: {file_name} - {str(e)}")
|
||||
error_count += 1
|
||||
# 计算总转录时间
|
||||
total_transcription_time = time.time() - start_time
|
||||
# 生成批处理结果摘要
|
||||
summary = f"批处理完成,总转录时间: {format_time(total_transcription_time)}"
|
||||
summary += f" | 成功: {success_count}/{total_files}"
|
||||
summary += f" | 失败: {error_count}/{total_files}"
|
||||
# 输出结果
|
||||
for result in results:
|
||||
logger.info(result)
|
||||
return summary
|
||||
@@ -1,108 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
基于Faster Whisper的语音识别MCP服务
|
||||
提供高性能的音频转录功能,支持批处理加速和多种输出格式
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from model_manager import get_model_info
|
||||
from transcriber import transcribe_audio, batch_transcribe
|
||||
|
||||
# 日志配置
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建FastMCP服务器实例
|
||||
mcp = FastMCP(
|
||||
name="fast-whisper-mcp-server",
|
||||
version="0.1.1",
|
||||
dependencies=["faster-whisper>=0.9.0", "torch==2.6.0+cu126", "torchaudio==2.6.0+cu126", "numpy>=1.20.0"]
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def get_model_info_api() -> str:
|
||||
"""
|
||||
获取可用的Whisper模型信息
|
||||
"""
|
||||
return get_model_info()
|
||||
|
||||
@mcp.tool()
|
||||
def transcribe(audio_path: str, model_name: str = "large-v3", device: str = "auto",
|
||||
compute_type: str = "auto", language: str = None, output_format: str = "vtt",
|
||||
beam_size: int = 5, temperature: float = 0.0, initial_prompt: str = None,
|
||||
output_directory: str = None) -> str:
|
||||
"""
|
||||
使用Faster Whisper转录音频文件
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: 运行设备 (cpu, cuda, auto)
|
||||
compute_type: 计算类型 (float16, int8, auto)
|
||||
language: 语言代码 (如zh, en, ja等,默认自动检测)
|
||||
output_format: 输出格式 (vtt, srt或json)
|
||||
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
|
||||
temperature: 采样温度,贪婪解码
|
||||
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
|
||||
output_directory: 输出目录路径,默认为音频文件所在目录
|
||||
|
||||
Returns:
|
||||
str: 转录结果,格式为VTT字幕或JSON
|
||||
"""
|
||||
return transcribe_audio(
|
||||
audio_path=audio_path,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
output_directory=output_directory
|
||||
)
|
||||
|
||||
@mcp.tool()
|
||||
def batch_transcribe_audio(audio_folder: str, output_folder: str = None, model_name: str = "large-v3",
|
||||
device: str = "auto", compute_type: str = "auto", language: str = None,
|
||||
output_format: str = "vtt", beam_size: int = 5, temperature: float = 0.0,
|
||||
initial_prompt: str = None, parallel_files: int = 1) -> str:
|
||||
"""
|
||||
批量转录文件夹中的音频文件
|
||||
|
||||
Args:
|
||||
audio_folder: 包含音频文件的文件夹路径
|
||||
output_folder: 输出文件夹路径,默认为audio_folder下的transcript子文件夹
|
||||
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: 运行设备 (cpu, cuda, auto)
|
||||
compute_type: 计算类型 (float16, int8, auto)
|
||||
language: 语言代码 (如zh, en, ja等,默认自动检测)
|
||||
output_format: 输出格式 (vtt, srt或json)
|
||||
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
|
||||
temperature: 采样温度,0表示贪婪解码
|
||||
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
|
||||
parallel_files: 并行处理的文件数量(仅在CPU模式下有效)
|
||||
|
||||
Returns:
|
||||
str: 批处理结果摘要,包含处理时间和成功率
|
||||
"""
|
||||
return batch_transcribe(
|
||||
audio_folder=audio_folder,
|
||||
output_folder=output_folder,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
output_format=output_format,
|
||||
beam_size=beam_size,
|
||||
temperature=temperature,
|
||||
initial_prompt=initial_prompt,
|
||||
parallel_files=parallel_files
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行服务器
|
||||
mcp.run()
|
||||
Reference in New Issue
Block a user