implemented end-to-end transcription pipeline

This commit is contained in:
ALIHAN DIKEL
2025-03-10 15:29:25 +03:00
parent 54f7a14cc6
commit 7f5af1cb85
5 changed files with 318 additions and 23 deletions

View File

@@ -0,0 +1,91 @@
#!/usr/bin/env python3
"""
CLI tool for testing the Whisper STT engine directly.
"""
import argparse
import json
import os
import sys
import time
from pathlib import Path
from loguru import logger
from src.engine.stt import WhisperEngine
# Add parent directory to path so we can import engine
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
def main():
parser = argparse.ArgumentParser(description="Transcribe audio files using Whisper")
parser.add_argument("audio_file", help="Path to audio file")
parser.add_argument("--model", help="Whisper model to use (tiny, base, small, medium, large, large_v3)")
parser.add_argument("--output", "-o", help="Output file path (default: stdout)")
parser.add_argument("--format", "-f", choices=["text", "json"], default="text",
help="Output format (default: text)")
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
args = parser.parse_args()
# Configure logger
log_level = "INFO" if args.verbose else "WARNING"
logger.remove() # Remove default handler
logger.add(sys.stderr, level=log_level,
format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>")
# Set model from args or use default
if args.model:
os.environ["WHISPER_MODEL"] = args.model
# Initialize engine
engine = WhisperEngine()
logger.info(f"Using Whisper model: {engine.model_name} on {engine.device}")
logger.info(f"Transcribing: {args.audio_file}")
try:
# Transcribe
start_time = time.time()
result = engine.transcribe(args.audio_file)
elapsed = time.time() - start_time
# Format output
if args.format == "json":
output = json.dumps({
"text": result.text,
"language": result.language,
"segments": result.segments,
"duration": result.duration,
"processing_time": result.processing_time,
"model": engine.model_name,
"device": engine.device
}, indent=2)
else:
output = result.text
# Write output
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
f.write(output)
logger.info(f"Output written to: {args.output}")
else:
# For the actual transcription output, we still use print (not logger)
# so it goes to stdout and can be piped or redirected
print(output)
logger.info(f"Transcription completed in {elapsed:.2f}s")
logger.info(f"Audio duration: {result.duration:.2f}s")
logger.info(f"Real-time factor: {elapsed / result.duration:.2f}x")
logger.info(f"Detected language: {result.language}")
return 0
except Exception as e:
logger.error(f"Error: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())

131
src/engine/stt.py Normal file
View File

@@ -0,0 +1,131 @@
"""
Whisper-based speech-to-text transcription engine.
"""
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Dict, Any, Union
import torch
import whisper
import librosa
from loguru import logger
@dataclass
class TranscriptionResult:
"""Result of a transcription operation."""
text: str
language: str
segments: list
duration: float
processing_time: float
class WhisperEngine:
"""
Speech-to-text engine using OpenAI's Whisper model.
Environment variables:
WHISPER_MODEL: Model size to use (tiny, base, small, medium, large, large_v3)
Default: "tiny" (for development)
WHISPER_DEVICE: Device to use for inference (cpu, cuda)
Default: "cpu" or first available CUDA device
"""
def __init__(self):
# Get model name from environment variable or use default
self.model_name = os.environ.get("WHISPER_MODEL")
# Determine device (CPU or GPU)
self.device = os.environ.get("WHISPER_DEVICE")
if not self.device:
if torch.cuda.is_available():
self.device = "cuda"
logger.debug("Using CUDA GPU for inference")
elif torch.backends.mps.is_available():
self.device = "mps"
logger.debug("Using Apple Silicon MPS GPU for inference")
else:
self.device = "cpu"
logger.debug("Using CPU for inference (no GPU detected)")
self._model = None
self._model_loaded = False
@property
def model(self):
"""Lazy-load the Whisper model on first use."""
if not self._model_loaded:
models_dir = os.environ.get("WHISPER_MODELS_DIR")
# Check if models directory exists
if not os.path.exists(models_dir):
logger.debug(f"Models directory '{models_dir}' not found")
raise RuntimeError(f"Models directory '{models_dir}' not found. Please create it first.")
try:
logger.debug(f"Loading model '{self.model_name}' from {models_dir}")
self._model = whisper.load_model(self.model_name, download_root=models_dir, device=self.device)
self._model_loaded = True
logger.debug(f"Successfully loaded model '{self.model_name}' on {self.device}")
except Exception as e:
raise RuntimeError(f"Failed to load Whisper model '{self.model_name}': {e}")
return self._model
def transcribe(self, audio_path: Union[str, Path], **kwargs) -> TranscriptionResult:
"""
Transcribe an audio file to text.
Args:
audio_path: Path to the audio file
**kwargs: Additional arguments to pass to whisper.transcribe()
Returns:
TranscriptionResult object with the transcription and metadata
"""
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"Audio file not found: {audio_path}")
start_time = time.time()
# Set default options if not provided and add additional options
options = {
"language": kwargs.pop("language", None), # Auto-detect if None
"task": kwargs.pop("task", "transcribe"),
"verbose": kwargs.pop("verbose", True)
}
options.update(kwargs)
# Perform transcription
result = self.model.transcribe(str(audio_path), **options)
processing_time = time.time() - start_time
# Extract audio duration if possible (or estimate)
try:
duration = librosa.get_duration(path=str(audio_path))
except Exception:
# If librosa fails, use the last segment's end time or 0
segments = result.get("segments", [])
duration = segments[-1]["end"] if segments else 0
return TranscriptionResult(
text=result["text"],
language=result["language"],
segments=result["segments"],
duration=duration,
processing_time=processing_time
)
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the current model."""
return {
"name": self.model_name,
"device": self.device,
"loaded": self._model_loaded,
"is_multilingual": True # All Whisper models are multilingual
}

View File

@@ -7,16 +7,23 @@ from fastapi import FastAPI, UploadFile, File, HTTPException, Request, WebSocket
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from loguru import logger
from config import UPLOAD_DIR, ALLOWED_EXTENSIONS, STATIC_DIR, TEMPLATES_DIR, TRANSCRIPT_DIR
from worker import audio_processor, FileStatus
from worker import FileStatus, AudioProcessor
audio_processor = AudioProcessor()
@asynccontextmanager
async def lifespan(app: FastAPI):
audio_processor.start() # Startup code
# Startup code
logger.debug("initializing audio processor")
audio_processor.start()
yield
audio_processor.stop() # Shutdown code
# Shutdown code
audio_processor.stop()
app = FastAPI(title="Transcriptor", lifespan=lifespan)

View File

@@ -3,7 +3,7 @@
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Audio File Upload</title>
<title>Transcriptor Agent</title>
<link rel="icon" href="/static/favicon.ico" type="image/x-icon">
<style>
body {
@@ -120,9 +120,8 @@
</style>
</head>
<body>
<h1>Audio File Upload</h1>
<p>Upload your MP3 or WAV audio files</p>
<h1>Transcriptor Agent</h1>
<div class="upload-form">
<form id="uploadForm">
<div class="form-group">

View File

@@ -1,5 +1,6 @@
# src/worker.py
import asyncio
import json
import time
from pathlib import Path
from typing import Dict, List, Set, Optional
@@ -8,6 +9,7 @@ from enum import Enum
from loguru import logger
from config import UPLOAD_DIR, TRANSCRIPT_DIR
from engine.stt import WhisperEngine, TranscriptionResult
class FileStatus(Enum):
@@ -22,12 +24,36 @@ class AudioProcessor:
self.file_status: Dict[str, FileStatus] = {}
self.is_running = False
self._task: Optional[asyncio.Task] = None
self._transcription_tasks: Dict[str, asyncio.Task] = {}
# Initialize Whisper engine
self.stt_engine = WhisperEngine()
# Load any existing statuses
self._load_initial_status()
def _load_initial_status(self):
"""Load initial status based on existing transcript files."""
# Find all transcript files and mark their source files as completed
for transcript_path in TRANSCRIPT_DIR.iterdir():
if transcript_path.suffix.lower() == '.txt':
base_name = transcript_path.stem
# Check for corresponding audio files
for ext in ['.mp3', '.wav']:
audio_file = f"{base_name}{ext}"
audio_path = UPLOAD_DIR / audio_file
if audio_path.exists():
self.file_status[audio_file] = FileStatus.COMPLETED
break
def start(self):
"""Start the background worker"""
if not self._task or self._task.done():
self.is_running = True
self._task = asyncio.create_task(self._monitor_files())
logger.info(f"STT Engine started with model: {self.stt_engine.model_name}")
def stop(self):
"""Stop the background worker"""
@@ -35,6 +61,12 @@ class AudioProcessor:
if self._task and not self._task.done():
self._task.cancel()
# Cancel any running transcription tasks
for filename, task in self._transcription_tasks.items():
if not task.done():
task.cancel()
logger.info(f"Cancelled transcription task for {filename}")
async def _monitor_files(self):
"""Monitor the uploads folder for new files"""
while self.is_running:
@@ -55,34 +87,71 @@ class AudioProcessor:
# Check if file exists and is valid
if not file_path.exists():
logger.error(f"File not found: {file_path}")
return False
# Check if already processing
if filename in self._transcription_tasks and not self._transcription_tasks[filename].done():
logger.info(f"File {filename} is already being processed")
return True
# Update status
self.file_status[filename] = FileStatus.PROCESSING
# Create transcription task
self._transcription_tasks[filename] = asyncio.create_task(
self._transcribe_file(filename)
)
return True
async def _transcribe_file(self, filename: str) -> bool:
"""Run the actual transcription in a separate task."""
file_path = UPLOAD_DIR / filename
base_name = Path(filename).stem
transcript_path = TRANSCRIPT_DIR / f"{base_name}.txt"
json_path = TRANSCRIPT_DIR / f"{base_name}.json"
try:
### mock starts here
# TODO: implement real transcription logic here (as stt engine)
# Mock transcription process by creating a text file with the same name
base_name = Path(filename).stem # Get filename without extension
transcript_path = TRANSCRIPT_DIR / f"{base_name}.txt"
logger.info(f"Starting transcription for {filename}")
# Create a simple mock transcript file
with open(transcript_path, "w") as f:
f.write(f"Mock transcript for {filename}\n")
f.write("This is a placeholder for the actual transcription.\n")
f.write(f"Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
# Run transcription (in a thread pool to avoid blocking the event loop)
result = await asyncio.to_thread(
self.stt_engine.transcribe,
file_path
)
# Simulate some processing time
await asyncio.sleep(3)
### mock ends here
# Save plain text transcript
with open(transcript_path, "w", encoding="utf-8") as f:
f.write(result.text)
# Save detailed JSON result
with open(json_path, "w", encoding="utf-8") as f:
# Create a serializable version of the result
serializable_result = {
"text": result.text,
"language": result.language,
"segments": result.segments,
"duration": result.duration,
"processing_time": result.processing_time,
"metadata": {
"model": self.stt_engine.model_name,
"device": self.stt_engine.device,
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
}
}
json.dump(serializable_result, f, indent=2)
logger.success(f"Transcription completed for {filename} "
f"({result.duration:.1f}s audio, {result.processing_time:.1f}s processing)")
# Mark as completed
self.file_status[filename] = FileStatus.COMPLETED
return True
except Exception as e:
logger.error(f"Error transcribing {filename}: {e}")
self.file_status[filename] = FileStatus.FAILED
logger.error(f"Error processing {filename}: {e}")
return False
def get_status(self, filename: str) -> FileStatus:
@@ -94,5 +163,3 @@ class AudioProcessor:
return {name: status.value for name, status in self.file_status.items()}
# Create a global instance of the processor
audio_processor = AudioProcessor()