linting and few fixes for chattts

This commit is contained in:
Andres Marafioti
2024-08-28 15:43:33 +02:00
parent c40bf05df7
commit a26395c5f3
7 changed files with 26 additions and 33 deletions

View File

@@ -68,7 +68,9 @@ class WhisperSTTHandler(BaseHandler):
# one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
# hence, having min_new_tokens < max_new_tokens in the future doesn't make sense
warmup_gen_kwargs = {
"min_new_tokens": self.gen_kwargs["max_new_tokens"], # Yes, assign max_new_tokens to min_new_tokens
"min_new_tokens": self.gen_kwargs[
"max_new_tokens"
], # Yes, assign max_new_tokens to min_new_tokens
"max_new_tokens": self.gen_kwargs["max_new_tokens"],
**self.gen_kwargs,
}

View File

@@ -18,15 +18,15 @@ class ChatTTSHandler(BaseHandler):
def setup(
self,
should_listen,
device="mps",
device="cuda",
gen_kwargs={}, # Unused
stream=True,
chunk_size=512,
):
self.should_listen = should_listen
self.device = device
self.model = ChatTTS.Chat()
self.model.load(compile=True) # Set to True for better performance
self.model = ChatTTS.Chat()
self.model.load(compile=False) # Doesn't work for me with True
self.chunk_size = chunk_size
self.stream = stream
rnd_spk_emb = self.model.sample_random_speaker()
@@ -37,8 +37,7 @@ class ChatTTSHandler(BaseHandler):
def warmup(self):
logger.info(f"Warming up {self.__class__.__name__}")
_= self.model.infer("text")
_ = self.model.infer("text")
def process(self, llm_sentence):
console.print(f"[green]ASSISTANT: {llm_sentence}")
@@ -52,36 +51,32 @@ class ChatTTSHandler(BaseHandler):
time.time() - start
) # Removing this line makes it fail more often. I'm looking into it.
wavs_gen = self.model.infer(llm_sentence,params_infer_code=self.params_infer_code, stream=self.stream)
wavs_gen = self.model.infer(
llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream
)
if self.stream:
wavs = [np.array([])]
for gen in wavs_gen:
print('new chunk gen', len(gen[0]))
if len(gen[0]) == 0:
self.should_listen.set()
return
audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
print('audio_chunk:', audio_chunk.shape)
audio_chunk = (audio_chunk * 32768).astype(np.int16)[0]
while len(audio_chunk) > self.chunk_size:
yield audio_chunk[:self.chunk_size] # 返回前 chunk_size 字节的数据
audio_chunk = audio_chunk[self.chunk_size:] # 移除已返回的数据
yield np.pad(audio_chunk, (0,self.chunk_size-len(audio_chunk)))
yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据
audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据
yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk)))
else:
print('check result', wavs_gen)
wavs = wavs_gen
if len(wavs[0]) == 0:
self.should_listen.set()
return
audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
audio_chunk = (audio_chunk * 32768).astype(np.int16)
print('audio_chunk:', audio_chunk.shape)
for i in range(0, len(audio_chunk), self.chunk_size):
yield np.pad(
audio_chunk[i : i + self.chunk_size],
(0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
)
self.should_listen.set()

View File

@@ -5,14 +5,12 @@ from dataclasses import dataclass, field
class ChatTTSHandlerArguments:
chat_tts_stream: bool = field(
default=True,
metadata={
"help": "The tts mode is stream Default is 'stream'."
},
metadata={"help": "The tts mode is stream Default is 'stream'."},
)
chat_tts_device: str = field(
default="mps",
default="cuda",
metadata={
"help": "The device to be used for speech synthesis. Default is 'mps'."
"help": "The device to be used for speech synthesis. Default is 'cuda'."
},
)
chat_tts_chunk_size: int = field(

View File

@@ -35,7 +35,7 @@ class ModuleArguments:
tts: Optional[str] = field(
default="parler",
metadata={
"help": "The TTS to use. Either 'parler' or 'melo'. Default is 'parler'"
"help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'"
},
)
log_level: str = field(

View File

@@ -3,6 +3,6 @@ parler_tts @ git+https://github.com/huggingface/parler-tts.git
melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers
torch==2.4.0
sounddevice==0.5.0
ChatTTS
funasr
modelscope
ChatTTS>=0.1.1
funasr>=1.1.6
modelscope>=1.17.1

View File

@@ -5,6 +5,6 @@ torch==2.4.0
sounddevice==0.5.0
lightning-whisper-mlx>=0.0.10
mlx-lm>=0.14.0
ChatTTS
ChatTTS>=0.1.1
funasr>=1.1.6
modelscope>=1.17.1

View File

@@ -80,7 +80,7 @@ def main():
MLXLanguageModelHandlerArguments,
ParlerTTSHandlerArguments,
MeloTTSHandlerArguments,
ChatTTSHandlerArguments
ChatTTSHandlerArguments,
)
)
@@ -190,7 +190,7 @@ def main():
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
prepare_args(parler_tts_handler_kwargs, "tts")
prepare_args(melo_tts_handler_kwargs, "melo")
prepare_args(chat_tts_handler_kwargs,"chat_tts")
prepare_args(chat_tts_handler_kwargs, "chat_tts")
# 3. Build the pipeline
stop_event = Event()
@@ -319,9 +319,7 @@ def main():
try:
from TTS.chatTTS_handler import ChatTTSHandler
except RuntimeError as e:
logger.error(
"Error importing ChatTTSHandler"
)
logger.error("Error importing ChatTTSHandler")
raise e
tts = ChatTTSHandler(
stop_event,
@@ -331,7 +329,7 @@ def main():
setup_kwargs=vars(chat_tts_handler_kwargs),
)
else:
raise ValueError("The TTS should be either parler or melo")
raise ValueError("The TTS should be either parler, melo or chatTTS")
# 4. Run the pipeline
try: