mirror of
https://github.com/neuphonic/neutts-air.git
synced 2025-10-10 02:44:44 +03:00
working streaming example
This commit is contained in:
@@ -36,6 +36,7 @@ def main(input_text, ref_codes_path, ref_text, backbone):
|
||||
print("Streaming...")
|
||||
for chunk in tts.infer_stream(input_text, ref_codes, ref_text):
|
||||
audio = (chunk * 32767).astype(np.int16)
|
||||
print(audio)
|
||||
stream.write(audio.tobytes())
|
||||
|
||||
stream.stop_stream()
|
||||
@@ -84,5 +85,4 @@ if __name__ == "__main__":
|
||||
ref_codes_path=args.ref_codes,
|
||||
ref_text=args.ref_text,
|
||||
backbone=args.backbone,
|
||||
output_path=args.output_path,
|
||||
)
|
||||
|
||||
@@ -11,10 +11,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStream
|
||||
from threading import Thread
|
||||
|
||||
|
||||
def _linear_overlap_add(frames: list[torch.Tensor], stride: int):
|
||||
def _linear_overlap_add(frames: list[np.ndarray], stride: int) -> np.ndarray:
|
||||
# original impl --> https://github.com/facebookresearch/encodec/blob/main/encodec/utils.py
|
||||
assert len(frames)
|
||||
device = frames[0].device
|
||||
dtype = frames[0].dtype
|
||||
shape = frames[0].shape[:-1]
|
||||
|
||||
@@ -23,14 +22,14 @@ def _linear_overlap_add(frames: list[torch.Tensor], stride: int):
|
||||
frame_end = stride * i + frame.shape[-1]
|
||||
total_size = max(total_size, frame_end)
|
||||
|
||||
sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
|
||||
out = torch.zeros(*shape, total_size, device=device, dtype=dtype)
|
||||
sum_weight = np.zeros(total_size, dtype=dtype)
|
||||
out = np.zeros(*shape, total_size, dtype=dtype)
|
||||
|
||||
offset: int = 0
|
||||
for frame in frames:
|
||||
frame_length = frame.shape[-1]
|
||||
t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1:-1]
|
||||
weight = 0.5 - (t - 0.5).abs()
|
||||
t = np.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
|
||||
weight = np.abs(0.5 - (t - 0.5))
|
||||
|
||||
out[..., offset : offset + frame_length] += weight * frame
|
||||
sum_weight[offset : offset + frame_length] += weight
|
||||
@@ -169,7 +168,7 @@ class NeuTTSAir:
|
||||
|
||||
return watermarked_wav
|
||||
|
||||
def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray]:
|
||||
def infer_stream(self, text: str, ref_codes: np.ndarray | torch.Tensor, ref_text: str) -> Generator[np.ndarray, None, None]:
|
||||
"""
|
||||
Perform streaming inference to generate speech from text using the TTS model and reference audio.
|
||||
|
||||
@@ -182,7 +181,7 @@ class NeuTTSAir:
|
||||
"""
|
||||
|
||||
if self._is_quantized_model:
|
||||
yield self._infer_stream_ggml(ref_codes, ref_text, text)
|
||||
return self._infer_stream_ggml(ref_codes, ref_text, text)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Streaming is not implemented for the torch backend!")
|
||||
@@ -294,7 +293,7 @@ class NeuTTSAir:
|
||||
output_str = output["choices"][0]["text"]
|
||||
return output_str
|
||||
|
||||
def _infer_stream_ggml(self, ref_codes: list[int], ref_text: str, input_text: str) -> Generator[np.ndarray]:
|
||||
def _infer_stream_ggml(self, ref_codes: torch.Tensor, ref_text: str, input_text: str) -> Generator[np.ndarray, None, None]:
|
||||
ref_text = self._to_phones(ref_text)
|
||||
input_text = self._to_phones(input_text)
|
||||
|
||||
@@ -304,10 +303,10 @@ class NeuTTSAir:
|
||||
f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}"
|
||||
)
|
||||
|
||||
audio_cache = []
|
||||
token_cache = ref_codes
|
||||
n_decoded_samples = 0
|
||||
n_decoded_tokens = len(ref_codes)
|
||||
audio_cache: list[np.ndarray] = []
|
||||
token_cache: list[str] = [f"<|speech_{idx}|>" for idx in ref_codes]
|
||||
n_decoded_samples: int = 0
|
||||
n_decoded_tokens: int = len(ref_codes)
|
||||
|
||||
for item in self.backbone(
|
||||
prompt,
|
||||
@@ -343,7 +342,7 @@ class NeuTTSAir:
|
||||
+ (self.streaming_frames_per_chunk + 2 * self.streaming_overlap_frames) * self.hop_length
|
||||
)
|
||||
curr_codes = token_cache[tokens_start:tokens_end]
|
||||
recon = self._decode(curr_codes)
|
||||
recon = self._decode("".join(curr_codes))
|
||||
recon = self.watermarker.apply_watermark(recon, sample_rate=24_000)
|
||||
recon = recon[sample_start:sample_end]
|
||||
audio_cache.append(recon)
|
||||
@@ -375,7 +374,7 @@ class NeuTTSAir:
|
||||
- self.streaming_overlap_frames
|
||||
) * self.hop_length
|
||||
curr_codes = token_cache[tokens_start:]
|
||||
recon = self._decode(curr_codes)
|
||||
recon = self._decode("".join(curr_codes))
|
||||
recon = self.watermarker.apply_watermark(recon, sample_rate=24_000)
|
||||
recon = recon[sample_start:]
|
||||
audio_cache.append(recon)
|
||||
|
||||
Reference in New Issue
Block a user