201 lines
7.1 KiB
Python
201 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Model Management Module
|
|
Responsible for loading, caching, and managing Whisper models
|
|
"""
|
|
|
|
import os; print(os.environ.get("WHISPER_MODEL_DIR"))
|
|
import time
|
|
import logging
|
|
from typing import Dict, Any
|
|
import torch
|
|
from faster_whisper import WhisperModel, BatchedInferencePipeline
|
|
|
|
# Log configuration
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global model instance cache
|
|
model_instances = {}
|
|
|
|
def test_gpu_driver():
|
|
"""Simple GPU driver test"""
|
|
try:
|
|
if not torch.cuda.is_available():
|
|
logger.error("CUDA not available in PyTorch")
|
|
raise RuntimeError("CUDA not available")
|
|
|
|
gpu_count = torch.cuda.device_count()
|
|
if gpu_count == 0:
|
|
logger.error("No CUDA devices found")
|
|
raise RuntimeError("No CUDA devices")
|
|
|
|
# Quick GPU test
|
|
test_tensor = torch.randn(10, 10).cuda()
|
|
_ = test_tensor @ test_tensor.T
|
|
|
|
device_name = torch.cuda.get_device_name(0)
|
|
memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
|
logger.info(f"GPU test passed: {device_name} ({memory_gb:.1f}GB)")
|
|
|
|
except Exception as e:
|
|
logger.error(f"GPU test failed: {e}")
|
|
raise RuntimeError(f"GPU initialization failed: {e}")
|
|
|
|
def get_whisper_model(model_name: str, device: str, compute_type: str) -> Dict[str, Any]:
|
|
"""
|
|
Get or create Whisper model instance
|
|
|
|
Args:
|
|
model_name: Model name (tiny, base, small, medium, large-v1, large-v2, large-v3)
|
|
device: Running device (cpu, cuda, auto)
|
|
compute_type: Computation type (float16, int8, auto)
|
|
|
|
Returns:
|
|
dict: Dictionary containing model instance and configuration
|
|
"""
|
|
global model_instances
|
|
|
|
# Validate model name
|
|
valid_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"]
|
|
if model_name not in valid_models:
|
|
raise ValueError(f"Invalid model name: {model_name}. Valid models: {', '.join(valid_models)}")
|
|
|
|
# Auto-detect device
|
|
if device == "auto":
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
compute_type = "float16" if device == "cuda" else "int8"
|
|
|
|
# Validate device and compute type
|
|
if device not in ["cpu", "cuda"]:
|
|
raise ValueError(f"Invalid device: {device}. Valid devices: cpu, cuda")
|
|
|
|
if device == "cuda" and not torch.cuda.is_available():
|
|
logger.error("CUDA requested but not available")
|
|
raise RuntimeError("CUDA not available but explicitly requested")
|
|
|
|
if compute_type not in ["float16", "int8"]:
|
|
raise ValueError(f"Invalid compute type: {compute_type}. Valid compute types: float16, int8")
|
|
|
|
if device == "cpu" and compute_type == "float16":
|
|
logger.warning("CPU device does not support float16 computation type, automatically switching to int8")
|
|
compute_type = "int8"
|
|
|
|
# Generate model key
|
|
model_key = f"{model_name}_{device}_{compute_type}"
|
|
|
|
# If model is already instantiated, return directly
|
|
if model_key in model_instances:
|
|
logger.info(f"Using cached model instance: {model_key}")
|
|
return model_instances[model_key]
|
|
|
|
# Test GPU driver before loading model and clean
|
|
if device == "cuda":
|
|
test_gpu_driver()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Instantiate model
|
|
try:
|
|
logger.info(f"Loading Whisper model: {model_name} device: {device} compute type: {compute_type}")
|
|
|
|
# Base model
|
|
model = WhisperModel(
|
|
model_name,
|
|
device=device,
|
|
compute_type=compute_type,
|
|
download_root=os.environ.get("WHISPER_MODEL_DIR", None) # Support custom model directory
|
|
)
|
|
|
|
# Batch processing settings - batch processing enabled by default to improve speed
|
|
batched_model = None
|
|
batch_size = 0
|
|
|
|
if device == "cuda": # Only use batch processing on CUDA devices
|
|
# Determine appropriate batch size based on available memory
|
|
if torch.cuda.is_available():
|
|
gpu_mem = torch.cuda.get_device_properties(0).total_memory
|
|
free_mem = gpu_mem - torch.cuda.memory_allocated()
|
|
# Dynamically adjust batch size based on GPU memory
|
|
if free_mem > 16e9: # >16GB
|
|
batch_size = 32
|
|
elif free_mem > 12e9: # >12GB
|
|
batch_size = 16
|
|
elif free_mem > 8e9: # >8GB
|
|
batch_size = 8
|
|
elif free_mem > 4e9: # >4GB
|
|
batch_size = 4
|
|
else: # Smaller memory
|
|
batch_size = 2
|
|
|
|
logger.info(f"Available GPU memory: {free_mem / 1e9:.2f} GB")
|
|
else:
|
|
batch_size = 8 # Default value
|
|
|
|
logger.info(f"Enabling batch processing acceleration, batch size: {batch_size}")
|
|
batched_model = BatchedInferencePipeline(model=model)
|
|
|
|
# Create result object
|
|
result = {
|
|
'model': model,
|
|
'device': device,
|
|
'compute_type': compute_type,
|
|
'batched_model': batched_model,
|
|
'batch_size': batch_size,
|
|
'load_time': time.time()
|
|
}
|
|
|
|
# Cache instance
|
|
model_instances[model_key] = result
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {str(e)}")
|
|
raise
|
|
|
|
def get_model_info() -> str:
|
|
"""
|
|
Get available Whisper model information
|
|
|
|
Returns:
|
|
str: JSON string of model information
|
|
"""
|
|
import json
|
|
|
|
models = [
|
|
"tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"
|
|
]
|
|
devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
|
|
compute_types = ["float16", "int8"] if torch.cuda.is_available() else ["int8"]
|
|
|
|
# Supported language list
|
|
languages = {
|
|
"zh": "Chinese", "en": "English", "ja": "Japanese", "ko": "Korean", "de": "German",
|
|
"fr": "French", "es": "Spanish", "ru": "Russian", "it": "Italian",
|
|
"pt": "Portuguese", "nl": "Dutch", "ar": "Arabic", "hi": "Hindi",
|
|
"tr": "Turkish", "vi": "Vietnamese", "th": "Thai", "id": "Indonesian"
|
|
}
|
|
|
|
# Supported audio formats
|
|
audio_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
|
|
|
info = {
|
|
"available_models": models,
|
|
"default_model": "large-v3",
|
|
"available_devices": devices,
|
|
"default_device": "cuda" if torch.cuda.is_available() else "cpu",
|
|
"available_compute_types": compute_types,
|
|
"default_compute_type": "float16" if torch.cuda.is_available() else "int8",
|
|
"cuda_available": torch.cuda.is_available(),
|
|
"supported_languages": languages,
|
|
"supported_audio_formats": audio_formats,
|
|
"version": "0.1.1"
|
|
}
|
|
|
|
if torch.cuda.is_available():
|
|
info["gpu_info"] = {
|
|
"name": torch.cuda.get_device_name(0),
|
|
"memory_total": f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB",
|
|
"memory_available": f"{torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated() / 1e9:.2f} GB"
|
|
}
|
|
|
|
return json.dumps(info, indent=2)
|