Refactor voice management and enhance TTS demo functionality

- Updated voice file retrieval to store voices in a local 'voices' directory, improving organization and accessibility.
- Implemented automatic creation of the voices directory if it doesn't exist, ensuring smoother user experience.
- Enhanced load_voice function to download missing voice files locally, defaulting to 'af_bella' for better usability.
- Adjusted tqdm import for improved compatibility with Windows consoles and configured it to prevent encoding issues.
- Updated command-line argument defaults in tts_demo.py to reflect changes in voice management.
This commit is contained in:
Pierre Bruno
2025-01-16 16:40:24 +01:00
parent ff9c696065
commit 49379b98e5
2 changed files with 46 additions and 27 deletions

View File

@@ -3,30 +3,34 @@ import os
import sys
import platform
import glob
import warnings
from huggingface_hub import hf_hub_download, list_repo_files
import espeakng_loader
from phonemizer.backend.espeak.wrapper import EspeakWrapper
from importlib.util import spec_from_file_location, module_from_spec
from pathlib import Path
# Filter out specific warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.nn.utils.weight_norm")
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.rnn")
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
__all__ = ['list_available_voices', 'build_model', 'load_voice', 'generate_speech']
def get_voices_path():
"""Get the path where voice files are stored."""
system = platform.system().lower()
if system == "windows":
base = os.getenv("APPDATA", os.path.expanduser("~"))
else: # Linux and macOS
base = str(Path.home() / ".cache")
return str(Path(base) / "huggingface" / "hub" / "models--hexgrad--Kokoro-82M" / "snapshots" / "main" / "voices")
# Store voices in a 'voices' directory in the project root
return str(Path(__file__).parent / "voices")
def list_available_voices():
"""List all available voices from the official voicepacks."""
voices_path = get_voices_path()
try:
# Create voices directory if it doesn't exist
os.makedirs(voices_path, exist_ok=True)
# Download voices if they don't exist
if not os.path.exists(voices_path):
if not any(f.endswith('.pt') for f in os.listdir(voices_path)):
print("Downloading voice files...")
repo_id = "hexgrad/Kokoro-82M"
repo_files = list_repo_files(repo_id)
@@ -35,25 +39,25 @@ def list_available_voices():
for voice_file in voice_files:
try:
voice_name = os.path.splitext(os.path.basename(voice_file))[0]
print(f"Downloading voice: {voice_name}")
hf_hub_download(
repo_id=repo_id,
filename=voice_file,
local_dir=str(Path(voices_path).parent.parent.parent)
)
target_path = os.path.join(voices_path, f"{voice_name}.pt")
if not os.path.exists(target_path):
print(f"Downloading voice: {voice_name}")
hf_hub_download(
repo_id=repo_id,
filename=voice_file,
local_dir=str(Path(voices_path).parent),
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading voice {voice_file}: {e}")
# List available voice files
if os.path.exists(voices_path):
voices = [os.path.splitext(f)[0] for f in os.listdir(voices_path) if f.endswith('.pt')]
return sorted(voices)
voices = [os.path.splitext(f)[0] for f in os.listdir(voices_path) if f.endswith('.pt')]
return sorted(voices)
except Exception as e:
print(f"Error accessing voices: {e}")
# Fallback to default voices if everything fails
return ["af_bella", "af_nicole", "af_sarah", "af_sky", "am_adam", "am_michael",
"bf_emma", "bf_isabella", "bm_george", "bm_lewis"]
return ["af_bella", "af_nicole", "af_sarah", "af_sky", "am_adam", "am_michael",
"bf_emma", "bf_isabella", "bm_george", "bm_lewis"]
def get_platform_paths():
"""Get platform-specific paths for espeak-ng"""
@@ -168,11 +172,23 @@ def build_model(model_file, device='cpu'):
traceback.print_exc()
raise e
def load_voice(voice_name='af', device='cpu'):
"""Load a voice from the official voicepacks."""
def load_voice(voice_name='af_bella', device='cpu'):
"""Load a voice from the local voices directory."""
try:
repo_id = "hexgrad/Kokoro-82M"
voice_path = hf_hub_download(repo_id=repo_id, filename=f"voices/{voice_name}.pt")
voices_path = get_voices_path()
voice_path = os.path.join(voices_path, f"{voice_name}.pt")
# Download voice if it doesn't exist locally
if not os.path.exists(voice_path):
print(f"Downloading voice: {voice_name}")
repo_id = "hexgrad/Kokoro-82M"
hf_hub_download(
repo_id=repo_id,
filename=f"voices/{voice_name}.pt",
local_dir=str(Path(voices_path).parent),
local_dir_use_symlinks=False
)
voice = torch.load(voice_path, weights_only=True).to(device)
print(f"Loaded voice: {voice_name}")
return voice

View File

@@ -2,7 +2,7 @@ import torch
from typing import Optional, Tuple, List
from models import build_model, load_voice, generate_speech, list_available_voices
import argparse
from tqdm import tqdm
from tqdm.auto import tqdm
import soundfile as sf
from pathlib import Path
@@ -13,6 +13,9 @@ DEFAULT_OUTPUT_FILE = 'output.wav'
DEFAULT_LANGUAGE = 'a' # TODO: Document why this is 'a' or make configurable
DEFAULT_TEXT = "Hello, welcome to this text-to-speech test."
# Configure tqdm for better Windows console support
tqdm.monitor_interval = 0 # Disable monitor thread to prevent encoding issues
def load_and_validate_voice(voice_name: str, device: str) -> torch.Tensor:
"""Load and validate the requested voice.
@@ -36,7 +39,7 @@ def main() -> None:
# Parse command line arguments
parser = argparse.ArgumentParser(description='Kokoro TTS Demo')
parser.add_argument('--text', type=str, help='Text to synthesize (optional)')
parser.add_argument('--voice', type=str, default='af', help='Voice to use (default: af)')
parser.add_argument('--voice', type=str, default='af_bella', help='Voice to use (default: af_bella)')
parser.add_argument('--list-voices', action='store_true', help='List all available voices')
parser.add_argument('--model', type=str, default=DEFAULT_MODEL_PATH, help=f'Path to model file (default: {DEFAULT_MODEL_PATH})')
parser.add_argument('--output', type=str, default=DEFAULT_OUTPUT_FILE, help=f'Output WAV file (default: {DEFAULT_OUTPUT_FILE})')