Merge branch 'main' into DeepFilterNet
This commit is contained in:
@@ -68,7 +68,9 @@ class WhisperSTTHandler(BaseHandler):
|
||||
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
|
||||
# hence, having min_new_tokens < max_new_tokens in the future doesn't make sense
|
||||
warmup_gen_kwargs = {
|
||||
"min_new_tokens": self.gen_kwargs["max_new_tokens"], # Yes, assign max_new_tokens to min_new_tokens
|
||||
"min_new_tokens": self.gen_kwargs[
|
||||
"max_new_tokens"
|
||||
], # Yes, assign max_new_tokens to min_new_tokens
|
||||
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
|
||||
**self.gen_kwargs,
|
||||
}
|
||||
|
||||
82
TTS/chatTTS_handler.py
Normal file
82
TTS/chatTTS_handler.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import ChatTTS
|
||||
import logging
|
||||
from baseHandler import BaseHandler
|
||||
import librosa
|
||||
import numpy as np
|
||||
from rich.console import Console
|
||||
import torch
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class ChatTTSHandler(BaseHandler):
|
||||
def setup(
|
||||
self,
|
||||
should_listen,
|
||||
device="cuda",
|
||||
gen_kwargs={}, # Unused
|
||||
stream=True,
|
||||
chunk_size=512,
|
||||
):
|
||||
self.should_listen = should_listen
|
||||
self.device = device
|
||||
self.model = ChatTTS.Chat()
|
||||
self.model.load(compile=False) # Doesn't work for me with True
|
||||
self.chunk_size = chunk_size
|
||||
self.stream = stream
|
||||
rnd_spk_emb = self.model.sample_random_speaker()
|
||||
self.params_infer_code = ChatTTS.Chat.InferCodeParams(
|
||||
spk_emb=rnd_spk_emb,
|
||||
)
|
||||
self.warmup()
|
||||
|
||||
def warmup(self):
|
||||
logger.info(f"Warming up {self.__class__.__name__}")
|
||||
_ = self.model.infer("text")
|
||||
|
||||
def process(self, llm_sentence):
|
||||
console.print(f"[green]ASSISTANT: {llm_sentence}")
|
||||
if self.device == "mps":
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
|
||||
torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
|
||||
_ = (
|
||||
time.time() - start
|
||||
) # Removing this line makes it fail more often. I'm looking into it.
|
||||
|
||||
wavs_gen = self.model.infer(
|
||||
llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
wavs = [np.array([])]
|
||||
for gen in wavs_gen:
|
||||
if gen[0] is None or len(gen[0]) == 0:
|
||||
self.should_listen.set()
|
||||
return
|
||||
audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
|
||||
audio_chunk = (audio_chunk * 32768).astype(np.int16)[0]
|
||||
while len(audio_chunk) > self.chunk_size:
|
||||
yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据
|
||||
audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据
|
||||
yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk)))
|
||||
else:
|
||||
wavs = wavs_gen
|
||||
if len(wavs[0]) == 0:
|
||||
self.should_listen.set()
|
||||
return
|
||||
audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
|
||||
audio_chunk = (audio_chunk * 32768).astype(np.int16)
|
||||
for i in range(0, len(audio_chunk), self.chunk_size):
|
||||
yield np.pad(
|
||||
audio_chunk[i : i + self.chunk_size],
|
||||
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
|
||||
)
|
||||
self.should_listen.set()
|
||||
21
arguments_classes/chat_tts_arguments.py
Normal file
21
arguments_classes/chat_tts_arguments.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTTSHandlerArguments:
|
||||
chat_tts_stream: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "The tts mode is stream Default is 'stream'."},
|
||||
)
|
||||
chat_tts_device: str = field(
|
||||
default="cuda",
|
||||
metadata={
|
||||
"help": "The device to be used for speech synthesis. Default is 'cuda'."
|
||||
},
|
||||
)
|
||||
chat_tts_chunk_size: int = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "Sets the size of the audio data chunk processed per cycle, balancing playback latency and CPU load.. Default is 512。."
|
||||
},
|
||||
)
|
||||
@@ -35,7 +35,7 @@ class ModuleArguments:
|
||||
tts: Optional[str] = field(
|
||||
default="parler",
|
||||
metadata={
|
||||
"help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'"
|
||||
"help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'"
|
||||
},
|
||||
)
|
||||
log_level: str = field(
|
||||
|
||||
@@ -3,6 +3,7 @@ parler_tts @ git+https://github.com/huggingface/parler-tts.git
|
||||
melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers
|
||||
torch==2.4.0
|
||||
sounddevice==0.5.0
|
||||
funasr
|
||||
modelscope
|
||||
deepfilternet
|
||||
ChatTTS>=0.1.1
|
||||
funasr>=1.1.6
|
||||
modelscope>=1.17.1
|
||||
deepfilternet>=0.5.6
|
||||
|
||||
@@ -5,6 +5,8 @@ torch==2.4.0
|
||||
sounddevice==0.5.0
|
||||
lightning-whisper-mlx>=0.0.10
|
||||
mlx-lm>=0.14.0
|
||||
ChatTTS>=0.1.1
|
||||
funasr>=1.1.6
|
||||
modelscope>=1.17.1
|
||||
deepfilternet
|
||||
deepfilternet>=0.5.6
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from threading import Event
|
||||
from typing import Optional
|
||||
from sys import platform
|
||||
from VAD.vad_handler import VADHandler
|
||||
from arguments_classes.chat_tts_arguments import ChatTTSHandlerArguments
|
||||
from arguments_classes.language_model_arguments import LanguageModelHandlerArguments
|
||||
from arguments_classes.mlx_language_model_arguments import (
|
||||
MLXLanguageModelHandlerArguments,
|
||||
@@ -79,6 +80,7 @@ def main():
|
||||
MLXLanguageModelHandlerArguments,
|
||||
ParlerTTSHandlerArguments,
|
||||
MeloTTSHandlerArguments,
|
||||
ChatTTSHandlerArguments,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -96,6 +98,7 @@ def main():
|
||||
mlx_language_model_handler_kwargs,
|
||||
parler_tts_handler_kwargs,
|
||||
melo_tts_handler_kwargs,
|
||||
chat_tts_handler_kwargs,
|
||||
) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
# Parse arguments from command line if no JSON file is provided
|
||||
@@ -110,6 +113,7 @@ def main():
|
||||
mlx_language_model_handler_kwargs,
|
||||
parler_tts_handler_kwargs,
|
||||
melo_tts_handler_kwargs,
|
||||
chat_tts_handler_kwargs,
|
||||
) = parser.parse_args_into_dataclasses()
|
||||
|
||||
# 1. Handle logger
|
||||
@@ -186,6 +190,7 @@ def main():
|
||||
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
|
||||
prepare_args(parler_tts_handler_kwargs, "tts")
|
||||
prepare_args(melo_tts_handler_kwargs, "melo")
|
||||
prepare_args(chat_tts_handler_kwargs, "chat_tts")
|
||||
|
||||
# 3. Build the pipeline
|
||||
stop_event = Event()
|
||||
@@ -310,8 +315,21 @@ def main():
|
||||
setup_args=(should_listen,),
|
||||
setup_kwargs=vars(melo_tts_handler_kwargs),
|
||||
)
|
||||
elif module_kwargs.tts == "chatTTS":
|
||||
try:
|
||||
from TTS.chatTTS_handler import ChatTTSHandler
|
||||
except RuntimeError as e:
|
||||
logger.error("Error importing ChatTTSHandler")
|
||||
raise e
|
||||
tts = ChatTTSHandler(
|
||||
stop_event,
|
||||
queue_in=lm_response_queue,
|
||||
queue_out=send_audio_chunks_queue,
|
||||
setup_args=(should_listen,),
|
||||
setup_kwargs=vars(chat_tts_handler_kwargs),
|
||||
)
|
||||
else:
|
||||
raise ValueError("The TTS should be either parler or melo")
|
||||
raise ValueError("The TTS should be either parler, melo or chatTTS")
|
||||
|
||||
# 4. Run the pipeline
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user