implemented end-to-end transcription pipeline
This commit is contained in:
91
scripts/generate_transcript.py
Normal file
91
scripts/generate_transcript.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLI tool for testing the Whisper STT engine directly.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from src.engine.stt import WhisperEngine
|
||||
|
||||
# Add parent directory to path so we can import engine
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Transcribe audio files using Whisper")
|
||||
parser.add_argument("audio_file", help="Path to audio file")
|
||||
parser.add_argument("--model", help="Whisper model to use (tiny, base, small, medium, large, large_v3)")
|
||||
parser.add_argument("--output", "-o", help="Output file path (default: stdout)")
|
||||
parser.add_argument("--format", "-f", choices=["text", "json"], default="text",
|
||||
help="Output format (default: text)")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logger
|
||||
log_level = "INFO" if args.verbose else "WARNING"
|
||||
logger.remove() # Remove default handler
|
||||
logger.add(sys.stderr, level=log_level,
|
||||
format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>")
|
||||
|
||||
# Set model from args or use default
|
||||
if args.model:
|
||||
os.environ["WHISPER_MODEL"] = args.model
|
||||
|
||||
# Initialize engine
|
||||
engine = WhisperEngine()
|
||||
|
||||
logger.info(f"Using Whisper model: {engine.model_name} on {engine.device}")
|
||||
logger.info(f"Transcribing: {args.audio_file}")
|
||||
|
||||
try:
|
||||
# Transcribe
|
||||
start_time = time.time()
|
||||
result = engine.transcribe(args.audio_file)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Format output
|
||||
if args.format == "json":
|
||||
output = json.dumps({
|
||||
"text": result.text,
|
||||
"language": result.language,
|
||||
"segments": result.segments,
|
||||
"duration": result.duration,
|
||||
"processing_time": result.processing_time,
|
||||
"model": engine.model_name,
|
||||
"device": engine.device
|
||||
}, indent=2)
|
||||
else:
|
||||
output = result.text
|
||||
|
||||
# Write output
|
||||
if args.output:
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
f.write(output)
|
||||
logger.info(f"Output written to: {args.output}")
|
||||
else:
|
||||
# For the actual transcription output, we still use print (not logger)
|
||||
# so it goes to stdout and can be piped or redirected
|
||||
print(output)
|
||||
|
||||
logger.info(f"Transcription completed in {elapsed:.2f}s")
|
||||
logger.info(f"Audio duration: {result.duration:.2f}s")
|
||||
logger.info(f"Real-time factor: {elapsed / result.duration:.2f}x")
|
||||
logger.info(f"Detected language: {result.language}")
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
131
src/engine/stt.py
Normal file
131
src/engine/stt.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Whisper-based speech-to-text transcription engine.
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Union
|
||||
|
||||
import torch
|
||||
import whisper
|
||||
import librosa
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionResult:
|
||||
"""Result of a transcription operation."""
|
||||
text: str
|
||||
language: str
|
||||
segments: list
|
||||
duration: float
|
||||
processing_time: float
|
||||
|
||||
|
||||
class WhisperEngine:
|
||||
"""
|
||||
Speech-to-text engine using OpenAI's Whisper model.
|
||||
|
||||
Environment variables:
|
||||
WHISPER_MODEL: Model size to use (tiny, base, small, medium, large, large_v3)
|
||||
Default: "tiny" (for development)
|
||||
WHISPER_DEVICE: Device to use for inference (cpu, cuda)
|
||||
Default: "cpu" or first available CUDA device
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Get model name from environment variable or use default
|
||||
self.model_name = os.environ.get("WHISPER_MODEL")
|
||||
|
||||
# Determine device (CPU or GPU)
|
||||
self.device = os.environ.get("WHISPER_DEVICE")
|
||||
if not self.device:
|
||||
if torch.cuda.is_available():
|
||||
self.device = "cuda"
|
||||
logger.debug("Using CUDA GPU for inference")
|
||||
elif torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
logger.debug("Using Apple Silicon MPS GPU for inference")
|
||||
else:
|
||||
self.device = "cpu"
|
||||
logger.debug("Using CPU for inference (no GPU detected)")
|
||||
|
||||
self._model = None
|
||||
self._model_loaded = False
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Lazy-load the Whisper model on first use."""
|
||||
if not self._model_loaded:
|
||||
models_dir = os.environ.get("WHISPER_MODELS_DIR")
|
||||
|
||||
# Check if models directory exists
|
||||
if not os.path.exists(models_dir):
|
||||
logger.debug(f"Models directory '{models_dir}' not found")
|
||||
raise RuntimeError(f"Models directory '{models_dir}' not found. Please create it first.")
|
||||
|
||||
try:
|
||||
logger.debug(f"Loading model '{self.model_name}' from {models_dir}")
|
||||
self._model = whisper.load_model(self.model_name, download_root=models_dir, device=self.device)
|
||||
self._model_loaded = True
|
||||
logger.debug(f"Successfully loaded model '{self.model_name}' on {self.device}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load Whisper model '{self.model_name}': {e}")
|
||||
|
||||
return self._model
|
||||
|
||||
def transcribe(self, audio_path: Union[str, Path], **kwargs) -> TranscriptionResult:
|
||||
"""
|
||||
Transcribe an audio file to text.
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
**kwargs: Additional arguments to pass to whisper.transcribe()
|
||||
|
||||
Returns:
|
||||
TranscriptionResult object with the transcription and metadata
|
||||
"""
|
||||
audio_path = Path(audio_path)
|
||||
if not audio_path.exists():
|
||||
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Set default options if not provided and add additional options
|
||||
options = {
|
||||
"language": kwargs.pop("language", None), # Auto-detect if None
|
||||
"task": kwargs.pop("task", "transcribe"),
|
||||
"verbose": kwargs.pop("verbose", True)
|
||||
}
|
||||
options.update(kwargs)
|
||||
|
||||
# Perform transcription
|
||||
result = self.model.transcribe(str(audio_path), **options)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Extract audio duration if possible (or estimate)
|
||||
try:
|
||||
duration = librosa.get_duration(path=str(audio_path))
|
||||
except Exception:
|
||||
# If librosa fails, use the last segment's end time or 0
|
||||
segments = result.get("segments", [])
|
||||
duration = segments[-1]["end"] if segments else 0
|
||||
|
||||
return TranscriptionResult(
|
||||
text=result["text"],
|
||||
language=result["language"],
|
||||
segments=result["segments"],
|
||||
duration=duration,
|
||||
processing_time=processing_time
|
||||
)
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the current model."""
|
||||
return {
|
||||
"name": self.model_name,
|
||||
"device": self.device,
|
||||
"loaded": self._model_loaded,
|
||||
"is_multilingual": True # All Whisper models are multilingual
|
||||
}
|
||||
13
src/main.py
13
src/main.py
@@ -7,16 +7,23 @@ from fastapi import FastAPI, UploadFile, File, HTTPException, Request, WebSocket
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from loguru import logger
|
||||
|
||||
from config import UPLOAD_DIR, ALLOWED_EXTENSIONS, STATIC_DIR, TEMPLATES_DIR, TRANSCRIPT_DIR
|
||||
from worker import audio_processor, FileStatus
|
||||
from worker import FileStatus, AudioProcessor
|
||||
|
||||
audio_processor = AudioProcessor()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
audio_processor.start() # Startup code
|
||||
# Startup code
|
||||
logger.debug("initializing audio processor")
|
||||
audio_processor.start()
|
||||
yield
|
||||
audio_processor.stop() # Shutdown code
|
||||
# Shutdown code
|
||||
audio_processor.stop()
|
||||
|
||||
|
||||
app = FastAPI(title="Transcriptor", lifespan=lifespan)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Audio File Upload</title>
|
||||
<title>Transcriptor Agent</title>
|
||||
<link rel="icon" href="/static/favicon.ico" type="image/x-icon">
|
||||
<style>
|
||||
body {
|
||||
@@ -120,9 +120,8 @@
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Audio File Upload</h1>
|
||||
<p>Upload your MP3 or WAV audio files</p>
|
||||
|
||||
<h1>Transcriptor Agent</h1>
|
||||
|
||||
<div class="upload-form">
|
||||
<form id="uploadForm">
|
||||
<div class="form-group">
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# src/worker.py
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set, Optional
|
||||
@@ -8,6 +9,7 @@ from enum import Enum
|
||||
from loguru import logger
|
||||
|
||||
from config import UPLOAD_DIR, TRANSCRIPT_DIR
|
||||
from engine.stt import WhisperEngine, TranscriptionResult
|
||||
|
||||
|
||||
class FileStatus(Enum):
|
||||
@@ -22,12 +24,36 @@ class AudioProcessor:
|
||||
self.file_status: Dict[str, FileStatus] = {}
|
||||
self.is_running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._transcription_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
# Initialize Whisper engine
|
||||
self.stt_engine = WhisperEngine()
|
||||
|
||||
# Load any existing statuses
|
||||
self._load_initial_status()
|
||||
|
||||
def _load_initial_status(self):
|
||||
"""Load initial status based on existing transcript files."""
|
||||
# Find all transcript files and mark their source files as completed
|
||||
for transcript_path in TRANSCRIPT_DIR.iterdir():
|
||||
if transcript_path.suffix.lower() == '.txt':
|
||||
base_name = transcript_path.stem
|
||||
|
||||
# Check for corresponding audio files
|
||||
for ext in ['.mp3', '.wav']:
|
||||
audio_file = f"{base_name}{ext}"
|
||||
audio_path = UPLOAD_DIR / audio_file
|
||||
|
||||
if audio_path.exists():
|
||||
self.file_status[audio_file] = FileStatus.COMPLETED
|
||||
break
|
||||
|
||||
def start(self):
|
||||
"""Start the background worker"""
|
||||
if not self._task or self._task.done():
|
||||
self.is_running = True
|
||||
self._task = asyncio.create_task(self._monitor_files())
|
||||
logger.info(f"STT Engine started with model: {self.stt_engine.model_name}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the background worker"""
|
||||
@@ -35,6 +61,12 @@ class AudioProcessor:
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
|
||||
# Cancel any running transcription tasks
|
||||
for filename, task in self._transcription_tasks.items():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.info(f"Cancelled transcription task for {filename}")
|
||||
|
||||
async def _monitor_files(self):
|
||||
"""Monitor the uploads folder for new files"""
|
||||
while self.is_running:
|
||||
@@ -55,34 +87,71 @@ class AudioProcessor:
|
||||
|
||||
# Check if file exists and is valid
|
||||
if not file_path.exists():
|
||||
logger.error(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
# Check if already processing
|
||||
if filename in self._transcription_tasks and not self._transcription_tasks[filename].done():
|
||||
logger.info(f"File {filename} is already being processed")
|
||||
return True
|
||||
|
||||
# Update status
|
||||
self.file_status[filename] = FileStatus.PROCESSING
|
||||
|
||||
# Create transcription task
|
||||
self._transcription_tasks[filename] = asyncio.create_task(
|
||||
self._transcribe_file(filename)
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def _transcribe_file(self, filename: str) -> bool:
|
||||
"""Run the actual transcription in a separate task."""
|
||||
file_path = UPLOAD_DIR / filename
|
||||
base_name = Path(filename).stem
|
||||
transcript_path = TRANSCRIPT_DIR / f"{base_name}.txt"
|
||||
json_path = TRANSCRIPT_DIR / f"{base_name}.json"
|
||||
|
||||
try:
|
||||
### mock starts here
|
||||
# TODO: implement real transcription logic here (as stt engine)
|
||||
# Mock transcription process by creating a text file with the same name
|
||||
base_name = Path(filename).stem # Get filename without extension
|
||||
transcript_path = TRANSCRIPT_DIR / f"{base_name}.txt"
|
||||
logger.info(f"Starting transcription for {filename}")
|
||||
|
||||
# Create a simple mock transcript file
|
||||
with open(transcript_path, "w") as f:
|
||||
f.write(f"Mock transcript for {filename}\n")
|
||||
f.write("This is a placeholder for the actual transcription.\n")
|
||||
f.write(f"Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
# Run transcription (in a thread pool to avoid blocking the event loop)
|
||||
result = await asyncio.to_thread(
|
||||
self.stt_engine.transcribe,
|
||||
file_path
|
||||
)
|
||||
|
||||
# Simulate some processing time
|
||||
await asyncio.sleep(3)
|
||||
### mock ends here
|
||||
# Save plain text transcript
|
||||
with open(transcript_path, "w", encoding="utf-8") as f:
|
||||
f.write(result.text)
|
||||
|
||||
# Save detailed JSON result
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
# Create a serializable version of the result
|
||||
serializable_result = {
|
||||
"text": result.text,
|
||||
"language": result.language,
|
||||
"segments": result.segments,
|
||||
"duration": result.duration,
|
||||
"processing_time": result.processing_time,
|
||||
"metadata": {
|
||||
"model": self.stt_engine.model_name,
|
||||
"device": self.stt_engine.device,
|
||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
}
|
||||
json.dump(serializable_result, f, indent=2)
|
||||
|
||||
logger.success(f"Transcription completed for {filename} "
|
||||
f"({result.duration:.1f}s audio, {result.processing_time:.1f}s processing)")
|
||||
|
||||
# Mark as completed
|
||||
self.file_status[filename] = FileStatus.COMPLETED
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error transcribing {filename}: {e}")
|
||||
self.file_status[filename] = FileStatus.FAILED
|
||||
logger.error(f"Error processing {filename}: {e}")
|
||||
return False
|
||||
|
||||
def get_status(self, filename: str) -> FileStatus:
|
||||
@@ -94,5 +163,3 @@ class AudioProcessor:
|
||||
return {name: status.value for name, status in self.file_status.items()}
|
||||
|
||||
|
||||
# Create a global instance of the processor
|
||||
audio_processor = AudioProcessor()
|
||||
Reference in New Issue
Block a user