refactor all the handlers - folder structure

This commit is contained in:
Andres Marafioti
2024-08-23 16:57:38 +02:00
parent f72806da5a
commit d50687a0c3
13 changed files with 660 additions and 589 deletions

134
LLM/language_model.py Normal file
View File

@@ -0,0 +1,134 @@
from threading import Thread
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
TextIteratorStreamer,
)
import torch
from LLM.chat import Chat
from baseHandler import BaseHandler
from rich.console import Console
import logging
from nltk import sent_tokenize
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class LanguageModelHandler(BaseHandler):
"""
Handles the language model part.
"""
def setup(
self,
model_name="microsoft/Phi-3-mini-4k-instruct",
device="cuda",
torch_dtype="float16",
gen_kwargs={},
user_role="user",
chat_size=1,
init_chat_role=None,
init_chat_prompt="You are a helpful AI assistant.",
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
self.pipe = pipeline(
"text-generation", model=self.model, tokenizer=self.tokenizer, device=device
)
self.streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
self.gen_kwargs = {
"streamer": self.streamer,
"return_full_text": False,
**gen_kwargs,
}
self.chat = Chat(chat_size)
if init_chat_role:
if not init_chat_prompt:
raise ValueError(
"An initial promt needs to be specified when setting init_chat_role."
)
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
self.user_role = user_role
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
dummy_input_text = "Write me a poem about Machine Learning."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"],
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}
n_steps = 2
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(n_steps):
thread = Thread(
target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
)
thread.start()
for _ in self.streamer:
pass
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, prompt):
logger.debug("infering language model...")
self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread(
target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
)
thread.start()
if self.device == "mps":
generated_text = ""
for new_text in self.streamer:
generated_text += new_text
printable_text = generated_text
torch.mps.empty_cache()
else:
generated_text, printable_text = "", ""
for new_text in self.streamer:
generated_text += new_text
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield (sentences[0])
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text

View File

@@ -66,13 +66,15 @@ class MLXLanguageModelHandler(BaseHandler):
logger.debug("infering language model...")
self.chat.append({"role": self.user_role, "content": prompt})
# Remove system messages if using a Gemma model
if "gemma" in self.model_name.lower():
chat_messages = [msg for msg in self.chat.to_list() if msg["role"] != "system"]
chat_messages = [
msg for msg in self.chat.to_list() if msg["role"] != "system"
]
else:
chat_messages = self.chat.to_list()
prompt = self.tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True
)

113
STT/whisper_stt_handler.py Normal file
View File

@@ -0,0 +1,113 @@
from time import perf_counter
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
)
import torch
from baseHandler import BaseHandler
from rich.console import Console
import logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class WhisperSTTHandler(BaseHandler):
"""
Handles the Speech To Text generation using a Whisper model.
"""
def setup(
self,
model_name="distil-whisper/distil-large-v3",
device="cuda",
torch_dtype="float16",
compile_mode=None,
gen_kwargs={},
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode = compile_mode
self.gen_kwargs = gen_kwargs
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
torch_dtype=self.torch_dtype,
).to(device)
# compile
if self.compile_mode:
self.model.generation_config.cache_implementation = "static"
self.model.forward = torch.compile(
self.model.forward, mode=self.compile_mode, fullgraph=True
)
self.warmup()
def prepare_model_inputs(self, spoken_prompt):
input_features = self.processor(
spoken_prompt, sampling_rate=16000, return_tensors="pt"
).input_features
input_features = input_features.to(self.device, dtype=self.torch_dtype)
return input_features
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2
dummy_input = torch.randn(
(1, self.model.config.num_mel_bins, 3000),
dtype=self.torch_dtype,
device=self.device,
)
if self.compile_mode not in (None, "default"):
# generating more tokens than previously will trigger CUDA graphs capture
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"],
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}
else:
warmup_gen_kwargs = self.gen_kwargs
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(n_steps):
_ = self.model.generate(dummy_input, **warmup_gen_kwargs)
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, spoken_prompt):
logger.debug("infering whisper...")
global pipeline_start
pipeline_start = perf_counter()
input_features = self.prepare_model_inputs(spoken_prompt)
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
yield pred_text

View File

@@ -24,7 +24,6 @@ class MeloTTSHandler(BaseHandler):
gen_kwargs={}, # Unused
blocksize=512,
):
print(device)
self.should_listen = should_listen
self.device = device
self.model = TTS(language=language, device=device)

181
TTS/parler_handler.py Normal file
View File

@@ -0,0 +1,181 @@
from threading import Thread
from time import perf_counter
from baseHandler import BaseHandler
import numpy as np
import torch
from transformers import (
AutoTokenizer,
)
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
import librosa
import logging
from rich.console import Console
from utils.utils import next_power_of_2
torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
torch._dynamo.config.cache_size_limit = 15
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class ParlerTTSHandler(BaseHandler):
def setup(
self,
should_listen,
model_name="ylacombe/parler-tts-mini-jenny-30H",
device="cuda",
torch_dtype="float16",
compile_mode=None,
gen_kwargs={},
max_prompt_pad_length=8,
description=(
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
"She speaks very fast."
),
play_steps_s=1,
blocksize=512,
):
self.should_listen = should_listen
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.gen_kwargs = gen_kwargs
self.compile_mode = compile_mode
self.max_prompt_pad_length = max_prompt_pad_length
self.description = description
self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name, torch_dtype=self.torch_dtype
).to(device)
framerate = self.model.audio_encoder.config.frame_rate
self.play_steps = int(framerate * play_steps_s)
self.blocksize = blocksize
if self.compile_mode not in (None, "default"):
logger.warning(
"Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'"
)
self.compile_mode = "default"
if self.compile_mode:
self.model.generation_config.cache_implementation = "static"
self.model.forward = torch.compile(
self.model.forward, mode=self.compile_mode, fullgraph=True
)
self.warmup()
def prepare_model_inputs(
self,
prompt,
max_length_prompt=50,
pad=False,
):
pad_args_prompt = (
{"padding": "max_length", "max_length": max_length_prompt} if pad else {}
)
tokenized_description = self.description_tokenizer(
self.description, return_tensors="pt"
)
input_ids = tokenized_description.input_ids.to(self.device)
attention_mask = tokenized_description.attention_mask.to(self.device)
tokenized_prompt = self.prompt_tokenizer(
prompt, return_tensors="pt", **pad_args_prompt
)
prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
gen_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"prompt_input_ids": prompt_input_ids,
"prompt_attention_mask": prompt_attention_mask,
**self.gen_kwargs,
}
return gen_kwargs
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2
if self.device == "cuda":
torch.cuda.synchronize()
start_event.record()
if self.compile_mode:
pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
for pad_length in pad_lengths[::-1]:
model_kwargs = self.prepare_model_inputs(
"dummy prompt", max_length_prompt=pad_length, pad=True
)
for _ in range(n_steps):
_ = self.model.generate(**model_kwargs)
logger.info(f"Warmed up length {pad_length} tokens!")
else:
model_kwargs = self.prepare_model_inputs("dummy prompt")
for _ in range(n_steps):
_ = self.model.generate(**model_kwargs)
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}")
nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
pad_args = {}
if self.compile_mode:
# pad to closest upper power of two
pad_length = next_power_of_2(nb_tokens)
logger.debug(f"padding to {pad_length}")
pad_args["pad"] = True
pad_args["max_length_prompt"] = pad_length
tts_gen_kwargs = self.prepare_model_inputs(
llm_sentence,
**pad_args,
)
streamer = ParlerTTSStreamer(
self.model, device=self.device, play_steps=self.play_steps
)
tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
torch.manual_seed(0)
thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
thread.start()
for i, audio_chunk in enumerate(streamer):
global pipeline_start
if i == 0 and "pipeline_start" in globals():
logger.info(
f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
)
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
for i in range(0, len(audio_chunk), self.blocksize):
yield np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
)
self.should_listen.set()

64
VAD/vad_handler.py Normal file
View File

@@ -0,0 +1,64 @@
from VAD.vad_iterator import VADIterator
from baseHandler import BaseHandler
import numpy as np
import torch
from rich.console import Console
from utils.utils import int2float
import logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class VADHandler(BaseHandler):
"""
Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
to the following part.
"""
def setup(
self,
should_listen,
thresh=0.3,
sample_rate=16000,
min_silence_ms=1000,
min_speech_ms=500,
max_speech_ms=float("inf"),
speech_pad_ms=30,
):
self.should_listen = should_listen
self.sample_rate = sample_rate
self.min_silence_ms = min_silence_ms
self.min_speech_ms = min_speech_ms
self.max_speech_ms = max_speech_ms
self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
self.iterator = VADIterator(
self.model,
threshold=thresh,
sampling_rate=sample_rate,
min_silence_duration_ms=min_silence_ms,
speech_pad_ms=speech_pad_ms,
)
def process(self, audio_chunk):
audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
audio_float32 = int2float(audio_int16)
vad_output = self.iterator(torch.from_numpy(audio_float32))
if vad_output is not None and len(vad_output) != 0:
logger.debug("VAD: end of speech detected")
array = torch.cat(vad_output).cpu().numpy()
duration_ms = len(array) / self.sample_rate * 1000
if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
logger.debug(
f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
)
else:
self.should_listen.clear()
logger.debug("Stop listening")
yield array

View File

@@ -1,24 +1,6 @@
import numpy as np
import torch
def next_power_of_2(x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def int2float(sound):
"""
Taken from https://github.com/snakers4/silero-vad
"""
abs_max = np.abs(sound).max()
sound = sound.astype("float32")
if abs_max > 0:
sound *= 1 / 32768
sound = sound.squeeze() # depends on the use case
return sound
class VADIterator:
def __init__(
self,

View File

@@ -0,0 +1,63 @@
import socket
from rich.console import Console
import logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class SocketReceiver:
"""
Handles reception of the audio packets from the client.
"""
def __init__(
self,
stop_event,
queue_out,
should_listen,
host="0.0.0.0",
port=12345,
chunk_size=1024,
):
self.stop_event = stop_event
self.queue_out = queue_out
self.should_listen = should_listen
self.chunk_size = chunk_size
self.host = host
self.port = port
def receive_full_chunk(self, conn, chunk_size):
data = b""
while len(data) < chunk_size:
packet = conn.recv(chunk_size - len(data))
if not packet:
# connection closed
return None
data += packet
return data
def run(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((self.host, self.port))
self.socket.listen(1)
logger.info("Receiver waiting to be connected...")
self.conn, _ = self.socket.accept()
logger.info("receiver connected")
self.should_listen.set()
while not self.stop_event.is_set():
audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size)
if audio_chunk is None:
# connection closed
self.queue_out.put(b"END")
break
if self.should_listen.is_set():
self.queue_out.put(audio_chunk)
self.conn.close()
logger.info("Receiver closed")

View File

@@ -0,0 +1,39 @@
import socket
from rich.console import Console
import logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
console = Console()
class SocketSender:
"""
Handles sending generated audio packets to the clients.
"""
def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346):
self.stop_event = stop_event
self.queue_in = queue_in
self.host = host
self.port = port
def run(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((self.host, self.port))
self.socket.listen(1)
logger.info("Sender waiting to be connected...")
self.conn, _ = self.socket.accept()
logger.info("sender connected")
while not self.stop_event.is_set():
audio_chunk = self.queue_in.get()
self.conn.sendall(audio_chunk)
if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
break
self.conn.close()
logger.info("Sender closed")

View File

@@ -1,43 +1,32 @@
import logging
import os
import socket
import sys
import threading
from copy import copy
from pathlib import Path
from queue import Queue
from threading import Event, Thread
from time import perf_counter
from threading import Event
from typing import Optional
from sys import platform
from VAD.vad_handler import VADHandler
from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
from arguments_classes.mlx_language_model_arguments import MLXLanguageModelHandlerArguments
from arguments_classes.mlx_language_model_arguments import (
MLXLanguageModelHandlerArguments,
)
from arguments_classes.module_arguments import ModuleArguments
from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
from arguments_classes.socket_receiver_arguments import SocketReceiverArguments
from arguments_classes.socket_sender_arguments import SocketSenderArguments
from arguments_classes.vad_arguments import VADHandlerArguments
from arguments_classes.whisper_stt_arguments import WhisperSTTHandlerArguments
from baseHandler import BaseHandler
from arguments_classes.melo_tts_arguments import MeloTTSHandlerArguments
import numpy as np
import torch
import nltk
from nltk.tokenize import sent_tokenize
from rich.console import Console
from transformers import (
AutoModelForCausalLM,
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoTokenizer,
HfArgumentParser,
pipeline,
TextIteratorStreamer,
)
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
import librosa
from utils import VADIterator, int2float, next_power_of_2
from utils.thread_manager import ThreadManager
# Ensure that the necessary NLTK resources are available
try:
@@ -58,550 +47,6 @@ os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp")
console = Console()
class ThreadManager:
"""
Manages multiple threads used to execute given handler tasks.
"""
def __init__(self, handlers):
self.handlers = handlers
self.threads = []
def start(self):
for handler in self.handlers:
thread = threading.Thread(target=handler.run)
self.threads.append(thread)
thread.start()
def stop(self):
for handler in self.handlers:
handler.stop_event.set()
for thread in self.threads:
thread.join()
class SocketReceiver:
"""
Handles reception of the audio packets from the client.
"""
def __init__(
self,
stop_event,
queue_out,
should_listen,
host="0.0.0.0",
port=12345,
chunk_size=1024,
):
self.stop_event = stop_event
self.queue_out = queue_out
self.should_listen = should_listen
self.chunk_size = chunk_size
self.host = host
self.port = port
def receive_full_chunk(self, conn, chunk_size):
data = b""
while len(data) < chunk_size:
packet = conn.recv(chunk_size - len(data))
if not packet:
# connection closed
return None
data += packet
return data
def run(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((self.host, self.port))
self.socket.listen(1)
logger.info("Receiver waiting to be connected...")
self.conn, _ = self.socket.accept()
logger.info("receiver connected")
self.should_listen.set()
while not self.stop_event.is_set():
audio_chunk = self.receive_full_chunk(self.conn, self.chunk_size)
if audio_chunk is None:
# connection closed
self.queue_out.put(b"END")
break
if self.should_listen.is_set():
self.queue_out.put(audio_chunk)
self.conn.close()
logger.info("Receiver closed")
class SocketSender:
"""
Handles sending generated audio packets to the clients.
"""
def __init__(self, stop_event, queue_in, host="0.0.0.0", port=12346):
self.stop_event = stop_event
self.queue_in = queue_in
self.host = host
self.port = port
def run(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind((self.host, self.port))
self.socket.listen(1)
logger.info("Sender waiting to be connected...")
self.conn, _ = self.socket.accept()
logger.info("sender connected")
while not self.stop_event.is_set():
audio_chunk = self.queue_in.get()
self.conn.sendall(audio_chunk)
if isinstance(audio_chunk, bytes) and audio_chunk == b"END":
break
self.conn.close()
logger.info("Sender closed")
class VADHandler(BaseHandler):
"""
Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
to the following part.
"""
def setup(
self,
should_listen,
thresh=0.3,
sample_rate=16000,
min_silence_ms=1000,
min_speech_ms=500,
max_speech_ms=float("inf"),
speech_pad_ms=30,
):
self.should_listen = should_listen
self.sample_rate = sample_rate
self.min_silence_ms = min_silence_ms
self.min_speech_ms = min_speech_ms
self.max_speech_ms = max_speech_ms
self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
self.iterator = VADIterator(
self.model,
threshold=thresh,
sampling_rate=sample_rate,
min_silence_duration_ms=min_silence_ms,
speech_pad_ms=speech_pad_ms,
)
def process(self, audio_chunk):
audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
audio_float32 = int2float(audio_int16)
vad_output = self.iterator(torch.from_numpy(audio_float32))
if vad_output is not None and len(vad_output) != 0:
logger.debug("VAD: end of speech detected")
array = torch.cat(vad_output).cpu().numpy()
duration_ms = len(array) / self.sample_rate * 1000
if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
logger.debug(
f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
)
else:
self.should_listen.clear()
logger.debug("Stop listening")
yield array
class WhisperSTTHandler(BaseHandler):
"""
Handles the Speech To Text generation using a Whisper model.
"""
def setup(
self,
model_name="distil-whisper/distil-large-v3",
device="cuda",
torch_dtype="float16",
compile_mode=None,
gen_kwargs={},
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.compile_mode = compile_mode
self.gen_kwargs = gen_kwargs
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
torch_dtype=self.torch_dtype,
).to(device)
# compile
if self.compile_mode:
self.model.generation_config.cache_implementation = "static"
self.model.forward = torch.compile(
self.model.forward, mode=self.compile_mode, fullgraph=True
)
self.warmup()
def prepare_model_inputs(self, spoken_prompt):
input_features = self.processor(
spoken_prompt, sampling_rate=16000, return_tensors="pt"
).input_features
input_features = input_features.to(self.device, dtype=self.torch_dtype)
return input_features
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2
dummy_input = torch.randn(
(1, self.model.config.num_mel_bins, 3000),
dtype=self.torch_dtype,
device=self.device,
)
if self.compile_mode not in (None, "default"):
# generating more tokens than previously will trigger CUDA graphs capture
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"],
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}
else:
warmup_gen_kwargs = self.gen_kwargs
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(n_steps):
_ = self.model.generate(dummy_input, **warmup_gen_kwargs)
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, spoken_prompt):
logger.debug("infering whisper...")
global pipeline_start
pipeline_start = perf_counter()
input_features = self.prepare_model_inputs(spoken_prompt)
pred_ids = self.model.generate(input_features, **self.gen_kwargs)
pred_text = self.processor.batch_decode(
pred_ids, skip_special_tokens=True, decode_with_timestamps=False
)[0]
logger.debug("finished whisper inference")
console.print(f"[yellow]USER: {pred_text}")
yield pred_text
class Chat:
"""
Handles the chat using to avoid OOM issues.
"""
def __init__(self, size):
self.size = size
self.init_chat_message = None
# maxlen is necessary pair, since a each new step we add an prompt and assitant answer
self.buffer = []
def append(self, item):
self.buffer.append(item)
if len(self.buffer) == 2 * (self.size + 1):
self.buffer.pop(0)
self.buffer.pop(0)
def init_chat(self, init_chat_message):
self.init_chat_message = init_chat_message
def to_list(self):
if self.init_chat_message:
return [self.init_chat_message] + self.buffer
else:
return self.buffer
class LanguageModelHandler(BaseHandler):
"""
Handles the language model part.
"""
def setup(
self,
model_name="microsoft/Phi-3-mini-4k-instruct",
device="cuda",
torch_dtype="float16",
gen_kwargs={},
user_role="user",
chat_size=1,
init_chat_role=None,
init_chat_prompt="You are a helpful AI assistant.",
):
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
self.pipe = pipeline(
"text-generation", model=self.model, tokenizer=self.tokenizer, device=device
)
self.streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
self.gen_kwargs = {
"streamer": self.streamer,
"return_full_text": False,
**gen_kwargs,
}
self.chat = Chat(chat_size)
if init_chat_role:
if not init_chat_prompt:
raise ValueError(
"An initial promt needs to be specified when setting init_chat_role."
)
self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
self.user_role = user_role
self.warmup()
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
dummy_input_text = "Write me a poem about Machine Learning."
dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"],
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}
n_steps = 2
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(n_steps):
thread = Thread(
target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
)
thread.start()
for _ in self.streamer:
pass
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, prompt):
logger.debug("infering language model...")
self.chat.append({"role": self.user_role, "content": prompt})
thread = Thread(
target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
)
thread.start()
if self.device == "mps":
generated_text = ""
for new_text in self.streamer:
generated_text += new_text
printable_text = generated_text
torch.mps.empty_cache()
else:
generated_text, printable_text = "", ""
for new_text in self.streamer:
generated_text += new_text
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
yield (sentences[0])
printable_text = new_text
self.chat.append({"role": "assistant", "content": generated_text})
# don't forget last sentence
yield printable_text
class ParlerTTSHandler(BaseHandler):
def setup(
self,
should_listen,
model_name="ylacombe/parler-tts-mini-jenny-30H",
device="cuda",
torch_dtype="float16",
compile_mode=None,
gen_kwargs={},
max_prompt_pad_length=8,
description=(
"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
"She speaks very fast."
),
play_steps_s=1,
blocksize=512,
):
self.should_listen = should_listen
self.device = device
self.torch_dtype = getattr(torch, torch_dtype)
self.gen_kwargs = gen_kwargs
self.compile_mode = compile_mode
self.max_prompt_pad_length = max_prompt_pad_length
self.description = description
self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name, torch_dtype=self.torch_dtype
).to(device)
framerate = self.model.audio_encoder.config.frame_rate
self.play_steps = int(framerate * play_steps_s)
self.blocksize = blocksize
if self.compile_mode not in (None, "default"):
logger.warning(
"Torch compilation modes that captures CUDA graphs are not yet compatible with the STT part. Reverting to 'default'"
)
self.compile_mode = "default"
if self.compile_mode:
self.model.generation_config.cache_implementation = "static"
self.model.forward = torch.compile(
self.model.forward, mode=self.compile_mode, fullgraph=True
)
self.warmup()
def prepare_model_inputs(
self,
prompt,
max_length_prompt=50,
pad=False,
):
pad_args_prompt = (
{"padding": "max_length", "max_length": max_length_prompt} if pad else {}
)
tokenized_description = self.description_tokenizer(
self.description, return_tensors="pt"
)
input_ids = tokenized_description.input_ids.to(self.device)
attention_mask = tokenized_description.attention_mask.to(self.device)
tokenized_prompt = self.prompt_tokenizer(
prompt, return_tensors="pt", **pad_args_prompt
)
prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
gen_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"prompt_input_ids": prompt_input_ids,
"prompt_attention_mask": prompt_attention_mask,
**self.gen_kwargs,
}
return gen_kwargs
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
if self.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
n_steps = 1 if self.compile_mode == "default" else 2
if self.device == "cuda":
torch.cuda.synchronize()
start_event.record()
if self.compile_mode:
pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
for pad_length in pad_lengths[::-1]:
model_kwargs = self.prepare_model_inputs(
"dummy prompt", max_length_prompt=pad_length, pad=True
)
for _ in range(n_steps):
_ = self.model.generate(**model_kwargs)
logger.info(f"Warmed up length {pad_length} tokens!")
else:
model_kwargs = self.prepare_model_inputs("dummy prompt")
for _ in range(n_steps):
_ = self.model.generate(**model_kwargs)
if self.device == "cuda":
end_event.record()
torch.cuda.synchronize()
logger.info(
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
)
def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}")
nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
pad_args = {}
if self.compile_mode:
# pad to closest upper power of two
pad_length = next_power_of_2(nb_tokens)
logger.debug(f"padding to {pad_length}")
pad_args["pad"] = True
pad_args["max_length_prompt"] = pad_length
tts_gen_kwargs = self.prepare_model_inputs(
llm_sentence,
**pad_args,
)
streamer = ParlerTTSStreamer(
self.model, device=self.device, play_steps=self.play_steps
)
tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
torch.manual_seed(0)
thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
thread.start()
for i, audio_chunk in enumerate(streamer):
if i == 0 and "pipeline_start" in globals():
logger.info(
f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
)
audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
for i in range(0, len(audio_chunk), self.blocksize):
yield np.pad(
audio_chunk[i : i + self.blocksize],
(0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
)
self.should_listen.set()
def prepare_args(args, prefix):
"""
Rename arguments by removing the prefix and prepares the gen_kwargs.
@@ -745,7 +190,7 @@ def main():
lm_response_queue = Queue()
if module_kwargs.mode == "local":
from local_audio_streamer import LocalAudioStreamer
from connections.local_audio_streamer import LocalAudioStreamer
local_audio_streamer = LocalAudioStreamer(
input_queue=recv_audio_chunks_queue, output_queue=send_audio_chunks_queue
@@ -753,6 +198,9 @@ def main():
comms_handlers = [local_audio_streamer]
should_listen.set()
else:
from connections.socket_receiver import SocketReceiver
from connections.socket_sender import SocketSender
comms_handlers = [
SocketReceiver(
stop_event,
@@ -778,6 +226,8 @@ def main():
setup_kwargs=vars(vad_handler_kwargs),
)
if module_kwargs.stt == "whisper":
from STT.whisper_stt_handler import WhisperSTTHandler
stt = WhisperSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
@@ -786,6 +236,7 @@ def main():
)
elif module_kwargs.stt == "whisper-mlx":
from STT.lightning_whisper_mlx_handler import LightningWhisperSTTHandler
stt = LightningWhisperSTTHandler(
stop_event,
queue_in=spoken_prompt_queue,
@@ -795,6 +246,8 @@ def main():
else:
raise ValueError("The STT should be either whisper or whisper-mlx")
if module_kwargs.llm == "transformers":
from LLM.language_model import LanguageModelHandler
lm = LanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
@@ -802,7 +255,8 @@ def main():
setup_kwargs=vars(language_model_handler_kwargs),
)
elif module_kwargs.llm == "mlx-lm":
from LLM.mlx_lm import MLXLanguageModelHandler
from LLM.mlx_language_model import MLXLanguageModelHandler
lm = MLXLanguageModelHandler(
stop_event,
queue_in=text_prompt_queue,
@@ -812,9 +266,8 @@ def main():
else:
raise ValueError("The LLM should be either transformers or mlx-lm")
if module_kwargs.tts == "parler":
torch._inductor.config.fx_graph_cache = True
# mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
torch._dynamo.config.cache_size_limit = 15
from TTS.parler_handler import ParlerTTSHandler
tts = ParlerTTSHandler(
stop_event,
queue_in=lm_response_queue,
@@ -825,7 +278,7 @@ def main():
elif module_kwargs.tts == "melo":
try:
from TTS.melotts import MeloTTSHandler
from TTS.melo_handler import MeloTTSHandler
except RuntimeError as e:
logger.error(
"Error importing MeloTTSHandler. You might need to run: python -m unidic download"

23
utils/thread_manager.py Normal file
View File

@@ -0,0 +1,23 @@
import threading
class ThreadManager:
"""
Manages multiple threads used to execute given handler tasks.
"""
def __init__(self, handlers):
self.handlers = handlers
self.threads = []
def start(self):
for handler in self.handlers:
thread = threading.Thread(target=handler.run)
self.threads.append(thread)
thread.start()
def stop(self):
for handler in self.handlers:
handler.stop_event.set()
for thread in self.threads:
thread.join()

18
utils/utils.py Normal file
View File

@@ -0,0 +1,18 @@
import numpy as np
def next_power_of_2(x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def int2float(sound):
"""
Taken from https://github.com/snakers4/silero-vad
"""
abs_max = np.abs(sound).max()
sound = sound.astype("float32")
if abs_max > 0:
sound *= 1 / 32768
sound = sound.squeeze() # depends on the use case
return sound