Files
Fast-Whisper-MCP-Server/whisper_server.py
BigUncleHomePC 5b5b952382 feat: 初始化基于Faster Whisper的语音识别MCP服务器
添加了服务器核心代码、启动脚本、依赖配置及文档,支持批处理加速、CUDA优化及多格式输出,便于集成到Claude Desktop中。
2025-03-22 03:23:54 +08:00

297 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
基于Faster Whisper的语音识别MCP服务
"""
import os
import json
import logging
from typing import Optional, Dict, List
import torch
from faster_whisper import WhisperModel, BatchedInferencePipeline
from mcp.server.fastmcp import FastMCP, Context
# 日志配置
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建FastMCP服务器实例
mcp = FastMCP(
name="whisper-server",
version="0.1.0",
dependencies=["faster-whisper>=0.9.0", "torch==2.6.0+cu126", "torchaudio==2.6.0+cu126"]
)
# 全局模型实例缓存
model_instances = {}
@mcp.tool()
def get_model_info() -> str:
"""获取可用的Whisper模型信息"""
models = [
"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"
]
devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
compute_types = ["float16", "int8"] if torch.cuda.is_available() else ["int8"]
info = {
"available_models": models,
"default_model": "large-v3",
"available_devices": devices,
"default_device": "cuda" if torch.cuda.is_available() else "cpu",
"available_compute_types": compute_types,
"default_compute_type": "float16" if torch.cuda.is_available() else "int8",
"cuda_available": torch.cuda.is_available()
}
if torch.cuda.is_available():
info["gpu_info"] = {
"name": torch.cuda.get_device_name(0),
"memory_total": f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
}
return json.dumps(info, indent=2)
def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict:
"""
获取或创建Whisper模型实例
Args:
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
device: 运行设备 (cpu, cuda)
compute_type: 计算类型 (float16, int8)
Returns:
dict: 包含模型实例和配置的字典
"""
global model_instances
# 生成模型键
model_key = f"{model_name}_{device}_{compute_type}"
# 如果模型已实例化,直接返回
if model_key in model_instances:
return model_instances[model_key]
# 自动检测设备
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
# 清理GPU内存如果使用CUDA
if device == "cuda":
torch.cuda.empty_cache()
# 实例化模型
try:
logger.info(f"加载Whisper模型: {model_name} 设备: {device} 计算类型: {compute_type}")
# 基础模型
model = WhisperModel(
model_name,
device=device,
compute_type=compute_type
)
# 批处理设置 - 默认启用批处理以提高速度
batched_model = None
batch_size = 0
if device == "cuda": # 只在CUDA设备上使用批处理
# 根据显存大小确定合适的批大小
if torch.cuda.is_available():
gpu_mem = torch.cuda.get_device_properties(0).total_memory
# 根据GPU显存动态调整批大小
if gpu_mem > 16e9: # >16GB
batch_size = 32
elif gpu_mem > 12e9: # >12GB
batch_size = 16
elif gpu_mem > 8e9: # >8GB
batch_size = 8
else: # 较小显存
batch_size = 4
else:
batch_size = 8 # 默认值
logger.info(f"启用批处理加速,批大小: {batch_size}")
batched_model = BatchedInferencePipeline(model=model)
# 创建结果对象
result = {
'model': model,
'device': device,
'compute_type': compute_type,
'batched_model': batched_model,
'batch_size': batch_size
}
# 缓存实例
model_instances[model_key] = result
return result
except Exception as e:
logger.error(f"加载模型失败: {str(e)}")
raise
@mcp.tool()
def transcribe(audio_path: str, model_name: str = "large-v3", device: str = "auto",
compute_type: str = "auto", language: str = None, output_format: str = "vtt") -> str:
"""
使用Faster Whisper转录音频文件
Args:
audio_path: 音频文件路径
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
device: 运行设备 (cpu, cuda, auto)
compute_type: 计算类型 (float16, int8, auto)
language: 语言代码 (如zh, en, ja等默认自动检测)
output_format: 输出格式 (vtt或json)
Returns:
str: 转录结果格式为VTT字幕或JSON
"""
# 验证参数
if not os.path.exists(audio_path):
return f"错误: 音频文件不存在: {audio_path}"
try:
# 获取模型实例
model_instance = get_whisper_model(model_name, device, compute_type)
# 设置转录参数
options = {
"language": language,
"vad_filter": True, # 使用语音活动检测
"vad_parameters": {"min_silence_duration_ms": 500}, # VAD参数优化
}
# 执行转录 - 优先使用批处理模型
if model_instance['batched_model'] is not None and model_instance['device'] == 'cuda':
logger.info("使用批处理加速进行转录...")
# 批处理模型需要单独设置batch_size参数
segments, info = model_instance['batched_model'].transcribe(
audio_path,
batch_size=model_instance['batch_size'],
**options
)
else:
logger.info("使用标准模型进行转录...")
segments, info = model_instance['model'].transcribe(audio_path, **options)
# 将生成器转换为列表
segment_list = list(segments)
if not segment_list:
return "转录失败,未获得结果"
# 根据输出格式返回结果
if output_format.lower() == "vtt":
return format_vtt(segment_list)
else:
return format_json(segment_list, info)
except Exception as e:
logger.error(f"转录失败: {str(e)}")
return f"转录过程中发生错误: {str(e)}"
def format_vtt(segments) -> str:
"""将转录结果格式化为VTT"""
vtt_content = "WEBVTT\n\n"
for segment in segments:
start = format_timestamp(segment.start)
end = format_timestamp(segment.end)
text = segment.text.strip()
if text:
vtt_content += f"{start} --> {end}\n{text}\n\n"
return vtt_content
def format_json(segments, info) -> str:
"""将转录结果格式化为JSON"""
result = {
"segments": [{
"start": segment.start,
"end": segment.end,
"text": segment.text
} for segment in segments],
"language": info.language,
"duration": info.duration
}
return json.dumps(result, indent=2, ensure_ascii=False)
def format_timestamp(seconds: float) -> str:
"""格式化时间戳为VTT格式"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}"
@mcp.tool()
def batch_transcribe(audio_folder: str, output_folder: str = None, model_name: str = "large-v3",
device: str = "auto", compute_type: str = "auto") -> str:
"""
批量转录文件夹中的音频文件
Args:
audio_folder: 包含音频文件的文件夹路径
output_folder: 输出文件夹路径默认为audio_folder下的transcript子文件夹
model_name: 模型名称
device: 运行设备
compute_type: 计算类型
Returns:
str: 批处理结果摘要
"""
if not os.path.isdir(audio_folder):
return f"错误: 文件夹不存在: {audio_folder}"
# 设置输出文件夹
if output_folder is None:
output_folder = os.path.join(audio_folder, "transcript")
# 确保输出目录存在
os.makedirs(output_folder, exist_ok=True)
# 获取所有音频文件
audio_files = []
for filename in os.listdir(audio_folder):
if filename.lower().endswith(('.mp3', '.wav', '.m4a', '.flac')):
audio_files.append(os.path.join(audio_folder, filename))
if not audio_files:
return f"{audio_folder} 中未找到音频文件"
# 处理每个文件
results = []
for i, audio_path in enumerate(audio_files):
logger.info(f"处理第 {i+1}/{len(audio_files)} 个文件: {os.path.basename(audio_path)}")
# 设置输出文件路径
base_name = os.path.splitext(os.path.basename(audio_path))[0]
vtt_path = os.path.join(output_folder, f"{base_name}.vtt")
# 执行转录
result = transcribe(
audio_path=audio_path,
model_name=model_name,
device=device,
compute_type=compute_type,
output_format="vtt"
)
# 保存结果到文件
with open(vtt_path, 'w', encoding='utf-8') as f:
f.write(result)
results.append(f"已转录: {os.path.basename(audio_path)} -> {os.path.basename(vtt_path)}")
summary = f"批处理完成,成功转录 {len(results)}/{len(audio_files)} 个文件\n\n"
summary += "\n".join(results)
return summary
if __name__ == "__main__":
# 运行服务器
mcp.run()