Merge pull request #53 from huggingface/add-paraformer
Add paraformer - Chinese STT
This commit is contained in:
61
STT/paraformer_handler.py
Normal file
61
STT/paraformer_handler.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from time import perf_counter
|
||||
|
||||
from baseHandler import BaseHandler
|
||||
from funasr import AutoModel
|
||||
import numpy as np
|
||||
from rich.console import Console
|
||||
import torch
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class ParaformerSTTHandler(BaseHandler):
|
||||
"""
|
||||
Handles the Speech To Text generation using a Paraformer model.
|
||||
The default for this model is set to Chinese.
|
||||
This model was contributed by @wuhongsheng.
|
||||
"""
|
||||
|
||||
def setup(
|
||||
self,
|
||||
model_name="paraformer-zh",
|
||||
device="cuda",
|
||||
gen_kwargs={},
|
||||
):
|
||||
print(model_name)
|
||||
if len(model_name.split("/")) > 1:
|
||||
model_name = model_name.split("/")[-1]
|
||||
self.device = device
|
||||
self.model = AutoModel(model=model_name, device=device)
|
||||
self.warmup()
|
||||
|
||||
def warmup(self):
|
||||
logger.info(f"Warming up {self.__class__.__name__}")
|
||||
|
||||
# 2 warmup steps for no compile or compile mode with CUDA graphs capture
|
||||
n_steps = 1
|
||||
dummy_input = np.array([0] * 512, dtype=np.float32)
|
||||
for _ in range(n_steps):
|
||||
_ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "")
|
||||
|
||||
def process(self, spoken_prompt):
|
||||
logger.debug("infering paraformer...")
|
||||
|
||||
global pipeline_start
|
||||
pipeline_start = perf_counter()
|
||||
|
||||
pred_text = (
|
||||
self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "")
|
||||
)
|
||||
torch.mps.empty_cache()
|
||||
|
||||
logger.debug("finished paraformer inference")
|
||||
console.print(f"[yellow]USER: {pred_text}")
|
||||
|
||||
yield pred_text
|
||||
@@ -23,7 +23,7 @@ class ModuleArguments:
|
||||
stt: Optional[str] = field(
|
||||
default="whisper",
|
||||
metadata={
|
||||
"help": "The STT to use. Either 'whisper' or 'whisper-mlx'. Default is 'whisper'."
|
||||
"help": "The STT to use. Either 'whisper', 'whisper-mlx', and 'paraformer'. Default is 'whisper'."
|
||||
},
|
||||
)
|
||||
llm: Optional[str] = field(
|
||||
|
||||
17
arguments_classes/paraformer_stt_arguments.py
Normal file
17
arguments_classes/paraformer_stt_arguments.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParaformerSTTHandlerArguments:
|
||||
paraformer_stt_model_name: str = field(
|
||||
default="paraformer-zh",
|
||||
metadata={
|
||||
"help": "The pretrained model to use. Default is 'paraformer-zh'. Can be choose from https://github.com/modelscope/FunASR"
|
||||
},
|
||||
)
|
||||
paraformer_stt_device: str = field(
|
||||
default="cuda",
|
||||
metadata={
|
||||
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
|
||||
},
|
||||
)
|
||||
@@ -2,4 +2,6 @@ nltk==3.9.1
|
||||
parler_tts @ git+https://github.com/huggingface/parler-tts.git
|
||||
melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a copy of MeloTTS to have compatible versions of transformers
|
||||
torch==2.4.0
|
||||
sounddevice==0.5.0
|
||||
sounddevice==0.5.0
|
||||
funasr
|
||||
modelscope
|
||||
@@ -4,4 +4,6 @@ melotts @ git+https://github.com/andimarafioti/MeloTTS.git#egg=MeloTTS # made a
|
||||
torch==2.4.0
|
||||
sounddevice==0.5.0
|
||||
lightning-whisper-mlx>=0.0.10
|
||||
mlx-lm>=0.14.0
|
||||
mlx-lm>=0.14.0
|
||||
funasr>=1.1.6
|
||||
modelscope>=1.17.1
|
||||
@@ -13,6 +13,7 @@ from arguments_classes.mlx_language_model_arguments import (
|
||||
MLXLanguageModelHandlerArguments,
|
||||
)
|
||||
from arguments_classes.module_arguments import ModuleArguments
|
||||
from arguments_classes.paraformer_stt_arguments import ParaformerSTTHandlerArguments
|
||||
from arguments_classes.parler_tts_arguments import ParlerTTSHandlerArguments
|
||||
from arguments_classes.socket_receiver_arguments import SocketReceiverArguments
|
||||
from arguments_classes.socket_sender_arguments import SocketSenderArguments
|
||||
@@ -73,6 +74,7 @@ def main():
|
||||
SocketSenderArguments,
|
||||
VADHandlerArguments,
|
||||
WhisperSTTHandlerArguments,
|
||||
ParaformerSTTHandlerArguments,
|
||||
LanguageModelHandlerArguments,
|
||||
MLXLanguageModelHandlerArguments,
|
||||
ParlerTTSHandlerArguments,
|
||||
@@ -89,6 +91,7 @@ def main():
|
||||
socket_sender_kwargs,
|
||||
vad_handler_kwargs,
|
||||
whisper_stt_handler_kwargs,
|
||||
paraformer_stt_handler_kwargs,
|
||||
language_model_handler_kwargs,
|
||||
mlx_language_model_handler_kwargs,
|
||||
parler_tts_handler_kwargs,
|
||||
@@ -102,6 +105,7 @@ def main():
|
||||
socket_sender_kwargs,
|
||||
vad_handler_kwargs,
|
||||
whisper_stt_handler_kwargs,
|
||||
paraformer_stt_handler_kwargs,
|
||||
language_model_handler_kwargs,
|
||||
mlx_language_model_handler_kwargs,
|
||||
parler_tts_handler_kwargs,
|
||||
@@ -163,6 +167,8 @@ def main():
|
||||
kwargs.tts_device = common_device
|
||||
if hasattr(kwargs, "stt_device"):
|
||||
kwargs.stt_device = common_device
|
||||
if hasattr(kwargs, "paraformer_stt_device"):
|
||||
kwargs.paraformer_stt_device = common_device
|
||||
|
||||
# Call this function with the common device and all the handlers
|
||||
overwrite_device_argument(
|
||||
@@ -171,9 +177,11 @@ def main():
|
||||
mlx_language_model_handler_kwargs,
|
||||
parler_tts_handler_kwargs,
|
||||
whisper_stt_handler_kwargs,
|
||||
paraformer_stt_handler_kwargs,
|
||||
)
|
||||
|
||||
prepare_args(whisper_stt_handler_kwargs, "stt")
|
||||
prepare_args(paraformer_stt_handler_kwargs, "paraformer_stt")
|
||||
prepare_args(language_model_handler_kwargs, "lm")
|
||||
prepare_args(mlx_language_model_handler_kwargs, "mlx_lm")
|
||||
prepare_args(parler_tts_handler_kwargs, "tts")
|
||||
@@ -243,8 +251,19 @@ def main():
|
||||
queue_out=text_prompt_queue,
|
||||
setup_kwargs=vars(whisper_stt_handler_kwargs),
|
||||
)
|
||||
elif module_kwargs.stt == "paraformer":
|
||||
from STT.paraformer_handler import ParaformerSTTHandler
|
||||
|
||||
stt = ParaformerSTTHandler(
|
||||
stop_event,
|
||||
queue_in=spoken_prompt_queue,
|
||||
queue_out=text_prompt_queue,
|
||||
setup_kwargs=vars(paraformer_stt_handler_kwargs),
|
||||
)
|
||||
else:
|
||||
raise ValueError("The STT should be either whisper or whisper-mlx")
|
||||
raise ValueError(
|
||||
"The STT should be either whisper, whisper-mlx, or paraformer."
|
||||
)
|
||||
if module_kwargs.llm == "transformers":
|
||||
from LLM.language_model import LanguageModelHandler
|
||||
|
||||
|
||||
Reference in New Issue
Block a user