review from eustache

This commit is contained in:
Andres Marafioti
2024-09-04 13:39:58 +02:00
parent 65bef760b4
commit 65f779de83
3 changed files with 46 additions and 48 deletions

View File

@@ -19,12 +19,12 @@ console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"<|en|>": "english",
"<|fr|>": "french",
"<|es|>": "spanish",
"<|zh|>": "chinese",
"<|ja|>": "japanese",
"<|ko|>": "korean",
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}
class LanguageModelHandler(BaseHandler):
@@ -112,10 +112,10 @@ class LanguageModelHandler(BaseHandler):
def process(self, prompt):
logger.debug("infering language model...")
language_id = None
language_code = None
if isinstance(prompt, tuple):
prompt, language_id = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_id]}. " + prompt
prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread(
@@ -135,10 +135,10 @@ class LanguageModelHandler(BaseHandler):
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield (sentences[0], language_id)
yield (sentences[0], language_code)
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield (printable_text, language_id)
yield (printable_text, language_code)

View File

@@ -13,12 +13,12 @@ logger = logging.getLogger(__name__)
console = Console()
SUPPORTED_LANGUAGES = [
"<|en|>",
"<|fr|>",
"<|es|>",
"<|zh|>",
"<|ja|>",
"<|ko|>",
"en",
"fr",
"es",
"zh",
"ja",
"ko",
]
@@ -117,24 +117,24 @@ class WhisperSTTHandler(BaseHandler):
input_features = self.prepare_model_inputs(spoken_prompt)
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
if language_id not in SUPPORTED_LANGUAGES: # reprocess with the last language
logger.warning("Whisper detected unsupported language:", language_id)
if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
logger.warning("Whisper detected unsupported language:", language_code)
gen_kwargs = copy(self.gen_kwargs)
gen_kwargs['language'] = self.last_language
language_id = self.last_language
language_code = self.last_language
pred_ids = self.model.generate(input_features, **gen_kwargs)
else:
self.last_language = language_id
self.last_language = language_code
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
language_id = self.processor.tokenizer.decode(pred_ids[0, 1])
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language ID Whisper: {language_id}")
logger.debug(f"Language Code Whisper: {language_code}")
yield (pred_text, language_id)
yield (pred_text, language_code)

View File

@@ -11,21 +11,21 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"<|en|>": "EN_NEWEST",
"<|fr|>": "FR",
"<|es|>": "ES",
"<|zh|>": "ZH",
"<|ja|>": "JP",
"<|ko|>": "KR",
"en": "EN_NEWEST",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"<|en|>": "EN-Newest",
"<|fr|>": "FR",
"<|es|>": "ES",
"<|zh|>": "ZH",
"<|ja|>": "JP",
"<|ko|>": "KR",
"en": "EN-Newest",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
@@ -41,14 +41,12 @@ class MeloTTSHandler(BaseHandler):
):
self.should_listen = should_listen
self.device = device
self.language = (
"<|" + language + "|>"
) # 'Tokenize' the language code to do less operations
self.language = language
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER["<|" + speaker_to_id + "|>"]
WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
]
self.blocksize = blocksize
self.warmup()
@@ -58,26 +56,26 @@ class MeloTTSHandler(BaseHandler):
_ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
def process(self, llm_sentence):
language_id = None
language_code = None
if isinstance(llm_sentence, tuple):
llm_sentence, language_id = llm_sentence
llm_sentence, language_code = llm_sentence
console.print(f"[green]ASSISTANT: {llm_sentence}")
if language_id is not None and self.language != language_id:
if language_code is not None and self.language != language_code:
try:
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_id],
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code],
device=self.device,
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_id]
WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code]
]
self.language = language_id
self.language = language_code
except KeyError:
console.print(
f"[red]Language {language_id} not supported by Melo. Using {self.language} instead."
f"[red]Language {language_code} not supported by Melo. Using {self.language} instead."
)
if self.device == "mps":