mirror of
https://github.com/oliverguhr/wav2vec2-live.git
synced 2021-10-25 02:25:02 +03:00
initial commit
This commit is contained in:
57
README.md
57
README.md
@@ -1,2 +1,57 @@
|
||||
# wav2vec-live
|
||||
# automatic speech recognition with wav2vec2
|
||||
|
||||
Use any wav2vec model with a microphone.
|
||||
|
||||

|
||||
|
||||
## Setup
|
||||
|
||||
I recommend to install this project in a virtual environment.
|
||||
|
||||
```
|
||||
python3 -m venv ./venv
|
||||
source ./venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Depending on linux distribution you might encounter an **error that portaudio was not found** when installing pyaudio. For Ubuntu you can solve that issue by installing the "portaudio19-dev" package.
|
||||
|
||||
```
|
||||
sudo apt install portaudio19-dev
|
||||
```
|
||||
|
||||
Finally you can test the speech recognition:
|
||||
|
||||
```
|
||||
python live_asr.py
|
||||
```
|
||||
|
||||
### Possible Issues:
|
||||
|
||||
* The code uses the systems default audio device. Please make sure that you set your systems default audio device correctly.
|
||||
|
||||
* "*attempt to connect to server failed*" you can safely ignore this message from pyaudio. It just means, that pyaudio can't connect to the jack audio server.
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
You can use any **wav2vec2** model from the [huggingface model hub](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&search=wav2vec2). Just set the model name, all files will be downloaded on first execution.
|
||||
|
||||
```python
|
||||
from live_asr import LiveWav2Vec2
|
||||
|
||||
english_model = "facebook/wav2vec2-large-960h-lv60-self"
|
||||
german_model = "maxidl/wav2vec2-large-xlsr-german"
|
||||
asr = LiveWav2Vec2(german_model,device_name="default")
|
||||
asr.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
text,sample_length,inference_time = asr.get_last_text()
|
||||
print(f"{sample_length:.3f}s"
|
||||
+f"\t{inference_time:.3f}s"
|
||||
+f"\t{text}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
asr.stop()
|
||||
```
|
||||
BIN
docs/wav2veclive.gif
Normal file
BIN
docs/wav2veclive.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.9 MiB |
131
live_asr.py
Normal file
131
live_asr.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import pyaudio
|
||||
import webrtcvad
|
||||
from wav2vec2_inference import Wave2Vec2Inference
|
||||
import numpy as np
|
||||
from multiprocessing import Process, Queue
|
||||
import copy
|
||||
import time
|
||||
from sys import exit
|
||||
import contextvars
|
||||
|
||||
|
||||
class LiveWav2Vec2():
|
||||
is_listening = contextvars.ContextVar('global_is_listening_state', default=None)
|
||||
|
||||
def __init__(self, model_name, device_name="default"):
|
||||
self.model_name = model_name
|
||||
self.device_name = device_name
|
||||
self.is_listening = None
|
||||
self.asr_output_queue = Queue()
|
||||
self.asr_input_queue = Queue()
|
||||
|
||||
def stop(self):
|
||||
"""stop the asr process"""
|
||||
LiveWav2Vec2.is_listening.set(False)
|
||||
self.asr_input_queue.close()
|
||||
self.asr_output_queue.close()
|
||||
print("asr stopped")
|
||||
|
||||
def start(self):
|
||||
"""start the asr process"""
|
||||
LiveWav2Vec2.is_listening.set(True)
|
||||
self.asr_process = Process(target=LiveWav2Vec2.asr_process, args=(
|
||||
self.model_name, self.asr_input_queue, self.asr_output_queue,))
|
||||
self.asr_process.daemon = True
|
||||
self.asr_process.start()
|
||||
time.sleep(5) # start vad after asr model is loaded
|
||||
self.vad_process = Process(target=LiveWav2Vec2.vad_process, args=(
|
||||
self.device_name, self.asr_input_queue,))
|
||||
self.vad_process.daemon = True
|
||||
self.vad_process.start()
|
||||
|
||||
def vad_process(device_name, asr_input_queue):
|
||||
vad = webrtcvad.Vad()
|
||||
vad.set_mode(1)
|
||||
|
||||
audio = pyaudio.PyAudio()
|
||||
FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 1
|
||||
RATE = 16000
|
||||
# A frame must be either 10, 20, or 30 ms in duration for webrtcvad
|
||||
FRAME_DURATION = 30
|
||||
CHUNK = int(RATE * FRAME_DURATION / 1000)
|
||||
RECORD_SECONDS = 50
|
||||
|
||||
microphones = LiveWav2Vec2.list_microphones(audio)
|
||||
selected_input_device_id = LiveWav2Vec2.get_input_device_id(
|
||||
device_name, microphones)
|
||||
|
||||
stream = audio.open(input_device_index=selected_input_device_id,
|
||||
format=FORMAT,
|
||||
channels=CHANNELS,
|
||||
rate=RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK)
|
||||
|
||||
frames = b''
|
||||
while LiveWav2Vec2.is_listening.get():
|
||||
frame = stream.read(CHUNK)
|
||||
is_speech = vad.is_speech(frame, RATE)
|
||||
if is_speech:
|
||||
frames += frame
|
||||
else:
|
||||
if len(frames) > 1:
|
||||
asr_input_queue.put(frames)
|
||||
frames = b''
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
audio.terminate()
|
||||
|
||||
def asr_process(model_name, in_queue, output_queue):
|
||||
wave2vec_asr = Wave2Vec2Inference(model_name)
|
||||
|
||||
print("\nlistening to your voice\n")
|
||||
while LiveWav2Vec2.is_listening.get():
|
||||
in_queue
|
||||
audio_frames = in_queue.get()
|
||||
float64_buffer = np.frombuffer(
|
||||
audio_frames, dtype=np.int16) / 32767
|
||||
start = time.perf_counter()
|
||||
text = wave2vec_asr.buffer_to_text(float64_buffer).lower()
|
||||
inference_time = time.perf_counter()-start
|
||||
sample_length = len(float64_buffer) / 16000 # length in sec
|
||||
if text != "":
|
||||
output_queue.put([text,sample_length,inference_time])
|
||||
|
||||
def get_input_device_id(device_name, microphones):
|
||||
for device in microphones:
|
||||
if device_name in device[1]:
|
||||
return device[0]
|
||||
|
||||
def list_microphones(pyaudio_instance):
|
||||
info = pyaudio_instance.get_host_api_info_by_index(0)
|
||||
numdevices = info.get('deviceCount')
|
||||
|
||||
result = []
|
||||
for i in range(0, numdevices):
|
||||
if (pyaudio_instance.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
|
||||
name = pyaudio_instance.get_device_info_by_host_api_device_index(
|
||||
0, i).get('name')
|
||||
result += [[i, name]]
|
||||
return result
|
||||
|
||||
def get_last_text(self):
|
||||
"""returns the text, sample length and inference time in seconds."""
|
||||
return self.asr_output_queue.get()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Live ASR")
|
||||
|
||||
asr = LiveWav2Vec2("facebook/wav2vec2-large-960h-lv60-self")
|
||||
asr.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
text,sample_length,inference_time = asr.get_last_text()
|
||||
print(f"{sample_length:.3f}s\t{inference_time:.3f}s\t{text}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("stopping")
|
||||
asr.stop()
|
||||
exit()
|
||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
soundfile
|
||||
torch
|
||||
transformers
|
||||
pyaudio
|
||||
webrtcvad
|
||||
34
wav2vec2_inference.py
Normal file
34
wav2vec2_inference.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
||||
|
||||
# Improvements:
|
||||
# - gpu / cpu flag
|
||||
# - convert non 16 khz sample rates
|
||||
# - inference time log
|
||||
|
||||
class Wave2Vec2Inference():
|
||||
def __init__(self,model_name):
|
||||
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
||||
|
||||
def buffer_to_text(self,audio_buffer):
|
||||
inputs = self.processor([audio_buffer], sampling_rate=16_000, return_tensors="pt", padding=True)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.model(inputs.input_values, attention_mask=torch.ones(len(inputs.input_values[0]))).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
transcription = self.processor.batch_decode(predicted_ids)[0]
|
||||
return transcription
|
||||
|
||||
def file_to_text(self,filename):
|
||||
audio_input, samplerate = sf.read(filename)
|
||||
assert samplerate == 16000
|
||||
return self.buffer_to_text(audio_input)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Model test")
|
||||
asr = Wave2VecInference2("maxidl/wav2vec2-large-xlsr-german")
|
||||
text = asr.file_to_text("some.wav")
|
||||
print(text)
|
||||
Reference in New Issue
Block a user