25 Commits

Author SHA1 Message Date
Andres Marafioti
8e9b069f3a push 2024-09-11 14:28:00 +02:00
Andres Marafioti
97341ed914 lower the volume 2024-09-11 14:23:01 +02:00
Andres Marafioti
c08a02159d remove debug 2024-09-11 13:32:39 +02:00
Andres Marafioti
1a42b69c09 try with device 2024-09-11 13:30:00 +02:00
Andres Marafioti
cf9f7b41aa adding debugs 2024-09-11 13:24:41 +02:00
Andres Marafioti
5fd9615b5f print the exception 2024-09-11 13:20:46 +02:00
Andres Marafioti
702c14553f test 2024-09-11 13:08:19 +02:00
Andres Marafioti
d8afa799ac Add support for a docker container for Jetson boards 2024-09-10 14:06:05 +02:00
Andrés Marafioti
3338f88c7a Merge pull request #98 from huggingface/mac-multi-language
Mac multi language
2024-09-10 11:32:09 +02:00
Andres Marafioti
ac1dd41260 improve readme (not according to cursor :( ) 2024-09-09 16:32:25 +02:00
Andres Marafioti
129cd11bf0 Fixes to bugs from original PR 2024-09-09 16:23:20 +02:00
Andres Marafioti
1d3b5bfc40 avoid debug logging for melo 2024-09-09 15:39:08 +02:00
Andrés Marafioti
b3ace711b0 Merge pull request #93 from ybm911/main
Update: Added multi-language support for macOS.
2024-09-09 10:01:19 +02:00
Andrés Marafioti
86c08082f8 Merge pull request #91 from BrutalCoding/patch-1
fix: Changed [True] to [False] in help text for audio_enhancement to align with actual default
2024-09-09 09:35:36 +02:00
Elli0t
29fedd0720 Update: Added multi-language support for macOS 2024-09-08 02:51:23 +08:00
Daniel Breedeveld
6daa9baf68 Update vad_arguments.py
Fixed help comment mismatch with actual default value.
2024-09-07 12:09:46 +08:00
Andrés Marafioti
d98e252769 Merge pull request #87 from huggingface/hotfix-parler
fix
2024-09-05 17:07:11 +02:00
Andres Marafioti
c88eec7e26 fix 2024-09-05 17:02:34 +02:00
Andrés Marafioti
274f27a7a8 Merge pull request #85 from rchan26/main
Fix relative link in README
2024-09-04 20:23:39 +02:00
Andrés Marafioti
144ef43cd1 Merge pull request #84 from rchan26/main
Add language arg to lightning whisper handler
2024-09-04 18:26:31 +02:00
rchan
4662204eae fix relative link in README 2024-09-04 17:18:18 +01:00
rchan
1afd2445d3 add language arg to lightning whisper handler 2024-09-04 17:12:09 +01:00
Andrés Marafioti
8afd078ab4 Merge pull request #78 from AgainstEntropy/patch-1
Update module_arguments.py
2024-09-04 13:55:09 +02:00
Andrés Marafioti
b915e58b76 Merge pull request #60 from huggingface/multi-language
Add support for multiple languages
2024-09-04 13:54:08 +02:00
Yihao Wang
dfed6263d0 Update module_arguments.py
fix typo in the help message of `mode`
2024-09-03 14:49:02 -04:00
12 changed files with 175 additions and 39 deletions

13
Dockerfile.arm64 Normal file
View File

@@ -0,0 +1,13 @@
FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3
ENV PYTHONUNBUFFERED 1
WORKDIR /usr/src/app
# Install packages
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt
COPY . .

View File

@@ -9,6 +9,14 @@ logger = logging.getLogger(__name__)
console = Console()
WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
"en": "english",
"fr": "french",
"es": "spanish",
"zh": "chinese",
"ja": "japanese",
"ko": "korean",
}
class MLXLanguageModelHandler(BaseHandler):
"""
@@ -61,6 +69,11 @@ class MLXLanguageModelHandler(BaseHandler):
def process(self, prompt):
logger.debug("infering language model...")
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
@@ -86,9 +99,9 @@ class MLXLanguageModelHandler(BaseHandler):
output += t
curr_output += t
if curr_output.endswith((".", "?", "!", "<|end|>")):
yield curr_output.replace("<|end|>", "")
yield (curr_output.replace("<|end|>", ""), language_code)
curr_output = ""
generated_text = output.replace("<|end|>", "")
torch.mps.empty_cache()
self.chat.append({"role": "assistant", "content": generated_text})
self.chat.append({"role": "assistant", "content": generated_text})

View File

@@ -14,7 +14,7 @@
* [Usage](#usage)
- [Docker Server approach](#docker-server)
- [Server/Client approach](#serverclient-approach)
- [Local approach](#local-approach)
- [Local approach](#local-approach-running-on-mac)
* [Command-line usage](#command-line-usage)
- [Model parameters](#model-parameters)
- [Generation parameters](#generation-parameters)
@@ -79,27 +79,28 @@ https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install
### Server/Client Approach
To run the pipeline on the server:
```bash
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
```
1. Run the pipeline on the server:
```bash
python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
```
Then run the client locally to handle sending microphone input and receiving generated audio:
```bash
python listen_and_play.py --host <IP address of your server>
```
2. Run the client locally to handle microphone input and receive generated audio:
```bash
python listen_and_play.py --host <IP address of your server>
```
### Running on Mac
To run on mac, we recommend setting the flag `--local_mac_optimal_settings`:
```bash
python s2s_pipeline.py --local_mac_optimal_settings
```
### Local Approach (Mac)
You can also pass `--device mps` to have all the models set to device mps.
The local mac optimal settings set the mode to be local as explained above and change the models to:
- LightningWhisperMLX
- MLX LM
- MeloTTS
1. For optimal settings on Mac:
```bash
python s2s_pipeline.py --local_mac_optimal_settings
```
This setting:
- Adds `--device mps` to use MPS for all models.
- Sets LightningWhisperMLX for STT
- Sets MLX LM for language model
- Sets MeloTTS for TTS
### Recommended usage with Cuda
@@ -117,6 +118,57 @@ python s2s_pipeline.py \
For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`).
### Multi-language Support
The pipeline supports multiple languages, allowing for automatic language detection or specific language settings. Here are examples for both local (Mac) and server setups:
#### With the server version:
For automatic language detection:
```bash
python s2s_pipeline.py \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
```
Or for one language in particular, chinese in this example
```bash
python s2s_pipeline.py \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
```
#### Local Mac Setup
For automatic language detection:
```bash
python s2s_pipeline.py \
--local_mac_optimal_settings \
--device mps \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
```
Or for one language in particular, chinese in this example
```bash
python s2s_pipeline.py \
--local_mac_optimal_settings \
--device mps \
--stt_model_name large-v3 \
--language zh \
--mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
```
## Command-line Usage
### Model Parameters

View File

@@ -4,12 +4,22 @@ from baseHandler import BaseHandler
from lightning_whisper_mlx import LightningWhisperMLX
import numpy as np
from rich.console import Console
from copy import copy
import torch
logger = logging.getLogger(__name__)
console = Console()
SUPPORTED_LANGUAGES = [
"en",
"fr",
"es",
"zh",
"ja",
"ko",
]
class LightningWhisperSTTHandler(BaseHandler):
"""
@@ -19,15 +29,19 @@ class LightningWhisperSTTHandler(BaseHandler):
def setup(
self,
model_name="distil-large-v3",
device="cuda",
device="mps",
torch_dtype="float16",
compile_mode=None,
language=None,
gen_kwargs={},
):
if len(model_name.split("/")) > 1:
model_name = model_name.split("/")[-1]
self.device = device
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
self.start_language = language
self.last_language = language
self.warmup()
def warmup(self):
@@ -46,10 +60,26 @@ class LightningWhisperSTTHandler(BaseHandler):
global pipeline_start
pipeline_start = perf_counter()
pred_text = self.model.transcribe(spoken_prompt)["text"].strip()
if self.start_language != 'auto':
transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
else:
transcription_dict = self.model.transcribe(spoken_prompt)
language_code = transcription_dict["language"]
if language_code not in SUPPORTED_LANGUAGES:
logger.warning(f"Whisper detected unsupported language: {language_code}")
if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
else:
transcription_dict = {"text": "", "language": "en"}
else:
self.last_language = language_code
pred_text = transcription_dict["text"].strip()
language_code = transcription_dict["language"]
torch.mps.empty_cache()
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
logger.debug(f"Language Code Whisper: {language_code}")
yield pred_text
yield (pred_text, language_code)

View File

@@ -147,6 +147,9 @@ class ParlerTTSHandler(BaseHandler):
)
def process(self, llm_sentence):
if isinstance(llm_sentence, tuple):
llm_sentence, _ = llm_sentence
console.print(f"[green]ASSISTANT: {llm_sentence}")
nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)

View File

@@ -86,3 +86,7 @@ class VADHandler(BaseHandler):
)
array = enhanced.numpy().squeeze()
yield array
@property
def min_time_to_debug(self):
return 0.00001

View File

@@ -11,7 +11,7 @@ class ModuleArguments:
mode: Optional[str] = field(
default="socket",
metadata={
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'socket'."
},
)
local_mac_optimal_settings: bool = field(

View File

@@ -42,6 +42,6 @@ class VADHandlerArguments:
audio_enhancement: bool = field(
default=False,
metadata={
"help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is True."
"help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is False."
},
)

View File

@@ -36,7 +36,8 @@ class BaseHandler:
start_time = perf_counter()
for output in self.process(input):
self._times.append(perf_counter() - start_time)
logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
if self.last_time > self.min_time_to_debug:
logger.debug(f"{self.__class__.__name__}: {self.last_time: .3f} s")
self.queue_out.put(output)
start_time = perf_counter()
@@ -46,6 +47,10 @@ class BaseHandler:
@property
def last_time(self):
return self._times[-1]
@property
def min_time_to_debug(self):
return 0.001
def cleanup(self):
pass

View File

@@ -4,6 +4,7 @@ services:
pipeline:
build:
context: .
dockerfile: ${DOCKERFILE:-Dockerfile}
command:
- python3
- s2s_pipeline.py

View File

@@ -3,8 +3,8 @@ import threading
from queue import Queue
from dataclasses import dataclass, field
import sounddevice as sd
from transformers import HfArgumentParser
import argparse
import numpy as np
@dataclass
class ListenAndPlayArguments:
@@ -29,7 +29,6 @@ class ListenAndPlayArguments:
metadata={"help": "The network port for receiving data. Default is 12346."},
)
def listen_and_play(
send_rate=16000,
recv_rate=44100,
@@ -53,8 +52,14 @@ def listen_and_play(
def callback_recv(outdata, frames, time, status):
if not recv_queue.empty():
data = recv_queue.get()
outdata[: len(data)] = data
outdata[len(data) :] = b"\x00" * (len(outdata) - len(data))
# Convert bytes to numpy array
audio_array = np.frombuffer(data, dtype=np.int16)
# Reduce volume to 30%
audio_array = (audio_array * 0.3).astype(np.int16)
# Convert back to bytes
reduced_data = audio_array.tobytes()
outdata[: len(reduced_data)] = reduced_data
outdata[len(reduced_data) :] = b"\x00" * (len(outdata) - len(reduced_data))
else:
outdata[:] = b"\x00" * len(outdata)
@@ -74,7 +79,7 @@ def listen_and_play(
while len(data) < chunk_size:
packet = conn.recv(chunk_size - len(data))
if not packet:
return None # Connection has been closed
return None
data += packet
return data
@@ -91,13 +96,16 @@ def listen_and_play(
blocksize=list_play_chunk_size,
callback=callback_send,
)
recv_stream = sd.RawOutputStream(
samplerate=recv_rate,
channels=1,
dtype="int16",
blocksize=list_play_chunk_size,
callback=callback_recv,
device=0,
)
threading.Thread(target=send_stream.start).start()
threading.Thread(target=recv_stream.start).start()
@@ -109,7 +117,9 @@ def listen_and_play(
input("Press Enter to stop...")
except KeyboardInterrupt:
print("Finished streaming.")
print("\nProgram interrupted by user. Exiting...")
except Exception as e:
print(f"An error occurred: {e}")
finally:
stop_event.set()
@@ -119,8 +129,14 @@ def listen_and_play(
recv_socket.close()
print("Connection closed.")
if __name__ == "__main__":
parser = HfArgumentParser((ListenAndPlayArguments,))
(listen_and_play_kwargs,) = parser.parse_args_into_dataclasses()
listen_and_play(**vars(listen_and_play_kwargs))
parser = argparse.ArgumentParser(description="Listen and Play Audio")
parser.add_argument("--send_rate", type=int, default=16000, help="In Hz. Default is 16000.")
parser.add_argument("--recv_rate", type=int, default=16000, help="In Hz. Default is 16000.")
parser.add_argument("--list_play_chunk_size", type=int, default=1024, help="The size of data chunks (in bytes). Default is 1024.")
parser.add_argument("--host", type=str, default="localhost", help="The hostname or IP address for listening and playing. Default is 'localhost'.")
parser.add_argument("--send_port", type=int, default=12345, help="The network port for sending data. Default is 12345.")
parser.add_argument("--recv_port", type=int, default=12346, help="The network port for receiving data. Default is 12346.")
args = parser.parse_args()
listen_and_play(**vars(args))

View File

@@ -299,7 +299,6 @@ def main():
setup_args=(should_listen,),
setup_kwargs=vars(parler_tts_handler_kwargs),
)
elif module_kwargs.tts == "melo":
try:
from TTS.melo_handler import MeloTTSHandler