working
This commit is contained in:
committed by
Andres Marafioti
parent
669bdbf94d
commit
77894a7a5b
@@ -18,6 +18,15 @@ logger = logging.getLogger(__name__)
|
||||
console = Console()
|
||||
|
||||
|
||||
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
|
||||
"<|en|>": "english",
|
||||
"<|fr|>": "french",
|
||||
"<|es|>": "spanish",
|
||||
"<|zh|>": "chinese",
|
||||
"<|ja|>": "japanese",
|
||||
"<|ko|>": "korean",
|
||||
}
|
||||
|
||||
class LanguageModelHandler(BaseHandler):
|
||||
"""
|
||||
Handles the language model part.
|
||||
@@ -106,6 +115,7 @@ class LanguageModelHandler(BaseHandler):
|
||||
language_id = None
|
||||
if isinstance(prompt, tuple):
|
||||
prompt, language_id = prompt
|
||||
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_id]}. " + prompt
|
||||
|
||||
self.chat.append({"role": self.user_role, "content": prompt})
|
||||
thread = Thread(
|
||||
@@ -125,7 +135,7 @@ class LanguageModelHandler(BaseHandler):
|
||||
printable_text += new_text
|
||||
sentences = sent_tokenize(printable_text)
|
||||
if len(sentences) > 1:
|
||||
yield (sentences[0])
|
||||
yield (sentences[0], language_id)
|
||||
printable_text = new_text
|
||||
|
||||
self.chat.append({"role": "assistant", "content": generated_text})
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from time import perf_counter
|
||||
from transformers import (
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoProcessor,
|
||||
AutoModelForSpeechSeq2Seq
|
||||
)
|
||||
import torch
|
||||
|
||||
from copy import copy
|
||||
from baseHandler import BaseHandler
|
||||
from rich.console import Console
|
||||
import logging
|
||||
@@ -12,6 +12,15 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
console = Console()
|
||||
|
||||
SUPPORTED_LANGUAGES = [
|
||||
"<|en|>",
|
||||
"<|fr|>",
|
||||
"<|es|>",
|
||||
"<|zh|>",
|
||||
"<|ja|>",
|
||||
"<|ko|>",
|
||||
]
|
||||
|
||||
|
||||
class WhisperSTTHandler(BaseHandler):
|
||||
"""
|
||||
@@ -24,13 +33,16 @@ class WhisperSTTHandler(BaseHandler):
|
||||
device="cuda",
|
||||
torch_dtype="float16",
|
||||
compile_mode=None,
|
||||
language=None,
|
||||
gen_kwargs={},
|
||||
):
|
||||
self.device = device
|
||||
self.torch_dtype = getattr(torch, torch_dtype)
|
||||
self.compile_mode = compile_mode
|
||||
self.gen_kwargs = gen_kwargs
|
||||
del self.gen_kwargs["language"]
|
||||
self.last_language = language
|
||||
if self.last_language is not None:
|
||||
self.gen_kwargs["language"] = self.last_language
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
@@ -103,6 +115,17 @@ class WhisperSTTHandler(BaseHandler):
|
||||
|
||||
input_features = self.prepare_model_inputs(spoken_prompt)
|
||||
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
|
||||
language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
|
||||
|
||||
if language_id not in SUPPORTED_LANGUAGES: # reprocess with the last language
|
||||
logger.warning("Whisper detected unsupported language:", language_id)
|
||||
gen_kwargs = copy(self.gen_kwargs)
|
||||
gen_kwargs['language'] = self.last_language
|
||||
language_id = self.last_language
|
||||
pred_ids = self.model.generate(input_features, **gen_kwargs)
|
||||
else:
|
||||
self.last_language = language_id
|
||||
|
||||
pred_text = self.processor.batch_decode(
|
||||
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
|
||||
)[0]
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
current_language = "en"
|
||||
Reference in New Issue
Block a user