This commit is contained in:
andimarafioti
2024-08-30 09:38:44 +00:00
committed by Andres Marafioti
parent 669bdbf94d
commit 77894a7a5b
3 changed files with 37 additions and 5 deletions

View File

@@ -18,6 +18,15 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"<|en|>": "english",
"<|fr|>": "french",
"<|es|>": "spanish",
"<|zh|>": "chinese",
"<|ja|>": "japanese",
"<|ko|>": "korean",
}
class LanguageModelHandler(BaseHandler):
"""
Handles the language model part.
@@ -106,6 +115,7 @@ class LanguageModelHandler(BaseHandler):
language_id = None
if isinstance(prompt, tuple):
prompt, language_id = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_id]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread(
@@ -125,7 +135,7 @@ class LanguageModelHandler(BaseHandler):
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield (sentences[0])
yield (sentences[0], language_id)
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})

View File

@@ -1,10 +1,10 @@
from time import perf_counter
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoModelForSpeechSeq2Seq
)
import torch
from copy import copy
from baseHandler import BaseHandler
from rich.console import Console
import logging
@@ -12,6 +12,15 @@ import logging
logger = logging.getLogger(__name__)
console = Console()
SUPPORTED_LANGUAGES = [
"<|en|>",
"<|fr|>",
"<|es|>",
"<|zh|>",
"<|ja|>",
"<|ko|>",
]
class WhisperSTTHandler(BaseHandler):
"""
@@ -24,13 +33,16 @@ class WhisperSTTHandler(BaseHandler):
device="cuda",
torch_dtype="float16",
compile_mode=None,
language=None,
gen_kwargs={},
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode = compile_mode
self.gen_kwargs = gen_kwargs
del self.gen_kwargs["language"]
self.last_language = language
if self.last_language is not None:
self.gen_kwargs["language"] = self.last_language
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
@@ -103,6 +115,17 @@ class WhisperSTTHandler(BaseHandler):
input_features = self.prepare_model_inputs(spoken_prompt)
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
if language_id not in SUPPORTED_LANGUAGES: # reprocess with the last language
logger.warning("Whisper detected unsupported language:", language_id)
gen_kwargs = copy(self.gen_kwargs)
gen_kwargs['language'] = self.last_language
language_id = self.last_language
pred_ids = self.model.generate(input_features, **gen_kwargs)
else:
self.last_language = language_id
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]

View File

@@ -1 +0,0 @@
current_language = "en"