10 Commits

Author SHA1 Message Date
Andres Marafioti
d8afa799ac Add support for a docker container for Jetson boards 2024-09-10 14:06:05 +02:00
Andrés Marafioti
3338f88c7a Merge pull request #98 from huggingface/mac-multi-language
Mac multi language
2024-09-10 11:32:09 +02:00
Andres Marafioti
ac1dd41260 improve readme (not according to cursor :( ) 2024-09-09 16:32:25 +02:00
Andres Marafioti
129cd11bf0 Fixes to bugs from original PR 2024-09-09 16:23:20 +02:00
Andres Marafioti
1d3b5bfc40 avoid debug logging for melo 2024-09-09 15:39:08 +02:00
Andrés Marafioti
b3ace711b0 Merge pull request #93 from ybm911/main
Update: Added multi-language support for macOS.
2024-09-09 10:01:19 +02:00
Andrés Marafioti
86c08082f8 Merge pull request #91 from BrutalCoding/patch-1
fix: Changed [True] to [False] in help text for audio_enhancement to align with actual default
2024-09-09 09:35:36 +02:00
Elli0t
29fedd0720 Update: Added multi-language support for macOS 2024-09-08 02:51:23 +08:00
Daniel Breedeveld
6daa9baf68 Update vad_arguments.py
Fixed help comment mismatch with actual default value.
2024-09-07 12:09:46 +08:00
Andrés Marafioti
d98e252769 Merge pull request #87 from huggingface/hotfix-parler
fix
2024-09-05 17:07:11 +02:00
8 changed files with 142 additions and 25 deletions

13
Dockerfile.arm64 Normal file
View File

@@ -0,0 +1,13 @@
FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3
ENV PYTHONUNBUFFERED 1
WORKDIR /usr/src/app
# Install packages
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt
COPY . .

View File

@@ -9,6 +9,14 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}
class MLXLanguageModelHandler(BaseHandler):
"""
@@ -61,6 +69,11 @@ class MLXLanguageModelHandler(BaseHandler):
def process(self, prompt):
logger.debug("infering language model...")
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
@@ -86,9 +99,9 @@ class MLXLanguageModelHandler(BaseHandler):
output += t
curr_output += t
if curr_output.endswith((".", "?", "!", "<|end|>")):
yield curr_output.replace("<|end|>", "")
yield (curr_output.replace("<|end|>", ""), language_code)
curr_output = ""
generated_text = output.replace("<|end|>", "")
torch.mps.empty_cache()
self.chat.append({"role": "assistant", "content": generated_text})
self.chat.append({"role": "assistant", "content": generated_text})

View File

@@ -79,27 +79,28 @@ https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install
### Server/Client Approach
To run the pipeline on the server:
```bash
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
```
1. Run the pipeline on the server:
```bash
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
```
Then run the client locally to handle sending microphone input and receiving generated audio:
```bash
python listen_and_play.py --host <IP address of your server>
```
2. Run the client locally to handle microphone input and receive generated audio:
```bash
python listen_and_play.py --host <IP address of your server>
```
### Local approach (running on Mac)
To run on mac, we recommend setting the flag `--local_mac_optimal_settings`:
```bash
python s2s_pipeline.py --local_mac_optimal_settings
```
### Local Approach (Mac)
You can also pass `--device mps` to have all the models set to device mps.
The local mac optimal settings set the mode to be local as explained above and change the models to:
- LightningWhisperMLX
- MLX LM
- MeloTTS
1. For optimal settings on Mac:
```bash
python s2s_pipeline.py --local_mac_optimal_settings
```
This setting:
- Adds `--device mps` to use MPS for all models.
- Sets LightningWhisperMLX for STT
- Sets MLX LM for language model
- Sets MeloTTS for TTS
### Recommended usage with Cuda
@@ -117,6 +118,57 @@ python s2s_pipeline.py \
For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`).
### Multi-language Support
The pipeline supports multiple languages, allowing for automatic language detection or specific language settings. Here are examples for both local (Mac) and server setups:
#### With the server version:
For automatic language detection:
```bash
python s2s_pipeline.py \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
```
Or for one language in particular, chinese in this example
```bash
python s2s_pipeline.py \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
```
#### Local Mac Setup
For automatic language detection:
```bash
python s2s_pipeline.py \
--local_mac_optimal_settings \
--device mps \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
```
Or for one language in particular, chinese in this example
```bash
python s2s_pipeline.py \
--local_mac_optimal_settings \
--device mps \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
```
## Command-line Usage
### Model Parameters

View File

@@ -4,12 +4,22 @@ from baseHandler import BaseHandler
from lightning_whisper_mlx import LightningWhisperMLX
import numpy as np
from rich.console import Console
from copy import copy
import torch
logger = logging.getLogger(__name__)
console = Console()
SUPPORTED_LANGUAGES = [
"en",
"fr",
"es",
"zh",
"ja",
"ko",
]
class LightningWhisperSTTHandler(BaseHandler):
"""
@@ -19,7 +29,7 @@ class LightningWhisperSTTHandler(BaseHandler):
def setup(
self,
model_name="distil-large-v3",
device="cuda",
device="mps",
torch_dtype="float16",
compile_mode=None,
language=None,
@@ -29,6 +39,9 @@ class LightningWhisperSTTHandler(BaseHandler):
model_name = model_name.split("/")[-1]
self.device = device
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
self.start_language = language
self.last_language = language
self.warmup()
def warmup(self):
@@ -47,10 +60,26 @@ class LightningWhisperSTTHandler(BaseHandler):
global pipeline_start
pipeline_start = perf_counter()
pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
if self.start_language != 'auto':
transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
else:
transcription_dict = self.model.transcribe(spoken_prompt)
language_code = transcription_dict["language"]
if language_code not in SUPPORTED_LANGUAGES:
logger.warning(f"Whisper detected unsupported language: {language_code}")
if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
else:
transcription_dict = {"text": "", "language": "en"}
else:
self.last_language = language_code
pred_text = transcription_dict["text"].strip()
language_code = transcription_dict["language"]
torch.mps.empty_cache()
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}")
yield pred_text
yield (pred_text, language_code)

View File

@@ -86,3 +86,7 @@ class VADHandler(BaseHandler):
)
array = enhanced.numpy().squeeze()
yield array
@property
def min_time_to_debug(self):
return 0.00001

View File

@@ -42,6 +42,6 @@ class VADHandlerArguments:
audio_enhancement: bool = field(
default=False,
metadata={
"help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is True."
"help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is False."
},
)

View File

@@ -36,7 +36,8 @@ class BaseHandler:
start_time = perf_counter()
for output in self.process(input):
self._times.append(perf_counter() - start_time)
logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
if self.last_time > self.min_time_to_debug:
logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
self.queue_out.put(output)
start_time = perf_counter()
@@ -46,6 +47,10 @@ class BaseHandler:
@property
def last_time(self):
return self._times[-1]
@property
def min_time_to_debug(self):
return 0.001
def cleanup(self):
pass

View File

@@ -4,6 +4,7 @@ services:
pipeline:
build:
context: .
dockerfile: ${DOCKERFILE:-Dockerfile}
command:
- python3
- s2s_pipeline.py