Merge pull request #93 from ybm911/main
Update: Added multi-language support for macOS.
This commit is contained in:
@@ -9,6 +9,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
console = Console()
|
||||
|
||||
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
|
||||
"en": "english",
|
||||
"fr": "french",
|
||||
"es": "spanish",
|
||||
"zh": "chinese",
|
||||
"ja": "japanese",
|
||||
"ko": "korean",
|
||||
}
|
||||
|
||||
class MLXLanguageModelHandler(BaseHandler):
|
||||
"""
|
||||
@@ -61,6 +69,11 @@ class MLXLanguageModelHandler(BaseHandler):
|
||||
|
||||
def process(self, prompt):
|
||||
logger.debug("infering language model...")
|
||||
language_code = None
|
||||
|
||||
if isinstance(prompt, tuple):
|
||||
prompt, language_code = prompt
|
||||
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
|
||||
|
||||
self.chat.append({"role": self.user_role, "content": prompt})
|
||||
|
||||
@@ -89,6 +102,10 @@ class MLXLanguageModelHandler(BaseHandler):
|
||||
yield curr_output.replace("<|end|>", "")
|
||||
curr_output = ""
|
||||
generated_text = output.replace("<|end|>", "")
|
||||
printable_text = generated_text
|
||||
torch.mps.empty_cache()
|
||||
|
||||
self.chat.append({"role": "assistant", "content": generated_text})
|
||||
|
||||
# don't forget last sentence
|
||||
yield (printable_text, language_code)
|
||||
@@ -4,12 +4,22 @@ from baseHandler import BaseHandler
|
||||
from lightning_whisper_mlx import LightningWhisperMLX
|
||||
import numpy as np
|
||||
from rich.console import Console
|
||||
from copy import copy
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
console = Console()
|
||||
|
||||
SUPPORTED_LANGUAGES = [
|
||||
"en",
|
||||
"fr",
|
||||
"es",
|
||||
"zh",
|
||||
"ja",
|
||||
"ko",
|
||||
]
|
||||
|
||||
|
||||
class LightningWhisperSTTHandler(BaseHandler):
|
||||
"""
|
||||
@@ -19,7 +29,7 @@ class LightningWhisperSTTHandler(BaseHandler):
|
||||
def setup(
|
||||
self,
|
||||
model_name="distil-large-v3",
|
||||
device="cuda",
|
||||
device="mps",
|
||||
torch_dtype="float16",
|
||||
compile_mode=None,
|
||||
language=None,
|
||||
@@ -29,6 +39,12 @@ class LightningWhisperSTTHandler(BaseHandler):
|
||||
model_name = model_name.split("/")[-1]
|
||||
self.device = device
|
||||
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
|
||||
if language == 'auto':
|
||||
language = None
|
||||
self.last_language = language
|
||||
if self.last_language is not None:
|
||||
self.gen_kwargs["language"] = self.last_language
|
||||
|
||||
self.warmup()
|
||||
|
||||
def warmup(self):
|
||||
@@ -47,10 +63,27 @@ class LightningWhisperSTTHandler(BaseHandler):
|
||||
global pipeline_start
|
||||
pipeline_start = perf_counter()
|
||||
|
||||
# language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
|
||||
|
||||
language_code = self.model.transcribe(spoken_prompt)["language"]
|
||||
|
||||
if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
|
||||
logger.warning("Whisper detected unsupported language:", language_code)
|
||||
gen_kwargs = copy(self.gen_kwargs) #####
|
||||
gen_kwargs['language'] = self.last_language
|
||||
language_code = self.last_language
|
||||
# pred_ids = self.model.generate(input_features, **gen_kwargs)
|
||||
else:
|
||||
self.last_language = language_code
|
||||
|
||||
pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
|
||||
torch.mps.empty_cache()
|
||||
|
||||
# language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
|
||||
language_code = self.model.transcribe(spoken_prompt)["language"]
|
||||
|
||||
logger.debug("finished whisper inference")
|
||||
console.print(f"[yellow]USER: {pred_text}")
|
||||
logger.debug(f"Language Code Whisper: {language_code}")
|
||||
|
||||
yield pred_text
|
||||
yield (pred_text, language_code)
|
||||
|
||||
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
console = Console()
|
||||
|
||||
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
|
||||
"en": "EN_NEWEST",
|
||||
"en": "EN",
|
||||
"fr": "FR",
|
||||
"es": "ES",
|
||||
"zh": "ZH",
|
||||
@@ -20,7 +20,7 @@ WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
|
||||
}
|
||||
|
||||
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
|
||||
"en": "EN-Newest",
|
||||
"en": "EN-BR",
|
||||
"fr": "FR",
|
||||
"es": "ES",
|
||||
"zh": "ZH",
|
||||
|
||||
Reference in New Issue
Block a user