refactor arguments folder + run ruff
This commit is contained in:
@@ -1,6 +1,3 @@
|
||||
|
||||
|
||||
|
||||
class Chat:
|
||||
"""
|
||||
Handles the chat using to avoid OOM issues.
|
||||
|
||||
@@ -4,6 +4,7 @@ from baseHandler import BaseHandler
|
||||
from mlx_lm import load, stream_generate, generate
|
||||
from rich.console import Console
|
||||
import torch
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
@@ -11,6 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class MLXLanguageModelHandler(BaseHandler):
|
||||
"""
|
||||
Handles the language model part.
|
||||
@@ -28,7 +30,7 @@ class MLXLanguageModelHandler(BaseHandler):
|
||||
init_chat_prompt="You are a helpful AI assistant.",
|
||||
):
|
||||
self.model_name = model_name
|
||||
model_id = 'microsoft/Phi-3-mini-4k-instruct'
|
||||
model_id = "microsoft/Phi-3-mini-4k-instruct"
|
||||
self.model, self.tokenizer = load(model_id)
|
||||
self.gen_kwargs = gen_kwargs
|
||||
|
||||
@@ -48,28 +50,40 @@ class MLXLanguageModelHandler(BaseHandler):
|
||||
|
||||
dummy_input_text = "Write me a poem about Machine Learning."
|
||||
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
|
||||
|
||||
|
||||
n_steps = 2
|
||||
|
||||
for _ in range(n_steps):
|
||||
prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False)
|
||||
generate(self.model, self.tokenizer, prompt=prompt, max_tokens=self.gen_kwargs["max_new_tokens"], verbose=False)
|
||||
|
||||
generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=self.gen_kwargs["max_new_tokens"],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
def process(self, prompt):
|
||||
logger.debug("infering language model...")
|
||||
|
||||
self.chat.append({"role": self.user_role, "content": prompt})
|
||||
prompt = self.tokenizer.apply_chat_template(self.chat.to_list(), tokenize=False, add_generation_prompt=True)
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
self.chat.to_list(), tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
output = ""
|
||||
curr_output = ""
|
||||
for t in stream_generate(self.model, self.tokenizer, prompt, max_tokens=self.gen_kwargs["max_new_tokens"]):
|
||||
for t in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt,
|
||||
max_tokens=self.gen_kwargs["max_new_tokens"],
|
||||
):
|
||||
output += t
|
||||
curr_output += t
|
||||
if curr_output.endswith(('.', '?', '!', '<|end|>')):
|
||||
yield curr_output.replace('<|end|>', '')
|
||||
if curr_output.endswith((".", "?", "!", "<|end|>")):
|
||||
yield curr_output.replace("<|end|>", "")
|
||||
curr_output = ""
|
||||
generated_text = output.replace('<|end|>', '')
|
||||
generated_text = output.replace("<|end|>", "")
|
||||
torch.mps.empty_cache()
|
||||
|
||||
self.chat.append({"role": "assistant", "content": generated_text})
|
||||
|
||||
@@ -5,6 +5,7 @@ from lightning_whisper_mlx import LightningWhisperMLX
|
||||
import numpy as np
|
||||
from rich.console import Console
|
||||
import torch
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
@@ -26,12 +27,10 @@ class LightningWhisperSTTHandler(BaseHandler):
|
||||
compile_mode=None,
|
||||
gen_kwargs={},
|
||||
):
|
||||
if len(model_name.split('/')) > 1:
|
||||
model_name = model_name.split('/')[-1]
|
||||
if len(model_name.split("/")) > 1:
|
||||
model_name = model_name.split("/")[-1]
|
||||
self.device = device
|
||||
self.model = LightningWhisperMLX(
|
||||
model=model_name, batch_size=6, quant=None
|
||||
)
|
||||
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
|
||||
self.warmup()
|
||||
|
||||
def warmup(self):
|
||||
|
||||
@@ -40,10 +40,13 @@ class MeloTTSHandler(BaseHandler):
|
||||
console.print(f"[green]ASSISTANT: {llm_sentence}")
|
||||
if self.device == "mps":
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
|
||||
torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
|
||||
time_it_took = time.time()-start # Removing this line makes it fail more often. I'm looking into it.
|
||||
_ = (
|
||||
time.time() - start
|
||||
) # Removing this line makes it fail more often. I'm looking into it.
|
||||
|
||||
audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True)
|
||||
if len(audio_chunk) == 0:
|
||||
|
||||
65
arguments_classes/language_model_arguments.py
Normal file
65
arguments_classes/language_model_arguments.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageModelHandlerArguments:
|
||||
lm_model_name: str = field(
|
||||
default="HuggingFaceTB/SmolLM-360M-Instruct",
|
||||
metadata={
|
||||
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
|
||||
},
|
||||
)
|
||||
lm_device: str = field(
|
||||
default="cuda",
|
||||
metadata={
|
||||
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
||||
},
|
||||
)
|
||||
lm_torch_dtype: str = field(
|
||||
default="float16",
|
||||
metadata={
|
||||
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
|
||||
},
|
||||
)
|
||||
user_role: str = field(
|
||||
default="user",
|
||||
metadata={
|
||||
"help": "Role assigned to the user in the chat context. Default is 'user'."
|
||||
},
|
||||
)
|
||||
init_chat_role: str = field(
|
||||
default="system",
|
||||
metadata={
|
||||
"help": "Initial role for setting up the chat context. Default is 'system'."
|
||||
},
|
||||
)
|
||||
init_chat_prompt: str = field(
|
||||
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
|
||||
metadata={
|
||||
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
|
||||
},
|
||||
)
|
||||
lm_gen_max_new_tokens: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "Maximum number of new tokens to generate in a single completion. Default is 128."
|
||||
},
|
||||
)
|
||||
lm_gen_temperature: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
|
||||
},
|
||||
)
|
||||
lm_gen_do_sample: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
|
||||
},
|
||||
)
|
||||
chat_size: int = field(
|
||||
default=2,
|
||||
metadata={
|
||||
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
|
||||
},
|
||||
)
|
||||
@@ -1,6 +1,4 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -23,4 +21,3 @@ class MeloTTSHandlerArguments:
|
||||
"help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']."
|
||||
},
|
||||
)
|
||||
|
||||
46
arguments_classes/module_arguments.py
Normal file
46
arguments_classes/module_arguments.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleArguments:
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "If specified, overrides the device for all handlers."},
|
||||
)
|
||||
mode: Optional[str] = field(
|
||||
default="socket",
|
||||
metadata={
|
||||
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
|
||||
},
|
||||
)
|
||||
local_mac_optimal_settings: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used."
|
||||
},
|
||||
)
|
||||
stt: Optional[str] = field(
|
||||
default="whisper",
|
||||
metadata={
|
||||
"help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'."
|
||||
},
|
||||
)
|
||||
llm: Optional[str] = field(
|
||||
default="transformers",
|
||||
metadata={
|
||||
"help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'"
|
||||
},
|
||||
)
|
||||
tts: Optional[str] = field(
|
||||
default="parler",
|
||||
metadata={
|
||||
"help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'"
|
||||
},
|
||||
)
|
||||
log_level: str = field(
|
||||
default="info",
|
||||
metadata={
|
||||
"help": "Provide logging level. Example --log_level debug, default=warning."
|
||||
},
|
||||
)
|
||||
62
arguments_classes/parler_tts_arguments.py
Normal file
62
arguments_classes/parler_tts_arguments.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParlerTTSHandlerArguments:
|
||||
tts_model_name: str = field(
|
||||
default="ylacombe/parler-tts-mini-jenny-30H",
|
||||
metadata={
|
||||
"help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
|
||||
},
|
||||
)
|
||||
tts_device: str = field(
|
||||
default="cuda",
|
||||
metadata={
|
||||
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
||||
},
|
||||
)
|
||||
tts_torch_dtype: str = field(
|
||||
default="float16",
|
||||
metadata={
|
||||
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
|
||||
},
|
||||
)
|
||||
tts_compile_mode: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
|
||||
},
|
||||
)
|
||||
tts_gen_min_new_tokens: int = field(
|
||||
default=64,
|
||||
metadata={
|
||||
"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"
|
||||
},
|
||||
)
|
||||
tts_gen_max_new_tokens: int = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
|
||||
},
|
||||
)
|
||||
description: str = field(
|
||||
default=(
|
||||
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
|
||||
"She speaks very fast."
|
||||
),
|
||||
metadata={
|
||||
"help": "Description of the speaker's voice and speaking style to guide the TTS model."
|
||||
},
|
||||
)
|
||||
play_steps_s: float = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
|
||||
},
|
||||
)
|
||||
max_prompt_pad_length: int = field(
|
||||
default=8,
|
||||
metadata={
|
||||
"help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
|
||||
},
|
||||
)
|
||||
24
arguments_classes/socket_receiver_arguments.py
Normal file
24
arguments_classes/socket_receiver_arguments.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SocketReceiverArguments:
|
||||
recv_host: str = field(
|
||||
default="localhost",
|
||||
metadata={
|
||||
"help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all "
|
||||
"available interfaces on the host machine."
|
||||
},
|
||||
)
|
||||
recv_port: int = field(
|
||||
default=12345,
|
||||
metadata={
|
||||
"help": "The port number on which the socket server listens. Default is 12346."
|
||||
},
|
||||
)
|
||||
chunk_size: int = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes."
|
||||
},
|
||||
)
|
||||
18
arguments_classes/socket_sender_arguments.py
Normal file
18
arguments_classes/socket_sender_arguments.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SocketSenderArguments:
|
||||
send_host: str = field(
|
||||
default="localhost",
|
||||
metadata={
|
||||
"help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all "
|
||||
"available interfaces on the host machine."
|
||||
},
|
||||
)
|
||||
send_port: int = field(
|
||||
default=12346,
|
||||
metadata={
|
||||
"help": "The port number on which the socket server listens. Default is 12346."
|
||||
},
|
||||
)
|
||||
41
arguments_classes/vad_arguments.py
Normal file
41
arguments_classes/vad_arguments.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class VADHandlerArguments:
|
||||
thresh: float = field(
|
||||
default=0.3,
|
||||
metadata={
|
||||
"help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection."
|
||||
},
|
||||
)
|
||||
sample_rate: int = field(
|
||||
default=16000,
|
||||
metadata={
|
||||
"help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio."
|
||||
},
|
||||
)
|
||||
min_silence_ms: int = field(
|
||||
default=250,
|
||||
metadata={
|
||||
"help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms."
|
||||
},
|
||||
)
|
||||
min_speech_ms: int = field(
|
||||
default=500,
|
||||
metadata={
|
||||
"help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms."
|
||||
},
|
||||
)
|
||||
max_speech_ms: float = field(
|
||||
default=float("inf"),
|
||||
metadata={
|
||||
"help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments."
|
||||
},
|
||||
)
|
||||
speech_pad_ms: int = field(
|
||||
default=250,
|
||||
metadata={
|
||||
"help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
|
||||
},
|
||||
)
|
||||
59
arguments_classes/whisper_stt_arguments.py
Normal file
59
arguments_classes/whisper_stt_arguments.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhisperSTTHandlerArguments:
|
||||
stt_model_name: str = field(
|
||||
default="distil-whisper/distil-large-v3",
|
||||
metadata={
|
||||
"help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
|
||||
},
|
||||
)
|
||||
stt_device: str = field(
|
||||
default="cuda",
|
||||
metadata={
|
||||
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
||||
},
|
||||
)
|
||||
stt_torch_dtype: str = field(
|
||||
default="float16",
|
||||
metadata={
|
||||
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
|
||||
},
|
||||
)
|
||||
stt_compile_mode: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
|
||||
},
|
||||
)
|
||||
stt_gen_max_new_tokens: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum number of new tokens to generate. Default is 128."
|
||||
},
|
||||
)
|
||||
stt_gen_num_beams: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "The number of beams for beam search. Default is 1, implying greedy decoding."
|
||||
},
|
||||
)
|
||||
stt_gen_return_timestamps: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to return timestamps with transcriptions. Default is False."
|
||||
},
|
||||
)
|
||||
stt_gen_task: str = field(
|
||||
default="transcribe",
|
||||
metadata={
|
||||
"help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
|
||||
},
|
||||
)
|
||||
stt_gen_language: str = field(
|
||||
default="en",
|
||||
metadata={
|
||||
"help": "The language of the speech to transcribe. Default is 'en' for English."
|
||||
},
|
||||
)
|
||||
362
s2s_pipeline.py
362
s2s_pipeline.py
@@ -4,7 +4,6 @@ import socket
|
||||
import sys
|
||||
import threading
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from threading import Event, Thread
|
||||
@@ -12,9 +11,16 @@ from time import perf_counter
|
||||
from typing import Optional
|
||||
from sys import platform
|
||||
from LLM.mlx_lm import MLXLanguageModelHandler
|
||||
from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
|
||||
from arguments_classes.module_arguments import ModuleArguments
|
||||
from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
|
||||
from arguments_classes.socket_receiver_arguments import SocketReceiverArguments
|
||||
from arguments_classes.socket_sender_arguments import SocketSenderArguments
|
||||
from arguments_classes.vad_arguments import VADHandlerArguments
|
||||
from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments
|
||||
from baseHandler import BaseHandler
|
||||
from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
|
||||
from handlers.melo_tts_handler import MeloTTSHandlerArguments
|
||||
from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments
|
||||
import numpy as np
|
||||
import torch
|
||||
import nltk
|
||||
@@ -37,9 +43,9 @@ from utils import VADIterator, int2float, next_power_of_2
|
||||
|
||||
# Ensure that the necessary NLTK resources are available
|
||||
try:
|
||||
nltk.data.find('tokenizers/punkt_tab')
|
||||
nltk.data.find("tokenizers/punkt_tab")
|
||||
except (LookupError, OSError):
|
||||
nltk.download('punkt_tab')
|
||||
nltk.download("punkt_tab")
|
||||
|
||||
# caching allows ~50% compilation time reduction
|
||||
# see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma
|
||||
@@ -50,50 +56,6 @@ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp")
|
||||
console = Console()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleArguments:
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "If specified, overrides the device for all handlers."},
|
||||
)
|
||||
mode: Optional[str] = field(
|
||||
default="socket",
|
||||
metadata={
|
||||
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
|
||||
},
|
||||
)
|
||||
local_mac_optimal_settings: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used."
|
||||
},
|
||||
)
|
||||
stt: Optional[str] = field(
|
||||
default="whisper",
|
||||
metadata={
|
||||
"help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'."
|
||||
},
|
||||
)
|
||||
llm: Optional[str] = field(
|
||||
default="transformers",
|
||||
metadata={
|
||||
"help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'"
|
||||
},
|
||||
)
|
||||
tts: Optional[str] = field(
|
||||
default="parler",
|
||||
metadata={
|
||||
"help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'"
|
||||
},
|
||||
)
|
||||
log_level: str = field(
|
||||
default="info",
|
||||
metadata={
|
||||
"help": "Provide logging level. Example --log_level debug, default=warning."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ThreadManager:
|
||||
"""
|
||||
Manages multiple threads used to execute given handler tasks.
|
||||
@@ -116,29 +78,6 @@ class ThreadManager:
|
||||
thread.join()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SocketReceiverArguments:
|
||||
recv_host: str = field(
|
||||
default="localhost",
|
||||
metadata={
|
||||
"help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all "
|
||||
"available interfaces on the host machine."
|
||||
},
|
||||
)
|
||||
recv_port: int = field(
|
||||
default=12345,
|
||||
metadata={
|
||||
"help": "The port number on which the socket server listens. Default is 12346."
|
||||
},
|
||||
)
|
||||
chunk_size: int = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SocketReceiver:
|
||||
"""
|
||||
Handles reception of the audio packets from the client.
|
||||
@@ -192,23 +131,6 @@ class SocketReceiver:
|
||||
logger.info("Receiver closed")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SocketSenderArguments:
|
||||
send_host: str = field(
|
||||
default="localhost",
|
||||
metadata={
|
||||
"help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all "
|
||||
"available interfaces on the host machine."
|
||||
},
|
||||
)
|
||||
send_port: int = field(
|
||||
default=12346,
|
||||
metadata={
|
||||
"help": "The port number on which the socket server listens. Default is 12346."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SocketSender:
|
||||
"""
|
||||
Handles sending generated audio packets to the clients.
|
||||
@@ -238,46 +160,6 @@ class SocketSender:
|
||||
logger.info("Sender closed")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VADHandlerArguments:
|
||||
thresh: float = field(
|
||||
default=0.3,
|
||||
metadata={
|
||||
"help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection."
|
||||
},
|
||||
)
|
||||
sample_rate: int = field(
|
||||
default=16000,
|
||||
metadata={
|
||||
"help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio."
|
||||
},
|
||||
)
|
||||
min_silence_ms: int = field(
|
||||
default=250,
|
||||
metadata={
|
||||
"help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms."
|
||||
},
|
||||
)
|
||||
min_speech_ms: int = field(
|
||||
default=500,
|
||||
metadata={
|
||||
"help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms."
|
||||
},
|
||||
)
|
||||
max_speech_ms: float = field(
|
||||
default=float("inf"),
|
||||
metadata={
|
||||
"help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments."
|
||||
},
|
||||
)
|
||||
speech_pad_ms: int = field(
|
||||
default=250,
|
||||
metadata={
|
||||
"help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class VADHandler(BaseHandler):
|
||||
"""
|
||||
Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
|
||||
@@ -326,64 +208,6 @@ class VADHandler(BaseHandler):
|
||||
yield array
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhisperSTTHandlerArguments:
|
||||
stt_model_name: str = field(
|
||||
default="distil-whisper/distil-large-v3",
|
||||
metadata={
|
||||
"help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
|
||||
},
|
||||
)
|
||||
stt_device: str = field(
|
||||
default="cuda",
|
||||
metadata={
|
||||
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
||||
},
|
||||
)
|
||||
stt_torch_dtype: str = field(
|
||||
default="float16",
|
||||
metadata={
|
||||
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
|
||||
},
|
||||
)
|
||||
stt_compile_mode: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
|
||||
},
|
||||
)
|
||||
stt_gen_max_new_tokens: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum number of new tokens to generate. Default is 128."
|
||||
},
|
||||
)
|
||||
stt_gen_num_beams: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "The number of beams for beam search. Default is 1, implying greedy decoding."
|
||||
},
|
||||
)
|
||||
stt_gen_return_timestamps: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to return timestamps with transcriptions. Default is False."
|
||||
},
|
||||
)
|
||||
# stt_gen_task: str = field(
|
||||
# default="transcribe",
|
||||
# metadata={
|
||||
# "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
|
||||
# },
|
||||
# )
|
||||
# stt_gen_language: str = field(
|
||||
# default="en",
|
||||
# metadata={
|
||||
# "help": "The language of the speech to transcribe. Default is 'en' for English."
|
||||
# },
|
||||
# )
|
||||
|
||||
|
||||
class WhisperSTTHandler(BaseHandler):
|
||||
"""
|
||||
Handles the Speech To Text generation using a Whisper model.
|
||||
@@ -480,70 +304,6 @@ class WhisperSTTHandler(BaseHandler):
|
||||
yield pred_text
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageModelHandlerArguments:
|
||||
lm_model_name: str = field(
|
||||
default="HuggingFaceTB/SmolLM-360M-Instruct",
|
||||
metadata={
|
||||
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
|
||||
},
|
||||
)
|
||||
lm_device: str = field(
|
||||
default="cuda",
|
||||
metadata={
|
||||
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
||||
},
|
||||
)
|
||||
lm_torch_dtype: str = field(
|
||||
default="float16",
|
||||
metadata={
|
||||
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
|
||||
},
|
||||
)
|
||||
user_role: str = field(
|
||||
default="user",
|
||||
metadata={
|
||||
"help": "Role assigned to the user in the chat context. Default is 'user'."
|
||||
},
|
||||
)
|
||||
init_chat_role: str = field(
|
||||
default='system',
|
||||
metadata={
|
||||
"help": "Initial role for setting up the chat context. Default is 'system'."
|
||||
},
|
||||
)
|
||||
init_chat_prompt: str = field(
|
||||
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
|
||||
metadata={
|
||||
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
|
||||
},
|
||||
)
|
||||
lm_gen_max_new_tokens: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "Maximum number of new tokens to generate in a single completion. Default is 128."
|
||||
},
|
||||
)
|
||||
lm_gen_temperature: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
|
||||
},
|
||||
)
|
||||
lm_gen_do_sample: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
|
||||
},
|
||||
)
|
||||
chat_size: int = field(
|
||||
default=2,
|
||||
metadata={
|
||||
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Chat:
|
||||
"""
|
||||
Handles the chat using to avoid OOM issues.
|
||||
@@ -684,67 +444,6 @@ class LanguageModelHandler(BaseHandler):
|
||||
yield printable_text
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParlerTTSHandlerArguments:
|
||||
tts_model_name: str = field(
|
||||
default="ylacombe/parler-tts-mini-jenny-30H",
|
||||
metadata={
|
||||
"help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
|
||||
},
|
||||
)
|
||||
tts_device: str = field(
|
||||
default="cuda",
|
||||
metadata={
|
||||
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
||||
},
|
||||
)
|
||||
tts_torch_dtype: str = field(
|
||||
default="float16",
|
||||
metadata={
|
||||
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
|
||||
},
|
||||
)
|
||||
tts_compile_mode: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
|
||||
},
|
||||
)
|
||||
tts_gen_min_new_tokens: int = field(
|
||||
default=64,
|
||||
metadata={
|
||||
"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"
|
||||
},
|
||||
)
|
||||
tts_gen_max_new_tokens: int = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
|
||||
},
|
||||
)
|
||||
description: str = field(
|
||||
default=(
|
||||
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
|
||||
"She speaks very fast."
|
||||
),
|
||||
metadata={
|
||||
"help": "Description of the speaker's voice and speaking style to guide the TTS model."
|
||||
},
|
||||
)
|
||||
play_steps_s: float = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
|
||||
},
|
||||
)
|
||||
max_prompt_pad_length: int = field(
|
||||
default=8,
|
||||
metadata={
|
||||
"help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ParlerTTSHandler(BaseHandler):
|
||||
def setup(
|
||||
self,
|
||||
@@ -886,7 +585,7 @@ class ParlerTTSHandler(BaseHandler):
|
||||
thread.start()
|
||||
|
||||
for i, audio_chunk in enumerate(streamer):
|
||||
if i == 0 and 'pipeline_start' in globals():
|
||||
if i == 0 and "pipeline_start" in globals():
|
||||
logger.info(
|
||||
f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
|
||||
)
|
||||
@@ -971,7 +670,6 @@ def main():
|
||||
if module_kwargs.log_level == "debug":
|
||||
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
|
||||
|
||||
|
||||
def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs):
|
||||
if mac_optimal_settings:
|
||||
for kwargs in handler_kwargs:
|
||||
@@ -1070,14 +768,14 @@ def main():
|
||||
setup_args=(should_listen,),
|
||||
setup_kwargs=vars(vad_handler_kwargs),
|
||||
)
|
||||
if module_kwargs.stt == 'whisper':
|
||||
if module_kwargs.stt == "whisper":
|
||||
stt = WhisperSTTHandler(
|
||||
stop_event,
|
||||
queue_in=spoken_prompt_queue,
|
||||
queue_out=text_prompt_queue,
|
||||
setup_kwargs=vars(whisper_stt_handler_kwargs),
|
||||
)
|
||||
elif module_kwargs.stt == 'whisper-mlx':
|
||||
stop_event,
|
||||
queue_in=spoken_prompt_queue,
|
||||
queue_out=text_prompt_queue,
|
||||
setup_kwargs=vars(whisper_stt_handler_kwargs),
|
||||
)
|
||||
elif module_kwargs.stt == "whisper-mlx":
|
||||
stt = LightningWhisperSTTHandler(
|
||||
stop_event,
|
||||
queue_in=spoken_prompt_queue,
|
||||
@@ -1086,14 +784,14 @@ def main():
|
||||
)
|
||||
else:
|
||||
raise ValueError("The STT should be either whisper or whisper-mlx")
|
||||
if module_kwargs.llm == 'transformers':
|
||||
if module_kwargs.llm == "transformers":
|
||||
lm = LanguageModelHandler(
|
||||
stop_event,
|
||||
queue_in=text_prompt_queue,
|
||||
queue_out=lm_response_queue,
|
||||
setup_kwargs=vars(language_model_handler_kwargs),
|
||||
)
|
||||
elif module_kwargs.llm == 'mlx-lm':
|
||||
stop_event,
|
||||
queue_in=text_prompt_queue,
|
||||
queue_out=lm_response_queue,
|
||||
setup_kwargs=vars(language_model_handler_kwargs),
|
||||
)
|
||||
elif module_kwargs.llm == "mlx-lm":
|
||||
lm = MLXLanguageModelHandler(
|
||||
stop_event,
|
||||
queue_in=text_prompt_queue,
|
||||
@@ -1102,7 +800,7 @@ def main():
|
||||
)
|
||||
else:
|
||||
raise ValueError("The LLM should be either transformers or mlx-lm")
|
||||
if module_kwargs.tts == 'parler':
|
||||
if module_kwargs.tts == "parler":
|
||||
torch._inductor.config.fx_graph_cache = True
|
||||
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
|
||||
torch._dynamo.config.cache_size_limit = 15
|
||||
@@ -1113,12 +811,14 @@ def main():
|
||||
setup_args=(should_listen,),
|
||||
setup_kwargs=vars(parler_tts_handler_kwargs),
|
||||
)
|
||||
|
||||
elif module_kwargs.tts == 'melo':
|
||||
|
||||
elif module_kwargs.tts == "melo":
|
||||
try:
|
||||
from TTS.melotts import MeloTTSHandler
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Error importing MeloTTSHandler. You might need to run: python -m unidic download")
|
||||
logger.error(
|
||||
"Error importing MeloTTSHandler. You might need to run: python -m unidic download"
|
||||
)
|
||||
raise e
|
||||
tts = MeloTTSHandler(
|
||||
stop_event,
|
||||
|
||||
Reference in New Issue
Block a user