mirror of
https://github.com/neuphonic/neutts-air.git
synced 2025-10-10 02:44:44 +03:00
Merge pull request #33 from harryjulian/feat/streaming
GGUF Streaming Support
This commit is contained in:
@@ -110,7 +110,7 @@ from neuttsair.neutts import NeuTTSAir
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
tts = NeuTTSAir(
|
tts = NeuTTSAir(
|
||||||
backbone_repo="neuphonic/neutts-air", # or 'neutts-air-q4-gguf' wit llama-cpp-python installed
|
backbone_repo="neuphonic/neutts-air", # or 'neutts-air-q4-gguf' with llama-cpp-python installed
|
||||||
backbone_device="cpu",
|
backbone_device="cpu",
|
||||||
codec_repo="neuphonic/neucodec",
|
codec_repo="neuphonic/neucodec",
|
||||||
codec_device="cpu"
|
codec_device="cpu"
|
||||||
|
|||||||
@@ -35,4 +35,16 @@ python -m examples.onnx_example \
|
|||||||
--ref_codes samples/dave.pt \
|
--ref_codes samples/dave.pt \
|
||||||
--ref_text samples/dave.txt \
|
--ref_text samples/dave.txt \
|
||||||
--backbone neuphonic/neutts-air-q4-gguf
|
--backbone neuphonic/neutts-air-q4-gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Streaming Support
|
||||||
|
|
||||||
|
To stream the model output in chunks, try out the `onnx_streaming.py` example. For streaming, only the GGUF backends are currently supported. Ensure you have `llama-cpp-pyhon`, `onnxruntime` and `pyaudio` installed to run this example.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m examples.basic_streaming_example \
|
||||||
|
--input_text "My name is Dave, and um, I'm from London" \
|
||||||
|
--ref_codes samples/dave.pt \
|
||||||
|
--ref_text samples/dave.txt \
|
||||||
|
--backbone neuphonic/neutts-air-q4-gguf
|
||||||
|
```
|
||||||
|
|||||||
87
examples/basic_streaming_example.py
Normal file
87
examples/basic_streaming_example.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import os
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from neuttsair.neutts import NeuTTSAir
|
||||||
|
import pyaudio
|
||||||
|
|
||||||
|
|
||||||
|
def main(input_text, ref_codes_path, ref_text, backbone):
|
||||||
|
assert backbone in ["neuphonic/neutts-air-q4-gguf", "neuphonic/neutts-air-q8-gguf"], "Must be a GGUF ckpt as streaming is only currently supported by llama-cpp."
|
||||||
|
|
||||||
|
# Initialize NeuTTSAir with the desired model and codec
|
||||||
|
tts = NeuTTSAir(
|
||||||
|
backbone_repo=backbone,
|
||||||
|
backbone_device="cpu",
|
||||||
|
codec_repo="neuphonic/neucodec-onnx-decoder",
|
||||||
|
codec_device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if ref_text is a path if it is read it if not just return string
|
||||||
|
if ref_text and os.path.exists(ref_text):
|
||||||
|
with open(ref_text, "r") as f:
|
||||||
|
ref_text = f.read().strip()
|
||||||
|
|
||||||
|
if ref_codes_path and os.path.exists(ref_codes_path):
|
||||||
|
ref_codes = torch.load(ref_codes_path)
|
||||||
|
|
||||||
|
print(f"Generating audio for input text: {input_text}")
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
stream = p.open(
|
||||||
|
format=pyaudio.paInt16,
|
||||||
|
channels=1,
|
||||||
|
rate=24_000,
|
||||||
|
output=True
|
||||||
|
)
|
||||||
|
print("Streaming...")
|
||||||
|
for chunk in tts.infer_stream(input_text, ref_codes, ref_text):
|
||||||
|
audio = (chunk * 32767).astype(np.int16)
|
||||||
|
print(audio.shape)
|
||||||
|
stream.write(audio.tobytes())
|
||||||
|
|
||||||
|
stream.stop_stream()
|
||||||
|
stream.close()
|
||||||
|
p.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="NeuTTSAir Example")
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_text",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Input text to be converted to speech"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ref_codes",
|
||||||
|
type=str,
|
||||||
|
default="./samples/dave.pt",
|
||||||
|
help="Path to pre-encoded reference audio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ref_text",
|
||||||
|
type=str,
|
||||||
|
default="./samples/dave.txt",
|
||||||
|
help="Reference text corresponding to the reference audio",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path",
|
||||||
|
type=str,
|
||||||
|
default="output.wav",
|
||||||
|
help="Path to save the output audio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backbone",
|
||||||
|
type=str,
|
||||||
|
default="neuphonic/neutts-air-q8-gguf",
|
||||||
|
help="Huggingface repo containing the backbone checkpoint. Must be GGUF."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(
|
||||||
|
input_text=args.input_text,
|
||||||
|
ref_codes_path=args.ref_codes,
|
||||||
|
ref_text=args.ref_text,
|
||||||
|
backbone=args.backbone,
|
||||||
|
)
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from typing import Generator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -6,7 +7,35 @@ import re
|
|||||||
import perth
|
import perth
|
||||||
from neucodec import NeuCodec, DistillNeuCodec
|
from neucodec import NeuCodec, DistillNeuCodec
|
||||||
from phonemizer.backend import EspeakBackend
|
from phonemizer.backend import EspeakBackend
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
|
||||||
|
def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
|
||||||
|
# original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py
|
||||||
|
assert len(frames)
|
||||||
|
dtype = frames[0].dtype
|
||||||
|
shape = frames[0].shape[:-1]
|
||||||
|
|
||||||
|
total_size = 0
|
||||||
|
for i, frame in enumerate(frames):
|
||||||
|
frame_end = stride * i + frame.shape[-1]
|
||||||
|
total_size = max(total_size, frame_end)
|
||||||
|
|
||||||
|
sum_weight = np.zeros(total_size, dtype=dtype)
|
||||||
|
out = np.zeros(*shape, total_size, dtype=dtype)
|
||||||
|
|
||||||
|
offset: int = 0
|
||||||
|
for frame in frames:
|
||||||
|
frame_length = frame.shape[-1]
|
||||||
|
t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
|
||||||
|
weight = np.abs(0.5 - (t - 0.5))
|
||||||
|
|
||||||
|
out[..., offset : offset + frame_length] += weight * frame
|
||||||
|
sum_weight[offset : offset + frame_length] += weight
|
||||||
|
offset += stride
|
||||||
|
assert sum_weight.min() > 0
|
||||||
|
return out / sum_weight
|
||||||
|
|
||||||
|
|
||||||
class NeuTTSAir:
|
class NeuTTSAir:
|
||||||
@@ -22,9 +51,14 @@ class NeuTTSAir:
|
|||||||
# Consts
|
# Consts
|
||||||
self.sample_rate = 24_000
|
self.sample_rate = 24_000
|
||||||
self.max_context = 2048
|
self.max_context = 2048
|
||||||
|
self.hop_length = 480
|
||||||
|
self.streaming_overlap_frames = 1
|
||||||
|
self.streaming_frames_per_chunk = 25
|
||||||
|
self.streaming_lookforward = 5
|
||||||
|
self.streaming_lookback = 50
|
||||||
|
self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length
|
||||||
|
|
||||||
# ggml & onnx flags
|
# ggml & onnx flags
|
||||||
self._grammar = None # set with a ggml model
|
|
||||||
self._is_quantized_model = False
|
self._is_quantized_model = False
|
||||||
self._is_onnx_codec = False
|
self._is_onnx_codec = False
|
||||||
|
|
||||||
@@ -133,6 +167,24 @@ class NeuTTSAir:
|
|||||||
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=24_000)
|
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=24_000)
|
||||||
|
|
||||||
return watermarked_wav
|
return watermarked_wav
|
||||||
|
|
||||||
|
def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
|
||||||
|
"""
|
||||||
|
Perform streaming inference to generate speech from text using the TTS model and reference audio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Input text to be converted to speech.
|
||||||
|
ref_codes (np.ndarray | torch.tensor): Encoded reference.
|
||||||
|
ref_text (str): Reference text for reference audio. Defaults to None.
|
||||||
|
Yields:
|
||||||
|
np.ndarray: Generated speech waveform.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._is_quantized_model:
|
||||||
|
return self._infer_stream_ggml(ref_codes, ref_text, text)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Streaming is not implemented for the torch backend!")
|
||||||
|
|
||||||
def encode_reference(self, ref_audio_path: str | Path):
|
def encode_reference(self, ref_audio_path: str | Path):
|
||||||
wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
|
wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True)
|
||||||
@@ -221,7 +273,7 @@ class NeuTTSAir:
|
|||||||
output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
|
output_tokens[0, input_length:].cpu().numpy().tolist(), add_special_tokens=False
|
||||||
)
|
)
|
||||||
return output_str
|
return output_str
|
||||||
|
|
||||||
def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
|
def _infer_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> str:
|
||||||
ref_text = self._to_phones(ref_text)
|
ref_text = self._to_phones(ref_text)
|
||||||
input_text = self._to_phones(input_text)
|
input_text = self._to_phones(input_text)
|
||||||
@@ -240,3 +292,93 @@ class NeuTTSAir:
|
|||||||
)
|
)
|
||||||
output_str = output["choices"][0]["text"]
|
output_str = output["choices"][0]["text"]
|
||||||
return output_str
|
return output_str
|
||||||
|
|
||||||
|
def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
|
||||||
|
ref_text = self._to_phones(ref_text)
|
||||||
|
input_text = self._to_phones(input_text)
|
||||||
|
|
||||||
|
codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes])
|
||||||
|
prompt = (
|
||||||
|
f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{ref_text} {input_text}"
|
||||||
|
f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_cache: list[np.ndarray] = []
|
||||||
|
token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
|
||||||
|
n_decoded_samples: int = 0
|
||||||
|
n_decoded_tokens: int = len(ref_codes)
|
||||||
|
|
||||||
|
for item in self.backbone(
|
||||||
|
prompt,
|
||||||
|
max_tokens=self.max_context,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=50,
|
||||||
|
stop=["<|SPEECH_GENERATION_END|>"],
|
||||||
|
stream=True
|
||||||
|
):
|
||||||
|
output_str = item["choices"][0]["text"]
|
||||||
|
token_cache.append(output_str)
|
||||||
|
|
||||||
|
if len(token_cache[n_decoded_tokens:]) >= self.streaming_frames_per_chunk + self.streaming_lookforward:
|
||||||
|
|
||||||
|
# decode chunk
|
||||||
|
tokens_start = max(
|
||||||
|
n_decoded_tokens
|
||||||
|
- self.streaming_lookback
|
||||||
|
- self.streaming_overlap_frames,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
tokens_end = (
|
||||||
|
n_decoded_tokens
|
||||||
|
+ self.streaming_frames_per_chunk
|
||||||
|
+ self.streaming_lookforward
|
||||||
|
+ self.streaming_overlap_frames
|
||||||
|
)
|
||||||
|
sample_start = (
|
||||||
|
n_decoded_tokens - tokens_start
|
||||||
|
) * self.hop_length
|
||||||
|
sample_end = (
|
||||||
|
sample_start
|
||||||
|
+ (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
|
||||||
|
)
|
||||||
|
curr_codes = token_cache[tokens_start:tokens_end]
|
||||||
|
recon = self._decode("".join(curr_codes))
|
||||||
|
recon = self.watermarker.apply_watermark(recon, sample_rate=24_000)
|
||||||
|
recon = recon[sample_start:sample_end]
|
||||||
|
audio_cache.append(recon)
|
||||||
|
|
||||||
|
# postprocess
|
||||||
|
processed_recon = _linear_overlap_add(
|
||||||
|
audio_cache, stride=self.streaming_stride_samples
|
||||||
|
)
|
||||||
|
new_samples_end = len(audio_cache) * self.streaming_stride_samples
|
||||||
|
processed_recon = processed_recon[
|
||||||
|
n_decoded_samples:new_samples_end
|
||||||
|
]
|
||||||
|
n_decoded_samples = new_samples_end
|
||||||
|
n_decoded_tokens += self.streaming_frames_per_chunk
|
||||||
|
yield processed_recon
|
||||||
|
|
||||||
|
# final decoding handled seperately as non-constant chunk size
|
||||||
|
remaining_tokens = len(token_cache) - n_decoded_tokens
|
||||||
|
if len(token_cache) > n_decoded_tokens:
|
||||||
|
tokens_start = max(
|
||||||
|
len(token_cache)
|
||||||
|
- (self.streaming_lookback + self.streaming_overlap_frames + remaining_tokens),
|
||||||
|
0
|
||||||
|
)
|
||||||
|
sample_start = (
|
||||||
|
len(token_cache)
|
||||||
|
- tokens_start
|
||||||
|
- remaining_tokens
|
||||||
|
- self.streaming_overlap_frames
|
||||||
|
) * self.hop_length
|
||||||
|
curr_codes = token_cache[tokens_start:]
|
||||||
|
recon = self._decode("".join(curr_codes))
|
||||||
|
recon = self.watermarker.apply_watermark(recon, sample_rate=24_000)
|
||||||
|
recon = recon[sample_start:]
|
||||||
|
audio_cache.append(recon)
|
||||||
|
|
||||||
|
processed_recon = _linear_overlap_add(audio_cache, stride=self.streaming_stride_samples)
|
||||||
|
processed_recon = processed_recon[n_decoded_samples:]
|
||||||
|
yield processed_recon
|
||||||
Reference in New Issue
Block a user