initial commit

This commit is contained in:
Oliver Guhr
2021-04-15 13:34:40 +02:00
parent 7e6b0706bf
commit 8248ff0966
5 changed files with 226 additions and 1 deletions

View File

@@ -1,2 +1,57 @@
# wav2vec-live
# automatic speech recognition with wav2vec2
Use any wav2vec model with a microphone.
![demo gif](./docs/wav2veclive.gif)
## Setup
I recommend to install this project in a virtual environment.
```
python3 -m venv ./venv
source ./venv/bin/activate
pip install -r requirements.txt
```
Depending on linux distribution you might encounter an **error that portaudio was not found** when installing pyaudio. For Ubuntu you can solve that issue by installing the "portaudio19-dev" package.
```
sudo apt install portaudio19-dev
```
Finally you can test the speech recognition:
```
python live_asr.py
```
### Possible Issues:
* The code uses the systems default audio device. Please make sure that you set your systems default audio device correctly.
* "*attempt to connect to server failed*" you can safely ignore this message from pyaudio. It just means, that pyaudio can't connect to the jack audio server.
## Usage
You can use any **wav2vec2** model from the [huggingface model hub](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&search=wav2vec2). Just set the model name, all files will be downloaded on first execution.
```python
from live_asr import LiveWav2Vec2
english_model = "facebook/wav2vec2-large-960h-lv60-self"
german_model = "maxidl/wav2vec2-large-xlsr-german"
asr = LiveWav2Vec2(german_model,device_name="default")
asr.start()
try:
while True:
text,sample_length,inference_time = asr.get_last_text()
print(f"{sample_length:.3f}s"
+f"\t{inference_time:.3f}s"
+f"\t{text}")
except KeyboardInterrupt:
asr.stop()
```

BIN
docs/wav2veclive.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 MiB

131
live_asr.py Normal file
View File

@@ -0,0 +1,131 @@
import pyaudio
import webrtcvad
from wav2vec2_inference import Wave2Vec2Inference
import numpy as np
from multiprocessing import Process, Queue
import copy
import time
from sys import exit
import contextvars
class LiveWav2Vec2():
is_listening = contextvars.ContextVar('global_is_listening_state', default=None)
def __init__(self, model_name, device_name="default"):
self.model_name = model_name
self.device_name = device_name
self.is_listening = None
self.asr_output_queue = Queue()
self.asr_input_queue = Queue()
def stop(self):
"""stop the asr process"""
LiveWav2Vec2.is_listening.set(False)
self.asr_input_queue.close()
self.asr_output_queue.close()
print("asr stopped")
def start(self):
"""start the asr process"""
LiveWav2Vec2.is_listening.set(True)
self.asr_process = Process(target=LiveWav2Vec2.asr_process, args=(
self.model_name, self.asr_input_queue, self.asr_output_queue,))
self.asr_process.daemon = True
self.asr_process.start()
time.sleep(5) # start vad after asr model is loaded
self.vad_process = Process(target=LiveWav2Vec2.vad_process, args=(
self.device_name, self.asr_input_queue,))
self.vad_process.daemon = True
self.vad_process.start()
def vad_process(device_name, asr_input_queue):
vad = webrtcvad.Vad()
vad.set_mode(1)
audio = pyaudio.PyAudio()
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 16000
# A frame must be either 10, 20, or 30 ms in duration for webrtcvad
FRAME_DURATION = 30
CHUNK = int(RATE * FRAME_DURATION / 1000)
RECORD_SECONDS = 50
microphones = LiveWav2Vec2.list_microphones(audio)
selected_input_device_id = LiveWav2Vec2.get_input_device_id(
device_name, microphones)
stream = audio.open(input_device_index=selected_input_device_id,
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=CHUNK)
frames = b''
while LiveWav2Vec2.is_listening.get():
frame = stream.read(CHUNK)
is_speech = vad.is_speech(frame, RATE)
if is_speech:
frames += frame
else:
if len(frames) > 1:
asr_input_queue.put(frames)
frames = b''
stream.stop_stream()
stream.close()
audio.terminate()
def asr_process(model_name, in_queue, output_queue):
wave2vec_asr = Wave2Vec2Inference(model_name)
print("\nlistening to your voice\n")
while LiveWav2Vec2.is_listening.get():
in_queue
audio_frames = in_queue.get()
float64_buffer = np.frombuffer(
audio_frames, dtype=np.int16) / 32767
start = time.perf_counter()
text = wave2vec_asr.buffer_to_text(float64_buffer).lower()
inference_time = time.perf_counter()-start
sample_length = len(float64_buffer) / 16000 # length in sec
if text != "":
output_queue.put([text,sample_length,inference_time])
def get_input_device_id(device_name, microphones):
for device in microphones:
if device_name in device[1]:
return device[0]
def list_microphones(pyaudio_instance):
info = pyaudio_instance.get_host_api_info_by_index(0)
numdevices = info.get('deviceCount')
result = []
for i in range(0, numdevices):
if (pyaudio_instance.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
name = pyaudio_instance.get_device_info_by_host_api_device_index(
0, i).get('name')
result += [[i, name]]
return result
def get_last_text(self):
"""returns the text, sample length and inference time in seconds."""
return self.asr_output_queue.get()
if __name__ == "__main__":
print("Live ASR")
asr = LiveWav2Vec2("facebook/wav2vec2-large-960h-lv60-self")
asr.start()
try:
while True:
text,sample_length,inference_time = asr.get_last_text()
print(f"{sample_length:.3f}s\t{inference_time:.3f}s\t{text}")
except KeyboardInterrupt:
print("stopping")
asr.stop()
exit()

5
requirements.txt Normal file
View File

@@ -0,0 +1,5 @@
soundfile
torch
transformers
pyaudio
webrtcvad

34
wav2vec2_inference.py Normal file
View File

@@ -0,0 +1,34 @@
import soundfile as sf
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
# Improvements:
# - gpu / cpu flag
# - convert non 16 khz sample rates
# - inference time log
class Wave2Vec2Inference():
def __init__(self,model_name):
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
def buffer_to_text(self,audio_buffer):
inputs = self.processor([audio_buffer], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = self.model(inputs.input_values, attention_mask=torch.ones(len(inputs.input_values[0]))).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0]
return transcription
def file_to_text(self,filename):
audio_input, samplerate = sf.read(filename)
assert samplerate == 16000
return self.buffer_to_text(audio_input)
if __name__ == "__main__":
print("Model test")
asr = Wave2VecInference2("maxidl/wav2vec2-large-xlsr-german")
text = asr.file_to_text("some.wav")
print(text)