Fixes to bugs from original PR
This commit is contained in:
@@ -99,13 +99,9 @@ class MLXLanguageModelHandler(BaseHandler):
|
||||
output += t
|
||||
curr_output += t
|
||||
if curr_output.endswith((".", "?", "!", "<|end|>")):
|
||||
yield curr_output.replace("<|end|>", "")
|
||||
yield (curr_output.replace("<|end|>", ""), language_code)
|
||||
curr_output = ""
|
||||
generated_text = output.replace("<|end|>", "")
|
||||
printable_text = generated_text
|
||||
torch.mps.empty_cache()
|
||||
|
||||
self.chat.append({"role": "assistant", "content": generated_text})
|
||||
|
||||
# don't forget last sentence
|
||||
yield (printable_text, language_code)
|
||||
self.chat.append({"role": "assistant", "content": generated_text})
|
||||
@@ -39,11 +39,8 @@ class LightningWhisperSTTHandler(BaseHandler):
|
||||
model_name = model_name.split("/")[-1]
|
||||
self.device = device
|
||||
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
|
||||
if language == 'auto':
|
||||
language = None
|
||||
self.start_language = language
|
||||
self.last_language = language
|
||||
if self.last_language is not None:
|
||||
self.gen_kwargs["language"] = self.last_language
|
||||
|
||||
self.warmup()
|
||||
|
||||
@@ -63,25 +60,24 @@ class LightningWhisperSTTHandler(BaseHandler):
|
||||
global pipeline_start
|
||||
pipeline_start = perf_counter()
|
||||
|
||||
# language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
|
||||
|
||||
language_code = self.model.transcribe(spoken_prompt)["language"]
|
||||
|
||||
if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
|
||||
logger.warning("Whisper detected unsupported language:", language_code)
|
||||
gen_kwargs = copy(self.gen_kwargs) #####
|
||||
gen_kwargs['language'] = self.last_language
|
||||
language_code = self.last_language
|
||||
# pred_ids = self.model.generate(input_features, **gen_kwargs)
|
||||
if self.start_language != 'auto':
|
||||
transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
|
||||
else:
|
||||
self.last_language = language_code
|
||||
transcription_dict = self.model.transcribe(spoken_prompt)
|
||||
language_code = transcription_dict["language"]
|
||||
if language_code not in SUPPORTED_LANGUAGES:
|
||||
logger.warning(f"Whisper detected unsupported language: {language_code}")
|
||||
if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
|
||||
transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
|
||||
else:
|
||||
transcription_dict = {"text": "", "language": "en"}
|
||||
else:
|
||||
self.last_language = language_code
|
||||
|
||||
pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
|
||||
pred_text = transcription_dict["text"].strip()
|
||||
language_code = transcription_dict["language"]
|
||||
torch.mps.empty_cache()
|
||||
|
||||
# language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
|
||||
language_code = self.model.transcribe(spoken_prompt)["language"]
|
||||
|
||||
logger.debug("finished whisper inference")
|
||||
console.print(f"[yellow]USER: {pred_text}")
|
||||
logger.debug(f"Language Code Whisper: {language_code}")
|
||||
|
||||
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
console = Console()
|
||||
|
||||
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
|
||||
"en": "EN",
|
||||
"en": "EN_NEWEST",
|
||||
"fr": "FR",
|
||||
"es": "ES",
|
||||
"zh": "ZH",
|
||||
@@ -20,7 +20,7 @@ WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
|
||||
}
|
||||
|
||||
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
|
||||
"en": "EN-BR",
|
||||
"en": "EN-Newest",
|
||||
"fr": "FR",
|
||||
"es": "ES",
|
||||
"zh": "ZH",
|
||||
|
||||
Reference in New Issue
Block a user