38 Commits

Author SHA1 Message Date
Andres Marafioti
d8afa799ac Add support for a docker container for Jetson boards 2024-09-10 14:06:05 +02:00
Andrés Marafioti
3338f88c7a Merge pull request #98 from huggingface/mac-multi-language
Mac multi language
2024-09-10 11:32:09 +02:00
Andres Marafioti
ac1dd41260 improve readme (not according to cursor :( ) 2024-09-09 16:32:25 +02:00
Andres Marafioti
129cd11bf0 Fixes to bugs from original PR 2024-09-09 16:23:20 +02:00
Andres Marafioti
1d3b5bfc40 avoid debug logging for melo 2024-09-09 15:39:08 +02:00
Andrés Marafioti
b3ace711b0 Merge pull request #93 from ybm911/main
Update: Added multi-language support for macOS.
2024-09-09 10:01:19 +02:00
Andrés Marafioti
86c08082f8 Merge pull request #91 from BrutalCoding/patch-1
fix: Changed [True] to [False] in help text for audio_enhancement to align with actual default
2024-09-09 09:35:36 +02:00
Elli0t
29fedd0720 Update: Added multi-language support for macOS 2024-09-08 02:51:23 +08:00
Daniel Breedeveld
6daa9baf68 Update vad_arguments.py
Fixed help comment mismatch with actual default value.
2024-09-07 12:09:46 +08:00
Andrés Marafioti
d98e252769 Merge pull request #87 from huggingface/hotfix-parler
fix
2024-09-05 17:07:11 +02:00
Andres Marafioti
c88eec7e26 fix 2024-09-05 17:02:34 +02:00
Andrés Marafioti
274f27a7a8 Merge pull request #85 from rchan26/main
Fix relative link in README
2024-09-04 20:23:39 +02:00
Andrés Marafioti
144ef43cd1 Merge pull request #84 from rchan26/main
Add language arg to lightning whisper handler
2024-09-04 18:26:31 +02:00
rchan
4662204eae fix relative link in README 2024-09-04 17:18:18 +01:00
rchan
1afd2445d3 add language arg to lightning whisper handler 2024-09-04 17:12:09 +01:00
Andrés Marafioti
8afd078ab4 Merge pull request #78 from AgainstEntropy/patch-1
Update module_arguments.py
2024-09-04 13:55:09 +02:00
Andrés Marafioti
b915e58b76 Merge pull request #60 from huggingface/multi-language
Add support for multiple languages
2024-09-04 13:54:08 +02:00
Andres Marafioti
4e6055f1a9 pass auto for auto language detection 2024-09-04 13:53:47 +02:00
Andres Marafioti
65f779de83 review from eustache 2024-09-04 13:39:58 +02:00
Andres Marafioti
65bef760b4 this param was there twice 2024-09-04 13:24:38 +02:00
andimarafioti
61e0d7c32c set default back to english 2024-09-04 13:24:38 +02:00
andimarafioti
f09cf64f5a pass language as a parameter to avoid multi-language detection 2024-09-04 13:24:38 +02:00
andimarafioti
77894a7a5b working 2024-09-04 13:24:38 +02:00
Andres Marafioti
669bdbf94d catch a few exceptions from melo 2024-09-04 13:24:38 +02:00
Andres Marafioti
3fff1d1da0 catch exception if language is not supported 2024-09-04 13:24:38 +02:00
Andres Marafioti
712005aff0 set language 2024-09-04 13:24:38 +02:00
Andres Marafioti
dc024ab6b7 speed up warm up time for llm 2024-09-04 13:24:38 +02:00
Andres Marafioti
d8f7b5fef6 debug 2024-09-04 13:24:38 +02:00
Andres Marafioti
7b0da0971c ... 2024-09-04 13:24:38 +02:00
Andres Marafioti
901d9a1402 handle language better 2024-09-04 13:24:38 +02:00
Andres Marafioti
2f266b2d95 fixy 2024-09-04 13:24:38 +02:00
Andres Marafioti
2cb9464b8f try to pass the language_id in the queue 2024-09-04 13:24:38 +02:00
Andres Marafioti
0555d4dc75 pass the current language around 2024-09-04 13:24:38 +02:00
Andres Marafioti
e2c9d96824 let whisper determine the language 2024-09-04 13:24:38 +02:00
Andres Marafioti
832ee799d1 update melos arguments 2024-09-04 13:24:38 +02:00
Andres Marafioti
cfd3065e34 language fun 2024-09-04 13:24:38 +02:00
Yihao Wang
dfed6263d0 Update module_arguments.py
fix typo in the help message of `mode`
2024-09-03 14:49:02 -04:00
Andrés Marafioti
d2d33d1035 Merge pull request #77 from huggingface/doc-improvement
improve documentation
2024-09-03 16:26:47 +02:00
16 changed files with 261 additions and 45 deletions

13
Dockerfile.arm64 Normal file
View File

@@ -0,0 +1,13 @@
FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3
ENV PYTHONUNBUFFERED 1
WORKDIR /usr/src/app
# Install packages
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt
COPY . .

View File

@@ -18,6 +18,15 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}
class LanguageModelHandler(BaseHandler):
"""
Handles the language model part.
@@ -69,7 +78,7 @@ class LanguageModelHandler(BaseHandler):
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
dummy_input_text = "Write me a poem about Machine Learning."
dummy_input_text = "Repeat the word 'home'."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["min_new_tokens"],
@@ -103,6 +112,10 @@ class LanguageModelHandler(BaseHandler):
def process(self, prompt):
logger.debug("infering language model...")
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread(
@@ -122,10 +135,10 @@ class LanguageModelHandler(BaseHandler):
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield (sentences[0])
yield (sentences[0], language_code)
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text
yield (printable_text, language_code)

View File

@@ -9,6 +9,14 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}
class MLXLanguageModelHandler(BaseHandler):
"""
@@ -61,6 +69,11 @@ class MLXLanguageModelHandler(BaseHandler):
def process(self, prompt):
logger.debug("infering language model...")
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
@@ -86,9 +99,9 @@ class MLXLanguageModelHandler(BaseHandler):
output += t
curr_output += t
if curr_output.endswith((".", "?", "!", "<|end|>")):
yield curr_output.replace("<|end|>", "")
yield (curr_output.replace("<|end|>", ""), language_code)
curr_output = ""
generated_text = output.replace("<|end|>", "")
torch.mps.empty_cache()
self.chat.append({"role": "assistant", "content": generated_text})
self.chat.append({"role": "assistant", "content": generated_text})

View File

@@ -14,7 +14,7 @@
* [Usage](#usage)
- [Docker Server approach](#docker-server)
- [Server/Client approach](#serverclient-approach)
- [Local approach](#local-approach)
- [Local approach](#local-approach-running-on-mac)
* [Command-line usage](#command-line-usage)
- [Model parameters](#model-parameters)
- [Generation parameters](#generation-parameters)
@@ -79,27 +79,28 @@ https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install
### Server/Client Approach
To run the pipeline on the server:
```bash
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
```
1. Run the pipeline on the server:
```bash
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
```
Then run the client locally to handle sending microphone input and receiving generated audio:
```bash
python listen_and_play.py --host <IP address of your server>
```
2. Run the client locally to handle microphone input and receive generated audio:
```bash
python listen_and_play.py --host <IP address of your server>
```
### Running on Mac
To run on mac, we recommend setting the flag `--local_mac_optimal_settings`:
```bash
python s2s_pipeline.py --local_mac_optimal_settings
```
### Local Approach (Mac)
You can also pass `--device mps` to have all the models set to device mps.
The local mac optimal settings set the mode to be local as explained above and change the models to:
- LightningWhisperMLX
- MLX LM
- MeloTTS
1. For optimal settings on Mac:
```bash
python s2s_pipeline.py --local_mac_optimal_settings
```
This setting:
- Adds `--device mps` to use MPS for all models.
- Sets LightningWhisperMLX for STT
- Sets MLX LM for language model
- Sets MeloTTS for TTS
### Recommended usage with Cuda
@@ -117,6 +118,57 @@ python s2s_pipeline.py \
For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`).
### Multi-language Support
The pipeline supports multiple languages, allowing for automatic language detection or specific language settings. Here are examples for both local (Mac) and server setups:
#### With the server version:
For automatic language detection:
```bash
python s2s_pipeline.py \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
```
Or for one language in particular, chinese in this example
```bash
python s2s_pipeline.py \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
```
#### Local Mac Setup
For automatic language detection:
```bash
python s2s_pipeline.py \
--local_mac_optimal_settings \
--device mps \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
```
Or for one language in particular, chinese in this example
```bash
python s2s_pipeline.py \
--local_mac_optimal_settings \
--device mps \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
```
## Command-line Usage
### Model Parameters

View File

@@ -4,12 +4,22 @@ from baseHandler import BaseHandler
from lightning_whisper_mlx import LightningWhisperMLX
import numpy as np
from rich.console import Console
from copy import copy
import torch
logger = logging.getLogger(__name__)
console = Console()
SUPPORTED_LANGUAGES = [
"en",
"fr",
"es",
"zh",
"ja",
"ko",
]
class LightningWhisperSTTHandler(BaseHandler):
"""
@@ -19,15 +29,19 @@ class LightningWhisperSTTHandler(BaseHandler):
def setup(
self,
model_name="distil-large-v3",
device="cuda",
device="mps",
torch_dtype="float16",
compile_mode=None,
language=None,
gen_kwargs={},
):
if len(model_name.split("/")) > 1:
model_name = model_name.split("/")[-1]
self.device = device
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
self.start_language = language
self.last_language = language
self.warmup()
def warmup(self):
@@ -46,10 +60,26 @@ class LightningWhisperSTTHandler(BaseHandler):
global pipeline_start
pipeline_start = perf_counter()
pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
if self.start_language != 'auto':
transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
else:
transcription_dict = self.model.transcribe(spoken_prompt)
language_code = transcription_dict["language"]
if language_code not in SUPPORTED_LANGUAGES:
logger.warning(f"Whisper detected unsupported language: {language_code}")
if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
else:
transcription_dict = {"text": "", "language": "en"}
else:
self.last_language = language_code
pred_text = transcription_dict["text"].strip()
language_code = transcription_dict["language"]
torch.mps.empty_cache()
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}")
yield pred_text
yield (pred_text, language_code)

View File

@@ -1,10 +1,10 @@
from time import perf_counter
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoModelForSpeechSeq2Seq
)
import torch
from copy import copy
from baseHandler import BaseHandler
from rich.console import Console
import logging
@@ -12,6 +12,15 @@ import logging
logger = logging.getLogger(__name__)
console = Console()
SUPPORTED_LANGUAGES = [
"en",
"fr",
"es",
"zh",
"ja",
"ko",
]
class WhisperSTTHandler(BaseHandler):
"""
@@ -24,12 +33,18 @@ class WhisperSTTHandler(BaseHandler):
device="cuda",
torch_dtype="float16",
compile_mode=None,
language=None,
gen_kwargs={},
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode = compile_mode
self.gen_kwargs = gen_kwargs
if language == 'auto':
language = None
self.last_language = language
if self.last_language is not None:
self.gen_kwargs["language"] = self.last_language
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
@@ -102,11 +117,24 @@ class WhisperSTTHandler(BaseHandler):
input_features = self.prepare_model_inputs(spoken_prompt)
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
logger.warning("Whisper detected unsupported language:", language_code)
gen_kwargs = copy(self.gen_kwargs)
gen_kwargs['language'] = self.last_language
language_code = self.last_language
pred_ids = self.model.generate(input_features, **gen_kwargs)
else:
self.last_language = language_code
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}")
yield pred_text
yield (pred_text, language_code)

View File

@@ -10,21 +10,44 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
"en": "EN_NEWEST",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
"en": "EN-Newest",
"fr": "FR",
"es": "ES",
"zh": "ZH",
"ja": "JP",
"ko": "KR",
}
class MeloTTSHandler(BaseHandler):
def setup(
self,
should_listen,
device="mps",
language="EN_NEWEST",
speaker_to_id="EN-Newest",
language="en",
speaker_to_id="en",
gen_kwargs={}, # Unused
blocksize=512,
):
self.should_listen = should_listen
self.device = device
self.model = TTS(language=language, device=device)
self.speaker_id = self.model.hps.data.spk2id[speaker_to_id]
self.language = language
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
]
self.blocksize = blocksize
self.warmup()
@@ -33,7 +56,28 @@ class MeloTTSHandler(BaseHandler):
_ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
def process(self, llm_sentence):
language_code = None
if isinstance(llm_sentence, tuple):
llm_sentence, language_code = llm_sentence
console.print(f"[green]ASSISTANT: {llm_sentence}")
if language_code is not None and self.language != language_code:
try:
self.model = TTS(
language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code],
device=self.device,
)
self.speaker_id = self.model.hps.data.spk2id[
WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code]
]
self.language = language_code
except KeyError:
console.print(
f"[red]Language {language_code} not supported by Melo. Using {self.language} instead."
)
if self.device == "mps":
import time
@@ -44,7 +88,13 @@ class MeloTTSHandler(BaseHandler):
time.time() - start
) # Removing this line makes it fail more often. I'm looking into it.
audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True)
try:
audio_chunk = self.model.tts_to_file(
llm_sentence, self.speaker_id, quiet=True
)
except (AssertionError, RuntimeError) as e:
logger.error(f"Error in MeloTTSHandler: {e}")
audio_chunk = np.array([])
if len(audio_chunk) == 0:
self.should_listen.set()
return

View File

@@ -147,6 +147,9 @@ class ParlerTTSHandler(BaseHandler):
)
def process(self, llm_sentence):
if isinstance(llm_sentence, tuple):
llm_sentence, _ = llm_sentence
console.print(f"[green]ASSISTANT: {llm_sentence}")
nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)

View File

@@ -86,3 +86,7 @@ class VADHandler(BaseHandler):
)
array = enhanced.numpy().squeeze()
yield array
@property
def min_time_to_debug(self):
return 0.00001

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass, field
@dataclass
class MeloTTSHandlerArguments:
melo_language: str = field(
default="EN_NEWEST",
default="en",
metadata={
"help": "The language of the text to be synthesized. Default is 'EN_NEWEST'."
},
@@ -16,7 +16,7 @@ class MeloTTSHandlerArguments:
},
)
melo_speaker_to_id: str = field(
default="EN-Newest",
default="en",
metadata={
"help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']."
},

View File

@@ -11,7 +11,7 @@ class ModuleArguments:
mode: Optional[str] = field(
default="socket",
metadata={
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'socket'."
},
)
local_mac_optimal_settings: bool = field(

View File

@@ -42,6 +42,6 @@ class VADHandlerArguments:
audio_enhancement: bool = field(
default=False,
metadata={
"help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is True."
"help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is False."
},
)

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
@@ -51,9 +52,13 @@ class WhisperSTTHandlerArguments:
"help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
},
)
stt_gen_language: str = field(
default="en",
language: Optional[str] = field(
default='en',
metadata={
"help": "The language of the speech to transcribe. Default is 'en' for English."
"help": """The language for the conversation.
Choose between 'en' (english), 'fr' (french), 'es' (spanish),
'zh' (chinese), 'ko' (korean), 'ja' (japanese), or 'None'.
If using 'auto', the language is automatically detected and can
change during the conversation. Default is 'en'."""
},
)
)

View File

@@ -36,7 +36,8 @@ class BaseHandler:
start_time = perf_counter()
for output in self.process(input):
self._times.append(perf_counter() - start_time)
logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
if self.last_time > self.min_time_to_debug:
logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
self.queue_out.put(output)
start_time = perf_counter()
@@ -46,6 +47,10 @@ class BaseHandler:
@property
def last_time(self):
return self._times[-1]
@property
def min_time_to_debug(self):
return 0.001
def cleanup(self):
pass

View File

@@ -4,6 +4,7 @@ services:
pipeline:
build:
context: .
dockerfile: ${DOCKERFILE:-Dockerfile}
command:
- python3
- s2s_pipeline.py

View File

@@ -299,7 +299,6 @@ def main():
setup_args=(should_listen,),
setup_kwargs=vars(parler_tts_handler_kwargs),
)
elif module_kwargs.tts == "melo":
try:
from TTS.melo_handler import MeloTTSHandler