Merge pull request #60 from huggingface/multi-language

Add support for multiple languages
This commit is contained in:
Andrés Marafioti
2024-09-04 13:54:08 +02:00
committed by GitHub
5 changed files with 113 additions and 17 deletions

View File

@@ -18,6 +18,15 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}
class LanguageModelHandler(BaseHandler):
"""
Handles the language model part.
@@ -69,7 +78,7 @@ class LanguageModelHandler(BaseHandler):
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
dummy_input_text = "Write me a poem about Machine Learning."
dummy_input_text = "Repeat the word 'home'."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["min_new_tokens"],
@@ -103,6 +112,10 @@ class LanguageModelHandler(BaseHandler):
def process(self, prompt):
logger.debug("infering language model...")
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread(
@@ -122,10 +135,10 @@ class LanguageModelHandler(BaseHandler):
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield (sentences[0])
yield (sentences[0], language_code)
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text
yield (printable_text, language_code)

View File

@@ -1,10 +1,10 @@
from time import perf_counter
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoModelForSpeechSeq2Seq
)
import torch
from copy import copy
from baseHandler import BaseHandler
from rich.console import Console
import logging
@@ -12,6 +12,15 @@ import logging
logger = logging.getLogger(__name__)
console = Console()
SUPPORTED_LANGUAGES = [
"en",
"fr",
"es",
"zh",
"ja",
"ko",
]
class WhisperSTTHandler(BaseHandler):
"""
@@ -24,12 +33,18 @@ class WhisperSTTHandler(BaseHandler):
device="cuda",
torch_dtype="float16",
compile_mode=None,
language=None,
gen_kwargs={},
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode = compile_mode
self.gen_kwargs = gen_kwargs
if language == 'auto':
language = None
self.last_language = language
if self.last_language is not None:
self.gen_kwargs["language"] = self.last_language
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
@@ -102,11 +117,24 @@ class WhisperSTTHandler(BaseHandler):
input_features = self.prepare_model_inputs(spoken_prompt)
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
logger.warning("Whisper detected unsupported language:", language_code)
gen_kwargs = copy(self.gen_kwargs)
gen_kwargs['language'] = self.last_language
language_code = self.last_language
pred_ids = self.model.generate(input_features, **gen_kwargs)
else:
self.last_language = language_code
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}")
yield pred_text
yield (pred_text, language_code)

View File

@@ -10,21 +10,44 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"en": "EN_NEWEST",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"en": "EN-Newest",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
class MeloTTSHandler(BaseHandler):
def setup(
self,
should_listen,
device="mps",
language="EN_NEWEST",
speaker_to_id="EN-Newest",
language="en",
speaker_to_id="en",
gen_kwargs={}, # Unused
blocksize=512,
):
self.should_listen = should_listen
self.device = device
self.model = TTS(language=language, device=device)
self.speaker_id = self.model.hps.data.spk2id[speaker_to_id]
self.language = language
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
]
self.blocksize = blocksize
self.warmup()
@@ -33,7 +56,28 @@ class MeloTTSHandler(BaseHandler):
_ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
def process(self, llm_sentence):
language_code = None
if isinstance(llm_sentence, tuple):
llm_sentence, language_code = llm_sentence
console.print(f"[green]ASSISTANT: {llm_sentence}")
if language_code is not None and self.language != language_code:
try:
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code],
device=self.device,
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code]
]
self.language = language_code
except KeyError:
console.print(
f"[red]Language {language_code} not supported by Melo. Using {self.language} instead."
)
if self.device == "mps":
import time
@@ -44,7 +88,13 @@ class MeloTTSHandler(BaseHandler):
time.time() - start
) # Removing this line makes it fail more often. I'm looking into it.
audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True)
try:
audio_chunk = self.model.tts_to_file(
llm_sentence, self.speaker_id, quiet=True
)
except (AssertionError, RuntimeError) as e:
logger.error(f"Error in MeloTTSHandler: {e}")
audio_chunk = np.array([])
if len(audio_chunk) == 0:
self.should_listen.set()
return

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass, field
@dataclass
class MeloTTSHandlerArguments:
melo_language: str = field(
default="EN_NEWEST",
default="en",
metadata={
"help": "The language of the text to be synthesized. Default is 'EN_NEWEST'."
},
@@ -16,7 +16,7 @@ class MeloTTSHandlerArguments:
},
)
melo_speaker_to_id: str = field(
default="EN-Newest",
default="en",
metadata={
"help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']."
},

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
@@ -51,9 +52,13 @@ class WhisperSTTHandlerArguments:
"help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
},
)
stt_gen_language: str = field(
default="en",
language: Optional[str] = field(
default='en',
metadata={
"help": "The language of the speech to transcribe. Default is 'en' for English."
"help": """The language for the conversation.
Choose between 'en' (english), 'fr' (french), 'es' (spanish),
'zh' (chinese), 'ko' (korean), 'ja' (japanese), or 'None'.
If using 'auto', the language is automatically detected and can
change during the conversation. Default is 'en'."""
},
)
)