mirror of
https://github.com/oliverguhr/wav2vec2-live.git
synced 2021-10-25 02:25:02 +03:00
upgraded output format
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
import collections, queue
|
||||
from wav2vec2_inference import Wave2Vec2Inference
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
import webrtcvad
|
||||
from halo import Halo
|
||||
import torch
|
||||
import torchaudio
|
||||
from rx.subject import BehaviorSubject
|
||||
import time
|
||||
|
||||
class Audio(object):
|
||||
"""Streams raw audio from microphone. Data is received in a separate thread, and stored in a buffer, to be read from."""
|
||||
@@ -109,6 +112,12 @@ class VADAudio(Audio):
|
||||
ring_buffer.clear()
|
||||
|
||||
def main(ARGS):
|
||||
model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-german"
|
||||
|
||||
wave_buffer = BehaviorSubject(np.array([]))
|
||||
wave2vec_asr = Wave2Vec2Inference(model_name)
|
||||
wave_buffer.subscribe(on_next=lambda x: asr_output_formatter(wave2vec_asr,x))
|
||||
|
||||
# Start audio with VAD
|
||||
vad_audio = VADAudio(aggressiveness=ARGS.webRTC_aggressiveness,
|
||||
device=ARGS.device,
|
||||
@@ -125,7 +134,8 @@ def main(ARGS):
|
||||
(get_speech_ts,_,_, _,_, _, _) = utils
|
||||
|
||||
|
||||
# Stream from microphone to DeepSpeech using VAD
|
||||
# Stream from microphone to Wav2Vec 2.0 using VAD
|
||||
print("audio length\tinference time\ttext")
|
||||
spinner = None
|
||||
if not ARGS.nospinner:
|
||||
spinner = Halo(spinner='line')
|
||||
@@ -137,23 +147,34 @@ def main(ARGS):
|
||||
wav_data.extend(frame)
|
||||
else:
|
||||
if spinner: spinner.stop()
|
||||
print("webRTC has detected a possible speech")
|
||||
#print("webRTC has detected a possible speech")
|
||||
|
||||
newsound= np.frombuffer(wav_data,np.int16)
|
||||
audio_float32=Int2Float(newsound)
|
||||
audio_float32=Int2FloatSimple(newsound)
|
||||
time_stamps =get_speech_ts(audio_float32, model,num_steps=ARGS.num_steps,trig_sum=ARGS.trig_sum,neg_trig_sum=ARGS.neg_trig_sum,
|
||||
num_samples_per_window=ARGS.num_samples_per_window,min_speech_samples=ARGS.min_speech_samples,
|
||||
min_silence_samples=ARGS.min_silence_samples)
|
||||
|
||||
if(len(time_stamps)>0):
|
||||
print("silero VAD has detected a possible speech")
|
||||
#print("silero VAD has detected a possible speech")
|
||||
#float64_buffer = np.frombuffer(wav_data, dtype=np.int16) / 32767 -> hacky version
|
||||
wave_buffer.on_next(audio_float32.numpy())
|
||||
else:
|
||||
print("silero VAD has detected a noise")
|
||||
print()
|
||||
print("VAD detected noise")
|
||||
wav_data = bytearray()
|
||||
|
||||
def asr_output_formatter(asr,audio):
|
||||
start = time.perf_counter()
|
||||
text = asr.buffer_to_text(audio)
|
||||
inference_time = time.perf_counter()-start
|
||||
sample_length = len(audio) / DEFAULT_SAMPLE_RATE
|
||||
print(f"{sample_length:.3f}s\t{inference_time:.3f}s\t{text}")
|
||||
|
||||
def Int2FloatSimple(sound):
|
||||
return torch.from_numpy(np.frombuffer(sound, dtype=np.int16).astype('float32') / 32767)
|
||||
|
||||
def Int2Float(sound):
|
||||
"""converts the format and normalizes the data"""
|
||||
_sound = np.copy(sound) #
|
||||
abs_max = np.abs(_sound).max()
|
||||
_sound = _sound.astype('float32')
|
||||
@@ -13,6 +13,9 @@ class Wave2Vec2Inference():
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
||||
|
||||
def buffer_to_text(self,audio_buffer):
|
||||
if(len(audio_buffer)==0):
|
||||
return ""
|
||||
|
||||
inputs = self.processor([audio_buffer], sampling_rate=16_000, return_tensors="pt", padding=True)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -20,7 +23,7 @@ class Wave2Vec2Inference():
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
transcription = self.processor.batch_decode(predicted_ids)[0]
|
||||
return transcription
|
||||
return transcription.lower()
|
||||
|
||||
def file_to_text(self,filename):
|
||||
audio_input, samplerate = sf.read(filename)
|
||||
@@ -30,5 +33,5 @@ class Wave2Vec2Inference():
|
||||
if __name__ == "__main__":
|
||||
print("Model test")
|
||||
asr = Wave2Vec2Inference("maxidl/wav2vec2-large-xlsr-german")
|
||||
text = asr.file_to_text("some.wav")
|
||||
text = asr.file_to_text("test.wav")
|
||||
print(text)
|
||||
Reference in New Issue
Block a user