360 lines
14 KiB
Python
360 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Transcription Core Module with Environment Variable Support
|
|
Contains core logic for audio transcription
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import logging
|
|
from typing import Dict, Any, Tuple, List, Optional, Union
|
|
|
|
from model_manager import get_whisper_model
|
|
from audio_processor import validate_audio_file, process_audio
|
|
from formatters import format_vtt, format_srt, format_json, format_txt, format_time
|
|
|
|
# Logging configuration
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Environment variable defaults
|
|
DEFAULT_OUTPUT_DIR = os.getenv('TRANSCRIPTION_OUTPUT_DIR', None)
|
|
DEFAULT_BATCH_OUTPUT_DIR = os.getenv('TRANSCRIPTION_BATCH_OUTPUT_DIR', None)
|
|
DEFAULT_MODEL = os.getenv('TRANSCRIPTION_MODEL', 'base')
|
|
DEFAULT_DEVICE = os.getenv('TRANSCRIPTION_DEVICE', 'cuda')
|
|
DEFAULT_COMPUTE_TYPE = os.getenv('TRANSCRIPTION_COMPUTE_TYPE', 'base')
|
|
DEFAULT_LANGUAGE = os.getenv('TRANSCRIPTION_LANGUAGE', None)
|
|
DEFAULT_OUTPUT_FORMAT = os.getenv('TRANSCRIPTION_OUTPUT_FORMAT', 'txt')
|
|
DEFAULT_BEAM_SIZE = int(os.getenv('TRANSCRIPTION_BEAM_SIZE', '5'))
|
|
DEFAULT_TEMPERATURE = float(os.getenv('TRANSCRIPTION_TEMPERATURE', '0.0'))
|
|
|
|
# Model storage configuration
|
|
WHISPER_MODEL_DIR = os.getenv('WHISPER_MODEL_DIR', None)
|
|
|
|
# File naming configuration
|
|
USE_TIMESTAMP = os.getenv('TRANSCRIPTION_USE_TIMESTAMP', 'false').lower() == 'true'
|
|
FILENAME_PREFIX = os.getenv('TRANSCRIPTION_FILENAME_PREFIX', '')
|
|
FILENAME_SUFFIX = os.getenv('TRANSCRIPTION_FILENAME_SUFFIX', '')
|
|
|
|
def transcribe_audio(
|
|
audio_path: str,
|
|
model_name: str = None,
|
|
device: str = None,
|
|
compute_type: str = None,
|
|
language: str = None,
|
|
output_format: str = None,
|
|
beam_size: int = None,
|
|
temperature: float = None,
|
|
initial_prompt: str = None,
|
|
output_directory: str = None
|
|
) -> str:
|
|
"""
|
|
Transcribe audio file using Faster Whisper with ENV VAR support
|
|
|
|
Args:
|
|
audio_path: Path to audio file
|
|
model_name: Model name (defaults to TRANSCRIPTION_MODEL env var or "large-v3")
|
|
device: Execution device (defaults to TRANSCRIPTION_DEVICE env var or "auto")
|
|
compute_type: Computation type (defaults to TRANSCRIPTION_COMPUTE_TYPE env var or "auto")
|
|
language: Language code (defaults to TRANSCRIPTION_LANGUAGE env var or auto-detect)
|
|
output_format: Output format (defaults to TRANSCRIPTION_OUTPUT_FORMAT env var or "vtt")
|
|
beam_size: Beam search size (defaults to TRANSCRIPTION_BEAM_SIZE env var or 5)
|
|
temperature: Sampling temperature (defaults to TRANSCRIPTION_TEMPERATURE env var or 0.0)
|
|
initial_prompt: Initial prompt text
|
|
output_directory: Output directory (defaults to TRANSCRIPTION_OUTPUT_DIR env var or audio file directory)
|
|
|
|
Returns:
|
|
str: Transcription result path or error message
|
|
"""
|
|
# Apply environment variable defaults
|
|
model_name = model_name or DEFAULT_MODEL
|
|
device = device or DEFAULT_DEVICE
|
|
compute_type = compute_type or DEFAULT_COMPUTE_TYPE
|
|
language = language or DEFAULT_LANGUAGE
|
|
output_format = output_format or DEFAULT_OUTPUT_FORMAT
|
|
beam_size = beam_size if beam_size is not None else DEFAULT_BEAM_SIZE
|
|
temperature = temperature if temperature is not None else DEFAULT_TEMPERATURE
|
|
|
|
# Validate audio file
|
|
validation_result = validate_audio_file(audio_path)
|
|
if validation_result != "ok":
|
|
return validation_result
|
|
|
|
try:
|
|
# Get model instance
|
|
model_instance = get_whisper_model(model_name, device, compute_type)
|
|
|
|
# Validate language code
|
|
supported_languages = {
|
|
"zh": "Chinese", "en": "English", "ja": "Japanese", "ko": "Korean", "de": "German",
|
|
"fr": "French", "es": "Spanish", "ru": "Russian", "it": "Italian",
|
|
"pt": "Portuguese", "nl": "Dutch", "ar": "Arabic", "hi": "Hindi",
|
|
"tr": "Turkish", "vi": "Vietnamese", "th": "Thai", "id": "Indonesian"
|
|
}
|
|
|
|
if language is not None and language not in supported_languages:
|
|
logger.warning(f"Unknown language code: {language}, will use auto-detection")
|
|
language = None
|
|
|
|
# Set transcription parameters
|
|
options = {
|
|
"verbose": True,
|
|
"language": language,
|
|
"vad_filter": True,
|
|
"vad_parameters": {"min_silence_duration_ms": 500},
|
|
"beam_size": beam_size,
|
|
"temperature": temperature,
|
|
"initial_prompt": initial_prompt,
|
|
"word_timestamps": True,
|
|
"suppress_tokens": [-1],
|
|
"condition_on_previous_text": True,
|
|
"compression_ratio_threshold": 2.4,
|
|
}
|
|
|
|
start_time = time.time()
|
|
logger.info(f"Starting transcription of file: {os.path.basename(audio_path)}")
|
|
|
|
# Process audio
|
|
audio_source = process_audio(audio_path)
|
|
|
|
# Execute transcription
|
|
if model_instance['batched_model'] is not None and model_instance['device'] == 'cuda':
|
|
logger.info("Using batch acceleration for transcription...")
|
|
segments, info = model_instance['batched_model'].transcribe(
|
|
audio_source,
|
|
batch_size=model_instance['batch_size'],
|
|
**options
|
|
)
|
|
else:
|
|
logger.info("Using standard model for transcription...")
|
|
segments, info = model_instance['model'].transcribe(audio_source, **options)
|
|
|
|
# Convert generator to list
|
|
segment_list = list(segments)
|
|
|
|
if not segment_list:
|
|
return "Transcription failed, no results obtained"
|
|
|
|
# Record transcription information
|
|
elapsed_time = time.time() - start_time
|
|
logger.info(f"Transcription completed, time used: {elapsed_time:.2f} seconds, detected language: {info.language}, audio length: {info.duration:.2f} seconds")
|
|
|
|
# Format transcription results based on output format
|
|
output_format_lower = output_format.lower()
|
|
|
|
if output_format_lower == "vtt":
|
|
transcription_result = format_vtt(segment_list)
|
|
elif output_format_lower == "srt":
|
|
transcription_result = format_srt(segment_list)
|
|
elif output_format_lower == "txt":
|
|
transcription_result = format_txt(segment_list)
|
|
elif output_format_lower == "json":
|
|
transcription_result = format_json(segment_list, info)
|
|
else:
|
|
raise ValueError(f"Unsupported output format: {output_format}. Supported formats: vtt, srt, txt, json")
|
|
|
|
# Determine output directory
|
|
audio_dir = os.path.dirname(audio_path)
|
|
audio_filename = os.path.splitext(os.path.basename(audio_path))[0]
|
|
|
|
# Priority: parameter > env var > audio directory
|
|
if output_directory is not None:
|
|
output_dir = output_directory
|
|
elif DEFAULT_OUTPUT_DIR is not None:
|
|
output_dir = DEFAULT_OUTPUT_DIR
|
|
else:
|
|
output_dir = audio_dir
|
|
|
|
# Ensure output directory exists
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Generate filename with customizable format
|
|
filename_parts = []
|
|
|
|
# Add prefix if specified
|
|
if FILENAME_PREFIX:
|
|
filename_parts.append(FILENAME_PREFIX)
|
|
|
|
# Add base filename
|
|
filename_parts.append(audio_filename)
|
|
|
|
# Add suffix if specified
|
|
if FILENAME_SUFFIX:
|
|
filename_parts.append(FILENAME_SUFFIX)
|
|
|
|
# Add timestamp if enabled
|
|
if USE_TIMESTAMP:
|
|
timestamp = time.strftime("%Y%m%d%H%M%S")
|
|
filename_parts.append(timestamp)
|
|
|
|
# Join parts and add extension
|
|
base_name = "_".join(filename_parts)
|
|
output_filename = f"{base_name}.{output_format_lower}"
|
|
output_path = os.path.join(output_dir, output_filename)
|
|
|
|
# Write transcription results to file
|
|
try:
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
f.write(transcription_result)
|
|
logger.info(f"Transcription results saved to: {output_path}")
|
|
return f"Transcription successful, results saved to: {output_path}"
|
|
except Exception as e:
|
|
logger.error(f"Failed to save transcription results: {str(e)}")
|
|
return f"Transcription successful, but failed to save results: {str(e)}"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Transcription failed: {str(e)}")
|
|
return f"Error occurred during transcription: {str(e)}"
|
|
|
|
|
|
def batch_transcribe(
|
|
audio_folder: str,
|
|
output_folder: str = None,
|
|
model_name: str = None,
|
|
device: str = None,
|
|
compute_type: str = None,
|
|
language: str = None,
|
|
output_format: str = None,
|
|
beam_size: int = None,
|
|
temperature: float = None,
|
|
initial_prompt: str = None,
|
|
parallel_files: int = 1
|
|
) -> str:
|
|
"""
|
|
Batch transcribe audio files with ENV VAR support
|
|
|
|
Args:
|
|
audio_folder: Path to folder containing audio files
|
|
output_folder: Output folder (defaults to TRANSCRIPTION_BATCH_OUTPUT_DIR env var or "transcript" subfolder)
|
|
model_name: Model name (defaults to TRANSCRIPTION_MODEL env var or "large-v3")
|
|
device: Execution device (defaults to TRANSCRIPTION_DEVICE env var or "auto")
|
|
compute_type: Computation type (defaults to TRANSCRIPTION_COMPUTE_TYPE env var or "auto")
|
|
language: Language code (defaults to TRANSCRIPTION_LANGUAGE env var or auto-detect)
|
|
output_format: Output format (defaults to TRANSCRIPTION_OUTPUT_FORMAT env var or "vtt")
|
|
beam_size: Beam search size (defaults to TRANSCRIPTION_BEAM_SIZE env var or 5)
|
|
temperature: Sampling temperature (defaults to TRANSCRIPTION_TEMPERATURE env var or 0.0)
|
|
initial_prompt: Initial prompt text
|
|
parallel_files: Number of files to process in parallel (not implemented yet)
|
|
|
|
Returns:
|
|
str: Batch processing summary
|
|
"""
|
|
# Apply environment variable defaults
|
|
model_name = model_name or DEFAULT_MODEL
|
|
device = device or DEFAULT_DEVICE
|
|
compute_type = compute_type or DEFAULT_COMPUTE_TYPE
|
|
language = language or DEFAULT_LANGUAGE
|
|
output_format = output_format or DEFAULT_OUTPUT_FORMAT
|
|
beam_size = beam_size if beam_size is not None else DEFAULT_BEAM_SIZE
|
|
temperature = temperature if temperature is not None else DEFAULT_TEMPERATURE
|
|
|
|
if not os.path.isdir(audio_folder):
|
|
return f"Error: Folder does not exist: {audio_folder}"
|
|
|
|
# Determine output folder with environment variable support
|
|
if output_folder is not None:
|
|
# Use provided parameter
|
|
pass
|
|
elif DEFAULT_BATCH_OUTPUT_DIR is not None:
|
|
# Use environment variable
|
|
output_folder = DEFAULT_BATCH_OUTPUT_DIR
|
|
else:
|
|
# Use default subfolder
|
|
output_folder = os.path.join(audio_folder, "transcript")
|
|
|
|
# Ensure output directory exists
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
|
|
# Validate output format
|
|
valid_formats = ["txt", "vtt", "srt", "json"]
|
|
if output_format.lower() not in valid_formats:
|
|
return f"Error: Unsupported output format: {output_format}. Supported formats: {', '.join(valid_formats)}"
|
|
|
|
# Get all audio files
|
|
audio_files = []
|
|
supported_formats = [".mp3", ".wav", ".m4a", ".flac", ".ogg", ".aac"]
|
|
|
|
for filename in os.listdir(audio_folder):
|
|
file_ext = os.path.splitext(filename)[1].lower()
|
|
if file_ext in supported_formats:
|
|
audio_files.append(os.path.join(audio_folder, filename))
|
|
|
|
if not audio_files:
|
|
return f"No supported audio files found in {audio_folder}. Supported formats: {', '.join(supported_formats)}"
|
|
|
|
# Record start time
|
|
start_time = time.time()
|
|
total_files = len(audio_files)
|
|
logger.info(f"Starting batch transcription of {total_files} files, output format: {output_format}")
|
|
|
|
# Preload model
|
|
try:
|
|
get_whisper_model(model_name, device, compute_type)
|
|
logger.info(f"Model preloaded: {model_name}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to preload model: {str(e)}")
|
|
return f"Batch processing failed: Cannot load model {model_name}: {str(e)}"
|
|
|
|
# Process files
|
|
results = []
|
|
success_count = 0
|
|
error_count = 0
|
|
|
|
for i, audio_path in enumerate(audio_files):
|
|
file_name = os.path.basename(audio_path)
|
|
elapsed = time.time() - start_time
|
|
|
|
# Report progress
|
|
progress_msg = report_progress(i, total_files, elapsed)
|
|
logger.info(f"{progress_msg} | Currently processing: {file_name}")
|
|
|
|
# Execute transcription
|
|
try:
|
|
result = transcribe_audio(
|
|
audio_path=audio_path,
|
|
model_name=model_name,
|
|
device=device,
|
|
compute_type=compute_type,
|
|
language=language,
|
|
output_format=output_format,
|
|
beam_size=beam_size,
|
|
temperature=temperature,
|
|
initial_prompt=initial_prompt,
|
|
output_directory=output_folder
|
|
)
|
|
|
|
if result.startswith("Error:") or result.startswith("Error occurred during transcription:"):
|
|
logger.error(f"Transcription failed: {file_name} - {result}")
|
|
results.append(f"❌ Failed: {file_name} - {result}")
|
|
error_count += 1
|
|
else:
|
|
output_path = result.split(": ")[1] if ": " in result else "Unknown path"
|
|
success_count += 1
|
|
results.append(f"✅ Success: {file_name} -> {os.path.basename(output_path)}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error occurred during transcription process: {file_name} - {str(e)}")
|
|
results.append(f"❌ Failed: {file_name} - {str(e)}")
|
|
error_count += 1
|
|
|
|
# Calculate total transcription time
|
|
total_transcription_time = time.time() - start_time
|
|
|
|
# Generate summary
|
|
summary = f"Batch processing completed, total transcription time: {format_time(total_transcription_time)}"
|
|
summary += f" | Success: {success_count}/{total_files}"
|
|
summary += f" | Failed: {error_count}/{total_files}"
|
|
|
|
# Output results
|
|
for result in results:
|
|
logger.info(result)
|
|
|
|
return summary
|
|
|
|
|
|
def report_progress(current: int, total: int, elapsed_time: float) -> str:
|
|
"""Generate progress report"""
|
|
progress = current / total * 100
|
|
eta = (elapsed_time / current) * (total - current) if current > 0 else 0
|
|
return (f"Progress: {current}/{total} ({progress:.1f}%)" +
|
|
f" | Time used: {format_time(elapsed_time)}" +
|
|
f" | Estimated remaining: {format_time(eta)}") |