refactor(whisper_server): 重构代码以模块化转录功能
将转录核心逻辑拆分为独立模块(transcriber.py、model_manager.py、audio_processor.py、formatters.py),提升代码可维护性和复用性。删除main.py文件,优化依赖管理并更新requirements.txt和pyproject.toml。
This commit is contained in:
176
model_manager.py
Normal file
176
model_manager.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
模型管理模块
|
||||
负责Whisper模型的加载、缓存和管理
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
import torch
|
||||
from faster_whisper import WhisperModel, BatchedInferencePipeline
|
||||
|
||||
# 日志配置
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局模型实例缓存
|
||||
model_instances = {}
|
||||
|
||||
def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取或创建Whisper模型实例
|
||||
|
||||
Args:
|
||||
model_name: 模型名称 (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
||||
device: 运行设备 (cpu, cuda, auto)
|
||||
compute_type: 计算类型 (float16, int8, auto)
|
||||
|
||||
Returns:
|
||||
dict: 包含模型实例和配置的字典
|
||||
"""
|
||||
global model_instances
|
||||
|
||||
# 验证模型名称
|
||||
valid_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"]
|
||||
if model_name not in valid_models:
|
||||
raise ValueError(f"无效的模型名称: {model_name}。有效的模型: {', '.join(valid_models)}")
|
||||
|
||||
# 自动检测设备
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
|
||||
# 验证设备和计算类型
|
||||
if device not in ["cpu", "cuda"]:
|
||||
raise ValueError(f"无效的设备: {device}。有效的设备: cpu, cuda")
|
||||
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
logger.warning("CUDA不可用,自动切换到CPU")
|
||||
device = "cpu"
|
||||
compute_type = "int8"
|
||||
|
||||
if compute_type not in ["float16", "int8"]:
|
||||
raise ValueError(f"无效的计算类型: {compute_type}。有效的计算类型: float16, int8")
|
||||
|
||||
if device == "cpu" and compute_type == "float16":
|
||||
logger.warning("CPU设备不支持float16计算类型,自动切换到int8")
|
||||
compute_type = "int8"
|
||||
|
||||
# 生成模型键
|
||||
model_key = f"{model_name}_{device}_{compute_type}"
|
||||
|
||||
# 如果模型已实例化,直接返回
|
||||
if model_key in model_instances:
|
||||
logger.info(f"使用缓存的模型实例: {model_key}")
|
||||
return model_instances[model_key]
|
||||
|
||||
# 清理GPU内存(如果使用CUDA)
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 实例化模型
|
||||
try:
|
||||
logger.info(f"加载Whisper模型: {model_name} 设备: {device} 计算类型: {compute_type}")
|
||||
|
||||
# 基础模型
|
||||
model = WhisperModel(
|
||||
model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
download_root=os.environ.get("WHISPER_MODEL_DIR", None) # 支持自定义模型目录
|
||||
)
|
||||
|
||||
# 批处理设置 - 默认启用批处理以提高速度
|
||||
batched_model = None
|
||||
batch_size = 0
|
||||
|
||||
if device == "cuda": # 只在CUDA设备上使用批处理
|
||||
# 根据显存大小确定合适的批大小
|
||||
if torch.cuda.is_available():
|
||||
gpu_mem = torch.cuda.get_device_properties(0).total_memory
|
||||
free_mem = gpu_mem - torch.cuda.memory_allocated()
|
||||
# 根据GPU显存动态调整批大小
|
||||
if free_mem > 16e9: # >16GB
|
||||
batch_size = 32
|
||||
elif free_mem > 12e9: # >12GB
|
||||
batch_size = 16
|
||||
elif free_mem > 8e9: # >8GB
|
||||
batch_size = 8
|
||||
elif free_mem > 4e9: # >4GB
|
||||
batch_size = 4
|
||||
else: # 较小显存
|
||||
batch_size = 2
|
||||
|
||||
logger.info(f"可用GPU显存: {free_mem / 1e9:.2f} GB")
|
||||
else:
|
||||
batch_size = 8 # 默认值
|
||||
|
||||
logger.info(f"启用批处理加速,批大小: {batch_size}")
|
||||
batched_model = BatchedInferencePipeline(model=model)
|
||||
|
||||
# 创建结果对象
|
||||
result = {
|
||||
'model': model,
|
||||
'device': device,
|
||||
'compute_type': compute_type,
|
||||
'batched_model': batched_model,
|
||||
'batch_size': batch_size,
|
||||
'load_time': time.time()
|
||||
}
|
||||
|
||||
# 缓存实例
|
||||
model_instances[model_key] = result
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载模型失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_model_info() -> str:
|
||||
"""
|
||||
获取可用的Whisper模型信息
|
||||
|
||||
Returns:
|
||||
str: 模型信息的JSON字符串
|
||||
"""
|
||||
import json
|
||||
|
||||
models = [
|
||||
"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"
|
||||
]
|
||||
devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
|
||||
compute_types = ["float16", "int8"] if torch.cuda.is_available() else ["int8"]
|
||||
|
||||
# 支持的语言列表
|
||||
languages = {
|
||||
"zh": "中文", "en": "英语", "ja": "日语", "ko": "韩语", "de": "德语",
|
||||
"fr": "法语", "es": "西班牙语", "ru": "俄语", "it": "意大利语",
|
||||
"pt": "葡萄牙语", "nl": "荷兰语", "ar": "阿拉伯语", "hi": "印地语",
|
||||
"tr": "土耳其语", "vi": "越南语", "th": "泰语", "id": "印尼语"
|
||||
}
|
||||
|
||||
# 支持的音频格式
|
||||
audio_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
||||
|
||||
info = {
|
||||
"available_models": models,
|
||||
"default_model": "large-v3",
|
||||
"available_devices": devices,
|
||||
"default_device": "cuda" if torch.cuda.is_available() else "cpu",
|
||||
"available_compute_types": compute_types,
|
||||
"default_compute_type": "float16" if torch.cuda.is_available() else "int8",
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"supported_languages": languages,
|
||||
"supported_audio_formats": audio_formats,
|
||||
"version": "0.1.1"
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
info["gpu_info"] = {
|
||||
"name": torch.cuda.get_device_name(0),
|
||||
"memory_total": f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB",
|
||||
"memory_available": f"{torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated() / 1e9:.2f} GB"
|
||||
}
|
||||
|
||||
return json.dumps(info, indent=2)
|
||||
Reference in New Issue
Block a user