Compare commits
4 Commits
main
...
9c020f947b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c020f947b | ||
|
|
4936684db4 | ||
|
|
8e30a8812c | ||
|
|
37935066ad |
51
Dockerfile
Normal file
51
Dockerfile
Normal file
@@ -0,0 +1,51 @@
|
||||
# Use NVIDIA CUDA base image with Python
|
||||
FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04
|
||||
|
||||
# Install Python 3.12
|
||||
RUN apt-get update && apt-get install -y \
|
||||
software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update && apt-get install -y \
|
||||
python3.12 \
|
||||
python3.12-venv \
|
||||
python3.12-dev \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Make python3.12 the default
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1
|
||||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1
|
||||
|
||||
# Upgrade pip
|
||||
RUN python -m pip install --upgrade pip
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY fast-whisper-mcp-server/requirements.txt .
|
||||
|
||||
# Install Python dependencies with CUDA support
|
||||
RUN pip install --no-cache-dir \
|
||||
faster-whisper \
|
||||
torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121 \
|
||||
torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121 \
|
||||
mcp[cli]
|
||||
|
||||
# Copy application code
|
||||
COPY fast-whisper-mcp-server/ .
|
||||
|
||||
# Create directories for models and outputs
|
||||
RUN mkdir -p /models /outputs
|
||||
|
||||
# Set environment variables for GPU
|
||||
ENV WHISPER_MODEL_DIR=/models
|
||||
ENV TRANSCRIPTION_OUTPUT_DIR=/outputs
|
||||
ENV TRANSCRIPTION_MODEL=large-v3
|
||||
ENV TRANSCRIPTION_DEVICE=cuda
|
||||
ENV TRANSCRIPTION_COMPUTE_TYPE=float16
|
||||
|
||||
# Run the server
|
||||
CMD ["python", "whisper_server.py"]
|
||||
184
README-CN.md
184
README-CN.md
@@ -1,184 +0,0 @@
|
||||
# Whisper 语音识别 MCP 服务器
|
||||
|
||||
基于 Faster Whisper 的语音识别 MCP 服务器,提供高性能的音频转录功能。
|
||||
|
||||
## 功能特点
|
||||
|
||||
- 集成 Faster Whisper 进行高效语音识别
|
||||
- 支持批处理加速,提高转录速度
|
||||
- 自动使用 CUDA 加速(如果可用)
|
||||
- 支持多种模型大小(tiny 到 large-v3)
|
||||
- 输出格式支持 VTT 字幕和 JSON
|
||||
- 支持批量转录文件夹中的音频文件
|
||||
- 模型实例缓存,避免重复加载
|
||||
|
||||
## 安装
|
||||
|
||||
### 依赖项
|
||||
|
||||
- Python 3.10+
|
||||
- faster-whisper>=0.9.0
|
||||
- torch==2.6.0+cu126
|
||||
- torchaudio==2.6.0+cu126
|
||||
- mcp[cli]>=1.2.0
|
||||
|
||||
### 安装步骤
|
||||
|
||||
1. 克隆或下载此仓库
|
||||
2. 创建并激活虚拟环境(推荐)
|
||||
3. 安装依赖项:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 启动服务器
|
||||
|
||||
在 Windows 上,直接运行 `start_server.bat`。
|
||||
|
||||
在其他平台上,运行:
|
||||
|
||||
```bash
|
||||
python whisper_server.py
|
||||
```
|
||||
|
||||
### 配置 Claude Desktop
|
||||
|
||||
1. 打开 Claude Desktop 配置文件:
|
||||
- Windows: `%APPDATA%\Claude\claude_desktop_config.json`
|
||||
- macOS: `~/Library/Application Support/Claude/claude_desktop_config.json`
|
||||
|
||||
2. 添加 Whisper 服务器配置:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"whisper": {
|
||||
"command": "python",
|
||||
"args": ["D:/path/to/whisper_server.py"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. 重启 Claude Desktop
|
||||
|
||||
### 可用工具
|
||||
|
||||
服务器提供以下工具:
|
||||
|
||||
1. **get_model_info** - 获取可用的 Whisper 模型信息
|
||||
2. **transcribe** - 转录单个音频文件
|
||||
3. **batch_transcribe** - 批量转录文件夹中的音频文件
|
||||
|
||||
## 性能优化建议
|
||||
|
||||
- 使用 CUDA 加速可显著提高转录速度
|
||||
- 对于大量短音频,批处理模式效率更高
|
||||
- 根据 GPU 显存大小自动调整批处理大小
|
||||
- 对于长音频,使用 VAD 过滤可提高准确性
|
||||
- 指定正确的语言可提高转录质量
|
||||
|
||||
## 本地测试方案
|
||||
|
||||
1. 使用 MCP Inspector 进行快速测试:
|
||||
|
||||
```bash
|
||||
mcp dev whisper_server.py
|
||||
```
|
||||
|
||||
2. 使用 Claude Desktop 进行集成测试
|
||||
|
||||
3. 使用命令行直接调用(需要安装 mcp[cli]):
|
||||
|
||||
```bash
|
||||
mcp run whisper_server.py
|
||||
```
|
||||
|
||||
## 错误处理
|
||||
|
||||
服务器实现了以下错误处理机制:
|
||||
|
||||
- 音频文件不存在检查
|
||||
- 模型加载失败处理
|
||||
- 转录过程异常捕获
|
||||
- GPU 内存管理
|
||||
- 批处理参数自适应调整
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT
|
||||
|
||||
## 致谢
|
||||
|
||||
本项目在开发过程中得到了以下优秀AI工具和模型的帮助:
|
||||
|
||||
- [GitHub Copilot](https://github.com/features/copilot) - AI结对编程助手
|
||||
- [Trae](https://trae.ai/) - 智能AI编码助手
|
||||
- [Cline](https://cline.ai/) - AI驱动的终端
|
||||
- [DeepSeek](https://www.deepseek.com/) - 先进的AI模型
|
||||
- [Claude-3.7-Sonnet](https://www.anthropic.com/claude) - Anthropic强大的AI助手
|
||||
- [Gemini-2.0-Flash](https://ai.google/gemini/) - Google的多模态AI模型
|
||||
- [VS Code](https://code.visualstudio.com/) - 强大的代码编辑器
|
||||
- [Whisper](https://github.com/openai/whisper) - OpenAI的语音识别模型
|
||||
- [Faster Whisper](https://github.com/guillaumekln/faster-whisper) - 优化的Whisper实现
|
||||
|
||||
特别感谢这些出色的工具和背后的团队。
|
||||
|
||||
---
|
||||
|
||||
# Whisper 语音识别 MCP 服务器(cline claude sonnet 3.7 完成所有任务后的说明)
|
||||
|
||||
我已经成功创建了基于Python的语音识别MCP服务器,该服务器集成了Faster Whisper并支持批处理加速。服务器具有以下特点:
|
||||
|
||||
## 核心功能
|
||||
|
||||
- **集成Faster Whisper**:使用最新版本(>=0.9.0)进行高效语音识别
|
||||
- **CUDA加速**:自动检测并使用CUDA加速,默认使用torch==2.6.0+cu126
|
||||
- **批处理优化**:使用BatchedInferencePipeline实现批处理,根据GPU显存自动调整批大小
|
||||
- **模型缓存**:实现模型实例缓存,避免重复加载模型
|
||||
- **多格式输出**:支持VTT字幕和JSON格式的转录结果
|
||||
- **批量处理**:支持批量转录整个文件夹的音频文件
|
||||
|
||||
## 主要工具
|
||||
|
||||
服务器提供了三个主要工具:
|
||||
|
||||
1. **get_model_info**:获取可用的Whisper模型信息和系统配置
|
||||
2. **transcribe**:转录单个音频文件,支持多种参数配置
|
||||
3. **batch_transcribe**:批量转录文件夹中的音频文件
|
||||
|
||||
## 错误处理机制
|
||||
|
||||
- 音频文件存在性验证
|
||||
- 模型加载异常捕获和日志记录
|
||||
- 转录过程异常处理
|
||||
- GPU内存管理和清理
|
||||
- 批处理参数自适应调整
|
||||
|
||||
## 性能优化
|
||||
|
||||
- 根据GPU显存大小动态调整批处理大小(4-32)
|
||||
- 使用VAD(语音活动检测)过滤提高准确性
|
||||
- 模型实例缓存避免重复加载
|
||||
- 自动选择最佳设备和计算类型
|
||||
|
||||
## 本地测试方案
|
||||
|
||||
提供了多种测试方法:
|
||||
|
||||
- 使用MCP Inspector进行快速测试:`mcp dev whisper_server.py`
|
||||
- 使用Claude Desktop进行集成测试
|
||||
- 使用命令行直接调用:`mcp run whisper_server.py`
|
||||
|
||||
所有文件已准备就绪,包括:
|
||||
|
||||
- whisper_server.py:主服务器代码
|
||||
- requirements.txt:依赖项列表
|
||||
- start_server.bat:Windows启动脚本
|
||||
- README.md:详细文档
|
||||
|
||||
您可以通过运行start_server.bat或直接执行`python whisper_server.py`来启动服务器。
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
音频处理模块
|
||||
负责音频文件的验证和预处理
|
||||
Audio Processing Module
|
||||
Responsible for audio file validation and preprocessing
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -9,59 +9,59 @@ import logging
|
||||
from typing import Union, Any
|
||||
from faster_whisper import decode_audio
|
||||
|
||||
# 日志配置
|
||||
# Log configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def validate_audio_file(audio_path: str) -> str:
|
||||
"""
|
||||
验证音频文件是否有效
|
||||
Validate if an audio file is valid
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
audio_path: Path to the audio file
|
||||
|
||||
Returns:
|
||||
str: 验证结果,"ok"表示验证通过,否则返回错误信息
|
||||
str: Validation result, "ok" indicates validation passed, otherwise returns error message
|
||||
"""
|
||||
# 验证参数
|
||||
# Validate parameters
|
||||
if not os.path.exists(audio_path):
|
||||
return f"错误: 音频文件不存在: {audio_path}"
|
||||
return f"Error: Audio file does not exist: {audio_path}"
|
||||
|
||||
# 验证文件格式
|
||||
# Validate file format
|
||||
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
||||
file_ext = os.path.splitext(audio_path)[1].lower()
|
||||
if file_ext not in supported_formats:
|
||||
return f"错误: 不支持的音频格式: {file_ext}。支持的格式: {', '.join(supported_formats)}"
|
||||
return f"Error: Unsupported audio format: {file_ext}. Supported formats: {', '.join(supported_formats)}"
|
||||
|
||||
# 验证文件大小
|
||||
# Validate file size
|
||||
try:
|
||||
file_size = os.path.getsize(audio_path)
|
||||
if file_size == 0:
|
||||
return f"错误: 音频文件为空: {audio_path}"
|
||||
return f"Error: Audio file is empty: {audio_path}"
|
||||
|
||||
# 大文件警告(超过1GB)
|
||||
# Warning for large files (over 1GB)
|
||||
if file_size > 1024 * 1024 * 1024:
|
||||
logger.warning(f"警告: 文件大小超过1GB,可能需要较长处理时间: {audio_path}")
|
||||
logger.warning(f"Warning: File size exceeds 1GB, may require longer processing time: {audio_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"检查文件大小失败: {str(e)}")
|
||||
return f"错误: 检查文件大小失败: {str(e)}"
|
||||
logger.error(f"Failed to check file size: {str(e)}")
|
||||
return f"Error: Failed to check file size: {str(e)}"
|
||||
|
||||
return "ok"
|
||||
|
||||
def process_audio(audio_path: str) -> Union[str, Any]:
|
||||
"""
|
||||
处理音频文件,进行解码和预处理
|
||||
Process audio file, perform decoding and preprocessing
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
audio_path: Path to the audio file
|
||||
|
||||
Returns:
|
||||
Union[str, Any]: 处理后的音频数据或原始文件路径
|
||||
Union[str, Any]: Processed audio data or original file path
|
||||
"""
|
||||
# 尝试使用decode_audio预处理音频,以处理更多格式
|
||||
# Try to preprocess audio using decode_audio to handle more formats
|
||||
try:
|
||||
audio_data = decode_audio(audio_path)
|
||||
logger.info(f"成功预处理音频: {os.path.basename(audio_path)}")
|
||||
logger.info(f"Successfully preprocessed audio: {os.path.basename(audio_path)}")
|
||||
return audio_data
|
||||
except Exception as audio_error:
|
||||
logger.warning(f"音频预处理失败,将直接使用文件路径: {str(audio_error)}")
|
||||
logger.warning(f"Audio preprocessing failed, will use file path directly: {str(audio_error)}")
|
||||
return audio_path
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
格式化输出模块
|
||||
负责将转录结果格式化为不同的输出格式(VTT、SRT、JSON)
|
||||
Formatting Output Module
|
||||
Responsible for formatting transcription results into different output formats (VTT, SRT, JSON, TXT)
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -9,13 +9,13 @@ from typing import List, Dict, Any
|
||||
|
||||
def format_vtt(segments: List) -> str:
|
||||
"""
|
||||
将转录结果格式化为VTT
|
||||
Format transcription results to VTT
|
||||
|
||||
Args:
|
||||
segments: 转录段落列表
|
||||
segments: List of transcription segments
|
||||
|
||||
Returns:
|
||||
str: VTT格式的字幕内容
|
||||
str: Subtitle content in VTT format
|
||||
"""
|
||||
vtt_content = "WEBVTT\n\n"
|
||||
|
||||
@@ -31,13 +31,13 @@ def format_vtt(segments: List) -> str:
|
||||
|
||||
def format_srt(segments: List) -> str:
|
||||
"""
|
||||
将转录结果格式化为SRT
|
||||
Format transcription results to SRT
|
||||
|
||||
Args:
|
||||
segments: 转录段落列表
|
||||
segments: List of transcription segments
|
||||
|
||||
Returns:
|
||||
str: SRT格式的字幕内容
|
||||
str: Subtitle content in SRT format
|
||||
"""
|
||||
srt_content = ""
|
||||
index = 1
|
||||
@@ -53,16 +53,51 @@ def format_srt(segments: List) -> str:
|
||||
|
||||
return srt_content
|
||||
|
||||
def format_json(segments: List, info: Any) -> str:
|
||||
def format_txt(segments: List) -> str:
|
||||
"""
|
||||
将转录结果格式化为JSON
|
||||
Format transcription results to plain text
|
||||
|
||||
Args:
|
||||
segments: 转录段落列表
|
||||
info: 转录信息对象
|
||||
segments: List of transcription segments
|
||||
|
||||
Returns:
|
||||
str: JSON格式的转录结果
|
||||
str: Plain text transcription content
|
||||
"""
|
||||
text_content = ""
|
||||
|
||||
for segment in segments:
|
||||
text = segment.text.strip()
|
||||
if text:
|
||||
# Add the text content
|
||||
text_content += text
|
||||
|
||||
# Add appropriate spacing between segments
|
||||
# If the text doesn't end with punctuation, add a space
|
||||
if not text.endswith(('.', '!', '?', ':', ';')):
|
||||
text_content += " "
|
||||
else:
|
||||
# If it ends with punctuation, add a space for natural flow
|
||||
text_content += " "
|
||||
|
||||
# Clean up any trailing whitespace and ensure single line breaks
|
||||
text_content = text_content.strip()
|
||||
|
||||
# Replace multiple spaces with single spaces
|
||||
while " " in text_content:
|
||||
text_content = text_content.replace(" ", " ")
|
||||
|
||||
return text_content
|
||||
|
||||
def format_json(segments: List, info: Any) -> str:
|
||||
"""
|
||||
Format transcription results to JSON
|
||||
|
||||
Args:
|
||||
segments: List of transcription segments
|
||||
info: Transcription information object
|
||||
|
||||
Returns:
|
||||
str: Transcription results in JSON format
|
||||
"""
|
||||
result = {
|
||||
"segments": [{
|
||||
@@ -86,13 +121,13 @@ def format_json(segments: List, info: Any) -> str:
|
||||
|
||||
def format_timestamp(seconds: float) -> str:
|
||||
"""
|
||||
格式化时间戳为VTT格式
|
||||
Format timestamp for VTT format
|
||||
|
||||
Args:
|
||||
seconds: 秒数
|
||||
seconds: Number of seconds
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间戳 (HH:MM:SS.mmm)
|
||||
str: Formatted timestamp (HH:MM:SS.mmm)
|
||||
"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
@@ -101,13 +136,13 @@ def format_timestamp(seconds: float) -> str:
|
||||
|
||||
def format_timestamp_srt(seconds: float) -> str:
|
||||
"""
|
||||
格式化时间戳为SRT格式
|
||||
Format timestamp for SRT format
|
||||
|
||||
Args:
|
||||
seconds: 秒数
|
||||
seconds: Number of seconds
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间戳 (HH:MM:SS,mmm)
|
||||
str: Formatted timestamp (HH:MM:SS,mmm)
|
||||
"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
@@ -117,13 +152,13 @@ def format_timestamp_srt(seconds: float) -> str:
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""
|
||||
格式化时间为可读格式
|
||||
Format time into readable format
|
||||
|
||||
Args:
|
||||
seconds: 秒数
|
||||
seconds: Number of seconds
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间 (HH:MM:SS)
|
||||
str: Formatted time (HH:MM:SS)
|
||||
"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
模型管理模块
|
||||
负责Whisper模型的加载、缓存和管理
|
||||
Model Management Module
|
||||
Responsible for loading, caching, and managing Whisper models
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -11,86 +11,86 @@ from typing import Dict, Any
|
||||
import torch
|
||||
from faster_whisper import WhisperModel, BatchedInferencePipeline
|
||||
|
||||
# 日志配置
|
||||
# Log configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局模型实例缓存
|
||||
# Global model instance cache
|
||||
model_instances = {}
|
||||
|
||||
def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取或创建Whisper模型实例
|
||||
Get or create Whisper model instance
|
||||
|
||||
Args:
|
||||
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: 运行设备 (cpu, cuda, auto)
|
||||
compute_type: 计算类型 (float16, int8, auto)
|
||||
model_name: Model name (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: Running device (cpu, cuda, auto)
|
||||
compute_type: Computation type (float16, int8, auto)
|
||||
|
||||
Returns:
|
||||
dict: 包含模型实例和配置的字典
|
||||
dict: Dictionary containing model instance and configuration
|
||||
"""
|
||||
global model_instances
|
||||
|
||||
# 验证模型名称
|
||||
# Validate model name
|
||||
valid_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"]
|
||||
if model_name not in valid_models:
|
||||
raise ValueError(f"无效的模型名称: {model_name}。有效的模型: {', '.join(valid_models)}")
|
||||
raise ValueError(f"Invalid model name: {model_name}. Valid models: {', '.join(valid_models)}")
|
||||
|
||||
# 自动检测设备
|
||||
# Auto-detect device
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
|
||||
# 验证设备和计算类型
|
||||
# Validate device and compute type
|
||||
if device not in ["cpu", "cuda"]:
|
||||
raise ValueError(f"无效的设备: {device}。有效的设备: cpu, cuda")
|
||||
raise ValueError(f"Invalid device: {device}. Valid devices: cpu, cuda")
|
||||
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
logger.warning("CUDA不可用,自动切换到CPU")
|
||||
logger.warning("CUDA not available, automatically switching to CPU")
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
|
||||
if compute_type not in ["float16", "int8"]:
|
||||
raise ValueError(f"无效的计算类型: {compute_type}。有效的计算类型: float16, int8")
|
||||
raise ValueError(f"Invalid compute type: {compute_type}. Valid compute types: float16, int8")
|
||||
|
||||
if device == "cpu" and compute_type == "float16":
|
||||
logger.warning("CPU设备不支持float16计算类型,自动切换到int8")
|
||||
logger.warning("CPU device does not support float16 computation type, automatically switching to int8")
|
||||
compute_type = "int8"
|
||||
|
||||
# 生成模型键
|
||||
# Generate model key
|
||||
model_key = f"{model_name}_{device}_{compute_type}"
|
||||
|
||||
# 如果模型已实例化,直接返回
|
||||
# If model is already instantiated, return directly
|
||||
if model_key in model_instances:
|
||||
logger.info(f"使用缓存的模型实例: {model_key}")
|
||||
logger.info(f"Using cached model instance: {model_key}")
|
||||
return model_instances[model_key]
|
||||
|
||||
# 清理GPU内存(如果使用CUDA)
|
||||
# Clean GPU memory (if using CUDA)
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 实例化模型
|
||||
# Instantiate model
|
||||
try:
|
||||
logger.info(f"加载Whisper模型: {model_name} 设备: {device} 计算类型: {compute_type}")
|
||||
logger.info(f"Loading Whisper model: {model_name} device: {device} compute type: {compute_type}")
|
||||
|
||||
# 基础模型
|
||||
# Base model
|
||||
model = WhisperModel(
|
||||
model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
download_root=os.environ.get("WHISPER_MODEL_DIR", None) # 支持自定义模型目录
|
||||
download_root=os.environ.get("WHISPER_MODEL_DIR", None) # Support custom model directory
|
||||
)
|
||||
|
||||
# 批处理设置 - 默认启用批处理以提高速度
|
||||
# Batch processing settings - batch processing enabled by default to improve speed
|
||||
batched_model = None
|
||||
batch_size = 0
|
||||
|
||||
if device == "cuda": # 只在CUDA设备上使用批处理
|
||||
# 根据显存大小确定合适的批大小
|
||||
if device == "cuda": # Only use batch processing on CUDA devices
|
||||
# Determine appropriate batch size based on available memory
|
||||
if torch.cuda.is_available():
|
||||
gpu_mem = torch.cuda.get_device_properties(0).total_memory
|
||||
free_mem = gpu_mem - torch.cuda.memory_allocated()
|
||||
# 根据GPU显存动态调整批大小
|
||||
# Dynamically adjust batch size based on GPU memory
|
||||
if free_mem > 16e9: # >16GB
|
||||
batch_size = 32
|
||||
elif free_mem > 12e9: # >12GB
|
||||
@@ -99,17 +99,17 @@ def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[s
|
||||
batch_size = 8
|
||||
elif free_mem > 4e9: # >4GB
|
||||
batch_size = 4
|
||||
else: # 较小显存
|
||||
else: # Smaller memory
|
||||
batch_size = 2
|
||||
|
||||
logger.info(f"可用GPU显存: {free_mem / 1e9:.2f} GB")
|
||||
logger.info(f"Available GPU memory: {free_mem / 1e9:.2f} GB")
|
||||
else:
|
||||
batch_size = 8 # 默认值
|
||||
batch_size = 8 # Default value
|
||||
|
||||
logger.info(f"启用批处理加速,批大小: {batch_size}")
|
||||
logger.info(f"Enabling batch processing acceleration, batch size: {batch_size}")
|
||||
batched_model = BatchedInferencePipeline(model=model)
|
||||
|
||||
# 创建结果对象
|
||||
# Create result object
|
||||
result = {
|
||||
'model': model,
|
||||
'device': device,
|
||||
@@ -119,20 +119,20 @@ def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[s
|
||||
'load_time': time.time()
|
||||
}
|
||||
|
||||
# 缓存实例
|
||||
# Cache instance
|
||||
model_instances[model_key] = result
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载模型失败: {str(e)}")
|
||||
logger.error(f"Failed to load model: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_model_info() -> str:
|
||||
"""
|
||||
获取可用的Whisper模型信息
|
||||
Get available Whisper model information
|
||||
|
||||
Returns:
|
||||
str: 模型信息的JSON字符串
|
||||
str: JSON string of model information
|
||||
"""
|
||||
import json
|
||||
|
||||
@@ -142,15 +142,15 @@ def get_model_info() -> str:
|
||||
devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
|
||||
compute_types = ["float16", "int8"] if torch.cuda.is_available() else ["int8"]
|
||||
|
||||
# 支持的语言列表
|
||||
# Supported language list
|
||||
languages = {
|
||||
"zh": "中文", "en": "英语", "ja": "日语", "ko": "韩语", "de": "德语",
|
||||
"fr": "法语", "es": "西班牙语", "ru": "俄语", "it": "意大利语",
|
||||
"pt": "葡萄牙语", "nl": "荷兰语", "ar": "阿拉伯语", "hi": "印地语",
|
||||
"tr": "土耳其语", "vi": "越南语", "th": "泰语", "id": "印尼语"
|
||||
"zh": "Chinese", "en": "English", "ja": "Japanese", "ko": "Korean", "de": "German",
|
||||
"fr": "French", "es": "Spanish", "ru": "Russian", "it": "Italian",
|
||||
"pt": "Portuguese", "nl": "Dutch", "ar": "Arabic", "hi": "Hindi",
|
||||
"tr": "Turkish", "vi": "Vietnamese", "th": "Thai", "id": "Indonesian"
|
||||
}
|
||||
|
||||
# 支持的音频格式
|
||||
# Supported audio formats
|
||||
audio_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
||||
|
||||
info = {
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
# uv pip install -r ./requirements.txt --index-url https://download.pytorch.org/whl/cu126
|
||||
faster-whisper
|
||||
torch==2.6.0+cu126
|
||||
torchaudio==2.6.0+cu126
|
||||
torch #==2.6.0+cu126
|
||||
torchaudio #==2.6.0+cu126
|
||||
|
||||
# uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
# pip install faster-whisper>=0.9.0
|
||||
# pip install mcp[cli]>=1.2.0
|
||||
mcp[cli]
|
||||
|
||||
# PyTorch安装指南:
|
||||
# 请根据您的CUDA版本安装适当版本的PyTorch:
|
||||
# PyTorch Installation Guide:
|
||||
# Please install the appropriate version of PyTorch based on your CUDA version:
|
||||
#
|
||||
# • CUDA 12.6:
|
||||
# pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
@@ -16,7 +17,6 @@ torchaudio==2.6.0+cu126
|
||||
# • CUDA 12.1:
|
||||
# pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
||||
#
|
||||
# • CPU版本:
|
||||
# • CPU version:
|
||||
# pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
#
|
||||
# 可用命令`nvcc --version`或`nvidia-smi`查看CUDA版本
|
||||
36
run_server.sh
Executable file
36
run_server.sh
Executable file
@@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
datetime_prefix() {
|
||||
date "+[%Y-%m-%d %H:%M:%S]"
|
||||
}
|
||||
|
||||
# Get current user ID to avoid permission issues
|
||||
USER_ID=$(id -u)
|
||||
GROUP_ID=$(id -g)
|
||||
|
||||
# Set environment variables
|
||||
export WHISPER_MODEL_DIR="/home/uad/agents/tools/mcp-transcriptor/data/models"
|
||||
export TRANSCRIPTION_OUTPUT_DIR="/home/uad/agents/tools/mcp-transcriptor/data/transcripts"
|
||||
export TRANSCRIPTION_BATCH_OUTPUT_DIR="/home/uad/agents/tools/mcp-transcriptor/data/transcripts/batch"
|
||||
export TRANSCRIPTION_MODEL="base"
|
||||
export TRANSCRIPTION_DEVICE="auto"
|
||||
export TRANSCRIPTION_COMPUTE_TYPE="auto"
|
||||
export TRANSCRIPTION_OUTPUT_FORMAT="txt"
|
||||
export TRANSCRIPTION_BEAM_SIZE="2"
|
||||
export TRANSCRIPTION_TEMPERATURE="0.0"
|
||||
export TRANSCRIPTION_USE_TIMESTAMP="false"
|
||||
export TRANSCRIPTION_FILENAME_PREFIX="test_"
|
||||
|
||||
# Log start of the script
|
||||
echo "$(datetime_prefix) Starting whisper server script..."
|
||||
|
||||
# Optional: Verify required directories exist
|
||||
if [ ! -d "$WHISPER_MODEL_DIR" ]; then
|
||||
echo "$(datetime_prefix) Error: Whisper model directory does not exist: $WHISPER_MODEL_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run the Python script with the defined environment variables
|
||||
sudo /home/uad/agents/tools/mcp-transcriptor/venv/bin/python \
|
||||
/home/uad/agents/tools/mcp-transcriptor/whisper_server.py
|
||||
@@ -1,16 +0,0 @@
|
||||
@echo off
|
||||
echo 启动Whisper语音识别MCP服务器...
|
||||
|
||||
:: 激活虚拟环境(如果存在)
|
||||
if exist "..\venv\Scripts\activate.bat" (
|
||||
call ..\venv\Scripts\activate.bat
|
||||
)
|
||||
|
||||
:: 运行MCP服务器
|
||||
python whisper_server.py
|
||||
|
||||
:: 如果出错,暂停以查看错误信息
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo 服务器启动失败,错误代码: %ERRORLEVEL%
|
||||
pause
|
||||
)
|
||||
375
transcriber.py
375
transcriber.py
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
转录核心模块
|
||||
包含音频转录的核心逻辑
|
||||
Transcription Core Module with Environment Variable Support
|
||||
Contains core logic for audio transcription
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -11,211 +11,264 @@ from typing import Dict, Any, Tuple, List, Optional, Union
|
||||
|
||||
from model_manager import get_whisper_model
|
||||
from audio_processor import validate_audio_file, process_audio
|
||||
from formatters import format_vtt, format_srt, format_json, format_time
|
||||
from formatters import format_vtt, format_srt, format_json, format_txt, format_time
|
||||
|
||||
# 日志配置
|
||||
# Logging configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Environment variable defaults
|
||||
DEFAULT_OUTPUT_DIR = os.getenv('TRANSCRIPTION_OUTPUT_DIR', None)
|
||||
DEFAULT_BATCH_OUTPUT_DIR = os.getenv('TRANSCRIPTION_BATCH_OUTPUT_DIR', None)
|
||||
DEFAULT_MODEL = os.getenv('TRANSCRIPTION_MODEL', 'large-v3')
|
||||
DEFAULT_DEVICE = os.getenv('TRANSCRIPTION_DEVICE', 'auto')
|
||||
DEFAULT_COMPUTE_TYPE = os.getenv('TRANSCRIPTION_COMPUTE_TYPE', 'auto')
|
||||
DEFAULT_LANGUAGE = os.getenv('TRANSCRIPTION_LANGUAGE', None)
|
||||
DEFAULT_OUTPUT_FORMAT = os.getenv('TRANSCRIPTION_OUTPUT_FORMAT', 'txt')
|
||||
DEFAULT_BEAM_SIZE = int(os.getenv('TRANSCRIPTION_BEAM_SIZE', '5'))
|
||||
DEFAULT_TEMPERATURE = float(os.getenv('TRANSCRIPTION_TEMPERATURE', '0.0'))
|
||||
|
||||
# Model storage configuration
|
||||
WHISPER_MODEL_DIR = os.getenv('WHISPER_MODEL_DIR', None)
|
||||
|
||||
# File naming configuration
|
||||
USE_TIMESTAMP = os.getenv('TRANSCRIPTION_USE_TIMESTAMP', 'true').lower() == 'true'
|
||||
FILENAME_PREFIX = os.getenv('TRANSCRIPTION_FILENAME_PREFIX', '')
|
||||
FILENAME_SUFFIX = os.getenv('TRANSCRIPTION_FILENAME_SUFFIX', '')
|
||||
|
||||
def transcribe_audio(
|
||||
audio_path: str,
|
||||
model_name: str = "large-v3",
|
||||
device: str = "auto",
|
||||
compute_type: str = "auto",
|
||||
model_name: str = None,
|
||||
device: str = None,
|
||||
compute_type: str = None,
|
||||
language: str = None,
|
||||
output_format: str = "vtt",
|
||||
beam_size: int = 5,
|
||||
temperature: float = 0.0,
|
||||
output_format: str = None,
|
||||
beam_size: int = None,
|
||||
temperature: float = None,
|
||||
initial_prompt: str = None,
|
||||
output_directory: str = None
|
||||
) -> str:
|
||||
"""
|
||||
使用Faster Whisper转录音频文件
|
||||
Transcribe audio file using Faster Whisper with ENV VAR support
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: 运行设备 (cpu, cuda, auto)
|
||||
compute_type: 计算类型 (float16, int8, auto)
|
||||
language: 语言代码 (如zh, en, ja等,默认自动检测)
|
||||
output_format: 输出格式 (vtt, srt或json)
|
||||
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
|
||||
temperature: 采样温度,贪婪解码
|
||||
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
|
||||
output_directory: 输出目录路径,默认为音频文件所在目录
|
||||
audio_path: Path to audio file
|
||||
model_name: Model name (defaults to TRANSCRIPTION_MODEL env var or "large-v3")
|
||||
device: Execution device (defaults to TRANSCRIPTION_DEVICE env var or "auto")
|
||||
compute_type: Computation type (defaults to TRANSCRIPTION_COMPUTE_TYPE env var or "auto")
|
||||
language: Language code (defaults to TRANSCRIPTION_LANGUAGE env var or auto-detect)
|
||||
output_format: Output format (defaults to TRANSCRIPTION_OUTPUT_FORMAT env var or "vtt")
|
||||
beam_size: Beam search size (defaults to TRANSCRIPTION_BEAM_SIZE env var or 5)
|
||||
temperature: Sampling temperature (defaults to TRANSCRIPTION_TEMPERATURE env var or 0.0)
|
||||
initial_prompt: Initial prompt text
|
||||
output_directory: Output directory (defaults to TRANSCRIPTION_OUTPUT_DIR env var or audio file directory)
|
||||
|
||||
Returns:
|
||||
str: 转录结果,格式为VTT字幕或JSON
|
||||
str: Transcription result path or error message
|
||||
"""
|
||||
# 验证音频文件
|
||||
# Apply environment variable defaults
|
||||
model_name = model_name or DEFAULT_MODEL
|
||||
device = device or DEFAULT_DEVICE
|
||||
compute_type = compute_type or DEFAULT_COMPUTE_TYPE
|
||||
language = language or DEFAULT_LANGUAGE
|
||||
output_format = output_format or DEFAULT_OUTPUT_FORMAT
|
||||
beam_size = beam_size if beam_size is not None else DEFAULT_BEAM_SIZE
|
||||
temperature = temperature if temperature is not None else DEFAULT_TEMPERATURE
|
||||
|
||||
# Validate audio file
|
||||
validation_result = validate_audio_file(audio_path)
|
||||
if validation_result != "ok":
|
||||
return validation_result
|
||||
|
||||
try:
|
||||
# 获取模型实例
|
||||
# Get model instance
|
||||
model_instance = get_whisper_model(model_name, device, compute_type)
|
||||
|
||||
# 验证语言代码
|
||||
# Validate language code
|
||||
supported_languages = {
|
||||
"zh": "中文", "en": "英语", "ja": "日语", "ko": "韩语", "de": "德语",
|
||||
"fr": "法语", "es": "西班牙语", "ru": "俄语", "it": "意大利语",
|
||||
"pt": "葡萄牙语", "nl": "荷兰语", "ar": "阿拉伯语", "hi": "印地语",
|
||||
"tr": "土耳其语", "vi": "越南语", "th": "泰语", "id": "印尼语"
|
||||
"zh": "Chinese", "en": "English", "ja": "Japanese", "ko": "Korean", "de": "German",
|
||||
"fr": "French", "es": "Spanish", "ru": "Russian", "it": "Italian",
|
||||
"pt": "Portuguese", "nl": "Dutch", "ar": "Arabic", "hi": "Hindi",
|
||||
"tr": "Turkish", "vi": "Vietnamese", "th": "Thai", "id": "Indonesian"
|
||||
}
|
||||
|
||||
if language is not None and language not in supported_languages:
|
||||
logger.warning(f"未知的语言代码: {language},将使用自动检测")
|
||||
logger.warning(f"Unknown language code: {language}, will use auto-detection")
|
||||
language = None
|
||||
|
||||
# 设置转录参数
|
||||
# Set transcription parameters
|
||||
options = {
|
||||
"language": language,
|
||||
"vad_filter": True, # 使用语音活动检测
|
||||
"vad_parameters": {"min_silence_duration_ms": 500}, # VAD参数优化
|
||||
"vad_filter": True,
|
||||
"vad_parameters": {"min_silence_duration_ms": 500},
|
||||
"beam_size": beam_size,
|
||||
"temperature": temperature,
|
||||
"initial_prompt": initial_prompt,
|
||||
"word_timestamps": True, # 启用单词级时间戳
|
||||
"suppress_tokens": [-1], # 抑制特殊标记
|
||||
"condition_on_previous_text": True, # 基于前文进行条件生成
|
||||
"compression_ratio_threshold": 2.4 # 压缩比阈值,用于过滤重复内容
|
||||
"word_timestamps": True,
|
||||
"suppress_tokens": [-1],
|
||||
"condition_on_previous_text": True,
|
||||
"compression_ratio_threshold": 2.4
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(f"开始转录文件: {os.path.basename(audio_path)}")
|
||||
logger.info(f"Starting transcription of file: {os.path.basename(audio_path)}")
|
||||
|
||||
# 处理音频
|
||||
# Process audio
|
||||
audio_source = process_audio(audio_path)
|
||||
|
||||
# 执行转录 - 优先使用批处理模型
|
||||
# Execute transcription
|
||||
if model_instance['batched_model'] is not None and model_instance['device'] == 'cuda':
|
||||
logger.info("使用批处理加速进行转录...")
|
||||
# 批处理模型需要单独设置batch_size参数
|
||||
logger.info("Using batch acceleration for transcription...")
|
||||
segments, info = model_instance['batched_model'].transcribe(
|
||||
audio_source,
|
||||
batch_size=model_instance['batch_size'],
|
||||
**options
|
||||
)
|
||||
else:
|
||||
logger.info("使用标准模型进行转录...")
|
||||
logger.info("Using standard model for transcription...")
|
||||
segments, info = model_instance['model'].transcribe(audio_source, **options)
|
||||
|
||||
# 将生成器转换为列表
|
||||
# Convert generator to list
|
||||
segment_list = list(segments)
|
||||
|
||||
if not segment_list:
|
||||
return "转录失败,未获得结果"
|
||||
return "Transcription failed, no results obtained"
|
||||
|
||||
# 记录转录信息
|
||||
# Record transcription information
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"转录完成,用时: {elapsed_time:.2f}秒,检测语言: {info.language},音频长度: {info.duration:.2f}秒")
|
||||
logger.info(f"Transcription completed, time used: {elapsed_time:.2f} seconds, detected language: {info.language}, audio length: {info.duration:.2f} seconds")
|
||||
|
||||
# 格式化转录结果
|
||||
if output_format.lower() == "vtt":
|
||||
# Format transcription results based on output format
|
||||
output_format_lower = output_format.lower()
|
||||
|
||||
if output_format_lower == "vtt":
|
||||
transcription_result = format_vtt(segment_list)
|
||||
elif output_format.lower() == "srt":
|
||||
elif output_format_lower == "srt":
|
||||
transcription_result = format_srt(segment_list)
|
||||
else:
|
||||
elif output_format_lower == "txt":
|
||||
transcription_result = format_txt(segment_list)
|
||||
elif output_format_lower == "json":
|
||||
transcription_result = format_json(segment_list, info)
|
||||
else:
|
||||
raise ValueError(f"Unsupported output format: {output_format}. Supported formats: vtt, srt, txt, json")
|
||||
|
||||
# 获取音频文件的目录和文件名
|
||||
# Determine output directory
|
||||
audio_dir = os.path.dirname(audio_path)
|
||||
audio_filename = os.path.splitext(os.path.basename(audio_path))[0]
|
||||
|
||||
# 设置输出目录
|
||||
if output_directory is None:
|
||||
output_dir = audio_dir
|
||||
else:
|
||||
# Priority: parameter > env var > audio directory
|
||||
if output_directory is not None:
|
||||
output_dir = output_directory
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
elif DEFAULT_OUTPUT_DIR is not None:
|
||||
output_dir = DEFAULT_OUTPUT_DIR
|
||||
else:
|
||||
output_dir = audio_dir
|
||||
|
||||
# 生成带有时间戳的文件名
|
||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||
output_filename = f"{audio_filename}_{timestamp}.{output_format.lower()}"
|
||||
# Ensure output directory exists
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Generate filename with customizable format
|
||||
filename_parts = []
|
||||
|
||||
# Add prefix if specified
|
||||
if FILENAME_PREFIX:
|
||||
filename_parts.append(FILENAME_PREFIX)
|
||||
|
||||
# Add base filename
|
||||
filename_parts.append(audio_filename)
|
||||
|
||||
# Add suffix if specified
|
||||
if FILENAME_SUFFIX:
|
||||
filename_parts.append(FILENAME_SUFFIX)
|
||||
|
||||
# Add timestamp if enabled
|
||||
if USE_TIMESTAMP:
|
||||
timestamp = time.strftime("%Y%m%d%H%M%S")
|
||||
filename_parts.append(timestamp)
|
||||
|
||||
# Join parts and add extension
|
||||
base_name = "_".join(filename_parts)
|
||||
output_filename = f"{base_name}.{output_format_lower}"
|
||||
output_path = os.path.join(output_dir, output_filename)
|
||||
|
||||
# 将转录结果写入文件
|
||||
# Write transcription results to file
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(transcription_result)
|
||||
logger.info(f"转录结果已保存到: {output_path}")
|
||||
return f"转录成功,结果已保存到: {output_path}"
|
||||
logger.info(f"Transcription results saved to: {output_path}")
|
||||
return f"Transcription successful, results saved to: {output_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"保存转录结果失败: {str(e)}")
|
||||
return f"转录成功,但保存结果失败: {str(e)}"
|
||||
logger.error(f"Failed to save transcription results: {str(e)}")
|
||||
return f"Transcription successful, but failed to save results: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转录失败: {str(e)}")
|
||||
return f"转录过程中发生错误: {str(e)}"
|
||||
logger.error(f"Transcription failed: {str(e)}")
|
||||
return f"Error occurred during transcription: {str(e)}"
|
||||
|
||||
|
||||
def report_progress(current: int, total: int, elapsed_time: float) -> str:
|
||||
"""
|
||||
生成进度报告
|
||||
|
||||
Args:
|
||||
current: 当前处理的项目数
|
||||
total: 总项目数
|
||||
elapsed_time: 已用时间(秒)
|
||||
|
||||
Returns:
|
||||
str: 格式化的进度报告
|
||||
"""
|
||||
progress = current / total * 100
|
||||
eta = (elapsed_time / current) * (total - current) if current > 0 else 0
|
||||
return (f"进度: {current}/{total} ({progress:.1f}%)" +
|
||||
f" | 已用时间: {format_time(elapsed_time)}" +
|
||||
f" | 预计剩余: {format_time(eta)}")
|
||||
|
||||
def batch_transcribe(
|
||||
audio_folder: str,
|
||||
output_folder: str = None,
|
||||
model_name: str = "large-v3",
|
||||
device: str = "auto",
|
||||
compute_type: str = "auto",
|
||||
model_name: str = None,
|
||||
device: str = None,
|
||||
compute_type: str = None,
|
||||
language: str = None,
|
||||
output_format: str = "vtt",
|
||||
beam_size: int = 5,
|
||||
temperature: float = 0.0,
|
||||
output_format: str = None,
|
||||
beam_size: int = None,
|
||||
temperature: float = None,
|
||||
initial_prompt: str = None,
|
||||
parallel_files: int = 1
|
||||
) -> str:
|
||||
"""
|
||||
批量转录文件夹中的音频文件
|
||||
Batch transcribe audio files with ENV VAR support
|
||||
|
||||
Args:
|
||||
audio_folder: 包含音频文件的文件夹路径
|
||||
output_folder: 输出文件夹路径,默认为audio_folder下的transcript子文件夹
|
||||
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: 运行设备 (cpu, cuda, auto)
|
||||
compute_type: 计算类型 (float16, int8, auto)
|
||||
language: 语言代码 (如zh, en, ja等,默认自动检测)
|
||||
output_format: 输出格式 (vtt, srt或json)
|
||||
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
|
||||
temperature: 采样温度,0表示贪婪解码
|
||||
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
|
||||
parallel_files: 并行处理的文件数量(仅在CPU模式下有效)
|
||||
audio_folder: Path to folder containing audio files
|
||||
output_folder: Output folder (defaults to TRANSCRIPTION_BATCH_OUTPUT_DIR env var or "transcript" subfolder)
|
||||
model_name: Model name (defaults to TRANSCRIPTION_MODEL env var or "large-v3")
|
||||
device: Execution device (defaults to TRANSCRIPTION_DEVICE env var or "auto")
|
||||
compute_type: Computation type (defaults to TRANSCRIPTION_COMPUTE_TYPE env var or "auto")
|
||||
language: Language code (defaults to TRANSCRIPTION_LANGUAGE env var or auto-detect)
|
||||
output_format: Output format (defaults to TRANSCRIPTION_OUTPUT_FORMAT env var or "vtt")
|
||||
beam_size: Beam search size (defaults to TRANSCRIPTION_BEAM_SIZE env var or 5)
|
||||
temperature: Sampling temperature (defaults to TRANSCRIPTION_TEMPERATURE env var or 0.0)
|
||||
initial_prompt: Initial prompt text
|
||||
parallel_files: Number of files to process in parallel (not implemented yet)
|
||||
|
||||
Returns:
|
||||
str: 批处理结果摘要,包含处理时间和成功率
|
||||
str: Batch processing summary
|
||||
"""
|
||||
# Apply environment variable defaults
|
||||
model_name = model_name or DEFAULT_MODEL
|
||||
device = device or DEFAULT_DEVICE
|
||||
compute_type = compute_type or DEFAULT_COMPUTE_TYPE
|
||||
language = language or DEFAULT_LANGUAGE
|
||||
output_format = output_format or DEFAULT_OUTPUT_FORMAT
|
||||
beam_size = beam_size if beam_size is not None else DEFAULT_BEAM_SIZE
|
||||
temperature = temperature if temperature is not None else DEFAULT_TEMPERATURE
|
||||
|
||||
if not os.path.isdir(audio_folder):
|
||||
return f"错误: 文件夹不存在: {audio_folder}"
|
||||
return f"Error: Folder does not exist: {audio_folder}"
|
||||
|
||||
# 设置输出文件夹
|
||||
if output_folder is None:
|
||||
# Determine output folder with environment variable support
|
||||
if output_folder is not None:
|
||||
# Use provided parameter
|
||||
pass
|
||||
elif DEFAULT_BATCH_OUTPUT_DIR is not None:
|
||||
# Use environment variable
|
||||
output_folder = DEFAULT_BATCH_OUTPUT_DIR
|
||||
else:
|
||||
# Use default subfolder
|
||||
output_folder = os.path.join(audio_folder, "transcript")
|
||||
|
||||
# 确保输出目录存在
|
||||
# Ensure output directory exists
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# 验证输出格式
|
||||
valid_formats = ["vtt", "srt", "json"]
|
||||
# Validate output format
|
||||
valid_formats = ["txt", "vtt", "srt", "json"]
|
||||
if output_format.lower() not in valid_formats:
|
||||
return f"错误: 不支持的输出格式: {output_format}。支持的格式: {', '.join(valid_formats)}"
|
||||
return f"Error: Unsupported output format: {output_format}. Supported formats: {', '.join(valid_formats)}"
|
||||
|
||||
# 获取所有音频文件
|
||||
# Get all audio files
|
||||
audio_files = []
|
||||
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
||||
|
||||
@@ -225,37 +278,35 @@ def batch_transcribe(
|
||||
audio_files.append(os.path.join(audio_folder, filename))
|
||||
|
||||
if not audio_files:
|
||||
return f"在 {audio_folder} 中未找到支持的音频文件。支持的格式: {', '.join(supported_formats)}"
|
||||
return f"No supported audio files found in {audio_folder}. Supported formats: {', '.join(supported_formats)}"
|
||||
|
||||
# 记录开始时间
|
||||
# Record start time
|
||||
start_time = time.time()
|
||||
total_files = len(audio_files)
|
||||
logger.info(f"开始批量转录 {total_files} 个文件,输出格式: {output_format}")
|
||||
logger.info(f"Starting batch transcription of {total_files} files, output format: {output_format}")
|
||||
|
||||
# 预加载模型以避免重复加载
|
||||
# Preload model
|
||||
try:
|
||||
get_whisper_model(model_name, device, compute_type)
|
||||
logger.info(f"已预加载模型: {model_name}")
|
||||
logger.info(f"Model preloaded: {model_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"预加载模型失败: {str(e)}")
|
||||
return f"批处理失败: 无法加载模型 {model_name}: {str(e)}"
|
||||
logger.error(f"Failed to preload model: {str(e)}")
|
||||
return f"Batch processing failed: Cannot load model {model_name}: {str(e)}"
|
||||
|
||||
# 处理每个文件
|
||||
# Process files
|
||||
results = []
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
total_audio_duration = 0
|
||||
|
||||
# 处理每个文件
|
||||
for i, audio_path in enumerate(audio_files):
|
||||
file_name = os.path.basename(audio_path)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# 报告进度
|
||||
# Report progress
|
||||
progress_msg = report_progress(i, total_files, elapsed)
|
||||
logger.info(f"{progress_msg} | 当前处理: {file_name}")
|
||||
logger.info(f"{progress_msg} | Currently processing: {file_name}")
|
||||
|
||||
# 执行转录
|
||||
# Execute transcription
|
||||
try:
|
||||
result = transcribe_audio(
|
||||
audio_path=audio_path,
|
||||
@@ -270,57 +321,39 @@ def batch_transcribe(
|
||||
output_directory=output_folder
|
||||
)
|
||||
|
||||
# 检查结果是否包含错误信息
|
||||
if result.startswith("错误:") or result.startswith("转录过程中发生错误:"):
|
||||
logger.error(f"转录失败: {file_name} - {result}")
|
||||
results.append(f"❌ 失败: {file_name} - {result}")
|
||||
if result.startswith("Error:") or result.startswith("Error occurred during transcription:"):
|
||||
logger.error(f"Transcription failed: {file_name} - {result}")
|
||||
results.append(f"❌ Failed: {file_name} - {result}")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# 如果转录成功,提取输出路径信息
|
||||
if result.startswith("转录成功"):
|
||||
# 从返回消息中提取输出路径
|
||||
output_path = result.split(": ")[1] if ": " in result else "未知路径"
|
||||
else:
|
||||
output_path = result.split(": ")[1] if ": " in result else "Unknown path"
|
||||
success_count += 1
|
||||
results.append(f"✅ 成功: {file_name} -> {os.path.basename(output_path)}")
|
||||
results.append(f"✅ Success: {file_name} -> {os.path.basename(output_path)}")
|
||||
|
||||
# 提取音频时长
|
||||
audio_duration = 0
|
||||
if output_format.lower() == "json":
|
||||
# 尝试从输出文件中解析音频时长
|
||||
try:
|
||||
import json
|
||||
# 从输出文件中读取JSON内容
|
||||
with open(output_path, "r", encoding="utf-8") as json_file:
|
||||
json_content = json_file.read()
|
||||
json_data = json.loads(json_content)
|
||||
audio_duration = json_data.get("duration", 0)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法从JSON文件中提取音频时长: {str(e)}")
|
||||
audio_duration = 0
|
||||
else:
|
||||
# 尝试从文件名中提取音频信息
|
||||
try:
|
||||
# 这里我们不能直接访问info对象,因为它在transcribe_audio函数的作用域内
|
||||
# 使用一个保守的估计值或从结果字符串中提取信息
|
||||
audio_duration = 0 # 默认为0
|
||||
except Exception as e:
|
||||
logger.warning(f"无法从文件名中提取音频时长: {str(e)}")
|
||||
audio_duration = 0
|
||||
|
||||
# 累加音频时长
|
||||
total_audio_duration += audio_duration
|
||||
except Exception as e:
|
||||
logger.error(f"转录过程中发生错误: {file_name} - {str(e)}")
|
||||
results.append(f"❌ 失败: {file_name} - {str(e)}")
|
||||
logger.error(f"Error occurred during transcription process: {file_name} - {str(e)}")
|
||||
results.append(f"❌ Failed: {file_name} - {str(e)}")
|
||||
error_count += 1
|
||||
# 计算总转录时间
|
||||
|
||||
# Calculate total transcription time
|
||||
total_transcription_time = time.time() - start_time
|
||||
# 生成批处理结果摘要
|
||||
summary = f"批处理完成,总转录时间: {format_time(total_transcription_time)}"
|
||||
summary += f" | 成功: {success_count}/{total_files}"
|
||||
summary += f" | 失败: {error_count}/{total_files}"
|
||||
# 输出结果
|
||||
|
||||
# Generate summary
|
||||
summary = f"Batch processing completed, total transcription time: {format_time(total_transcription_time)}"
|
||||
summary += f" | Success: {success_count}/{total_files}"
|
||||
summary += f" | Failed: {error_count}/{total_files}"
|
||||
|
||||
# Output results
|
||||
for result in results:
|
||||
logger.info(result)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def report_progress(current: int, total: int, elapsed_time: float) -> str:
|
||||
"""Generate progress report"""
|
||||
progress = current / total * 100
|
||||
eta = (elapsed_time / current) * (total - current) if current > 0 else 0
|
||||
return (f"Progress: {current}/{total} ({progress:.1f}%)" +
|
||||
f" | Time used: {format_time(elapsed_time)}" +
|
||||
f" | Estimated remaining: {format_time(eta)}")
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
基于Faster Whisper的语音识别MCP服务
|
||||
提供高性能的音频转录功能,支持批处理加速和多种输出格式
|
||||
Faster Whisper-based Speech Recognition MCP Service
|
||||
Provides high-performance audio transcription with batch processing acceleration and multiple output formats
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -11,11 +11,11 @@ from mcp.server.fastmcp import FastMCP
|
||||
from model_manager import get_model_info
|
||||
from transcriber import transcribe_audio, batch_transcribe
|
||||
|
||||
# 日志配置
|
||||
# Log configuration
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建FastMCP服务器实例
|
||||
# Create FastMCP server instance
|
||||
mcp = FastMCP(
|
||||
name="fast-whisper-mcp-server",
|
||||
version="0.1.1",
|
||||
@@ -25,7 +25,7 @@ mcp = FastMCP(
|
||||
@mcp.tool()
|
||||
def get_model_info_api() -> str:
|
||||
"""
|
||||
获取可用的Whisper模型信息
|
||||
Get available Whisper model information
|
||||
"""
|
||||
return get_model_info()
|
||||
|
||||
@@ -35,22 +35,22 @@ def transcribe(audio_path: str, model_name: str = "large-v3", device: str = "aut
|
||||
beam_size: int = 5, temperature: float = 0.0, initial_prompt: str = None,
|
||||
output_directory: str = None) -> str:
|
||||
"""
|
||||
使用Faster Whisper转录音频文件
|
||||
Transcribe audio files using Faster Whisper
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: 运行设备 (cpu, cuda, auto)
|
||||
compute_type: 计算类型 (float16, int8, auto)
|
||||
language: 语言代码 (如zh, en, ja等,默认自动检测)
|
||||
output_format: 输出格式 (vtt, srt或json)
|
||||
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
|
||||
temperature: 采样温度,贪婪解码
|
||||
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
|
||||
output_directory: 输出目录路径,默认为音频文件所在目录
|
||||
audio_path: Path to the audio file
|
||||
model_name: Model name (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: Execution device (cpu, cuda, auto)
|
||||
compute_type: Computation type (float16, int8, auto)
|
||||
language: Language code (such as zh, en, ja, etc., auto-detect by default)
|
||||
output_format: Output format (vtt, srt, json or txt)
|
||||
beam_size: Beam search size, larger values may improve accuracy but reduce speed
|
||||
temperature: Sampling temperature, greedy decoding
|
||||
initial_prompt: Initial prompt text, can help the model better understand context
|
||||
output_directory: Output directory path, defaults to the audio file's directory
|
||||
|
||||
Returns:
|
||||
str: 转录结果,格式为VTT字幕或JSON
|
||||
str: Transcription result, in VTT subtitle or JSON format
|
||||
"""
|
||||
return transcribe_audio(
|
||||
audio_path=audio_path,
|
||||
@@ -71,23 +71,23 @@ def batch_transcribe_audio(audio_folder: str, output_folder: str = None, model_n
|
||||
output_format: str = "vtt", beam_size: int = 5, temperature: float = 0.0,
|
||||
initial_prompt: str = None, parallel_files: int = 1) -> str:
|
||||
"""
|
||||
批量转录文件夹中的音频文件
|
||||
Batch transcribe audio files in a folder
|
||||
|
||||
Args:
|
||||
audio_folder: 包含音频文件的文件夹路径
|
||||
output_folder: 输出文件夹路径,默认为audio_folder下的transcript子文件夹
|
||||
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: 运行设备 (cpu, cuda, auto)
|
||||
compute_type: 计算类型 (float16, int8, auto)
|
||||
language: 语言代码 (如zh, en, ja等,默认自动检测)
|
||||
output_format: 输出格式 (vtt, srt或json)
|
||||
beam_size: 波束搜索大小,较大的值可能提高准确性但会降低速度
|
||||
temperature: 采样温度,0表示贪婪解码
|
||||
initial_prompt: 初始提示文本,可以帮助模型更好地理解上下文
|
||||
parallel_files: 并行处理的文件数量(仅在CPU模式下有效)
|
||||
audio_folder: Path to the folder containing audio files
|
||||
output_folder: Output folder path, defaults to a 'transcript' subfolder in audio_folder
|
||||
model_name: Model name (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: Execution device (cpu, cuda, auto)
|
||||
compute_type: Computation type (float16, int8, auto)
|
||||
language: Language code (such as zh, en, ja, etc., auto-detect by default)
|
||||
output_format: Output format (vtt, srt, json or txt)
|
||||
beam_size: Beam search size, larger values may improve accuracy but reduce speed
|
||||
temperature: Sampling temperature, 0 means greedy decoding
|
||||
initial_prompt: Initial prompt text, can help the model better understand context
|
||||
parallel_files: Number of files to process in parallel (only effective in CPU mode)
|
||||
|
||||
Returns:
|
||||
str: 批处理结果摘要,包含处理时间和成功率
|
||||
str: Batch processing summary, including processing time and success rate
|
||||
"""
|
||||
return batch_transcribe(
|
||||
audio_folder=audio_folder,
|
||||
@@ -104,5 +104,5 @@ def batch_transcribe_audio(audio_folder: str, output_folder: str = None, model_n
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行服务器
|
||||
print("starting mcp server for whisper stt transcriptor")
|
||||
mcp.run()
|
||||
Reference in New Issue
Block a user