upgraded output format

This commit is contained in:
Oliver Guhr
2021-07-08 16:52:58 +02:00
parent 6ad598cbfe
commit ace16f99b2
2 changed files with 32 additions and 8 deletions

View File

@@ -1,10 +1,13 @@
import collections, queue
from wav2vec2_inference import Wave2Vec2Inference
import numpy as np
import pyaudio
import webrtcvad
from halo import Halo
import torch
import torchaudio
from rx.subject import BehaviorSubject
import time
class Audio(object):
"""Streams raw audio from microphone. Data is received in a separate thread, and stored in a buffer, to be read from."""
@@ -109,6 +112,12 @@ class VADAudio(Audio):
ring_buffer.clear()
def main(ARGS):
model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-german"
wave_buffer = BehaviorSubject(np.array([]))
wave2vec_asr = Wave2Vec2Inference(model_name)
wave_buffer.subscribe(on_next=lambda x: asr_output_formatter(wave2vec_asr,x))
# Start audio with VAD
vad_audio = VADAudio(aggressiveness=ARGS.webRTC_aggressiveness,
device=ARGS.device,
@@ -125,7 +134,8 @@ def main(ARGS):
(get_speech_ts,_,_, _,_, _, _) = utils
# Stream from microphone to DeepSpeech using VAD
# Stream from microphone to Wav2Vec 2.0 using VAD
print("audio length\tinference time\ttext")
spinner = None
if not ARGS.nospinner:
spinner = Halo(spinner='line')
@@ -137,23 +147,34 @@ def main(ARGS):
wav_data.extend(frame)
else:
if spinner: spinner.stop()
print("webRTC has detected a possible speech")
#print("webRTC has detected a possible speech")
newsound= np.frombuffer(wav_data,np.int16)
audio_float32=Int2Float(newsound)
audio_float32=Int2FloatSimple(newsound)
time_stamps =get_speech_ts(audio_float32, model,num_steps=ARGS.num_steps,trig_sum=ARGS.trig_sum,neg_trig_sum=ARGS.neg_trig_sum,
num_samples_per_window=ARGS.num_samples_per_window,min_speech_samples=ARGS.min_speech_samples,
min_silence_samples=ARGS.min_silence_samples)
if(len(time_stamps)>0):
print("silero VAD has detected a possible speech")
#print("silero VAD has detected a possible speech")
#float64_buffer = np.frombuffer(wav_data, dtype=np.int16) / 32767 -> hacky version
wave_buffer.on_next(audio_float32.numpy())
else:
print("silero VAD has detected a noise")
print()
print("VAD detected noise")
wav_data = bytearray()
def asr_output_formatter(asr,audio):
start = time.perf_counter()
text = asr.buffer_to_text(audio)
inference_time = time.perf_counter()-start
sample_length = len(audio) / DEFAULT_SAMPLE_RATE
print(f"{sample_length:.3f}s\t{inference_time:.3f}s\t{text}")
def Int2FloatSimple(sound):
return torch.from_numpy(np.frombuffer(sound, dtype=np.int16).astype('float32') / 32767)
def Int2Float(sound):
"""converts the format and normalizes the data"""
_sound = np.copy(sound) #
abs_max = np.abs(_sound).max()
_sound = _sound.astype('float32')

View File

@@ -13,6 +13,9 @@ class Wave2Vec2Inference():
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
def buffer_to_text(self,audio_buffer):
if(len(audio_buffer)==0):
return ""
inputs = self.processor([audio_buffer], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
@@ -20,7 +23,7 @@ class Wave2Vec2Inference():
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0]
return transcription
return transcription.lower()
def file_to_text(self,filename):
audio_input, samplerate = sf.read(filename)
@@ -30,5 +33,5 @@ class Wave2Vec2Inference():
if __name__ == "__main__":
print("Model test")
asr = Wave2Vec2Inference("maxidl/wav2vec2-large-xlsr-german")
text = asr.file_to_text("some.wav")
text = asr.file_to_text("test.wav")
print(text)