refactor arguments folder + run ruff

This commit is contained in:
Andres Marafioti
2024-08-22 18:02:08 +02:00
parent 0c53fda7dd
commit 696bf85628
14 changed files with 378 additions and 353 deletions

View File

@@ -1,6 +1,3 @@
class Chat:
"""
Handles the chat using to avoid OOM issues.

View File

@@ -4,6 +4,7 @@ from baseHandler import BaseHandler
from mlx_lm import load, stream_generate, generate
from rich.console import Console
import torch
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
@@ -11,6 +12,7 @@ logger = logging.getLogger(__name__)
console = Console()
class MLXLanguageModelHandler(BaseHandler):
"""
Handles the language model part.
@@ -28,7 +30,7 @@ class MLXLanguageModelHandler(BaseHandler):
init_chat_prompt="You are a helpful AI assistant.",
):
self.model_name = model_name
model_id = 'microsoft/Phi-3-mini-4k-instruct'
model_id = "microsoft/Phi-3-mini-4k-instruct"
self.model, self.tokenizer = load(model_id)
self.gen_kwargs = gen_kwargs
@@ -48,28 +50,40 @@ class MLXLanguageModelHandler(BaseHandler):
dummy_input_text = "Write me a poem about Machine Learning."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
n_steps = 2
for _ in range(n_steps):
prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False)
generate(self.model, self.tokenizer, prompt=prompt, max_tokens=self.gen_kwargs["max_new_tokens"], verbose=False)
generate(
self.model,
self.tokenizer,
prompt=prompt,
max_tokens=self.gen_kwargs["max_new_tokens"],
verbose=False,
)
def process(self, prompt):
logger.debug("infering language model...")
self.chat.append({"role": self.user_role, "content": prompt})
prompt = self.tokenizer.apply_chat_template(self.chat.to_list(), tokenize=False, add_generation_prompt=True)
prompt = self.tokenizer.apply_chat_template(
self.chat.to_list(), tokenize=False, add_generation_prompt=True
)
output = ""
curr_output = ""
for t in stream_generate(self.model, self.tokenizer, prompt, max_tokens=self.gen_kwargs["max_new_tokens"]):
for t in stream_generate(
self.model,
self.tokenizer,
prompt,
max_tokens=self.gen_kwargs["max_new_tokens"],
):
output += t
curr_output += t
if curr_output.endswith(('.', '?', '!', '<|end|>')):
yield curr_output.replace('<|end|>', '')
if curr_output.endswith((".", "?", "!", "<|end|>")):
yield curr_output.replace("<|end|>", "")
curr_output = ""
generated_text = output.replace('<|end|>', '')
generated_text = output.replace("<|end|>", "")
torch.mps.empty_cache()
self.chat.append({"role": "assistant", "content": generated_text})

View File

@@ -5,6 +5,7 @@ from lightning_whisper_mlx import LightningWhisperMLX
import numpy as np
from rich.console import Console
import torch
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
@@ -26,12 +27,10 @@ class LightningWhisperSTTHandler(BaseHandler):
compile_mode=None,
gen_kwargs={},
):
if len(model_name.split('/')) > 1:
model_name = model_name.split('/')[-1]
if len(model_name.split("/")) > 1:
model_name = model_name.split("/")[-1]
self.device = device
self.model = LightningWhisperMLX(
model=model_name, batch_size=6, quant=None
)
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
self.warmup()
def warmup(self):

View File

@@ -40,10 +40,13 @@ class MeloTTSHandler(BaseHandler):
console.print(f"[green]ASSISTANT: {llm_sentence}")
if self.device == "mps":
import time
start = time.time()
torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
time_it_took = time.time()-start # Removing this line makes it fail more often. I'm looking into it.
_ = (
time.time() - start
) # Removing this line makes it fail more often. I'm looking into it.
audio_chunk = self.model.tts_to_file(llm_sentence, self.speaker_id, quiet=True)
if len(audio_chunk) == 0:

View File

@@ -0,0 +1,65 @@
from dataclasses import dataclass, field
@dataclass
class LanguageModelHandlerArguments:
lm_model_name: str = field(
default="HuggingFaceTB/SmolLM-360M-Instruct",
metadata={
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
},
)
lm_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
lm_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
user_role: str = field(
default="user",
metadata={
"help": "Role assigned to the user in the chat context. Default is 'user'."
},
)
init_chat_role: str = field(
default="system",
metadata={
"help": "Initial role for setting up the chat context. Default is 'system'."
},
)
init_chat_prompt: str = field(
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
},
)
lm_gen_max_new_tokens: int = field(
default=128,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 128."
},
)
lm_gen_temperature: float = field(
default=0.0,
metadata={
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
},
)
lm_gen_do_sample: bool = field(
default=False,
metadata={
"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
},
)
chat_size: int = field(
default=2,
metadata={
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
},
)

View File

@@ -1,6 +1,4 @@
from dataclasses import dataclass, field
from typing import List
@dataclass
@@ -23,4 +21,3 @@ class MeloTTSHandlerArguments:
"help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']."
},
)

View File

@@ -0,0 +1,46 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModuleArguments:
device: Optional[str] = field(
default=None,
metadata={"help": "If specified, overrides the device for all handlers."},
)
mode: Optional[str] = field(
default="socket",
metadata={
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
},
)
local_mac_optimal_settings: bool = field(
default=False,
metadata={
"help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used."
},
)
stt: Optional[str] = field(
default="whisper",
metadata={
"help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'."
},
)
llm: Optional[str] = field(
default="transformers",
metadata={
"help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'"
},
)
tts: Optional[str] = field(
default="parler",
metadata={
"help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'"
},
)
log_level: str = field(
default="info",
metadata={
"help": "Provide logging level. Example --log_level debug, default=warning."
},
)

View File

@@ -0,0 +1,62 @@
from dataclasses import dataclass, field
@dataclass
class ParlerTTSHandlerArguments:
tts_model_name: str = field(
default="ylacombe/parler-tts-mini-jenny-30H",
metadata={
"help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
},
)
tts_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
tts_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
tts_compile_mode: str = field(
default=None,
metadata={
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
},
)
tts_gen_min_new_tokens: int = field(
default=64,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"
},
)
tts_gen_max_new_tokens: int = field(
default=512,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
},
)
description: str = field(
default=(
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
"She speaks very fast."
),
metadata={
"help": "Description of the speaker's voice and speaking style to guide the TTS model."
},
)
play_steps_s: float = field(
default=1.0,
metadata={
"help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
},
)
max_prompt_pad_length: int = field(
default=8,
metadata={
"help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
},
)

View File

@@ -0,0 +1,24 @@
from dataclasses import dataclass, field
@dataclass
class SocketReceiverArguments:
recv_host: str = field(
default="localhost",
metadata={
"help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all "
"available interfaces on the host machine."
},
)
recv_port: int = field(
default=12345,
metadata={
"help": "The port number on which the socket server listens. Default is 12346."
},
)
chunk_size: int = field(
default=1024,
metadata={
"help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes."
},
)

View File

@@ -0,0 +1,18 @@
from dataclasses import dataclass, field
@dataclass
class SocketSenderArguments:
send_host: str = field(
default="localhost",
metadata={
"help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all "
"available interfaces on the host machine."
},
)
send_port: int = field(
default=12346,
metadata={
"help": "The port number on which the socket server listens. Default is 12346."
},
)

View File

@@ -0,0 +1,41 @@
from dataclasses import dataclass, field
@dataclass
class VADHandlerArguments:
thresh: float = field(
default=0.3,
metadata={
"help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection."
},
)
sample_rate: int = field(
default=16000,
metadata={
"help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio."
},
)
min_silence_ms: int = field(
default=250,
metadata={
"help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms."
},
)
min_speech_ms: int = field(
default=500,
metadata={
"help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms."
},
)
max_speech_ms: float = field(
default=float("inf"),
metadata={
"help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments."
},
)
speech_pad_ms: int = field(
default=250,
metadata={
"help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
},
)

View File

@@ -0,0 +1,59 @@
from dataclasses import dataclass, field
@dataclass
class WhisperSTTHandlerArguments:
stt_model_name: str = field(
default="distil-whisper/distil-large-v3",
metadata={
"help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
},
)
stt_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
stt_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
stt_compile_mode: str = field(
default=None,
metadata={
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
},
)
stt_gen_max_new_tokens: int = field(
default=128,
metadata={
"help": "The maximum number of new tokens to generate. Default is 128."
},
)
stt_gen_num_beams: int = field(
default=1,
metadata={
"help": "The number of beams for beam search. Default is 1, implying greedy decoding."
},
)
stt_gen_return_timestamps: bool = field(
default=False,
metadata={
"help": "Whether to return timestamps with transcriptions. Default is False."
},
)
stt_gen_task: str = field(
default="transcribe",
metadata={
"help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
},
)
stt_gen_language: str = field(
default="en",
metadata={
"help": "The language of the speech to transcribe. Default is 'en' for English."
},
)

View File

@@ -4,7 +4,6 @@ import socket
import sys
import threading
from copy import copy
from dataclasses import dataclass, field
from pathlib import Path
from queue import Queue
from threading import Event, Thread
@@ -12,9 +11,16 @@ from time import perf_counter
from typing import Optional
from sys import platform
from LLM.mlx_lm import MLXLanguageModelHandler
from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
from arguments_classes.module_arguments import ModuleArguments
from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
from arguments_classes.socket_receiver_arguments import SocketReceiverArguments
from arguments_classes.socket_sender_arguments import SocketSenderArguments
from arguments_classes.vad_arguments import VADHandlerArguments
from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments
from baseHandler import BaseHandler
from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
from handlers.melo_tts_handler import MeloTTSHandlerArguments
from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments
import numpy as np
import torch
import nltk
@@ -37,9 +43,9 @@ from utils import VADIterator, int2float, next_power_of_2
# Ensure that the necessary NLTK resources are available
try:
nltk.data.find('tokenizers/punkt_tab')
nltk.data.find("tokenizers/punkt_tab")
except (LookupError, OSError):
nltk.download('punkt_tab')
nltk.download("punkt_tab")
# caching allows ~50% compilation time reduction
# see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma
@@ -50,50 +56,6 @@ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp")
console = Console()
@dataclass
class ModuleArguments:
device: Optional[str] = field(
default=None,
metadata={"help": "If specified, overrides the device for all handlers."},
)
mode: Optional[str] = field(
default="socket",
metadata={
"help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'local'."
},
)
local_mac_optimal_settings: bool = field(
default=False,
metadata={
"help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used."
},
)
stt: Optional[str] = field(
default="whisper",
metadata={
"help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'."
},
)
llm: Optional[str] = field(
default="transformers",
metadata={
"help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'"
},
)
tts: Optional[str] = field(
default="parler",
metadata={
"help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'"
},
)
log_level: str = field(
default="info",
metadata={
"help": "Provide logging level. Example --log_level debug, default=warning."
},
)
class ThreadManager:
"""
Manages multiple threads used to execute given handler tasks.
@@ -116,29 +78,6 @@ class ThreadManager:
thread.join()
@dataclass
class SocketReceiverArguments:
recv_host: str = field(
default="localhost",
metadata={
"help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all "
"available interfaces on the host machine."
},
)
recv_port: int = field(
default=12345,
metadata={
"help": "The port number on which the socket server listens. Default is 12346."
},
)
chunk_size: int = field(
default=1024,
metadata={
"help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes."
},
)
class SocketReceiver:
"""
Handles reception of the audio packets from the client.
@@ -192,23 +131,6 @@ class SocketReceiver:
logger.info("Receiver closed")
@dataclass
class SocketSenderArguments:
send_host: str = field(
default="localhost",
metadata={
"help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all "
"available interfaces on the host machine."
},
)
send_port: int = field(
default=12346,
metadata={
"help": "The port number on which the socket server listens. Default is 12346."
},
)
class SocketSender:
"""
Handles sending generated audio packets to the clients.
@@ -238,46 +160,6 @@ class SocketSender:
logger.info("Sender closed")
@dataclass
class VADHandlerArguments:
thresh: float = field(
default=0.3,
metadata={
"help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection."
},
)
sample_rate: int = field(
default=16000,
metadata={
"help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio."
},
)
min_silence_ms: int = field(
default=250,
metadata={
"help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms."
},
)
min_speech_ms: int = field(
default=500,
metadata={
"help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms."
},
)
max_speech_ms: float = field(
default=float("inf"),
metadata={
"help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments."
},
)
speech_pad_ms: int = field(
default=250,
metadata={
"help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
},
)
class VADHandler(BaseHandler):
"""
Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
@@ -326,64 +208,6 @@ class VADHandler(BaseHandler):
yield array
@dataclass
class WhisperSTTHandlerArguments:
stt_model_name: str = field(
default="distil-whisper/distil-large-v3",
metadata={
"help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
},
)
stt_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
stt_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
stt_compile_mode: str = field(
default=None,
metadata={
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
},
)
stt_gen_max_new_tokens: int = field(
default=128,
metadata={
"help": "The maximum number of new tokens to generate. Default is 128."
},
)
stt_gen_num_beams: int = field(
default=1,
metadata={
"help": "The number of beams for beam search. Default is 1, implying greedy decoding."
},
)
stt_gen_return_timestamps: bool = field(
default=False,
metadata={
"help": "Whether to return timestamps with transcriptions. Default is False."
},
)
# stt_gen_task: str = field(
# default="transcribe",
# metadata={
# "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
# },
# )
# stt_gen_language: str = field(
# default="en",
# metadata={
# "help": "The language of the speech to transcribe. Default is 'en' for English."
# },
# )
class WhisperSTTHandler(BaseHandler):
"""
Handles the Speech To Text generation using a Whisper model.
@@ -480,70 +304,6 @@ class WhisperSTTHandler(BaseHandler):
yield pred_text
@dataclass
class LanguageModelHandlerArguments:
lm_model_name: str = field(
default="HuggingFaceTB/SmolLM-360M-Instruct",
metadata={
"help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
},
)
lm_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
lm_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
user_role: str = field(
default="user",
metadata={
"help": "Role assigned to the user in the chat context. Default is 'user'."
},
)
init_chat_role: str = field(
default='system',
metadata={
"help": "Initial role for setting up the chat context. Default is 'system'."
},
)
init_chat_prompt: str = field(
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
},
)
lm_gen_max_new_tokens: int = field(
default=128,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 128."
},
)
lm_gen_temperature: float = field(
default=0.0,
metadata={
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
},
)
lm_gen_do_sample: bool = field(
default=False,
metadata={
"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
},
)
chat_size: int = field(
default=2,
metadata={
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
},
)
class Chat:
"""
Handles the chat using to avoid OOM issues.
@@ -684,67 +444,6 @@ class LanguageModelHandler(BaseHandler):
yield printable_text
@dataclass
class ParlerTTSHandlerArguments:
tts_model_name: str = field(
default="ylacombe/parler-tts-mini-jenny-30H",
metadata={
"help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
},
)
tts_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
tts_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
tts_compile_mode: str = field(
default=None,
metadata={
"help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
},
)
tts_gen_min_new_tokens: int = field(
default=64,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"
},
)
tts_gen_max_new_tokens: int = field(
default=512,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
},
)
description: str = field(
default=(
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
"She speaks very fast."
),
metadata={
"help": "Description of the speaker's voice and speaking style to guide the TTS model."
},
)
play_steps_s: float = field(
default=1.0,
metadata={
"help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
},
)
max_prompt_pad_length: int = field(
default=8,
metadata={
"help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
},
)
class ParlerTTSHandler(BaseHandler):
def setup(
self,
@@ -886,7 +585,7 @@ class ParlerTTSHandler(BaseHandler):
thread.start()
for i, audio_chunk in enumerate(streamer):
if i == 0 and 'pipeline_start' in globals():
if i == 0 and "pipeline_start" in globals():
logger.info(
f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
)
@@ -971,7 +670,6 @@ def main():
if module_kwargs.log_level == "debug":
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
def optimal_mac_settings(mac_optimal_settings: Optional[str], *handler_kwargs):
if mac_optimal_settings:
for kwargs in handler_kwargs:
@@ -1070,14 +768,14 @@ def main():
setup_args=(should_listen,),
setup_kwargs=vars(vad_handler_kwargs),
)
if module_kwargs.stt == 'whisper':
if module_kwargs.stt == "whisper":
stt = WhisperSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue,
setup_kwargs=vars(whisper_stt_handler_kwargs),
)
elif module_kwargs.stt == 'whisper-mlx':
stop_event,
queue_in=spoken_prompt_queue,
queue_out=text_prompt_queue,
setup_kwargs=vars(whisper_stt_handler_kwargs),
)
elif module_kwargs.stt == "whisper-mlx":
stt = LightningWhisperSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
@@ -1086,14 +784,14 @@ def main():
)
else:
raise ValueError("The STT should be either whisper or whisper-mlx")
if module_kwargs.llm == 'transformers':
if module_kwargs.llm == "transformers":
lm = LanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs),
)
elif module_kwargs.llm == 'mlx-lm':
stop_event,
queue_in=text_prompt_queue,
queue_out=lm_response_queue,
setup_kwargs=vars(language_model_handler_kwargs),
)
elif module_kwargs.llm == "mlx-lm":
lm = MLXLanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
@@ -1102,7 +800,7 @@ def main():
)
else:
raise ValueError("The LLM should be either transformers or mlx-lm")
if module_kwargs.tts == 'parler':
if module_kwargs.tts == "parler":
torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
torch._dynamo.config.cache_size_limit = 15
@@ -1113,12 +811,14 @@ def main():
setup_args=(should_listen,),
setup_kwargs=vars(parler_tts_handler_kwargs),
)
elif module_kwargs.tts == 'melo':
elif module_kwargs.tts == "melo":
try:
from TTS.melotts import MeloTTSHandler
except RuntimeError as e:
logger.error(f"Error importing MeloTTSHandler. You might need to run: python -m unidic download")
logger.error(
"Error importing MeloTTSHandler. You might need to run: python -m unidic download"
)
raise e
tts = MeloTTSHandler(
stop_event,

View File

@@ -84,7 +84,7 @@ class VADIterator:
if not torch.is_tensor(x):
try:
x = torch.Tensor(x)
except:
except Exception:
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)