diff --git a/README.md b/README.md index 03a7160..b42360f 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ from neuttsair.neutts import NeuTTSAir import soundfile as sf tts = NeuTTSAir( - backbone_repo="neuphonic/neutts-air", # or 'neutts-air-q4-gguf' wit llama-cpp-python installed + backbone_repo="neuphonic/neutts-air", # or 'neutts-air-q4-gguf' with llama-cpp-python installed backbone_device="cpu", codec_repo="neuphonic/neucodec", codec_device="cpu" diff --git a/examples/README.md b/examples/README.md index 231e3a0..facdab9 100644 --- a/examples/README.md +++ b/examples/README.md @@ -35,4 +35,16 @@ python -m examples.onnx_example \ --ref_codes samples/dave.pt \ --ref_text samples/dave.txt \ --backbone neuphonic/neutts-air-q4-gguf -``` \ No newline at end of file +``` + +### Streaming Support + +To stream the model output in chunks, try out the `onnx_streaming.py` example. For streaming, only the GGUF backends are currently supported. Ensure you have `llama-cpp-pyhon`, `onnxruntime` and `pyaudio` installed to run this example. + +```bash +python -m examples.basic_streaming_example \ + --input_text "My name is Dave, and um, I'm from London" \ + --ref_codes samples/dave.pt \ + --ref_text samples/dave.txt \ + --backbone neuphonic/neutts-air-q4-gguf +``` diff --git a/examples/basic_streaming_example.py b/examples/basic_streaming_example.py new file mode 100644 index 0000000..dd9e640 --- /dev/null +++ b/examples/basic_streaming_example.py @@ -0,0 +1,87 @@ +import os +import soundfile as sf +import torch +import numpy as np +from neuttsair.neutts import NeuTTSAir +import pyaudio + + +def main(input_text, ref_codes_path, ref_text, backbone): + assert backbone in ["neuphonic/neutts-air-q4-gguf", "neuphonic/neutts-air-q8-gguf"], "Must be a GGUF ckpt as streaming is only currently supported by llama-cpp." + + # Initialize NeuTTSAir with the desired model and codec + tts = NeuTTSAir( + backbone_repo=backbone, + backbone_device="cpu", + codec_repo="neuphonic/neucodec-onnx-decoder", + codec_device="cpu" + ) + + # Check if ref_text is a path if it is read it if not just return string + if ref_text and os.path.exists(ref_text): + with open(ref_text, "r") as f: + ref_text = f.read().strip() + + if ref_codes_path and os.path.exists(ref_codes_path): + ref_codes = torch.load(ref_codes_path) + + print(f"Generating audio for input text: {input_text}") + p = pyaudio.PyAudio() + stream = p.open( + format=pyaudio.paInt16, + channels=1, + rate=24_000, + output=True + ) + print("Streaming...") + for chunk in tts.infer_stream(input_text, ref_codes, ref_text): + audio = (chunk * 32767).astype(np.int16) + print(audio.shape) + stream.write(audio.tobytes()) + + stream.stop_stream() + stream.close() + p.terminate() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="NeuTTSAir Example") + parser.add_argument( + "--input_text", + type=str, + required=True, + help="Input text to be converted to speech" + ) + parser.add_argument( + "--ref_codes", + type=str, + default="./samples/dave.pt", + help="Path to pre-encoded reference audio" + ) + parser.add_argument( + "--ref_text", + type=str, + default="./samples/dave.txt", + help="Reference text corresponding to the reference audio", + ) + parser.add_argument( + "--output_path", + type=str, + default="output.wav", + help="Path to save the output audio" + ) + parser.add_argument( + "--backbone", + type=str, + default="neuphonic/neutts-air-q8-gguf", + help="Huggingface repo containing the backbone checkpoint. Must be GGUF." + ) + args = parser.parse_args() + main( + input_text=args.input_text, + ref_codes_path=args.ref_codes, + ref_text=args.ref_text, + backbone=args.backbone, + ) diff --git a/neuttsair/neutts.py b/neuttsair/neutts.py index f58829c..3bbc1ba 100644 --- a/neuttsair/neutts.py +++ b/neuttsair/neutts.py @@ -1,3 +1,4 @@ +from typing import Generator from pathlib import Path import librosa import numpy as np @@ -6,7 +7,35 @@ import re import perth from neucodec import NeuCodec, DistillNeuCodec from phonemizer.backend import EspeakBackend -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer +from threading import Thread + + +def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray: + # original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py + assert len(frames) + dtype = frames[0].dtype + shape = frames[0].shape[:-1] + + total_size = 0 + for i, frame in enumerate(frames): + frame_end = stride * i + frame.shape[-1] + total_size = max(total_size, frame_end) + + sum_weight = np.zeros(total_size, dtype=dtype) + out = np.zeros(*shape, total_size, dtype=dtype) + + offset: int = 0 + for frame in frames: + frame_length = frame.shape[-1] + t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1] + weight = np.abs(0.5 - (t - 0.5)) + + out[..., offset : offset + frame_length] += weight * frame + sum_weight[offset : offset + frame_length] += weight + offset += stride + assert sum_weight.min() > 0 + return out / sum_weight class NeuTTSAir: @@ -22,9 +51,14 @@ class NeuTTSAir: # Consts self.sample_rate = 24_000 self.max_context = 2048 + self.hop_length = 480 + self.streaming_overlap_frames = 1 + self.streaming_frames_per_chunk = 25 + self.streaming_lookforward = 5 + self.streaming_lookback = 50 + self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length # ggml & onnx flags - self._grammar = None # set with a ggml model self._is_quantized_model = False self._is_onnx_codec = False @@ -133,6 +167,24 @@ class NeuTTSAir: watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=24_000) return watermarked_wav + + def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]: + """ + Perform streaming inference to generate speech from text using the TTS model and reference audio. + + Args: + text (str): Input text to be converted to speech. + ref_codes (np.ndarray | torch.tensor): Encoded reference. + ref_text (str): Reference text for reference audio. Defaults to None. + Yields: + np.ndarray: Generated speech waveform. + """ + + if self._is_quantized_model: + return self._infer_stream_ggml(ref_codes, ref_text, text) + + else: + raise NotImplementedError("Streaming is not implemented for the torch backend!") def encode_reference(self, ref_audio_path: str | Path): wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True) @@ -221,7 +273,7 @@ class NeuTTSAir: output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False ) return output_str - + def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str: ref_text = self._to_phones(ref_text) input_text = self._to_phones(input_text) @@ -240,3 +292,93 @@ class NeuTTSAir: ) output_str = output["choices"][0]["text"] return output_str + + def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]: + ref_text = self._to_phones(ref_text) + input_text = self._to_phones(input_text) + + codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes]) + prompt = ( + f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}" + f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}" + ) + + audio_cache: list[np.ndarray] = [] + token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes] + n_decoded_samples: int = 0 + n_decoded_tokens: int = len(ref_codes) + + for item in self.backbone( + prompt, + max_tokens=self.max_context, + temperature=1.0, + top_k=50, + stop=["<|SPEECH_GENERATION_END|>"], + stream=True + ): + output_str = item["choices"][0]["text"] + token_cache.append(output_str) + + if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward: + + # decode chunk + tokens_start = max( + n_decoded_tokens + - self.streaming_lookback + - self.streaming_overlap_frames, + 0 + ) + tokens_end = ( + n_decoded_tokens + + self.streaming_frames_per_chunk + + self.streaming_lookforward + + self.streaming_overlap_frames + ) + sample_start = ( + n_decoded_tokens - tokens_start + ) * self.hop_length + sample_end = ( + sample_start + + (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length + ) + curr_codes = token_cache[tokens_start:tokens_end] + recon = self._decode("".join(curr_codes)) + recon = self.watermarker.apply_watermark(recon, sample_rate=24_000) + recon = recon[sample_start:sample_end] + audio_cache.append(recon) + + # postprocess + processed_recon = _linear_overlap_add( + audio_cache, stride=self.streaming_stride_samples + ) + new_samples_end = len(audio_cache) * self.streaming_stride_samples + processed_recon = processed_recon[ + n_decoded_samples:new_samples_end + ] + n_decoded_samples = new_samples_end + n_decoded_tokens += self.streaming_frames_per_chunk + yield processed_recon + + # final decoding handled seperately as non-constant chunk size + remaining_tokens = len(token_cache) - n_decoded_tokens + if len(token_cache) > n_decoded_tokens: + tokens_start = max( + len(token_cache) + - (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens), + 0 + ) + sample_start = ( + len(token_cache) + - tokens_start + - remaining_tokens + - self.streaming_overlap_frames + ) * self.hop_length + curr_codes = token_cache[tokens_start:] + recon = self._decode("".join(curr_codes)) + recon = self.watermarker.apply_watermark(recon, sample_rate=24_000) + recon = recon[sample_start:] + audio_cache.append(recon) + + processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples) + processed_recon = processed_recon[n_decoded_samples:] + yield processed_recon \ No newline at end of file