90 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			90 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| from time import perf_counter
 | |
| from baseHandler import BaseHandler
 | |
| from lightning_whisper_mlx import LightningWhisperMLX
 | |
| import numpy as np
 | |
| from rich.console import Console
 | |
| from copy import copy
 | |
| import torch
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| console = Console()
 | |
| 
 | |
| SUPPORTED_LANGUAGES = [
 | |
|     "en",
 | |
|     "fr",
 | |
|     "es",
 | |
|     "zh",
 | |
|     "ja",
 | |
|     "ko",
 | |
| ]
 | |
| 
 | |
| 
 | |
| class LightningWhisperSTTHandler(BaseHandler):
 | |
|     """
 | |
|     Handles the Speech To Text generation using a Whisper model.
 | |
|     """
 | |
| 
 | |
|     def setup(
 | |
|         self,
 | |
|         model_name="distil-large-v3",
 | |
|         device="mps",
 | |
|         torch_dtype="float16",
 | |
|         compile_mode=None,
 | |
|         language=None,
 | |
|         gen_kwargs={},
 | |
|     ):
 | |
|         if len(model_name.split("/")) > 1:
 | |
|             model_name = model_name.split("/")[-1]
 | |
|         self.device = device
 | |
|         self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
 | |
|         if language == 'auto':
 | |
|             language = None
 | |
|         self.last_language = language
 | |
|         if self.last_language is not None:
 | |
|             self.gen_kwargs["language"] = self.last_language
 | |
| 
 | |
|         self.warmup()
 | |
| 
 | |
|     def warmup(self):
 | |
|         logger.info(f"Warming up {self.__class__.__name__}")
 | |
| 
 | |
|         # 2 warmup steps for no compile or compile mode with CUDA graphs capture
 | |
|         n_steps = 1
 | |
|         dummy_input = np.array([0] * 512)
 | |
| 
 | |
|         for _ in range(n_steps):
 | |
|             _ = self.model.transcribe(dummy_input)["text"].strip()
 | |
| 
 | |
|     def process(self, spoken_prompt):
 | |
|         logger.debug("infering whisper...")
 | |
| 
 | |
|         global pipeline_start
 | |
|         pipeline_start = perf_counter()
 | |
| 
 | |
|         # language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2]  # remove "<|" and "|>"
 | |
| 
 | |
|         language_code = self.model.transcribe(spoken_prompt)["language"]
 | |
| 
 | |
|         if language_code not in SUPPORTED_LANGUAGES:  # reprocess with the last language
 | |
|             logger.warning("Whisper detected unsupported language:", language_code)
 | |
|             gen_kwargs = copy(self.gen_kwargs) #####
 | |
|             gen_kwargs['language'] = self.last_language
 | |
|             language_code = self.last_language
 | |
|             # pred_ids = self.model.generate(input_features, **gen_kwargs)
 | |
|         else:
 | |
|             self.last_language = language_code
 | |
| 
 | |
|         pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
 | |
|         torch.mps.empty_cache()
 | |
| 
 | |
|         # language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
 | |
|         language_code = self.model.transcribe(spoken_prompt)["language"]
 | |
| 
 | |
|         logger.debug("finished whisper inference")
 | |
|         console.print(f"[yellow]USER: {pred_text}")
 | |
|         logger.debug(f"Language Code Whisper: {language_code}")
 | |
| 
 | |
|         yield (pred_text, language_code)
 | 
